Understanding Transformer Architecture: From Attention Mechanism to LLM Foundation
A systematic exploration of Transformer architecture — from the original "Attention is All You Need" paper to modern LLM foundations, with code implementations and hardware-aware insights for EDA researchers.
Transformer架构是现代大语言模型(LLM)的基石,从2017年Google提出的”Attention is All You Need”开始,彻底改变了自然语言处理(NLP)领域。本文将从AI Infra视角系统性地解析Transformer的核心组件、计算特性及其在硬件加速中的意义。
一、背景与动机
1.1 序列建模的演进
在Transformer出现之前,序列建模主要依赖RNN(循环神经网络)及其变体LSTM和GRU。
RNN/LSTM的局限性
RNN的核心问题在于其顺序计算特性——当前隐藏状态依赖于前一时刻的隐藏状态,这导致:
- 梯度消失/爆炸:长序列训练困难
- 无法并行:时间步之间存在数据依赖
- 长距离依赖衰减:信息在传递过程中逐渐丢失
即便LSTM通过门控机制部分缓解了梯度问题,但其本质仍是顺序计算,计算效率低下。
为什么需要注意力机制?
注意力机制(Attention Mechanism)最早应用于序列到序列(Seq2Seq)模型中,通过直接建模任意位置之间的依赖关系,解决了RNN的长距离依赖问题。
核心思想:“我应该关注输入序列中的哪些部分?”
注意力机制允许模型在处理每个位置时,都能”看到”输入序列的所有位置,并根据相关性动态分配权重。
1.2 Transformer的诞生
2017年,Google Brain团队发表了里程碑论文《Attention Is All You Need》,首次提出了完全基于注意力机制的Transformer架构:
- 摒弃了RNN:全部使用注意力机制建模序列关系
- 支持并行计算:大幅提升训练效率
- 可扩展性强:为后续BERT、GPT等预训练模型奠定基础
为什么Transformer如此重要?对于AI Infra研究者而言,理解其计算特性是设计AI加速器(如CIM、Chiplet架构)的关键基础。
二、注意力机制详解
2.1 Query-Key-Value 三元组
注意力机制的核心是Query(查询)、Key(键)、Value(值)三个向量:
| 符号 | 含义 | 直观理解 |
|---|---|---|
| Q | Query | “我在找什么信息?” |
| K | Key | “我有哪些信息?” |
| V | Value | “信息的实际内容是什么?” |
对于输入序列中的每个token,我们用Query去”查询”所有位置的Key,计算与各位置的相关性(注意力权重),再用这些权重对Value进行加权求和。
2.2 缩放点积注意力(Scaled Dot-Product Attention)
Transformer使用的是缩放点积注意力,公式如下:
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]计算步骤分解
- 计算点积:$QK^T$ — 计算Query与所有Key的相似度
- 缩放:$\div \sqrt{d_k}$ — 防止点积值过大导致梯度消失
- Softmax:归一化得到注意力权重(和为1)
- 加权求和:权重与Value相乘得到输出
为什么需要缩放?当 $d_k$ 较大时,点积的值会方差变大,导致Softmax输出趋于one-hot(梯度接近0)。除以 $\sqrt{d_k}$ 可以稳定梯度。
数学推导
假设 $q$ 和 $k$ 是独立随机变量,均值为0,方差为1,则点积 $q \cdot k$ 的均值为0,方差为 $d_k$。
缩放后: \(\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1\)
这确保了Softmax函数工作在梯度较稳定的区域。
2.3 代码实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力实现"""
def __init__(self):
super().__init__()
def forward(self, Q, K, V, mask=None):
"""
Args:
Q: [batch_size, num_heads, seq_len, d_k]
K: [batch_size, num_heads, seq_len, d_k]
V: [batch_size, num_heads, seq_len, d_v]
mask: [batch_size, 1, seq_len, seq_len] - 可选掩码
Returns:
output: [batch_size, num_heads, seq_len, d_v]
attn_weights: [batch_size, num_heads, seq_len, seq_len]
"""
d_k = Q.size(-1)
# 1. 计算点积并缩放
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# 2. 应用掩码(如果提供)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 3. Softmax得到注意力权重
attn_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
output = torch.matmul(attn_weights, V)
return output, attn_weights
在AI硬件加速中,这个运算的核心计算是矩阵乘法 $QK^T$ 和后续的矩阵乘法,这正是CIM(计算存内)架构擅长加速的场景。
三、多头注意力(Multi-Head Attention)
3.1 为什么需要多头?
单头注意力有一个局限:所有注意力头只能学习同一种类型的相关性。
多头注意力的核心思想:将Q、K、V分别投影到多个低维空间,每个空间独立学习不同类型的相关性,然后拼接输出。
数学表达: \(\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O\)
其中每个 $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$
3.2 维度分析
| 参数 | 含义 |
|---|---|
| $d_{model}$ | 模型维度(通常为512、768、1024等) |
| $h$ | 注意力头数量(通常为8、12、16) |
| $d_k = d_v = d_{model} / h$ | 每个头的维度 |
计算复杂度对比:
- 单头注意力:$O(d_{model}^2)$
- 多头注意力:仍然是 $O(d_{model}^2)$
- 多头只是将计算”分块”,总计算量不变,但表达能力增强
3.3 代码实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class MultiHeadAttention(nn.Module):
"""多头注意力实现"""
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性投影层
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention()
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 1. 线性投影 + 分头
Q = self.W_Q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_K(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_V(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 2. 计算注意力
x, attn_weights = self.attention(Q, K, V, mask)
# 3. 拼接多头输出
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 4. 最终线性投影
output = self.W_O(x)
return output, attn_weights
硬件视角:多头注意力中各头之间是独立的,这意味着在硬件层面可以实现并行计算,是设计AI加速器时的重要优化点。
四、位置编码(Positional Encoding)
4.1 为什么需要位置信息?
Transformer本身是排列不变(permutation invariant)的——输入序列 [A, B, C] 和 [C, B, A] 在Transformer看来没有区别,因为自注意力只建模token之间的关系,不建模位置关系。
为了让模型感知序列顺序,需要人为注入位置信息,这就是位置编码(Positional Encoding)。
4.2 三角函数位置编码
原始Transformer使用正弦和余弦函数生成位置编码:
\(PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)\) \(PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)\)
其中 $pos$ 是位置,$i$ 是维度索引。
性质
| 性质 | 含义 |
|---|---|
| 周期性 | 不同频率的正弦/余弦波 |
| 唯一性 | 每个位置有唯一的编码向量 |
| 相对距离 | $PE(pos+k)$ 可以表示为 $PE(pos)$ 的线性函数 |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class PositionalEncoding(nn.Module):
"""位置编码实现"""
def __init__(self, d_model, max_len=5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float) *
(-math.log(10000.0) / d_model)
)
# 偶数维度用sin,奇数维度用cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 添加batch维度并注册为buffer(不参与训练)
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
"""
# 将位置编码加到输入上
return x + self.pe[:, :x.size(1), :]
4.3 现代位置编码方案
| 方案 | 提出时间 | 特点 |
|---|---|---|
| Sinusoidal PE | 2017 | 原始Transformer使用,固定不变 |
| Learnable PE | 2018 | BERT等采用,可学习参数 |
| RoPE | 2021 | 旋转位置编码,LLM常用(如LLaMA) |
| ALiBi | 2022 | 无需显式位置编码,外推性好 |
RoPE(Rotary Position Embedding) 是现代LLM(如LLaMA)的标配,它通过旋转操作将位置信息融入Query和Key向量,而非简单加法。这种方式具有更好的长度外推性。
五、前馈神经网络(FFN)
5.1 FFN在Transformer中的作用
每个Transformer层除了注意力模块,还包含一个前馈神经网络(Feed-Forward Network):
\[\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2\]FFN的核心作用:
- 非线性变换:为每个token独立应用非线性激活
- 特征提取:增加模型容量和表达能力
- 位置独立:与注意力不同,FFN对每个位置独立计算
5.2 FFN的计算特性
1
2
3
4
5
6
7
8
9
10
11
12
13
class FeedForward(nn.Module):
"""Transformer中的FFN模块"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.activation = nn.GELU() # 现代Transformer常用GELU
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(
self.linear2(self.activation(self.linear1(x)))
)
参数配置:
- $d_{model}$:通常为512、768、1024
- $d_{ff}$:通常是 $4 \times d_{model}$(如2048、3072、4096)
- 激活函数:ReLU(原始)→ GELU(现代BERT、GPT)
硬件视角:FFN本质是两个矩阵乘法(GEMM),与注意力计算占比约为 2:1。在AI加速器设计中,需要同时优化两种计算模式。
六、完整Transformer架构
6.1 编码器(Encoder)
结构组成
每个编码器层包含:
- 多头自注意力层(Multi-Head Self-Attention)
- 残差连接 + 层归一化(Add & Norm)
- FFN层
- 残差连接 + 层归一化(Add & Norm)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class EncoderLayer(nn.Module):
"""单层Transformer编码器"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差归一化
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# FFN + 残差归一化
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
Layer Norm的作用:
- 稳定训练:使每层输入分布接近标准正态分布
- 加速收敛:减少Internal Covariate Shift
6.2 解码器(Decoder)
与编码器的区别
| 组件 | 编码器 | 解码器 |
|---|---|---|
| 自注意力 | ✅ 标准 | ✅ 带掩码(防止看到未来) |
| 交叉注意力 | ❌ | ✅ Q来自解码器,K/V来自编码器 |
| FFN | ✅ | ✅ |
掩码机制(MAsked Multi-Head Attention)
解码器的自注意力必须是因果的(Causal),即每个位置只能看到当前及之前的位置:
1
2
3
4
5
6
输入: "<start> A B C"
掩码:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
1
2
3
4
5
6
7
8
def create_causal_mask(seq_len):
"""创建因果掩码"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
# 使用时
mask = create_causal_mask(seq_len)
attn_output, _ = self.self_attn(Q, K, V, mask)
6.3 计算复杂度分析
这是AI Infra研究的核心关注点!
自注意力的复杂度
对于序列长度为 $n$、模型维度为 $d$ 的情况:
| 操作 | 复杂度 | 说明 |
|---|---|---|
| $QK^T$ | $O(n^2 \cdot d)$ | 最大的瓶颈! |
| Softmax | $O(n^2)$ | 与序列长度平方成正比 |
| 加权求和 | $O(n^2 \cdot d)$ | 同样需要 $n^2$ 操作 |
与RNN的对比
| 架构 | 时间复杂度 | 空间复杂度 | 可并行性 |
|---|---|---|---|
| RNN | $O(n \cdot d^2)$ | $O(d^2)$ | ❌ 顺序依赖 |
| Transformer | $O(n^2 \cdot d + n \cdot d^2)$ | $O(n^2 + d^2)$ | ✅ 完全并行 |
关键洞察:Transformer的计算量随序列长度呈二次方增长,这带来了巨大的优化空间和挑战,也是设计AI加速器时必须考虑的问题。
七、典型变体与应用
7.1 BERT:仅编码器模型
BERT(Bidirectional Encoder Representations from Transformers)仅使用Transformer编码器,核心创新是双向上下文建模。
预训练任务
- MLM(Masked Language Model):随机遮盖15%的token,让模型预测被遮盖的词
- NSP(Next Sentence Prediction):判断句子对是否连续
BERT系列
| 模型 | 参数规模 | 特点 |
|---|---|---|
| BERT-Base | 110M | 12层,768维,12头 |
| BERT-Large | 340M | 24层,1024维,16头 |
| RoBERTa | 125M | 去NSP,更大数据集 |
| ALBERT | 12M-235M | 参数共享,轻量化 |
7.2 GPT系列:仅解码器模型
GPT(Generative Pre-trained Transformer)仅使用Transformer解码器,核心是自回归语言建模。
与BERT的区别
| 特性 | BERT | GPT |
|---|---|---|
| 注意力方向 | 双向 | 单向(因果) |
| 预训练任务 | MLM + NSP | 语言建模 |
| 典型应用 | 理解任务 | 生成任务 |
| 生成能力 | 有限 | 强大 |
GPT演进
| 模型 | 年份 | 参数 | 关键创新 |
|---|---|---|---|
| GPT-1 | 2018 | 117M | 开创性工作 |
| GPT-2 | 2019 | 1.5B | 扩大规模,zero-shot |
| GPT-3 | 2020 | 175B | In-context learning |
| GPT-4 | 2023 | 未公开 | 多模态,RLHF |
7.3 Transformer在AI硬件加速中的角色
这部分与你的EDA研究直接相关!
计算特性总结
| 组件 | 主要计算 | 存储需求 |
|---|---|---|
| 自注意力 | $QK^T$、$softmax$、$WV$ | $O(n^2)$ 注意力矩阵 |
| FFN | 两次GEMM | $O(d^2)$ 权重矩阵 |
硬件优化方向
- 稀疏注意力:减少 $O(n^2)$ 的计算量(如Sparse Transformer)
- Flash Attention:IO感知的注意力计算优化
- 计算存内(CIM):在存储单元内完成矩阵乘法(如3D-CIMlet)
- 混合精度:FP16/BF16/INT8量化减少带宽
结合你的DAC2025论文《3D-CIMlet》,可以看到存算一体架构在加速Transformer推理中的潜力:RRAM CIM适合处理静态权重(Q/K/V投影),eDRAM适合动态激活值。
八、代码实战:完整Transformer实现
8.1 简化版Transformer
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import math
class SimpleTransformer(nn.Module):
"""简化版Transformer(编码器+解码器)"""
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len=5000):
super().__init__()
self.d_model = d_model
# 词嵌入 + 位置编码
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
# 编码器
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
# 解码器
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
# 输出层
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# 编码器
enc_output = self.embedding(src) * math.sqrt(self.d_model)
enc_output = self.pos_encoding(enc_output)
for enc_layer in self.encoder_layers:
enc_output = enc_layer(enc_output, src_mask)
# 解码器
dec_output = self.embedding(tgt) * math.sqrt(self.d_model)
dec_output = self.pos_encoding(dec_output)
for dec_layer in self.decoder_layers:
dec_output = dec_layer(dec_output, enc_output, tgt_mask)
# 输出投影
output = self.fc(dec_output)
return output
8.2 使用示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 模型参数配置(以BERT-Base为例)
VOCAB_SIZE = 30522
D_MODEL = 768
NUM_HEADS = 12
NUM_LAYERS = 12
D_FF = 3072 # 4 * 768
# 创建模型
model = SimpleTransformer(
vocab_size=VOCAB_SIZE,
d_model=D_MODEL,
num_heads=NUM_HEADS,
num_layers=NUM_LAYERS,
d_ff=D_FF
)
# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}") # ~110M
九、总结
核心要点回顾
| 模块 | 关键公式/概念 | 计算特点 |
|---|---|---|
| 注意力机制 | $\text{softmax}(QK^T / \sqrt{d_k})V$ | $O(n^2 \cdot d)$,序列越长开销越大 |
| 多头注意力 | 多个头并行学习不同相关性 | 表达能力增强,计算量不变 |
| 位置编码 | Sinusoidal/RoPE等 | 注入序列顺序信息 |
| FFN | $\max(0, xW_1)W_2$ | $O(n \cdot d_{ff})$,通常 $d_{ff} = 4d$ |
| Layer Norm | $\frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta$ | 稳定训练,加速收敛 |
AI Infra视角的关键洞察
- 计算瓶颈:自注意力的 $O(n^2)$ 复杂度是主要瓶颈
- 存储需求:注意力矩阵需要 $O(n^2)$ 显存,长序列场景下不可忽视
- 并行潜力:Transformer天然支持并行,为硬件加速提供机会
- 量化友好:矩阵乘法天然适合低精度计算(INT8/FP16)
Transformer不仅是深度学习的里程碑,更是AI Infra研究的核心载体。理解其架构原理,是设计下一代AI加速器的前提。
📚 参考资料
- Vaswani et al. “Attention Is All You Need.” NeurIPS 2017
- Devlin et al. “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” NAACL 2019
- Radford et al. “Language Models are Unsupervised Multitask Learners.” OpenAI Technical Report 2019
- Brown et al. “Language Models are Few-Shot Learners.” NeurIPS 2020
- Dao et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022