1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
d_model = 512 # 输入特征维度
nhead = 8 # 注意力头数
# d_model: 输入和输出的特征维度
# nhead: 注意力头数
# dim_feedforward: FFN的神经元个数(MLP)
# dropout: 丢失率,默认0.1
# batch_first: 默认为False,即输入为(seq, batch, feature); 设置为True,则输入为(batch, seq, feature)
# norm_first: 默认为False,即在Attention和FFN之前进行norm
# activation: 前馈网络中使用的激活函数,默认是 relu,可以选择其他激活函数如 gelu。
# 创建 Transformer 编码器层
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model,
nhead=nhead,
dim_feedforward=d_model*2,
dropout=0.1,
batch_first=True)
input = torch.randn(2, 10, 512) # (批量大小, 序列长度, 特征维度)
encoder_layer(input).shape
# torch.Size([2, 10, 512])
|