使用 [:, :, 0] 的 TF2 / Keras 切片张量

在 TF 2.0 Beta 中,我正在尝试:


x = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)

print(x.shape) # (None, 240, 2)

a = x[:, :, 0]

print(a.shape) # <unknown>

在 TF 1.x 中,我可以这样做:


x = tf1.placeholder(tf1.float32, (None, 240, 2)

a = x[:, :, 0]

它会正常工作。如何在 TF 2.0 中实现这一点?我认为


tf.split(x, 2, axis=2)

可能有效,但是我想使用切片而不是硬编码 2(轴 2 的暗淡)。


holdtom
浏览 139回答 1
1回答

慕工程0101907

不同之处在于返回的对象Input代表一个层,而不是任何类似于占位符或张量的东西。因此,x上面的 tf 2.0 代码中的 是层对象,而xtf 1.x 代码中的 是张量的占位符。您可以定义一个切片层来执行该操作。有现成可用的层,但是对于像这样的简单切片,Lambda层非常容易阅读,并且可能最接近您在 tf 1.x 中使用的切片方式。像这样的东西:input_lyr = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)sliced_lyr = tf.keras.layers.Lambda(lambda x: x[:,:,0])您可以像这样在您的 keras 模型中使用它:model = tf.keras.models.Sequential([&nbsp; &nbsp; input_lyr,&nbsp; &nbsp; sliced_lyr,&nbsp; &nbsp; # ...&nbsp; &nbsp; # <other layers>&nbsp; &nbsp; # ...])当然,以上是特定于keras模型的。相反,如果你有一个张量而不是一个 keras 层对象,那么切片就像以前一样工作。像这样的东西:my_tensor = tf.random.uniform((8,240,2))sliced = my_tensor[:,:,0]print(my_tensor.shape)print(sliced.shape)输出:(8, 240, 2)(8, 240)正如预期的那样
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python