在PyTorch中使用多元正态分布时,我决定将其与精确的解析表达式进行比较。
令我惊讶的是,它们之间存在着微小的差异。
这种行为有什么原因吗?
首先,使用 MultivariateNormal 计算概率:
1 from torch.distributions.multivariate_normal import MultivariateNormal
2 import torch
3 sigma = 2
4 m = MultivariateNormal(torch.zeros(2, dtype=torch.float32), torch.eye(2, dtype=torch.float32)*sigma**2)
5 values_temp = torch.zeros(size=(1,2), dtype=torch.float32)
6 out_torch = torch.exp(m.log_prob(values_temp))
7 out_torch
Out: tensor([0.0398])
其次,可以为这种情况写出精确的公式:
1 import numpy as np
2 out_exact = 1/(2*np.pi*sigma**2) * torch.exp(-torch.pow(values_temp, 2).sum(dim=-1)/(2*sigma**2))
3 out_exact
Out: tensor([0.0398])
他们之间有一个区别:
1 (out_torch - out_exact).sum()
Out: tensor(3.7253e-09)
有人可以帮助我理解这两个片段的行为吗?这两种表达方式哪个更准确呢?也许有人可以在代码的任何部分强调我的错误?
温温酱
相关分类