PyTorch 和精确解析表达式之间的差异

在PyTorch中使用多元正态分布时,我决定将其与精确的解析表达式进行比较。

令我惊讶的是,它们之间存在着微小的差异。

这种行为有什么原因吗?

首先,使用 MultivariateNormal 计算概率:

1 from torch.distributions.multivariate_normal import MultivariateNormal

2 import torch

3 sigma = 2

4 m = MultivariateNormal(torch.zeros(2, dtype=torch.float32), torch.eye(2, dtype=torch.float32)*sigma**2)

5 values_temp = torch.zeros(size=(1,2), dtype=torch.float32)

6 out_torch = torch.exp(m.log_prob(values_temp))

7 out_torch 

Out: tensor([0.0398])

其次,可以为这种情况写出精确的公式:

1 import numpy as np

2 out_exact = 1/(2*np.pi*sigma**2) * torch.exp(-torch.pow(values_temp, 2).sum(dim=-1)/(2*sigma**2))

3 out_exact

Out: tensor([0.0398])

他们之间有一个区别:


1 (out_torch - out_exact).sum()

Out: tensor(3.7253e-09)

有人可以帮助我理解这两个片段的行为吗?这两种表达方式哪个更准确呢?也许有人可以在代码的任何部分强调我的错误?


蝴蝶不菲
浏览 109回答 1
1回答

温温酱

大多数现代系统使用IEEE 754标准来表示固定精度浮点值。因此,我们可以确定 pytorch 提供的结果或您计算的“精确”值实际上并不完全等于解析表达式。我们知道这一点是因为表达式的实际值肯定是无理数,而 IEEE 754 不能精确地表示任何无理数。这是使用固定精度浮点表示时的普遍现象.经过进一步分析,我们发现您看到的归一化差异约为机器 epsilon 的量级(即3.7253e-09 / 0.0398约等于torch.finfo(torch.float32).eps),表明差异可能只是浮点运算不准确的结果。为了进一步演示,我们可以编写一个与您所拥有的数学等效的表达式:out_exact = torch.exp(np.log(1/ (2*np.pi*sigma**2)) + (-torch.pow(values_temp, 2).sum(dim=-1)/2/sigma**2))这与我当前安装的 pytorch 给出的值完全一致。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python