猿问

是否有任何pytorch函数可以将张量的特定连续尺寸合并为一个?

让我们称呼我正在寻找的函数“ magic_combine”,它可以组合我赋予它的张量的连续尺寸。更具体地说,我希望它执行以下操作:


a = torch.zeros(1, 2, 3, 4, 5, 6)  

b = a.magic_combine(2, 5) # combine dimension 2, 3, 4 

print(b.size()) # should be (1, 2, 60, 6)

我知道torch.view()可以做类似的事情。但我只是想知道是否还有其他更优雅的方法可以达成目标?


慕的地8271018
浏览 248回答 2
2回答

犯罪嫌疑人X

我不确定“更优雅的方式”是什么想法,但是Tensor.view()优点是不为视图重新分配数据(原始张量和视图共享相同的数据),从而使此操作轻巧。如@UmangGupta所述,但是包装此函数以实现您想要的内容很简单,例如:import torchdef magic_combine(x, dim_begin, dim_end):    combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])    return x.view(combined_shape)a = torch.zeros(1, 2, 3, 4, 5, 6)b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4print(b.size())# torch.Size([1, 2, 60, 6])
随时随地看视频慕课网APP

相关分类

Python
我要回答