CS336 Lecture Notes 2

6309 字
32 分钟
CS336 Lecture Notes 2
Warning

本文为本人学习相关开源课程过程中,整理的个人学习笔记及作业解答,核心目的仅用于记录个人学习轨迹、巩固所学知识、梳理学习思路,全程为个人自主学习使用,不具备任何商业用途,也不构成任何形式的课程辅导或标准答案参考。

需特别说明的是,由于本人学习进度及知识储备有限,笔记内容及作业解答中可能存在大量纰漏、思路偏差甚至错误,仅代表本人当时的学习理解,不具备权威性和准确性。

在此郑重提醒:请勿将本文中的任何作业解答复制粘贴,作为自身所修课程的提交答案。任何因抄袭本文内容导致的课程成绩问题、学术诚信问题,均由抄袭者自行承担全部责任,本人不承担任何相关连带责任。

同时,本文所分享的内容均基于开源课程的公开内容整理,尊重原课程创作者的知识产权,若涉及相关内容的版权问题,请及时联系本人,本人将第一时间进行调整或删除。

感谢各位读者的理解与支持,也欢迎大家针对笔记及解答中的问题提出宝贵建议,共同交流学习、共同进步。

Important

Architectures, hyperparameters#

本节内容基于 Lecture 3,主题为 “Everything You Didn’t Want to Know About LM Architecture and Hyperparameters”,系统性地介绍了现代大语言模型的架构设计选择和超参数配置共识。

Transformer 基础与现代变体#

原始 Transformer

  • 位置编码:正弦余弦
PE(pos,2i)=sin(pos100002idmodel)PE(pos,2i+1)=cos(pos100002idmodel)PE_{(pos, 2i)} = \sin\left( \frac{pos}{10000^{\frac{2i}{d_{model}}}} \right) \\ PE_{(pos, 2i+1)} = \cos\left( \frac{pos}{10000^{\frac{2i}{d_{model}}}} \right)
  • FFN 激活:ReLU
FFN(x)=max(0,xW1+b1)W2+b2\operatorname{FFN}(x) = \max(0, x W_1 + b_1) W_2 + b_2
  • 归一化:Post-LayerNorm

现代简化变体(CS336 课程实现)

  • 归一化:Pre-LayerNorm
  • 位置编码:RoPE
  • FFN 激活:SwiGLU
  • 线性层与 Norm:无偏置项

Pre-norm vs Post-norm#

现代 LLM 几乎全部采用 Pre-norm(LayerNorm 放在 block 前面),而非原始 Transformer 的 Post-norm。

核心原因

  • Pre-norm 保持残差连接的主信号路径不被 LayerNorm 打断
  • 梯度传播更平滑,避免梯度衰减和梯度尖峰问题
  • 支持更大的学习率和更稳定的训练

一些近期模型(如 Grok、Gemma 2、OLMo 2)还会在残差流外部添加额外的 post-norm。

LayerNorm vs RMSNorm#

原始 Transformer 使用 LayerNorm,对特征维度(d_model)上的激活值,做均值归零 + 方差归一化,再加上可学习的缩放(γ\gamma)和偏置(β\beta)参数。 GPT-1/2/3、OPT、GPT-J、BLOOM 等早期 / 开源模型,就是使用的 LayerNorm。

y=xE[x]Var[x]+ϵγ+βy = \frac{x - \mathbb{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} \cdot \gamma + \beta

现代的 LLM 普遍采用 RMSNorm不减去均值、不添加偏置项,只做基于均方根(RMS)的归一化,再乘上缩放参数 γ\gamma。 LLaMA 系列、PaLM、Chinchilla、T5 等主流模型就是采用的 RMSNorm。

y=xx22+ϵγy = \frac{x}{\sqrt{\|x\|_2^2 + \epsilon}} \cdot \gamma

优势

  • 不计算均值,操作更少
  • 没有 bias 参数,存储更少
  • 实测显示:FLOPs 不是唯一考量,数据移动(memory access)同样重要,RMSNorm 在实际运行时间上有优势

更广泛地,现代 transformer 大多数去掉 bias 项,原因同样是内存和优化稳定性。

激活函数:从 ReLU 到 SwiGLU#

激活函数公式代表模型
ReLUmax(0,xW1)W2\max(0, xW_1)W_2原始 Transformer, T5, Gopher, Chinchilla, OPT
GeLUGELU(xW1)W2\text{GELU}(xW_1)W_2GPT-1/2/3, GPTJ, GPT-Neox, BLOOM
GeGLU(GELU(xW1)xV)W2(\text{GELU}(xW_1) \otimes xV)W_2T5 v1.1, mT5, LaMDA, Phi3, Gemma 2/3/4
SwiGLU(Swish(xW1)xV)W2(\text{Swish}(xW_1) \otimes xV)W_2LLaMA 1/2/3, PaLM, Mistral, OlMo, 大多数 2023+ 模型

GLU(Gated Linear Units) 的核心思想是将线性层改为门控形式:

FFReGLU(x)=(max(0,xW1)xV)W2\text{FF}_{\text{ReGLU}}(x) = (\max(0, xW_1) \otimes xV)W_2

其中 VV 是额外的参数矩阵。实证研究表明 SwiGLU/GeGLU 有稳定的性能提升。

注意:GLU 变体的 FFN 维度通常缩小为 83dmodel\frac{8}{3}d_{\text{model}},而非标准的 4dmodel4d_{\text{model}}

位置编码:RoPE 的崛起#

正弦位置编码(Sine Embeddings):用固定的、非可学习的正弦 / 余弦函数生成位置信息,和词向量相加。

代表模型:原始 Transformer。

Embed(x,i)=vx+PEpos\text{Embed}(x, i) = v_x + PE_{pos}\\PE(pos,2i)=sin(pos100002idmodel)PE(pos,2i+1)=cos(pos100002idmodel)\begin{aligned} PE_{(pos, 2i)} &= \sin\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right) \\ PE_{(pos, 2i+1)} &= \cos\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right) \end{aligned}

绝对位置嵌入(Absolute Embeddings):为每个位置单独训练一个可学习的向量,直接加到词向量上。

代表模型:GPT-1/2/3、OPT。

Embed(x,i)=vx+ui\text{Embed}(x, i) = v_x + u_i

相对位置嵌入(Relative Embeddings):不直接给每个位置编码,而是把 “相对位置差” 作为偏置,加到注意力分数的计算中。

代表模型:T5、Gopher、Chinchilla。

eij=xiWQ(xjWK+aijK)Tdze_{ij} = \frac{x_i W^Q (x_j W^K + a_{ij}^K)^T}{\sqrt{d_z}}

旋转位置编码(RoPE Embeddings)#

代表模型:GPT-J、PaLM、LLaMA 系列、绝大多数 2024 年后的模型。

RoPE 提出了一个理想的相对位置编码应该满足的条件:

f(x,i),f(y,j)=g(x,y,ij)\langle f(x, i), f(y, j) \rangle = g(x, y, i-j)
  • 左边:是位置 ii 的词向量 xx 和位置 jj 的词向量 yy,经过位置编码函数 ff 后的内积(也就是注意力计算的核心)。
  • 右边:这个内积的结果,应该只和两个词的内容 xxyy,以及它们的相对位置差 iji−j 有关,而和它们的绝对位置 iijj 无关。

简单说:注意力分数只由 “两个词离得有多远” 决定,而不是它们在序列里的绝对位置,这样模型才能真正学到 “相对位置关系”,并且能外推到更长的序列。

现有方法的缺陷

  • Sine:它的内积展开式里面的PEi,vy\langle PE_i​,v_y \rangle​vx,PEj\langle v_x​,PE_j \rangle 会引入和绝对位置 iijj 相关的项,导致内积不仅和 iji−j 有关,还和绝对位置有关,破坏了纯相对位置的特性。
Embed(x,i),Embed(y,j)=vx,vy+PEi,vy+vx,PEj+PEi,PEj\langle \text{Embed}(x,i), \text{Embed}(y,j) \rangle = \langle v_x, v_y \rangle + \langle PE_i, v_y \rangle + \langle v_x, PE_j \rangle + \langle PE_i, PE_j \rangle
  • Absolute:直接给每个位置加一个可学习的向量,内积结果会直接依赖于 iijj,完全不具备 “相对位置” 的性质。而且无法外推到训练时没见过的序列长度。

  • Relative embeddingseij=xiWQ(xjWK+aijK)Tdze_{ij} = \frac{x_i W^Q (x_j W^K + a_{ij}^K)^T}{\sqrt{d_z}} 让注意力计算不再是纯的内积形式

RoPE 要解决的问题,就是让注意力计算只依赖相对位置差,而不是绝对位置。它的灵感来自一个关键数学性质:

向量的内积,在整体旋转下是不变的。

也就是说,两个向量一起旋转相同角度,它们的夹角和内积都不会变。

RoPE 就是利用这个性质,给不同位置的向量加上 “旋转角度”,让内积结果天然只和位置差有关。

超参数共识#

FFN 维度比例#

几乎所有模型遵循 dff=4dmodeld_{\text{ff}} = 4 \cdot d_{\text{model}}(GLU 变体为 83dmodel\frac{8}{3} \cdot d_{\text{model}})。

Modeldff/dmodeld_{\text{ff}}/d_{\text{model}}
标准 Transformer4
PaLM4
Mistral 7B3.5
LLaMA-2 70B3.5
LLaMA 70B2.68
Qwen 14B2.67
DeepSeek 67B2.68
Yi 34B2.85
T5 v1.12.5

极端案例:T5 11B 曾用 64dmodel64 \cdot d_{\text{model}},但后续 T5 v1.1 已回归标准值。

Head 维度#

主流选择:head_dim×num_heads=dmodel\text{head\_dim} \times \text{num\_heads} = d_{\text{model}}

大多数模型这个比例为 1,少数 Google 模型(如 LaMDA, PaLM)有例外。

ModelNum headsHead dimModel dimRatio head_dim×num_heads/dmodel\text{head\_dim} \times \text{num\_heads} / d_{\text{model}}
GPT396128122881
T5128128102416
T5 v1.1646440961
LaMDA12812881922
PaLM48258184321.48
LLaMA26412881921
Qwen 3.5 (27B)2425651201.2

宽深比 (Aspect Ratio)#

模型应该变宽还是变深?多宽以及多深?

大多数模型的 dmodel/nlayersd_{\text{model}} / n_{\text{layers}}100-200 范围:

Modeldmodel/nlayerd_{model}/n_{layer}
BLOOM205
T5 v1.1171
PaLM (540B)156
GPT-3/OPT/Mistral/Qwen/OLMo 3128
LLaMA/LLaMA2102
Gamma 387
Gamma 461
T5 11B33 (极端深)

过深的模型并行化困难、延迟高,但 [Kaplan et al 2020] 显示小模型时深一点更好。

词汇表大小#

  • 单语言模型:30k-50k
  • 多语言/生产模型:100k-250k(GPT-4 ~100k, Gemma 4 ~262k)

Dropout 与 Weight Decay#

趋势:新模型大多不用 Dropout,仅靠 Weight Decay。

原因:数据量大(万亿 tokens),单遍 SGD 不易过拟合;Weight Decay 主要影响优化动力学而非泛化。

稳定性技巧#

训练过程最好不要像蓝色曲线那样
训练过程最好不要像蓝色曲线那样

训练大模型时,loss spike 是常见问题。以下技巧用于稳定训练:

Z-loss#

Transformer 语言模型的输出层用 Softmax 计算 token 概率:

P(x)=eUr(x)Z(x)P(x) = \frac{e^{U_r(x)}}{Z(x)}

其中归一化项(也叫配分函数)是:

Z(x)=r=1VeUr(x)Z(x) = \sum_{r'=1}^{|V|} e^{U_{r'}(x)}

取对数后,我们得到模型的对数概率:

log(P(x))=Ur(x)log(Z(x))\log(P(x)) = U_r(x) - \log(Z(x))

训练时,我们优化的目标就是最大化这个对数概率(交叉熵损失,等价于最小化负对数概率)。

随着模型变大、logits(也就是 Ur(x)U_r​(x))的数值范围变宽,Z(x)Z(x) 会出现两种极端情况:

  1. logits 太大:指数项 eUre^{U_r}​ 爆炸,Z(x)Z(x) 变得极大,导致 log(Z(x))log(Z(x)) 也变得非常大,log(P(x))log(P(x)) 被压得很小,梯度消失。
  2. logits 太小:指数项 eUre^{U_r}​ 接近 0,Z(x)Z(x) 接近 1,log(Z(x))log(Z(x)) 接近 0,但此时模型可能过于自信,也容易出现梯度问题。

log(Z(x))log(Z(x)) 出现剧烈波动时,损失和梯度都会变得不稳定,导致训练过程震荡甚至崩溃。

为了解决这个问题,我们在交叉熵损失上额外加一个正则项,也就是 Z-loss:

L=i[log(P(xi))αlog2(Z(xi))]L = \sum_i \left[ \log(P(x_i)) - \alpha \log^2(Z(x_i)) \right]

PaLM 首次大规模验证了 Z-loss 的有效性,后续很多模型,如 Baichuan 2、DCLM、OLMo 2/3 等都沿用了这个做法。

QK-norm#

在 attention softmax 前对 Query 和 Key 做 LayerNorm/RMSNorm。

标准自注意力公式:

Attention(Q,K,V)=softmax(QKdk)V\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V

QK Norm 定义(对 Q 和 K 分别归一化):

Q^=Norm(Q),K^=Norm(K)\hat{Q} = \mathrm{Norm}(Q), \quad \hat{K} = \mathrm{Norm}(K)

加入 QK Norm 的最终注意力公式:

Attention(Q,K,V)=softmax(Q^K^dk)V\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left( \frac{\hat{Q} \hat{K}^\top}{\sqrt{d_k}} \right) V

使用模型:DCLM, OLMo 2, Gemma 2, Qwen 3。

Logit Soft-capping#

用 tanh 限制 logits 范围:

logits=soft_captanh(logits/soft_cap)\text{logits} = \text{soft\_cap} \cdot \tanh(\text{logits}/\text{soft\_cap})
  1. 先把 logits 除以一个缩放因子 soft_cap(比如 50 或 30);
  2. 再用 tanh 函数做非线性变换,将值限制在 [-1, 1] 之间;
  3. 最后乘回 soft_cap,最终 logits 就被限制在 [-soft_cap, soft_cap] 之间。

注意力优化#

默认情况下,LLM 都用Multi-Head Attention(MHA),每个 token 生成 num_heads 个独立的 Query/Key/Value 头,所有的头的维度加起来等于d_model。 这种设计是通用的、简单且有效的,所以绝大多数模型都没改动。

GQA/MQA:降低推理成本#

推理时的增量计算依赖 KV Cache,内存访问量大。MQA(Multi-Query Attention)和 GQA(Grouped-Query Attention)通过减少 KV heads 来降低内存开销。

问题背景:Attention 的计算与内存特性#

https://medium.com/@joaolages/kv-caching-explained-276520203249
https://medium.com/@joaolages/kv-caching-explained-276520203249

https://medium.com/@joaolages/kv-caching-explained-276520203249
https://medium.com/@joaolages/kv-caching-explained-276520203249

在自回归推理中,每次生成一个新 token 时需要访问之前所有 token 的 Key 和 Value。为了理解为什么 KV Cache 是瓶颈,我们需要分析 Attention 的 FLOPs(浮点运算次数)内存访问量,计算 Arithmetic Intensity(算术强度)

推理时的 FLOPs 估计

生成第 tt 个 token 时,主要张量的形状:

  • 当前 Query QQRh×dh\mathbb{R}^{h \times d_h}hh 个头,每头 dhd_h 维)
  • 历史 Key Cache KKRh×t×dh\mathbb{R}^{h \times t \times d_h}hh 个头,tt 个历史位置,每头 dhd_h 维)
  • 历史 Value Cache VVRh×t×dh\mathbb{R}^{h \times t \times d_h}

Attention 的主要计算:

  1. Query-Key 点积Q×KQ \times K^\top 得到注意力分数矩阵

    • 形状变化:[h,dh]×[h,dh,t][h,t][h, d_h] \times [h, d_h, t]^\top \rightarrow [h, t]
    • 每个头:[dh]×[t,dh]=[t][d_h] \times [t, d_h]^\top = [t],即 dhd_h 维向量与 ttdhd_h 维向量做点积
    • FLOPs:t×dht \times d_h 次乘法 + t×(dh1)t \times (d_h-1) 次加法 2tdh\approx 2td_h
    • hh 个头总计:2htdh2htd_h
  2. Attention 加权 Value:用注意力分数 [h,t][h, t] 加权历史 Value [h,t,dh][h, t, d_h]

    • 形状变化:[h,t]×[h,t,dh][h,dh][h, t] \times [h, t, d_h] \rightarrow [h, d_h]
    • 每个头:[t][t] 个权重加权 [t,dh][t, d_h] 的 Value,得到 [dh][d_h]
    • FLOPs:t×dht \times d_h 次乘法 + t×(dh1)t \times (d_h-1) 次加法 2tdh\approx 2td_h
    • hh 个头总计:2htdh2htd_h

单层 Attention 的总 FLOPs(不含投影):

FLOPsattn4htdh\text{FLOPs}_{\text{attn}} \approx 4htd_h

推理时的内存访问量

设数据类型为 dtype(如 fp16、fp32),字节大小为 sizeof(dtype)\text{sizeof}(\text{dtype})。生成第 tt 个 token 时:

操作张量形状访问量(bytes)
读取当前 Query[h,dh][h, d_h]hdhsizeof(dtype)h \cdot d_h \cdot \text{sizeof}(\text{dtype})
读取历史 Key Cache[h,t,dh][h, t, d_h]htdhsizeof(dtype)h \cdot t \cdot d_h \cdot \text{sizeof}(\text{dtype})
读取历史 Value Cache[h,t,dh][h, t, d_h]htdhsizeof(dtype)h \cdot t \cdot d_h \cdot \text{sizeof}(\text{dtype})
写入当前 K/V 到 Cache[h,2,dh][h, 2, d_h]2hdhsizeof(dtype)2 \cdot h \cdot d_h \cdot \text{sizeof}(\text{dtype})
写入输出[h,dh][h, d_h]hdhsizeof(dtype)h \cdot d_h \cdot \text{sizeof}(\text{dtype})

总内存访问量:

Memoryaccess=(2ht+4h)dhsizeof(dtype)\text{Memory}_{\text{access}} = (2ht + 4h) \cdot d_h \cdot \text{sizeof}(\text{dtype})

主要开销来自读取历史 KV Cache:2htdhsizeof(dtype)2ht \cdot d_h \cdot \text{sizeof}(\text{dtype})

Arithmetic Intensity(算术强度)

算术强度定义为 FLOPs 与内存访问量的比值:

Arithmetic Intensity=FLOPsBytes accessed\text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes accessed}}

代入上述估计:

I=4htdh(2ht+4h)dhsizeof(dtype)=4t(2t+4)sizeof(dtype)I = \frac{4htd_h}{(2ht+4h) \cdot d_h \cdot \text{sizeof}(\text{dtype})} = \frac{4t}{(2t+4) \cdot \text{sizeof}(\text{dtype})}

当序列长度 tt 较大时:

I2sizeof(dtype) FLOPs/byteI \rightarrow \frac{2}{\text{sizeof}(\text{dtype})} \text{ FLOPs/byte}

对于 fp16,I1I \rightarrow 1 FLOP/byte。

关键洞察

典型 GPU 的峰值算力与内存带宽对比:

  • NVIDIA A100:峰值 312 TFLOPS (fp16),带宽 2 TB/s → 理论峰值强度 156 FLOPs/byte
  • NVIDIA RTX 4090:峰值 83 TFLOPS (fp16),带宽 1 TB/s → 理论峰值强度 83 FLOPs/byte

Attention 推理的算术强度 1\approx 1 FLOPs/byte,远低于 GPU 峰值强度,属于 memory-bound(内存受限) 操作:计算单元大部分时间在等待数据加载,而非真正执行计算。

这就是 KV Cache 成为推理瓶颈的根本原因——降低内存访问量比优化计算更关键

MQA:极端压缩方案#

Multi-Query Attention(MQA)由 Shazeer 2019 提出,核心思想是所有 Query 头共享同一个 Key 和 Value 头

MQA 的 FLOPs 估计

MQA 的计算量与 MHA 相同,因为每个 Query 头仍需与 Key 做点积、加权 Value:

  • QK 点积:2htdh2htd_h FLOPs
  • Attention 加权:2htdh2htd_h FLOPs
  • 总计:4htdh4htd_h FLOPs

MQA 的内存访问量

MQA 的关键变化:Key 和 Value 只有1个头(而非 h 个头)。

设数据类型为 dtype,字节大小为 sizeof(dtype)\text{sizeof}(\text{dtype})。生成第 tt 个 token 时:

操作张量形状访问量(bytes)
读取当前 Query[h,dh][h, d_h]hdhsizeof(dtype)h \cdot d_h \cdot \text{sizeof}(\text{dtype})
读取历史 Key Cache[1,t,dh][1, t, d_h]tdhsizeof(dtype)t \cdot d_h \cdot \text{sizeof}(\text{dtype})
读取历史 Value Cache[1,t,dh][1, t, d_h]tdhsizeof(dtype)t \cdot d_h \cdot \text{sizeof}(\text{dtype})
写入当前 K/V 到 Cache[1,2,dh][1, 2, d_h]2dhsizeof(dtype)2 \cdot d_h \cdot \text{sizeof}(\text{dtype})
写入输出[h,dh][h, d_h]hdhsizeof(dtype)h \cdot d_h \cdot \text{sizeof}(\text{dtype})

总内存访问量:

MemoryMQA=(2h+2t+2)dhsizeof(dtype)\text{Memory}_{\text{MQA}} = (2h + 2t + 2) \cdot d_h \cdot \text{sizeof}(\text{dtype})

主要开销来自读取历史 KV Cache:2tdhsizeof(dtype)2t \cdot d_h \cdot \text{sizeof}(\text{dtype}),相比 MHA 的 2htdhsizeof(dtype)2ht \cdot d_h \cdot \text{sizeof}(\text{dtype}) 减少 h 倍

MQA 的 Arithmetic Intensity

FLOPs 与 MHA 相同(4htdh4htd_h),内存访问量大幅减少:

IMQA=4htdh(2h+2t+2)dhsizeof(dtype)I_{\text{MQA}} = \frac{4htd_h}{(2h + 2t + 2) \cdot d_h \cdot \text{sizeof}(\text{dtype})}

当序列长度 tt 较大时,2h+22h + 2 可忽略:

IMQA4htdh2tdhsizeof(dtype)=2hsizeof(dtype)I_{\text{MQA}} \approx \frac{4htd_h}{2td_h \cdot \text{sizeof}(\text{dtype})} = \frac{2h}{\text{sizeof}(\text{dtype})}

以 fp16 为例(sizeof(dtype)=2\text{sizeof}(\text{dtype}) = 2):

IMQAh FLOPs/byteI_{\text{MQA}} \approx h \text{ FLOPs/byte}

对比 MHA 的 IMHA2sizeof(dtype)I_{\text{MHA}} \approx \frac{2}{\text{sizeof}(\text{dtype})}(fp16 下约为 1 FLOP/byte),MQA 将算术强度提升 h 倍

以 LLaMA 2 70B 为例(h=64, fp16):MHA ≈ 1 FLOP/byte,MQA ≈ 64 FLOPs/byte,显著更接近 GPU 峰值强度(A100: 156 FLOPs/byte),从 memory-bound 向 compute-bound 过渡。

代价:多个 Query 头被迫使用相同的 K/V 表示,表达能力受限。实验显示 MQA 在某些任务上性能下降明显。

GQA:折中方案#

Grouped-Query Attention(GQA)由 Ainslie et al. 2023 提出,将 hh 个 Query 头分成 gg 组,每组共享一个 KV 头。典型取值 g = 8 或 g = 4。

GQA 的 FLOPs 估计

与 MHA、MQA 相同,FLOPs 不受 KV 头数影响。

GQA 的内存访问量

GQA 的关键变化:Key 和 Value 有 g 个头(介于 MHA 的 h 个和 MQA 的 1 个之间)。

设数据类型为 dtype,字节大小为 sizeof(dtype)\text{sizeof}(\text{dtype})。生成第 tt 个 token 时:

操作张量形状访问量(bytes)
读取当前 Query[h,dh][h, d_h]hdhsizeof(dtype)h \cdot d_h \cdot \text{sizeof}(\text{dtype})
读取历史 Key Cache[g,t,dh][g, t, d_h]gtdhsizeof(dtype)g \cdot t \cdot d_h \cdot \text{sizeof}(\text{dtype})
读取历史 Value Cache[g,t,dh][g, t, d_h]gtdhsizeof(dtype)g \cdot t \cdot d_h \cdot \text{sizeof}(\text{dtype})
写入当前 K/V 到 Cache[g,2,dh][g, 2, d_h]2gdhsizeof(dtype)2 \cdot g \cdot d_h \cdot \text{sizeof}(\text{dtype})
写入输出[h,dh][h, d_h]hdhsizeof(dtype)h \cdot d_h \cdot \text{sizeof}(\text{dtype})

总内存访问量:

MemoryGQA=(2h+2gt+2g)dhsizeof(dtype)\text{Memory}_{\text{GQA}} = (2h + 2gt + 2g) \cdot d_h \cdot \text{sizeof}(\text{dtype})

主要开销来自读取历史 KV Cache:2gtdhsizeof(dtype)2gt \cdot d_h \cdot \text{sizeof}(\text{dtype}),相比 MHA 的 2htdhsizeof(dtype)2ht \cdot d_h \cdot \text{sizeof}(\text{dtype}) 减少 h/g 倍

GQA 的 Arithmetic Intensity

FLOPs 与 MHA 相同(4htdh4htd_h),内存访问量介于 MHA 和 MQA 之间:

IGQA=4htdh(2h+2gt+2g)dhsizeof(dtype)I_{\text{GQA}} = \frac{4htd_h}{(2h + 2gt + 2g) \cdot d_h \cdot \text{sizeof}(\text{dtype})}

当序列长度 tt 较大时,2h+2g2h + 2g 可忽略:

IGQA4htdh2gtdhsizeof(dtype)=2h/gsizeof(dtype)I_{\text{GQA}} \approx \frac{4htd_h}{2gt \cdot d_h \cdot \text{sizeof}(\text{dtype})} = \frac{2h/g}{\text{sizeof}(\text{dtype})}

以 fp16 为例(sizeof(dtype)=2\text{sizeof}(\text{dtype}) = 2):

IGQAhg FLOPs/byteI_{\text{GQA}} \approx \frac{h}{g} \text{ FLOPs/byte}

GQA 的算术强度介于 MHA(~1)和 MQA(~h)之间。以 LLaMA 2 70B 为例(h=64, g=8, fp16):GQA ≈ 8 FLOPs/byte,虽仍为 memory-bound,但比 MHA 好 8 倍。

实现示意

def gqa_attention(Q, K, V, h=64, g=8):
"""
Q: [h, d_h] - h 个 query 头
K: [g, d_h] - g 个 key 头
V: [g, d_h] - g 个 value 头
"""
heads_per_group = h // g
outputs = []
for i in range(h):
group_idx = i // heads_per_group # 确定 query 头属于哪个组
scores = (Q[i] @ K[group_idx]) / math.sqrt(d_h)
weight = softmax(scores)
out = weight * V[group_idx]
outputs.append(out)
return outputs # [h, d_h]

性能:实验表明 g=8 时与 MHA 差异在 1% 以内。

Sliding Window Attention#

标准 Attention 的计算复杂度为 O(n2)O(n^2),当序列长度 nn 很大时(如长文档、长对话),计算和内存开销成为瓶颈。

Sliding Window Attention(滑动窗口注意力)Longformer (Beltagy et al., 2020) 提出,核心思想是限制每个 token 只能 attend 窗口范围内的相邻 token,而非全部历史。后续 Mistral (Jiang et al., 2023) 将 SWA 与滚动缓存结合,实现了高效的无限长度生成。

原理与计算复杂度

设窗口大小为 ww,每个 token 只与最近的 ww 个 token 计算注意力:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V

其中 KK^\top 的形状从 [n,dk][n, d_k] 变为 [w,dk][w, d_k](只取窗口内的 Key)。

计算复杂度变化:

方案QK 点积复杂度总复杂度
Full AttentionO(n2dk)O(n^2 d_k)O(n2)O(n^2)
Sliding WindowO(nwdk)O(nw d_k)O(nw)O(nw)

wnw \ll n 时,复杂度接近线性。例如 n=8192,w=4096n=8192, w=4096 时,计算量减少一半;w=512w=512 时,减少 16 倍。

实现示意

def sliding_window_attention(Q, K, V, window_size):
"""
Q: [n, d_k] - n 个 query
K: [n, d_k] - n 个 key
V: [n, d_v] - n 个 value
window_size: 窗口大小 w
"""
n = Q.shape[0]
outputs = []
for i in range(n):
# 确定窗口范围: [i - window_size + 1, i]
start = max(0, i - window_size + 1)
end = i + 1
# 只取窗口内的 K 和 V
K_window = K[start:end] # [w, d_k]
V_window = V[start:end] # [w, d_v]
# 计算 attention
scores = Q[i] @ K_window.T / math.sqrt(d_k)
weights = softmax(scores)
out = weights @ V_window
outputs.append(out)
return outputs # [n, d_v]

窗口设计的权衡

窗口大小计算节省信息捕获适用场景
w=nw = n全局信息短序列
w=n/2w = n/250%半全局中等长度
w=512w = 512显著局部上下文长文档
w=128w = 128极大极局部极长序列

窗口过小会丢失全局依赖(如段落间引用、跨句语义),窗口过大则计算收益有限。

现代混合策略

单纯使用 SWA 会限制全局信息流动。现代做法是交错使用全注意力和局部注意力层

模型混合策略设计考量
Gemma 2/4每 4 层一层全注意力大部分层用 SWA 降低开销,少量全注意力层传递全局信息
MistralSWA(w=4096w=4096)+ 滚动缓存用旋转缓存实现无限生成,窗口内信息不断滚动更新
OLMo 3SWA + Full RoPE局部注意力配合 RoPE 位置编码,保持相对位置信息
Longformer局部 + 全局 token对关键位置(如 [CLS]、标题)用全局注意力,其余用 SWA

关键洞察:混合策略在计算效率与全局信息捕获之间取得平衡。全注意力层充当”信息枢纽”,将局部信息汇聚后传递到全局。

Attention alternatives and mixture of experts#

本节内容基于 Lecture 4,主题为 “Attention Alternatives and Mixtures of Experts”,探讨了大语言模型中降低注意力计算成本的替代方案,以及混合专家模型(MoE)的设计与训练方法。

Attention Alternatives#

随着上下文长度增加,标准注意力机制的计算成本(O(n2)O(n^2))成为瓶颈。解决方案分为两个方向:工程优化(如结合局部+全局注意力)和架构革新。

Linear Attention#

核心思想:利用矩阵乘法的结合律重排注意力计算:

Attn(Q,K,V)=ρ(QK)V=Q(KV)\text{Attn}(Q, K, V) = \rho(QK^\top)V = Q(K^\top V)

n2dk+n2dvn^2 d_k + n^2 d_v 降到 2ndvdk2n d_v d_k,实现线性时间复杂度。

递归形式(RNN duality):

St=St1+ktvt,yt=qtStS_t = S_{t-1} + k_t v_t^\top, \quad y_t = q_t^\top S_t

这种”对偶性”允许:

  • 训练时使用并行二次形式(效率高)
  • 推理时使用串行线性形式(内存友好)

若给 St1S_{t-1} 加权重因子 γ\gamma,得到 RetNet

实际应用:MiniMax M1 使用 7:1 混合(7 层线性注意力 + 1 层全注意力),在长上下文场景表现良好。

Mamba-2:带门控的线性注意力#

对线性注意力引入逐位置权重:

St=γtSt1+ktvt,yt=qtSt+vtDS_t = \gamma_t S_{t-1} + k_t v_t^\top, \quad y_t = q_t^\top S_t + v_t^\top D

其中 γt=f(xt)\gamma_t = f(x_t) 是数据依赖的门控因子。相比线性注意力,门控机制提供了更强的表达能力,同时保持对偶性质(并行计算 γ\gamma,应用对偶变换)。

实际应用:Nemotron 3 采用 Mamba/Attention 3:1 混合,性能与同规模稠密模型相当或更好。

Gated Delta Net (GDN)#

进一步泛化:门控输入 + 选择性擦除状态:

St=γt(Iβtktkt)St1+βtktvt,yt=qtStS_t = \gamma_t(I - \beta_t k_t k_t^\top)S_{t-1} + \beta_t k_t v_t^\top, \quad y_t = q_t^\top S_t

关键特性:

  • βt\beta_t 控制”无输入操作”(β=0\beta = 0时不更新)
  • (Iβtktkt)(I - \beta_t k_t k_t^\top) 除当前 key 方向的信息(选择性遗忘)

与 fast weight programming / test-time training 思想密切相关。

实际应用:Qwen 3.5 / Qwen Next 采用 3:1 GDN/Attention 混合,推理效率良好。

Hybrid 性能总结#

混合架构(线性注意力/SSM + 全注意力)的主流配置:

模型混合比例特点
MiniMax M17:1线性注意力为主
Nemotron 33:1Mamba + Attention
Qwen 3.53:1GDN + Attention

实证表明:较小的混合比例(少量全注意力层)通常能保持接近全注意力的性能,同时显著降低长上下文成本。

稀疏注意力 (DSA)#

替代混合方案:不 attend 所有 token,使用稀疏注意力(DeepSeek Sparse Attention, v3.2 / GLM5)。

特点:

  • Indexer 可很轻量,显著降低开销
  • 可在稠密短上下文预训练后”事后适配”

Mixture of Experts (MoE)#

MoE 的核心思想:用多个专家 FFN + 路由层替代单一 FFN,激活参数量保持不变但总参数量大幅增加。

为什么 MoE 受欢迎?#

  1. 相同 FLOPs,更多参数 → 更好性能:实证表明增加专家数量不改变推理成本,但提升质量
  2. 训练更快:相同数据量下,MoE 达到相同 loss 需更少步数
  3. 竞争性强:与稠密等价模型相比,MoE 性能优异
  4. 易于并行化:每个专家可放置在不同设备

MoE 架构#

典型设计:替换 MLP 层为 MoE 层。较少见:对 attention heads 使用 MoE(ModuleFormer, JetMoE)。

关键变量

  • 路由函数(Routing function)
  • 专家规模
  • 训练目标

路由机制#

几乎所有 MoE 使用 Top-K 路由:token 选择 top-k 个专家。

路由类型描述使用模型
Top-kToken 选择 top-k 专家大多数 MoE
Hashing基础对比方法-
RL routing学习路由策略早期工作 (Bengio 2013),现已不常用
Linear assignment求解匹配问题Clark 2022 等

Top-K 路由细节

Gate 由 logistic regressor 选择(DeepSeek V1-2, Grok, Qwen)。Mixtral、DBRX、DeepSeek v3 在 TopK 后做 softmax。

近期变体(DeepSeek / Qwen):

  • 更小、更多专家 + 少量”共享专家”(始终激活)
  • 共享专家来自 DeepSpeed MoE,确保通用知识始终可用

专家配置对比

模型总专家数激活数共享专家细粒度比例
GShard204820-
Switch Transformer6410-
Mixtral820-
DBRX1640-
Grok820-
DeepSeek v164621/4
Qwen 1.560441/8
DeepSeek v3256811/14
OlMoE64801/8
MiniMax3220~1/4
Llama 4 (maverick)128111/2

消融实验结论

  • 更多专家 + 共享专家通常有益(DeepSeek 消融)
  • OlMoE:细粒度专家有增益,共享专家无增益(不同架构结论可能不同)

MoE 训练挑战#

核心问题:稀疏路由决策不可微分。

解决方案对比

方法原理使用情况
RL (REINFORCE)学习路由策略有效但梯度方差大、复杂,不常用
随机扰动添加 Gaussian/Uniform 噪声Shazeer 2017, Fedus 2022(后移除)
启发式平衡损失惩罚专家使用不均主流方法

平衡损失(Auxiliary Loss)

来自 Switch Transformer,惩罚专家使用频率不均:

Laux=αifiPi\mathcal{L}_{\text{aux}} = \alpha \sum_i f_i P_i

其中 fif_i 是专家 ii 接收的 token 比例,PiP_i 是路由概率。

DeepSeek v1-2:额外添加设备级平衡损失(按设备聚合)。

DeepSeek v3 变体:per-expert bias + online learning,称为 “auxiliary-loss-free balancing”(实际仍有辅助损失)。

系统并行性#

MoE 天然支持并行:每个 FFN/专家可放置在单独设备。

挑战:

  • Token dropping(batch 级路由可能导致其他请求的 token 被丢弃)
  • 稀疏矩阵操作复杂

解决方案:MegaBlocks 等库使用智能稀疏矩阵乘法;Nemotron 3 下投影激活以减少通信。

训练稳定性#

MoE 路由器可能导致 loss spike。

解决方案:对路由器使用 Float32 + auxiliary z-loss(类似注意力中的 z-loss,防止 logits 过大)。

微调问题#

稀疏 MoE 在小规模微调数据上易过拟合。

解决方案:

  • Zoph 等:微调非 MoE MLP
  • DeepSeek:使用大量 SFT 数据(1.4M)

Upcycling:从稠密模型初始化 MoE#

用预训练稠密模型初始化 MoE,避免从头训练。

成功案例

  • MiniCPM MoE:基于 MiniCPM,top-k=2, 8 专家,~4B 激活参数,520B tokens
  • Qwen MoE:基于 Qwen 1.8B,top-k=4, 60 专家 + 4 共享,首批确认的 upcycling 成功案例

DeepSeek MoE 演进#

版本总参数激活参数关键特性
V116B2.8B标准 top-k;2 共享 + 64/4 细粒度;标准 aux-loss
V2236B21B2 共享 + 160/10 细粒度,6 激活;通信平衡损失;Top-M 设备路由
V3671B37B1 共享 + 258 细粒度,8 激活;Sigmoid+Softmax topK + topM;Aux-loss-free + seq-wise aux

DeepSeek v3 关键技术

  1. MLA (Multi-head Latent Attention):将 Q/K/V 表示为低维”潜在激活”的函数。KV-cache 只需存储 ctKVc_t^{KV},显著压缩。RoPE 与 MLA 缓存冲突的解决方案:保留少量非潜在 key 维度用于旋转。

  2. MTP (Multi-Token Prediction):小型轻量模型预测多步 ahead(类似 EAGLE),v3 只用单步预测。

MoE 总结#

  • MoE 利用稀疏性:并非所有输入需要完整模型
  • 离散路由困难,但 Top-K 启发式方法已被证明有效
  • 大量实证证据表明 MoE 有效且成本效益高
  • 训练稳定性、微调过拟合、系统复杂性是主要挑战,已有成熟解决方案

支持与分享

如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!

赞助
CS336 Lecture Notes 2
https://llm-tech.com.cn/posts/cs336-lec-notes-2/
作者
Ming
发布于
2026-05-04
许可协议
CC BY-NC-SA 4.0
Profile Image of the Author
Ming
你是来找 Ming 学习的吗
🎉 欢迎来到 Ming 的博客
这里是我的个人博客,分享 AI Infra、LLM 等技术内容。欢迎关注交流!
分类
标签
站点统计
文章
19
分类
6
标签
12
总字数
69,591
运行时长
0
最后活动
0 天前

目录