我正在测试 tensorflow tf.keras.layers.SimpleRNNCell
。我觉得这太奇怪了。我认为 RNN 单元是接收先前状态a^{<t-1>}
和当前数据输入的单元x^{<t>}
。它将输出一个新的状态a^{<t>}
和当前的 predict \hat{y}^{<t>}
。
因此,SimpleRNNCell如果设置了 batch_size,则输入应该是 2d。我认为输入应该是[batch_size,feature_size]. 但是,如果输入是 2D,则会引发错误。而之前的状态也需要3D。
正确的代码如下:
batch_data = tf.ones((batch_size, time_steps, label_num))
simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)
initial_state = tf.zeros((batch_size, time_steps, units))
output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)
但是,我认为以下代码是正确的。但我错了
batch_data = tf.ones((batch_size, label_num))
simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)
initial_state = tf.zeros((batch_size, units))
output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)
所以我的问题是为什么输入SimpleRNNCell是3D?
侃侃无极
心有法竹
相关分类