猿问

为什么tensorflow SimpleRNNCell的输入是3D的?

我正在测试 tensorflow tf.keras.layers.SimpleRNNCell。我觉得这太奇怪了。我认为 RNN 单元是接收先前状态a^{<t-1>}和当前数据输入的单元x^{<t>}。它将输出一个新的状态a^{<t>}和当前的 predict \hat{y}^{<t>}

因此,SimpleRNNCell如果设置了 batch_size,则输入应该是 2d。我认为输入应该是[batch_size,feature_size]. 但是,如果输入是 2D,则会引发错误。而之前的状态也需要3D。


正确的代码如下:


batch_data = tf.ones((batch_size, time_steps, label_num))    

simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)

initial_state = tf.zeros((batch_size, time_steps, units))

output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)

但是,我认为以下代码是正确的。但我错了


batch_data = tf.ones((batch_size, label_num))    

simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)

initial_state = tf.zeros((batch_size, units))

output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)

所以我的问题是为什么输入SimpleRNNCell是3D?


临摹微笑
浏览 228回答 2
2回答

侃侃无极

第三维是许多特征,用于多元时间序列。在您的情况下,对于特征数使用 1。例如,您可以认为张量 [1,2,3] 是 1D,[[1,2,3]] 是形状为 (1,3) 的 2D, [[[1,2,3]]] 是具有形状 (1,1,3) 等的 3D。因此,如果我们取一个输入样本,一个变量时间序列将是 [[1,2,3]],但两个变量时间序列可能看起来像 [[1,2,3], [7,8,9]]。

心有法竹

RNN(或 LSTM)的输入应该具有 [batch_size, timesteps, nbr_features] 的形状
随时随地看视频慕课网APP

相关分类

Python
我要回答