
pymc模型中,当使用自定义pytensor op定义对数似然并尝试结合blackjax采样器时,可能遭遇jax转换兼容性错误。本文将深入探讨如何实现自定义对数似然,分析blackjax集成时的挑战,并提供一种通过数学表达式重构来显著提升核心计算函数性能的通用优化策略,即使无法利用jax加速,也能有效缩短采样时间。
在贝叶斯建模中,有时标准分布无法满足特定需求,需要引入自定义的对数似然函数。PyMC(基于PyTensor)提供了一种机制,允许用户通过定义自定义的pytensor.Op来集成任意Python函数及其梯度。
要将一个复杂的Python函数(例如,涉及外部库或数值求解器)集成到PyMC的计算图中,需要创建两个pytensor.Op类:一个用于计算函数值(对数似然),另一个用于计算其梯度。
LogLikeWithGrad (对数似然函数)
这个Op负责计算给定参数的对数似然值。它需要实现perform方法来执行实际的对数似然计算,并重载grad方法来指定如何计算梯度。
import pytensor.tensor as pt
import numpy as np
from scipy.optimize import approx_fprime # 用于数值梯度
class LogLikeWithGrad(pt.Op):
itypes = [pt.dvector] # 输入是一个参数向量
otypes = [pt.dscalar] # 输出是一个标量(对数似然值)
def __init__(self, loglike_function):
self.likelihood = loglike_function
self.loglike_grad_op = LogLikeGrad(loglike_function) # 初始化梯度Op
def perform(self, node, inputs, outputs):
(theta,) = inputs
logl = self.likelihood(theta)
outputs[0][0] = np.array(logl)
def grad(self, inputs, grad_outputs):
(theta,) = inputs
# 调用自定义的梯度Op来计算梯度
grads = self.loglike_grad_op(theta)
return [grad_outputs[0] * grads]LogLikeGrad (对数似然梯度函数)
这个Op专门用于计算对数似然函数相对于其输入的梯度。在缺乏解析梯度的情况下,可以使用数值近似方法,例如scipy.optimize.approx_fprime。
class LogLikeGrad(pt.Op):
itypes = [pt.dvector] # 输入是一个参数向量
otypes = [pt.dvector] # 输出是一个梯度向量
def __init__(self, loglike_function):
self.likelihood = loglike_function
def perform(self, node, inputs, outputs):
(theta,) = inputs
# 使用数值方法近似梯度
grads = approx_fprime(theta, self.likelihood, epsilon=1e-8)
outputs[0][0] = grads一旦定义了自定义的LogLikeWithGrad Op,就可以将其作为pm.Potential添加到PyMC模型中。pm.Potential允许用户在模型中引入任意的对数概率贡献。
import pymc as pm
# 假设 applyMCMC 是你的核心对数似然计算函数
# 并且 param_names 和 lower/upper_boundaries 已经定义
# logl = LogLikeWithGrad(applyMCMC)
with pm.Model() as model:
# 定义模型参数
for i, name in enumerate(param_names):
pm.Uniform(name, lower=lower_boundaries[0][i], upper=upper_boundaries[0][i])
# 将所有参数组合成一个PyTensor向量
theta = pt.as_tensor_variable([model[param] for param in param_names])
# 将自定义对数似然作为潜力项添加到模型中
pm.Potential("likelihood", logl(theta))
# 执行采样
# trace = pm.sample(draws=niter, step=pm.NUTS(), tune=500, cores=64, init="jitter+adapt_diag", progressbar=True)PyMC 5.x 版本支持多种NUTS采样器后端,包括其默认的PyTensor后端以及基于JAX的Blackjax采样器,后者在GPU等加速设备上表现出色。然而,当模型中包含自定义的pytensor.Op时,尝试使用Blackjax采样器可能会遇到兼容性问题。
当尝试通过 pm.sample(nuts_sampler="blackjax") 使用Blackjax时,如果自定义的LogLikeWithGrad Op没有对应的JAX转换实现,PyTensor的JAX后端会抛出 NotImplementedError,错误信息通常为 No JAX conversion for the given Op: LogLikeWithGrad。
Primeshot
专业级AI人像摄影工作室
36
查看详情
这是因为Blackjax采样器依赖于JAX的即时编译(JIT)能力,而JAX只能编译其能够理解的操作。自定义的pytensor.Op本质上是一个Python对象,PyTensor需要一个明确的规则来告诉JAX如何将其转换为JAX可执行的操作。对于标准PyTensor操作,这些转换已经内置,但对于用户自定义的Op,则需要手动提供。
解决此问题通常需要以下两种方法之一:
在许多情况下,特别是当自定义似然函数依赖于复杂外部库(如物理模拟器)时,直接将其完全转换为JAX操作可能非常困难或不可能。此时,即使无法利用Blackjax的JAX加速,我们仍然可以通过优化核心计算逻辑来提升采样性能。
即使无法直接利用JAX的GPU加速,通过对核心数学计算函数进行细致的优化,也能显著提升PyMC模型的采样速度。这种优化策略侧重于减少冗余计算、避免重复的函数调用以及利用局部变量缓存中间结果。
在复杂的数学表达式中,往往存在重复计算相同子表达式的情况。通过将这些子表达式的结果存储在局部变量中,可以避免多次计算,从而提高效率。
以原始代码中的dH和du函数为例,它们包含大量重复的幂运算和乘法:
以下是针对 dH 和 du 函数的优化版本,通过引入局部变量来缓存重复计算的中间结果:
import math
import timeit # 用于性能测试
# 假设 Rho_m, Phi, u, omega_BD, Omega_k, z 为示例输入
# (为了测试方便,这里使用任意值,实际应是模型参数)
Rho_m = -1.0
Phi = 0.1
u = 3.0
omega_BD = 4.0
Omega_k = -5.0
z = 6.0
# 原始 du 函数 (为对比而保留,实际代码中应替换为优化版本)
def du_original(Rho_m, Phi, u, omega_BD, Omega_k, z):
return (
24 * math.pi * Rho_m * Phi**3
+ (1 + z)
* u
* Phi**2
* (
8 * math.pi * (-3 + omega_BD) * Rho_m
- 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
)
- 3
* (1 + z) ** 2
* u**2
* Phi
* (
-4 * math.pi * omega_BD * Rho_m
+ (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
)
- omega_BD
* u**3
* (
4 * math.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
+ (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
)
) / (
(1 + z) ** 2
* (3 + 2 * omega_BD)
* Phi**2
* (8 * math.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
)
# 优化后的 du 函数
def du_optimized(Rho_m, Phi, u, omega_BD, Omega_k, z):
# 缓存幂次和重复乘法
Phi_pow2 = Phi * Phi
Phi_pow3 = Phi_pow2 * Phi
one_plus_z = 1 + z
one_plus_z_pow2 = one_plus_z * one_plus_z
one_plus_z_pow3 = one_plus_z_pow2 * one_plus_z
one_plus_z_pow4 = one_plus_z_pow3 * one_plus_z
one_plus_z_pow5 = one_plus_z_pow4 * one_plus_z
# 缓存其他重复子表达式
one_plus_z_pow2_times_3 = 3 * one_plus_z_pow2
pi_times_Rho_m = math.pi * Rho_m
Omega_k_times_Phi = Omega_k * Phi
u_pow2 = u * u
u_pow3 = u_pow2 * u
omg1 = (3 + 2 * omega_BD) # (3 + 2 * omega_BD)
omg = omg1 * Omega_k_times_Phi # (3 + 2 * omega_BD) * Omega_k * Phi
omg2 = omega_BD * pi_times_Rho_m # omega_BD * math.pi * Rho_m
return (
24 * pi_times_Rho_m * Phi_pow3
+ one_plus_z * u * Phi_pow2 * (8 * (-3 + omega_BD) * pi_times_Rho_m - one_plus_z_pow2_times_3 * omg)
- one_plus_z_pow2_times_3 * u_pow2 * Phi * (-4 * omg2 + one_plus_z_pow4 * omg)
- omega_BD * u_pow3 * (4 * one_plus_z_pow3 * (pi_times_Rho_m + omg2) + one_plus_z_pow5 * omg)
) / (
one_plus_z_pow2 * omg1 * Phi_pow2 * (8 * pi_times_Rho_m + one_plus_z_pow2_times_3 * Omega_k_times_Phi)
)
# 原始 dH 函数 (为对比而保留)
def dH_original(Rho_m, Phi, u, omega_BD, Omega_k, z):
val = (-16 * math.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
)
if val >= 0:
return -(
(
(1 + z)
* (16 * math.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
* (
(1 + z) * omega_BD * u**3
- 2
* omega_BD
* u
* ((1 + z) * du_original(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
* Phi
- 6 * du_original(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
)
+ (
6
* Phi
* (
-8 * math.pi * Rho_m
+ (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
)
* (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
)
/ (1 + z)
)
/ (
2
* math.sqrt(val)
* (
(1 + z) ** 2 * omega_BD * u**2
+ 6 * (1 + z) * u * Phi
- 6 * Phi**2
)
** 2
)
)
else:
return None
# 优化后的 dH 函数
def dH_optimized(Rho_m, Phi, u, omega_BD, Omega_k, z):
# 缓存常用变量和幂次
Phi_pow2 = Phi * Phi
Phi_pow2_times_6 = Phi_pow2 * 6
Phi_times_6 = Phi * 6
one_plus_z = 1 + z
one_plus_z_pow2 = one_plus_z * one_plus_z
one_plus_z_times_u = one_plus_z * u
pi_times_Rho_m = math.pi * Rho_m
Omega_k_times_Phi = Omega_k * Phi
u_pow2 = u * u
u_pow3 = u_pow2 * u
# 重新计算 duu (如果 duu 在 dH 内部被调用多次,直接内联其计算可进一步优化)
# 此处为简洁起见,仍调用 du_optimized,但注意实际场景可内联
# 或者,如原答案所示,直接将 du_optimized 的计算逻辑复制到此处
# duu 的内联计算部分 (来自 du_optimized)
Phi_pow3_du = Phi_pow2 * Phi
one_plus_z_pow3_du = one_plus_z_pow2 * one_plus_z
one_plus_z_pow4_du = one_plus_z_pow3_du * one_plus_z
one_plus_z_pow5_du = one_plus_z_pow4_du * one_plus_z
one_plus_z_pow2_times_3_du = 3 * one_plus_z_pow2
omg1_du = (3 + 2 * omega_BD)
omg_du = omg1_du * Omega_k_times_Phi
omg2_du = omega_BD * pi_times_Rho_m
duu = (
24 * pi_times_Rho_m * Phi_pow3_du
+ one_plus_z * u * Phi_pow2 * (8 * (-3 + omega_BD) * pi_times_Rho_m - one_plus_z_pow2_times_3_du * omg_du)
- one_plus_z_pow2_times_3_du * u_pow2 * Phi * (-4 * omg2_du + one_plus_z_pow4_du * omg_du)
- omega_BD * u_pow3 * (4 * one_plus_z_pow3_du * (pi_times_Rho_m + omg2_du) + one_plus_z_pow5_du * omg_du)
) / (
one_plus_z_pow2 * omg1_du * Phi_pow2 * (8 * pi_times_Rho_m + one_plus_z_pow2_times_3_du * Omega_k_times_Phi)
)
# duu 内联计算结束
val1 = (-16 * pi_times_Rho_m - 6 * one_plus_z_pow2 * Omega_k_times_Phi)
val = val1 / (6 * one_plus_z_times_u + (one_plus_z_pow2 * omega_BD * u_pow2) / Phi - Phi_times_6)
if val >= 0:
Phi_times_2 = Phi + Phi
val2 = (one_plus_z_pow2 * omega_BD * u_pow2 + 6 * (one_plus_z_times_u * Phi - Phi_pow2))
# 优化分子中的复杂项
term1_numerator = one_plus_z_pow2 * val1 * (omega_BD * u * (one_plus_z * u_pow2 - Phi_times_2 * (one_plus_z * duu + u)) - duu * Phi_pow2_times_6)
term2_numerator = Phi_times_6 * (-8 * pi_times_Rho_m + one_plus_z_pow2 * Omega_k * (one_plus_z_times_u + Phi_times_2)) * (Phi_pow2_times_6 - one_plus_z_times_u * (one_plus_z_times_u * omega_BD + Phi_times_6))
return (term1_numerator - term2_numerator) / (2 * one_plus_z * math.sqrt(val) * val2 * val2)
else:
return None
# 性能测试
t_original = timeit.timeit('dH_original(Rho_m, Phi, u, omega_以上就是PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构的详细内容,更多请关注其它相关文章!
# 转换为
# 网站建设什么因素重要
# 河源sem网站优化报价
# 网站建设优化推广安徽
# seo文章url
# 石狮seo推广
# 网站产品推广活动方案
# 铜山区品质网站建设优势
# 面团儿seo博客
# 网贷网站建设公司
# 赞助营销推广的策略
# 也能
# 浮点
# python
# 将其
# 采样器
# 重构
# 是一个
# 自定义
# 模拟器
# python函数
# 性能测试
# 后端
# app
# node
相关栏目:
【
Google疑问12 】
【
Facebook疑问10 】
【
优化推广96088 】
【
技术知识133117 】
【
IDC资讯59369 】
【
网络运营7196 】
【
IT资讯61894 】
相关推荐:
《幻兽帕鲁》手游帕鲁捕捉技巧分享
Win10怎么设置快速启动 Win10开启快速启动设置方法
《虎扑》取消评分记录方法
荣耀盒子应用管理技巧
如何查询国外邮政编码_国外邮政编码查询的多种有效途径
怎样设置开机后自动运行某个程序_Windows启动文件夹与任务计划【自动化】
在PySimpleGUI中实现键盘按键绑定按钮事件
创建您的便携版VS Code:让配置随身携带
VS Code快捷键when上下文子句的妙用
sublime如何撤销关闭的标签页_sublime重新打开已关闭文件技巧
智慧职教mooc平台登录网址 智慧职教mooc官网直达
外卖小程序对接第三方配送
餐馆菜篮选购指南
b站怎么用微信登录_b站微信登录方法
我居然低估了 DeepSeek,这次更新它做到了这些!
Flexbox布局:实现粘性导航与底部页脚的完美结合
支付宝登录刷脸不是本人如何解决
Sublime怎么格式化HTML代码_Sublime前端代码美化插件使用指南
如何在Golang中处理表单文件上传_Golang 表单文件上传示例
Excel怎么用XLOOKUP函数实现双向查找_ExcelXLOOKUP替代VLOOKUP+HLOOKUP的高级用法
太平年在哪个平台播出
《理想汽车》权限管理设置方法
《波斯王子:失落的王冠》剑术大师打法攻略
《红果免费短剧》下载观看方法
使用Selenium在无头Chrome中交互动态菜单和复选框的策略
宝妈做视频号该写什么标签话题?宝妈关注的话题有哪些?
金牛福袋获取攻略
解决CSS容器溢出问题:使用calc()实现精确布局与边距控制
J*aScript包管理器_Npm与Yarn对比
顺丰快递怎么查物流_顺丰快递物流信息实时查询操作指南
汽水音乐网页端访问 汽水音乐官方网页直达
Selenium自动化:利用键盘模拟解决复杂日期输入框输入问题
动漫岛汉化官网网 动漫岛官方动漫汉化地址
VB表达式书写规则解析
《单词速记宝》设置学习计划方法
Win10通知横幅停留时间修改 Win10自定义通知显示时长【技巧】
139邮箱登录入口官网 139邮箱登录入口官网网址
《三国:谋定天下》平民全阶段通用阵容
手机坏了微信聊天记录怎么导出来 新手机恢复聊天记录技巧
电脑“无法访问指定设备、路径或文件”怎么办?五种权限设置方法
动漫之家观看全集库 动漫之家免费资源网地址
之了课堂app做题入口
《顺丰同城骑士》查看我的技能方法
解决Flex容器横向滚动内容截断与偏移问题
windows10怎么更改下载路径_windows10默认存储位置修改教程
解决异步Python机器人中同步操作的阻塞问题
QQ邮箱官方登录页_腾讯出品安全稳定的邮箱服务
Python测试中模块导入路径解析的最佳实践
VS Code的时间线(Timeline)视图:您的代码时光机
解决PHP MySQL数据库更新无响应:SQL查询语法错误解析
2025-11-19
运城市盐湖区信雨科技有限公司是一家深耕海外推广领域十年的专业服务商,作为谷歌推广与Facebook广告全球合作伙伴,聚焦外贸企业出海痛点,以数字化营销为核心,提供一站式海外营销解决方案。公司凭借十年行业沉淀与平台官方资源加持,打破传统外贸获客壁垒,助力企业高效开拓全球市场,成为中小企业出海的可靠合作伙伴。