<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Attention on 酒中仙</title><link>https://hanguangwu.github.io/blog/tags/attention/</link><description>Recent content in Attention on 酒中仙</description><generator>Hugo -- gohugo.io</generator><language>zh-cn</language><copyright>hanguangwu</copyright><lastBuildDate>Sat, 07 Feb 2026 20:34:25 -0800</lastBuildDate><atom:link href="https://hanguangwu.github.io/blog/tags/attention/index.xml" rel="self" type="application/rss+xml"/><item><title>深入解析 Transformer 架构</title><link>https://hanguangwu.github.io/blog/p/%E6%B7%B1%E5%85%A5%E8%A7%A3%E6%9E%90-transformer-%E6%9E%B6%E6%9E%84/</link><pubDate>Sat, 07 Feb 2026 20:34:25 -0800</pubDate><guid>https://hanguangwu.github.io/blog/p/%E6%B7%B1%E5%85%A5%E8%A7%A3%E6%9E%90-transformer-%E6%9E%B6%E6%9E%84/</guid><description>&lt;h1 id="深入解析-transformer"&gt;深入解析 Transformer
&lt;/h1&gt;&lt;p&gt;注意力机制通过动态加权的方式，克服了传统 Seq2Seq 模型中的“信息瓶颈”问题。但是，这些模型依然依赖于 RNN 来处理序列信息，也就是说它们必须按顺序，一个词元接一个词元地进行计算，这在处理长序列时效率低下，并且存在长距离依赖信息丢失的问题。&lt;/p&gt;
&lt;p&gt;2017年，Google 的研究团队发表了一篇名为《Attention Is All You Need》的论文，提出了一种全新的架构——&lt;strong&gt;Transformer&lt;/strong&gt; &lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;。这篇论文的标题很有冲击力，其思想也同样有颠覆性。它抛弃了传统的 RNN 和卷积网络，整个模型基于注意力机制来构建。Transformer 的提出在自然语言处理领域具有划时代的意义。它不仅凭借其出色的并行计算能力极大地提升了训练效率，还更有效地捕捉了文本中的长距离依赖关系，为后续的 BERT、GPT 等大规模预训练模型的诞生提供了架构基础。&lt;/p&gt;
&lt;h2 id="一自注意力机制"&gt;一、自注意力机制
&lt;/h2&gt;&lt;p&gt;从根本上说，要让模型理解一段文本，就需要提取其“序列特征”，即将文本中所有词元的信息以某种方式整合起来。RNN 通过依次传递隐藏状态来顺序地整合信息，而 Transformer 则选择了一条截然不同的道路。其核心是 &lt;strong&gt;自注意力机制&lt;/strong&gt;。它不再依赖于顺序计算，而是将提取序列特征的过程看作是输入序列“自己对自己进行注意力计算”。序列中的每个词元都会“审视”序列中的所有其他词元，来动态地计算出最能代表当前词元上下文含义的新表示。与上一节介绍的交叉注意力不同，在自注意力中，&lt;strong&gt;Query、Key、Value 均来源于同一个输入序列&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;举个例子，在句子“苹果公司发布了新款手机，&lt;strong&gt;它&lt;/strong&gt;采用了最新的芯片”中，要理解代词“它”指的是“新款手机”而不是“苹果公司”，模型就需要将“它”与句子中的其他词元进行关联。自注意力机制正是通过计算“它”对句中其他所有词的注意力权重来实现这一点的。&lt;/p&gt;
&lt;h3 id="11-自注意力与交叉注意力的区别"&gt;1.1 自注意力与交叉注意力的区别
&lt;/h3&gt;&lt;p&gt;从结构上看，自注意力与交叉注意力的区别在于&lt;strong&gt;信息的来源和流动方向&lt;/strong&gt;。在交叉注意力机制中，信息在两个不同的序列之间流动。通常，&lt;strong&gt;Query&lt;/strong&gt; 来自解码器（代表当前的目标序列状态），而 &lt;strong&gt;Key&lt;/strong&gt; 和 &lt;strong&gt;Value&lt;/strong&gt; 来自编码器的所有输出（代表完整的源序列信息）。其目的是在生成目标序列的每一步时，从源序列中寻找最相关的信息。&lt;/p&gt;
&lt;p&gt;&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/20260207205125471.png"
loading="lazy"
&gt;&lt;/p&gt;
&lt;p&gt;而在&lt;strong&gt;自注意力&lt;/strong&gt;机制中，信息则是在&lt;strong&gt;同一个序列内部&lt;/strong&gt;进行流动和重组。它的 &lt;strong&gt;Query, Key, 和 Value 都来自同一个输入序列&lt;/strong&gt;。其目的是为了捕捉输入序列内部的依赖关系，重新计算序列中每个词元的表示，使其包含更丰富的上下文信息。&lt;/p&gt;
&lt;p&gt;&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/20260207205104142.png"
loading="lazy"
&gt;&lt;/p&gt;
&lt;p&gt;总结来说，尽管底层的加权求和计算方式相似，但两者在架构上的目标完全不同：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;交叉注意力&lt;/strong&gt;：用于&lt;strong&gt;对齐&lt;/strong&gt;和&lt;strong&gt;整合&lt;/strong&gt;两个&lt;strong&gt;不同&lt;/strong&gt;序列之间的信息。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;自注意力&lt;/strong&gt;：用于&lt;strong&gt;理解&lt;/strong&gt;和&lt;strong&gt;重构&lt;/strong&gt;单个序列&lt;strong&gt;内部&lt;/strong&gt;的依赖关系。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="12-自注意力的计算过程"&gt;1.2 自注意力的计算过程
&lt;/h3&gt;&lt;p&gt;自注意力的计算过程与上一节介绍的 QKV 范式完全一致，关键区别在于 Q, K, V 的来源。&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;生成 Q, K, V 向量&lt;/strong&gt;：&lt;/p&gt;
&lt;p&gt;对于输入序列中的&lt;strong&gt;每一个&lt;/strong&gt;词元，首先获取其词嵌入向量 $x_i$。然后，将该向量分别与三个可学习的、在整个模型中共享的权重矩阵 $W^Q, W^K, W^V$ 相乘，生成该词元专属的 Query 向量 $q_i$、Key 向量 $k_i$ 和 Value 向量 $v_i$。&lt;/p&gt;
$$
q_i = x_i W^Q \\
k_i = x_i W^K \\
v_i = x_i W^V
$$&lt;p&gt;这三个矩阵的作用是将原始的词嵌入向量投影到不同的、专门用于注意力计算的表示空间中，赋予了模型更大的灵活性。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;计算注意力分数&lt;/strong&gt;：&lt;/p&gt;
&lt;p&gt;为了计算第 $i$ 个词元的新表示，需要用它的 Query 向量 $q_i$ 去和&lt;strong&gt;所有&lt;/strong&gt;词元（包括它自己）的 Key 向量 $k_j$ 计算点积，得到注意力分数。&lt;/p&gt;
$$
\text{score}(i, j) = q_i \cdot k_j
$$&lt;p&gt;（3）&lt;strong&gt;缩放与归一化&lt;/strong&gt;：&lt;/p&gt;
&lt;p&gt;将得到的分数除以一个缩放因子 $\sqrt{d_k}$（$d_k$ 是 Key 向量的维度），然后通过 Softmax 函数进行归一化，得到最终的注意力权重 $\alpha_{ij}$。这个缩放步骤的目的与上一节中介绍的一致，都是为了在训练过程中保持梯度稳定。当向量维度 $d_k$ 较大时，点积结果的方差会增大，可能将 Softmax 函数推向其梯度极小的区域，从而导致梯度消失，影响模型学习。进行缩放可以有效缓解这个问题。&lt;/p&gt;
$$
\alpha_{ij} = \text{softmax}\left(\frac{q_i \cdot k_j}{\sqrt{d_k}}\right)
$$&lt;ol start="4"&gt;
&lt;li&gt;&lt;strong&gt;加权求和&lt;/strong&gt;：&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;使用计算出的权重 $\alpha_{ij}$ 对&lt;strong&gt;所有&lt;/strong&gt;词元的 Value 向量 $v_j$ 进行加权求和，得到第 $i$ 个词元经过自注意力计算后得到的新表示 $z_i$。&lt;/p&gt;
$$
z_i = \sum_j \alpha_{ij} v_j
$$&lt;p&gt;通过这个过程，输出向量 $z_i$ 不再仅仅包含原始词元 $x_i$ 的信息，而是融合了整个序列中所有与之相关词元的信息，成为一个上下文感知的、更丰富的表示。其本质可以理解为：序列中的每个词元都同时扮演着“查询（Q）”、“键（K）”和“值（V）”三种角色。通过计算查询与其他所有词元的键之间的相关性，来决定如何加权融合所有词元的值，从而为每个词元生成一个全新的、深度融合了全局上下文信息的表示。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;既然 Q, K, V 都来自同一个输入 X，为什么不直接用 X 计算，而要引入三个独立的权重矩阵 $W^Q, W^K, W^V$？甚至，为什么是三个，而不是两个或四个？&lt;/p&gt;
&lt;p&gt;这可以类比在图书馆查资料的过程：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Query (Q) - 要问的问题&lt;/strong&gt;：代表了我们主动想查询的意图。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Key (K) - 书的索引/标签&lt;/strong&gt;：代表了书本内容的关键特征，用于被动地和你的问题进行匹配。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Value (V) - 书的具体内容&lt;/strong&gt;：代表了书本实际包含的信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;我们的“问题”和书本的“索引”可能都源于同一个知识领域（同一个输入 X），但它们在信息检索这个任务中扮演的角色是截然不同的。$W^Q, W^K, W^V$ 这三个矩阵的作用，就是让模型学会将原始输入 X 投影到三个功能不同的空间中，分别去扮演好“查询者”、“被查询的索引”和“信息提供者”这三种角色。Q-K 配对解决了“如何定位相关信息”的问题，而 V 提供了“应该提取什么信息”的答案。这个三元组结构在功能上是完备且高效的，所以成为了注意力机制的标准范式。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="13-矩阵运算与并行化"&gt;1.3 矩阵运算与并行化
&lt;/h3&gt;&lt;p&gt;上述步骤描述的是单个词元 $i$ 的计算过程。在实际应用中，如果采用循环的方式逐个计算每个词元的 $z_i$，效率会非常低下。自注意力的巨大优势在于其&lt;strong&gt;并行计算&lt;/strong&gt;能力，这通过将整个过程表达为矩阵运算来实现。&lt;/p&gt;
&lt;p&gt;假设整个输入序列的词嵌入矩阵为 $X$（维度为 &lt;code&gt;[sequence_length, embedding_dim]&lt;/code&gt;），可以一次性计算出所有词元的 Q, K, V 矩阵：&lt;/p&gt;
$$
Q = X W^Q \\
K = X W^K \\
V = X W^V
$$&lt;p&gt;然后，整个自注意力的输出矩阵 $Z$ 可以通过一个公式完成计算：&lt;/p&gt;
$$
Z = \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$&lt;p&gt;这个公式与上一节中介绍的通用注意力公式完全相同。这里的主要区别不在于数学运算，而在于&lt;strong&gt;输入的来源&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;在上一节的&lt;strong&gt;交叉注意力&lt;/strong&gt;中，Q 来自一个序列（解码器），而 K 和 V 来自另一个序列（编码器）。&lt;/li&gt;
&lt;li&gt;在当前的&lt;strong&gt;自注意力&lt;/strong&gt;中，矩阵 Q、K 和 V &lt;strong&gt;全部派生自同一个输入序列 X&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;所以，同一个数学范式，根据输入来源的不同，被用于解决两个不同的问题，一个是两个序列之间的对齐，另一个是单个序列内部的依赖关系建模。在这个公式中， $QK^T$ 的计算结果是一个维度为 &lt;code&gt;[sequence_length, sequence_length]&lt;/code&gt; 的注意力分数矩阵，其中第 $i$ 行第 $j$ 列的元素表示第 $i$ 个词元对第 $j$ 个词元的注意力分数（未归一化的 logits）。注意力权重来自对缩放分数应用 Softmax 后得到的归一化系数。&lt;/p&gt;
&lt;h3 id="14-pytorch-实现自注意力"&gt;1.4 PyTorch 实现自注意力
&lt;/h3&gt;&lt;blockquote&gt;
&lt;p&gt;&lt;a class="link" href="https://github.com/FutureUnreal/base-nlp/blob/main/code/C4/03_Self-Attention.py" target="_blank" rel="noopener"
&gt;本节完整代码&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;从概念上讲，自注意力的计算可以分解为对序列中每个词元进行循环操作，这种方式虽然直观但效率极低。因此，现代深度学习框架中的实现都采用了&lt;strong&gt;矩阵运算&lt;/strong&gt;的方式。通过将整个序列的 Q, K, V 看作矩阵，利用一次大规模的矩阵乘法，就能并行地完成所有词元之间的相关性计算。下面是这种并行化版本的实现：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;SelfAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;自注意力模块&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;SelfAttention&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;k_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;k_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attention_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;: 初始化了三个 &lt;code&gt;nn.Linear&lt;/code&gt; 层，它们分别对应将输入映射到 Q, K, V 空间的权重矩阵 $W^Q, W^K, W^V$。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;q_linear(x)&lt;/code&gt;, &lt;code&gt;k_linear(x)&lt;/code&gt;, &lt;code&gt;v_linear(x)&lt;/code&gt;：将形状为 &lt;code&gt;[batch_size, seq_len, hidden_size]&lt;/code&gt; 的输入张量 &lt;code&gt;x&lt;/code&gt; 分别通过三个线性层，一次性地为序列中的所有词元计算出 Q, K, V 矩阵。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;torch.matmul(q, k.transpose(-2, -1))&lt;/code&gt;: 这是实现并行计算的核心。通过将 K 矩阵的最后两个维度转置（&lt;code&gt;seq_len, hidden_size&lt;/code&gt; -&amp;gt; &lt;code&gt;hidden_size, seq_len&lt;/code&gt;），再与 Q 矩阵相乘，直接得到了一个 &lt;code&gt;[batch_size, seq_len, seq_len]&lt;/code&gt; 的分数矩阵。该矩阵中的 &lt;code&gt;scores[b, i, j]&lt;/code&gt; 代表了批次 &lt;code&gt;b&lt;/code&gt; 中第 &lt;code&gt;i&lt;/code&gt; 个词元对第 &lt;code&gt;j&lt;/code&gt; 个词元的注意力分数。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;/ math.sqrt(self.hidden_size)&lt;/code&gt;：执行缩放操作，防止梯度消失。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;torch.softmax(scores, dim=-1)&lt;/code&gt;：对分数的最后一个维度（&lt;code&gt;seq_len&lt;/code&gt;）进行 Softmax，得到归一化的注意力权重。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;torch.matmul(attention_weights, v)&lt;/code&gt;：将权重矩阵与 V 矩阵相乘，完成了对所有词元的 Value 向量的加权求和，得到最终的上下文感知表示。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="二多头注意力机制"&gt;二、多头注意力机制
&lt;/h2&gt;&lt;p&gt;仅仅用一组 $W^Q, W^K, W^V$ 矩阵进行一次自注意力计算，相当于只从一个“视角”来审视文本内在的关系。然而，文本中的关系是多层次的，例如，一组参数可能学会了关注代词（如 “它” 指向谁）的关系，但可能忽略了动作的执行者（主谓宾）等其他类型的关系。&lt;/p&gt;
&lt;p&gt;为了让模型能够综合利用从不同维度和视角提取出的信息，Transformer 引入了&lt;strong&gt;多头注意力机制 (Multi-Head Attention)&lt;/strong&gt;。其思想非常直接：并行地执行多次自注意力计算，每一次计算都是一个独立的“头 (Head)”。每个头都拥有一组自己专属的 $W^Q_i, W^K_i, W^V_i$ 权重矩阵，并且可以学习去关注一种特定类型的上下文关系。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;那么，多头注意力与我们之前讨论的“增加 A, B, C 等新角色”有什么不同呢？&lt;/p&gt;
&lt;p&gt;一个关键的区别：多头注意力&lt;strong&gt;不是&lt;/strong&gt;通过增加 A, B, C 等新角色来&lt;strong&gt;深化&lt;/strong&gt;单次注意力计算的复杂性，而是通过并行运行多个独立的 QKV 计算单元来&lt;strong&gt;拓宽&lt;/strong&gt;其广度。&lt;/p&gt;
&lt;p&gt;再次使用图书馆的类比：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;增加 A, B, C&lt;/strong&gt;：相当于给&lt;strong&gt;一个&lt;/strong&gt;图书管理员一套更复杂的工具，让他一次性处理问题(Q)、索引(K)、内容(V)之外，还要考虑主题(A)、背景(B)等，这会使单次查询过程变得非常复杂。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;多头注意力&lt;/strong&gt;：相当于&lt;strong&gt;雇佣一个各有所长的专家团队&lt;/strong&gt;（比如 8 个管理员，即 8 个“头”）。每个专家都只使用标准高效的 QKV 工具，但他们各自有独特的视角（独立的 $W^Q_i, W^K_i, W^V_i$ 矩阵）。一个专家可能专攻语法，另一个专攻语义。最后，将所有专家的报告汇总起来，得到一个更全面、更丰富的结论。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;因此，多头注意力机制为模型提供了从不同子空间、不同视角审视信息的能力，而不是改变注意力计算本身的范式。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;具体流程如下：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;并行计算&lt;/strong&gt;：假设有 $h$ 个头，那么就初始化 $h$ 组不同的权重矩阵 $(W^Q_0, W^K_0, W^V_0), (W^Q_1, W^K_1, W^V_1), \dots, (W^Q_{h-1}, W^K_{h-1}, W^V_{h-1})$。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;独立注意力&lt;/strong&gt;：对于输入序列，每个头都独立地执行一次完整的自注意力计算，产生一个输出矩阵 $Z_i$。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;拼接与投影&lt;/strong&gt;：将所有 $h$ 个头的输出矩阵 $Z_0, Z_1, \dots, Z_{h-1}$ 在特征维度上进行&lt;strong&gt;拼接 (Concatenate)&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;（4）&lt;strong&gt;最终输出&lt;/strong&gt;：将拼接后的巨大矩阵乘以一个新的权重矩阵 $W^O$，将其投影回原始的输入维度，得到多头注意力机制的最终输出。&lt;/p&gt;
&lt;p&gt;多头机制允许模型在不同的表示子空间中共同学习上下文信息。例如，一个头可能专注于捕捉长距离的语法依赖，而另一个头可能更关注局部的词义关联。这种设计极大地增强了模型的表达能力。&lt;/p&gt;
&lt;p&gt;在实践中，为了保持计算总量不变，通常会将原始的词嵌入维度 &lt;code&gt;embedding_dim&lt;/code&gt; 均分给 $h$ 个头。例如，如果 &lt;code&gt;embedding_dim=512&lt;/code&gt;，有 &lt;code&gt;h=8&lt;/code&gt; 个头，那么每个头产生的 Q, K, V 向量维度就是 &lt;code&gt;d_k = d_v = 512 / 8 = 64&lt;/code&gt;。计算时，先将输入 $X$ 分别投影到 $h$ 组低维的 Q, K, V 向量，并行计算后，再将结果拼接并投影回 &lt;code&gt;embedding_dim&lt;/code&gt; 维度。&lt;/p&gt;
&lt;h3 id="21-pytorch-实现多头注意力"&gt;2.1 PyTorch 实现多头注意力
&lt;/h3&gt;&lt;p&gt;多头注意力是通过并行运行多个独立的自注意力“头”，并融合它们的输出来增强模型的表达能力。一个低效的实现是简单地创建多个 &lt;code&gt;SelfAttention&lt;/code&gt; 实例并拼接结果。而高效的实现则是将多个头的计算逻辑合并到一次矩阵运算中。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadSelfAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;多头自注意力模块&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;MultiHeadSelfAttention&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;hidden_size 必须能被 num_heads 整除&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;k_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wo&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;k_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 拆分多头&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 并行计算注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attention_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 合并多头结果&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;contiguous&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 输出层&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;head_dim&lt;/code&gt;：计算出每个头的维度，即 &lt;code&gt;hidden_size / num_heads&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;q_linear, k_linear, v_linear&lt;/code&gt;：与单头类似，但这里的线性层输出维度仍然是 &lt;code&gt;hidden_size&lt;/code&gt;。这是为了&lt;strong&gt;一次性计算出所有头所需的总特征&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;wo&lt;/code&gt;：对应于多头注意力机制中的输出权重矩阵 $W^O$，用于融合所有头的信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;线性变换&lt;/strong&gt;: 与单头版本相同，得到总的 Q, K, V 矩阵。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;拆分多头&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;.view(batch_size, seq_len, self.num_heads, self.head_dim)&lt;/code&gt;: 首先，将 &lt;code&gt;hidden_size&lt;/code&gt; 维度逻辑上拆分为 &lt;code&gt;num_heads&lt;/code&gt; 和 &lt;code&gt;head_dim&lt;/code&gt; 两个维度。此时张量形状变为 &lt;code&gt;[batch, seq_len, num_heads, head_dim]&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;.transpose(1, 2)&lt;/code&gt;: 然后，交换 &lt;code&gt;seq_len&lt;/code&gt; 和 &lt;code&gt;num_heads&lt;/code&gt; 维度，得到 &lt;code&gt;[batch, num_heads, seq_len, head_dim]&lt;/code&gt;。这一步是为了让 &lt;code&gt;num_heads&lt;/code&gt; 成为一个类似批次 (batch) 的维度，使得后续的矩阵乘法可以在每个头内部独立、并行地进行。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;并行计算注意力&lt;/strong&gt;: &lt;code&gt;torch.matmul(q, k.transpose(-2, -1))&lt;/code&gt; 现在是一个四维张量的乘法。PyTorch 会自动地将其解释为在第 0 和第 1 维（&lt;code&gt;batch&lt;/code&gt; 和 &lt;code&gt;num_heads&lt;/code&gt;）上进行批处理，而对最后两个维度执行矩阵乘法。这样就实现了所有头的注意力分数计算的并行化。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;合并多头&lt;/strong&gt;: 这是拆分操作的逆过程。
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;.transpose(1, 2)&lt;/code&gt;: 先将 &lt;code&gt;num_heads&lt;/code&gt; 和 &lt;code&gt;seq_len&lt;/code&gt; 维度换回来，形状变为 &lt;code&gt;[batch, seq_len, num_heads, head_dim]&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;.contiguous()&lt;/code&gt;: 由于 &lt;code&gt;transpose&lt;/code&gt; 操作可能导致张量在内存中不是连续存储的，需要调用 &lt;code&gt;.contiguous()&lt;/code&gt; 来确保内存连续，之后才能安全地使用 &lt;code&gt;.view()&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;.view(batch_size, seq_len, self.hidden_size)&lt;/code&gt;: 最后，将 &lt;code&gt;num_heads&lt;/code&gt; 和 &lt;code&gt;head_dim&lt;/code&gt; 两个维度重新合并成 &lt;code&gt;hidden_size&lt;/code&gt; 维度，完成了所有头输出的拼接。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;输出投影&lt;/strong&gt;: 将合并后的结果通过 &lt;code&gt;wo&lt;/code&gt; 线性层，得到最终输出。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="三transformer-整体结构"&gt;三、Transformer 整体结构
&lt;/h2&gt;&lt;p&gt;理解了自注意力和多头注意力之后，就可以从一个更高的视角来审视 Transformer 的整体结构了。通过图 4-3 可以看出它依然是一个 Encoder-Decoder 架构，但其内部是由几个标准化的“积木”堆叠而成的。&lt;/p&gt;
&lt;div align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/4_3_1.svg" width="40%" alt="Transformer 架构" /&gt;
&lt;p&gt;图 4-3 Transformer 架构&lt;/p&gt;
&lt;/div&gt;
&lt;p&gt;Transformer 的 Encoder 和 Decoder 都是由 N 个（原论文中 N=6）功能相同的层（Layer）堆叠而成。下面我们分别来看它们的内部构造。&lt;/p&gt;
&lt;h3 id="31-编码器encoder"&gt;3.1 编码器（Encoder）
&lt;/h3&gt;&lt;p&gt;编码器的作用是“理解”和“消化”输入的整个序列，为序列中的每个词元生成一个富含上下文信息的表示。一个标准的&lt;strong&gt;编码器层&lt;/strong&gt;由两个主要的&lt;strong&gt;子层&lt;/strong&gt;构成，分别是&lt;strong&gt;多头自注意力层（Multi-Head Self-Attention Layer）&lt;strong&gt;和&lt;/strong&gt;位置前馈网络（Position-wise Feed-Forward Network）&lt;/strong&gt;。每个子层的输出都经过了**残差连接（Add）&lt;strong&gt;与&lt;/strong&gt;层归一化（Norm）**处理。所以，一个编码器层内部的数据流可以表示为 &lt;code&gt;x -&amp;gt; Sublayer1(x) -&amp;gt; Add &amp;amp; Norm -&amp;gt; Sublayer2(...) -&amp;gt; Add &amp;amp; Norm&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;关键特性&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;注意力类型&lt;/strong&gt;：编码器中的多头注意力层是&lt;strong&gt;双向的自注意力&lt;/strong&gt;。这意味着在计算时，序列中的任何一个词元都可以“看到”序列中的所有其他词元（包括它自己、它前面的和它后面的）。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;功能&lt;/strong&gt;：由于其双向性，编码器非常擅长&lt;strong&gt;理解&lt;/strong&gt;完整的输入文本，并为每个词元生成一个深度融合了上下文信息的表示。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;应用&lt;/strong&gt;：通过大量堆叠编码器层而构建的模型（Encoder-Only 架构），如 &lt;strong&gt;BERT&lt;/strong&gt;，在文本分类、命名实体识别等自然语言理解（NLU）任务上取得了巨大成功。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="32-解码器decoder"&gt;3.2 解码器（Decoder）
&lt;/h3&gt;&lt;p&gt;解码器的作用是基于编码器对原始输入的理解，并结合已经生成的部分，来逐个生成下一个词元。为了完成这个更复杂的任务，一个标准的&lt;strong&gt;解码器层（Decoder Layer）&lt;strong&gt;比编码器层多了一个注意力子层，总共包含&lt;/strong&gt;三个子层&lt;/strong&gt;。分别是&lt;strong&gt;带掩码的多头自注意力层（Masked Multi-Head Self-Attention Layer）&lt;/strong&gt;、&lt;strong&gt;交叉注意力层（Cross-Attention Layer）&lt;strong&gt;和&lt;/strong&gt;位置前馈网络（Position-wise Feed-Forward Network）&lt;/strong&gt;。同样，解码器的每个子层也都采用了残差连接和层归一化。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;关键特性&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;子层 1：带掩码的自注意力&lt;/strong&gt;
&lt;ul&gt;
&lt;li&gt;与编码器层相比，这是解码器的第一个主要区别。解码器在生成序列时必须是&lt;strong&gt;自回归&lt;/strong&gt;的，即在生成第 $t$ 个词元时，只能依赖于已经生成的前 $t-1$ 个词元，而不能“看到”未来的信息。&lt;/li&gt;
&lt;li&gt;为了在并行的自注意力计算中实现这一点，需要引入&lt;strong&gt;掩码 (Masking)&lt;/strong&gt;。在计算 Softmax 之前，一个“未来词元掩码”会被应用到注意力分数上，将所有未来位置的分数设置为一个极小的负数（如 &lt;code&gt;-inf&lt;/code&gt;），这样在经过 Softmax 之后，这些位置的注意力权重就会变为 0，从而确保了模型的单向性。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;子层 2：交叉注意力&lt;/strong&gt;
&lt;ul&gt;
&lt;li&gt;这是连接编码器和解码器的桥梁。这一层的实现与上一节中描述的交叉注意力一致。&lt;/li&gt;
&lt;li&gt;它的 &lt;strong&gt;Key 和 Value 来自于编码器的最终输出&lt;/strong&gt;，而 &lt;strong&gt;Query 则来自于解码器前一个子层（即带掩码的自注意力层）的输出&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;这一层允许解码器在生成每个词元时，能够“关注”到输入序列的所有部分，从而有针对性地提取所需信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;子层 3：逐位置前馈网络 (FFN)&lt;/strong&gt;
&lt;ul&gt;
&lt;li&gt;这部分与编码器中的 FFN 相同，为模型提供非线性变换能力。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;大模型应用&lt;/strong&gt;：通过大量堆叠解码器层而构建的模型（Decoder-Only 架构），如 &lt;strong&gt;GPT&lt;/strong&gt; 系列，由于其天然的自回归生成能力，引领了当前大语言模型（LLM）的发展浪潮。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="33-组件解析"&gt;3.3 组件解析
&lt;/h3&gt;&lt;p&gt;下面来详细解析一下构成上述“层”的几个重要组件。&lt;/p&gt;
&lt;h4 id="331-位置前馈网络-ffn"&gt;3.3.1 位置前馈网络 (FFN)
&lt;/h4&gt;&lt;p&gt;这是一个由两次线性变换和一个激活函数组成的全连接网络，它独立地应用于序列中的&lt;strong&gt;每一个位置&lt;/strong&gt;。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;作用&lt;/strong&gt;：特征变换。注意力子层内部主要包含 Softmax 归一化；逐位置的非线性主要由 FFN 提供（常见 ReLU/GELU&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;内部结构&lt;/strong&gt;：其常见的结构是“升维-激活-降维”。&lt;/p&gt;
$$
\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
$$&lt;ul&gt;
&lt;li&gt;第一次线性变换（ $W_1$ ）通常会将输入维度 &lt;code&gt;embedding_dim&lt;/code&gt; 放大到 4 倍（&lt;code&gt;4 * embedding_dim&lt;/code&gt;）。这种&lt;strong&gt;升维&lt;/strong&gt;操作的作用是将特征投影到一个更高维的空间，以便提取更丰富、更复杂的模式。&lt;/li&gt;
&lt;li&gt;使用一个激活函数（如 ReLU）进行非线性处理。&lt;/li&gt;
&lt;li&gt;第二次线性变换（ $W_2$ ）再将维度从 &lt;code&gt;4 * embedding_dim&lt;/code&gt; 压缩回原始的 &lt;code&gt;embedding_dim&lt;/code&gt;。这种&lt;strong&gt;降维&lt;/strong&gt;操作可以看作是对高维特征的筛选和压缩，保留最重要的信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;将中间层维度设为 4 倍的做法，主要是继承自原始论文的经验设定，并因其良好效果而被后续模型广泛采用，并非有严格的理论证明。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h4 id="332-残差连接与层归一化-add--norm"&gt;3.3.2 残差连接与层归一化 (Add &amp;amp; Norm)
&lt;/h4&gt;&lt;p&gt;为了让这些层能够成功地“堆叠”起来，每个子层的后面都连接了这个组合。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Add (残差连接)&lt;/strong&gt;：解决了深度网络的“模型退化”问题。从反向传播的角度看，子层的输出可以写成 $y = x + \text{Sublayer}(x)$。在计算梯度时， $\frac{\partial y}{\partial x} = 1 + \frac{\partial \text{Sublayer}(x)}{\partial x}$。其中“1”的存在为梯度创建了一条“高速公路”，确保无论网络有多深，梯度都能至少以大小为 1 的程度回传到最浅层，极大地稳定了训练过程。同时，这也要求模型中所有子层的输入和输出维度必须保持一致，以便进行元素相加。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Norm (层归一化)&lt;/strong&gt; &lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;：用于稳定训练过程。它独立地对&lt;strong&gt;每个样本的每个词元&lt;/strong&gt;的特征向量（即 &lt;code&gt;hidden_size&lt;/code&gt; 维度）进行标准化，使其均值变为 0，方差变为 1（但这并不假设其原始分布为正态分布）。更重要的是，它引入了两个可学习的参数 $\gamma$（缩放）和 $\beta$（偏移），让模型可以自主学习最佳的数据分布，其完整公式为：&lt;/p&gt;
$$
y = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta
$$&lt;p&gt;这使得模型既能享受到归一化带来的稳定性，又具备了根据任务需要恢复或调整原始分布的能力。与主要用于计算机视觉的 Batch Normalization（对一个批次中所有样本的同一特征进行归一化）相比，Layer Normalization 不受批次大小的影响，更适合处理长度可变的自然语言序列。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个设计使得模型可以通过简单地增加“层”的数量（即深度）和特征维度（即宽度）来进行扩展，为后来参数量巨大的语言模型奠定了基础。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;原论文采用 Post-LN（Sublayer → Add → LayerNorm）。许多现代实现（如 GPT 系列）采用 Pre-LN（LayerNorm → Sublayer → Add），训练更稳定、更易加深，但功能等价。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="34-位置编码"&gt;3.4 位置编码
&lt;/h3&gt;&lt;p&gt;自注意力机制的主要缺陷在于其 &lt;strong&gt;位置无关性&lt;/strong&gt;。由于计算是完全并行的，模型无法感知词元的顺序。例如，“猫追狗”和“狗追猫”这两个句子，在自注意力看来，它们的词元集合完全相同，因此会为“猫”和“狗”生成相同的上下文表示，这显然是错误的。为了解决这个问题，Transformer 在将词嵌入向量输入模型之前，为它们加入了一个 &lt;strong&gt;位置编码 (Positional Encoding)&lt;/strong&gt; 向量。其工作方式非常直接：&lt;/p&gt;
$$ input_\text{embedding} = token_\text{embedding} + positional_\text{encoding} $$&lt;p&gt;这个额外注入的向量为每个词元提供了其在序列中的位置信息。这是一种 &lt;strong&gt;绝对位置编码&lt;/strong&gt;，即每个位置（如第 0、1、2 个位置）都有一个固定的编码向量。在实践中，主要有两种实现方式：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;可学习的位置编码 (Learned Positional Encoding)&lt;/strong&gt;
- 在 Encoder-only 模型（如 BERT）中常见；而近年的大型解码器式模型多采用相对/旋转类位置编码（如 RoPE&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;）。
- 其实现非常简单：创建一个 &lt;code&gt;nn.Embedding&lt;/code&gt; 层，大小为 &lt;code&gt;[max_sequence_length, hidden_size]&lt;/code&gt;。&lt;code&gt;max_sequence_length&lt;/code&gt; 是模型能处理的最大序列长度，这是一个重要的超参数（在很多模型配置文件中被称为 &lt;code&gt;max_position_embeddings&lt;/code&gt;）。在训练时，模型会像学习词嵌入一样，自动学习出每个位置（0, 1, 2, &amp;hellip;）最合适的向量表示。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;基于三角函数的固定编码 (Sinusoidal Positional Encoding)&lt;/strong&gt;
- 这是原版 Transformer 论文中使用的方法，它不需要学习。
- 使用不同频率的正弦和余弦函数来为每个位置生成一个独特的、固定的编码向量：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt; $$
PE_{(\text{pos}, 2i)} = \sin\left(\frac{\text{pos}}{10000^{\frac{2i}{d_{\text{embedding}}}}}\right) \\
PE_{(\text{pos}, 2i+1)} = \cos\left(\frac{\text{pos}}{10000^{\frac{2i}{d_{\text{embedding}}}}}\right)
$$
其中：
- $pos$ 是词元在序列中的绝对位置（如第 0 个、第 1 个词...）。
- $i$ 是编码向量中的维度索引（从 0 到 $d_{embedding}$/2）。公式通过 $i$ 来同时计算偶数维度 $2i$ 和奇数维度 $2i+1$ 的值，因此 $i$ 的取值范围只需达到维度总数的一半。
- $d_{embedding}$ 是词嵌入的维度。
公式利用不同频率的正弦和余弦函数，为 $d_{embedding}$ 维编码向量的每一个维度（$2i$ 对应偶数位，$2i+1$ 对应奇数位）计算一个特定的值。由于每个位置 $pos$ 和每个维度 $i$ 的组合都是独一无二的，所以这种方法能为序列中的每个位置生成一个完全独特的编码向量。这种方法的优势是不同位置的编码向量之间存在固定的线性关系，这可能有助于模型推断出词元间的相对位置。其主要优点是不需要训练，并且理论上可以外推到比训练时遇到的更长的序列。
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;绝对 vs. 相对位置编码&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;上述两种方法都属于&lt;strong&gt;绝对位置编码&lt;/strong&gt;，因为它们为每个绝对位置（第 1 个、第 10 个等）分配一个特定的编码。然而，这种方式在处理超长文本时可能存在泛化性问题。因此，许多现代的大语言模型（如 Transformer-XL, Llama）转而采用&lt;strong&gt;相对位置编码&lt;/strong&gt;。这种方法不再关注词元的绝对位置，而是直接在注意力计算中建模词元之间的相对距离（例如，“当前词”与“前 2 个词”之间的关系），这被证明在处理长序列时更有效、更灵活。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="35-注意力掩码"&gt;3.5 注意力掩码
&lt;/h3&gt;&lt;p&gt;掩码是 Transformer 模型中一个重要的机制。其&lt;strong&gt;主要目的&lt;/strong&gt;是确保解码器在生成序列时的&lt;strong&gt;自回归&lt;/strong&gt;特性，即不能“看到”未来的信息。此外，作为一个&lt;strong&gt;通用的工程实践&lt;/strong&gt;，掩码也被用来处理批量训练中因句子长度不同而引入的填充（Padding）问题。Transformer 主要使用以下两种掩码：&lt;/p&gt;
&lt;p&gt;（1）因果掩码：因果掩码专用于解码器的&lt;strong&gt;带掩码的自注意力（Masked Self-Attention）&lt;strong&gt;子层，是为了确保解码过程遵循&lt;/strong&gt;自回归（Auto-regressive）&lt;strong&gt;特性，即生成第 $i$ 个词元时只能依赖前 $i-1$ 个词元的信息，而绝不能“偷看”到 $i$ 及之后位置的内容。它的实现核心是确保注意力权重矩阵呈现&lt;/strong&gt;下三角矩阵&lt;/strong&gt;的形态。对于长度为 $T$ 的序列，在 $[T, T]$ 的矩阵中，主对角线及以下的位置被标记为可关注（如 &lt;code&gt;True&lt;/code&gt; 或 0），而主对角线以上的位置则被标记为屏蔽（如 &lt;code&gt;False&lt;/code&gt; 或 1）。在计算 Softmax 之前，所有被屏蔽位置的注意力分数会被加上一个极大的负数（如 &lt;code&gt;-inf&lt;/code&gt;），迫使其注意力权重归零，从而物理上切断了信息的向后传播路径。&lt;/p&gt;
&lt;p&gt;（2）填充掩码：填充掩码广泛应用于编码器和解码器的所有注意力层，目的是解决变长序列批量处理时的**填充（Padding）**问题。由于填充词元（如 &lt;code&gt;&amp;lt;pad&amp;gt;&lt;/code&gt;）本身不携带语义信息，若模型对其分配注意力，不仅浪费计算资源，还会引入噪声干扰。填充掩码的作用就是在计算注意力分数后，将所有涉及填充词元的位置（无论是作为查询 Query 还是作为键 Key）的对应分数强制设为极大的负数（如 &lt;code&gt;-1e9&lt;/code&gt; 或负无穷）。假设有一个注意力分数矩阵，维度为 &lt;code&gt;[batch_size, num_heads, seq_len, seq_len]&lt;/code&gt;。填充掩码会是一个 &lt;code&gt;[batch_size, 1, 1, seq_len]&lt;/code&gt; 的矩阵（或可广播的形状），标记了哪些位置是填充。在进行 Softmax 之前，这个掩码会被加到分数矩阵上。经过 Softmax 运
算后，这些负无穷位置的注意力权重会趋近于 0，从而在后续的加权求和中被完全忽略。&lt;/p&gt;
&lt;p&gt;在解码器的自注意力层中，这两种掩码通常会结合使用，确保模型既不会关注到未来的信息，也不会关注到填充位。&lt;/p&gt;
&lt;h3 id="36-解码器推理与-kv-缓存"&gt;3.6 解码器推理与 KV 缓存
&lt;/h3&gt;&lt;p&gt;解码器在训练和推理时的行为有很大不同。训练时，模型可以看到完整的“正确答案”序列，并通过注意力掩码来并行计算所有位置的损失。然而，在&lt;strong&gt;推理&lt;/strong&gt;时，模型必须逐个生成词元，这是一个&lt;strong&gt;自回归&lt;/strong&gt;的过程：&lt;/p&gt;
&lt;p&gt;（1）输入 &lt;code&gt;[BOS]&lt;/code&gt;（开始符），生成第一个词 &lt;code&gt;token_1&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（2）输入 &lt;code&gt;[BOS], token_1&lt;/code&gt;，生成第二个词 &lt;code&gt;token_2&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（3）输入 &lt;code&gt;[BOS], token_1, token_2&lt;/code&gt;，生成第三个词 &lt;code&gt;token_3&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（4）&amp;hellip; 直到生成 &lt;code&gt;[EOS]&lt;/code&gt;（结束符）或达到最大长度。&lt;/p&gt;
&lt;p&gt;如果按照这个流程直接计算，效率会非常低下。例如，在生成 &lt;code&gt;token_3&lt;/code&gt; 时，模型需要为 &lt;code&gt;[BOS]&lt;/code&gt; 和 &lt;code&gt;token_1&lt;/code&gt; 重新计算它们的 Q, K, V 向量并参与注意力计算。但事实上，&lt;code&gt;[BOS]&lt;/code&gt; 和 &lt;code&gt;token_1&lt;/code&gt; 的 Key 和 Value 向量在之前的步骤中已经被计算过了。&lt;/p&gt;
&lt;p&gt;为解决这种冗余计算，推理时会采用一项关键的优化技术：&lt;strong&gt;KV 缓存&lt;/strong&gt;。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;基本原理&lt;/strong&gt;：对于解码器的每一层，都缓存下截至当前时刻已经计算出的所有词元的 &lt;strong&gt;Key&lt;/strong&gt; 和 &lt;strong&gt;Value&lt;/strong&gt; 向量。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;工作流程&lt;/strong&gt;：在生成第 $t$ 个词元时，模型只需要为当前输入的第 $t-1$ 个词元计算出它自己的 $q_{t-1}, k_{t-1}, v_{t-1}$。然后，它从缓存中取出历史的 $K_{cache} = [k_0, k_1, &amp;hellip;, k_{t-2}]$ 和 $V_{cache} = [v_0, v_1, &amp;hellip;, v_{t-2}]$。最后，将新的 $k_{t-1}, v_{t-1}$ 追加到缓存中，并用 $q_{t-1}$ 与更新后的完整 $K_{cache}, V_{cache}$ 进行注意力计算。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;通过 KV 缓存，每次解码步骤的计算量从与整个已生成序列长度的平方（$O(T^2)$）相关，降低到只与序列长度（$O(T)$）线性相关，极大地加速了文本生成的速度，是实现高效大模型推理的常用技术之一。需要注意，KV 缓存占用会随步数线性增长（$O(T)$），在多层多头设置下需关注显存开销。&lt;/p&gt;
&lt;h2 id="四transformer-代码实践"&gt;四、Transformer 代码实践
&lt;/h2&gt;&lt;blockquote&gt;
&lt;p&gt;&lt;a class="link" href="https://github.com/datawhalechina/base-nlp/tree/main/code/C4/transformer" target="_blank" rel="noopener"
&gt;本节完整代码&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="41-项目结构设计"&gt;4.1 项目结构设计
&lt;/h3&gt;&lt;p&gt;为了更好地理解 Transformer 的内部工作机制，接下来尝试从零实现一个完整的 Transformer 模型。我们会采用**“先整体框架，后组件实现”&lt;strong&gt;的思路，拆分多个文件来构建项目。在前面我们详细分析了 Transformer 的几大核心组件，分别是&lt;/strong&gt;位置编码**、&lt;strong&gt;多头注意力&lt;/strong&gt;、&lt;strong&gt;前馈网络 &lt;strong&gt;以及&lt;/strong&gt;归一化&lt;/strong&gt;。为了体现这些组件的独立性和复用性，我们将遵循&lt;strong&gt;模块化&lt;/strong&gt;的设计原则，将它们拆分到 &lt;code&gt;src/&lt;/code&gt; 目录下的独立文件中，而将模型的组装和运行逻辑放在根目录的 &lt;code&gt;main.py&lt;/code&gt; 中。目录设计如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-text" data-lang="text"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;code/C4/transformer/
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;├── src/
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ ├── transformer.py # 核心框架：定义 Transformer、EncoderLayer 和 DecoderLayer
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ ├── attention.py # 核心组件：多头注意力机制 (MultiHeadAttention)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ ├── ffn.py # 核心组件：前馈神经网络 (FeedForward)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ ├── norm.py # 辅助组件：层归一化 (LayerNorm)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ └── pos.py # 辅助组件：位置编码 (PositionalEncoding)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;└── main.py # 入口脚本：组装模型并演示前向传播
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="42-搭建整体框架"&gt;4.2 搭建整体框架
&lt;/h3&gt;&lt;p&gt;在开始编写具体的注意力机制或前馈网络之前，我们可以先在 &lt;code&gt;src/transformer.py&lt;/code&gt; 中勾勒出模型的高层架构。这种**“自顶向下”**的编程方式有助于我们理清数据流向。通过前面的学习我们知道，Transformer 宏观上是一个 Encoder-Decoder 架构，所以首先要实现的主要是以下几个部分：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Embedding 层&lt;/strong&gt;：将输入的 token ID 转换为连续的向量表示，并加上位置编码以保留序列顺序信息。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Encoder 堆叠&lt;/strong&gt;：由 $N$ 个 &lt;code&gt;EncoderLayer&lt;/code&gt; 串联而成，负责深度提取和理解输入序列的特征。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Decoder 堆叠&lt;/strong&gt;：由 $N$ 个 &lt;code&gt;DecoderLayer&lt;/code&gt; 串联而成，负责基于 Encoder 的输出逐步生成目标序列。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Output 层&lt;/strong&gt;：一个线性层，将解码器的最终输出映射回词表大小，用于计算下一个词的概率分布。&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# src/transformer.py&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.pos&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt; &lt;span class="c1"&gt;# 稍后实现&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# ... 导入其他组件&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 嵌入层与位置编码&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# src_embedding: 将源语言序列映射为向量 (Encoder输入)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;src_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# tgt_embedding: 将目标语言序列映射为向量 (Decoder输入)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tgt_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 编码器与解码器堆叠&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 使用 ModuleList 来存储层列表，支持按索引访问和自动注册参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ModuleList&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ModuleList&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 输出头&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 生成掩码 (Padding Mask &amp;amp; Causal Mask)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;generate_mask&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 编码器前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 解码器前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dec_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 输出 Logits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dec_output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;有了这个骨架，接下来的任务就是填充 &lt;code&gt;EncoderLayer&lt;/code&gt; 和 &lt;code&gt;DecoderLayer&lt;/code&gt;，而它们又依赖于更底层的组件。&lt;/p&gt;
&lt;h3 id="43-实现核心组件"&gt;4.3 实现核心组件
&lt;/h3&gt;&lt;p&gt;（1）位置编码 (src/pos.py)&lt;/p&gt;
&lt;p&gt;在 &lt;code&gt;src/transformer.py&lt;/code&gt; 中我们引入了 &lt;code&gt;PositionalEncoding&lt;/code&gt;，它是 Transformer 处理序列顺序的关键。这里我们实现论文中的正弦位置编码。位置编码的核心在于初始化阶段，我们会预先计算好一个足够长的编码矩阵。它的计算公式使用了不同频率的正弦和余弦函数：&lt;/p&gt;
$$
PE(pos, 2i) = \sin(pos / 10000^{2i/d_{model}}) \\
PE(pos, 2i+1) = \cos(pos / 10000^{2i/d_{model}})
$$&lt;p&gt;在 &lt;code&gt;__init__&lt;/code&gt; 方法中，我们一次性生成这个矩阵，并将其注册为 buffer。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 正弦位置编码
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Transformer 论文中使用固定公式计算位置编码，不涉及可学习参数。
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 创建一个足够长的 PE 矩阵 [max_seq_len, dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 生成位置索引 [0, 1, ..., max_seq_len-1] -&amp;gt; [max_seq_len, 1]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 计算分母中的 div_term: 10000^(2i/dim) = exp(2i * -log(10000)/dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 这种对数变换的计算方式在数值上更稳定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;div_term&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;10000.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 填充 PE 矩阵&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 偶数维度用 sin，奇数维度用 cos&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;div_term&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;div_term&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 增加 batch 维度: [1, max_seq_len, dim] 以便广播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 注册为 buffer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# register_buffer 的作用是告诉 PyTorch：&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. &amp;#39;pe&amp;#39; 是模型状态的一部分，会随模型保存和加载 (state_dict)。&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. &amp;#39;pe&amp;#39; 不是模型参数 (Parameter)，优化器更新时不会更新它。&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;register_buffer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;pe&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;在前向传播中，我们的任务就是将位置编码加到输入的词嵌入上。由于我们预先生成的 &lt;code&gt;pe&lt;/code&gt; 矩阵可能比当前的输入序列 &lt;code&gt;x&lt;/code&gt; 要长，所以需要根据 &lt;code&gt;x&lt;/code&gt; 的实际长度对 &lt;code&gt;pe&lt;/code&gt; 进行切片。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Args:
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; x: 输入的词嵌入序列 [batch_size, seq_len, dim]
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Returns:
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 加上位置编码后的序列 [batch_size, seq_len, dim]
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 截取与输入序列长度对应的位置编码并相加&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x.size(1) 是 seq_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# self.pe 的形状是 [1, max_seq_len, dim]，切片后会自动广播到 batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;最后，我们可以编写一段简单的测试代码来验证维度是否正确。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化模块&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备输入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 输入为0，直接观察PE值&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 验证输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;--- PositionalEncoding Test ---&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Input shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--- PositionalEncoding Test ---
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（2）多头注意力 (src/attention.py)&lt;/p&gt;
&lt;p&gt;这是 Transformer 中最复杂的组件，用于从不同的“表示子空间”中提取信息。在初始化阶段，我们需要定义四个主要的线性层：&lt;code&gt;wq&lt;/code&gt;, &lt;code&gt;wk&lt;/code&gt;, &lt;code&gt;wv&lt;/code&gt; 用于将输入投影到 Q, K, V 空间，&lt;code&gt;wo&lt;/code&gt; 用于将多头注意力的输出投影回原始维度。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# src/attention.py&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 定义 Wq, Wk, Wv 矩阵&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 这里我们使用一个大的线性层一次性计算所有头的 Q, K, V&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 最终输出的线性层 Wo&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wo&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;这部分前向传播的重点是“分头”操作。我们不直接对 [batch, seq_len, dim] 进行计算，而是将其 reshape 为 [batch, n_heads, seq_len, head_dim]，这样就可以利用矩阵运算并行地处理所有头。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;span class="lnt"&gt;49
&lt;/span&gt;&lt;span class="lnt"&gt;50
&lt;/span&gt;&lt;span class="lnt"&gt;51
&lt;/span&gt;&lt;span class="lnt"&gt;52
&lt;/span&gt;&lt;span class="lnt"&gt;53
&lt;/span&gt;&lt;span class="lnt"&gt;54
&lt;/span&gt;&lt;span class="lnt"&gt;55
&lt;/span&gt;&lt;span class="lnt"&gt;56
&lt;/span&gt;&lt;span class="lnt"&gt;57
&lt;/span&gt;&lt;span class="lnt"&gt;58
&lt;/span&gt;&lt;span class="lnt"&gt;59
&lt;/span&gt;&lt;span class="lnt"&gt;60
&lt;/span&gt;&lt;span class="lnt"&gt;61
&lt;/span&gt;&lt;span class="lnt"&gt;62
&lt;/span&gt;&lt;span class="lnt"&gt;63
&lt;/span&gt;&lt;span class="lnt"&gt;64
&lt;/span&gt;&lt;span class="lnt"&gt;65
&lt;/span&gt;&lt;span class="lnt"&gt;66
&lt;/span&gt;&lt;span class="lnt"&gt;67
&lt;/span&gt;&lt;span class="lnt"&gt;68
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 线性投影&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# [batch, seq_len, dim] -&amp;gt; [batch, seq_len, dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wk&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 分头 (Split Heads)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 变换形状: [batch, seq_len, n_heads, head_dim] &lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 然后转置: [batch, n_heads, seq_len, head_dim] 以便并行计算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 计算缩放点积注意力 (Scaled Dot-Product Attention)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# scores: [batch, n_heads, seq_len, seq_len]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 应用掩码 (Masking)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# mask == 0 的位置被填充为负无穷，Softmax 后变为 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;-inf&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 5. Softmax 与加权求和&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# context: [batch, n_heads, seq_len, head_dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 6. 合并多头 (Concat Heads)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# [batch, n_heads, seq_len, head_dim] -&amp;gt; [batch, seq_len, n_heads, head_dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# -&amp;gt; [batch, seq_len, dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;contiguous&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 7. 输出层投影&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 单元测试&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化模块&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备输入 (Query, Key, Value 相同)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mha&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 验证输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;--- MultiHeadAttention Test ---&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Input shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--- MultiHeadAttention Test ---
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（3）前馈神经网络 (src/ffn.py)&lt;/p&gt;
&lt;p&gt;标准的 Transformer FFN 是一个简单的两层全连接网络，中间包含激活函数。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# src/ffn.py&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;FeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;w1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 升维&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;w2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 降维&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 线性变换 -&amp;gt; ReLU -&amp;gt; Dropout -&amp;gt; 线性变换&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;w2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;w1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2048&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化模块&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ffn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;FeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备输入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ffn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 验证输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;--- FeedForward Test ---&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Input shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--- FeedForward Test ---
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（4）层归一化 (src/norm.py)&lt;/p&gt;
&lt;p&gt;层归一化 (Layer Normalization) 是 Transformer 中用来稳定训练的组件。与 Batch Normalization 不同，它是在&lt;strong&gt;最后一个维度&lt;/strong&gt;（即特征维度 &lt;code&gt;dim&lt;/code&gt;）上进行归一化的。公式如下：&lt;/p&gt;
$$
y = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta
$$&lt;p&gt;其中 $\gamma$ 和 $\beta$ 是可学习的缩放和平移参数。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 层归一化 (Layer Normalization)
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 公式: y = (x - mean) / sqrt(var + eps) * gamma + beta
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 可学习参数 gamma (缩放) 和 beta (偏移)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# nn.Parameter 会被自动注册为模型参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x: [batch_size, seq_len, dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 在最后一个维度 (dim) 上计算均值和方差&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# keepdim=True 保持维度以便进行广播计算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# unbiased=False 使用有偏估计 (分母为 N)，与 PyTorch 默认行为一致&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;var&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;unbiased&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 归一化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;var&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 缩放和平移&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x_norm&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 单元测试&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化模块&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ln&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备输入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 验证输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;--- LayerNorm Test ---&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Input shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--- LayerNorm Test ---
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="44-组装与运行"&gt;4.4 组装与运行
&lt;/h3&gt;&lt;p&gt;（1）完善核心框架 (src/transformer.py)&lt;/p&gt;
&lt;p&gt;之前我们只搭建了 &lt;code&gt;Transformer&lt;/code&gt; 类的骨架，现在我们利用已经实现好的组件，按“编码器层 → 解码器层 → 辅助方法”的顺序来补全 &lt;code&gt;src/transformer.py&lt;/code&gt;。编码器层，这部分包含一个多头自注意力子层和一个前馈网络子层，每个子层后面都接残差连接和层归一化（Post-LN 结构），代码如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 导入组件&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.attention&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.ffn&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;FeedForward&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.norm&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.pos&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 多头自注意力子层&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前馈网络子层&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;FeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffn_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 1：自注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# Q=K=V=x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 2：前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffn_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;接下来是解码器层，这部分比编码器层多了一个“交叉注意力”子层，先是带掩码的自注意力，再是对编码器输出的交叉注意力，最后是前馈网络。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 带掩码的自注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 交叉注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attn_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;FeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffn_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 1：带掩码的自注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 2：交叉注意力（Q 来自解码器，K/V 来自编码器输出）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attn_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 3：前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffn_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;最后在 &lt;code&gt;Transformer&lt;/code&gt; 主类中，我们需要补全相关的辅助方法。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2048&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# ... 初始化嵌入层、位置编码、编码器/解码器堆叠以及输出层等 ...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_init_parameters&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_init_parameters&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;xavier_uniform_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;generate_mask&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# src_mask: [batch, 1, 1, src_len]，pad token 假设为 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# tgt_mask: [batch, 1, tgt_len, tgt_len]，结合 pad mask 和 causal mask&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_pad_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# [batch, 1, 1, tgt_len]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_subsequent_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tril&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;tgt_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tgt_pad_mask&lt;/span&gt; &lt;span class="o"&gt;&amp;amp;&lt;/span&gt; &lt;span class="n"&gt;tgt_subsequent_mask&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;encode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;src_embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder_layers&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tgt_embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder_layers&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（2）运行主程序 (main.py)&lt;/p&gt;
&lt;p&gt;现在所有的零件都准备好了，我们可以在 &lt;code&gt;main.py&lt;/code&gt; 中将它们组装起来，并运行一个简单的前向传播测试。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;span class="lnt"&gt;49
&lt;/span&gt;&lt;span class="lnt"&gt;50
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;src.transformer&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Transformer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;main&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 超参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2048&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;50&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 实例化模型&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 模拟输入数据&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 随机生成 src 和 tgt 序列 (假设 pad_token_id=0)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 确保没有 pad token 影响简单测试，或者手动插入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_len&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Model Architecture:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# print(model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;Test Input:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Source Shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Target Shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;Model Output:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output Shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 预期 [batch_size, tgt_len, tgt_vocab_size]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;main&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Model Architecture:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Test Input:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Source Shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Target Shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 12&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Model Output:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output Shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 12, 100&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;hr&gt;
&lt;h2 id="参考文献"&gt;参考文献
&lt;/h2&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1706.03762" target="_blank" rel="noopener"
&gt;Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). &lt;em&gt;Attention Is All You Need&lt;/em&gt;. NeurIPS 2017.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1606.08415" target="_blank" rel="noopener"
&gt;Hendrycks, D., Gimpel, K. (2016). &lt;em&gt;Gaussian Error Linear Units (GELUs)&lt;/em&gt;.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1607.06450" target="_blank" rel="noopener"
&gt;Ba, J. L., Kiros, J. R., Hinton, G. E. (2016). &lt;em&gt;Layer Normalization&lt;/em&gt;.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/2104.09864" target="_blank" rel="noopener"
&gt;Su, J., Lu, Y., Pan, S., Wen, Y. (2021). &lt;em&gt;RoFormer: Enhanced Transformer with Rotary Position Embedding&lt;/em&gt;.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>注意力机制</title><link>https://hanguangwu.github.io/blog/p/%E6%B3%A8%E6%84%8F%E5%8A%9B%E6%9C%BA%E5%88%B6/</link><pubDate>Sat, 07 Feb 2026 19:34:25 -0800</pubDate><guid>https://hanguangwu.github.io/blog/p/%E6%B3%A8%E6%84%8F%E5%8A%9B%E6%9C%BA%E5%88%B6/</guid><description>&lt;h1 id="注意力机制"&gt;注意力机制
&lt;/h1&gt;&lt;p&gt;在上一节的结尾，讨论了标准 Seq2Seq 架构存在的一个核心缺陷：&lt;strong&gt;信息瓶颈&lt;/strong&gt;。编码器需要将源序列的所有信息，不论长短，全部压缩成一个固定长度的上下文向量 $C$。这种机制在处理长序列时，很容易丢失序列开头的关键信息，同时也无法让解码器在生成不同词元时，有选择性地关注输入的不同部分。&lt;/p&gt;
&lt;p&gt;用上一节提到的对联任务举例，当上联是“两个黄鹂鸣翠柳”时，期望模型在生成下联时：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;生成第一个词“一行”时，主要关注上联的“两个”。&lt;/li&gt;
&lt;li&gt;生成第二个词“白鹭”时，主要关注上联的“黄鹂”。&lt;/li&gt;
&lt;li&gt;&amp;hellip;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;但是标准的 Seq2Seq 架构的模型在生成“一行”、“白鹭”、“上青天”的每一个词时，所依赖的都是&lt;strong&gt;同一个、包含了整个上联概要&lt;/strong&gt;的上下文向量 $C$。模型缺乏一种动态的、有倾向性的“关注”能力。为了解决这个问题，&lt;strong&gt;注意力机制 (Attention Mechanism)&lt;/strong&gt; &lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;被提出。&lt;/p&gt;
&lt;h2 id="一注意力机制的设计原理"&gt;一、注意力机制的设计原理
&lt;/h2&gt;&lt;p&gt;注意力机制的原理，可以通俗地理解为从“一言以蔽之”到“择其要者而观之”的转变。人类在进行阅读理解或翻译时，并不会将整个句子或段落的信息平均地记在脑海里。当回答特定问题或翻译特定词组时，我们的注意力会自然地聚焦到原文中的相关部分。&lt;/p&gt;
&lt;p&gt;注意力机制就是对这种认知行为的模拟。它的原理是&lt;strong&gt;在解码器生成每一个词元时，不再依赖一个固定的上下文向量，而是允许它“回头看”一遍完整的输入序列，并根据当前解码的需求，自主地为输入序列的每个部分分配不同的注意力权重，然后基于这些权重将输入信息加权求和，生成一个动态的、专属当前时间步的上下文向量&lt;/strong&gt;。通过这种方式，模型便获得了“择其要者而观之”的能力：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;在生成“一行”时，模型可以学会将最大的权重分配给“两个”所对应的编码器状态。&lt;/li&gt;
&lt;li&gt;在生成“白鹭”时，则将最大的权重分配给“黄鹂”所对应的状态。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个动态计算的权重，就是&lt;strong&gt;注意力权重&lt;/strong&gt;；而整个动态计算上下文向量的过程，就是&lt;strong&gt;注意力机制&lt;/strong&gt;。&lt;/p&gt;
&lt;h2 id="二注意力机制的动机与推导"&gt;二、注意力机制的动机与推导
&lt;/h2&gt;&lt;p&gt;为了更直观地理解注意力机制的必要性，可以跟随一个逐步深入的思路。&lt;/p&gt;
&lt;h3 id="21-问题的根源与固定对齐策略的局限"&gt;2.1 问题的根源与固定对齐策略的局限
&lt;/h3&gt;&lt;p&gt;标准的 Seq2Seq 模型之所以表现不佳，根源是它试图将源序列的所有信息&lt;strong&gt;无差别&lt;/strong&gt;地压缩进一个向量。但在对联这类任务中，输入和输出之间存在着明显的&lt;strong&gt;局部对应关系&lt;/strong&gt;。一个直观的想法是能不能建立一种固定的对齐策略？例如，在生成下联第一个词时，就只使用上联第一个词的编码信息；生成第二个词时，就只用第二个词的信息，以此类推。&lt;/p&gt;
&lt;p&gt;这个想法可以表示为：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$C_1 = h_1$ (生成第一个词的上下文是第一个编码状态)&lt;/li&gt;
&lt;li&gt;$C_2 = h_2$ (生成第二个词的上下文是第二个编码状态)&lt;/li&gt;
&lt;li&gt;&amp;hellip;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这种方法在处理像对联这样长度相等、词序对应的“特例”时似乎是可行的。但它的局限性非常明显：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;要求序列等长&lt;/strong&gt;：对于不等长的序列（如中英文翻译），这种一对一的映射关系立刻失效。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;对齐关系僵化&lt;/strong&gt;：它假设了输入和输出的对齐关系是固定不变的，但实际任务中的对应关系可能非常复杂（如一对多、多对一）。&lt;/p&gt;
&lt;p&gt;这种固定对齐策略过于理想化，缺乏通用性。我们需要一种更灵活、更具普适性的方法。&lt;/p&gt;
&lt;h3 id="22-注意力机制的动态加权原理"&gt;2.2 注意力机制的动态加权原理
&lt;/h3&gt;&lt;p&gt;既然只取一个输入信息过于绝对，那么退一步，是否可以把所有输入信息都利用起来，但给它们分配不同的“重要性”呢？这就是通过动态加权进行对齐的思想，即&lt;strong&gt;加权求和&lt;/strong&gt;。我们可以为解码的第 $t$ 步，动态地计算一个上下文向量 $C_t$，它由编码器&lt;strong&gt;所有&lt;/strong&gt;的隐藏状态 $(h_1, h_2, \dots, h_{T_x})$ 加权求和得到：&lt;/p&gt;
$$
C_t = \sum_{j=1}^{T_x} \alpha_{tj} h_j
$$&lt;p&gt;其中， $\alpha_{tj}$ 就是在解码第 $t$ 个词时，分配给输入第 $j$ 个词的&lt;strong&gt;注意力权重&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;在这个思路下，前面提到的“固定对齐策略”可以看作是它的一个特例。例如，当 $\alpha_{11}=1$ 且其他所有 $\alpha_{1j}=0$ 时，就实现了 $C_1 = h_1$ 的效果。&lt;/p&gt;
&lt;h3 id="23-如何确定权重"&gt;2.3 如何确定权重？
&lt;/h3&gt;&lt;p&gt;加权求和的思路虽然灵活，但它引入了一个新的问题：&lt;strong&gt;权重 $\alpha_{tj}$ 从何而来？&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这个权重显然不能是固定的。它必须是&lt;strong&gt;动态的&lt;/strong&gt;，应该根据当前的解码需求来决定。例如，当解码器正要生成与“黄鹂”对应的词时，权重 $\alpha_{t, \text{黄鹂}}$ 就应该最大。所以我们需要一个额外的模块或机制，它能够：&lt;/p&gt;
&lt;p&gt;（1）审视当前解码器的状态（例如，解码器上一时刻的隐藏状态 $h^{\prime}_{t-1}$）。&lt;/p&gt;
&lt;p&gt;（2）将这个状态与编码器的每一个隐藏状态 $h_j$ 进行比较。&lt;/p&gt;
&lt;p&gt;（3）根据比较结果，生成一组相应的权重 $(\alpha_{t1}, \alpha_{t2}, \dots, \alpha_{t,T_x})$。&lt;/p&gt;
&lt;p&gt;让模型&lt;strong&gt;自行学习&lt;/strong&gt;如何根据当前上下文来计算这组权重，正是注意力机制的关键。&lt;/p&gt;
&lt;h2 id="三注意力机制详解"&gt;三、注意力机制详解
&lt;/h2&gt;&lt;p&gt;带有注意力机制的 Encoder-Decoder 模型，其整体结构与标准 Seq2Seq 类似，主要区别在于解码器部分。编码器的工作保持不变，但是需要向解码器提供&lt;strong&gt;所有时间步&lt;/strong&gt;的隐藏状态序列 $(h_1, h_2, \dots, h_{T_x})$，而不仅仅是最后一个时间步的状态。解码器在生成第 $t$ 个目标词元 $y_t$ 时，会通过三步进行“注意力计算”，来动态生成该时刻的上下文向量 $C_t$。这个过程通常以上一时刻的解码器隐藏状态 $h^{\prime}_{t-1}$ 为起点。&lt;/p&gt;
&lt;h3 id="31-注意力计算三部曲"&gt;3.1 注意力计算三部曲
&lt;/h3&gt;&lt;p&gt;（1）计算相似度&lt;/p&gt;
&lt;p&gt;使用解码器上一时刻的隐藏状态 $h^{\prime}_{t-1}$ 与编码器的每一个隐藏状态 $h_j$ 计算一个分数，这个分数衡量了在当前解码时刻，应当对第 $j$ 个输入词元投入多少“关注”。&lt;/p&gt;
$$
e_{tj} = \text{score}(h^{\prime}_{t-1}, h_j)
$$&lt;p&gt;这个分数越高，代表关联性越强。计算这个分数的方式有很多种，例如简单的点积、或者引入一个可学习的神经网络层。&lt;/p&gt;
&lt;p&gt;（2）计算注意力权重&lt;/p&gt;
&lt;p&gt;得到输入序列所有位置的注意力分数 $(e_{t1}, e_{t2}, \dots, e_{t,T_x})$ 后，为了将它们转换成一种“权重”的表示，可使用 &lt;strong&gt;Softmax&lt;/strong&gt; 函数对其进行归一化。这样，就能得到一组总和为 1、且均为正数的注意力权重 $(\alpha_{t1}, \alpha_{t2}, \dots, \alpha_{t,T_x})$。&lt;/p&gt;
$$
\alpha_{tj} = \text{softmax}(e_{tj}) = \frac{\exp(e_{tj})}{\sum_{i=1}^{T_x} \exp(e_{ti})}
$$&lt;p&gt;这组权重 $\alpha_t$ 构成了一个概率分布，清晰地表明了在当前解码步骤 $t$，注意力应该如何分配在输入序列的各个位置上。&lt;/p&gt;
&lt;p&gt;（3）加权求和，生成上下文向量&lt;/p&gt;
&lt;p&gt;最后，使用上一步得到的注意力权重 $\alpha_{tj}$，对编码器的所有隐藏状态 $h_j$ 进行加权求和，从而得到当前解码时刻 $t$ 专属的上下文向量 $C_t$。&lt;/p&gt;
$$
C_t = \sum_{j=1}^{T_x} \alpha_{tj} h_j
$$&lt;p&gt;这个 $C_t$ 向量，由于是根据当前解码需求动态生成的，它比原始 Seq2Seq 的那个固定向量 $C$ 包含了更具针对性的信息。&lt;/p&gt;
&lt;h3 id="32-结合上下文进行预测"&gt;3.2 结合上下文进行预测
&lt;/h3&gt;&lt;p&gt;得到动态上下文向量 $C_t$ 后，模型会将其与当前解码器自身的输入词元 $y_{t-1}$ 的词嵌入结合起来（最常见的方式是将两者&lt;strong&gt;拼接&lt;/strong&gt;），形成一个新的、信息更丰富的向量。&lt;/p&gt;
&lt;p&gt;最后，将这个拼接后的向量连同上一时刻的状态 $h^\prime_{t-1}$ 一起送入解码器的 RNN 单元，计算出当前时刻的状态 $h^\prime_{t}$，并基于 $h^\prime_{t}$ 预测出最有可能的输出词元 $y_t$。整个过程可以通过图 4-2 来概括：&lt;/p&gt;
&lt;div align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/4_2_1.svg" width="90%" alt="Attention 工作流程" /&gt;
&lt;p&gt;图 4-2 Attention 工作流程&lt;/p&gt;
&lt;/div&gt;
&lt;h3 id="33-一种高效的注意力打分函数"&gt;3.3 一种高效的注意力打分函数
&lt;/h3&gt;&lt;p&gt;计算相关性分数的函数有多种设计，其中一种非常高效的方法，是直接计算查询向量（ $h^\prime_{t-1}$ ）和键向量（ $h_j$ ）的&lt;strong&gt;点积&lt;/strong&gt;，并对其进行&lt;strong&gt;缩放&lt;/strong&gt;。这种思想也是后续通用注意力框架的核心。其计算方式非常简洁：&lt;/p&gt;
$$
\text{score}(h^\prime_{t-1}, h_j) = \frac{{h^\prime_{t-1}}^T \cdot h_j}{\sqrt{d_k}}
$$&lt;p&gt;其中：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;${h^\prime_{t-1}}^T \cdot h_j$ 就是两个向量的点积。点积是衡量向量相似度的一种有效方式。&lt;/li&gt;
&lt;li&gt;$d_k$ 是键向量（在这里是编码器隐藏状态）的维度。&lt;/li&gt;
&lt;li&gt;除以 $\sqrt{d_k}$ 是一个关键的缩放步骤。当向量维度 $d_k$ 很大时，点积的结果的方差也会很大，这可能导致一些维度的值非常大，从而将 Softmax 函数推向其梯度极小的区域（即概率值极端地趋近于 0 或 1），造成梯度消失，使模型难以训练。通过除以 $\sqrt{d_k}$ 进行缩放，可以有效缓解这个问题，使训练过程更加稳定。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="34-注意力机制的价值"&gt;3.4 注意力机制的价值
&lt;/h3&gt;&lt;p&gt;引入注意力机制，不仅仅是对 Seq2Seq 架构的一个小修补，它还带来了一个全新的视角。&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;克服信息瓶颈，提升性能&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;最直接的好处是，注意力机制彻底打破了信息必须被压缩成一个固定长度向量的限制。解码器在每一步都可以直接访问到源序列的全部信息，并根据需要动态聚焦。这使得模型在处理尤其是长序列时，性能得到了巨大的提升。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;提供可解释性，实现“词对齐”&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;注意力机制的另一个巨大价值在于它提供了很好的&lt;strong&gt;可解释性&lt;/strong&gt;。注意力权重矩阵 $\alpha$ 本身就蕴含了丰富的信息。可以将这个矩阵可视化，来观察当模型生成某个输出词时，它的“注意力”主要集中在输入的哪些词上。&lt;/p&gt;
&lt;h2 id="四查询-键-值-qkv-范式"&gt;四、查询-键-值 (QKV) 范式
&lt;/h2&gt;&lt;p&gt;为了将刚刚描述的注意力计算过程抽象出来，形成一个更通用的思想，可以引入一个概念框架，即 &lt;strong&gt;查询-键-值 （Query-Key-Value, QKV）&lt;/strong&gt; 。这个范式将注意力的计算过程类比为一次信息检索：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;查询（Query）&lt;/strong&gt;：代表了当前的需求或意图。在 Seq2Seq 中，这就是解码器在生成下一个词元前的状态 $h^\prime_{t-1}$，可以理解为它在“查询”：“根据我现在的情况，我最需要输入序列的哪部分信息？”&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;键（Key）&lt;/strong&gt;：可以看作是输入序列中各个信息片段的“标签”或“索引”，用于和查询进行匹配。在 Seq2Seq 中，输入序列的每个词元的隐藏状态 $h_j$ 都对应一个“键”。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;值（Value）&lt;/strong&gt;：是与“键”对应的实际信息内容。在基础的注意力机制中，“键”和“值”通常是相同的，都来自于编码器的隐藏状态 $h_j$。&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;无论形式如何变化，注意力机制的本质都可以概括为：通过&lt;strong&gt;查询（Q）&lt;/strong&gt; 和一系列&lt;strong&gt;键（K）&lt;/strong&gt; 计算相关性（权重），然后利用这个权重，对与各个键对应的&lt;strong&gt;值（V）&lt;/strong&gt; 进行加权求和，得到最终的输出。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;具体计算过程，可以用一个凝练的数学公式来统一表达，这就是 &lt;strong&gt;缩放点积注意力（Scaled Dot-Product Attention）&lt;/strong&gt; ，它也是 Transformer 模型的核心组件之一&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;：&lt;/p&gt;
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$&lt;p&gt;这个公式准确概括了注意力的计算步骤：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;$QK^T$&lt;/strong&gt;：计算查询矩阵 $Q$ 和键矩阵 $K$ 的转置的点积，得到原始的&lt;strong&gt;注意力分数&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;$\sqrt{d_k}$&lt;/strong&gt;：对分数进行缩放，以维持训练稳定性，其中 $d_k$ 是键向量的维度。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;softmax(&amp;hellip;)&lt;/strong&gt;：通过 Softmax 函数将分数归一化，得到&lt;strong&gt;注意力权重&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;（4）&lt;strong&gt;&amp;hellip;V&lt;/strong&gt;：将得到的权重矩阵与值矩阵 $V$ 相乘，进行&lt;strong&gt;加权求和&lt;/strong&gt;，得到最终的输出。&lt;/p&gt;
&lt;p&gt;这个通用的范式很实用，是理解后续 Transformer 等更先进模型的基础。在不同的任务中，只需要思考如何定义场景中的 Q、K、V 即可应用注意力机制。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;在刚刚讨论的 Seq2Seq 中：Q 是解码器状态，K 和 V 都是编码器状态序列。&lt;/li&gt;
&lt;li&gt;在后续的自注意力 (Self-Attention) 机制中，Q, K, V 将全部来源于同一个序列自身。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;此外，为了增加模型的表达能力，还可以在计算注意力之前，对原始的 Q, K, V 向量各自通过一个独立的全连接层进行&lt;strong&gt;线性变换&lt;/strong&gt;，得到新的 Q&amp;rsquo;, K&amp;rsquo;, V&amp;rsquo;，再用它们进行注意力的计算。这种做法可以让模型学习到在不同的“子空间”中进行信息匹配和聚合。&lt;/p&gt;
&lt;h2 id="五pytorch-实现与代码解析"&gt;五、PyTorch 实现与代码解析
&lt;/h2&gt;&lt;blockquote&gt;
&lt;p&gt;&lt;a class="link" href="https://github.com/FutureUnreal/base-nlp/blob/main/code/C4/02_attention.py" target="_blank" rel="noopener"
&gt;本节完整代码&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="51-整体思路"&gt;5.1 整体思路
&lt;/h3&gt;&lt;p&gt;要在 PyTorch 中实现一个带注意力的 Seq2Seq 模型，需要对上节的代码进行一些关键的调整：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;编码器 &lt;code&gt;Encoder&lt;/code&gt;&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;forward&lt;/code&gt; 函数的返回值需要改变。除了最后一个时间步的隐藏状态 &lt;code&gt;(hidden, cell)&lt;/code&gt;，还需要返回&lt;strong&gt;所有时间步的输出&lt;/strong&gt; &lt;code&gt;outputs&lt;/code&gt;，这正是注意力机制计算所需要的 Key 和 Value。&lt;/li&gt;
&lt;li&gt;如果编码器是&lt;strong&gt;双向 (Bidirectional)&lt;/strong&gt; 的，其输出维度会是 &lt;code&gt;hidden_size * 2&lt;/code&gt;。那么就需要增加一个线性层对其进行降维，或将其状态进行合并，以便与单向的解码器状态维度相匹配。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（2）&lt;strong&gt;新增 &lt;code&gt;Attention&lt;/code&gt; 模块&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;创建一个独立的 &lt;code&gt;nn.Module&lt;/code&gt; 类来实现注意力的计算逻辑。其 &lt;code&gt;forward&lt;/code&gt; 方法接收解码器状态 (Query) 和编码器所有输出 (Keys/Values)，返回计算得到的上下文向量。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（3）&lt;strong&gt;解码器 &lt;code&gt;Decoder&lt;/code&gt;&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;解码器的结构变化是最大的。它需要实例化一个 &lt;code&gt;Attention&lt;/code&gt; 模块。&lt;/li&gt;
&lt;li&gt;其 &lt;code&gt;forward&lt;/code&gt; 函数通常&lt;strong&gt;以循环的方式&lt;/strong&gt;逐个时间步解码，因为在第 $t$ 步计算注意力时需要依赖第 $t-1$ 步的解码器状态。需要强调的是，RNN 解码本身就是按时间步顺序计算，不能在时间维度并行；Attention 并未改变这一点。相较于“整序列一次性送入 RNN”的写法，逐步解码更便于在每步显式计算注意力并灵活插入教师强制等策略。在循环的每一步，它都会调用 &lt;code&gt;Attention&lt;/code&gt; 模块计算上下文向量，并将其与当前词元的词嵌入融合后，再送入 RNN 单元。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="52-编码器"&gt;5.2 编码器
&lt;/h3&gt;&lt;p&gt;为支持 Attention，编码器通常使用双向 RNN 以捕获更丰富的上下文，并需要返回所有时间步的输出序列作为注意力计算的 Key 和 Value。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bidirectional&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt; &lt;span class="c1"&gt;# 使用双向LSTM&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将双向RNN的输出通过线性层降维，使其与解码器维度匹配&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;bidirectional=True&lt;/code&gt;&lt;/strong&gt;：启用双向 LSTM，使原始 RNN &lt;code&gt;outputs&lt;/code&gt; 维度变为 &lt;code&gt;(batch, src_len, hidden_size * 2)&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;self.fc&lt;/code&gt;&lt;/strong&gt;：定义一个线性层，将拼接后的双向输出映射回 &lt;code&gt;hidden_size&lt;/code&gt; 维度；经过 &lt;code&gt;self.fc&lt;/code&gt; 和 &lt;code&gt;tanh&lt;/code&gt; 后，&lt;code&gt;outputs&lt;/code&gt; 维度回到 &lt;code&gt;(batch, src_len, hidden_size)&lt;/code&gt;，方便后续计算。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;return outputs, ...&lt;/code&gt;&lt;/strong&gt;：返回降维后的所有时间步输出 &lt;code&gt;outputs&lt;/code&gt; (作为后续的 K 和 V)，以及原始的最终状态 &lt;code&gt;hidden&lt;/code&gt; 和 &lt;code&gt;cell&lt;/code&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="53-注意力模块的两种实现"&gt;5.3 注意力模块的两种实现
&lt;/h3&gt;&lt;p&gt;这是模型的核心，我们通过两个版本的实现来体现具体的演进思路。&lt;/p&gt;
&lt;h4 id="531-无参数的注意力"&gt;5.3.1 无参数的注意力
&lt;/h4&gt;&lt;p&gt;这个版本直接使用缩放点积来计算注意力，不引入额外的可学习参数，对应了注意力机制最基础的数学思想。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;AttentionSimple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;1: 无参数的注意力模块&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;AttentionSimple&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 确保缩放因子是一个 non-learnable buffer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;register_buffer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;scale_factor&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;FloatTensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;])))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# hidden shape: (num_layers, batch_size, hidden_size)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# encoder_outputs shape: (batch_size, src_len, hidden_size)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# Q: 解码器最后一层的隐藏状态&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;query&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; (batch, 1, hidden)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# K/V: 编码器的所有输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; (batch, src_len, hidden)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# energy shape: (batch, 1, src_len)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bmm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale_factor&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# attention_weights shape: (batch, src_len)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;此方法的思路是：&lt;strong&gt;相似度越高的向量，其点积越大&lt;/strong&gt;。直接利用这一数学特性来衡量 Query 与各个 Key 的关联程度。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;query = hidden[-1].unsqueeze(1)&lt;/code&gt;: 提取解码器上一时间步的最终隐藏状态，作为当前解码需求的&lt;strong&gt;查询 (Query)&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;energy = torch.bmm(...)&lt;/code&gt;: 通过矩阵乘法，一次性计算出 Query 向量与所有 Key 向量（即编码器各时刻的输出）的点积。这个点积结果 &lt;code&gt;energy&lt;/code&gt; 直接反映了它们之间的原始相似度分数。除以一个缩放因子是为了让训练过程更稳定。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;return torch.softmax(...)&lt;/code&gt;: 使用 Softmax 函数将原始分数转换成一个标准的概率分布，即最终的&lt;strong&gt;注意力权重&lt;/strong&gt;。分数越高的位置，获得的权重也越大。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h4 id="532-带参数的注意力"&gt;5.3.2 带参数的注意力
&lt;/h4&gt;&lt;p&gt;这个版本引入了可学习的参数（一个线性层和一个向量 &lt;code&gt;v&lt;/code&gt;），让模型可以自主学习如何更好地对齐 Query 和 Keys。从 QKV 的来源来看，由于查询（Query）来自解码器，而键（Key）和值（Value）来自编码器，因此这种机制也可以称为&lt;strong&gt;交叉注意力 (Cross-Attention)&lt;/strong&gt;。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;“交叉”体现在哪里？&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;“交叉”一词形象说明了注意力计算的信息来源是&lt;strong&gt;两个不同的序列&lt;/strong&gt;。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;查询&lt;/strong&gt;：来自解码器序列的当前状态（例如 $h^\prime_{t-1}$），代表了“我现在需要什么信息”。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;键和值&lt;/strong&gt;：来自编码器处理完&lt;strong&gt;整个源序列&lt;/strong&gt;后产生的所有状态（例如 $h_1, h_2, \dots, h_{T_x}$），代表了“这里有全部的原始信息可供查询”。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;信息从编码器序列流向解码器序列，两者通过注意力机制进行互动和对齐，因此被称为“交叉注意力”。与之相对的是自注意力 (Self-Attention)，其查询、键、值均来自同一个序列。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;AttentionParams&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;2: 带参数的注意力模块&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;AttentionParams&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_last_layer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;energy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;hidden_last_layer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;energy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;此方法的核心是创建一个小型的神经网络（&lt;code&gt;self.attn&lt;/code&gt; 和 &lt;code&gt;self.v&lt;/code&gt;），让模型&lt;strong&gt;自主学习&lt;/strong&gt;如何判断 Query 和 Key 之间的相关性，而不是使用固定的点积运算。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;self.attn&lt;/code&gt;: 一个线性层，它将解码器状态（Query）和编码器状态（Key）拼接后的信息进行变换，学习它们之间复杂的对齐关系。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;self.v&lt;/code&gt;: 一个可学习的向量，它的作用是将 &lt;code&gt;self.attn&lt;/code&gt; 计算出的多维对齐信息，最终转化为一个单一的注意力分数。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;hidden_last_layer = ...&lt;/code&gt;: 将代表当前 Query 的解码器状态复制，使其能与每一个编码器状态进行配对。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;energy = torch.tanh(...)&lt;/code&gt;: 将配对好的 Query 和 Key 拼接起来，一同送入 &lt;code&gt;self.attn&lt;/code&gt; 线性层。这一步会计算出一个“能量”向量，&lt;code&gt;tanh&lt;/code&gt; 激活函数则为其增加了非线性表达能力。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;attention = torch.sum(...)&lt;/code&gt;: 利用可学习的 &lt;code&gt;v&lt;/code&gt; 向量与这个向量进行点积，将多维的能量信息“压缩”成一个最终的、未经归一化的注意力分数。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;return torch.softmax(attention, dim=1)&lt;/code&gt;: 和之前一样，使用 &lt;code&gt;softmax&lt;/code&gt; 将分数转换为标准的注意力权重。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="54-通用解码器"&gt;5.4 通用解码器
&lt;/h3&gt;&lt;p&gt;为了能同时适配上述两种 Attention 模块，我们设计一个通用的解码器。其核心改动是在每个时间步都&lt;strong&gt;调用 Attention 模块&lt;/strong&gt;，并将计算出的上下文向量&lt;strong&gt;融入到当前步的输入&lt;/strong&gt;中。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DecoderWithAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;attention_module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;DecoderWithAttention&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;attention_module&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 输入维度是 词嵌入(hidden_size) + 上下文向量(hidden_size)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 计算注意力权重&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# a shape: [batch, src_len]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 计算上下文向量&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bmm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 将上下文向量与当前输入拼接&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;rnn_input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 传入RNN解码&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rnn_input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 5. 预测输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;predictions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;self.attention&lt;/code&gt;: 持有传入的 Attention 实例（可以是 &lt;code&gt;AttentionSimple&lt;/code&gt; 或 &lt;code&gt;AttentionParams&lt;/code&gt;）。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;self.rnn&lt;/code&gt;: &lt;code&gt;input_size&lt;/code&gt; 变为 &lt;code&gt;hidden_size * 2&lt;/code&gt;，因为它接收的是&lt;strong&gt;词嵌入向量&lt;/strong&gt;和&lt;strong&gt;上下文向量&lt;/strong&gt;拼接后的结果。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt;: 完整地演示了注意力的应用流程。
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;a = self.attention(...)&lt;/code&gt;: 调用 attention 模块计算权重 &lt;code&gt;a&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;context = torch.bmm(a, encoder_outputs)&lt;/code&gt;: 对应注意力计算的&lt;strong&gt;第三步&lt;/strong&gt;。使用矩阵乘法，通过权重 &lt;code&gt;a&lt;/code&gt; 对 &lt;code&gt;encoder_outputs&lt;/code&gt; (Values) 进行加权求和，得到上下文向量 &lt;code&gt;context&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;rnn_input = torch.cat(...)&lt;/code&gt;: 将动态生成的上下文向量和当前词的嵌入向量拼接起来，形成一个信息更丰富的输入，再送入 RNN 进行解码。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="55-seq2seq-包装模块"&gt;5.5 Seq2Seq 包装模块
&lt;/h3&gt;&lt;p&gt;这个模块负责将上述组件串联起来，并处理好双向编码器与单向解码器之间的状态传递问题。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Seq2Seq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;带注意力的Seq2Seq&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Seq2Seq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;teacher_forcing_ratio&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;out_features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 适配Encoder(双向)和Decoder(单向)的状态维度&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 在循环的每一步，都将 encoder_outputs 传递给解码器&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 这是 Attention 机制能够&amp;#34;回顾&amp;#34;整个输入序列的关键&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_outputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;teacher_force&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;teacher_forcing_ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;top1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;teacher_force&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;top1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;状态适配&lt;/strong&gt;: 编码器是双向的，其 &lt;code&gt;hidden&lt;/code&gt; 状态形状为 &lt;code&gt;(num_layers * 2, ...)&lt;/code&gt;。解码器是单向的，需要 &lt;code&gt;(num_layers, ...)&lt;/code&gt; 的初始状态。这里的 &lt;code&gt;hidden.view(...).sum(dim=1)&lt;/code&gt; 通过 &lt;code&gt;view&lt;/code&gt; 操作将状态拆分为 &lt;code&gt;(层数, 方向, ...)&lt;/code&gt;，然后在方向维度上求和，巧妙地将双向状态合并为单向状态。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;循环解码&lt;/strong&gt;: 正如之前所强调的，Attention 机制下的解码必须是串行的。在 &lt;code&gt;for&lt;/code&gt; 循环的每一步，都将 &lt;code&gt;encoder_outputs&lt;/code&gt; 完整地传递给解码器，确保解码器在每个时间步都能基于上一时刻的状态，动态计算出当前最需要的上下文信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个实现完整地展示了 Attention 机制如何克服信息瓶颈问题：解码器不再只依赖于一个固定的上下文向量，而是在生成的每一步，都通过 Attention 模块动态地计算出一个与当前解码状态最相关的上下文向量，极大地提升了模型性能。&lt;/p&gt;
&lt;h3 id="完整python代码"&gt;完整Python代码
&lt;/h3&gt;&lt;h2 id="六注意力机制的类型"&gt;六、注意力机制的类型
&lt;/h2&gt;&lt;p&gt;在注意力机制发展的早期，受限于当时的硬件计算能力，研究者们为了降低计算开销，提出了一些不同类型的注意力机制。&lt;/p&gt;
&lt;h3 id="61-soft-attention-vs-hard-attention"&gt;6.1 Soft Attention vs. Hard Attention
&lt;/h3&gt;&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Soft Attention&lt;/strong&gt;：这就是前文一直在详细讨论的机制。它为输入序列的&lt;strong&gt;所有&lt;/strong&gt;位置都计算一个注意力权重，这些权重是 0 到 1 之间的浮点数（经 Softmax 归一化），然后进行加权求和。这种方式的优点是模型是端到端可微的，可以使用标准的梯度下降法进行训练。其缺点是在处理非常长的序列时，计算开销会很大。因为解码的每一步，都需要计算当前状态与所有输入状态的相似度。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Hard Attention&lt;/strong&gt;&lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;：与 Soft Attention 对所有输入进行加权不同，Hard Attention 在每一步只&lt;strong&gt;选择一个&lt;/strong&gt;最相关的输入位置。可以看作是一种“非 0 即 1”的注意力分配，即选中的位置权重为 1，其他所有位置的权重均为 0。这样做的好处是计算量大大减少，因为不再需要进行全面的加权求和。但它的缺点也很突出：选择过程是离散的、不可微的，因此无法使用常规的反向传播算法进行训练，通常需要借助强化学习等更复杂的技巧。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="62-global-attention-vs-local-attention"&gt;6.2 Global Attention vs. Local Attention
&lt;/h3&gt;&lt;p&gt;这是另一组从计算范围角度区分的概念，出自于另一篇开创性的论文&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Global Attention (全局注意力)&lt;/strong&gt;：其思想与 Soft Attention 基本一致，即在计算注意力时，会考虑编码器的&lt;strong&gt;所有&lt;/strong&gt;隐藏状态。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Local Attention (局部注意力)&lt;/strong&gt;：这是一种介于 Soft Attention 和 Hard Attention 之间的折中方案。能够减少计算量，但又不像 Hard Attention 那样极端。其核心思想是，在每个解码时间步，只关注输入序列的一个&lt;strong&gt;局部窗口&lt;/strong&gt;。它的工作流程通常是：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;预测对齐位置&lt;/strong&gt;：首先，模型需要预测一个当前解码步最关注的源序列位置 $p_t$。这个位置可以通过一个小型神经网络，仅依赖于当前解码器状态 $h^\prime_t$ 来预测，从而避免了与所有编码器状态进行比较，降低了计算成本。预测公式可以设计为： $p_t = T_x \cdot \text{sigmoid}(W_p h&amp;rsquo;_t + b_p)$，其中 $T_x$ 是源序列长度， $W_p$ 和 $b_p$ 是可学习的参数。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;定义窗口&lt;/strong&gt;：以预测出的 $p_t$ 为中心，定义一个大小为 $2D+1$ 的窗口，其中 $D$ 是一个超参数。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;局部计算&lt;/strong&gt;：最后，模型只在这个窗口内的编码器状态上应用 Soft Attention 机制，计算权重并生成上下文向量。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;早期 Local Attention 通过局部窗口显著降低复杂度并保持良好性能。尽管硬件与内核优化推动了全局注意力在常规长度任务中的普及，但其 $O(N^2)$ 成本在长序列、低延迟或资源受限场景仍是瓶颈，因此局部/稀疏/窗口化/混合注意力在这些场景依然常用。&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id="参考文献"&gt;参考文献
&lt;/h2&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1409.0473" target="_blank" rel="noopener"
&gt;Bahdanau, D., Cho, K., &amp;amp; Bengio, Y. (2014). &lt;em&gt;Neural machine translation by jointly learning to align and translate&lt;/em&gt;. arXiv preprint arXiv:1409.0473.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1706.03762" target="_blank" rel="noopener"
&gt;Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). &lt;em&gt;Attention Is All You Need&lt;/em&gt;. NeurIPS 2017.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1502.03044" target="_blank" rel="noopener"
&gt;Xu, K., Ba, J., Kiros, R., Cho, K., Courville, A., Salakhudinov, R., &amp;hellip; &amp;amp; Bengio, Y. (2015). &lt;em&gt;Show, attend and tell: Neural image caption generation with visual attention&lt;/em&gt;. arXiv preprint arXiv:1502.03044.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1508.04025" target="_blank" rel="noopener"
&gt;Luong, M. T., Pham, H., &amp;amp; Manning, C. D. (2015). &lt;em&gt;Effective approaches to attention-based neural machine translation&lt;/em&gt;. arXiv preprint arXiv:1508.04025.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>