1 数据库基类
用来实现数据的大小和索引。
pytorch的Dataset类是一个抽象类,只先实现了三个魔法方法
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
如描述所说,这是一个抽象类,其他数据库类应该是它的子类,所有子类应该重载如下两个函数
* __len__函数,用来提供数据库的大小 * __getitem__函数,支持一个整形索引,重来获取单个数据,范围是__len__定义的,范围是[0, len(self)]
2 数据库的合并
其中Dataset.add函数返回一个ConcatDataset类,这个类实现了数据库的合并,针对从基类DataSet派生类,ConcatDataset实现了不同源的数据库整合,数据存储在链表datasets中,通过累计长度,可以查询不同的datasets,这个类的详细描述如下:
class ConcatDataset(Dataset): """ Dataset to concatenate multiple datasets. Purpose: useful to assemble different existing datasets, possibly large-scale datasets as the concatenation operation is done in an on-the-fly manner. Arguments: datasets (sequence): List of datasets to be concatenated """ @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets): super(ConcatDataset, self).__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' self.datasets = list(datasets) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] @property def cummulative_sizes(self): warnings.warn("cummulative_sizes attribute is renamed to " "cumulative_sizes", DeprecationWarning, stacklevel=2) return self.cumulative_sizes
注意的是给定索引的时候,需要先判定是哪个数据集,然后判定数据集的索引,getitem函数使用了bitsect.bitsec_right查找数据库索引,然后计算该数据库的内部索引。
3 子数据库Subset
ConcatDataset将不同数据集组成链表,在这个大数据集的基础上,通过索引可以建立一个虚拟数据集,实现不同数据集的一个子集,如果通过随机函数实现索引,可以混合所有数据集,Subset数据集的源码如下:
import torch from torch.utils.data import Dataset, ConcatDataset, Subset, random_splitclass MyDataset(Dataset): def __init__(self, t=0, name="myDataset"): super(MyDataset, self).__init__() self.nums = [] if t == 0: self.nums = [torch.randn(1).item() for _ in range(100)] elif t == 1: self.nums = list(range(230)) elif t == 2: self.nums = torch.linspace(-1, 1, 250).data.numpy() self.name = name self.t = t def __getitem__(self, i): return self.nums[i] def __len__(self): return len(self.nums)if __name__ == "__main__": ds0 = MyDataset(0, "type_0") ds1 = MyDataset(1, "type_1") ds2 = MyDataset(2, "type_2") ds = ds0 + ds1 ds = ds + ds2 print(ds.datasets[0].datasets[0].name,ds.datasets[0].datasets[1].name,ds.datasets[1].name) print(len(ds)) dss = random_split(ds, [310, 270]) # 第二个参数是长度,累积和是数据集长度
此处要注意的是 ds0和ds1首先进行合并,形成一个ConcatDataset,然后和ds2合并,再形成一个ConcatDataset,因此ds的datasets长度为2,第一个数据是ConcatDataset,第二个数据是MyDataset(2, "type_2")
4 Tensor向量化数据库
内存数据需要转为Tensor才能使用,pytorch提供了TensorDataset类可以直接对Tensor数据进行数据库封装
class TensorDataset(Dataset): """Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Arguments: *tensors (Tensor): tensors that have the same size of the first dimension. """ def __init__(self, *tensors): assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
最后介绍一个对数据集进行子集切分的函数
def random_split(dataset, lengths): """ Randomly split a dataset into non-overlapping new datasets of given lengths. Arguments: dataset (Dataset): Dataset to be split lengths (sequence): lengths of splits to be produced """ if sum(lengths) != len(dataset): raise ValueError("Sum of input lengths does not equal the length of the input dataset!") indices = randperm(sum(lengths)) return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
数据集源码解读完毕了,虽然这是一个基类,但是提供了一个可迭代的思想,类似于道教的一分为二,二生四,......,提供了数据索引,合并,tensor,子集的等基本功能。
torchvision.dataset可以使用的数据集
LSUN, 大规模场景理解 LSUNClass ImageFolder, 图片目录的数据集 DatasetFolder 文件目录的数据集 CocoCaptions, 微软 MS COCO 相关的 Image Captioning CocoDetection MS COCO数据集目标检测CIFAR10, 该数据集共有60000张彩色图像分类数据集CIFAR100 数据集包含100小类,每小类包含600个图像,其中有500个训练图像和100个测试图像。100类被分组为20个大类。每个图像带有1个小类的“fine”标签和1个大类“coarse”标签。 STL10 * 10个类:飞机,鸟,汽车,猫,鹿,狗,马,猴子,船,卡车。* 图像为96x96像素,颜色。* 500个训练图像(10个预定义的折叠),每个类800个测试图像。 MNIST, MNIST数据集是一个手写体数据集 EMNIST, 扩展手写体数据集 FashionMNIST FashionMNIST 是一个替代 MNIST 手写数字集[1] 的图像数据集。 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。 SVHN PhotoTour FakeData SEMEION 图像处理_Semeion Handwritten Digit Data Set(Semeion手写体数字数据集) Omniglot Omniglot是一个在线的语言文字百科,其内涵盖了已知的全部书写系统
作者:readilen
链接:https://www.jianshu.com/p/5b65c43d45c0