从向量投影到概率分布:彻底搞懂 Transformer 注意力机制的数学本质

引言

自 2017 年 Google Brain 团队发表那篇著名的《Attention Is All You Need》论文以来,Transformer 架构已经彻底颠覆了深度学习的面貌。从自然语言处理(NLP)的绝对霸主 BERT 和 GPT 系列,到如今席卷计算机视觉(CV)、多模态乃至科学计算领域的各种大模型,Transformer 已经成为了现代人工智能的基石。

而在 Transformer 架构的核心中,最精妙、最关键的设计莫过于自注意力机制

对于许多初学者甚至有一定经验的算法工程师来说,注意力机制的代码实现可能倒背如流:QQ(Query)、KK(Key)、VV(Value)三个矩阵,接一个缩放点积,最后过 Softmax。但是,我们是否真正停下来思考过:为什么是这三个矩阵?为什么要点积?为什么要除以 dk\sqrt{d_k}(缩放)?为什么 Softmax 之前和之后的数学意义是什么?

本文将剥开代码的外衣,从纯粹的数学几何与概率论视角出发,带你一步步推演和彻底搞懂 Transformer 注意力机制的数学本质。准备好你的线性代数和微积分直觉,我们开始。


一、 拨开迷雾:什么是“注意力”?

在探讨数学之前,我们先建立直观。在处理一段文本(或一张图像)时,不同的元素(Token)之间是存在依赖关系的。

例如在句子 “The animal didn’t cross the street because it was too tired” 中,“it” 指代的是 “animal” 还是 “street”?人类一眼就能看出是 “animal”,但对机器来说,这需要建立一种机制:在处理 “it” 时,能够“注意”到上下文中的其他词,并赋予 “animal” 更高的权重。

用更抽象的术语来说,注意力机制本质上是一种信息检索机制的泛化。它允许模型在处理当前输入(Query,查询)时,动态地从一组上下文信息(Key-Value,键值对)中提取(Retrieve)所需的信息。


二、 宏观视角的数学框架

在 Transformer 中,注意力函数可以被定义为将一个查询(Query)和一组键值对映射到一个输出的过程。其标准的数学公式如下:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中:

  • QRN×dkQ \in \mathbb{R}^{N \times d_k}:查询矩阵
  • KRM×dkK \in \mathbb{R}^{M \times d_k}:键矩阵
  • VRM×dvV \in \mathbb{R}^{M \times d_v}:值矩阵
  • NNMM 分别代表查询序列和上下文序列的长度。
  • dkd_kdvd_v 分别是键和值的维度。

为了自下而上地理解这个公式,我们将它拆解为四个核心数学步骤:

  1. 线性投影(构造 Q, K, V)
  2. 相似度计算(点积 QKTQK^T
  3. 分布归一化(缩放与 Softmax)
  4. 特征聚合(乘以 V)

三、 第一步:线性变换与高维空间的旋转(Q, K, V 的诞生)

在输入注意力模块之前,我们通常有一个嵌入矩阵 XRN×dX \in \mathbb{R}^{N \times d}。Transformer 并没有直接使用 XX 去做计算,而是引入了三个参数矩阵 WQ,WK,WVW^Q, W^K, W^V,通过线性变换得到 Q,K,VQ, K, V

Q=XWQK=XWKV=XWVQ = X W^Q \\ K = X W^K \\ V = X W^V

数学本质:从公共语义空间到特定任务空间的投影

初始的词向量 XX 处于一个公共的高维语义空间中。在这个空间里,"苹果"的向量包含了水果和科技公司两种语义。但是,在一个特定的句子中,"苹果"的含义是确定的。

线性变换 WQ,WK,WKW^Q, W^K, W^K 本质上是对高维空间进行了一次基的变换(旋转与缩放)

  • Key 空间 (WKW^K):将词语投影到一个“易于被检索”的坐标系中。在这个坐标系里,词语的特征被提炼出来作为“标签”。
  • Query 空间 (WQW^Q):将词语投影到一个“用于发出检索请求”的坐标系中。它与 Key 空间必须是对齐的,因为 Query 和 Key 需要计算相似度。
  • Value 空间 (WVW^V):将词语投影到一个“承载实际内容”的坐标系中。当某个 Key 被命中时,需要把对应的 Value 传递出去。

将同一个输入 XX 映射到三个不同的子空间,赋予了模型强大的表达能力。模型在训练过程中,通过反向传播不断微调 WW 矩阵,学会了如何最好地组织这些“索引”和“内容”。


四、 第二步:点积的几何意义——相似度度量

得到 QQKK 后,注意力机制的核心操作是计算它们的点积:QKTQ K^T

假设 qiq_iQQ 的第 ii 行向量(代表第 ii 个词的查询向量),kjk_jKK 的第 jj 行向量。它们点积的结果构成了注意力矩阵 SS 的一个元素:

Si,j=qikj=t=1dkqi,tkj,tS_{i,j} = q_i \cdot k_j = \sum_{t=1}^{d_k} q_{i,t} k_{j,t}

数学本质:余弦相似度与模长的结合

在线性代数中,两个向量的点积可以表示为:

qikj=qikjcos(θ)q_i \cdot k_j = \|q_i\| \|k_j\| \cos(\theta)

这里 θ\theta 是两个向量在 dkd_k 维空间中的夹角。

  1. cos(θ)\cos(\theta) (方向):反映了两个向量方向上的相似性。如果方向一致(夹角小),余弦值大;方向正交(无关),余弦值为 0;方向相反,余弦值为负。这完美契合了“语义相关性”的度量。
  2. qikj\|q_i\| \|k_j\| (模长):代表了这两个向量的“能量”或“置信度”。模长越长,代表该词作为 Query 的意愿越强烈,或者作为 Key 的特征越显著。

矩阵相乘 QKTQ K^T 在一次操作中,并行地计算了序列中每一个词与其他所有词的相似度。这产生了一个 N×NN \times N(在自注意力中)的热力图矩阵,矩阵中的每一个数值都代表了一对词之间的原始吸引力。


五、 第三步:缩放与 Softmax——概率视角的重构

直接使用点积的结果作为权重是不稳定的,因此 Transformer 引入了缩放和 Softmax 操作:

A=softmax(Sdk)A = \text{softmax}\left(\frac{S}{\sqrt{d_k}}\right)

这里的数学设计极其巧妙。

1. 为什么需要缩放(Scaling by 1/dk1/\sqrt{d_k})?

这是论文中经常被问到的核心问题。假设 qqkk 的分量都是相互独立的、均值为 0、方差为 1 的随机变量。我们来计算点积 S=qkS = q \cdot k 的方差:

Var(qk)=Var(i=1dkqiki)=i=1dkVar(qiki)(假设相互独立)\begin{aligned} \text{Var}(q \cdot k) &= \text{Var}\left(\sum_{i=1}^{d_k} q_i k_i\right) \\ &= \sum_{i=1}^{d_k} \text{Var}(q_i k_i) \quad \text{(假设相互独立)} \end{aligned}

对于独立随机变量的乘积,其方差近似为(因为均值为 0):

Var(qiki)=E[qi2]E[ki2](E[qi]E[ki])2=1×10=1\text{Var}(q_i k_i) = \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] - (\mathbb{E}[q_i]\mathbb{E}[k_i])^2 = 1 \times 1 - 0 = 1

因此:

Var(qk)=dk\text{Var}(q \cdot k) = d_k

推演结果:随着维度 dkd_k 的增大(在原论文中 dk=64d_k=64),点积的方差会线性增大,这导致点积的绝对值会变得非常大!

为什么绝对值变大是个问题? 因为接下来的 Softmax 函数对输入的绝对值大小非常敏感。Softmax 的公式为 exiexj\frac{e^{x_i}}{\sum e^{x_j}}。当输入 xx 的方差极大时,最大值的指数将占主导地位,导致 Softmax 的输出逼近 One-Hot 向量(即最大值趋近于 1,其他趋近于 0)。

这会导致梯度消失。Softmax 函数的梯度为 piyip_i - y_iyiy_i 为真实标签,此处为 1 或 0)。当输出接近 0 或 1 时,梯度趋近于 0,模型在反向传播时几乎学不到任何信息。

解决方案:除以 dk\sqrt{d_k}。由于方差变为原来的 1/dk1/d_k,除以 dk\sqrt{d_k} 后,方差被完美地拉回到了 1。这就保证了无论特征维度 dkd_k 有多大,输入到 Softmax 的数值都能保持在一个合理的范围内,确保梯度的稳定传播。

2. Softmax:将相关性转化为概率分布

经过缩放后的矩阵,输入到 Softmax 中。在数学上,Softmax 实现了两个目的:

  1. 归一化:使得权重之和为 1。这可以理解为当前词 ii 在上下文词 jj 上分配的“注意力概率”。你分配给各个词的注意力总和必须是 100%。
  2. 非线性放大:将原本差异不大的相似度差异放大,使得模型能够更加“集中”注意力在最相关的少数词上,而不是对所有人都平均用力。

六、 第四步:加权求和——上下文感知的特征重构

最后一步,是将得到的注意力权重矩阵 AA 乘以 Value 矩阵 VV

Output=AV\text{Output} = A V

假设我们正在计算第 ii 个词的新表征 outiout_i

outi=j=1NAi,jvjout_i = \sum_{j=1}^{N} A_{i,j} v_j

数学本质:动态的高维向量混合

从几何上看,这是一个向量的凸组合(因为 Ai,jA_{i,j} 之和为 1 且非负)。我们正在以 Ai,jA_{i,j} 为权重,把上下文中所有的 Value 向量 vjv_j 加权求和。

  • 如果第 ii 个词(Query)与第 jj 个词相关性极高,那么 Ai,jA_{i,j} 接近 1。
  • 这意味着在合成 outiout_i 时,我们大量混入了 vjv_j 的信息。
  • 结果:原本 xix_i 只包含当前词的静态字典含义,而经过注意力机制输出的 outiout_i,已经融合了整个句子的上下文信息,成为了一个动态的、上下文感知的表征。

这就是为什么 Transformer 能够完美处理一词多义和长距离依赖的原因。每一个词的特征不再是孤立的,而是整个句子信息的浓缩。


七、 进阶:多头注意力的几何意义

原版 Transformer 使用的是多头注意力。数学公式如下:

MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O

其中

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)

为什么要多头?

如果只有单头注意力,整个句子只能学习到一种相似度模式(例如只能学到语法上的主谓关系,或者只能学到空间距离上的邻近关系)。

从流形学习的角度来看,现实世界的数据往往存在于高维空间的非线性流形上。单一的线性投影 WW 只能捕捉数据在一个超平面上的投影特征。

引入 hh 个头,相当于在 hh 个不同的子空间中独立地进行注意力计算。每个头拥有自己的 WQ,WK,WVW^Q, W^K, W^V 矩阵,这意味着模型可以同时从不同的维度观察句子:

  • Head 1 可能关注语法上的依赖(如形容词修饰名词)。
  • Head 2 可能关注指代消解(如代词指向的实体)。
  • Head 3 可能关注长距离的语义共现。

最后,将这 hh 个不同视角的特征拼接起来,再通过一个线性层 WOW^O 进行一次总体的特征融合与降维。这种设计极大地提升了模型捕捉复杂特征的能力。


八、 代码实战:用 PyTorch 从零手写自注意力

理论讲完了,让我们用代码把数学公式固化下来。下面是一个不依赖任何封装、纯用 PyTorch 实现的自注意力机制。通过代码,你可以更清晰地看到矩阵维度的变化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads=1, dropout=0.1):
"""
embed_dim: 输入特征的维度 (d_model)
num_heads: 注意力头数
"""
super(SelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads

# 在多头注意力中,通常 d_k = d_v = embed_dim / num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim 必须能被 num_heads 整除"

# 定义 Q, K, V 的线性投影矩阵 (合并了多头)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)

# 最终输出的线性变换矩阵
self.out_proj = nn.Linear(embed_dim, embed_dim)

self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
"""
x: 输入张量,形状为 (batch_size, seq_len, embed_dim)
mask: 掩码张量,用于防止某些位置参与计算 (如 padding mask 或 causal mask)
"""
batch_size, seq_len, embed_dim = x.size()

# 1. 线性投影并分割多头
# 形状变化: (B, S, E) -> (B, S, E) -> (B, S, H, D) -> (B, H, S, D)
# 其中 H = num_heads, D = head_dim, E = embed_dim
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

# 2. 计算点积相似度
# q @ k.transpose(-2, -1) 的形状: (B, H, S, D) @ (B, H, D, S) -> (B, H, S, S)
scores = torch.matmul(q, k.transpose(-2, -1))

# 3. 缩放
# 数学意义:防止方差随维度增大而发散,导致 softmax 梯度消失
scores = scores / math.sqrt(self.head_dim)

# (可选) 应用 Mask
# 将 mask=0 的位置替换为一个极小的负数,使其在 softmax 后概率接近 0
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

# 4. Softmax 归一化得到注意力概率
# 在最后一个维度 (Key 的维度) 上进行 softmax
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)

# 5. 乘以 Value 矩阵进行特征聚合
# (B, H, S, S) @ (B, H, S, D) -> (B, H, S, D)
out = torch.matmul(attn_weights, v)

# 6. 拼接多头并做最终的线性投影
# (B, H, S, D) -> (B, S, H, D) -> (B, S, E)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
out = self.out_proj(out)

return out, attn_weights

# 测试代码
if __name__ == "__main__":
batch_size = 2
seq_len = 10
embed_dim = 64
num_heads = 4

# 模拟输入
x = torch.randn(batch_size, seq_len, embed_dim)

# 初始化自注意力层
attn = SelfAttention(embed_dim=embed_dim, num_heads=num_heads)

# 前向传播
output, weights = attn(x)

print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"第 1 个样本,第 1 个词的注意力概率分布 (和应为1): \n{weights[0, 0, 0, :].detach().numpy()}")
print(f"概率分布求和: {weights[0, 0, 0, :].sum().item()}")

代码解析:
在这段代码中,有几个非常关键的维度操作:

  1. transpose(1, 2):将 batch_size 后面的 num_heads 维度提到前面,使得在后续矩阵乘法中,每个头独立计算(并行化)。
  2. k.transpose(-2, -1):转置 Key 矩阵的最后两个维度,这是实现 Q×KTQ \times K^T 矩阵乘法的前提。
  3. math.sqrt(self.head_dim):这就是数学推导中至关重要的 dk\sqrt{d_k}

九、 总结

透过繁杂的公式和代码,Transformer 注意力机制的数学本质其实非常清晰且优雅:

  1. 投影:通过矩阵乘法将原始静态特征投影到三个相互协同的子空间(Q, K, V),赋予系统“查询”、“匹配”和“表达”的能力。
  2. 度量:利用向量的内积(点积)几何意义,衡量不同特征向量之间的相似度。
  3. 稳定与分配:通过缩放因子 dk\sqrt{d_k} 抑制高维方差,通过 Softmax 函数将相似度转化为总和为 1 的概率分布,实现资源的竞争与集中。
  4. 融合:利用得到的概率分布对上下文的 Value 向量进行加权求和(凸组合),完成信息的动态提取与重组。

注意力机制之所以伟大,是因为它打破了传统 RNN 或 CNN 固定的局部感受野限制,用一种全连接但权重动态分配的方式,让模型拥有了根据输入内容自适应调整网络结构的能力。它不仅解决了长距离依赖问题,更为深度学习提供了一种通用的、基于内容的寻址范式。

理解了底层的数学本质,我们在未来面对各种各样的注意力变体(如 Sparse Attention, Linear Attention, Flash Attention 等)时,就能一眼看穿它们的设计动机——它们不过是在不同的维度上,对计算复杂度、信息保留程度和梯度流动性做着各种精妙的妥协与优化。

数学,永远是揭示 AI 魔法的最佳解码器。