PyTorch中的dim
dim概念
dim的不同值表示不同维度。特别的在dim=0表示二维中的行,dim=1在二维矩阵中表示行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:,那么dim=0就表示对应到也就是第一个维度,dim=1,表示对应到也就是第二个维度,以此类推
dim在函数中的作用
例一. torch.argmax()
函数中dim表示该维度会消失。
这个消失是什么意思?官方英文解释是:dim (int) – the dimension to reduce.
我们知道argmax就是得到最大值的序号索引,对于一个维度为的矩阵来说,我们想要求每一行中最大数的在该行中的列号,最后我们得到的就是一个维度为的一矩阵。这时候,列就要消失了。
因此,我们想要求每一行最大的列标号,我们就要指定dim=1,表示我们不要列了,保留行的size就可以了。
假如我们想求每一列的最大行标,就可以指定dim=0,表示我们不要行了。
1 | import torch |
1 | torch.argmax(input, dim=None, keepdim=False) |
返回指定维度最大的序号
dim给定的定义是:the dimention to reduce.也就是吧dim这个维度的,变成这个维度的最大值
如果上面的代码改成:
1 | b = torch.argmax(a, dim=1, keepdim=True) |