在深度学习中,尤其是Transformer架构中,注意力机制起到了至关重要的作用。多头注意力(MHA)作为最初的注意力机制,通过对输入序列进行多头并行处理,允许模型捕捉序列中不同部分之间的复杂关系。然而,随着模型规模的增大和高效推理需求的增加,新的注意力机制不断涌现,例如分组查询注意力(GQA)和多头潜在注意力(MLA)。本文将对这三种注意力机制进行深入比较,详细介绍它们的数学公式、计算过程以及代码实现差异,同时探讨它们在内存占用、计算性能和模型表达能力上的权衡。
在解释公式和代码实现之前,我们先对三种机制做一个基本的概述:
对于MHA,每个头独立计算注意力,其中第 \(i\) 个头的计算过程如下:
\( Q_i = X \cdot W_Q^i \)
\( K_i = X \cdot W_K^i \)
\( V_i = X \cdot W_V^i \)
计算缩放点积注意力:
\( \text{Attention}_i = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right) \cdot V_i \)
最终拼接所有头并投影得到输出:
\( \text{MHA}(X) = \text{Concat}(\text{Attention}_1, \text{Attention}_2, \ldots, \text{Attention}_h) \cdot W_O \)
在GQA中,假设有 \(h\) 个总查询头被分成 \(g\) 组,每组中所有头共享一组键和值投影,因此对于每个查询头 \(i\) 属于组 \(j\),公式如下:
\( Q_i = X \cdot W_Q^i \)
\( K_j = X \cdot W_K^j \)
\( V_j = X \cdot W_V^j \)
注意力计算为:
\( \text{Attention}_i = \text{softmax}\left(\frac{Q_iK_j^T}{\sqrt{d_k}}\right) \cdot V_j \)
最终输出:
\( \text{GQA}(X) = \text{Concat}(\text{Attention}_1, \text{Attention}_2, \ldots, \text{Attention}_h) \cdot W_O \)
MLA的核心在于对键和值进行压缩。首先计算常规的查询、键和值投影:
\( Q = X \cdot W_Q \)
\( K = X \cdot W_K \)
\( V = X \cdot W_V \)
将键和值压缩到潜在空间:
\( K_{\text{latent}} = K \cdot W_{\text{compress}}^K \)
\( V_{\text{latent}} = V \cdot W_{\text{compress}}^V \)
采用缩放点积注意力:
\( \text{Attention} = \text{softmax}\left(\frac{Q (K_{\text{latent}})^T}{\sqrt{d_{\text{latent}}}}\right) \cdot V_{\text{latent}} \)
最后输出:
\( \text{MLA}(X) = \text{Attention} \cdot W_O \)
这种方法利用潜在向量在内存与计算量上的优势,同时保持令牌级别细粒度信息。
# 导入必要模块
import torch
import torch.nn as nn
import math
import torch.nn.functional as F # 用于softmax
# 多头注意力实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0, "d_model必须是num_heads的整数倍"
self.depth = d_model // num_heads
# 查询、键、值以及输出投影层
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.wo = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
# 将最后一维分割为(num_heads, depth)并交换轴
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 对输入进行线性变换
q = self.wq(Q)
k = self.wk(K)
v = self.wv(V)
# 分割为多个头
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
# 计算缩放点积注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.depth)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, v)
# 恢复原始形状后进行输出投影
output = output.permute(0, 2, 1, 3).contiguous()
output = output.view(batch_size, -1, self.d_model)
output = self.wo(output)
return output, attention_weights
# 分组查询注意力实现
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_groups):
super(GroupedQueryAttention, self).__init__()
self.num_heads = num_heads
self.num_groups = num_groups # 分组数, 每组共享一个键和值
self.d_model = d_model
assert d_model % num_heads == 0, "d_model必须是num_heads的整数倍"
assert num_heads % num_groups == 0, "总头数必须能被组数整除"
self.depth = d_model // num_heads
self.wq = nn.Linear(d_model, d_model)
# 注意:键和值的投影维度较低,取决于组数
self.wk = nn.Linear(d_model, d_model // num_groups)
self.wv = nn.Linear(d_model, d_model // num_groups)
self.wo = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size, is_query=False):
if is_query:
# 对查询按头拆分
x = x.view(batch_size, -1, self.num_heads, self.depth)
else:
# 对键和值按组拆分
x = x.view(batch_size, -1, self.num_groups, self.depth * self.num_heads // self.num_groups)
return x.permute(0, 2, 1, 3)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 查询转换
q = self.wq(Q)
# 键和值转换 (注意组内共享)
k = self.wk(K)
v = self.wv(V)
q = self.split_heads(q, batch_size, is_query=True)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
# 计算注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.depth)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, v)
# 恢复形状并投影
output = output.permute(0, 2, 1, 3).contiguous()
output = output.view(batch_size, -1, self.d_model)
output = self.wo(output)
return output, attention_weights
# 多头潜在注意力实现
class MultiHeadLatentAttention(nn.Module):
def __init__(self, d_model, num_heads, latent_dim):
super(MultiHeadLatentAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.latent_dim = latent_dim # 潜在空间维度
assert d_model % num_heads == 0, "d_model必须是num_heads的整数倍"
self.depth = d_model // num_heads
self.wq = nn.Linear(d_model, d_model)
# 原始键和值的投影
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
# 压缩矩阵
self.w_compress_K = nn.Linear(d_model, latent_dim)
self.w_compress_V = nn.Linear(d_model, latent_dim)
# 解压矩阵
self.w_decompress_K = nn.Linear(latent_dim, d_model)
self.w_decompress_V = nn.Linear(latent_dim, d_model)
self.wo = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
# 将输入拆分为多个头并交换轴
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 查询投影与拆分
q = self.wq(Q)
q = self.split_heads(q, batch_size)
# 键和值正常投影
k_proj = self.wk(K)
v_proj = self.wv(V)
# 对键和值进行潜在空间压缩
latent_k = self.w_compress_K(k_proj)
latent_v = self.w_compress_V(v_proj)
# 对每个头重复压缩后的结果
latent_k = latent_k.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
latent_v = latent_v.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
# 计算缩放点积注意力(使用潜在维度)
scores = torch.matmul(q, latent_k.transpose(-2, -1)) / math.sqrt(self.latent_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
output_latent = torch.matmul(attention_weights, latent_v)
# 将潜在空间注意力结果解压回高维空间
output_latent = output_latent.view(batch_size, -1, self.latent_dim)
output = self.w_decompress_V(output_latent)
output = self.wo(output)
return output, attention_weights
下表展示了三种注意力机制在计算和内存方面的主要区别:
注意力机制 | KV缓存大小 | 计算复杂度 | 表达能力 | 适用场景 |
---|---|---|---|---|
MHA | 较大(每个头有独立KV) | 较高 (\(O(n^2)\)) | 捕捉多样化信息 | 资源充足训练与推理 |
GQA | 中等(组内共享KV) | 降低部分计算成本 | 平衡表达能力与效率 | 内存受限与模型推理 |
MLA | 最小(潜在空间:低维) | 显著减少计算开销 | 保持甚至增强表达能力 | 大模型推理与低内存场景 |
从上表可以看出,MHA在表现上最为全面,但由于每个注意力头都分配独立的键和值,内存占用及计算成本较高;GQA通过分组策略部分缓解了此问题;而MLA在引入潜在嵌入后,不仅大幅降低KV缓存需求,同时保存了丰富的上下文信息,适合低内存高性能模型场景。
每种注意力机制都有其独特的优缺点和适用场景:
MHA的优势在于每个头独立计算,即使同一输入可以在不同子空间捕捉多样化的信息。这种机制适合模型训练时资源充足的情况,可以捕获更丰富的语义和上下文信息。
由于每个头单独计算KV,计算复杂度高且内存消耗巨大,尤其对于长序列输入来说,计算成本呈二次增长,这使得在大规模推理场景下可能不够高效。
GQA采用查询头分组策略,由于组内共享键和值,这显著降低了KV存储的需求和计算量。对于需要同时兼顾内存限制和模型质量的应用来说,GQA提供了较好的折中方案。
共享的键和值虽然降低了计算成本,但也可能导致注意力权重的选择不够细粒度,从而在一些对细节捕捉要求高的任务中略逊于完全独立的MHA。
MLA引入压缩和解压矩阵,将高维的键和值数据压缩到低维潜在表示中,这样做能够大幅降低内存消耗,同时保持甚至提升模型表达能力。其优势在于能够在大模型推理和资源受限环境中运行,保持高效且精确的注意力计算。
尽管理论上MLA能实现和MHA相媲美的表达能力,训练和实现过程中对压缩和解压矩阵的学习要求较高,可能在调参和优化上更具挑战性。
本文详细比较了三种核心的注意力机制:
根据具体应用场景和资源约束,可以选择适合的注意力机制。对于追求模型最优表达能力且资源充足的情况,MHA是不二之选;在需要平衡内存和计算效率时,GQA提供了一个较好的方案;当面对大规模模型和推理任务时,MLA由于其对KV缓存的优化,显得更加吸引人,其性能提升在实际应用中已被证明具有重要意义。
通过本文的详细解析,我们可以看出:
总体来说,选择哪种机制需要综合考虑模型的规模、训练与推理的资源限制、实时性要求以及对细粒度信息捕捉的需求。在当前大多数大模型场景下,MLA和GQA正逐步获得更多关注,而传统的MHA仍在需要极高表达能力的任务中占据重要的地位。开发者可以根据具体的数据和应用背景,选择合适的注意力机制,并进一步对代码实现和参数设置进行优化,从而实现更高效、更精准的深度学习模型。