LSTM 模型的问题

我尝试在 PyTorch 中实现 LSTM 模型并遇到这样的问题:损失不减少。我的任务是这样的:我有不同功能的会话。会话长度是固定的,等于 20。我的目标是预测最后一个会话是否被跳过。我试图缩放输入特征,我试图传递target给特征(也许提供的特征绝对没有信息,我认为这应该导致过度拟合并且损失应该接近 0),但我的损失减少总是这样的: 

http://img2.mukewang.com/614d74970001b38907790511.jpg

print(X.shape)

#(82770, 20, 31) where 82770 is count of sessions, 20 is seq_len, 31 is count of features

print(y.shape)

#(82770, 20)

我也定义了get_batches函数。是的,我知道这个生成器中最后一批的问题


def get_batches(X, y, batch_size):

'''Create a generator that returns batches of size

   batch_size x seq_length from arr.

'''

assert X.shape[0] == y.shape[0]

assert X.shape[1] == y.shape[1]

assert len(X.shape) == 3

assert len(y.shape) == 2


seq_len = X.shape[1]

n_batches = X.shape[0]//seq_len


for batch_number in range(n_batches):

    #print(batch_number*batch_size, )

    batch_x = X[batch_number*batch_size:(batch_number+1)*batch_size, :, :]

    batch_y = y[batch_number*batch_size:(batch_number+1)*batch_size, :]

    if batch_x.shape[0] == batch_size:

        yield batch_x, batch_y

    else:

        print('batch_x shape: {}'.format(batch_x.shape))

        break


慕尼黑的夜晚无繁华
浏览 221回答 1
1回答

紫衣仙女

我的失败,忘记缩放输入功能,现在工作正常。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python