我正在制作一个DataLoaderfrom DataSetin PyTorch。
从加载DataFrame所有dtype作为一个np.float64
result = pd.read_csv('dummy.csv', header=0, dtype=DTYPE_CLEANED_DF)
这是我的数据集类。
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, result):
headers = list(result)
headers.remove('classes')
self.x_data = result[headers]
self.y_data = result['classes']
self.len = self.x_data.shape[0]
def __getitem__(self, index):
x = torch.tensor(self.x_data.iloc[index].values, dtype=torch.float)
y = torch.tensor(self.y_data.iloc[index], dtype=torch.float)
return (x, y)
def __len__(self):
return self.len
准备 train_loader and test_loader
train_size = int(0.5 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=1)
test_loader = DataLoader(dataset=train_dataset)
这是我的csv 文件
如何解决pandas
这里的问题?
守着星空守着你
相关分类