CS336 Lecture Notes 2
本文为本人学习相关开源课程过程中,整理的个人学习笔记及作业解答,核心目的仅用于记录个人学习轨迹、巩固所学知识、梳理学习思路,全程为个人自主学习使用,不具备任何商业用途,也不构成任何形式的课程辅导或标准答案参考。
需特别说明的是,由于本人学习进度及知识储备有限,笔记内容及作业解答中可能存在大量纰漏、思路偏差甚至错误,仅代表本人当时的学习理解,不具备权威性和准确性。
在此郑重提醒:请勿将本文中的任何作业解答复制粘贴,作为自身所修课程的提交答案。任何因抄袭本文内容导致的课程成绩问题、学术诚信问题,均由抄袭者自行承担全部责任,本人不承担任何相关连带责任。
同时,本文所分享的内容均基于开源课程的公开内容整理,尊重原课程创作者的知识产权,若涉及相关内容的版权问题,请及时联系本人,本人将第一时间进行调整或删除。
感谢各位读者的理解与支持,也欢迎大家针对笔记及解答中的问题提出宝贵建议,共同交流学习、共同进步。
- 课程网站: https://cs336.stanford.edu/
- Lec03 资料: lecture_03.pdf
- Lec04 资料: lecture_04.pdf
Architectures, hyperparameters
本节内容基于 Lecture 3,主题为 “Everything You Didn’t Want to Know About LM Architecture and Hyperparameters”,系统性地介绍了现代大语言模型的架构设计选择和超参数配置共识。
Transformer 基础与现代变体
原始 Transformer

- 位置编码:正弦余弦
- FFN 激活:ReLU
- 归一化:Post-LayerNorm
现代简化变体(CS336 课程实现)

- 归一化:Pre-LayerNorm
- 位置编码:RoPE
- FFN 激活:SwiGLU
- 线性层与 Norm:无偏置项
Pre-norm vs Post-norm
现代 LLM 几乎全部采用 Pre-norm(LayerNorm 放在 block 前面),而非原始 Transformer 的 Post-norm。


核心原因:
- Pre-norm 保持残差连接的主信号路径不被 LayerNorm 打断
- 梯度传播更平滑,避免梯度衰减和梯度尖峰问题
- 支持更大的学习率和更稳定的训练


一些近期模型(如 Grok、Gemma 2、OLMo 2)还会在残差流外部添加额外的 post-norm。

LayerNorm vs RMSNorm
原始 Transformer 使用 LayerNorm,对特征维度(d_model)上的激活值,做均值归零 + 方差归一化,再加上可学习的缩放()和偏置()参数。
GPT-1/2/3、OPT、GPT-J、BLOOM 等早期 / 开源模型,就是使用的 LayerNorm。
现代的 LLM 普遍采用 RMSNorm,不减去均值、不添加偏置项,只做基于均方根(RMS)的归一化,再乘上缩放参数 。 LLaMA 系列、PaLM、Chinchilla、T5 等主流模型就是采用的 RMSNorm。
优势:
- 不计算均值,操作更少
- 没有 bias 参数,存储更少
- 实测显示:FLOPs 不是唯一考量,数据移动(memory access)同样重要,RMSNorm 在实际运行时间上有优势


更广泛地,现代 transformer 大多数去掉 bias 项,原因同样是内存和优化稳定性。


激活函数:从 ReLU 到 SwiGLU
| 激活函数 | 公式 | 代表模型 |
|---|---|---|
| ReLU | 原始 Transformer, T5, Gopher, Chinchilla, OPT | |
| GeLU | GPT-1/2/3, GPTJ, GPT-Neox, BLOOM | |
| GeGLU | T5 v1.1, mT5, LaMDA, Phi3, Gemma 2/3/4 | |
| SwiGLU | LLaMA 1/2/3, PaLM, Mistral, OlMo, 大多数 2023+ 模型 |
GLU(Gated Linear Units) 的核心思想是将线性层改为门控形式:
其中 是额外的参数矩阵。实证研究表明 SwiGLU/GeGLU 有稳定的性能提升。
注意:GLU 变体的 FFN 维度通常缩小为 ,而非标准的 。


位置编码:RoPE 的崛起
正弦位置编码(Sine Embeddings):用固定的、非可学习的正弦 / 余弦函数生成位置信息,和词向量相加。
代表模型:原始 Transformer。
绝对位置嵌入(Absolute Embeddings):为每个位置单独训练一个可学习的向量,直接加到词向量上。
代表模型:GPT-1/2/3、OPT。
相对位置嵌入(Relative Embeddings):不直接给每个位置编码,而是把 “相对位置差” 作为偏置,加到注意力分数的计算中。
代表模型:T5、Gopher、Chinchilla。
旋转位置编码(RoPE Embeddings)
代表模型:GPT-J、PaLM、LLaMA 系列、绝大多数 2024 年后的模型。
RoPE 提出了一个理想的相对位置编码应该满足的条件:
- 左边:是位置 的词向量 和位置 的词向量 ,经过位置编码函数 后的内积(也就是注意力计算的核心)。
- 右边:这个内积的结果,应该只和两个词的内容 ,,以及它们的相对位置差 有关,而和它们的绝对位置 , 无关。
简单说:注意力分数只由 “两个词离得有多远” 决定,而不是它们在序列里的绝对位置,这样模型才能真正学到 “相对位置关系”,并且能外推到更长的序列。
现有方法的缺陷
- Sine:它的内积展开式里面的 和 会引入和绝对位置 , 相关的项,导致内积不仅和 有关,还和绝对位置有关,破坏了纯相对位置的特性。
-
Absolute:直接给每个位置加一个可学习的向量,内积结果会直接依赖于 和 ,完全不具备 “相对位置” 的性质。而且无法外推到训练时没见过的序列长度。
-
Relative embeddings: 让注意力计算不再是纯的内积形式。
RoPE 要解决的问题,就是让注意力计算只依赖相对位置差,而不是绝对位置。它的灵感来自一个关键数学性质:
向量的内积,在整体旋转下是不变的。
也就是说,两个向量一起旋转相同角度,它们的夹角和内积都不会变。
RoPE 就是利用这个性质,给不同位置的向量加上 “旋转角度”,让内积结果天然只和位置差有关。




超参数共识
FFN 维度比例
几乎所有模型遵循 (GLU 变体为 )。
| Model | |
|---|---|
| 标准 Transformer | 4 |
| PaLM | 4 |
| Mistral 7B | 3.5 |
| LLaMA-2 70B | 3.5 |
| LLaMA 70B | 2.68 |
| Qwen 14B | 2.67 |
| DeepSeek 67B | 2.68 |
| Yi 34B | 2.85 |
| T5 v1.1 | 2.5 |
极端案例:T5 11B 曾用 ,但后续 T5 v1.1 已回归标准值。
Head 维度
主流选择:
大多数模型这个比例为 1,少数 Google 模型(如 LaMDA, PaLM)有例外。
| Model | Num heads | Head dim | Model dim | Ratio |
|---|---|---|---|---|
| GPT3 | 96 | 128 | 12288 | 1 |
| T5 | 128 | 128 | 1024 | 16 |
| T5 v1.1 | 64 | 64 | 4096 | 1 |
| LaMDA | 128 | 128 | 8192 | 2 |
| PaLM | 48 | 258 | 18432 | 1.48 |
| LLaMA2 | 64 | 128 | 8192 | 1 |
| Qwen 3.5 (27B) | 24 | 256 | 5120 | 1.2 |
宽深比 (Aspect Ratio)
模型应该变宽还是变深?多宽以及多深?
大多数模型的 在 100-200 范围:
| Model | |
|---|---|
| BLOOM | 205 |
| T5 v1.1 | 171 |
| PaLM (540B) | 156 |
| GPT-3/OPT/Mistral/Qwen/OLMo 3 | 128 |
| LLaMA/LLaMA2 | 102 |
| Gamma 3 | 87 |
| Gamma 4 | 61 |
| T5 11B | 33 (极端深) |
过深的模型并行化困难、延迟高,但 [Kaplan et al 2020] 显示小模型时深一点更好。
词汇表大小
- 单语言模型:30k-50k
- 多语言/生产模型:100k-250k(GPT-4 ~100k, Gemma 4 ~262k)

Dropout 与 Weight Decay

趋势:新模型大多不用 Dropout,仅靠 Weight Decay。
原因:数据量大(万亿 tokens),单遍 SGD 不易过拟合;Weight Decay 主要影响优化动力学而非泛化。
稳定性技巧

训练大模型时,loss spike 是常见问题。以下技巧用于稳定训练:
Z-loss
Transformer 语言模型的输出层用 Softmax 计算 token 概率:
其中归一化项(也叫配分函数)是:
取对数后,我们得到模型的对数概率:
训练时,我们优化的目标就是最大化这个对数概率(交叉熵损失,等价于最小化负对数概率)。
随着模型变大、logits(也就是 )的数值范围变宽, 会出现两种极端情况:
- logits 太大:指数项 爆炸, 变得极大,导致 也变得非常大, 被压得很小,梯度消失。
- logits 太小:指数项 接近 0, 接近 1, 接近 0,但此时模型可能过于自信,也容易出现梯度问题。
当 出现剧烈波动时,损失和梯度都会变得不稳定,导致训练过程震荡甚至崩溃。
为了解决这个问题,我们在交叉熵损失上额外加一个正则项,也就是 Z-loss:
PaLM 首次大规模验证了 Z-loss 的有效性,后续很多模型,如 Baichuan 2、DCLM、OLMo 2/3 等都沿用了这个做法。
QK-norm
在 attention softmax 前对 Query 和 Key 做 LayerNorm/RMSNorm。
标准自注意力公式:
QK Norm 定义(对 Q 和 K 分别归一化):
加入 QK Norm 的最终注意力公式:
使用模型:DCLM, OLMo 2, Gemma 2, Qwen 3。
Logit Soft-capping
用 tanh 限制 logits 范围:
- 先把 logits 除以一个缩放因子
soft_cap(比如 50 或 30); - 再用
tanh函数做非线性变换,将值限制在[-1, 1]之间; - 最后乘回
soft_cap,最终 logits 就被限制在[-soft_cap, soft_cap]之间。
注意力优化
默认情况下,LLM 都用Multi-Head Attention(MHA),每个 token 生成 num_heads 个独立的 Query/Key/Value 头,所有的头的维度加起来等于d_model。
这种设计是通用的、简单且有效的,所以绝大多数模型都没改动。
GQA/MQA:降低推理成本
推理时的增量计算依赖 KV Cache,内存访问量大。MQA(Multi-Query Attention)和 GQA(Grouped-Query Attention)通过减少 KV heads 来降低内存开销。
问题背景:Attention 的计算与内存特性


在自回归推理中,每次生成一个新 token 时需要访问之前所有 token 的 Key 和 Value。为了理解为什么 KV Cache 是瓶颈,我们需要分析 Attention 的 FLOPs(浮点运算次数) 和 内存访问量,计算 Arithmetic Intensity(算术强度)。
推理时的 FLOPs 估计
生成第 个 token 时,主要张量的形状:
- 当前 Query :( 个头,每头 维)
- 历史 Key Cache :( 个头, 个历史位置,每头 维)
- 历史 Value Cache :
Attention 的主要计算:
-
Query-Key 点积: 得到注意力分数矩阵
- 形状变化:
- 每个头:,即 维向量与 个 维向量做点积
- FLOPs: 次乘法 + 次加法
- 个头总计:
-
Attention 加权 Value:用注意力分数 加权历史 Value
- 形状变化:
- 每个头: 个权重加权 的 Value,得到
- FLOPs: 次乘法 + 次加法
- 个头总计:
单层 Attention 的总 FLOPs(不含投影):
推理时的内存访问量
设数据类型为 dtype(如 fp16、fp32),字节大小为 。生成第 个 token 时:
| 操作 | 张量形状 | 访问量(bytes) |
|---|---|---|
| 读取当前 Query | ||
| 读取历史 Key Cache | ||
| 读取历史 Value Cache | ||
| 写入当前 K/V 到 Cache | ||
| 写入输出 |
总内存访问量:
主要开销来自读取历史 KV Cache:。
Arithmetic Intensity(算术强度)
算术强度定义为 FLOPs 与内存访问量的比值:
代入上述估计:
当序列长度 较大时:
对于 fp16, FLOP/byte。
关键洞察:
典型 GPU 的峰值算力与内存带宽对比:
- NVIDIA A100:峰值 312 TFLOPS (fp16),带宽 2 TB/s → 理论峰值强度 156 FLOPs/byte
- NVIDIA RTX 4090:峰值 83 TFLOPS (fp16),带宽 1 TB/s → 理论峰值强度 83 FLOPs/byte
Attention 推理的算术强度 FLOPs/byte,远低于 GPU 峰值强度,属于 memory-bound(内存受限) 操作:计算单元大部分时间在等待数据加载,而非真正执行计算。
这就是 KV Cache 成为推理瓶颈的根本原因——降低内存访问量比优化计算更关键。
MQA:极端压缩方案
Multi-Query Attention(MQA)由 Shazeer 2019 提出,核心思想是所有 Query 头共享同一个 Key 和 Value 头。
MQA 的 FLOPs 估计
MQA 的计算量与 MHA 相同,因为每个 Query 头仍需与 Key 做点积、加权 Value:
- QK 点积: FLOPs
- Attention 加权: FLOPs
- 总计: FLOPs
MQA 的内存访问量

MQA 的关键变化:Key 和 Value 只有1个头(而非 h 个头)。
设数据类型为 dtype,字节大小为 。生成第 个 token 时:
| 操作 | 张量形状 | 访问量(bytes) |
|---|---|---|
| 读取当前 Query | ||
| 读取历史 Key Cache | ||
| 读取历史 Value Cache | ||
| 写入当前 K/V 到 Cache | ||
| 写入输出 |
总内存访问量:
主要开销来自读取历史 KV Cache:,相比 MHA 的 减少 h 倍。
MQA 的 Arithmetic Intensity
FLOPs 与 MHA 相同(),内存访问量大幅减少:
当序列长度 较大时, 可忽略:
以 fp16 为例():
对比 MHA 的 (fp16 下约为 1 FLOP/byte),MQA 将算术强度提升 h 倍。
以 LLaMA 2 70B 为例(h=64, fp16):MHA ≈ 1 FLOP/byte,MQA ≈ 64 FLOPs/byte,显著更接近 GPU 峰值强度(A100: 156 FLOPs/byte),从 memory-bound 向 compute-bound 过渡。
代价:多个 Query 头被迫使用相同的 K/V 表示,表达能力受限。实验显示 MQA 在某些任务上性能下降明显。
GQA:折中方案

Grouped-Query Attention(GQA)由 Ainslie et al. 2023 提出,将 个 Query 头分成 组,每组共享一个 KV 头。典型取值 g = 8 或 g = 4。
GQA 的 FLOPs 估计
与 MHA、MQA 相同,FLOPs 不受 KV 头数影响。
GQA 的内存访问量
GQA 的关键变化:Key 和 Value 有 g 个头(介于 MHA 的 h 个和 MQA 的 1 个之间)。
设数据类型为 dtype,字节大小为 。生成第 个 token 时:
| 操作 | 张量形状 | 访问量(bytes) |
|---|---|---|
| 读取当前 Query | ||
| 读取历史 Key Cache | ||
| 读取历史 Value Cache | ||
| 写入当前 K/V 到 Cache | ||
| 写入输出 |
总内存访问量:
主要开销来自读取历史 KV Cache:,相比 MHA 的 减少 h/g 倍。
GQA 的 Arithmetic Intensity
FLOPs 与 MHA 相同(),内存访问量介于 MHA 和 MQA 之间:
当序列长度 较大时, 可忽略:
以 fp16 为例():
GQA 的算术强度介于 MHA(~1)和 MQA(~h)之间。以 LLaMA 2 70B 为例(h=64, g=8, fp16):GQA ≈ 8 FLOPs/byte,虽仍为 memory-bound,但比 MHA 好 8 倍。
实现示意
def gqa_attention(Q, K, V, h=64, g=8): """ Q: [h, d_h] - h 个 query 头 K: [g, d_h] - g 个 key 头 V: [g, d_h] - g 个 value 头 """ heads_per_group = h // g
outputs = [] for i in range(h): group_idx = i // heads_per_group # 确定 query 头属于哪个组 scores = (Q[i] @ K[group_idx]) / math.sqrt(d_h) weight = softmax(scores) out = weight * V[group_idx] outputs.append(out)
return outputs # [h, d_h]性能:实验表明 g=8 时与 MHA 差异在 1% 以内。
Sliding Window Attention
标准 Attention 的计算复杂度为 ,当序列长度 很大时(如长文档、长对话),计算和内存开销成为瓶颈。
Sliding Window Attention(滑动窗口注意力) 由 Longformer (Beltagy et al., 2020) 提出,核心思想是限制每个 token 只能 attend 窗口范围内的相邻 token,而非全部历史。后续 Mistral (Jiang et al., 2023) 将 SWA 与滚动缓存结合,实现了高效的无限长度生成。
原理与计算复杂度
设窗口大小为 ,每个 token 只与最近的 个 token 计算注意力:
其中 的形状从 变为 (只取窗口内的 Key)。
计算复杂度变化:
| 方案 | QK 点积复杂度 | 总复杂度 |
|---|---|---|
| Full Attention | ||
| Sliding Window |
当 时,复杂度接近线性。例如 时,计算量减少一半; 时,减少 16 倍。
实现示意
def sliding_window_attention(Q, K, V, window_size): """ Q: [n, d_k] - n 个 query K: [n, d_k] - n 个 key V: [n, d_v] - n 个 value window_size: 窗口大小 w """ n = Q.shape[0] outputs = []
for i in range(n): # 确定窗口范围: [i - window_size + 1, i] start = max(0, i - window_size + 1) end = i + 1
# 只取窗口内的 K 和 V K_window = K[start:end] # [w, d_k] V_window = V[start:end] # [w, d_v]
# 计算 attention scores = Q[i] @ K_window.T / math.sqrt(d_k) weights = softmax(scores) out = weights @ V_window outputs.append(out)
return outputs # [n, d_v]窗口设计的权衡
| 窗口大小 | 计算节省 | 信息捕获 | 适用场景 |
|---|---|---|---|
| 无 | 全局信息 | 短序列 | |
| 50% | 半全局 | 中等长度 | |
| 显著 | 局部上下文 | 长文档 | |
| 极大 | 极局部 | 极长序列 |
窗口过小会丢失全局依赖(如段落间引用、跨句语义),窗口过大则计算收益有限。
现代混合策略
单纯使用 SWA 会限制全局信息流动。现代做法是交错使用全注意力和局部注意力层:
| 模型 | 混合策略 | 设计考量 |
|---|---|---|
| Gemma 2/4 | 每 4 层一层全注意力 | 大部分层用 SWA 降低开销,少量全注意力层传递全局信息 |
| Mistral | SWA()+ 滚动缓存 | 用旋转缓存实现无限生成,窗口内信息不断滚动更新 |
| OLMo 3 | SWA + Full RoPE | 局部注意力配合 RoPE 位置编码,保持相对位置信息 |
| Longformer | 局部 + 全局 token | 对关键位置(如 [CLS]、标题)用全局注意力,其余用 SWA |
关键洞察:混合策略在计算效率与全局信息捕获之间取得平衡。全注意力层充当”信息枢纽”,将局部信息汇聚后传递到全局。
Attention alternatives and mixture of experts
本节内容基于 Lecture 4,主题为 “Attention Alternatives and Mixtures of Experts”,探讨了大语言模型中降低注意力计算成本的替代方案,以及混合专家模型(MoE)的设计与训练方法。
Attention Alternatives

随着上下文长度增加,标准注意力机制的计算成本()成为瓶颈。解决方案分为两个方向:工程优化(如结合局部+全局注意力)和架构革新。
Linear Attention
核心思想:利用矩阵乘法的结合律重排注意力计算:
从 降到 ,实现线性时间复杂度。
递归形式(RNN duality):
这种”对偶性”允许:
- 训练时使用并行二次形式(效率高)
- 推理时使用串行线性形式(内存友好)
若给 加权重因子 ,得到 RetNet。
实际应用:MiniMax M1 使用 7:1 混合(7 层线性注意力 + 1 层全注意力),在长上下文场景表现良好。
Mamba-2:带门控的线性注意力
对线性注意力引入逐位置权重:
其中 是数据依赖的门控因子。相比线性注意力,门控机制提供了更强的表达能力,同时保持对偶性质(并行计算 ,应用对偶变换)。
实际应用:Nemotron 3 采用 Mamba/Attention 3:1 混合,性能与同规模稠密模型相当或更好。
Gated Delta Net (GDN)
进一步泛化:门控输入 + 选择性擦除状态:
关键特性:
- 控制”无输入操作”(时不更新)
- 除当前 key 方向的信息(选择性遗忘)
与 fast weight programming / test-time training 思想密切相关。
实际应用:Qwen 3.5 / Qwen Next 采用 3:1 GDN/Attention 混合,推理效率良好。
Hybrid 性能总结
混合架构(线性注意力/SSM + 全注意力)的主流配置:
| 模型 | 混合比例 | 特点 |
|---|---|---|
| MiniMax M1 | 7:1 | 线性注意力为主 |
| Nemotron 3 | 3:1 | Mamba + Attention |
| Qwen 3.5 | 3:1 | GDN + Attention |
实证表明:较小的混合比例(少量全注意力层)通常能保持接近全注意力的性能,同时显著降低长上下文成本。
稀疏注意力 (DSA)
替代混合方案:不 attend 所有 token,使用稀疏注意力(DeepSeek Sparse Attention, v3.2 / GLM5)。
特点:
- Indexer 可很轻量,显著降低开销
- 可在稠密短上下文预训练后”事后适配”
Mixture of Experts (MoE)
MoE 的核心思想:用多个专家 FFN + 路由层替代单一 FFN,激活参数量保持不变但总参数量大幅增加。
为什么 MoE 受欢迎?
- 相同 FLOPs,更多参数 → 更好性能:实证表明增加专家数量不改变推理成本,但提升质量
- 训练更快:相同数据量下,MoE 达到相同 loss 需更少步数
- 竞争性强:与稠密等价模型相比,MoE 性能优异
- 易于并行化:每个专家可放置在不同设备
MoE 架构
典型设计:替换 MLP 层为 MoE 层。较少见:对 attention heads 使用 MoE(ModuleFormer, JetMoE)。
关键变量:
- 路由函数(Routing function)
- 专家规模
- 训练目标
路由机制
几乎所有 MoE 使用 Top-K 路由:token 选择 top-k 个专家。
| 路由类型 | 描述 | 使用模型 |
|---|---|---|
| Top-k | Token 选择 top-k 专家 | 大多数 MoE |
| Hashing | 基础对比方法 | - |
| RL routing | 学习路由策略 | 早期工作 (Bengio 2013),现已不常用 |
| Linear assignment | 求解匹配问题 | Clark 2022 等 |
Top-K 路由细节:
Gate 由 logistic regressor 选择(DeepSeek V1-2, Grok, Qwen)。Mixtral、DBRX、DeepSeek v3 在 TopK 后做 softmax。
近期变体(DeepSeek / Qwen):
- 更小、更多专家 + 少量”共享专家”(始终激活)
- 共享专家来自 DeepSpeed MoE,确保通用知识始终可用
专家配置对比:
| 模型 | 总专家数 | 激活数 | 共享专家 | 细粒度比例 |
|---|---|---|---|---|
| GShard | 2048 | 2 | 0 | - |
| Switch Transformer | 64 | 1 | 0 | - |
| Mixtral | 8 | 2 | 0 | - |
| DBRX | 16 | 4 | 0 | - |
| Grok | 8 | 2 | 0 | - |
| DeepSeek v1 | 64 | 6 | 2 | 1/4 |
| Qwen 1.5 | 60 | 4 | 4 | 1/8 |
| DeepSeek v3 | 256 | 8 | 1 | 1/14 |
| OlMoE | 64 | 8 | 0 | 1/8 |
| MiniMax | 32 | 2 | 0 | ~1/4 |
| Llama 4 (maverick) | 128 | 1 | 1 | 1/2 |
消融实验结论:
- 更多专家 + 共享专家通常有益(DeepSeek 消融)
- OlMoE:细粒度专家有增益,共享专家无增益(不同架构结论可能不同)
MoE 训练挑战
核心问题:稀疏路由决策不可微分。
解决方案对比:
| 方法 | 原理 | 使用情况 |
|---|---|---|
| RL (REINFORCE) | 学习路由策略 | 有效但梯度方差大、复杂,不常用 |
| 随机扰动 | 添加 Gaussian/Uniform 噪声 | Shazeer 2017, Fedus 2022(后移除) |
| 启发式平衡损失 | 惩罚专家使用不均 | 主流方法 |
平衡损失(Auxiliary Loss):
来自 Switch Transformer,惩罚专家使用频率不均:
其中 是专家 接收的 token 比例, 是路由概率。
DeepSeek v1-2:额外添加设备级平衡损失(按设备聚合)。
DeepSeek v3 变体:per-expert bias + online learning,称为 “auxiliary-loss-free balancing”(实际仍有辅助损失)。
系统并行性
MoE 天然支持并行:每个 FFN/专家可放置在单独设备。
挑战:
- Token dropping(batch 级路由可能导致其他请求的 token 被丢弃)
- 稀疏矩阵操作复杂
解决方案:MegaBlocks 等库使用智能稀疏矩阵乘法;Nemotron 3 下投影激活以减少通信。
训练稳定性
MoE 路由器可能导致 loss spike。
解决方案:对路由器使用 Float32 + auxiliary z-loss(类似注意力中的 z-loss,防止 logits 过大)。
微调问题
稀疏 MoE 在小规模微调数据上易过拟合。
解决方案:
- Zoph 等:微调非 MoE MLP
- DeepSeek:使用大量 SFT 数据(1.4M)
Upcycling:从稠密模型初始化 MoE
用预训练稠密模型初始化 MoE,避免从头训练。
成功案例:
- MiniCPM MoE:基于 MiniCPM,top-k=2, 8 专家,~4B 激活参数,520B tokens
- Qwen MoE:基于 Qwen 1.8B,top-k=4, 60 专家 + 4 共享,首批确认的 upcycling 成功案例
DeepSeek MoE 演进
| 版本 | 总参数 | 激活参数 | 关键特性 |
|---|---|---|---|
| V1 | 16B | 2.8B | 标准 top-k;2 共享 + 64/4 细粒度;标准 aux-loss |
| V2 | 236B | 21B | 2 共享 + 160/10 细粒度,6 激活;通信平衡损失;Top-M 设备路由 |
| V3 | 671B | 37B | 1 共享 + 258 细粒度,8 激活;Sigmoid+Softmax topK + topM;Aux-loss-free + seq-wise aux |
DeepSeek v3 关键技术:
-
MLA (Multi-head Latent Attention):将 Q/K/V 表示为低维”潜在激活”的函数。KV-cache 只需存储 ,显著压缩。RoPE 与 MLA 缓存冲突的解决方案:保留少量非潜在 key 维度用于旋转。
-
MTP (Multi-Token Prediction):小型轻量模型预测多步 ahead(类似 EAGLE),v3 只用单步预测。
MoE 总结
- MoE 利用稀疏性:并非所有输入需要完整模型
- 离散路由困难,但 Top-K 启发式方法已被证明有效
- 大量实证证据表明 MoE 有效且成本效益高
- 训练稳定性、微调过拟合、系统复杂性是主要挑战,已有成熟解决方案
支持与分享
如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!