要点:
Torch中提供了一种帮你整理你的数据结构的好东西,叫做DataLoader,我们能用它来包装自己的数据,进行批训练,而且批训练可以有很多种途径
DataLoader
DataLoader是torch给你用来包装你的数据的工具,所以你要将自己的(numpy array或其他)数据形式装换成Tensor,然后再放进这个包装器中,使用DataLoader可以帮你有效地迭代数据
举例:
1 | import torch |
得到的结果
Epoch: 0 | Step: 0 | batch x: [ 5. 7. 10. 3. 4.] | batch y: [6. 4. 1. 8. 7.]
Epoch: 0 | Step: 1 | batch x: [2. 1. 8. 9. 6.] | batch y: [ 9. 10. 3. 2. 5.]
Epoch: 1 | Step: 0 | batch x: [ 4. 6. 7. 10. 8.] | batch y: [7. 5. 4. 1. 3.]
Epoch: 1 | Step: 1 | batch x: [5. 3. 2. 1. 9.] | batch y: [ 6. 8. 9. 10. 2.]
Epoch: 2 | Step: 0 | batch x: [ 4. 2. 5. 6. 10.] | batch y: [7. 9. 6. 5. 1.]
Epoch: 2 | Step: 1 | batch x: [3. 9. 1. 8. 7.] | batch y: [ 8. 2. 10. 3. 4.]
如果 shuffle = True,则导出的数据是
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5.] | batch y: [10. 9. 8. 7. 6.]
Epoch: 0 | Step: 1 | batch x: [ 6. 7. 8. 9. 10.] | batch y: [5. 4. 3. 2. 1.]
Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5.] | batch y: [10. 9. 8. 7. 6.]
Epoch: 1 | Step: 1 | batch x: [ 6. 7. 8. 9. 10.] | batch y: [5. 4. 3. 2. 1.]
Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5.] | batch y: [10. 9. 8. 7. 6.]
Epoch: 2 | Step: 1 | batch x: [ 6. 7. 8. 9. 10.] | batch y: [5. 4. 3. 2. 1.]
可以看出来,每步都导出了5个数据进行学习,然后每个epoch的导出数据都是先打乱了以后再导出
如果改变 BATCH_SIZE = 8 ,这样 step = 0,就回导出8个数据 ,但是,step = 1 时数据库中的数据不够8个,则step1,只能给你返回epoch中剩下的数据了
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]