在Keras上使用解码器输入seq2seq模型连接关注层

我正在尝试使用Keras库注意实现序列2序列模型。该模型的框图如下

http://img2.mukewang.com/609a30e700012ca303180572.jpg

模型将输入序列嵌入3D张量。然后,双向lstm创建编码层。接下来,将编码后的序列发送到自定义关注层,该层返回具有每个隐藏节点的关注权重的2D张量。


解码器输入作为一个热矢量注入模型中。现在在解码器(另一个bistlm)中,解码器输入和注意力权重都作为输入传递。解码器的输出被发送到具有softmax激活函数的时间分布密集层,以概率的方式获得每个时间步长的输出。该模型的代码如下:


encoder_input = Input(shape=(MAX_LENGTH_Input, ))


embedded = Embedding(input_dim=vocab_size_input, output_dim= embedding_width, trainable=False)(encoder_input)


encoder = Bidirectional(LSTM(units= hidden_size, input_shape=(MAX_LENGTH_Input,embedding_width), return_sequences=True, dropout=0.25, recurrent_dropout=0.25))(embedded)


attention = Attention(MAX_LENGTH_Input)(encoder)


decoder_input = Input(shape=(MAX_LENGTH_Output,vocab_size_output))


merge = concatenate([attention, decoder_input])


decoder = Bidirectional(LSTM(units=hidden_size, input_shape=(MAX_LENGTH_Output,vocab_size_output))(merge))


output = TimeDistributed(Dense(MAX_LENGTH_Output, activation="softmax"))(decoder)

问题是当我连接注意层和解码器输入时。由于解码器输入是3D张量,而注意是2D张量,因此显示以下错误:


ValueError:Concatenate图层需要输入的形状与concat轴一致,但匹配的轴除外。得到了输入形状:[(无,1024),(无,10,8281)]


如何将2D注意张量转换为3D张量?


函数式编程
浏览 203回答 1
1回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python