0%

PyTorch 中的 dim

PyTorch中的dim

dim概念

dim的不同值表示不同维度。特别的在dim=0表示二维中的行,dim=1在二维矩阵中表示行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:d0,d1,...,dn1{d_{0},d_{1},...,d_{n-1}},那么dim=0就表示对应到d0d_{0}也就是第一个维度,dim=1,表示对应到d1d_{1}也就是第二个维度,以此类推

dim在函数中的作用

例一. torch.argmax()

函数中dim表示该维度会消失。

这个消失是什么意思?官方英文解释是:dim (int) – the dimension to reduce.

我们知道argmax就是得到最大值的序号索引,对于一个维度为(d0,d1)(d_0,d_1)的矩阵来说,我们想要求每一行中最大数的在该行中的列号,最后我们得到的就是一个维度为(d0,1)(d_0,1)的一矩阵。这时候,列就要消失了。

因此,我们想要求每一行最大的列标号,我们就要指定dim=1,表示我们不要列了,保留行的size就可以了。
假如我们想求每一列的最大行标,就可以指定dim=0,表示我们不要行了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import numpy as np
a = torch.rand(3, 4)
print(a.size())
print(a)

b = torch.argmax(a, dim=1)
print(b)
print(b.size())
"""
output:
torch.Size([3, 4])
tensor([[0.9120, 0.4805, 0.6701, 0.5446],
[0.6273, 0.1295, 0.3416, 0.2213],
[0.6068, 0.8448, 0.8452, 0.4931]])
tensor([0, 0, 2])
torch.Size([3])
"""
# 可以看到,指定dim=1时,列的size没有了
1
torch.argmax(input, dim=None, keepdim=False)

返回指定维度最大的序号

dim给定的定义是:the dimention to reduce.也就是吧dim这个维度的,变成这个维度的最大值

如果上面的代码改成:

1
2
3
4
5
6
7
8
b = torch.argmax(a, dim=1, keepdim=True)
"""
output:
tensor([[0],
[0],
[2]])
torch.Size([3, 1])
"""