0%

批训练

要点:

Torch中提供了一种帮你整理你的数据结构的好东西,叫做DataLoader,我们能用它来包装自己的数据,进行批训练,而且批训练可以有很多种途径

DataLoader

DataLoader是torch给你用来包装你的数据的工具,所以你要将自己的(numpy array或其他)数据形式装换成Tensor,然后再放进这个包装器中,使用DataLoader可以帮你有效地迭代数据

举例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch 
import torch.utils.data as Data

torch.manual_seed(1) #reproducible

# 批训练的数据个数
BATCH_SIZE = 5

x = torch.linspace(1, 10, 10) # this is x data (torch tensor)
y = torch.linspace(10, 1, 10) # this is y data (torch tensor)

# 先转换成torch能识别的Dataset
torch_dataset = Data.TensorDataset(x, y)

# 把dataset 放入 DataLoader
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打乱数据(打乱比较好)
num_workers=2, # 多线程来读取数据
)

def show_batch():
# train entire dataset 3 times
for eopch in range(3):
# 每一步loader释放一小批数据用来学习
for step, (batch_x, batch_y) in enumerate(loader):
# training your data...
print('Epoch: ', eopch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())



if __name__ == '__main__':
show_batch()

得到的结果

Epoch: 0 | Step: 0 | batch x: [ 5. 7. 10. 3. 4.] | batch y: [6. 4. 1. 8. 7.]

Epoch: 0 | Step: 1 | batch x: [2. 1. 8. 9. 6.] | batch y: [ 9. 10. 3. 2. 5.]

Epoch: 1 | Step: 0 | batch x: [ 4. 6. 7. 10. 8.] | batch y: [7. 5. 4. 1. 3.]

Epoch: 1 | Step: 1 | batch x: [5. 3. 2. 1. 9.] | batch y: [ 6. 8. 9. 10. 2.]

Epoch: 2 | Step: 0 | batch x: [ 4. 2. 5. 6. 10.] | batch y: [7. 9. 6. 5. 1.]

Epoch: 2 | Step: 1 | batch x: [3. 9. 1. 8. 7.] | batch y: [ 8. 2. 10. 3. 4.]

如果 shuffle = True,则导出的数据是

Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5.] | batch y: [10. 9. 8. 7. 6.]

Epoch: 0 | Step: 1 | batch x: [ 6. 7. 8. 9. 10.] | batch y: [5. 4. 3. 2. 1.]

Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5.] | batch y: [10. 9. 8. 7. 6.]

Epoch: 1 | Step: 1 | batch x: [ 6. 7. 8. 9. 10.] | batch y: [5. 4. 3. 2. 1.]

Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5.] | batch y: [10. 9. 8. 7. 6.]

Epoch: 2 | Step: 1 | batch x: [ 6. 7. 8. 9. 10.] | batch y: [5. 4. 3. 2. 1.]

可以看出来,每步都导出了5个数据进行学习,然后每个epoch的导出数据都是先打乱了以后再导出

如果改变 BATCH_SIZE = 8 ,这样 step = 0,就回导出8个数据 ,但是,step = 1 时数据库中的数据不够8个,则step1,只能给你返回epoch中剩下的数据了

Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]

Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]

Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]

Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]

Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]

Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]