Tensorflow 报告“TypeError:预期单个张量时的张量列表”

我正在使用 Tensorflow 编写模型。我的条件语句的一部分,例如:

new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: tf.constant([1, src_shape[0]]))

并且src_shape是 的结果tf.shape()

它报告TypeError: List of Tensors when single Tensor expected。我知道这是因为tf.constant([1, src_shape[0]])是张量列表,但我不知道如何以合法的方式实现我的代码。

我试图删除tf.constant()喜欢

new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: [1, src_shape[0]])

但它报告 ValueError: Incompatible return values of true_fn and false_fn: The two structures don't have the same nested structure.


婷婷同学_
浏览 284回答 2
2回答

素胚勾勒不出你

一种方法是使用 tf.stack,它将 rank-R 张量列表堆叠成一个 rank-(R+1) 张量。lambda: tf.stack([1, src_shape[0]], axis=0)另一种解决方案是使用 tf.concat 使用正确的 tf.reshape 命令。

catspeake

我试过了tf.convert_to_tensor([1, src_shape[0]])。这是一种替代解决方案。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python