了解 Flatten 在 Keras 中的作用并确定何时使用它

我正在尝试了解为时间序列预测开发的模型。它使用一个 Con1D 层和两个 LSTM 层,然后是一个密集层。我的问题是,它应该Flatten() 在 LSTM 和 Denser 层之间使用吗?在我看来,输出应该只有一个值,形状为(None, 1),可以通过Flatten()在 LSTM 和 Dense 层之间使用来实现。没有Flatten(),输出形状将为(None, 30, 1)。或者,我可以从第二个 LSTM 层中删除return_sequences=True,我认为这与Flatten(). 哪种方式更合适?它们会影响损失吗?这是模型。


model = tf.keras.models.Sequential([

    tf.keras.layers.Conv1D(filters=32, kernel_size=3, strides=1, padding="causal", activation="relu", input_shape=(30 ,1)),

    tf.keras.layers.LSTM(32, return_sequences=True),

    tf.keras.layers.LSTM(32, return_sequences=True),

    # tf.keras.layers.Flatten(),

    tf.keras.layers.Dense(1),

    ])

这是没有的模型摘要Flatten()


Model: "sequential"

_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

conv1d (Conv1D)              (None, 30, 32)            128       

_________________________________________________________________

lstm (LSTM)                  (None, 30, 32)            8320      

_________________________________________________________________

lstm_1 (LSTM)                (None, 30, 32)            8320      

_________________________________________________________________

dense (Dense)                (None, 30, 1)             33        

=================================================================

Total params: 16,801

Trainable params: 16,801

Non-trainable params: 0

_________________________________________________________________


繁星coding
浏览 218回答 2
2回答

DIEA

嗯,这取决于你想要达到的目标。我试着给你一些提示,因为我不是 100% 清楚你想要获得什么。如果您的 LSTM 使用return_sequences=True,那么您将返回每个 LSTM 单元格的输出,即每个时间戳的输出。如果您随后添加一个密集层,其中一个将添加到每个 LSTM 层的顶部。如果您将展平层与 一起使用return_sequences=True,那么您基本上是在删除时间维度,就像(None, 30)您的情况一样。然后,您可以添加一个致密层或任何您需要的层。如果你设置return_sequences=False,你只会在你的 LSTM 的最后得到输出(请注意,在任何情况下,由于 LSTM 的功能,它都是基于在之前的时间戳发生的计算),输出将是这样的(None, dim)其中dim等于您在 LSTM 中使用的隐藏单元数(即 32)。同样,在这里,您可以简单地添加一个带有一个隐藏单元的密集层,以获得您正在寻找的东西。

holdtom

请在此处参考此链接>>类似问题。flatten()一般用在输出层之前。最好在 LSTM 层的全部输出上使用 flatten ...它可以在密集层之后而不是在 LSTM 层之后使用吗?我想通过这里的其他答案和评论向专柜学习。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python