调用函数中的变量batch_size

我正在尝试用TensorFlow 2实现一个注意力网络。因此,对于每个图像,我只想瞥见一些,即图像的一小部分。为此,我从tensorflow.keras.models.Model中实现了一个子类,这里有一个片段。


class RecurrentAttentionModel(models.Model):

# ...


def call(self, inputs):


    l = tf.random.uniform((40,2,), minval=0, maxval=1)


    for _ in range(0, self.glimpses):

        glimpse = tf.image.extract_glimpse(inputs, size=(self.retina_size, self.retina_size), offsets=l, centered=False, normalized=True)


        # some other code...

        # update l to take a glimpse somewhere else



    return result           

现在,上面的代码可以完美地工作和训练,但我的问题是,我有硬编码的40,这是我在数据集中定义的batch_size。我无法在调用方法中读取/获取batch_size,因为变量“inputs”的形式是batch_size似乎是预期行为。当我只用下面的代码初始化l(没有batch_size)Tensor("input_1_77:0", shape=(None, 250, 500, 1), dtype=float32)None


l = tf.random.uniform((2,), minval=0, maxval=1)

它抛出此错误


ValueError: Shape must be rank 2 but is rank 1 for 'recurrent_attention_model_86/ExtractGlimpse' (op: 'ExtractGlimpse') with input shapes: [?,250,500,1], [2], [2]

我完全理解,但我不知道如何根据batch_size实现初始值。


富国沪深
浏览 118回答 1
1回答

守着一只汪

您可以使用 动态提取批大小维度。tf.shapel = tf.random.normal(tf.stack([tf.shape(inputs)[0], 2]), minval=0, maxval=1))
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python