PyTorch:时间序列任务的数据加载器

我有一个 Pandas 数据框,其中n行和k列加载到内存中。我想为预测任务获取批次,其中批次的第一个训练示例应该具有(q, k)参考q原始数据帧中的行数(例如 0:128)的形状。下一个例子应该是(128:256, k)等等。因此,最终,一批应该具有(32, q, k)与批量大小相对应的 32 形状。


由于TensorDatasetfromdata_utils在这里不起作用,我想知道最好的方法是什么。我尝试使用将qnp.array_split()值的可能拆分数作为第一维,以便编写自定义 DataLoader,但由于并非所有数组都具有相同的形状,因此不能保证重新整形。


这是一个更清楚的最小示例。在这种情况下,批量大小为 3,q为 2:


import pandas as pd

import numpy as np

df = pd.DataFrame(data=np.arange(0,30).reshape(10,3),columns=['A','B','C'])

数据集:


    A   B   C

0   0   1   2

1   3   4   5

2   6   7   8

3   9   10  11

4   12  13  14

5   15  16  17

6   18  19  20

7   21  22  23

8   24  25  26

9   27  28  29

在这种情况下,第一批的形状应该是 (3,2,3),看起来像:


array([[[ 0.,  1.,  2.],

        [ 3.,  4.,  5.]],


       [[ 3.,  4.,  5.],

        [ 6.,  7.,  8.]],


       [[ 6.,  7.,  8.],

        [ 9., 10., 11.]]])


幕布斯7119047
浏览 296回答 3
3回答

红糖糍粑

我最终也编写了自定义数据集,尽管它与上面的答案有点不同:class TimeseriesDataset(torch.utils.data.Dataset):       def __init__(self, X, y, seq_len=1):        self.X = X        self.y = y        self.seq_len = seq_len    def __len__(self):        return self.X.__len__() - (self.seq_len-1)    def __getitem__(self, index):        return (self.X[index:index+self.seq_len], self.y[index+self.seq_len-1])用法如下:train_dataset = TimeseriesDataset(X_lstm, y_lstm, seq_len=4)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 3, shuffle = False)for i, d in enumerate(train_loader):    print(i, d[0].shape, d[1].shape)>>># shape: tuple((batch_size, seq_len, n_features), (batch_size))0 torch.Size([3, 4, 2]) torch.Size([3])

饮歌长啸

您可以编写 TensorDataset 的模拟。为此,您需要从 Dataset 类继承。from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):    def __init__(self, data_frame, q):        self.data = data_frame.values        self.q = q    def __len__(self):        return self.data.shape[0] // self.q    def __getitem__(self, index):        return self.data[index * self.q: (index+1) * self.q]

慕工程0101907

另一种方法是使用开源库 pytorch_forecasting。时间序列数据集的链接可以在这里找到使用此数据集的摘录:该数据集自动执行常见任务,例如变量的缩放和编码标准化目标变量有效地将 pandas 数据帧中的时间序列转换为火炬张量持有关于未来已知和未知的静态和时变变量的信息持有相关类别的信息(如假期)数据增强的下采样生成推理、验证和测试数据集
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python