DataLoader的使用
1 | import torch |
1 | # GPU version |
Dataset和DataLoader总结:
- Dataset是一个抽象类,需要派生一个子类构造数据集,需要改写的方法有
__init__, __getitem__, __len__
等等。 - DataLoader是一个迭代器,方便我们访问Dataset里的对象,值得注意的
num_workers
的参数设置:如果放在CPU上跑,可以不管,但是放在GPU上则需要设置为0;或者在DataLoader操作之后将Tensor放在GPU上 - 数据和标签是tuple元组的形式,使用DataLoader然后使用enumerate函数访问他们
顺便带一下CIFAR-10数据集的说明
该数据集共有60000张彩色图像,这些图像时32 * 32,分为10个类,每类6000张图片。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的随机排列就组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看,每一类都有5000张图
下面这幅图就是列举了10各类,每一类展示了随机的10张图片: