attention
也读了一些论文, 见过一些attention机制的变体, 我发现下面这种解释可以用来理解绝大多数的attention的变体
也读了一些论文, 见过一些attention机制的变体, 我发现下面这种解释可以用来理解绝大多数的attention的变体
为这篇博客的翻译
在深度学习中,“注意力”这一概念起源于改进循环神经网络(RNNs)以处理更长的序列或句子的努力。例如,考虑将一个句子从一种语言翻译成另一种语言。逐词翻译并不能有效工作。

为了解决这个问题,引入了注意力机制,使每个时间步骤都能访问所有序列元素。关键在于选择性地确定在特定上下文中哪些词最重要。2017 年,Transformer 架构引入了一个独立的自注意力机制,完全消除了对 RNNs 的需求。
我们可以将自注意力机制视为一种通过包含输入上下文信息来增强输入嵌入信息内容的机制。换句话说,自注意力机制使模型能够权衡输入序列中不同元素的重要性,并动态调整它们对输出的影响。这对于语言处理任务尤为重要,因为一个词的意义可能会根据其在句子或文档中的上下文而改变。
需要注意的是,自注意力机制有许多变体。特别关注的是如何使自注意力机制更高效。然而,大多数论文仍然采用本文讨论的原始缩放点积注意力机制,因为它通常能带来更高的准确性,并且对于大多数训练大规模变换器的公司来说,自注意力机制很少成为计算瓶颈。
在本文中,我们重点关注原始的缩放点积注意力机制(称为自注意力机制),它仍然是实践中最流行和最广泛使用的注意力机制。
在开始之前,让我们考虑一个输入句子“Life is short, eat dessert first”,我们希望将其通过自注意力机制进行处理。类似于其他处理文本的建模方法(例如使用循环神经网络或卷积神经网络),我们首先创建一个句子嵌入。
为了简单起见,这里我们的词典 dc 仅限于输入句子中出现的单词。在实际应用中,我们会考虑训练数据集中的所有单词(典型的词汇表大小在 3 万到 5 万之间)。
输入:
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)输出:
{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}接下来,我们使用这个字典为每个词分配一个整数索引:
输入:
import torch
sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)输出:
tensor([0, 4, 5, 2, 1, 3])现在,使用输入句子的整数向量表示,我们可以使用嵌入层将输入编码为实向量嵌入。这里,我们将使用 16 维嵌入,使得每个输入词由一个 16 维向量表示。由于句子由 6 个词组成,这将生成一个 维嵌入:
输入:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)输出:
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880, 0.3486, 0.6603, -0.2196, -0.3792,
0.7671, -1.1925, 0.6984, -1.4097, 0.1794, 1.8951, 0.4954, 0.2692],
[ 0.5146, 0.9938, -0.2587, -1.0826, -0.0444, 1.6236, -2.3229, 1.0878,
0.6716, 0.6933, -0.9487, -0.0765, -0.1526, 0.1167, 0.4403, -1.4465],
[ 0.2553, -0.5496, 1.0042, 0.8272, -0.3948, 0.4892, -0.2168, -1.7472,
-1.6025, -1.0764, 0.9031, -0.7218, -0.5951, -0.7112, 0.6230, -1.3729],
[-1.3250, 0.1784, -2.1338, 1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
0.8805, 1.5542, 0.6266, -0.1755, 0.0983, -0.0935, 0.2662, -0.5850],
[-0.0770, -1.0205, -0.1690, 0.9178, 1.5810, 1.3010, 1.2753, -0.2010,
0.4965, -1.5723, 0.9666, -1.1481, -1.1589, 0.3255, -0.6315, -2.8400],
[ 0.8768, 1.6221, -1.4779, 1.1331, -1.2203, 1.3139, 1.0533, 0.1388,
2.2473, -0.8036, -0.2808, 0.7697, -0.6596, -0.7979, 0.1838, 0.2293]])
torch.Size([6, 16])现在,让我们来讨论广泛使用的自注意力机制,即缩放点积注意力,它被集成到变压器架构中。
自注意力机制使用三个权重矩阵,分别称为 、 和 ,这些矩阵在训练过程中作为模型参数进行调整。这些矩阵分别用于将输入投影到序列的查询、键和值组件中。
相应的查询、键和值序列是通过权重矩阵 和嵌入输入 之间的矩阵乘法获得的:
索引 指的是输入序列中的标记索引位置,该序列的长度为 。

这里, 和 都是维度为 的向量。投影矩阵 和 的形状为 ,而 的形状为 。
(需要注意的是, 表示每个词向量的大小, 。)
由于我们计算的是查询向量和键向量之间的点积,因此这两个向量必须包含相同数量的元素( )。然而,值向量中的元素数量 是任意的,它决定了结果上下文向量的大小。
因此,在接下来的代码讲解中,我们将设置 并使用 ,初始化投影矩阵如下:
输入:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 24, 24, 28
W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))现在,假设我们感兴趣的是计算第二个输入元素的注意力向量——这里的第二个输入元素充当查询

在代码中,这看起来如下:
输入:
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)
print(query_2.shape)
print(key_2.shape)
print(value_2.shape)torch.Size([24])
torch.Size([24])
torch.Size([28])然后,我们可以将其推广到计算所有输入的其余键和值元素,因为在下一步计算未归一化的注意力权重 时,我们需要它们
输入:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)输出:
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])现在我们已经拥有了所有必需的键和值,可以继续下一步,计算未归一化的注意力权重 ,如下图所示:

如上图所示,我们将 计算为查询序列和键序列之间的点积,即 。
例如,我们可以计算查询与第 5 个输入元素(对应于索引位置 4)之间的未归一化注意力权重如下:
输入:
omega_24 = query_2.dot(keys[4])
print(omega_24)输出:
由于我们稍后需要计算注意力分数,让我们根据前图所示为所有输入标记计算 值:
输入:
omega_2 = query_2.matmul(keys.T)
print(omega_2)输出:
tensor([ 8.5808, -7.6597, 3.2558, 1.0395, 11.1466, -0.4800])自注意力机制的下一步是对未归一化的注意力权重 进行归一化,以获得归一化的注意力权重 ,方法是应用 softmax 函数。此外,在通过 softmax 函数进行归一化之前,使用 对 进行缩放,如下所示:

通过平方根对 进行缩放,可以确保权重向量的欧几里得长度大致保持在同一数量级。这有助于防止注意力权重变得过小或过大,从而避免数值不稳定或影响模型在训练过程中的收敛能力。
为什么特别选择 ?q 和 k 之间的点积是 个独立项的和,每个项的方差约为 1。这意味着原始分数的方差会随着 线性增长。通过除以 ,我们抵消了这种增长,并将方差恢复到约 1。
在代码中,我们可以如下实现注意力权重的计算:
输入:
import torch.nn.functional as F
attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)输出:
tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458])最后一步是计算上下文向量 ,这是原始查询输入 的注意力加权版本,通过注意力权重将所有其他输入元素作为其上下文包括在内

在代码中,这看起来如下:
输入:
context_vector_2 = attention_weights_2.matmul(values)
print(context_vector_2.shape)
print(context_vector_2)输出:
torch.Size([28])
tensor(torch.Size([28])
tensor([-1.5993, 0.0156, 1.2670, 0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
0.4747, 1.1926, 0.4506, -0.7110, 0.0602, 0.7125, -0.1628, -2.0184,
0.3838, -2.1188, -0.8136, -1.5694, 0.7934, -0.2911, -1.3640, -0.2366,
-0.9564, -0.5265, 0.0624, 1.7084])请注意,由于我们之前指定了 ,这个输出向量的维度( )比原始输入向量的维度( )要多;然而,嵌入大小的选择是任意的。
在本文开头的第一张图中,我们看到变压器使用了一个叫做多头注意力的模块。这与我们上面介绍的自注意力机制(缩放点积注意力)有什么关系?
在缩放点积注意力中,输入序列通过三个矩阵进行变换,这三个矩阵分别代表查询、键和值。在这三个矩阵可以被视为多头注意力中的一个单一注意力头。下图总结了我们之前讨论的这个单一注意力头:

顾名思义,多头注意力机制涉及多个这样的头,每个头由查询、键和值矩阵组成。这个概念类似于卷积神经网络中使用多个核。

为了用代码说明这一点,假设我们有 3 个注意力头,因此我们现在扩展 维的权重矩阵,使其变为 :
输入:
h = 3
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))因此,每个查询元素现在是 维的,其中 (这里,让我们重点关注对应于索引位置 2 的第 3 个元素):
输入:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2.shape)输出:
然后我们可以用类似的方法获取键和值:
输入:
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)现在,这些键和值元素是特定于查询元素的。但是,类似于之前的情况,为了计算查询的注意力分数,我们还需要其他序列元素的值和键。我们可以通过将输入序列嵌入扩展到大小 3 来实现这一点,即注意力头的数量:
输入:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs.shape)输出:
现在,我们可以使用 via torch.bmm() (批量矩阵乘法)来计算所有的键和值:
输入:
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)输出:
multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 28, 6])我们现在有了表示三个注意力头的张量,它们的第一维度代表注意力头。第三和第二维度分别表示单词数量和嵌入大小。为了使值和键更易于解释,我们将交换第二和第三维度,从而得到与原始输入序列具有相同维度结构的张量, embedded_sentence :
输入:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)输出:
multihead_keys.shape: torch.Size([3, 6, 24])
multihead_values.shape: torch.Size([3, 6, 28])然后,我们按照之前的步骤计算未缩放的注意力权重 和注意力权重 ,接着通过缩放的 softmax 计算来获得一个 (这里: )维的上下文向量 ,用于输入元素 。
在上面的代码讲解中,我们设置了 和 。换句话说,我们对查询和键序列使用了相同的维度。虽然值矩阵 的维度通常与查询和键矩阵的维度相同(例如在 PyTorch 的 MultiHeadAttention 类中),但我们也可以为值的维度选择任意大小。
由于维度有时难以跟踪,让我们在下图中总结我们迄今为止所涵盖的所有内容,该图展示了单个注意力头的各种张量大小。

上面的图示对应于变压器中使用的自注意力机制。我们还没有讨论的一种特定的注意力机制是交叉注意力。

什么是交叉注意力,它与自注意力有什么不同?
在自注意力机制中,我们处理的是同一个输入序列。而在交叉注意力机制中,我们混合或组合两个不同的输入序列。对于上述原始的 Transformer 架构来说,这两个序列分别是左侧编码器模块返回的序列和右侧解码器部分正在处理的输入序列。
请注意,在交叉注意力机制中,两个输入序列 和 可以具有不同数量的元素。然而,它们的嵌入维度必须匹配。
下图说明了交叉注意力的概念。如果我们设置 ,这就等同于自注意力机制。

(注意,查询通常来自解码器,而键和值通常来自编码器。)
在代码中是如何实现的?之前,在本文开头实现自注意力机制时,我们使用了以下代码来计算第二个输入元素的查询以及所有键和值:
输入:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
print("embedded_sentence.shape:", embedded_sentence.shape:)
d_q, d_k, d_v = 24, 24, 28
W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
print("query.shape", query_2.shape)
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)输出:
embedded_sentence.shape: torch.Size([6, 16])
queries.shape: torch.Size([24])
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])在交叉注意力中唯一改变的是我们现在有了第二个输入序列,例如,一个有 8 个而不是 6 个输入元素的第二个句子。这里假设这是一个包含 8 个标记的句子。
输入:
embedded_sentence_2 = torch.rand(8, 16) # 2nd input sequence
keys = W_key.matmul(embedded_sentence_2.T).T
values = W_value.matmul(embedded_sentence_2.T).T
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)输出:
keys.shape: torch.Size([8, 24])
values.shape: torch.Size([8, 28])请注意,与自注意力机制相比,现在的键和值有 8 行而不是 6 行。其他一切保持不变。
我们在上面谈了很多关于语言转换器的内容。在原始的转换器架构中,当我们从输入句子到输出句子进行语言翻译时,交叉注意力机制非常有用。输入句子代表一个输入序列,而翻译则代表第二个输入序列(这两个句子可以包含不同数量的词)。
另一个广泛使用交叉注意力的模型是 Stable Diffusion。Stable Diffusion 在 U-Net 模型生成的图像和用于条件约束的文本提示之间使用交叉注意力,这在《使用潜在扩散模型进行高分辨率图像合成》这篇原始论文中有详细描述,该论文后来被 Stability AI 采纳以实现广受欢迎的 Stable Diffusion 模型。
