1 数据库基类
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
* __len__函数,用来提供数据库的大小 * __getitem__函数,支持一个整形索引,重来获取单个数据,范围是__len__定义的,范围是[0, len(self)]
2 数据库的合并
class ConcatDataset(Dataset): """ Dataset to concatenate multiple datasets. Purpose: useful to assemble different existing datasets, possibly large-scale datasets as the concatenation operation is done in an on-the-fly manner. Arguments: datasets (sequence): List of datasets to be concatenated """ @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets): super(ConcatDataset, self).__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' self.datasets = list(datasets) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] @property def cummulative_sizes(self): warnings.warn("cummulative_sizes attribute is renamed to " "cumulative_sizes", DeprecationWarning, stacklevel=2) return self.cumulative_sizes
3 子数据库Subset
import torch from torch.utils.data import Dataset, ConcatDataset, Subset, random_splitclass MyDataset(Dataset): def __init__(self, t=0, name="myDataset"): super(MyDataset, self).__init__() self.nums = [] if t == 0: self.nums = [torch.randn(1).item() for _ in range(100)] elif t == 1: self.nums = list(range(230)) elif t == 2: self.nums = torch.linspace(-1, 1, 250).data.numpy() self.name = name self.t = t def __getitem__(self, i): return self.nums[i] def __len__(self): return len(self.nums)if __name__ == "__main__": ds0 = MyDataset(0, "type_0") ds1 = MyDataset(1, "type_1") ds2 = MyDataset(2, "type_2") ds = ds0 + ds1 ds = ds + ds2 print(ds.datasets[0].datasets[0].name,ds.datasets[0].datasets[1].name,ds.datasets[1].name) print(len(ds)) dss = random_split(ds, [310, 270]) # 第二个参数是长度,累积和是数据集长度
此处要注意的是 ds0和ds1首先进行合并,形成一个ConcatDataset,然后和ds2合并,再形成一个ConcatDataset,因此ds的datasets长度为2,第一个数据是ConcatDataset,第二个数据是MyDataset(2, "type_2")
4 Tensor向量化数据库
class TensorDataset(Dataset): """Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Arguments: *tensors (Tensor): tensors that have the same size of the first dimension. """ def __init__(self, *tensors): assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
def random_split(dataset, lengths): """ Randomly split a dataset into non-overlapping new datasets of given lengths. Arguments: dataset (Dataset): Dataset to be split lengths (sequence): lengths of splits to be produced """ if sum(lengths) != len(dataset): raise ValueError("Sum of input lengths does not equal the length of the input dataset!") indices = randperm(sum(lengths)) return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
LSUN, 大规模场景理解 LSUNClass ImageFolder, 图片目录的数据集 DatasetFolder 文件目录的数据集 CocoCaptions, 微软 MS COCO 相关的 Image Captioning CocoDetection MS COCO数据集目标检测CIFAR10, 该数据集共有60000张彩色图像分类数据集CIFAR100 数据集包含100小类,每小类包含600个图像,其中有500个训练图像和100个测试图像。100类被分组为20个大类。每个图像带有1个小类的“fine”标签和1个大类“coarse”标签。 STL10 * 10个类:飞机,鸟,汽车,猫,鹿,狗,马,猴子,船,卡车。* 图像为96x96像素,颜色。* 500个训练图像(10个预定义的折叠),每个类800个测试图像。 MNIST, MNIST数据集是一个手写体数据集 EMNIST, 扩展手写体数据集 FashionMNIST FashionMNIST 是一个替代 MNIST 手写数字集[1] 的图像数据集。 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。 SVHN PhotoTour FakeData SEMEION 图像处理_Semeion Handwritten Digit Data Set(Semeion手写体数字数据集) Omniglot Omniglot是一个在线的语言文字百科,其内涵盖了已知的全部书写系统