Post

Understanding Transformer Architecture: From Attention Mechanism to LLM Foundation

A systematic exploration of Transformer architecture — from the original "Attention is All You Need" paper to modern LLM foundations, with code implementations and hardware-aware insights for EDA researchers.

Understanding Transformer Architecture: From Attention Mechanism to LLM Foundation

Transformer架构是现代大语言模型(LLM)的基石,从2017年Google提出的”Attention is All You Need”开始,彻底改变了自然语言处理(NLP)领域。本文将从AI Infra视角系统性地解析Transformer的核心组件、计算特性及其在硬件加速中的意义。

一、背景与动机

1.1 序列建模的演进

在Transformer出现之前,序列建模主要依赖RNN(循环神经网络)及其变体LSTM和GRU。

RNN/LSTM的局限性

RNN的核心问题在于其顺序计算特性——当前隐藏状态依赖于前一时刻的隐藏状态,这导致:

  • 梯度消失/爆炸:长序列训练困难
  • 无法并行:时间步之间存在数据依赖
  • 长距离依赖衰减:信息在传递过程中逐渐丢失

即便LSTM通过门控机制部分缓解了梯度问题,但其本质仍是顺序计算,计算效率低下。

为什么需要注意力机制?

注意力机制(Attention Mechanism)最早应用于序列到序列(Seq2Seq)模型中,通过直接建模任意位置之间的依赖关系,解决了RNN的长距离依赖问题。

核心思想:“我应该关注输入序列中的哪些部分?”

注意力机制允许模型在处理每个位置时,都能”看到”输入序列的所有位置,并根据相关性动态分配权重。

1.2 Transformer的诞生

2017年,Google Brain团队发表了里程碑论文《Attention Is All You Need》,首次提出了完全基于注意力机制的Transformer架构:

  • 摒弃了RNN:全部使用注意力机制建模序列关系
  • 支持并行计算:大幅提升训练效率
  • 可扩展性强:为后续BERT、GPT等预训练模型奠定基础

为什么Transformer如此重要?对于AI Infra研究者而言,理解其计算特性是设计AI加速器(如CIM、Chiplet架构)的关键基础。


二、注意力机制详解

2.1 Query-Key-Value 三元组

注意力机制的核心是Query(查询)、Key(键)、Value(值)三个向量:

符号含义直观理解
QQuery“我在找什么信息?”
KKey“我有哪些信息?”
VValue“信息的实际内容是什么?”

对于输入序列中的每个token,我们用Query去”查询”所有位置的Key,计算与各位置的相关性(注意力权重),再用这些权重对Value进行加权求和。

2.2 缩放点积注意力(Scaled Dot-Product Attention)

Transformer使用的是缩放点积注意力,公式如下:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

计算步骤分解

  1. 计算点积:$QK^T$ — 计算Query与所有Key的相似度
  2. 缩放:$\div \sqrt{d_k}$ — 防止点积值过大导致梯度消失
  3. Softmax:归一化得到注意力权重(和为1)
  4. 加权求和:权重与Value相乘得到输出

为什么需要缩放?当 $d_k$ 较大时,点积的值会方差变大,导致Softmax输出趋于one-hot(梯度接近0)。除以 $\sqrt{d_k}$ 可以稳定梯度。

数学推导

假设 $q$ 和 $k$ 是独立随机变量,均值为0,方差为1,则点积 $q \cdot k$ 的均值为0,方差为 $d_k$。

缩放后: \(\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1\)

这确保了Softmax函数工作在梯度较稳定的区域。

2.3 代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    """缩放点积注意力实现"""
    def __init__(self):
        super().__init__()
    
    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q: [batch_size, num_heads, seq_len, d_k]
            K: [batch_size, num_heads, seq_len, d_k]
            V: [batch_size, num_heads, seq_len, d_v]
            mask: [batch_size, 1, seq_len, seq_len] - 可选掩码
        Returns:
            output: [batch_size, num_heads, seq_len, d_v]
            attn_weights: [batch_size, num_heads, seq_len, seq_len]
        """
        d_k = Q.size(-1)
        
        # 1. 计算点积并缩放
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        # 2. 应用掩码(如果提供)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 3. Softmax得到注意力权重
        attn_weights = F.softmax(scores, dim=-1)
        
        # 4. 加权求和
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights

在AI硬件加速中,这个运算的核心计算是矩阵乘法 $QK^T$ 和后续的矩阵乘法,这正是CIM(计算存内)架构擅长加速的场景。


三、多头注意力(Multi-Head Attention)

3.1 为什么需要多头?

单头注意力有一个局限:所有注意力头只能学习同一种类型的相关性

多头注意力的核心思想:将Q、K、V分别投影到多个低维空间,每个空间独立学习不同类型的相关性,然后拼接输出。

数学表达: \(\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O\)

其中每个 $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$

3.2 维度分析

参数含义
$d_{model}$模型维度(通常为512、768、1024等)
$h$注意力头数量(通常为8、12、16)
$d_k = d_v = d_{model} / h$每个头的维度

计算复杂度对比

  • 单头注意力:$O(d_{model}^2)$
  • 多头注意力:仍然是 $O(d_{model}^2)$
  • 多头只是将计算”分块”,总计算量不变,但表达能力增强

3.3 代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class MultiHeadAttention(nn.Module):
    """多头注意力实现"""
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性投影层
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention()
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 1. 线性投影 + 分头
        Q = self.W_Q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 2. 计算注意力
        x, attn_weights = self.attention(Q, K, V, mask)
        
        # 3. 拼接多头输出
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 4. 最终线性投影
        output = self.W_O(x)
        
        return output, attn_weights

硬件视角:多头注意力中各头之间是独立的,这意味着在硬件层面可以实现并行计算,是设计AI加速器时的重要优化点。


四、位置编码(Positional Encoding)

4.1 为什么需要位置信息?

Transformer本身是排列不变(permutation invariant)的——输入序列 [A, B, C][C, B, A] 在Transformer看来没有区别,因为自注意力只建模token之间的关系,不建模位置关系。

为了让模型感知序列顺序,需要人为注入位置信息,这就是位置编码(Positional Encoding)。

4.2 三角函数位置编码

原始Transformer使用正弦和余弦函数生成位置编码:

\(PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)\) \(PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)\)

其中 $pos$ 是位置,$i$ 是维度索引。

性质

性质含义
周期性不同频率的正弦/余弦波
唯一性每个位置有唯一的编码向量
相对距离$PE(pos+k)$ 可以表示为 $PE(pos)$ 的线性函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class PositionalEncoding(nn.Module):
    """位置编码实现"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float) * 
            (-math.log(10000.0) / d_model)
        )
        
        # 偶数维度用sin,奇数维度用cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 添加batch维度并注册为buffer(不参与训练)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        # 将位置编码加到输入上
        return x + self.pe[:, :x.size(1), :]

4.3 现代位置编码方案

方案提出时间特点
Sinusoidal PE2017原始Transformer使用,固定不变
Learnable PE2018BERT等采用,可学习参数
RoPE2021旋转位置编码,LLM常用(如LLaMA)
ALiBi2022无需显式位置编码,外推性好

RoPE(Rotary Position Embedding) 是现代LLM(如LLaMA)的标配,它通过旋转操作将位置信息融入Query和Key向量,而非简单加法。这种方式具有更好的长度外推性。


五、前馈神经网络(FFN)

5.1 FFN在Transformer中的作用

每个Transformer层除了注意力模块,还包含一个前馈神经网络(Feed-Forward Network):

\[\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2\]

FFN的核心作用:

  • 非线性变换:为每个token独立应用非线性激活
  • 特征提取:增加模型容量和表达能力
  • 位置独立:与注意力不同,FFN对每个位置独立计算

5.2 FFN的计算特性

1
2
3
4
5
6
7
8
9
10
11
12
13
class FeedForward(nn.Module):
    """Transformer中的FFN模块"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.activation = nn.GELU()  # 现代Transformer常用GELU
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.dropout(
            self.linear2(self.activation(self.linear1(x)))
        )

参数配置

  • $d_{model}$:通常为512、768、1024
  • $d_{ff}$:通常是 $4 \times d_{model}$(如2048、3072、4096)
  • 激活函数:ReLU(原始)→ GELU(现代BERT、GPT)

硬件视角:FFN本质是两个矩阵乘法(GEMM),与注意力计算占比约为 2:1。在AI加速器设计中,需要同时优化两种计算模式。


六、完整Transformer架构

6.1 编码器(Encoder)

结构组成

每个编码器层包含:

  1. 多头自注意力层(Multi-Head Self-Attention)
  2. 残差连接 + 层归一化(Add & Norm)
  3. FFN层
  4. 残差连接 + 层归一化(Add & Norm)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class EncoderLayer(nn.Module):
    """单层Transformer编码器"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # 自注意力 + 残差归一化
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # FFN + 残差归一化
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        
        return x

Layer Norm的作用

  • 稳定训练:使每层输入分布接近标准正态分布
  • 加速收敛:减少Internal Covariate Shift

6.2 解码器(Decoder)

与编码器的区别

组件编码器解码器
自注意力✅ 标准带掩码(防止看到未来)
交叉注意力✅ Q来自解码器,K/V来自编码器
FFN

掩码机制(MAsked Multi-Head Attention)

解码器的自注意力必须是因果的(Causal),即每个位置只能看到当前及之前的位置:

1
2
3
4
5
6
输入: "<start> A B C"
掩码:
[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]
1
2
3
4
5
6
7
8
def create_causal_mask(seq_len):
    """创建因果掩码"""
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]

# 使用时
mask = create_causal_mask(seq_len)
attn_output, _ = self.self_attn(Q, K, V, mask)

6.3 计算复杂度分析

这是AI Infra研究的核心关注点!

自注意力的复杂度

对于序列长度为 $n$、模型维度为 $d$ 的情况:

操作复杂度说明
$QK^T$$O(n^2 \cdot d)$最大的瓶颈!
Softmax$O(n^2)$与序列长度平方成正比
加权求和$O(n^2 \cdot d)$同样需要 $n^2$ 操作

与RNN的对比

架构时间复杂度空间复杂度可并行性
RNN$O(n \cdot d^2)$$O(d^2)$❌ 顺序依赖
Transformer$O(n^2 \cdot d + n \cdot d^2)$$O(n^2 + d^2)$✅ 完全并行

关键洞察:Transformer的计算量随序列长度呈二次方增长,这带来了巨大的优化空间和挑战,也是设计AI加速器时必须考虑的问题。


七、典型变体与应用

7.1 BERT:仅编码器模型

BERT(Bidirectional Encoder Representations from Transformers)仅使用Transformer编码器,核心创新是双向上下文建模

预训练任务

  1. MLM(Masked Language Model):随机遮盖15%的token,让模型预测被遮盖的词
  2. NSP(Next Sentence Prediction):判断句子对是否连续

BERT系列

模型参数规模特点
BERT-Base110M12层,768维,12头
BERT-Large340M24层,1024维,16头
RoBERTa125M去NSP,更大数据集
ALBERT12M-235M参数共享,轻量化

7.2 GPT系列:仅解码器模型

GPT(Generative Pre-trained Transformer)仅使用Transformer解码器,核心是自回归语言建模

与BERT的区别

特性BERTGPT
注意力方向双向单向(因果)
预训练任务MLM + NSP语言建模
典型应用理解任务生成任务
生成能力有限强大

GPT演进

模型年份参数关键创新
GPT-12018117M开创性工作
GPT-220191.5B扩大规模,zero-shot
GPT-32020175BIn-context learning
GPT-42023未公开多模态,RLHF

7.3 Transformer在AI硬件加速中的角色

这部分与你的EDA研究直接相关!

计算特性总结

组件主要计算存储需求
自注意力$QK^T$、$softmax$、$WV$$O(n^2)$ 注意力矩阵
FFN两次GEMM$O(d^2)$ 权重矩阵

硬件优化方向

  1. 稀疏注意力:减少 $O(n^2)$ 的计算量(如Sparse Transformer)
  2. Flash Attention:IO感知的注意力计算优化
  3. 计算存内(CIM):在存储单元内完成矩阵乘法(如3D-CIMlet)
  4. 混合精度:FP16/BF16/INT8量化减少带宽

结合你的DAC2025论文《3D-CIMlet》,可以看到存算一体架构在加速Transformer推理中的潜力:RRAM CIM适合处理静态权重(Q/K/V投影),eDRAM适合动态激活值。


八、代码实战:完整Transformer实现

8.1 简化版Transformer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import math

class SimpleTransformer(nn.Module):
    """简化版Transformer(编码器+解码器)"""
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len=5000):
        super().__init__()
        
        self.d_model = d_model
        
        # 词嵌入 + 位置编码
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # 编码器
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff) 
            for _ in range(num_layers)
        ])
        
        # 解码器
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff) 
            for _ in range(num_layers)
        ])
        
        # 输出层
        self.fc = nn.Linear(d_model, vocab_size)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # 编码器
        enc_output = self.embedding(src) * math.sqrt(self.d_model)
        enc_output = self.pos_encoding(enc_output)
        
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)
        
        # 解码器
        dec_output = self.embedding(tgt) * math.sqrt(self.d_model)
        dec_output = self.pos_encoding(dec_output)
        
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, tgt_mask)
        
        # 输出投影
        output = self.fc(dec_output)
        
        return output

8.2 使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 模型参数配置(以BERT-Base为例)
VOCAB_SIZE = 30522
D_MODEL = 768
NUM_HEADS = 12
NUM_LAYERS = 12
D_FF = 3072  # 4 * 768

# 创建模型
model = SimpleTransformer(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    d_ff=D_FF
)

# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")  # ~110M

九、总结

核心要点回顾

模块关键公式/概念计算特点
注意力机制$\text{softmax}(QK^T / \sqrt{d_k})V$$O(n^2 \cdot d)$,序列越长开销越大
多头注意力多个头并行学习不同相关性表达能力增强,计算量不变
位置编码Sinusoidal/RoPE等注入序列顺序信息
FFN$\max(0, xW_1)W_2$$O(n \cdot d_{ff})$,通常 $d_{ff} = 4d$
Layer Norm$\frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta$稳定训练,加速收敛

AI Infra视角的关键洞察

  1. 计算瓶颈:自注意力的 $O(n^2)$ 复杂度是主要瓶颈
  2. 存储需求:注意力矩阵需要 $O(n^2)$ 显存,长序列场景下不可忽视
  3. 并行潜力:Transformer天然支持并行,为硬件加速提供机会
  4. 量化友好:矩阵乘法天然适合低精度计算(INT8/FP16)

Transformer不仅是深度学习的里程碑,更是AI Infra研究的核心载体。理解其架构原理,是设计下一代AI加速器的前提。


📚 参考资料

  1. Vaswani et al. “Attention Is All You Need.” NeurIPS 2017
  2. Devlin et al. “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” NAACL 2019
  3. Radford et al. “Language Models are Unsupervised Multitask Learners.” OpenAI Technical Report 2019
  4. Brown et al. “Language Models are Few-Shot Learners.” NeurIPS 2020
  5. Dao et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022
This post is licensed under CC BY 4.0 by the author.