输入张量 <name> 进入形状为 () 的循环,但在一次迭代后形状为 <unknown>

tf.function我正在尝试使用贪婪解码方法保存模型。


该代码经过测试并按预期在急切模式(调试)下工作。但是,它在非急切执行中不起作用。


该方法被调用namedtuple,Hyp如下所示:


Hyp = namedtuple(

    'Hyp',

    field_names='score, yseq, encoder_state, decoder_state, decoder_output'

)

while 循环的调用方式如下:


_, hyp = tf.while_loop(

    cond=condition_,

    body=body_,

    loop_vars=(tf.constant(0, dtype=tf.int32), hyp),

    shape_invariants=(

        tf.TensorShape([]),

        tf.nest.map_structure(get_shape_invariants, hyp),

    )

)

这是以下的相关部分body_:


def body_(i_, hypothesis_: Hyp):


    # [:] Collapsed some code ..


    def update_from_next_id_():

        return Hyp(

            # Update values ..

        )


    # The only place where I generate a new hypothesis_ namedtuple

    hypothesis_ = tf.cond(

        tf.not_equal(next_id, blank),

        true_fn=lambda: update_from_next_id_(),

        false_fn=lambda: hypothesis_

    )


    return i_ + 1, hypothesis_

我得到的是ValueError:


ValueError: Input tensor 'hypotheses:0' enters the loop with shape (), but has shape <unknown> after one iteration. To allow the shape to vary across iterations, use the 形状不变量 argument of tf.while_loop to specify a less-specific shape.


这里可能有什么问题?


以下是如何input_signature定义tf.function我想序列化的。


这self.greedy_decode_impl是实际的实现 - 我知道这有点难看,但这self.greedy_decode就是我所说的。

慕尼黑8549860
浏览 94回答 1
1回答

MM们

好吧,事实证明tf.concat([hypothesis_.yseq,&nbsp;next_id],&nbsp;axis=0),本来应该是tf.concat([hypothesis_.yseq,&nbsp;next_id],&nbsp;axis=-1),公平地说,错误消息有点提示您在哪里查看,但“有帮助”不足以描述它。我TensorSpec通过连接错误的轴来违反了,仅此而已,但 Tensorflow 还无法直接指向受影响的张量。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python