火炬中的一些操作是就地执行的。例如,像 += 这样的速记运算符。
是否可以就地执行其他操作,例如softmax?
我目前正在研究语言处理。该模型在大量词汇表上生成一长串概率分布。这个最终输出张量负责大约 60% 的分配内存。这是一个巨大的问题,因为我需要计算一个 softmax 并且将所需的内存加倍。
这是问题的一个例子。我对张量 t 不感兴趣,只对它的 softmax 感兴趣:
import numpy as np
import torch
import torch.nn.functional as F
t = torch.tensor(np.zeros((30000,30000))).cuda() #allocates 6.71 GB of GPU
softmax = F.softmax(t, 1) #out of memory error
del t #too late, program crashed
即使以下也不起作用:
F.softmax(torch.tensor(np.zeros((30000,30000))).cuda(), 1)
繁花如伊
相关分类