1. 注意力提示

1.1 生物学的注意力提示

如下的观察实验:

  • 受试者的注意力往往首先被颜色鲜艳的红色咖啡杯吸引(非自主性);
    • 客观存在的,对于观察者的吸引特征。
  • 喝完咖啡,处于兴奋状态的大脑经思考后,相比看报等,可能更想要读一本书(自主性权重更高);
    • 在受试者的主观意愿推动下所做的决定。

image-20240812204130442

1.2 查询、键和值

  • 上述的非自主性提示,可以类比之前的全连接层、卷积层等。
    • 红色的咖啡杯可以理解为高权重值的神经元,对输出有较大的影响。
  • 而注意力(Attention)机制可通过注意力汇聚,将查询(Query)与所有的键值(Key-Value)对进行关联,得到输出。
    • 查询:分别与所有键Key计算’相似度’,表示权重值,得到注意力的抽象表示;
    • 键-值对:基于上述权重,对值value进行加权平均求和,得到输出。(二者可以是同一数据)
    • 一个Query得到一个输出

image-20240812205100062

1.3 注意力的可视化

  • 对注意力的权重进行热图可视化
    • 每行表示一次Query与所有的Key计算的权重结果。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from d2l import torch as d2l

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);
    
#四维矩阵的输入,前两个维度表示子图的行数与列数
attention_weights = torch.eye(10).reshape((1, 1, 10, 10)) 
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

../_images/output_attention-cues_054b1a_36_0.svg

2. 注意力汇聚

  • 1964年提出的Nadaraya-Watson核回归,本质上可以理解为带有加权平均的注意力机制

2.1 生成数据集

  • x_train :键Key
  • y_train:值Value
  • x_test:查询Query
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本(0~5分布范围)
def f(x):
    return 2 * torch.sin(x) + x**0.8
y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出

x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test
# 50

#定义一个绘图函数
def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

2.2 平均汇聚

  • 最简单的做法是对于任意Query(x),都直接计算所有训练样本输出值(yi)的均值

image-20240812213916763

1
2
3
4
5
y_hat = torch.repeat_interleave(y_train.mean(), n_test) #两个参数,值以及重复的次数
y_hat.shape, y_hat[:5]
# (torch.Size([50]), tensor([2.4023, 2.4023, 2.4023, 2.4023, 2.4023]))

plot_kernel_reg(y_hat)

../_images/output_nadaraya-waston_736177_39_0.svg

2.3 非参数注意力汇聚

  • 根据query(x)与key(xi)的关系度量α,计算当key为xi时,值yi的权重。
  • 一种关系度量方式是将x与xi间的距离进行高斯核函数转换。距离越近,则值越大。
  • 最后的Softmax操作将权重和变为1,得到最终的加权平均方式。

image-20240812214652217

  • 下述的计算为非参数的注意力汇聚,即没有可学习的模型参数。
1
2
3
4
5
6
7
8
9
# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询),以分别计算与所有key的距离(每行)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# attention_weights的形状同样是:(n_test,n_train)
# 每一行表示每个查询与所有键key之间的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# 计算最终值value的加权平均值
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat) #下图左
image-20240812220059975
1
2
3
4
5
6
# 可视化注意力权重
attention_weights.shape
# torch.Size([50, 50])
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs') # 上图右

2.4 带参数注意力汇聚

  • 如下,是带可学习参数的Nadaraya-Watson核回归实现
    • w参数用于控制高斯核的宽度,可以理解为方差。
    • 方差越大,表示越关注少数几个与query高度接近的xi,赋予较高的权重。

image-20240812220957816

(1)批量矩阵乘法

  • nn.bmm: 第一个小批量的第i个矩阵与第二个小批量的第i个矩阵相乘。
1
2
3
4
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
# torch.Size([2, 1, 6])
  • 据此,可在注意力机制背景下,计算小批量数据的加权平均值
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
weights.unsqueeze(1).shape #一次query,10个键值对
# torch.Size([2, 1, 10])
values.unsqueeze(-1).shape # 10个键值对
# torch.Size([2, 10, 1])

torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
# tensor([[[ 4.5000]],
#        [[14.5000]]])

(2)定义模型

  • w可学习参数
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        # 每个query重复keys的数量次
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1) #按行做softmax,计算权重
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

(3)训练

  • 计算keys与values
    • 因为要使用x_train作为query,所以在keys与values中的每一行中,去除自己本身的观测键值对
1
2
3
4
5
6
7
8
# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
  • 训练
    • 训练模型使用x_train作为query,而不是x_test;
    • x_train中的第i个query与keys中的第i行进行α计算;
    • 得到对应的权重结果后,再对values的第i行进行加权平均、
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))
  • 预测
1
2
3
4
5
6
7
8
9
# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
y_hat.shape
# torch.Size([50, 1])

plot_kernel_reg(y_hat)  #下图左
image-20240813083120379
1
2
3
4
5
6
# 权重可视化
net.attention_weights.shape
# torch.Size([50, 50])
d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

如上右图,可以看到w参数将注意力机制更关注少数与query权重更高的key

3. 注意力评分函数

  • 上面的α(高斯运算)可以视为注意力评分函数,计算query与每个key的’关系’。然后再经softmax操作,映射为注意力权重。
  • 引申来看,如下公式中:
    • q表示查询query,可以是一个向量;
    • (k, v)表示键值对(key-value),二者可以是长度不同的向量,也可以是同一数据;
    • 评分函数 α(q, ki)将query与每个key映射成标量,再进行softmax计算。
image-20240813090639100
1
2
3
4
import math
import torch
from torch import nn
from d2l import torch as d2l

3.1 掩蔽softmax操作

  • 在上一章的seq2seq学习中,为了保证子序列长度相同,对于原来较短的序列进行了填充。
  • 同样,这里需要将这些填充词元的注意分数设置为很小的值,从而在softmax操作时计算得到权重值为0。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
#@save
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
  • 示例操作
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# valid_lens.dim() == 1
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
# tensor([[[0.5599, 0.4401, 0.0000, 0.0000],
#          [0.4361, 0.5639, 0.0000, 0.0000]],

#         [[0.2928, 0.4262, 0.2810, 0.0000],
#          [0.3205, 0.3216, 0.3579, 0.0000]]])

# valid_lens.dim() == 2
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))
# tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
#          [0.3533, 0.3007, 0.3460, 0.0000]],

#         [[0.5339, 0.4661, 0.0000, 0.0000],
#          [0.2291, 0.1934, 0.3177, 0.2598]]])

3.2 加性注意力

  • Activation attention:将查询向量与键的向量相加后,输入到一个多层感知机中
    • 单隐藏层,tanh激活函数,禁用偏置项,输出层的神经元个数为1

image-20240813115720313

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#@save
class AdditiveAttention(nn.Module):
    """加性注意力"""
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 在维度扩展后,
        # queries的形状:(batch_size,查询的个数,1,num_hidden)
        # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        # 使用广播方式进行求和:对于 queries 的每一个查询(query),将其与所有键(key)相加。
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        # (batch_size, num_queries, num_keys, num_hiddens)
        features = torch.tanh(features)
        # self.w_v仅有一个输出,因此从形状中移除最后那个维度。
        scores = self.w_v(features).squeeze(-1) 
        # scores的形状:(batch_size, num_queries, num_keys)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        # 最后返回结果:(batch_size,num_queries,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)
    
  • 示例操作:批量大小为2,每个批量
    • 1个query,其向量长度为20
    • 10对key-value,key向量长度为2,value向量长度为4
1
2
3
4
5
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# values的小批量,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])
  • 注意力汇聚输出的形状为(批量大小,查询数,值value的维度)
1
2
3
4
5
6
7
8
9
# 实例化一个attention
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
                              dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)
# tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],
#         [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)
attention(queries, keys, values, valid_lens).shape
# torch.Size([2, 1, 4])
  • 每个query对于所有key的注意力权重可视化
1
2
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

../_images/output_attention-scoring-functions_2a8fdc_96_0.svg

3.3 缩放点积注意力

  • Scaled dot-product attention:当query向量与key向量长度一致时,可进行点积操作;再根据向量长度进行缩放,作为注意力分数。
    • 相比于加性注意力,模型参数较少(只有dropout)

image-20240813121308655

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
#@save
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
  
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 交换keys的最后两个维度
        # queries的形状:(batch_size,查询的个数,d)
        # keys的形状:(batch_size,“键-值”对的个数,d)        
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)
  • 示例操作
1
2
3
4
5
6
queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)
# tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],
#         [[10.0000, 11.0000, 12.0000, 13.0000]]])

4. Bahdanau注意力

  • 在上一章学习seq2seq时,将编码器RNN中最后一个时间步的隐状态作为上下文变量传递给了解码器。
    • 此时,假设该隐状态能够学习到编码器序列的全部信息,但对于较长的序列,实际情况可能并非如此。
  • 对此,Bahdanau等提出了注意力机制的seq2seq模型,具体实现方式如下:

4.1 模型

  • 特定解码器词元的context上下文变量来自于编码器序列所有词元隐状态的加权平均
    • Query:编码器RNN序列中,前一个时间步的最后一层隐状态输出
    • Key/Value:解码器RNN序列中,每个时间步的最后一层隐状态输出(既作为Key,也作为Value)。
  • 其它操作与之前学习基本一致。
image-20240813134102502
1
2
3
import torch
from torch import nn
from d2l import torch as d2l

4.2 定义注意力解码器

编码器不用重新定义,直接使用之前的就行

  • 首先定义一个基本接口
1
2
3
4
5
6
7
8
9
#@save
class AttentionDecoder(d2l.Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    @property
    def attention_weights(self):
        raise NotImplementedError
  • 然后是具体的代码实现
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention = d2l.AdditiveAttention(
            num_hiddens, num_hiddens, num_hiddens, dropout) 
        # 编码器与解码器的隐状态的神经元个数一致
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers,
            dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size,num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)

    def forward(self, X, state):
        # enc_outputs的形状为(batch_size,num_steps,num_hiddens).
        # hidden_state的形状为(num_layers,batch_size, num_hiddens)
        # 解码器序列的第一个Query隐状态来自于编码器最后一个时间步
        enc_outputs, hidden_state, enc_valid_lens = state
        # 输出X的形状为(num_steps,batch_size,embed_size)
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            # query的形状为(batch_size,1,num_hiddens)
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            # 注意力机制的context的形状为(batch_size,1,num_hiddens)
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
            # 在特征维度上连结
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            # 将x变形为(1,batch_size,embed_size+num_hiddens)
            # 更新Query隐状态
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        outputs = self.dense(torch.cat(outputs, dim=0))
        # 全连接层变换后,outputs的形状为(num_steps,batch_size,vocab_size)
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]

    @property
    def attention_weights(self):
        return self._attention_weights
  • 示例
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,
                             num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                                  num_layers=2)
decoder.eval()

X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape # (batch_size,num_steps, vocab_size)
# torch.Size([4, 7, 10])
len(state)
# 3, 分别是[enc_outputs, hidden_state, enc_valid_lens]
state[0].shape #编码器所有时间步的最终层隐状态
# torch.Size([4, 7, 16])
state[1].shape #解码器最后一个时间步的所有层隐状态
# torch.Size([2, 4, 16])

4.3 训练

  • 带有注意力机制的解码器会增加训练的时间
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

encoder = d2l.Seq2SeqEncoder(
    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)

net = d2l.EncoderDecoder(encoder, decoder)

d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
# loss 0.021, 4352.7 tokens/sec on cuda:0
  • 预测
    • predict_seq2seq函数返回翻译后的outputs,以及相应的注意力权重结果
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
          f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
# go . => va !,  bleu 1.000
# i lost . => j'ai perdu .,  bleu 1.000
# he's calm . => il est paresseux .,  bleu 0.658
# i'm home . => je suis chez moi .,  bleu 1.000
  • 查看注意力权重
1
2
3
4
5
6
7
8
9
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((
    1, 1, -1, num_steps))
attention_weights.shape
# torch.Size([1, 1, 6, 10])

# 加上一个包含序列结束词元
d2l.show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
    xlabel='Key positions', ylabel='Query positions')

../_images/output_bahdanau-attention_7f08d9_87_0.svg

5. 多头注意力

5.1 模型

  • 类似于CNN的多输出通道,多头注意力旨在通过多个独立的注意力汇聚学习到不同角度的信息
  • 如下图所示:
    • 首先将query,key,value向量进行线性投影(全连接层转换);
    • 然后并行地分送到多个不同的注意力汇聚中;
    • 最后将多头注意力输出结果拼接在一起,再经最后一个线性投影转换。

image-20240813202604493

5.2 实现

  • 通常选择缩放点积注意力作为每个注意力头;
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状: (batch_size,查询或者“键-值”对的个数,num_hiddens)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        # 经transpose_qkv变换后,输出的queries,keys,values的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
        
        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads) 除
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
  • 如上操作,虽然是多头计算,但可以通过数据处理技巧节省运算。
  • 简单来说,在计算式,通过合并num_head维度到batch_size维度,一次性计算多头的注意力结果。
  • 最后再将结果在num_hiddens维度cat拼接到一起。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)

    X = X.permute(0, 2, 1, 3)
	# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    # 三维变四维
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    # 四维再变三维:(batch_size,查询的个数,num_hiddens)
    return X.reshape(X.shape[0], X.shape[1], -1)
  • 示例演示
    • 5个注意力头
    • qkv隐藏层神经元长度都设置为一样,为100
    • 批量大小为2,每个批量4次query,6个键值对
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens)) #Query
Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) #Key/Value是相同的
X.shape, Y.shape
# (torch.Size([2, 4, 100]), torch.Size([2, 6, 100]))

attention(X, Y, Y, valid_lens).shape
# torch.Size([2, 4, 100])

6. 自注意力和位置编码

6.1 自注意力

  • 可采用自注意力机制对序列词元进行编码。此时Query,以及Key-Value都来自同一组输入
    • 即每个词元查询都会关注所有的键-值对,并生成一个注意力输出
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])

# (批量大小,词元序列长度,d)
X = torch.ones((batch_size, num_queries, num_hiddens))
# 如下X分别充当queries, keys, values
attention(X, X, X, valid_lens).shape
# torch.Size([2, 4, 100])
# 输入与输出的形状相同

6.2 比较CNN, RNN, Self-attention

卷积神经网路、循环神经网络,以及自注意力架构都可以将n个词元组成的序列映射到另一个长度相同的序列表示。

../_images/cnn-rnn-self-attention.svg
  • CNN (假设卷积核大小为k,输入与输出通道为d)

    • 计算复杂度:O(knd*d)
    • 顺序操作:O(1)
    • 最大路径长度:O(n/k)
  • RNN(d×d权重矩阵,d维隐状态)

    • 计算复杂度:O(d*d)

    • 顺序操作:O(n)

    • 最大路径长度:O(n)

  • Self-attention

    • 计算复杂度:O(n*nd)

    • 顺序操作:O(1)

    • 最大路径长度:O(1)

综上:卷积神经网络和自注意力都拥有并行计算的优势。而自注意力的最大路径长度最短,其计算复杂度在很长的序列中计算会比较慢。

TIPS: 顺序操作会妨碍并行计算。而任意的序列位置组合之间的路径越短,则能更轻松地学习序列中的远距离依赖关系。

6.3 位置编码

  • 在6.1的计算过程中,忽略了序列所包含的位置信息。
  • 对于n×d的序列词元输入信息,可进行位置编码生成相同形状的表示,再进行矩阵加法,共同作为输入。
    • n表示序列中词元的个数,d表示features数
  • 如下公式,一种常见方式是基于正弦函数和余弦函数的固定位置编码。
    • 序列中第i个词元(行)的第偶数个维数(列)使用sin函数
    • 序列中第i个词元(行)的第奇数个维数(列)使用cos函数

image-20240813213636081

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)
  • 如下可视化,可以看出
    • 第2j列与第2j+1列的周期/频率是一样的
    • j越大,sin/cos函数周期越大,或者说频率越低
1
2
3
4
5
6
7
8
9
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()

X = pos_encoding(torch.zeros((1, num_steps, encoding_dim))) # X+P
P = pos_encoding.P[:, :X.shape[1], :] # P

d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

../_images/output_self-attention-and-positional-encoding_d76d5a_52_0.svg

7. Transformer

  • Transformer模型完全基于注意力机制,没有任何卷积层或循环神经网络层 (Attention is all you need);
  • 它最初是应用于在文本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语言、视觉、语音和强化学习领域。

7.1 模型

  • Transformer是一个经典的编码器与解码器架构。

架构角度

  • 编码器:n个编码层组成,每个层由2个子层串联组成。每个子层后面都采用了残差连接,再应用层规范化(layer normalization)
    • 第一个子层:多头自注意力汇聚;
    • 第二个子层:基于位置的前馈网络;
  • 解码器:n个解码层组成,每个层由3个子层串联组成。每个子层后面同样都采用了残差连接,再应用层规范化(layer normalization)
    • 第一个子层:带掩码的多头自注意力汇聚;
    • 第二个子层:编码器-解码器注意力层;
    • 第三个子层:基于位置的前馈网络

数据角度

  • 编码器
    • 输入词元序列的embedding加上位置编码,输入到第一个编码层的第一个子层;
    • 经过n层编码层学习后,输出的形状一般不变(batch_size, num_steps, num_hiddens)
  • 解码器
    • 标签词元序列的embedding,加上位置编码,输入到第一个解码层的第一个子层(掩码自注意力);
    • 在第二个子层中,将第一个子层的输出作为Query,将编码器的输出作为Key和Value,进行解码,输出到前馈网络;
    • 经过如上n个解码层学习,最后输出到一个全连接层中。
image-20240814101932173
1
2
3
4
5
import math
import pandas as pd
import torch
from torch import nn
from d2l import torch as d2l

7.2 基于位置的前馈神经网络

  • Positionwise feed-forward network, FFN
    • 对序列中的每个位置词元特征,都进行相同的映射变换
  • 本质上就是两层的MLP
    • 输入形状:(批量大小,序列长度,特征维度)
    • 输出形状:(批量大小,序列长度,特征维度2)
  • FFN类的参数
    • 输入:特征维度
    • 中间:隐藏层神经元
    • 输出:特征维度2
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
#@save
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

inputs = torch.ones((2, 3, 4))
ffn = PositionWiseFFN(4, 4, 8)
outputs = ffn(inputs)
outputs.shape
# torch.Size([2, 3, 8])

7.3 残差连接和层规范化

  • 层规范化,Layer Normalization
    • 对每个序列中所有词元的特征数据进行规范化
    • e.g. 对于每个序列的二维矩阵(序列长度,特征数)的整体求均值与方差

TIPS: 之前学习的BatchNorm是对一个特征在所有批量样本的规范化。

1
2
3
4
5
6
7
8
9
ln = nn.LayerNorm(normalized_shape=(5, 10)) #一个序列所有词元的特征
# ln = nn.LayerNorm(normalized_shape=(5, 10)) 单个词元所有特征的规范化

# 输入数据
X = torch.randn(3, 5, 10)
# 归一化
output = ln(X)
output.shape
# torch.Size([3, 5, 10])
  • 残差连接的定义仍是加上原始输入X,以便计算深层网络
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
#@save
class AddNorm(nn.Module):
    """残差连接后进行层规范化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X) #先计算残差连接,再Norm
    
# 实例化一个类
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
# torch.Size([2, 3, 4])

7.4 编码器

  • 首先定义编码层,如上所述包含了两个子层
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
#@save
class EncoderBlock(nn.Module):
    """Transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        # 第一个子层
        self.attention = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout,
            use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        # 第二个子层
        self.ffn = PositionWiseFFN(
            ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
X = torch.ones((2, 100, 24)) #输入X
valid_lens = torch.tensor([3, 2])

# qkv以及hidden均为24
# [100, 24]表示对于后两个维度进行LayerNorm
# 24, 48为FNN的输入与隐藏,8个多头注意力
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoder_blk.eval()

encoder_blk(X, valid_lens).shape
# torch.Size([2, 100, 24])
  • 构造编码器类
    • 输入X加上位置编码,输入到n个编码层中得到输出
    • num_layers设置n
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#@save
class TransformerEncoder(d2l.Encoder):
    """Transformer编码器"""
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                EncoderBlock(key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        # 因为位置编码值在-1和1之间,因此嵌入值乘以嵌入维度的平方根进行缩放,然后再与位置编码相加。
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            # 保存注意力权重
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

示例

1
2
3
4
5
6
# n=2
encoder = TransformerEncoder(
    200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape
# torch.Size([2, 100, 24])

7.5 解码器

  • 相对编码器,解码器的构造比较复杂。(1)每个解码层包括3个子层;(2)将编码器的输出结合到解码器中;(3)训练与预测的处理方式有差异

  • 关于第一个子层,即带掩码的自注意力层:

    • 在训练时,序列中所有位置的词元理论上都是已知的。

    • 但是在真实的应用/预测场景中,只有生成的词元才能用于解码器的自注意力计算中。

    • dec_valid_lens参数可以使得查询都只会与解码器中所有已经生成词元的位置进行注意力计算。

  • 关于第二个子层,即编码器-解码器注意力层

    • 上一的解码层子层的单个词元输出将作为Query
    • 来自编码器的输出将同时作为Key和Value
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class DecoderBlock(nn.Module):
    """解码器中第i个块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        # i表示n个解码层中的第i个层
        self.i = i
        # 第一个子层
        self.attention1 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        # 第二个子层
        self.attention2 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        # 第三个子层
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                   num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 训练阶段,输出序列的所有词元都在同一时间处理,因此state[2][self.i]初始化为None。
        if state[2][self.i] is None:
            key_values = X
        # 预测阶段,输出序列是通过词元一个接着一个解码的,因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
        else:
            # (Batch_size, Num_steps + 1, Num_hiddens) 逐步累计每个时间步的解码表示
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # dec_valid_lens的开头:(batch_size,num_steps), 其中每一行是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None
        # 第一个子层:自注意力,dec_valid_lens表示掩码
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # 第二个子层:编码器-解码器注意力。
        # enc_outputs的开头:(batch_size,num_steps,num_hiddens)作为key/value
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        # 第三个子层,state主要更新了state[2]中的内容
        return self.addnorm3(Z, self.ffn(Z)), state

示例

1
2
3
4
5
6
# 8个head
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
decoder_blk.eval()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
decoder_blk(X, state)[0].shape
  • 构造解码器类
    • 最后一个全连接层输出序列中每个词元的vocab_size个可能输出词元的概率
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class TransformerDecoder(d2l.AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        # 序列embedding+位置编码
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        # n个解码层
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                DecoderBlock(key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, i))
        #最后一个全连接层
        self.dense = nn.Linear(num_hiddens, vocab_size)

    # 初始化的state值
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        # 序列embedding+位置编码
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
        # n个解码层
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            # 解码器自注意力权重
            self._attention_weights[0][
                i] = blk.attention1.attention.attention_weights
            # “编码器-解码器”自注意力权重
            self._attention_weights[1][
                i] = blk.attention2.attention.attention_weights
        #最后一个全连接层
        return self.dense(X), state

    @property
    def attention_weights(self):
        return self._attention_weights

7.6 训练

  • 同样以之前的’英语-法语’的机器翻译任务为例,演示Transformer的训练
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]

# 训练数据
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
# 编码器
encoder = TransformerEncoder(
    len(src_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
# 解码器
decoder = TransformerDecoder(
    len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
# 实例化模型
net = d2l.EncoderDecoder(encoder, decoder)
# 训练
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
# loss 0.032, 5852.4 tokens/sec on cuda:0

尽管Transformer架构是为了序列到序列的学习而提出的,但正如本书后面将提及的那样,Transformer编码器或Transformer解码器通常被单独用于不同的深度学习任务中。