如何对火炬张量应用变换

我有一个火炬张量z,我想应用一个变换矩阵,mat并使z输出的大小与 完全相同z。这是我正在运行的代码:


def trans(z):

    print(z)

    mat = transforms.Compose([transforms.ToPILImage(),transforms.RandomRotation(90),transforms.ToTensor()])

    z = Variable(mat(z.cpu()).cuda())

    z = nnf.interpolate(z, size=(28, 28), mode='linear', align_corners=False)

    return z

z = trans(z)

但是,我收到此错误:


RuntimeError                              Traceback (most recent call last)

<ipython-input-12-e2fc36889ba5> in <module>()

      3 inputs,targs=next(iter(tst_loader))

      4 recon, mean, var = vae.predict(model, inputs[img_idx])

----> 5 out = vae.generate(model, mean, var)


4 frames

/content/vae.py in generate(model, mean, var)

     90     z = trans(z)

     91     z = Variable(z.cpu().cuda())

---> 92     out = model.decode(z)

     93     return out.data.cpu()

     94 


/content/vae.py in decode(self, z)

     56 

     57     def decode(self, z):

---> 58         out = self.z_develop(z)

     59         out = out.view(z.size(0), 64, self.z_dim, self.z_dim)

     60         out = self.decoder(out)


/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)

    720             result = self._slow_forward(*input, **kwargs)

    721         else:

--> 722             result = self.forward(*input, **kwargs)

    723         for hook in itertools.chain(

    724                 _global_forward_hooks.values(),


/usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py in forward(self, input)

     89 

     90     def forward(self, input: Tensor) -> Tensor:

---> 91         return F.linear(input, self.weight, self.bias)

     92 

     93     def extra_repr(self) -> str:


/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias)

   1674         ret = torch.addmm(bias, input, weight.t())

   1675     else:

-> 1676         output = input.matmul(weight.t())

   1677         if bias is not None:

   1678             output += bias


RuntimeError: mat1 dim 1 must match mat2 dim 0

如何成功应用此旋转变换mat并且不会出现任何错误?


郎朗坤
浏览 96回答 1
1回答

蝴蝶刀刀

问题是interpolate需要一个批次维度,但根据错误消息和transforms. 由于您的输入是空间的(基于size=(28, 28)),您可以通过添加批量维度并更改 来解决该问题mode,因为linear没有针对空间输入实现:z = nnf.interpolate(z.unsqueeze(0), size=(28, 28), mode='bilinear', align_corners=False)如果你z仍然想拥有像 (C, H, W) 这样的形状,那么:z = nnf.interpolate(z.unsqueeze(0), size=(28, 28), mode='bilinear', align_corners=False).squeeze(0)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python