0%

Dataloader的使用

DataLoader的使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch 
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as Dataloader
import numpy as np
"""
# CPU version

# 注意,这里如果只写了Dataset而不是Dataset.Dataset,则会报错,因为Dataset是module模块,而不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset
class subDataset(Dataset.Dataset):
def __init__(self, Data, Label):
self.Data = Data
self.Label = Label

def __len__(self):
return len(self.Data)

def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.Tensor(self.Label[index])
return data, label

Data = np.asarray([[1, 2], [3, 4], [5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])

if __name__ == '__main__':
dataset = subDataset(Data, Label)
print(dataset)
print('dataset.size = ', dataset.__len__())
print(dataset.__getitem__(0))
print(dataset[0])

dataloader = Dataloader.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)
for index, item in enumerate(dataloader):
print('i', index)
data, label = item
print('data: {}, label: {}'.format(data, label))
"""
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# GPU version  

class subDataset(Dataset.Dataset):
def __init__(self, Data, Label):
self.Data = Data
self.Label = Label

def __len__(self):
return len(self.Data)

def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.Tensor(self.Label[index])
return data, label

Data = np.asarray([[1, 2], [3, 4], [5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])

if __name__ == '__main__':
dataset = subDataset(Data, Label)
print(dataset)
print('dataset.size = ', dataset.__len__())
print(dataset.__getitem__(0))
print(dataset[0])

dataloader = Dataloader.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
# 这里需要质疑把num_workers改为0
for index, item in enumerate(dataloader):
print('i', index)
data, label = item
if torch.cuda.is_available():
data = data.cuda()
label = label.cuda()
print('data: {}, label: {}'.format(data, label))

Dataset和DataLoader总结:

  • Dataset是一个抽象类,需要派生一个子类构造数据集,需要改写的方法有__init__, __getitem__, __len__等等。
  • DataLoader是一个迭代器,方便我们访问Dataset里的对象,值得注意的num_workers的参数设置:如果放在CPU上跑,可以不管,但是放在GPU上则需要设置为0;或者在DataLoader操作之后将Tensor放在GPU上
  • 数据和标签是tuple元组的形式,使用DataLoader然后使用enumerate函数访问他们

顺便带一下CIFAR-10数据集的说明

该数据集共有60000张彩色图像,这些图像时32 * 32,分为10个类,每类6000张图片。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的随机排列就组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看,每一类都有5000张图

下面这幅图就是列举了10各类,每一类展示了随机的10张图片:

img