在深度学习的前向传播中,最重要的是理解每个计算步骤的输入前与输入后的维度形状。与之对应的时需要熟悉一些常见的维度操作方法,根据项目的学习总结记录如下:

1. torch方法

1
2
3
4
import torch
data = torch.randn(32, 100, 8, 128) # c(batch_size, seq_len, head_num, head_embed)
data.shape
# torch.Size([32, 100, 8, 128])
  • transpose 交换指定的两个维度
1
2
data.transpose(2, 0).shape
# torch.Size([8, 100, 32, 128])
  • permute 指定所有维度的新顺序。
1
2
data.permute(0, 2, 1, 3).shape
# torch.Size([32, 8, 100, 128])
  • reshape/view 变更张量的维度
1
2
3
4
5
6
data.reshape(32 * 8, 100, 128).shape
# torch.Size([256, 100, 128])

# -1表示根据已设置维度,计算该维度的大小
data.reshape(32 * 8, -1).shape
# torch.Size([256, 12800])

.view.reshape操作基本一致,前者要求输入张量是内存连续的。

  • squeeze/unsqueeze 管理大小为1的维度
1
2
3
4
5
x = torch.randn(1, 3, 1, 4)
y = x.squeeze()  # 移除所有大小为1的维度,结果为 (3, 4)
y1 = x.squeeze(0) # 结果为 (3, 1, 4)
z = x.unsqueeze(2)  # 在第2维添加一个维度,结果为 (1, 3, 1, 1, 4)
z1 = x.unsqueeze(-1)  # 结果为 (1, 3, 1, 1, 4, 1)
  • repeat : 沿指定维度重复张量
1
2
3
4
x = torch.tensor([[1, 2, 3]]) # (1, 3)

y = x.repeat(2, 3)  # 结果为 (2, 9)
y1 = x.repeat(1, 3) # 结果为 (1, 9)
  • cat/stack: 合并多个张量。cat:沿指定维度连接张量;stack:在新维度上堆叠张量。
1
2
3
4
5
6
a = torch.randn(3, 4)
b = torch.randn(3, 4)

c = torch.cat((a, b), dim=0)  # 结果为 (6, 4)
d = torch.stack((a, b), dim=0)  # 结果为 (2, 3, 4)
d1 = torch.stack((a, b), dim=2) # (3, 4, 2)

2. einops

1
2
3
4
5
6
import torch
from einops import rearrange

data = torch.randn(32, 100, 8, 128) # c(batch_size, seq_len, head_num, head_embed)
data.shape
# torch.Size([32, 100, 8, 128])
  • 维度顺序调换
1
2
rearrange(data, 'b s h d -> b s h d').shape
# torch.Size([32, 100, 8, 128])

可以使用任何符合表示任何维度,但使用常见符号可以让代码更易于理解。例如在LLM模型中,b表示batch,s表示sequence,h表示head number,d表示head dimension

1
2
rearrange(data, 'b s h d -> b h s d').shape
# torch.Size([32, 8, 100, 128])
  • 维度合并与拆分
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
rearrange(data, 'b s h d -> (b s) h d').shape
# torch.Size([3200, 8, 128])

# 维度合并的不同顺序会对产生不同的结果(维度拆分同理)
rearrange(data, 'b s h d -> (s b) h d').shape

rearrange(data, 'b s h d -> b (s h d)').shape
# torch.Size([32, 102400])

# 维度拆分
rearrange(data, 'b (s s1) h d -> b s s1 h d', s1=5)
# torch.Size([32, 20, 5, 8, 128])