Chapter 4 — step3 多头自注意力

对应实践course/practice/labs/lab04-step3/
主要修改文件course/practice/labs/lab04-step3/framework/student.c
验证命令make clean && make test

前一章结束时,你已经拿到了模型真正会处理的输入形式:一个 shape 为 [seq_len, hidden_dim] 的浮点向量序列。
这是 embedding 层交出来的结果。每一行代表一个位置的向量表示,其中既包含“这个 token 是谁”,也包含“它出现在第几个位置”。

但到此为止,这些位置之间仍然是彼此孤立的。

第 0 个位置并不知道第 1 个位置写了什么,第 5 个位置也不会自动去参考第 2 个位置的内容。换句话说,你现在只有一串“带位置的局部表示”,还没有“位置之间互相读取信息”的机制。

attention 正是为这个问题而生。

4.1 这一章真正要解决的问题

这一章最值得先说清楚的,不是公式,而是动机。

假设你现在要让模型处理这样一段序列。第 6 个 token 想判断自己更应该参考前面的哪个位置:是第 1 个位置里的主语,还是第 4 个位置里的限定词,还是刚刚出现的另一个关键词?

如果模型没有一种“读别人”的能力,它就只能把每个位置当成独立样本,根本无法建立上下文依赖。

attention 做的事,就是给每个位置一套主动查阅其它位置的机制。
更准确地说,它允许当前位置:

  1. 发出一个“我现在在找什么”的查询;
  2. 用这个查询去和其它位置做匹配;
  3. 按匹配强弱,从其它位置收集信息。

这三步,正是 Q、K、V 背后的直觉来源。

4.2 本章你要建立哪些判断

这一章完成后,你应当能够:

  1. 用工程语言解释 Q、K、V:不是抽象字母,而是“查询”“标签”“内容”三种不同角色。
  2. 明白 attention 分数矩阵为什么是 Q @ K^T,以及它的 shape 为什么是 [seq_len, seq_len]
  3. 明白为什么要乘 1 / sqrt(head_dim) 这个缩放因子。
  4. 明白因果掩码为什么必须在 softmax 之前加,而不是之后再处理。
  5. 在本章实验代码中实现单头 attention 的三个关键步骤:
    • student_attention_scores
    • student_apply_mask
    • student_softmax

和前三章一样,这里真正训练的不是“你见过 attention 公式”,而是“你能把这个公式拆成几段可观察、可验证、可落代码的步骤”。

4.3 先看 practice target:这章改哪里

你这章的主工作区是:

course/practice/labs/lab04-step3/
├── TASK.md
├── Makefile
└── framework/
    ├── student.c      <- 主要修改这里
    ├── student.h
    ├── verify.c       <- 自动验证,不改
    └── verify.h

本章 student 文件里留出的待实现函数有三个:

  • student_attention_scores
  • student_apply_mask
  • student_softmax

这套拆法非常有教学意义。因为完整 attention 虽然听起来像一个“大模块”,但它最核心的数学动作其实只有三段:

  1. 先算相似度分数;
  2. 再把不该看的位置屏蔽掉;
  3. 最后把分数归一化成概率。

课程没有一上来让你实现整套 attention_forward 的多头重排和输出投影,而是先把最本质的三段拆出来。这是非常合理的,因为如果你连这三段都没有建立清晰直觉,多头、拼接、输出线性层只会显得更乱。

4.4 Q、K、V 的直觉到底是什么

Q、K、V 这三个字母第一次出现时,很容易被初学者当成“又一套新名词”。
但如果只记名词,你很快就会混乱。更好的方式是先记住它们各自承担的角色。

Query

当前位置发出的“问题”。
你可以把它理解成:我现在在找哪类信息?

Key

每个位置携带的“标签”。
它回答的是:如果别人来查我,我属于哪一类、我适合在哪些问题下被匹配上?

Value

真正被读取出来的“内容”。
它回答的是:一旦当前位置决定要参考我,那它究竟从我这里拿走什么信息?

从这个角度看,attention 的第一步 Q @ K^T 根本不神秘。它只是在问:

当前位置的问题,和其它位置的标签,匹配得有多强?

如果你把这个直觉建立起来,后面的矩阵乘法就不再只是符号操作,而会开始有明确语义。

4.5 为什么 Q @ K^T 的 shape 是 [seq_len, seq_len]

这一点值得专门讲,因为它是很多人第一次真正“看懂 attention 里面矩阵维度”的转折点。

假设:

  • Q 的 shape 是 [seq_len, head_dim]
  • K 的 shape 也是 [seq_len, head_dim]

那么 K^T 的 shape 就是 [head_dim, seq_len]
所以:

Q @ K^T : [seq_len, head_dim] @ [head_dim, seq_len]
        -> [seq_len, seq_len]

这个结果矩阵的第 i, j 个元素,表示的是:

i 个位置发出的 query,与第 j 个位置的 key 的匹配分数。

也就是说,这张矩阵本质上不是“内容矩阵”,而是一张“谁更值得看”的关系表。

一旦你接受这一点,后面 mask 和 softmax 的意义就会立刻清楚很多:

  • mask 决定“哪些位置根本不能看”;
  • softmax 决定“在可看的位置里,各自分到多少权重”。

4.6 为什么还要缩放 1 / sqrt(head_dim)

如果没有这一步,随着 head_dim 变大,Q 和 K 的点积规模也会越来越大。分数一旦过大,softmax 就会很容易变得极端:最大的那一项几乎独占全部概率,其它项接近 0。

从某种角度说,这会让 attention 过早进入“近似 one-hot”的状态。
而一旦 softmax 太尖锐,梯度、数值稳定性和训练行为都会变得更难处理。

所以 attention 里引入:

scale = 1 / sqrt(head_dim)

它的作用不是改变排序,而是把分数范围压回一个更适合 softmax 的数值尺度。

在本章实验代码里,这一步会直接体现在 student_attention_scoresscale 参数上。课程把它显式留给你,是为了让你知道:这不是公式里可有可无的小系数,而是 attention 计算能否稳定工作的关键一部分。

4.7 为什么因果掩码必须在 softmax 之前加

这一章还有一个非常容易“看懂文字却没真正想明白”的点:mask 为什么要提前加?

对于 decoder-only 模型来说,位置 i 不应该看到未来位置 j > i
实现这个约束的常见方式是:在这些未来位置上加一个极大的负数,比如 -1e9f

这样一来,softmax 之前的分数矩阵里,未来位置就会变成几乎不可能被选中的候选项。
经过 softmax 之后,它们的权重自然会接近 0。

如果你把 mask 放到 softmax 之后再处理,会有两个问题:

  1. softmax 已经把这些非法位置也纳入概率归一化;
  2. 你后面再“清零”它们,会破坏一整行权重和为 1 的性质。

这就是为什么本章把 student_apply_maskstudent_softmax 拆成两个步骤,并且明确要求顺序是:

scores -> mask -> softmax

不是别的顺序。

4.8 -1e9f 为什么比 -INFINITY 更常见

这也是一个很好的“工程实现和数学理想并不完全相同”的例子。

从数学上说,被屏蔽位置的分数当然可以理解成负无穷,这样 softmax 后它的概率正好就是 0。
但在真实浮点实现里,直接混入 -INFINITY 有时会和某些后续运算、某些编译器行为、某些数值路径形成更脆弱的组合,尤其是在你自己手写不同风格的 softmax 或调试时。

-1e9f 这种“足够小的大负数”在实践里通常就已经够用:

  • 它仍然会让 expf(x) 近似 0;
  • 但在工程上往往更容易和普通浮点流程兼容。

这就是为什么 create_causal_mask 在当前框架里返回的是大负数,而不是严格的 -INFINITY

4.9 本章实践步骤

task 4.1:先读 student.cverify.c

进入:

cd course/practice/labs/lab04-step3

建议先读:

  • framework/student.h
  • framework/student.c
  • framework/verify.c

你会发现当前 lab 的验证器并不是只做 4 个粗粒度 PASS/FAIL,而是在每个大测试里还拆出很多细项。
这非常适合 attention 这种容易“哪里都像有点对、但整体就是不对”的模块。

例如它会单独检查:

  • out[0][0] 是否等于 0.5;
  • out[0][1] 是否等于 0;
  • mask 后某个位置是不是小于 -1e8f
  • softmax 后每一行和是不是 1。

这意味着你做这章时,应该利用这些局部反馈,而不是只盯住最后“结果对不对”。

task 4.2:实现 student_attention_scores

这是 attention 公式的第一步,也是最容易通过一个极小例子验证的部分。

建议用三层循环思考:

  1. 外层枚举 query 位置 i
  2. 中层枚举 key 位置 j
  3. 内层沿 head_dim 做点积累加。

然后再乘以 scale

这里一个很关键的意识是:你并不需要真的去构造 K^T 这个新张量。
只要你在读 K 时按照“第 j 行、第 d 列”访问,它在语义上就已经等价于 Q @ K^T 里的那个转置读取了。

task 4.3:实现 student_apply_mask

这一段反而是最短的。

它做的事情就是把 mask 逐元素加到 scores 上,而且是原地修改。
从代码量上看可能只有一两个循环,但从概念上看它很重要,因为它把“模型结构约束”显式地注入到了分数矩阵里。

以后你会不断遇到这种模式:真正决定模型行为的,并不总是复杂算子,有时恰恰是这些在关键节点插入的约束张量。

task 4.4:实现 student_softmax

这一步会把“匹配分数”变成“注意力权重”。

推荐仍然按数值稳定版的标准流程来写:

  1. 按行找最大值;
  2. 计算 exp(x - max)
  3. 求和;
  4. 归一化。

这一章最有价值的一点是:你会再次看到 Chapter 1 里 softmax 稳定性那套思想,但它这次不再是孤立技巧,而是 attention 里真正的核心组成部分。

这正说明前面学的东西不是学完就丢,而是会不断进入更大的结构里。

task 4.5:运行当前真实基线

在 student 实现还没写完之前,执行:

make clean && make test

当前这个 lab 的真实基线表现是:

  • TEST 1 里部分断言已经 PASS,但关键数值断言 FAIL;
  • TEST 2 里主对角项可能 PASS,但被 mask 的位置 FAIL;
  • TEST 3 中 softmax 相关断言大多 FAIL;
  • TEST 4 会出现一部分 PASS、一部分 FAIL。

这和前几章不太一样。前几章的基线更像“全部失败”。而这一章的基线更像“框架已经帮你做了输入检查和结构检查,但核心数学逻辑还没实现,所以局部现象对、关键结果不对”。

这类基线对 attention 反而很有帮助,因为它能把错误位置缩得更细。

task 4.6:完成后重新验证

当你补完三个函数后,再执行:

make clean && make test

理想结果应当是:

  • TEST 1 的所有局部断言都 PASS;
  • TEST 2 的 mask 相关断言都 PASS;
  • TEST 3 的行和与均分断言都 PASS;
  • TEST 4 的端到端断言都 PASS。

这时你不只是“写过一个 attention 公式”,而是已经把 attention 最核心的三段拆开、落地、验证了一次。

4.10 常见错误与排查顺序

现象 更可能的问题 优先检查
对角项不对 点积或 scale 没乘对 student_attention_scores
上三角没有被压到极小 mask 没有真正加进 scores student_apply_mask
每行和不是 1 softmax 归一化逻辑有误 student_softmax
全部是 0 softmax 结果没写进 out,或 sum 路径错误 student_softmax
端到端只剩零星 PASS 某一步局部逻辑对,但顺序或组合错了 检查 scores -> mask -> softmax 顺序

这一章最重要的排查原则是:不要从“完整 attention 为什么不对”开始想,而要从“分数、mask、softmax 三段里,究竟哪一段先开始偏了”开始想。

4.11 思考题

  1. 如果去掉 1 / sqrt(head_dim) 的缩放,softmax 权重为什么更容易变得极端?
  2. 为什么因果掩码必须放在 softmax 之前,而不是之后?
  3. student_softmax 再次用到了“先减最大值”的技巧。它和 Chapter 1 里的 softmax 有什么共同本质?
  4. 这一章只让你实现单头 attention 的核心三段,没有让你直接写完整多头前向。为什么这是更适合教学的拆分?

4.12 本章小结

这一章你第一次真正让序列里的位置彼此“看见了对方”。

embedding 解决的是“每个位置如何拥有自己的表示”;attention 解决的则是“每个位置如何读取别人的表示”。
从模型结构上看,这是一道很关键的分界线:从这里开始,序列不再只是排成一列的向量,而变成了一个会互相交换信息的系统。

后面的 Transformer block,正是在 attention 外面再包上层归一化、前馈网络和残差连接,让这套读取机制能稳定叠很多层。

继续阅读:Chapter 5
对应实践:Lab05