tf.function我正在尝试使用贪婪解码方法保存模型。
该代码经过测试并按预期在急切模式(调试)下工作。但是,它在非急切执行中不起作用。
该方法被调用namedtuple,Hyp如下所示:
Hyp = namedtuple(
'Hyp',
field_names='score, yseq, encoder_state, decoder_state, decoder_output'
)
while 循环的调用方式如下:
_, hyp = tf.while_loop(
cond=condition_,
body=body_,
loop_vars=(tf.constant(0, dtype=tf.int32), hyp),
shape_invariants=(
tf.TensorShape([]),
tf.nest.map_structure(get_shape_invariants, hyp),
)
)
这是以下的相关部分body_:
def body_(i_, hypothesis_: Hyp):
# [:] Collapsed some code ..
def update_from_next_id_():
return Hyp(
# Update values ..
)
# The only place where I generate a new hypothesis_ namedtuple
hypothesis_ = tf.cond(
tf.not_equal(next_id, blank),
true_fn=lambda: update_from_next_id_(),
false_fn=lambda: hypothesis_
)
return i_ + 1, hypothesis_
我得到的是ValueError:
ValueError: Input tensor 'hypotheses:0' enters the loop with shape (), but has shape <unknown> after one iteration. To allow the shape to vary across iterations, use the 形状不变量 argument of tf.while_loop to specify a less-specific shape.
这里可能有什么问题?
以下是如何input_signature定义tf.function我想序列化的。
这self.greedy_decode_impl是实际的实现 - 我知道这有点难看,但这self.greedy_decode就是我所说的。
MM们
相关分类