使用 ModuleDict,我有:输入类型(torch.cuda.FloatTensor)

我正在尝试我的__init__功能:



        self.downscale_time_conv = np.empty(8, dtype=object)

        for i in range(8):

            self.downscale_time_conv[i] = torch.nn.ModuleDict({})

但在我的forward,我有:


        down_out = False

        for i in range(8):

            if not down_out:

                down_out = self.downscale_time_conv[i][side](inputs)

            else:

                down_out += self.downscale_time_conv[i][side](inputs)

我得到:


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

Python火炬张量


Smart猫小萌
浏览 114回答 1
1回答

萧十郎

      self.downscale_time_conv = torch.nn.ModuleList()        for i in range(8):            self.downscale_time_conv.append(torch.nn.ModuleDict({}))这解决了它。显然我需要使用ModuleList
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python