如何加入编码器和解码器

我构建了以下编码器-解码器架构,编码器和解码器都可以单独工作:


from tensorflow.keras.layers import LSTM, Input, Reshape, Lambda

from tensorflow.keras.models import Model

from tensorflow.keras import backend as K


WORD_TO_INDEX = {"foo": 0, "bar": 1}


MAX_QUERY_WORD_COUNT = 10

QUERY_ENCODING_SIZE = 15


# ENCODER

query_encoder_input = Input(shape=(None, len(WORD_TO_INDEX)), name="query_encoder_input")

query_encoder_output = LSTM(QUERY_ENCODING_SIZE, name="query_encoder_lstm")(query_encoder_input)

query_encoder = Model(inputs=query_encoder_input, outputs=query_encoder_output)

# DECODER

query_decoder_input = Input(shape=(QUERY_ENCODING_SIZE,), name="query_decoder_input")

query_decoder_reshape = Reshape((1, QUERY_ENCODING_SIZE), name="query_decoder_reshape")(query_decoder_input)

query_decoder_lstm = LSTM(QUERY_ENCODING_SIZE, name="query_decoder_lstm", return_sequences=True, return_state=True)

recurrent_input, state_h, state_c = query_decoder_lstm(query_decoder_reshape)

states = [state_h, state_c]

query_decoder_outputs = []

for _ in range(MAX_QUERY_WORD_COUNT):

    recurrent_input, state_h, state_c = query_decoder_lstm(recurrent_input, initial_state=states)

    query_decoder_outputs.append(recurrent_input)

    states = [state_h, state_c]

query_decoder_output = Lambda(lambda x: K.concatenate(x, axis=1), name="query_decoder_concat")(query_decoder_outputs)

query_decoder = Model(inputs=query_decoder_input, outputs=query_decoder_output)



是我用于解码器的模板。(请参阅“如果我不想使用教师强制进行培训怎么办?”部分。)

我依靠这些 StackOverflow 问题(尤其是最后一个)来弄清楚如何将模型组合在一起。

这个错误是什么意思,我该如何解决?


LEATH
浏览 74回答 0
0回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python