猿问

具有伯努利分布的 TensorFlow Probability MCMC

我需要使用 TensorFlow Probability 从伯努利分布中采样来实现马尔可夫链蒙特卡罗。但是,我的尝试显示的结果与我对伯努利分布的期望不一致。


我在这里修改了 tfp.mcmc.sample_chain(从对角线方差高斯采样)示例的文档中给出的示例,以从伯努利分布中提取。由于伯努利分布是离散的,我使用了 RandomWalkMetropolis 转换内核而不是 Hamiltonian Monte Carlo 内核,我预计它不会工作,因为它计算梯度。


这是代码:


import numpy as np

import matplotlib.pyplot as plt

import seaborn as sns

import tensorflow as tf

import tensorflow_probability as tfp

tfd = tfp.distributions


def make_likelihood(event_prob):

    return tfd.Bernoulli(probs=event_prob,dtype=tf.float32)



dims=1

event_prob = 0.3

num_results = 30000

likelihood = make_likelihood(event_prob)



states, kernel_results = tfp.mcmc.sample_chain(

    num_results=num_results,

    current_state=tf.zeros(dims),

    kernel = tfp.mcmc.RandomWalkMetropolis(

              target_log_prob_fn=likelihood.log_prob,

              new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=1.0),

              seed=124

             ),

    num_burnin_steps=5000)


chain_vals = states


# Compute sample stats.

sample_mean = tf.reduce_mean(states, axis=0)

sample_var = tf.reduce_mean(

    tf.squared_difference(states, sample_mean),

    axis=0)


#initialize the variable

init_op = tf.global_variables_initializer()


#run the graph

with tf.Session() as sess:

    sess.run(init_op) 

    [sample_mean_, sample_var_, chain_vals_] = sess.run([sample_mean,sample_var,chain_vals])


chain_samples = (chain_vals_[:] )   

print ('Sample mean = {}'.format(sample_mean_))

print ('Sample var = {}'.format(sample_var_))

fig, axes = plt.subplots(2, 1)

fig.set_size_inches(12, 10)


axes[0].plot(chain_samples[:])

axes[0].title.set_text("values sample chain tfd.Bernoulli")

sns.kdeplot(chain_samples[:,0], ax=axes[1], shade=True)

axes[1].title.set_text("chain tfd.Bernoulli distribution")

fig.tight_layout()

plt.show()

我希望看到区间 [0,1] 中马尔可夫链状态的值。


马尔可夫链值的结果看起来不像伯努利分布的预期结果,KDE 图也不是,如下图所示:

我的示例是否存在概念上的缺陷,或者在使用 TensorFlow Probability API 时是否存在错误?

或者使用离散分布(例如伯努利分布)的马尔可夫链蒙特卡罗的 TF.Probability 实现可能存在问题?


一只萌萌小番薯
浏览 325回答 1
1回答
随时随地看视频慕课网APP

相关分类

Python
我要回答