我正在尝试用TensorFlow 2实现一个注意力网络。因此,对于每个图像,我只想瞥见一些,即图像的一小部分。为此,我从tensorflow.keras.models.Model中实现了一个子类,这里有一个片段。
class RecurrentAttentionModel(models.Model):
# ...
def call(self, inputs):
l = tf.random.uniform((40,2,), minval=0, maxval=1)
for _ in range(0, self.glimpses):
glimpse = tf.image.extract_glimpse(inputs, size=(self.retina_size, self.retina_size), offsets=l, centered=False, normalized=True)
# some other code...
# update l to take a glimpse somewhere else
return result
现在,上面的代码可以完美地工作和训练,但我的问题是,我有硬编码的40,这是我在数据集中定义的batch_size。我无法在调用方法中读取/获取batch_size,因为变量“inputs”的形式是batch_size似乎是预期行为。当我只用下面的代码初始化l(没有batch_size)Tensor("input_1_77:0", shape=(None, 250, 500, 1), dtype=float32)None
l = tf.random.uniform((2,), minval=0, maxval=1)
它抛出此错误
ValueError: Shape must be rank 2 but is rank 1 for 'recurrent_attention_model_86/ExtractGlimpse' (op: 'ExtractGlimpse') with input shapes: [?,250,500,1], [2], [2]
我完全理解,但我不知道如何根据batch_size实现初始值。
守着一只汪
相关分类