猿问

Keras:收集张量更改批量维度

我有一个形状为 (5, 2) 的输入张量,代表 2D 空间中的五个点。


我想取第一点,然后从所有五点中减去它。


仔细阅读,我想我可以用它K.gather来切片和重复第一层。


在 Lambda 层中应用它后,批处理维度被覆盖:


_input = Input(shape=(5, 2))

x = Reshape((5 * 2,))(_input)

x_ = Lambda(lambda t: K.gather(t, [0, 1] * 5))(x)

结果是:


__________________________________________________________________________________________________

Layer (type)                    Output Shape         Param #     Connected to                     

==================================================================================================

input_1 (InputLayer)            (None, 5, 2)         0                                            

__________________________________________________________________________________________________

reshape_1 (Reshape)             (None, 10)           0           input_1[0][0]                    

__________________________________________________________________________________________________

lambda_1 (Lambda)               (10, 10)             0           reshape_1[0][0]                  

__________________________________________________________________________________________________

我究竟做错了什么?


另外,有没有更简单的方法来做到这一点?


墨色风雨
浏览 219回答 2
2回答

幕布斯6054654

gather函数从批处理(0th)轴返回提供的索引值。因此,它为我们提供了形状为 (10, 10) 的批次中的第一个 (index:0) 和第二个 (index:1) 样本 (形状 (10,)) 的列表 (length=10) 而我们想要第一个批次中每个样本的(索引:0)和第二(索引:1)特征点。为了解决这个问题,我们可以在使用gather函数之前转置张量,以便gather函数选择正确的值,最后生成的张量应该再次转置。_input = Input(shape=(5, 2))x = Reshape((5 * 2,))(_input)x_ = Lambda(lambda t: K.transpose(K.gather(K.transpose(t), [0, 1]*5)))(x)输出:_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_1 (InputLayer)         [(None, 5, 2)]            0         _________________________________________________________________reshape (Reshape)            (None, 10)                0         _________________________________________________________________lambda (Lambda)              (None, 10)                0         =================================================================

慕雪6442864

如果你使用tf.gather(),你可以避免使用@bit01 描述的转置操作。中有一个axis论点tf.gather()。_input = Input(shape=(5, 2))x = Reshape((5 * 2,))(_input)x_ = Lambda(lambda t: tf.gather(t, [0, 1]*5, axis=1))(x)# Layer (type)                 Output Shape              Param #   # =================================================================# input_2 (InputLayer)         (None, 5, 2)              0         # _________________________________________________________________# reshape_2 (Reshape)          (None, 10)                0         # _________________________________________________________________# lambda_1 (Lambda)            (None, 10)                0         # =================================================================
随时随地看视频慕课网APP

相关分类

Python
我要回答