我目前正在深入研究 tensorflow,我对正确使用tf.nn.Conv2d(input, filter, strides, padding). 尽管乍一看看起来很简单,但我无法听到以下问题:
的使用filter, strides, padding对我来说很清楚。然而,不清楚的是input.
我来自强化学习 Atari (Pong) 问题,在该问题中我想使用网络进行批量训练,并且(以一定的概率)也用于每一步的预测。这意味着,对于训练,我正在为网络提供一整批(假设为 100 ),每个单元由 3 个帧组成,大小为 160、128。使用 tensorflow 的 NHWC 格式,我的输入input将是 atf.placeholder形状(100,160,128,3)。所以为了训练,我喂了 100 个 160x128x3 的包。
但是,在某种情况下预测来自我的网络的输出(用乒乓球拍向上或向下)时,我只向网络馈送一包160x128x3(即一包三帧)。现在这是 tensorflow 崩溃的地方。它期望(100,160,128,3)但接收(1,160,128,3).
现在我很困惑。我显然不想将批量大小设置为 1 并且总是只提供一个包进行训练。但是我怎么能在这里继续呢?这将如何实施tf.nn.conv2d?
如果有人能在这里引导我走向正确的方向,我将不胜感激
繁星点点滴滴
慕斯709654
相关分类