首先,这可能是一个非常基本和愚蠢的问题,但我已经尝试了很多事情并四处搜索无济于事,所以我来了。问题如下:
我有一个张量,出于各种原因,我想找到“它通向何处”。理论上这样做的方法是根据文档等查看 my_tensor.op.outputs,但这似乎总是指向 my_tensor本身!
我以前很容易走另一条路,这意味着我可以通过使用 my_tensor.op.inputs 来获得输入张量,但由于某种原因,“输出”没有达到预期的效果。
这是一个简单的例子:
import tensorflow as tf
a = tf.placeholder(tf.uint8, name='a')
b = tf.placeholder(tf.uint8, name='b')
my_sum = tf.identity(a + b, name='my_sum')
graph = tf.get_default_graph()
# I should have 4 ops at this point, as validated by:
print(graph.get_operations())
>> [<tf.Operation 'a' type=Placeholder>, <tf.Operation 'b' type=Placeholder>, <tf.Operation 'add' type=Add>, <tf.Operation 'my_sum' type=Identity>]
# So let's try get the output of 'a':
print(list(a.op.outputs))
>> [<tf.Tensor 'a:0' shape=<unknown> dtype=uint8>]
如果您尝试了上述操作,您会看到您回到了 'a'...
再次,运行 my_sum.op.inputs 给出了 'add' 操作,并且运行得更远让我们回到 'a' 和 'b'正如预期的那样:
input_to_my_sum = list(my_sum.op.inputs)[0]
print(input_to_my_sum)
>> Tensor("add:0", dtype=uint8)
print(list(input_to_my_sum.op.inputs))
>> [<tf.Tensor 'a:0' shape=<unknown> dtype=uint8>, <tf.Tensor 'b:0' shape=<unknown> dtype=uint8>]
但反过来呢?没有这样的运气:
print(list(input_to_my_sum.op.outputs))
>> [<tf.Tensor 'add:0' shape=<unknown> dtype=uint8>]
print('This is no fun at all')
>> This is no fun at all
那么我做错了什么?
我也试过使用(不推荐使用的)op.values() 没有成功,我很困惑,因为文档明确指出这应该给我操作的输出(来自https://www.tensorflow.org /api_docs/python/tf/Operation):
输出
表示此操作输出的张量对象列表。
(我检查了 a.op.__class__ 是正确的类,并且我正在阅读正确的文档)。
(总结一下,操作的 node_def 也没有显示输出字段的迹象......)。
提前感谢您的任何建议!
编辑(由于 Yuxin 的回答):
只是为了澄清,将等的输出的输出保持在相同的张量上。我正在尝试达到下一个张量/操作。
紫衣仙女
相关分类