<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Transformer on 酒中仙</title><link>https://hanguangwu.github.io/blog/tags/transformer/</link><description>Recent content in Transformer on 酒中仙</description><generator>Hugo -- gohugo.io</generator><language>zh-cn</language><copyright>hanguangwu</copyright><lastBuildDate>Sat, 07 Feb 2026 20:34:25 -0800</lastBuildDate><atom:link href="https://hanguangwu.github.io/blog/tags/transformer/index.xml" rel="self" type="application/rss+xml"/><item><title>深入解析 Transformer 架构</title><link>https://hanguangwu.github.io/blog/p/%E6%B7%B1%E5%85%A5%E8%A7%A3%E6%9E%90-transformer-%E6%9E%B6%E6%9E%84/</link><pubDate>Sat, 07 Feb 2026 20:34:25 -0800</pubDate><guid>https://hanguangwu.github.io/blog/p/%E6%B7%B1%E5%85%A5%E8%A7%A3%E6%9E%90-transformer-%E6%9E%B6%E6%9E%84/</guid><description>&lt;h1 id="深入解析-transformer"&gt;深入解析 Transformer
&lt;/h1&gt;&lt;p&gt;注意力机制通过动态加权的方式，克服了传统 Seq2Seq 模型中的“信息瓶颈”问题。但是，这些模型依然依赖于 RNN 来处理序列信息，也就是说它们必须按顺序，一个词元接一个词元地进行计算，这在处理长序列时效率低下，并且存在长距离依赖信息丢失的问题。&lt;/p&gt;
&lt;p&gt;2017年，Google 的研究团队发表了一篇名为《Attention Is All You Need》的论文，提出了一种全新的架构——&lt;strong&gt;Transformer&lt;/strong&gt; &lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;。这篇论文的标题很有冲击力，其思想也同样有颠覆性。它抛弃了传统的 RNN 和卷积网络，整个模型基于注意力机制来构建。Transformer 的提出在自然语言处理领域具有划时代的意义。它不仅凭借其出色的并行计算能力极大地提升了训练效率，还更有效地捕捉了文本中的长距离依赖关系，为后续的 BERT、GPT 等大规模预训练模型的诞生提供了架构基础。&lt;/p&gt;
&lt;h2 id="一自注意力机制"&gt;一、自注意力机制
&lt;/h2&gt;&lt;p&gt;从根本上说，要让模型理解一段文本，就需要提取其“序列特征”，即将文本中所有词元的信息以某种方式整合起来。RNN 通过依次传递隐藏状态来顺序地整合信息，而 Transformer 则选择了一条截然不同的道路。其核心是 &lt;strong&gt;自注意力机制&lt;/strong&gt;。它不再依赖于顺序计算，而是将提取序列特征的过程看作是输入序列“自己对自己进行注意力计算”。序列中的每个词元都会“审视”序列中的所有其他词元，来动态地计算出最能代表当前词元上下文含义的新表示。与上一节介绍的交叉注意力不同，在自注意力中，&lt;strong&gt;Query、Key、Value 均来源于同一个输入序列&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;举个例子，在句子“苹果公司发布了新款手机，&lt;strong&gt;它&lt;/strong&gt;采用了最新的芯片”中，要理解代词“它”指的是“新款手机”而不是“苹果公司”，模型就需要将“它”与句子中的其他词元进行关联。自注意力机制正是通过计算“它”对句中其他所有词的注意力权重来实现这一点的。&lt;/p&gt;
&lt;h3 id="11-自注意力与交叉注意力的区别"&gt;1.1 自注意力与交叉注意力的区别
&lt;/h3&gt;&lt;p&gt;从结构上看，自注意力与交叉注意力的区别在于&lt;strong&gt;信息的来源和流动方向&lt;/strong&gt;。在交叉注意力机制中，信息在两个不同的序列之间流动。通常，&lt;strong&gt;Query&lt;/strong&gt; 来自解码器（代表当前的目标序列状态），而 &lt;strong&gt;Key&lt;/strong&gt; 和 &lt;strong&gt;Value&lt;/strong&gt; 来自编码器的所有输出（代表完整的源序列信息）。其目的是在生成目标序列的每一步时，从源序列中寻找最相关的信息。&lt;/p&gt;
&lt;p&gt;&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/20260207205125471.png"
loading="lazy"
&gt;&lt;/p&gt;
&lt;p&gt;而在&lt;strong&gt;自注意力&lt;/strong&gt;机制中，信息则是在&lt;strong&gt;同一个序列内部&lt;/strong&gt;进行流动和重组。它的 &lt;strong&gt;Query, Key, 和 Value 都来自同一个输入序列&lt;/strong&gt;。其目的是为了捕捉输入序列内部的依赖关系，重新计算序列中每个词元的表示，使其包含更丰富的上下文信息。&lt;/p&gt;
&lt;p&gt;&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/20260207205104142.png"
loading="lazy"
&gt;&lt;/p&gt;
&lt;p&gt;总结来说，尽管底层的加权求和计算方式相似，但两者在架构上的目标完全不同：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;交叉注意力&lt;/strong&gt;：用于&lt;strong&gt;对齐&lt;/strong&gt;和&lt;strong&gt;整合&lt;/strong&gt;两个&lt;strong&gt;不同&lt;/strong&gt;序列之间的信息。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;自注意力&lt;/strong&gt;：用于&lt;strong&gt;理解&lt;/strong&gt;和&lt;strong&gt;重构&lt;/strong&gt;单个序列&lt;strong&gt;内部&lt;/strong&gt;的依赖关系。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="12-自注意力的计算过程"&gt;1.2 自注意力的计算过程
&lt;/h3&gt;&lt;p&gt;自注意力的计算过程与上一节介绍的 QKV 范式完全一致，关键区别在于 Q, K, V 的来源。&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;生成 Q, K, V 向量&lt;/strong&gt;：&lt;/p&gt;
&lt;p&gt;对于输入序列中的&lt;strong&gt;每一个&lt;/strong&gt;词元，首先获取其词嵌入向量 $x_i$。然后，将该向量分别与三个可学习的、在整个模型中共享的权重矩阵 $W^Q, W^K, W^V$ 相乘，生成该词元专属的 Query 向量 $q_i$、Key 向量 $k_i$ 和 Value 向量 $v_i$。&lt;/p&gt;
$$
q_i = x_i W^Q \\
k_i = x_i W^K \\
v_i = x_i W^V
$$&lt;p&gt;这三个矩阵的作用是将原始的词嵌入向量投影到不同的、专门用于注意力计算的表示空间中，赋予了模型更大的灵活性。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;计算注意力分数&lt;/strong&gt;：&lt;/p&gt;
&lt;p&gt;为了计算第 $i$ 个词元的新表示，需要用它的 Query 向量 $q_i$ 去和&lt;strong&gt;所有&lt;/strong&gt;词元（包括它自己）的 Key 向量 $k_j$ 计算点积，得到注意力分数。&lt;/p&gt;
$$
\text{score}(i, j) = q_i \cdot k_j
$$&lt;p&gt;（3）&lt;strong&gt;缩放与归一化&lt;/strong&gt;：&lt;/p&gt;
&lt;p&gt;将得到的分数除以一个缩放因子 $\sqrt{d_k}$（$d_k$ 是 Key 向量的维度），然后通过 Softmax 函数进行归一化，得到最终的注意力权重 $\alpha_{ij}$。这个缩放步骤的目的与上一节中介绍的一致，都是为了在训练过程中保持梯度稳定。当向量维度 $d_k$ 较大时，点积结果的方差会增大，可能将 Softmax 函数推向其梯度极小的区域，从而导致梯度消失，影响模型学习。进行缩放可以有效缓解这个问题。&lt;/p&gt;
$$
\alpha_{ij} = \text{softmax}\left(\frac{q_i \cdot k_j}{\sqrt{d_k}}\right)
$$&lt;ol start="4"&gt;
&lt;li&gt;&lt;strong&gt;加权求和&lt;/strong&gt;：&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;使用计算出的权重 $\alpha_{ij}$ 对&lt;strong&gt;所有&lt;/strong&gt;词元的 Value 向量 $v_j$ 进行加权求和，得到第 $i$ 个词元经过自注意力计算后得到的新表示 $z_i$。&lt;/p&gt;
$$
z_i = \sum_j \alpha_{ij} v_j
$$&lt;p&gt;通过这个过程，输出向量 $z_i$ 不再仅仅包含原始词元 $x_i$ 的信息，而是融合了整个序列中所有与之相关词元的信息，成为一个上下文感知的、更丰富的表示。其本质可以理解为：序列中的每个词元都同时扮演着“查询（Q）”、“键（K）”和“值（V）”三种角色。通过计算查询与其他所有词元的键之间的相关性，来决定如何加权融合所有词元的值，从而为每个词元生成一个全新的、深度融合了全局上下文信息的表示。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;既然 Q, K, V 都来自同一个输入 X，为什么不直接用 X 计算，而要引入三个独立的权重矩阵 $W^Q, W^K, W^V$？甚至，为什么是三个，而不是两个或四个？&lt;/p&gt;
&lt;p&gt;这可以类比在图书馆查资料的过程：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Query (Q) - 要问的问题&lt;/strong&gt;：代表了我们主动想查询的意图。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Key (K) - 书的索引/标签&lt;/strong&gt;：代表了书本内容的关键特征，用于被动地和你的问题进行匹配。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Value (V) - 书的具体内容&lt;/strong&gt;：代表了书本实际包含的信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;我们的“问题”和书本的“索引”可能都源于同一个知识领域（同一个输入 X），但它们在信息检索这个任务中扮演的角色是截然不同的。$W^Q, W^K, W^V$ 这三个矩阵的作用，就是让模型学会将原始输入 X 投影到三个功能不同的空间中，分别去扮演好“查询者”、“被查询的索引”和“信息提供者”这三种角色。Q-K 配对解决了“如何定位相关信息”的问题，而 V 提供了“应该提取什么信息”的答案。这个三元组结构在功能上是完备且高效的，所以成为了注意力机制的标准范式。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="13-矩阵运算与并行化"&gt;1.3 矩阵运算与并行化
&lt;/h3&gt;&lt;p&gt;上述步骤描述的是单个词元 $i$ 的计算过程。在实际应用中，如果采用循环的方式逐个计算每个词元的 $z_i$，效率会非常低下。自注意力的巨大优势在于其&lt;strong&gt;并行计算&lt;/strong&gt;能力，这通过将整个过程表达为矩阵运算来实现。&lt;/p&gt;
&lt;p&gt;假设整个输入序列的词嵌入矩阵为 $X$（维度为 &lt;code&gt;[sequence_length, embedding_dim]&lt;/code&gt;），可以一次性计算出所有词元的 Q, K, V 矩阵：&lt;/p&gt;
$$
Q = X W^Q \\
K = X W^K \\
V = X W^V
$$&lt;p&gt;然后，整个自注意力的输出矩阵 $Z$ 可以通过一个公式完成计算：&lt;/p&gt;
$$
Z = \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$&lt;p&gt;这个公式与上一节中介绍的通用注意力公式完全相同。这里的主要区别不在于数学运算，而在于&lt;strong&gt;输入的来源&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;在上一节的&lt;strong&gt;交叉注意力&lt;/strong&gt;中，Q 来自一个序列（解码器），而 K 和 V 来自另一个序列（编码器）。&lt;/li&gt;
&lt;li&gt;在当前的&lt;strong&gt;自注意力&lt;/strong&gt;中，矩阵 Q、K 和 V &lt;strong&gt;全部派生自同一个输入序列 X&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;所以，同一个数学范式，根据输入来源的不同，被用于解决两个不同的问题，一个是两个序列之间的对齐，另一个是单个序列内部的依赖关系建模。在这个公式中， $QK^T$ 的计算结果是一个维度为 &lt;code&gt;[sequence_length, sequence_length]&lt;/code&gt; 的注意力分数矩阵，其中第 $i$ 行第 $j$ 列的元素表示第 $i$ 个词元对第 $j$ 个词元的注意力分数（未归一化的 logits）。注意力权重来自对缩放分数应用 Softmax 后得到的归一化系数。&lt;/p&gt;
&lt;h3 id="14-pytorch-实现自注意力"&gt;1.4 PyTorch 实现自注意力
&lt;/h3&gt;&lt;blockquote&gt;
&lt;p&gt;&lt;a class="link" href="https://github.com/FutureUnreal/base-nlp/blob/main/code/C4/03_Self-Attention.py" target="_blank" rel="noopener"
&gt;本节完整代码&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;从概念上讲，自注意力的计算可以分解为对序列中每个词元进行循环操作，这种方式虽然直观但效率极低。因此，现代深度学习框架中的实现都采用了&lt;strong&gt;矩阵运算&lt;/strong&gt;的方式。通过将整个序列的 Q, K, V 看作矩阵，利用一次大规模的矩阵乘法，就能并行地完成所有词元之间的相关性计算。下面是这种并行化版本的实现：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;SelfAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;自注意力模块&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;SelfAttention&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;k_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;k_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attention_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;: 初始化了三个 &lt;code&gt;nn.Linear&lt;/code&gt; 层，它们分别对应将输入映射到 Q, K, V 空间的权重矩阵 $W^Q, W^K, W^V$。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;q_linear(x)&lt;/code&gt;, &lt;code&gt;k_linear(x)&lt;/code&gt;, &lt;code&gt;v_linear(x)&lt;/code&gt;：将形状为 &lt;code&gt;[batch_size, seq_len, hidden_size]&lt;/code&gt; 的输入张量 &lt;code&gt;x&lt;/code&gt; 分别通过三个线性层，一次性地为序列中的所有词元计算出 Q, K, V 矩阵。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;torch.matmul(q, k.transpose(-2, -1))&lt;/code&gt;: 这是实现并行计算的核心。通过将 K 矩阵的最后两个维度转置（&lt;code&gt;seq_len, hidden_size&lt;/code&gt; -&amp;gt; &lt;code&gt;hidden_size, seq_len&lt;/code&gt;），再与 Q 矩阵相乘，直接得到了一个 &lt;code&gt;[batch_size, seq_len, seq_len]&lt;/code&gt; 的分数矩阵。该矩阵中的 &lt;code&gt;scores[b, i, j]&lt;/code&gt; 代表了批次 &lt;code&gt;b&lt;/code&gt; 中第 &lt;code&gt;i&lt;/code&gt; 个词元对第 &lt;code&gt;j&lt;/code&gt; 个词元的注意力分数。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;/ math.sqrt(self.hidden_size)&lt;/code&gt;：执行缩放操作，防止梯度消失。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;torch.softmax(scores, dim=-1)&lt;/code&gt;：对分数的最后一个维度（&lt;code&gt;seq_len&lt;/code&gt;）进行 Softmax，得到归一化的注意力权重。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;torch.matmul(attention_weights, v)&lt;/code&gt;：将权重矩阵与 V 矩阵相乘，完成了对所有词元的 Value 向量的加权求和，得到最终的上下文感知表示。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="二多头注意力机制"&gt;二、多头注意力机制
&lt;/h2&gt;&lt;p&gt;仅仅用一组 $W^Q, W^K, W^V$ 矩阵进行一次自注意力计算，相当于只从一个“视角”来审视文本内在的关系。然而，文本中的关系是多层次的，例如，一组参数可能学会了关注代词（如 “它” 指向谁）的关系，但可能忽略了动作的执行者（主谓宾）等其他类型的关系。&lt;/p&gt;
&lt;p&gt;为了让模型能够综合利用从不同维度和视角提取出的信息，Transformer 引入了&lt;strong&gt;多头注意力机制 (Multi-Head Attention)&lt;/strong&gt;。其思想非常直接：并行地执行多次自注意力计算，每一次计算都是一个独立的“头 (Head)”。每个头都拥有一组自己专属的 $W^Q_i, W^K_i, W^V_i$ 权重矩阵，并且可以学习去关注一种特定类型的上下文关系。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;那么，多头注意力与我们之前讨论的“增加 A, B, C 等新角色”有什么不同呢？&lt;/p&gt;
&lt;p&gt;一个关键的区别：多头注意力&lt;strong&gt;不是&lt;/strong&gt;通过增加 A, B, C 等新角色来&lt;strong&gt;深化&lt;/strong&gt;单次注意力计算的复杂性，而是通过并行运行多个独立的 QKV 计算单元来&lt;strong&gt;拓宽&lt;/strong&gt;其广度。&lt;/p&gt;
&lt;p&gt;再次使用图书馆的类比：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;增加 A, B, C&lt;/strong&gt;：相当于给&lt;strong&gt;一个&lt;/strong&gt;图书管理员一套更复杂的工具，让他一次性处理问题(Q)、索引(K)、内容(V)之外，还要考虑主题(A)、背景(B)等，这会使单次查询过程变得非常复杂。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;多头注意力&lt;/strong&gt;：相当于&lt;strong&gt;雇佣一个各有所长的专家团队&lt;/strong&gt;（比如 8 个管理员，即 8 个“头”）。每个专家都只使用标准高效的 QKV 工具，但他们各自有独特的视角（独立的 $W^Q_i, W^K_i, W^V_i$ 矩阵）。一个专家可能专攻语法，另一个专攻语义。最后，将所有专家的报告汇总起来，得到一个更全面、更丰富的结论。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;因此，多头注意力机制为模型提供了从不同子空间、不同视角审视信息的能力，而不是改变注意力计算本身的范式。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;具体流程如下：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;并行计算&lt;/strong&gt;：假设有 $h$ 个头，那么就初始化 $h$ 组不同的权重矩阵 $(W^Q_0, W^K_0, W^V_0), (W^Q_1, W^K_1, W^V_1), \dots, (W^Q_{h-1}, W^K_{h-1}, W^V_{h-1})$。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;独立注意力&lt;/strong&gt;：对于输入序列，每个头都独立地执行一次完整的自注意力计算，产生一个输出矩阵 $Z_i$。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;拼接与投影&lt;/strong&gt;：将所有 $h$ 个头的输出矩阵 $Z_0, Z_1, \dots, Z_{h-1}$ 在特征维度上进行&lt;strong&gt;拼接 (Concatenate)&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;（4）&lt;strong&gt;最终输出&lt;/strong&gt;：将拼接后的巨大矩阵乘以一个新的权重矩阵 $W^O$，将其投影回原始的输入维度，得到多头注意力机制的最终输出。&lt;/p&gt;
&lt;p&gt;多头机制允许模型在不同的表示子空间中共同学习上下文信息。例如，一个头可能专注于捕捉长距离的语法依赖，而另一个头可能更关注局部的词义关联。这种设计极大地增强了模型的表达能力。&lt;/p&gt;
&lt;p&gt;在实践中，为了保持计算总量不变，通常会将原始的词嵌入维度 &lt;code&gt;embedding_dim&lt;/code&gt; 均分给 $h$ 个头。例如，如果 &lt;code&gt;embedding_dim=512&lt;/code&gt;，有 &lt;code&gt;h=8&lt;/code&gt; 个头，那么每个头产生的 Q, K, V 向量维度就是 &lt;code&gt;d_k = d_v = 512 / 8 = 64&lt;/code&gt;。计算时，先将输入 $X$ 分别投影到 $h$ 组低维的 Q, K, V 向量，并行计算后，再将结果拼接并投影回 &lt;code&gt;embedding_dim&lt;/code&gt; 维度。&lt;/p&gt;
&lt;h3 id="21-pytorch-实现多头注意力"&gt;2.1 PyTorch 实现多头注意力
&lt;/h3&gt;&lt;p&gt;多头注意力是通过并行运行多个独立的自注意力“头”，并融合它们的输出来增强模型的表达能力。一个低效的实现是简单地创建多个 &lt;code&gt;SelfAttention&lt;/code&gt; 实例并拼接结果。而高效的实现则是将多个头的计算逻辑合并到一次矩阵运算中。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadSelfAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;多头自注意力模块&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;MultiHeadSelfAttention&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;hidden_size 必须能被 num_heads 整除&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;k_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wo&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;k_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;v_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 拆分多头&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 并行计算注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attention_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 合并多头结果&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;contiguous&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 输出层&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;head_dim&lt;/code&gt;：计算出每个头的维度，即 &lt;code&gt;hidden_size / num_heads&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;q_linear, k_linear, v_linear&lt;/code&gt;：与单头类似，但这里的线性层输出维度仍然是 &lt;code&gt;hidden_size&lt;/code&gt;。这是为了&lt;strong&gt;一次性计算出所有头所需的总特征&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;wo&lt;/code&gt;：对应于多头注意力机制中的输出权重矩阵 $W^O$，用于融合所有头的信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;线性变换&lt;/strong&gt;: 与单头版本相同，得到总的 Q, K, V 矩阵。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;拆分多头&lt;/strong&gt;:
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;.view(batch_size, seq_len, self.num_heads, self.head_dim)&lt;/code&gt;: 首先，将 &lt;code&gt;hidden_size&lt;/code&gt; 维度逻辑上拆分为 &lt;code&gt;num_heads&lt;/code&gt; 和 &lt;code&gt;head_dim&lt;/code&gt; 两个维度。此时张量形状变为 &lt;code&gt;[batch, seq_len, num_heads, head_dim]&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;.transpose(1, 2)&lt;/code&gt;: 然后，交换 &lt;code&gt;seq_len&lt;/code&gt; 和 &lt;code&gt;num_heads&lt;/code&gt; 维度，得到 &lt;code&gt;[batch, num_heads, seq_len, head_dim]&lt;/code&gt;。这一步是为了让 &lt;code&gt;num_heads&lt;/code&gt; 成为一个类似批次 (batch) 的维度，使得后续的矩阵乘法可以在每个头内部独立、并行地进行。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;并行计算注意力&lt;/strong&gt;: &lt;code&gt;torch.matmul(q, k.transpose(-2, -1))&lt;/code&gt; 现在是一个四维张量的乘法。PyTorch 会自动地将其解释为在第 0 和第 1 维（&lt;code&gt;batch&lt;/code&gt; 和 &lt;code&gt;num_heads&lt;/code&gt;）上进行批处理，而对最后两个维度执行矩阵乘法。这样就实现了所有头的注意力分数计算的并行化。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;合并多头&lt;/strong&gt;: 这是拆分操作的逆过程。
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;.transpose(1, 2)&lt;/code&gt;: 先将 &lt;code&gt;num_heads&lt;/code&gt; 和 &lt;code&gt;seq_len&lt;/code&gt; 维度换回来，形状变为 &lt;code&gt;[batch, seq_len, num_heads, head_dim]&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;.contiguous()&lt;/code&gt;: 由于 &lt;code&gt;transpose&lt;/code&gt; 操作可能导致张量在内存中不是连续存储的，需要调用 &lt;code&gt;.contiguous()&lt;/code&gt; 来确保内存连续，之后才能安全地使用 &lt;code&gt;.view()&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;.view(batch_size, seq_len, self.hidden_size)&lt;/code&gt;: 最后，将 &lt;code&gt;num_heads&lt;/code&gt; 和 &lt;code&gt;head_dim&lt;/code&gt; 两个维度重新合并成 &lt;code&gt;hidden_size&lt;/code&gt; 维度，完成了所有头输出的拼接。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;输出投影&lt;/strong&gt;: 将合并后的结果通过 &lt;code&gt;wo&lt;/code&gt; 线性层，得到最终输出。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="三transformer-整体结构"&gt;三、Transformer 整体结构
&lt;/h2&gt;&lt;p&gt;理解了自注意力和多头注意力之后，就可以从一个更高的视角来审视 Transformer 的整体结构了。通过图 4-3 可以看出它依然是一个 Encoder-Decoder 架构，但其内部是由几个标准化的“积木”堆叠而成的。&lt;/p&gt;
&lt;div align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/4_3_1.svg" width="40%" alt="Transformer 架构" /&gt;
&lt;p&gt;图 4-3 Transformer 架构&lt;/p&gt;
&lt;/div&gt;
&lt;p&gt;Transformer 的 Encoder 和 Decoder 都是由 N 个（原论文中 N=6）功能相同的层（Layer）堆叠而成。下面我们分别来看它们的内部构造。&lt;/p&gt;
&lt;h3 id="31-编码器encoder"&gt;3.1 编码器（Encoder）
&lt;/h3&gt;&lt;p&gt;编码器的作用是“理解”和“消化”输入的整个序列，为序列中的每个词元生成一个富含上下文信息的表示。一个标准的&lt;strong&gt;编码器层&lt;/strong&gt;由两个主要的&lt;strong&gt;子层&lt;/strong&gt;构成，分别是&lt;strong&gt;多头自注意力层（Multi-Head Self-Attention Layer）&lt;strong&gt;和&lt;/strong&gt;位置前馈网络（Position-wise Feed-Forward Network）&lt;/strong&gt;。每个子层的输出都经过了**残差连接（Add）&lt;strong&gt;与&lt;/strong&gt;层归一化（Norm）**处理。所以，一个编码器层内部的数据流可以表示为 &lt;code&gt;x -&amp;gt; Sublayer1(x) -&amp;gt; Add &amp;amp; Norm -&amp;gt; Sublayer2(...) -&amp;gt; Add &amp;amp; Norm&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;关键特性&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;注意力类型&lt;/strong&gt;：编码器中的多头注意力层是&lt;strong&gt;双向的自注意力&lt;/strong&gt;。这意味着在计算时，序列中的任何一个词元都可以“看到”序列中的所有其他词元（包括它自己、它前面的和它后面的）。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;功能&lt;/strong&gt;：由于其双向性，编码器非常擅长&lt;strong&gt;理解&lt;/strong&gt;完整的输入文本，并为每个词元生成一个深度融合了上下文信息的表示。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;应用&lt;/strong&gt;：通过大量堆叠编码器层而构建的模型（Encoder-Only 架构），如 &lt;strong&gt;BERT&lt;/strong&gt;，在文本分类、命名实体识别等自然语言理解（NLU）任务上取得了巨大成功。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="32-解码器decoder"&gt;3.2 解码器（Decoder）
&lt;/h3&gt;&lt;p&gt;解码器的作用是基于编码器对原始输入的理解，并结合已经生成的部分，来逐个生成下一个词元。为了完成这个更复杂的任务，一个标准的&lt;strong&gt;解码器层（Decoder Layer）&lt;strong&gt;比编码器层多了一个注意力子层，总共包含&lt;/strong&gt;三个子层&lt;/strong&gt;。分别是&lt;strong&gt;带掩码的多头自注意力层（Masked Multi-Head Self-Attention Layer）&lt;/strong&gt;、&lt;strong&gt;交叉注意力层（Cross-Attention Layer）&lt;strong&gt;和&lt;/strong&gt;位置前馈网络（Position-wise Feed-Forward Network）&lt;/strong&gt;。同样，解码器的每个子层也都采用了残差连接和层归一化。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;关键特性&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;子层 1：带掩码的自注意力&lt;/strong&gt;
&lt;ul&gt;
&lt;li&gt;与编码器层相比，这是解码器的第一个主要区别。解码器在生成序列时必须是&lt;strong&gt;自回归&lt;/strong&gt;的，即在生成第 $t$ 个词元时，只能依赖于已经生成的前 $t-1$ 个词元，而不能“看到”未来的信息。&lt;/li&gt;
&lt;li&gt;为了在并行的自注意力计算中实现这一点，需要引入&lt;strong&gt;掩码 (Masking)&lt;/strong&gt;。在计算 Softmax 之前，一个“未来词元掩码”会被应用到注意力分数上，将所有未来位置的分数设置为一个极小的负数（如 &lt;code&gt;-inf&lt;/code&gt;），这样在经过 Softmax 之后，这些位置的注意力权重就会变为 0，从而确保了模型的单向性。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;子层 2：交叉注意力&lt;/strong&gt;
&lt;ul&gt;
&lt;li&gt;这是连接编码器和解码器的桥梁。这一层的实现与上一节中描述的交叉注意力一致。&lt;/li&gt;
&lt;li&gt;它的 &lt;strong&gt;Key 和 Value 来自于编码器的最终输出&lt;/strong&gt;，而 &lt;strong&gt;Query 则来自于解码器前一个子层（即带掩码的自注意力层）的输出&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;这一层允许解码器在生成每个词元时，能够“关注”到输入序列的所有部分，从而有针对性地提取所需信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;子层 3：逐位置前馈网络 (FFN)&lt;/strong&gt;
&lt;ul&gt;
&lt;li&gt;这部分与编码器中的 FFN 相同，为模型提供非线性变换能力。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;大模型应用&lt;/strong&gt;：通过大量堆叠解码器层而构建的模型（Decoder-Only 架构），如 &lt;strong&gt;GPT&lt;/strong&gt; 系列，由于其天然的自回归生成能力，引领了当前大语言模型（LLM）的发展浪潮。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="33-组件解析"&gt;3.3 组件解析
&lt;/h3&gt;&lt;p&gt;下面来详细解析一下构成上述“层”的几个重要组件。&lt;/p&gt;
&lt;h4 id="331-位置前馈网络-ffn"&gt;3.3.1 位置前馈网络 (FFN)
&lt;/h4&gt;&lt;p&gt;这是一个由两次线性变换和一个激活函数组成的全连接网络，它独立地应用于序列中的&lt;strong&gt;每一个位置&lt;/strong&gt;。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;作用&lt;/strong&gt;：特征变换。注意力子层内部主要包含 Softmax 归一化；逐位置的非线性主要由 FFN 提供（常见 ReLU/GELU&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;内部结构&lt;/strong&gt;：其常见的结构是“升维-激活-降维”。&lt;/p&gt;
$$
\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
$$&lt;ul&gt;
&lt;li&gt;第一次线性变换（ $W_1$ ）通常会将输入维度 &lt;code&gt;embedding_dim&lt;/code&gt; 放大到 4 倍（&lt;code&gt;4 * embedding_dim&lt;/code&gt;）。这种&lt;strong&gt;升维&lt;/strong&gt;操作的作用是将特征投影到一个更高维的空间，以便提取更丰富、更复杂的模式。&lt;/li&gt;
&lt;li&gt;使用一个激活函数（如 ReLU）进行非线性处理。&lt;/li&gt;
&lt;li&gt;第二次线性变换（ $W_2$ ）再将维度从 &lt;code&gt;4 * embedding_dim&lt;/code&gt; 压缩回原始的 &lt;code&gt;embedding_dim&lt;/code&gt;。这种&lt;strong&gt;降维&lt;/strong&gt;操作可以看作是对高维特征的筛选和压缩，保留最重要的信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;将中间层维度设为 4 倍的做法，主要是继承自原始论文的经验设定，并因其良好效果而被后续模型广泛采用，并非有严格的理论证明。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h4 id="332-残差连接与层归一化-add--norm"&gt;3.3.2 残差连接与层归一化 (Add &amp;amp; Norm)
&lt;/h4&gt;&lt;p&gt;为了让这些层能够成功地“堆叠”起来，每个子层的后面都连接了这个组合。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Add (残差连接)&lt;/strong&gt;：解决了深度网络的“模型退化”问题。从反向传播的角度看，子层的输出可以写成 $y = x + \text{Sublayer}(x)$。在计算梯度时， $\frac{\partial y}{\partial x} = 1 + \frac{\partial \text{Sublayer}(x)}{\partial x}$。其中“1”的存在为梯度创建了一条“高速公路”，确保无论网络有多深，梯度都能至少以大小为 1 的程度回传到最浅层，极大地稳定了训练过程。同时，这也要求模型中所有子层的输入和输出维度必须保持一致，以便进行元素相加。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Norm (层归一化)&lt;/strong&gt; &lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt;：用于稳定训练过程。它独立地对&lt;strong&gt;每个样本的每个词元&lt;/strong&gt;的特征向量（即 &lt;code&gt;hidden_size&lt;/code&gt; 维度）进行标准化，使其均值变为 0，方差变为 1（但这并不假设其原始分布为正态分布）。更重要的是，它引入了两个可学习的参数 $\gamma$（缩放）和 $\beta$（偏移），让模型可以自主学习最佳的数据分布，其完整公式为：&lt;/p&gt;
$$
y = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta
$$&lt;p&gt;这使得模型既能享受到归一化带来的稳定性，又具备了根据任务需要恢复或调整原始分布的能力。与主要用于计算机视觉的 Batch Normalization（对一个批次中所有样本的同一特征进行归一化）相比，Layer Normalization 不受批次大小的影响，更适合处理长度可变的自然语言序列。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个设计使得模型可以通过简单地增加“层”的数量（即深度）和特征维度（即宽度）来进行扩展，为后来参数量巨大的语言模型奠定了基础。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;原论文采用 Post-LN（Sublayer → Add → LayerNorm）。许多现代实现（如 GPT 系列）采用 Pre-LN（LayerNorm → Sublayer → Add），训练更稳定、更易加深，但功能等价。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="34-位置编码"&gt;3.4 位置编码
&lt;/h3&gt;&lt;p&gt;自注意力机制的主要缺陷在于其 &lt;strong&gt;位置无关性&lt;/strong&gt;。由于计算是完全并行的，模型无法感知词元的顺序。例如，“猫追狗”和“狗追猫”这两个句子，在自注意力看来，它们的词元集合完全相同，因此会为“猫”和“狗”生成相同的上下文表示，这显然是错误的。为了解决这个问题，Transformer 在将词嵌入向量输入模型之前，为它们加入了一个 &lt;strong&gt;位置编码 (Positional Encoding)&lt;/strong&gt; 向量。其工作方式非常直接：&lt;/p&gt;
$$ input_\text{embedding} = token_\text{embedding} + positional_\text{encoding} $$&lt;p&gt;这个额外注入的向量为每个词元提供了其在序列中的位置信息。这是一种 &lt;strong&gt;绝对位置编码&lt;/strong&gt;，即每个位置（如第 0、1、2 个位置）都有一个固定的编码向量。在实践中，主要有两种实现方式：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;可学习的位置编码 (Learned Positional Encoding)&lt;/strong&gt;
- 在 Encoder-only 模型（如 BERT）中常见；而近年的大型解码器式模型多采用相对/旋转类位置编码（如 RoPE&lt;sup id="fnref:4"&gt;&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref"&gt;4&lt;/a&gt;&lt;/sup&gt;）。
- 其实现非常简单：创建一个 &lt;code&gt;nn.Embedding&lt;/code&gt; 层，大小为 &lt;code&gt;[max_sequence_length, hidden_size]&lt;/code&gt;。&lt;code&gt;max_sequence_length&lt;/code&gt; 是模型能处理的最大序列长度，这是一个重要的超参数（在很多模型配置文件中被称为 &lt;code&gt;max_position_embeddings&lt;/code&gt;）。在训练时，模型会像学习词嵌入一样，自动学习出每个位置（0, 1, 2, &amp;hellip;）最合适的向量表示。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;基于三角函数的固定编码 (Sinusoidal Positional Encoding)&lt;/strong&gt;
- 这是原版 Transformer 论文中使用的方法，它不需要学习。
- 使用不同频率的正弦和余弦函数来为每个位置生成一个独特的、固定的编码向量：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt; $$
PE_{(\text{pos}, 2i)} = \sin\left(\frac{\text{pos}}{10000^{\frac{2i}{d_{\text{embedding}}}}}\right) \\
PE_{(\text{pos}, 2i+1)} = \cos\left(\frac{\text{pos}}{10000^{\frac{2i}{d_{\text{embedding}}}}}\right)
$$
其中：
- $pos$ 是词元在序列中的绝对位置（如第 0 个、第 1 个词...）。
- $i$ 是编码向量中的维度索引（从 0 到 $d_{embedding}$/2）。公式通过 $i$ 来同时计算偶数维度 $2i$ 和奇数维度 $2i+1$ 的值，因此 $i$ 的取值范围只需达到维度总数的一半。
- $d_{embedding}$ 是词嵌入的维度。
公式利用不同频率的正弦和余弦函数，为 $d_{embedding}$ 维编码向量的每一个维度（$2i$ 对应偶数位，$2i+1$ 对应奇数位）计算一个特定的值。由于每个位置 $pos$ 和每个维度 $i$ 的组合都是独一无二的，所以这种方法能为序列中的每个位置生成一个完全独特的编码向量。这种方法的优势是不同位置的编码向量之间存在固定的线性关系，这可能有助于模型推断出词元间的相对位置。其主要优点是不需要训练，并且理论上可以外推到比训练时遇到的更长的序列。
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;绝对 vs. 相对位置编码&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;上述两种方法都属于&lt;strong&gt;绝对位置编码&lt;/strong&gt;，因为它们为每个绝对位置（第 1 个、第 10 个等）分配一个特定的编码。然而，这种方式在处理超长文本时可能存在泛化性问题。因此，许多现代的大语言模型（如 Transformer-XL, Llama）转而采用&lt;strong&gt;相对位置编码&lt;/strong&gt;。这种方法不再关注词元的绝对位置，而是直接在注意力计算中建模词元之间的相对距离（例如，“当前词”与“前 2 个词”之间的关系），这被证明在处理长序列时更有效、更灵活。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="35-注意力掩码"&gt;3.5 注意力掩码
&lt;/h3&gt;&lt;p&gt;掩码是 Transformer 模型中一个重要的机制。其&lt;strong&gt;主要目的&lt;/strong&gt;是确保解码器在生成序列时的&lt;strong&gt;自回归&lt;/strong&gt;特性，即不能“看到”未来的信息。此外，作为一个&lt;strong&gt;通用的工程实践&lt;/strong&gt;，掩码也被用来处理批量训练中因句子长度不同而引入的填充（Padding）问题。Transformer 主要使用以下两种掩码：&lt;/p&gt;
&lt;p&gt;（1）因果掩码：因果掩码专用于解码器的&lt;strong&gt;带掩码的自注意力（Masked Self-Attention）&lt;strong&gt;子层，是为了确保解码过程遵循&lt;/strong&gt;自回归（Auto-regressive）&lt;strong&gt;特性，即生成第 $i$ 个词元时只能依赖前 $i-1$ 个词元的信息，而绝不能“偷看”到 $i$ 及之后位置的内容。它的实现核心是确保注意力权重矩阵呈现&lt;/strong&gt;下三角矩阵&lt;/strong&gt;的形态。对于长度为 $T$ 的序列，在 $[T, T]$ 的矩阵中，主对角线及以下的位置被标记为可关注（如 &lt;code&gt;True&lt;/code&gt; 或 0），而主对角线以上的位置则被标记为屏蔽（如 &lt;code&gt;False&lt;/code&gt; 或 1）。在计算 Softmax 之前，所有被屏蔽位置的注意力分数会被加上一个极大的负数（如 &lt;code&gt;-inf&lt;/code&gt;），迫使其注意力权重归零，从而物理上切断了信息的向后传播路径。&lt;/p&gt;
&lt;p&gt;（2）填充掩码：填充掩码广泛应用于编码器和解码器的所有注意力层，目的是解决变长序列批量处理时的**填充（Padding）**问题。由于填充词元（如 &lt;code&gt;&amp;lt;pad&amp;gt;&lt;/code&gt;）本身不携带语义信息，若模型对其分配注意力，不仅浪费计算资源，还会引入噪声干扰。填充掩码的作用就是在计算注意力分数后，将所有涉及填充词元的位置（无论是作为查询 Query 还是作为键 Key）的对应分数强制设为极大的负数（如 &lt;code&gt;-1e9&lt;/code&gt; 或负无穷）。假设有一个注意力分数矩阵，维度为 &lt;code&gt;[batch_size, num_heads, seq_len, seq_len]&lt;/code&gt;。填充掩码会是一个 &lt;code&gt;[batch_size, 1, 1, seq_len]&lt;/code&gt; 的矩阵（或可广播的形状），标记了哪些位置是填充。在进行 Softmax 之前，这个掩码会被加到分数矩阵上。经过 Softmax 运
算后，这些负无穷位置的注意力权重会趋近于 0，从而在后续的加权求和中被完全忽略。&lt;/p&gt;
&lt;p&gt;在解码器的自注意力层中，这两种掩码通常会结合使用，确保模型既不会关注到未来的信息，也不会关注到填充位。&lt;/p&gt;
&lt;h3 id="36-解码器推理与-kv-缓存"&gt;3.6 解码器推理与 KV 缓存
&lt;/h3&gt;&lt;p&gt;解码器在训练和推理时的行为有很大不同。训练时，模型可以看到完整的“正确答案”序列，并通过注意力掩码来并行计算所有位置的损失。然而，在&lt;strong&gt;推理&lt;/strong&gt;时，模型必须逐个生成词元，这是一个&lt;strong&gt;自回归&lt;/strong&gt;的过程：&lt;/p&gt;
&lt;p&gt;（1）输入 &lt;code&gt;[BOS]&lt;/code&gt;（开始符），生成第一个词 &lt;code&gt;token_1&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（2）输入 &lt;code&gt;[BOS], token_1&lt;/code&gt;，生成第二个词 &lt;code&gt;token_2&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（3）输入 &lt;code&gt;[BOS], token_1, token_2&lt;/code&gt;，生成第三个词 &lt;code&gt;token_3&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（4）&amp;hellip; 直到生成 &lt;code&gt;[EOS]&lt;/code&gt;（结束符）或达到最大长度。&lt;/p&gt;
&lt;p&gt;如果按照这个流程直接计算，效率会非常低下。例如，在生成 &lt;code&gt;token_3&lt;/code&gt; 时，模型需要为 &lt;code&gt;[BOS]&lt;/code&gt; 和 &lt;code&gt;token_1&lt;/code&gt; 重新计算它们的 Q, K, V 向量并参与注意力计算。但事实上，&lt;code&gt;[BOS]&lt;/code&gt; 和 &lt;code&gt;token_1&lt;/code&gt; 的 Key 和 Value 向量在之前的步骤中已经被计算过了。&lt;/p&gt;
&lt;p&gt;为解决这种冗余计算，推理时会采用一项关键的优化技术：&lt;strong&gt;KV 缓存&lt;/strong&gt;。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;基本原理&lt;/strong&gt;：对于解码器的每一层，都缓存下截至当前时刻已经计算出的所有词元的 &lt;strong&gt;Key&lt;/strong&gt; 和 &lt;strong&gt;Value&lt;/strong&gt; 向量。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;工作流程&lt;/strong&gt;：在生成第 $t$ 个词元时，模型只需要为当前输入的第 $t-1$ 个词元计算出它自己的 $q_{t-1}, k_{t-1}, v_{t-1}$。然后，它从缓存中取出历史的 $K_{cache} = [k_0, k_1, &amp;hellip;, k_{t-2}]$ 和 $V_{cache} = [v_0, v_1, &amp;hellip;, v_{t-2}]$。最后，将新的 $k_{t-1}, v_{t-1}$ 追加到缓存中，并用 $q_{t-1}$ 与更新后的完整 $K_{cache}, V_{cache}$ 进行注意力计算。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;通过 KV 缓存，每次解码步骤的计算量从与整个已生成序列长度的平方（$O(T^2)$）相关，降低到只与序列长度（$O(T)$）线性相关，极大地加速了文本生成的速度，是实现高效大模型推理的常用技术之一。需要注意，KV 缓存占用会随步数线性增长（$O(T)$），在多层多头设置下需关注显存开销。&lt;/p&gt;
&lt;h2 id="四transformer-代码实践"&gt;四、Transformer 代码实践
&lt;/h2&gt;&lt;blockquote&gt;
&lt;p&gt;&lt;a class="link" href="https://github.com/datawhalechina/base-nlp/tree/main/code/C4/transformer" target="_blank" rel="noopener"
&gt;本节完整代码&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="41-项目结构设计"&gt;4.1 项目结构设计
&lt;/h3&gt;&lt;p&gt;为了更好地理解 Transformer 的内部工作机制，接下来尝试从零实现一个完整的 Transformer 模型。我们会采用**“先整体框架，后组件实现”&lt;strong&gt;的思路，拆分多个文件来构建项目。在前面我们详细分析了 Transformer 的几大核心组件，分别是&lt;/strong&gt;位置编码**、&lt;strong&gt;多头注意力&lt;/strong&gt;、&lt;strong&gt;前馈网络 &lt;strong&gt;以及&lt;/strong&gt;归一化&lt;/strong&gt;。为了体现这些组件的独立性和复用性，我们将遵循&lt;strong&gt;模块化&lt;/strong&gt;的设计原则，将它们拆分到 &lt;code&gt;src/&lt;/code&gt; 目录下的独立文件中，而将模型的组装和运行逻辑放在根目录的 &lt;code&gt;main.py&lt;/code&gt; 中。目录设计如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-text" data-lang="text"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;code/C4/transformer/
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;├── src/
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ ├── transformer.py # 核心框架：定义 Transformer、EncoderLayer 和 DecoderLayer
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ ├── attention.py # 核心组件：多头注意力机制 (MultiHeadAttention)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ ├── ffn.py # 核心组件：前馈神经网络 (FeedForward)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ ├── norm.py # 辅助组件：层归一化 (LayerNorm)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;│ └── pos.py # 辅助组件：位置编码 (PositionalEncoding)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;└── main.py # 入口脚本：组装模型并演示前向传播
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="42-搭建整体框架"&gt;4.2 搭建整体框架
&lt;/h3&gt;&lt;p&gt;在开始编写具体的注意力机制或前馈网络之前，我们可以先在 &lt;code&gt;src/transformer.py&lt;/code&gt; 中勾勒出模型的高层架构。这种**“自顶向下”**的编程方式有助于我们理清数据流向。通过前面的学习我们知道，Transformer 宏观上是一个 Encoder-Decoder 架构，所以首先要实现的主要是以下几个部分：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Embedding 层&lt;/strong&gt;：将输入的 token ID 转换为连续的向量表示，并加上位置编码以保留序列顺序信息。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Encoder 堆叠&lt;/strong&gt;：由 $N$ 个 &lt;code&gt;EncoderLayer&lt;/code&gt; 串联而成，负责深度提取和理解输入序列的特征。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Decoder 堆叠&lt;/strong&gt;：由 $N$ 个 &lt;code&gt;DecoderLayer&lt;/code&gt; 串联而成，负责基于 Encoder 的输出逐步生成目标序列。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Output 层&lt;/strong&gt;：一个线性层，将解码器的最终输出映射回词表大小，用于计算下一个词的概率分布。&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# src/transformer.py&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.pos&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt; &lt;span class="c1"&gt;# 稍后实现&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# ... 导入其他组件&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 嵌入层与位置编码&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# src_embedding: 将源语言序列映射为向量 (Encoder输入)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;src_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# tgt_embedding: 将目标语言序列映射为向量 (Decoder输入)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tgt_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 编码器与解码器堆叠&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 使用 ModuleList 来存储层列表，支持按索引访问和自动注册参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ModuleList&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ModuleList&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 输出头&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 生成掩码 (Padding Mask &amp;amp; Causal Mask)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;generate_mask&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 编码器前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 解码器前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dec_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 输出 Logits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dec_output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;有了这个骨架，接下来的任务就是填充 &lt;code&gt;EncoderLayer&lt;/code&gt; 和 &lt;code&gt;DecoderLayer&lt;/code&gt;，而它们又依赖于更底层的组件。&lt;/p&gt;
&lt;h3 id="43-实现核心组件"&gt;4.3 实现核心组件
&lt;/h3&gt;&lt;p&gt;（1）位置编码 (src/pos.py)&lt;/p&gt;
&lt;p&gt;在 &lt;code&gt;src/transformer.py&lt;/code&gt; 中我们引入了 &lt;code&gt;PositionalEncoding&lt;/code&gt;，它是 Transformer 处理序列顺序的关键。这里我们实现论文中的正弦位置编码。位置编码的核心在于初始化阶段，我们会预先计算好一个足够长的编码矩阵。它的计算公式使用了不同频率的正弦和余弦函数：&lt;/p&gt;
$$
PE(pos, 2i) = \sin(pos / 10000^{2i/d_{model}}) \\
PE(pos, 2i+1) = \cos(pos / 10000^{2i/d_{model}})
$$&lt;p&gt;在 &lt;code&gt;__init__&lt;/code&gt; 方法中，我们一次性生成这个矩阵，并将其注册为 buffer。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 正弦位置编码
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Transformer 论文中使用固定公式计算位置编码，不涉及可学习参数。
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 创建一个足够长的 PE 矩阵 [max_seq_len, dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 生成位置索引 [0, 1, ..., max_seq_len-1] -&amp;gt; [max_seq_len, 1]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 计算分母中的 div_term: 10000^(2i/dim) = exp(2i * -log(10000)/dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 这种对数变换的计算方式在数值上更稳定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;div_term&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;10000.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 填充 PE 矩阵&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 偶数维度用 sin，奇数维度用 cos&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;div_term&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;div_term&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 增加 batch 维度: [1, max_seq_len, dim] 以便广播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 注册为 buffer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# register_buffer 的作用是告诉 PyTorch：&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. &amp;#39;pe&amp;#39; 是模型状态的一部分，会随模型保存和加载 (state_dict)。&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. &amp;#39;pe&amp;#39; 不是模型参数 (Parameter)，优化器更新时不会更新它。&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;register_buffer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;pe&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;在前向传播中，我们的任务就是将位置编码加到输入的词嵌入上。由于我们预先生成的 &lt;code&gt;pe&lt;/code&gt; 矩阵可能比当前的输入序列 &lt;code&gt;x&lt;/code&gt; 要长，所以需要根据 &lt;code&gt;x&lt;/code&gt; 的实际长度对 &lt;code&gt;pe&lt;/code&gt; 进行切片。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Args:
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; x: 输入的词嵌入序列 [batch_size, seq_len, dim]
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; Returns:
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 加上位置编码后的序列 [batch_size, seq_len, dim]
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 截取与输入序列长度对应的位置编码并相加&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x.size(1) 是 seq_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# self.pe 的形状是 [1, max_seq_len, dim]，切片后会自动广播到 batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;最后，我们可以编写一段简单的测试代码来验证维度是否正确。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化模块&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备输入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 输入为0，直接观察PE值&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 验证输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;--- PositionalEncoding Test ---&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Input shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--- PositionalEncoding Test ---
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（2）多头注意力 (src/attention.py)&lt;/p&gt;
&lt;p&gt;这是 Transformer 中最复杂的组件，用于从不同的“表示子空间”中提取信息。在初始化阶段，我们需要定义四个主要的线性层：&lt;code&gt;wq&lt;/code&gt;, &lt;code&gt;wk&lt;/code&gt;, &lt;code&gt;wv&lt;/code&gt; 用于将输入投影到 Q, K, V 空间，&lt;code&gt;wo&lt;/code&gt; 用于将多头注意力的输出投影回原始维度。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# src/attention.py&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 定义 Wq, Wk, Wv 矩阵&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 这里我们使用一个大的线性层一次性计算所有头的 Q, K, V&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 最终输出的线性层 Wo&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wo&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;这部分前向传播的重点是“分头”操作。我们不直接对 [batch, seq_len, dim] 进行计算，而是将其 reshape 为 [batch, n_heads, seq_len, head_dim]，这样就可以利用矩阵运算并行地处理所有头。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;span class="lnt"&gt;49
&lt;/span&gt;&lt;span class="lnt"&gt;50
&lt;/span&gt;&lt;span class="lnt"&gt;51
&lt;/span&gt;&lt;span class="lnt"&gt;52
&lt;/span&gt;&lt;span class="lnt"&gt;53
&lt;/span&gt;&lt;span class="lnt"&gt;54
&lt;/span&gt;&lt;span class="lnt"&gt;55
&lt;/span&gt;&lt;span class="lnt"&gt;56
&lt;/span&gt;&lt;span class="lnt"&gt;57
&lt;/span&gt;&lt;span class="lnt"&gt;58
&lt;/span&gt;&lt;span class="lnt"&gt;59
&lt;/span&gt;&lt;span class="lnt"&gt;60
&lt;/span&gt;&lt;span class="lnt"&gt;61
&lt;/span&gt;&lt;span class="lnt"&gt;62
&lt;/span&gt;&lt;span class="lnt"&gt;63
&lt;/span&gt;&lt;span class="lnt"&gt;64
&lt;/span&gt;&lt;span class="lnt"&gt;65
&lt;/span&gt;&lt;span class="lnt"&gt;66
&lt;/span&gt;&lt;span class="lnt"&gt;67
&lt;/span&gt;&lt;span class="lnt"&gt;68
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 线性投影&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# [batch, seq_len, dim] -&amp;gt; [batch, seq_len, dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wk&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 分头 (Split Heads)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 变换形状: [batch, seq_len, n_heads, head_dim] &lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 然后转置: [batch, n_heads, seq_len, head_dim] 以便并行计算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 计算缩放点积注意力 (Scaled Dot-Product Attention)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# scores: [batch, n_heads, seq_len, seq_len]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 应用掩码 (Masking)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# mask == 0 的位置被填充为负无穷，Softmax 后变为 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;-inf&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 5. Softmax 与加权求和&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# context: [batch, n_heads, seq_len, head_dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 6. 合并多头 (Concat Heads)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# [batch, n_heads, seq_len, head_dim] -&amp;gt; [batch, seq_len, n_heads, head_dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# -&amp;gt; [batch, seq_len, dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;contiguous&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 7. 输出层投影&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 单元测试&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化模块&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备输入 (Query, Key, Value 相同)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mha&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 验证输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;--- MultiHeadAttention Test ---&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Input shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--- MultiHeadAttention Test ---
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（3）前馈神经网络 (src/ffn.py)&lt;/p&gt;
&lt;p&gt;标准的 Transformer FFN 是一个简单的两层全连接网络，中间包含激活函数。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# src/ffn.py&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;FeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;w1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 升维&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;w2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 降维&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 线性变换 -&amp;gt; ReLU -&amp;gt; Dropout -&amp;gt; 线性变换&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;w2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;w1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2048&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化模块&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ffn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;FeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备输入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ffn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 验证输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;--- FeedForward Test ---&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Input shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--- FeedForward Test ---
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（4）层归一化 (src/norm.py)&lt;/p&gt;
&lt;p&gt;层归一化 (Layer Normalization) 是 Transformer 中用来稳定训练的组件。与 Batch Normalization 不同，它是在&lt;strong&gt;最后一个维度&lt;/strong&gt;（即特征维度 &lt;code&gt;dim&lt;/code&gt;）上进行归一化的。公式如下：&lt;/p&gt;
$$
y = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta
$$&lt;p&gt;其中 $\gamma$ 和 $\beta$ 是可学习的缩放和平移参数。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 层归一化 (Layer Normalization)
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 公式: y = (x - mean) / sqrt(var + eps) * gamma + beta
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 可学习参数 gamma (缩放) 和 beta (偏移)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# nn.Parameter 会被自动注册为模型参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x: [batch_size, seq_len, dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 在最后一个维度 (dim) 上计算均值和方差&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# keepdim=True 保持维度以便进行广播计算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# unbiased=False 使用有偏估计 (分母为 N)，与 PyTorch 默认行为一致&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;var&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;var&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;unbiased&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 归一化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;var&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 缩放和平移&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x_norm&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 单元测试&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化模块&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ln&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 准备输入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 验证输出&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;--- LayerNorm Test ---&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Input shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--- LayerNorm Test ---
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10, 512&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="44-组装与运行"&gt;4.4 组装与运行
&lt;/h3&gt;&lt;p&gt;（1）完善核心框架 (src/transformer.py)&lt;/p&gt;
&lt;p&gt;之前我们只搭建了 &lt;code&gt;Transformer&lt;/code&gt; 类的骨架，现在我们利用已经实现好的组件，按“编码器层 → 解码器层 → 辅助方法”的顺序来补全 &lt;code&gt;src/transformer.py&lt;/code&gt;。编码器层，这部分包含一个多头自注意力子层和一个前馈网络子层，每个子层后面都接残差连接和层归一化（Post-LN 结构），代码如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 导入组件&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.attention&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.ffn&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;FeedForward&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.norm&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;.pos&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 多头自注意力子层&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前馈网络子层&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;FeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffn_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 1：自注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# Q=K=V=x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 2：前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffn_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;接下来是解码器层，这部分比编码器层多了一个“交叉注意力”子层，先是带掩码的自注意力，再是对编码器输出的交叉注意力，最后是前馈网络。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 带掩码的自注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 交叉注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attn_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;FeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffn_norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 1：带掩码的自注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 2：交叉注意力（Q 来自解码器，K/V 来自编码器输出）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attn_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 子层 3：前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffn_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;_x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;最后在 &lt;code&gt;Transformer&lt;/code&gt; 主类中，我们需要补全相关的辅助方法。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2048&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# ... 初始化嵌入层、位置编码、编码器/解码器堆叠以及输出层等 ...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_init_parameters&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_init_parameters&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;xavier_uniform_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;generate_mask&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# src_mask: [batch, 1, 1, src_len]，pad token 假设为 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# tgt_mask: [batch, 1, tgt_len, tgt_len]，结合 pad mask 和 causal mask&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_pad_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# [batch, 1, 1, tgt_len]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_subsequent_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tril&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;tgt_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tgt_pad_mask&lt;/span&gt; &lt;span class="o"&gt;&amp;amp;&lt;/span&gt; &lt;span class="n"&gt;tgt_subsequent_mask&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;encode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;src_embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder_layers&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tgt_embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder_layers&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;enc_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（2）运行主程序 (main.py)&lt;/p&gt;
&lt;p&gt;现在所有的零件都准备好了，我们可以在 &lt;code&gt;main.py&lt;/code&gt; 中将它们组装起来，并运行一个简单的前向传播测试。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;span class="lnt"&gt;49
&lt;/span&gt;&lt;span class="lnt"&gt;50
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;src.transformer&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Transformer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;main&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 超参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2048&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;50&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 实例化模型&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 模拟输入数据&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 随机生成 src 和 tgt 序列 (假设 pad_token_id=0)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 确保没有 pad token 影响简单测试，或者手动插入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_len&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Model Architecture:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# print(model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;Test Input:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Source Shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Target Shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;Model Output:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Output Shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 预期 [batch_size, tgt_len, tgt_vocab_size]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;main&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Model Architecture:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Test Input:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Source Shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 10&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Target Shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 12&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Model Output:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output Shape: torch.Size&lt;span class="o"&gt;([&lt;/span&gt;2, 12, 100&lt;span class="o"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;hr&gt;
&lt;h2 id="参考文献"&gt;参考文献
&lt;/h2&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1706.03762" target="_blank" rel="noopener"
&gt;Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). &lt;em&gt;Attention Is All You Need&lt;/em&gt;. NeurIPS 2017.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1606.08415" target="_blank" rel="noopener"
&gt;Hendrycks, D., Gimpel, K. (2016). &lt;em&gt;Gaussian Error Linear Units (GELUs)&lt;/em&gt;.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/1607.06450" target="_blank" rel="noopener"
&gt;Ba, J. L., Kiros, J. R., Hinton, G. E. (2016). &lt;em&gt;Layer Normalization&lt;/em&gt;.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:4"&gt;
&lt;p&gt;&lt;a class="link" href="https://arxiv.org/abs/2104.09864" target="_blank" rel="noopener"
&gt;Su, J., Lu, Y., Pan, S., Wen, Y. (2021). &lt;em&gt;RoFormer: Enhanced Transformer with Rotary Position Embedding&lt;/em&gt;.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item><item><title>大语言模型基础</title><link>https://hanguangwu.github.io/blog/p/%E5%A4%A7%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8B%E5%9F%BA%E7%A1%80/</link><pubDate>Mon, 29 Dec 2025 12:34:25 -0800</pubDate><guid>https://hanguangwu.github.io/blog/p/%E5%A4%A7%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8B%E5%9F%BA%E7%A1%80/</guid><description>&lt;h1 id="大语言模型基础"&gt;大语言模型基础
&lt;/h1&gt;&lt;h2 id="前言"&gt;前言
&lt;/h2&gt;&lt;p&gt;本文来源于DataWhale组织《Hello-Agents》课程，课程链接为&lt;a class="link" href="https://datawhalechina.github.io/hello-agents/#/./chapter3/%E7%AC%AC%E4%B8%89%E7%AB%A0%20%E5%A4%A7%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8B%E5%9F%BA%E7%A1%80" target="_blank" rel="noopener"
&gt;第三章大语言模型基础&lt;/a&gt;。&lt;/p&gt;
&lt;p&gt;前两章分别介绍了智能体的定义和发展历史，本章将完全聚焦于大语言模型本身解答一个关键问题：现代智能体是如何工作的？我们将从语言模型的基本定义出发，通过对这些原理的学习，为理解LLM如何获得强大的知识储备与推理能力打下坚实的基础。&lt;/p&gt;
&lt;h2 id="31-语言模型与-transformer-架构"&gt;3.1 语言模型与 Transformer 架构
&lt;/h2&gt;&lt;h3 id="311-从-n-gram-到-rnn"&gt;3.1.1 从 N-gram 到 RNN
&lt;/h3&gt;&lt;p&gt;&lt;strong&gt;语言模型 (Language Model, LM)&lt;/strong&gt; 是自然语言处理的核心，其根本任务是计算一个词序列（即一个句子）出现的概率。一个好的语言模型能够告诉我们什么样的句子是通顺的、自然的。在多智能体系统中，语言模型是智能体理解人类指令、生成回应的基础。本节将回顾从经典的统计方法到现代深度学习模型的演进历程，为理解后续的 Transformer 架构打下坚实的基础。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;（1）统计语言模型与N-gram的思想&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;在深度学习兴起之前，统计方法是语言模型的主流。其核心思想是，一个句子出现的概率，等于该句子中每个词出现的条件概率的连乘。对于一个由词 $w_1,w_2,\cdots,w_m$ 构成的句子 S，其概率 P(S) 可以表示为：&lt;/p&gt;
$$P(S)=P(w_1,w_2,…,w_m)=P(w_1)⋅P(w_2∣w_1)⋅P(w_3∣w_1,w_2)⋯P(w_m∣w_1,…,w_{m−1})$$&lt;p&gt;这个公式被称为概率的链式法则。然而，直接计算这个公式几乎是不可能的，因为像 $P(w_m∣w_1,\cdots,w_{m−1})$ 这样的条件概率太难从语料库中估计了，词序列 $w_1,\cdots,w_{m−1}$ 可能从未在训练数据中出现过。&lt;/p&gt;
&lt;div align="center"&gt;
&lt;img src="https://raw.githubusercontent.com/datawhalechina/Hello-Agents/main/docs/images/3-figures/1757249275674-0.png" alt="图片描述" width="90%"/&gt;
&lt;p&gt;图 3.1 马尔可夫假设示意图&lt;/p&gt;
&lt;/div&gt;
&lt;p&gt;为了解决这个问题，研究者引入了&lt;strong&gt;马尔可夫假设 (Markov Assumption)&lt;/strong&gt; 。其核心思想是：我们不必回溯一个词的全部历史，可以近似地认为，一个词的出现概率只与它前面有限的 $n−1$ 个词有关，如图3.1所示。基于这个假设建立的语言模型，我们称之为 &lt;strong&gt;N-gram模型&lt;/strong&gt;。这里的 &amp;ldquo;N&amp;rdquo; 代表我们考虑的上下文窗口大小。让我们来看几个最常见的例子来理解这个概念：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Bigram (当 N=2 时)&lt;/strong&gt; ：这是最简单的情况，我们假设一个词的出现只与它前面的一个词有关。因此，链式法则中复杂的条件概率 $P(w_i∣w_1,\cdots,w_{i−1})$ 就可以被近似为更容易计算的形式：&lt;/li&gt;
&lt;/ul&gt;
$$P(w_{i}∣w_{1},…,w_{i−1})≈P(w_{i}∣w_{i−1})$$&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Trigram (当 N=3 时)&lt;/strong&gt; ：类似地，我们假设一个词的出现只与它前面的两个词有关：&lt;/li&gt;
&lt;/ul&gt;
$$P(w_i∣w_1,…,w_{i−1})≈P(w_i∣w_{i−2},w_{i−1})$$&lt;p&gt;这些概率可以通过在大型语料库中进行&lt;strong&gt;最大似然估计(Maximum Likelihood Estimation,MLE)&lt;/strong&gt; 来计算。这个术语听起来很复杂，但其思想非常直观：最可能出现的，就是我们在数据中看到次数最多的。例如，对于 Bigram 模型，我们想计算在词 $w_{i−1}$ 出现后，下一个词是 $w_i$ 的概率 $P(w_i∣w_{i−1})$。根据最大似然估计，这个概率可以通过简单的计数来估算：&lt;/p&gt;
$$P(w_i∣w_{i−1})=\frac{Count(w_{i−1},w_i)}{Count(w_{i−1})}$$&lt;p&gt;这里的 &lt;code&gt;Count()&lt;/code&gt; 函数就代表“计数”：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$Count(w_i−1,w_i)$：表示词对 $(w_{i−1},w_i)$ 在语料库中连续出现的总次数。&lt;/li&gt;
&lt;li&gt;$Count(w_{i−1})$：表示单个词 $w_{i−1}$ 在语料库中出现的总次数。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;公式的含义就是：我们用“词对 $Count(w_i−1,w_i)$ 出现的次数”除以“词 $Count(w_{i−1})$ 出现的总次数”，来作为 $P(w_i∣w_{i−1})$ 的一个近似估计。&lt;/p&gt;
&lt;p&gt;为了让这个过程更具体，我们来手动进行一次计算。假设我们拥有一个仅包含以下两句话的迷你语料库：&lt;code&gt;datawhale agent learns&lt;/code&gt;, &lt;code&gt;datawhale agent works&lt;/code&gt;。我们的目标是：使用 Bigram (N=2) 模型，估算句子 &lt;code&gt;datawhale agent learns&lt;/code&gt; 出现的概率。根据 Bigram 的假设，我们每次会考察连续的两个词（即一个词对）。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;第一步：计算第一个词的概率&lt;/strong&gt; $P(datawhale)$ 这是 &lt;code&gt;datawhale&lt;/code&gt; 出现的次数除以总词数。&lt;code&gt;datawhale&lt;/code&gt; 出现了 2 次，总词数是 6。&lt;/p&gt;
$$P(\text{datawhale}) = \frac{\text{总语料中"datawhale"的数量}}{\text{总语料的词数}} = \frac{2}{6} \approx 0.333$$&lt;p&gt;&lt;strong&gt;第二步：计算条件概率&lt;/strong&gt; $P(agent∣datawhale)$ 这是词对 &lt;code&gt;datawhale agent&lt;/code&gt; 出现的次数除以 &lt;code&gt;datawhale&lt;/code&gt; 出现的总次数。&lt;code&gt;datawhale agent&lt;/code&gt; 出现了 2 次，&lt;code&gt;datawhale&lt;/code&gt; 出现了 2 次。&lt;/p&gt;
$$P(\text{agent}|\text{datawhale}) = \frac{\text{Count}(\text{datawhale agent})}{\text{Count}(\text{datawhale})} = \frac{2}{2} = 1$$&lt;p&gt;&lt;strong&gt;第三步：计算条件概率&lt;/strong&gt; $P(learns∣agent)$ 这是词对 &lt;code&gt;agent learns&lt;/code&gt; 出现的次数除以 &lt;code&gt;agent&lt;/code&gt; 出现的总次数。&lt;code&gt;agent learns&lt;/code&gt; 出现了 1 次，&lt;code&gt;agent&lt;/code&gt; 出现了 2 次。&lt;/p&gt;
$$P(\text{learns}|\text{agent}) = \frac{\text{Count(agent learns)}}{\text{Count(agent)}} = \frac{1}{2} = 0.5$$&lt;p&gt;&lt;strong&gt;最后：将概率连乘&lt;/strong&gt; 所以，整个句子的近似概率为：&lt;/p&gt;
$$P(\text{datawhale agent learns}) \approx P(\text{datawhale}) \cdot P(\text{agent}|\text{datawhale}) \cdot P(\text{learns}|\text{agent}) \approx 0.333 \cdot 1 \cdot 0.5 \approx 0.167$$&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;collections&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 示例语料库，与上方案例讲解中的语料库保持一致&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;corpus&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;datawhale agent learns datawhale agent works&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;corpus&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;total_tokens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokens&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# --- 第一步:计算 P(datawhale) ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;count_datawhale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokens&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;count&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;datawhale&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p_datawhale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;count_datawhale&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;total_tokens&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;第一步: P(datawhale) = &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;count_datawhale&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;total_tokens&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; = &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;p_datawhale&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.3f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# --- 第二步:计算 P(agent|datawhale) ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 先计算 bigrams 用于后续步骤&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;bigrams&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokens&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokens&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;bigram_counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;collections&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Counter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;bigrams&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;count_datawhale_agent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bigram_counts&lt;/span&gt;&lt;span class="p"&gt;[(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;datawhale&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;agent&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# count_datawhale 已在第一步计算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p_agent_given_datawhale&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;count_datawhale_agent&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;count_datawhale&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;第二步: P(agent|datawhale) = &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;count_datawhale_agent&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;count_datawhale&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; = &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;p_agent_given_datawhale&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.3f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# --- 第三步:计算 P(learns|agent) ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;count_agent_learns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bigram_counts&lt;/span&gt;&lt;span class="p"&gt;[(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;agent&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;learns&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;count_agent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokens&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;count&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;agent&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p_learns_given_agent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;count_agent_learns&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;count_agent&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;第三步: P(learns|agent) = &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;count_agent_learns&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;count_agent&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; = &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;p_learns_given_agent&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.3f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# --- 最后:将概率连乘 ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;p_sentence&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p_datawhale&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;p_agent_given_datawhale&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;p_learns_given_agent&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;最后: P(&amp;#39;datawhale agent learns&amp;#39;) ≈ &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;p_datawhale&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.3f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; * &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;p_agent_given_datawhale&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.3f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; * &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;p_learns_given_agent&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.3f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; = &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;p_sentence&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.3f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第一步: P(datawhale) = 2/6 = 0.333
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第二步: P(agent|datawhale) = 2/2 = 1.000
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第三步: P(learns|agent) = 1/2 = 0.500
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;最后: P(&amp;#39;datawhale agent learns&amp;#39;) ≈ 0.333 * 1.000 * 0.500 = 0.167
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;N-gram 模型虽然简单有效，但有两个致命缺陷：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;数据稀疏性 (Sparsity)&lt;/strong&gt; ：如果一个词序列从未在语料库中出现，其概率估计就为 0，这显然是不合理的。虽然可以通过平滑 (Smoothing) 技术缓解，但无法根除。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;泛化能力差：&lt;/strong&gt;模型无法理解词与词之间的语义相似性。例如，即使模型在语料库中见过很多次 &lt;code&gt;agent learns&lt;/code&gt;，它也无法将这个知识泛化到语义相似的词上。当我们计算 &lt;code&gt;robot learns&lt;/code&gt; 的概率时，如果 &lt;code&gt;robot&lt;/code&gt; 这个词从未出现过，或者 &lt;code&gt;robot learns&lt;/code&gt; 这个组合从未出现过，模型计算出的概率也会是零。模型无法理解 &lt;code&gt;agent&lt;/code&gt; 和 &lt;code&gt;robot&lt;/code&gt; 在语义上的相似性。&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;（2）神经网络语言模型与词嵌入&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;N-gram 模型的根本缺陷在于它将词视为孤立、离散的符号。为了克服这个问题，研究者们转向了神经网络，并提出了一种思想：用连续的向量来表示词。2003年，Bengio 等人提出的&lt;strong&gt;前馈神经网络语言模型 (Feedforward Neural Network Language Model)&lt;/strong&gt; 是这一领域的里程碑&lt;sup&gt;[1]&lt;/sup&gt;。&lt;/p&gt;
&lt;p&gt;其核心思想可以分为两步：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;构建一个语义空间&lt;/strong&gt;：创建一个高维的连续向量空间，然后将词汇表中的每个词都映射为该空间中的一个点。这个点（即向量）就被称为&lt;strong&gt;词嵌入 (Word Embedding)&lt;/strong&gt; 或词向量。在这个空间里，语义上相近的词，它们对应的向量在空间中的位置也相近。例如，&lt;code&gt;agent&lt;/code&gt; 和 &lt;code&gt;robot&lt;/code&gt; 的向量会靠得很近，而 &lt;code&gt;agent&lt;/code&gt; 和 &lt;code&gt;apple&lt;/code&gt; 的向量会离得很远。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;学习从上下文到下一个词的映射&lt;/strong&gt;：利用神经网络的强大拟合能力，来学习一个函数。这个函数的输入是前 $n−1$ 个词的词向量，输出是词汇表中每个词在当前上下文后出现的概率分布。&lt;/li&gt;
&lt;/ol&gt;
&lt;div align="center"&gt;
&lt;img src="https://raw.githubusercontent.com/datawhalechina/Hello-Agents/main/docs/images/3-figures/1757249275674-1.png" alt="图片描述" width="90%"/&gt;
&lt;p&gt;图 3.2 神经网络语言模型架构示意图&lt;/p&gt;
&lt;/div&gt;
&lt;p&gt;如图3.2所示，在这个架构中，词嵌入是在模型训练过程中自动学习得到的。模型为了完成“预测下一个词”这个任务，会不断调整每个词的向量位置，最终使这些向量能够蕴含丰富的语义信息。一旦我们将词转换成了向量，我们就可以用数学工具来度量它们之间的关系。最常用的方法是&lt;strong&gt;余弦相似度 (Cosine Similarity)&lt;/strong&gt; ，它通过计算两个向量夹角的余弦值来衡量它们的相似性。&lt;/p&gt;
$$\text{similarity}(\vec{a}, \vec{b}) = \cos(\theta) = \frac{\vec{a} \cdot \vec{b}}{|\vec{a}| |\vec{b}|}$$&lt;p&gt;这个公式的含义是：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;如果两个向量方向完全相同，夹角为0°，余弦值为1，表示完全相关。&lt;/li&gt;
&lt;li&gt;如果两个向量方向正交，夹角为90°，余弦值为0，表示毫无关系。&lt;/li&gt;
&lt;li&gt;如果两个向量方向完全相反，夹角为180°，余弦值为-1，表示完全负相关。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;通过这种方式，词向量不仅能捕捉到“同义词”这类简单的关系，还能捕捉到更复杂的类比关系。&lt;/p&gt;
&lt;p&gt;一个著名的例子展示了词向量捕捉到的语义关系： &lt;code&gt;vector('King') - vector('Man') + vector('Woman')&lt;/code&gt; 这个向量运算的结果，在向量空间中与 &lt;code&gt;vector('Queen')&lt;/code&gt; 的位置惊人地接近。这好比在进行语义的平移：我们从“国王”这个点出发，减去“男性”的向量，再加上“女性”的向量，最终就抵达了“女王”的位置。这证明了词嵌入能够学习到“性别”、“王室”这类抽象概念。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 假设我们已经学习到了简化的二维词向量&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;king&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;queen&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;man&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;woman&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cosine_similarity&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vec1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vec2&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dot_product&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vec1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vec2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;norm_product&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vec1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linalg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vec2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;dot_product&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;norm_product&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# king - man + woman&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;result_vec&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;king&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;man&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;woman&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 计算结果向量与 &amp;#34;queen&amp;#34; 的相似度&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;sim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cosine_similarity&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;result_vec&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;queen&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;king - man + woman 的结果向量: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;result_vec&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;该结果与 &amp;#39;queen&amp;#39; 的相似度: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;sim&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.4f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;king - man + woman 的结果向量: [0.9 0.2]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;该结果与 &amp;#39;queen&amp;#39; 的相似度: 1.0000
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;神经网络语言模型通过词嵌入，成功解决了 N-gram 模型的泛化能力差的问题。然而，它仍然有一个类似 N-gram 的限制：上下文窗口是固定的。它只能考虑固定数量的前文，这为能处理任意长序列的循环神经网络埋下了伏笔。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;（3）循环神经网络 (RNN) 与长短时记忆网络 (LSTM)&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;前一节的神经网络语言模型虽然引入了词嵌入解决了泛化问题，但它和 N-gram 模型一样，上下文窗口是固定大小的。为了预测下一个词，它只能看到前 n−1 个词，再早的历史信息就被丢弃了。这显然不符合我们人类理解语言的方式。为了打破固定窗口的限制，&lt;strong&gt;循环神经网络 (Recurrent Neural Network, RNN)&lt;/strong&gt; 应运而生，其核心思想非常直观：为网络增加“记忆”能力&lt;sup&gt;[2]&lt;/sup&gt;。&lt;/p&gt;
&lt;p&gt;如图3.3所示，RNN 的设计引入了一个&lt;strong&gt;隐藏状态 (hidden state)&lt;/strong&gt; 向量，我们可以将其理解为网络的短期记忆。在处理序列的每一步，网络都会读取当前的输入词，并结合它上一刻的记忆（即上一个时间步的隐藏状态），然后生成一个新的记忆（即当前时间步的隐藏状态）传递给下一刻。这个循环往复的过程，使得信息可以在序列中不断向后传递。&lt;/p&gt;
&lt;div align="center"&gt;
&lt;img src="https://raw.githubusercontent.com/datawhalechina/Hello-Agents/main/docs/images/3-figures/1757249275674-2.png" alt="图片描述" width="90%"/&gt;
&lt;p&gt;图 3.3 RNN 结构示意图&lt;/p&gt;
&lt;/div&gt;
&lt;p&gt;然而，标准的 RNN 在实践中存在一个严重的问题：&lt;strong&gt;长期依赖问题 (Long-term Dependency Problem)&lt;/strong&gt; 。在训练过程中，模型需要通过反向传播算法根据输出端的误差来调整网络深处的权重。对于 RNN 而言，序列的长度就是网络的深度。当序列很长时，梯度在从后向前传播的过程中会经过多次连乘，这会导致梯度值快速趋向于零（&lt;strong&gt;梯度消失&lt;/strong&gt;）或变得极大（&lt;strong&gt;梯度爆炸&lt;/strong&gt;）。梯度消失使得模型无法有效学习到序列早期信息对后期输出的影响，即难以捕捉长距离的依赖关系。&lt;/p&gt;
&lt;p&gt;为了解决长期依赖问题，&lt;strong&gt;长短时记忆网络 (Long Short-Term Memory, LSTM)&lt;/strong&gt; 被设计出来&lt;sup&gt;[3]&lt;/sup&gt;。LSTM 是一种特殊的 RNN，其核心创新在于引入了&lt;strong&gt;细胞状态 (Cell State)&lt;/strong&gt; 和一套精密的&lt;strong&gt;门控机制 (Gating Mechanism)&lt;/strong&gt; 。细胞状态可以看作是一条独立于隐藏状态的信息通路，允许信息在时间步之间更顺畅地传递。门控机制则是由几个小型神经网络构成，它们可以学习如何有选择地让信息通过，从而控制细胞状态中信息的增加与移除。这些门包括：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;遗忘门 (Forget Gate)&lt;/strong&gt;：决定从上一时刻的细胞状态中丢弃哪些信息。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;输入门 (Input Gate)&lt;/strong&gt;：决定将当前输入中的哪些新信息存入细胞状态。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;输出门 (Output Gate)&lt;/strong&gt;：决定根据当前的细胞状态，输出哪些信息到隐藏状态。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="312-transformer-架构解析"&gt;3.1.2 Transformer 架构解析
&lt;/h3&gt;&lt;p&gt;在上一节中，我们看到RNN及LSTM通过引入循环结构来处理序列数据，这在一定程度上解决了捕捉长距离依赖的问题。然而，这种循环的计算方式也带来了新的瓶颈：它必须按顺序处理数据。第 t 个时间步的计算，必须等待第 t−1 个时间步完成后才能开始。这意味着 RNN 无法进行大规模的并行计算，在处理长序列时效率低下，这极大地限制了模型规模和训练速度的提升。Transformer在2017 年由谷歌团队提出&lt;sup&gt;[4]&lt;/sup&gt;。它完全抛弃了循环结构，转而完全依赖一种名为&lt;strong&gt;注意力 (Attention)&lt;/strong&gt; 的机制来捕捉序列内的依赖关系，从而实现了真正意义上的并行计算。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;（1）Encoder-Decoder 整体结构&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;最初的 Transformer 模型是为端到端任务机器翻译而设计的。如图3.4所示，它在宏观上遵循了一个经典的&lt;strong&gt;编码器-解码器 (Encoder-Decoder)&lt;/strong&gt; 架构。&lt;/p&gt;
&lt;div align="center"&gt;
&lt;img src="https://raw.githubusercontent.com/datawhalechina/Hello-Agents/main/docs/images/3-figures/1757249275674-3.png" alt="图片描述" width="50%"/&gt;
&lt;p&gt;图 3.4 Transformer 整体架构图&lt;/p&gt;
&lt;/div&gt;
&lt;p&gt;我们可以将这个结构理解为一个分工明确的团队：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;编码器 (Encoder)&lt;/strong&gt; ：任务是“&lt;strong&gt;理解&lt;/strong&gt;”输入的整个句子。它会读取所有输入词元(这个概念会在3.2.2节介绍)，最终为每个词元生成一个富含上下文信息的向量表示。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;解码器 (Decoder)&lt;/strong&gt; ：任务是“&lt;strong&gt;生成&lt;/strong&gt;”目标句子。它会参考自己已经生成的前文，并“咨询”编码器的理解结果，来生成下一个词。&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;为了真正理解 Transformer 的工作原理，最好的方法莫过于亲手实现它。在本节中，我们将采用一种“自顶向下”的方法：首先，我们搭建出 Transformer 完整的代码框架，定义好所有需要的类和方法。然后，我们将像完成拼图一样，逐一实现这些类的具体功能。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;span class="lnt"&gt;49
&lt;/span&gt;&lt;span class="lnt"&gt;50
&lt;/span&gt;&lt;span class="lnt"&gt;51
&lt;/span&gt;&lt;span class="lnt"&gt;52
&lt;/span&gt;&lt;span class="lnt"&gt;53
&lt;/span&gt;&lt;span class="lnt"&gt;54
&lt;/span&gt;&lt;span class="lnt"&gt;55
&lt;/span&gt;&lt;span class="lnt"&gt;56
&lt;/span&gt;&lt;span class="lnt"&gt;57
&lt;/span&gt;&lt;span class="lnt"&gt;58
&lt;/span&gt;&lt;span class="lnt"&gt;59
&lt;/span&gt;&lt;span class="lnt"&gt;60
&lt;/span&gt;&lt;span class="lnt"&gt;61
&lt;/span&gt;&lt;span class="lnt"&gt;62
&lt;/span&gt;&lt;span class="lnt"&gt;63
&lt;/span&gt;&lt;span class="lnt"&gt;64
&lt;/span&gt;&lt;span class="lnt"&gt;65
&lt;/span&gt;&lt;span class="lnt"&gt;66
&lt;/span&gt;&lt;span class="lnt"&gt;67
&lt;/span&gt;&lt;span class="lnt"&gt;68
&lt;/span&gt;&lt;span class="lnt"&gt;69
&lt;/span&gt;&lt;span class="lnt"&gt;70
&lt;/span&gt;&lt;span class="lnt"&gt;71
&lt;/span&gt;&lt;span class="lnt"&gt;72
&lt;/span&gt;&lt;span class="lnt"&gt;73
&lt;/span&gt;&lt;span class="lnt"&gt;74
&lt;/span&gt;&lt;span class="lnt"&gt;75
&lt;/span&gt;&lt;span class="lnt"&gt;76
&lt;/span&gt;&lt;span class="lnt"&gt;77
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# --- 占位符模块，将在后续小节中实现 ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 位置编码模块
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;pass&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 多头注意力机制模块
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;pass&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 位置前馈网络模块
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;pass&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# --- 编码器核心层 ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# 待实现&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# 待实现&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 残差连接与层归一化将在 3.1.2.4 节中详细解释&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 多头自注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ff_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ff_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# --- 解码器核心层 ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# 待实现&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# 待实现&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# 待实现&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm3&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 掩码多头自注意力 (对自己)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 交叉注意力 (对编码器输出)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;cross_attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cross_attn_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ff_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm3&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ff_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;strong&gt;（2）从自注意力到多头注意力&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;现在，我们来填充骨架中最关键的模块，注意力机制。&lt;/p&gt;
&lt;p&gt;想象一下我们阅读这个句子：“The agent learns because &lt;strong&gt;it&lt;/strong&gt; is intelligent.”。当我们读到加粗的 &amp;ldquo;&lt;strong&gt;it&lt;/strong&gt;&amp;rdquo; 时，为了理解它的指代，我们的大脑会不自觉地将更多的注意力放在前面的 &amp;ldquo;agent&amp;rdquo; 这个词上。&lt;strong&gt;自注意力 (Self-Attention)&lt;/strong&gt; 机制就是对这种现象的数学建模。它允许模型在处理序列中的每一个词时，都能兼顾句子中的所有其他词，并为这些词分配不同的“注意力权重”。权重越高的词，代表其与当前词的关联性越强，其信息也应该在当前词的表示中占据更大的比重。&lt;/p&gt;
&lt;p&gt;为了实现上述过程，自注意力机制为每个输入的词元向量引入了三个可学习的角色：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;查询 (Query, Q)&lt;/strong&gt;：代表当前词元，它正在主动地“查询”其他词元以获取信息。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;键 (Key, K)&lt;/strong&gt;：代表句子中可被查询的词元“标签”或“索引”。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;值 (Value, V)&lt;/strong&gt;：代表词元本身所携带的“内容”或“信息”。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这三个向量都是由原始的词嵌入向量乘以三个不同的、可学习的权重矩阵 ($W^Q,W^K,W^V$) 得到的。整个计算过程可以分为以下几步，我们可以把它想象成一次高效的开卷考试：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;准备“考题”和“资料”：对于句子中的每个词，都通过权重矩阵生成其$Q,K,V$向量。&lt;/li&gt;
&lt;li&gt;计算相关性得分：要计算词$A$的新表示，就用词$A$的$Q$向量，去和句子中所有词（包括$A$自己）的$K$向量进行点积运算。这个得分反映了其他词对于理解词$A$的重要性。&lt;/li&gt;
&lt;li&gt;稳定化与归一化：将得到的所有分数除以一个缩放因子$\sqrt{d_{k}}$（$d_{k}$是$K$向量的维度），以防止梯度过小，然后用Softmax函数将分数转换成总和为1的权重，也就是归一化的过程。&lt;/li&gt;
&lt;li&gt;加权求和：将上一步得到的权重分别乘以每个词对应的$V$向量，然后将所有结果相加。最终得到的向量，就是词$A$融合了全局上下文信息后的新表示。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个过程可以用一个简洁的公式来概括：&lt;/p&gt;
$$\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V$$&lt;p&gt;如果只进行一次上述的注意力计算（即单头），模型可能会只学会关注一种类型的关联。比如，在处理 &amp;ldquo;it&amp;rdquo; 时，可能只学会了关注主语。但语言中的关系是复杂的，我们希望模型能同时关注多种关系（如指代关系、时态关系、从属关系等）。多头注意力机制应运而生。它的思想很简单：把一次做完变成分成几组，分开做，再合并。&lt;/p&gt;
&lt;p&gt;它将原始的 Q, K, V 向量在维度上切分成 h 份（h 就是“头”数），每一份都独立地进行一次单头注意力的计算。这就好比让 h 个不同的“专家”从不同的角度去审视句子，每个专家都能捕捉到一种不同的特征关系。最后，将这 h 个专家的“意见”（即输出向量）拼接起来，再通过一个线性变换进行整合，就得到了最终的输出。&lt;/p&gt;
&lt;div align="center"&gt;
&lt;img src="https://raw.githubusercontent.com/datawhalechina/Hello-Agents/main/docs/images/3-figures/1757249275674-4.png" alt="图片描述" width="50%"/&gt;
&lt;p&gt;图 3.5 多头注意力机制&lt;/p&gt;
&lt;/div&gt;
&lt;p&gt;如图3.5所示，这种设计让模型能够共同关注来自不同位置、不同表示子空间的信息，极大地增强了模型的表达能力。以下是多头注意力的简单实现可供参考。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;span class="lnt"&gt;49
&lt;/span&gt;&lt;span class="lnt"&gt;50
&lt;/span&gt;&lt;span class="lnt"&gt;51
&lt;/span&gt;&lt;span class="lnt"&gt;52
&lt;/span&gt;&lt;span class="lnt"&gt;53
&lt;/span&gt;&lt;span class="lnt"&gt;54
&lt;/span&gt;&lt;span class="lnt"&gt;55
&lt;/span&gt;&lt;span class="lnt"&gt;56
&lt;/span&gt;&lt;span class="lnt"&gt;57
&lt;/span&gt;&lt;span class="lnt"&gt;58
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 多头注意力机制模块
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;d_model 必须能被 num_heads 整除&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 定义 Q, K, V 和输出的线性变换层&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_o&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 计算注意力得分 (QK^T)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 应用掩码 (如果提供)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将掩码中为 0 的位置设置为一个非常小的负数，这样 softmax 后会接近 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1e9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 计算注意力权重 (Softmax)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 加权求和 (权重 * V)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_probs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;split_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将输入 x 的形状从 (batch_size, seq_length, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 变换为 (batch_size, num_heads, seq_length, d_k)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;combine_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将输入 x 的形状从 (batch_size, num_heads, seq_length, d_k)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 变回 (batch_size, seq_length, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;contiguous&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 对 Q, K, V 进行线性变换&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;V&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 计算缩放点积注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 合并多头输出并进行最终的线性变换&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_o&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;combine_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;strong&gt;（3）前馈神经网络&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;在每个 Encoder 和 Decoder 层中，多头注意力子层之后都跟着一个&lt;strong&gt;逐位置前馈网络(Position-wise Feed-Forward Network, FFN)&lt;/strong&gt; 。如果说注意力层的作用是从整个序列中“动态地聚合”相关信息，那么前馈网络的作用从这些聚合后的信息中提取更高阶的特征。&lt;/p&gt;
&lt;p&gt;这个名字的关键在于“逐位置”。它意味着这个前馈网络会独立地作用于序列中的每一个词元向量。换句话说，对于一个长度为 &lt;code&gt;seq_len&lt;/code&gt; 的序列，这个 FFN 实际上会被调用 &lt;code&gt;seq_len&lt;/code&gt; 次，每次处理一个词元。重要的是，所有位置共享的是同一组网络权重。这种设计既保持了对每个位置进行独立加工的能力，又大大减少了模型的参数量。这个网络的结构非常简单，由两个线性变换和一个 ReLU 激活函数组成：&lt;/p&gt;
$$\mathrm{FFN}(x)=\max\left(0, xW_{1}+b_{1}\right) W_{2}+b_{2}$$&lt;p&gt;其中，$x$是注意力子层的输出。 $W_1,b_1,W_2,b_2$是可学习的参数。通常，第一个线性层的输出维度 &lt;code&gt;d_ff&lt;/code&gt; 会远大于输入的维度 &lt;code&gt;d_model&lt;/code&gt;（例如 &lt;code&gt;d_ff = 4 * d_model&lt;/code&gt;），经过 ReLU 激活后再通过第二个线性层映射回 &lt;code&gt;d_model&lt;/code&gt; 维度。这种“先扩大再缩小”的模式，被认为有助于模型学习更丰富的特征表示。&lt;/p&gt;
&lt;p&gt;在我们的 PyTorch 骨架中，我们可以用以下代码来实现这个模块：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 位置前馈网络模块
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linear1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linear2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;relu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x 形状: (batch_size, seq_len, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linear1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linear2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 最终输出形状: (batch_size, seq_len, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;strong&gt;（4）残差连接与层归一化&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;在 Transformer 的每个编码器和解码器层中，所有子模块（如多头注意力和前馈网络）都被一个 &lt;code&gt;Add &amp;amp; Norm&lt;/code&gt; 操作包裹。这个组合是为了保证 Transformer 能够稳定训练。&lt;/p&gt;
&lt;p&gt;这个操作由两个部分组成：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;残差连接 (Add)&lt;/strong&gt;：该操作将子模块的输入 &lt;code&gt;x&lt;/code&gt; 直接加到该子模块的输出 &lt;code&gt;Sublayer(x)&lt;/code&gt; 上。这一结构解决了深度神经网络中的&lt;strong&gt;梯度消失 (Vanishing Gradients)&lt;/strong&gt; 问题。在反向传播时，梯度可以绕过子模块直接向前传播，从而保证了即使网络层数很深，模型也能得到有效的训练。其公式可以表示为：$\text{Output} = x + \text{Sublayer}(x)$。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;层归一化 (Norm)&lt;/strong&gt;：该操作对单个样本的所有特征进行归一化，使其均值为0，方差为1。这解决了模型训练过程中的&lt;strong&gt;内部协变量偏移 (Internal Covariate Shift)&lt;/strong&gt; 问题，使每一层的输入分布保持稳定，从而加速模型收敛并提高训练的稳定性。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;3.1.2.5 位置编码&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;我们已经了解，Transformer 的核心是自注意力机制，它通过计算序列中任意两个词元之间的关系来捕捉依赖。然而，这种计算方式有一个固有的问题：它本身不包含任何关于词元顺序或位置的信息。对于自注意力来说，“agent learns” 和 “learns agent” 这两个序列是完全等价的，因为它只关心词元之间的关系，而忽略了它们的排列。为了解决这个问题，Transformer 引入了&lt;strong&gt;位置编码 (Positional Encoding)&lt;/strong&gt; 。&lt;/p&gt;
&lt;p&gt;位置编码的核心思想是，为输入序列中的每一个词元嵌入向量，都额外加上一个能代表其绝对位置和相对位置信息的“位置向量”。这个位置向量不是通过学习得到的，而是通过一个固定的数学公式直接计算得出。这样一来，即使两个词元（例如，两个都叫 &lt;code&gt;agent&lt;/code&gt; 的词元）自身的嵌入是相同的，但由于它们在句子中的位置不同，它们最终输入到 Transformer 模型中的向量就会因为加上了不同的位置编码而变得独一无二。原论文中提出的位置编码使用正弦和余弦函数来生成，其公式如下：&lt;/p&gt;
$$PE_{(pos,2i)}=\sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)，$$$$PE_{(pos,2i+1)}=\cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$&lt;p&gt;其中：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$pos$ 是词元在序列中的位置（例如，$0$，$1$，$2$，&amp;hellip;）&lt;/li&gt;
&lt;li&gt;$i$ 是位置向量中的维度索引（从 $0$ 到 $d_{\text{model}}/2$）&lt;/li&gt;
&lt;li&gt;$d_{\text{model}}$是词嵌入向量的维度（与我们模型中定义的一致）&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;现在，我们来实现 &lt;code&gt;PositionalEncoding&lt;/code&gt; 模块，并完成我们 Transformer 骨架代码的最后一部分。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 为输入序列的词嵌入向量添加位置编码。
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 创建一个足够长的位置编码矩阵&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;div_term&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;10000.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# pe (positional encoding) 的大小为 (max_len, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 偶数维度使用 sin, 奇数维度使用 cos&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;div_term&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;div_term&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将 pe 注册为 buffer，这样它就不会被视为模型参数，但会随模型移动（例如 to(device)）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;register_buffer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;pe&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x.size(1) 是当前输入的序列长度&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将位置编码加到输入向量上&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;本小节主要是为了帮助理解 Transformer 的宏观结构和内部每个模块的运作细节。&lt;/p&gt;
&lt;p&gt;完整代码如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt; 10
&lt;/span&gt;&lt;span class="lnt"&gt; 11
&lt;/span&gt;&lt;span class="lnt"&gt; 12
&lt;/span&gt;&lt;span class="lnt"&gt; 13
&lt;/span&gt;&lt;span class="lnt"&gt; 14
&lt;/span&gt;&lt;span class="lnt"&gt; 15
&lt;/span&gt;&lt;span class="lnt"&gt; 16
&lt;/span&gt;&lt;span class="lnt"&gt; 17
&lt;/span&gt;&lt;span class="lnt"&gt; 18
&lt;/span&gt;&lt;span class="lnt"&gt; 19
&lt;/span&gt;&lt;span class="lnt"&gt; 20
&lt;/span&gt;&lt;span class="lnt"&gt; 21
&lt;/span&gt;&lt;span class="lnt"&gt; 22
&lt;/span&gt;&lt;span class="lnt"&gt; 23
&lt;/span&gt;&lt;span class="lnt"&gt; 24
&lt;/span&gt;&lt;span class="lnt"&gt; 25
&lt;/span&gt;&lt;span class="lnt"&gt; 26
&lt;/span&gt;&lt;span class="lnt"&gt; 27
&lt;/span&gt;&lt;span class="lnt"&gt; 28
&lt;/span&gt;&lt;span class="lnt"&gt; 29
&lt;/span&gt;&lt;span class="lnt"&gt; 30
&lt;/span&gt;&lt;span class="lnt"&gt; 31
&lt;/span&gt;&lt;span class="lnt"&gt; 32
&lt;/span&gt;&lt;span class="lnt"&gt; 33
&lt;/span&gt;&lt;span class="lnt"&gt; 34
&lt;/span&gt;&lt;span class="lnt"&gt; 35
&lt;/span&gt;&lt;span class="lnt"&gt; 36
&lt;/span&gt;&lt;span class="lnt"&gt; 37
&lt;/span&gt;&lt;span class="lnt"&gt; 38
&lt;/span&gt;&lt;span class="lnt"&gt; 39
&lt;/span&gt;&lt;span class="lnt"&gt; 40
&lt;/span&gt;&lt;span class="lnt"&gt; 41
&lt;/span&gt;&lt;span class="lnt"&gt; 42
&lt;/span&gt;&lt;span class="lnt"&gt; 43
&lt;/span&gt;&lt;span class="lnt"&gt; 44
&lt;/span&gt;&lt;span class="lnt"&gt; 45
&lt;/span&gt;&lt;span class="lnt"&gt; 46
&lt;/span&gt;&lt;span class="lnt"&gt; 47
&lt;/span&gt;&lt;span class="lnt"&gt; 48
&lt;/span&gt;&lt;span class="lnt"&gt; 49
&lt;/span&gt;&lt;span class="lnt"&gt; 50
&lt;/span&gt;&lt;span class="lnt"&gt; 51
&lt;/span&gt;&lt;span class="lnt"&gt; 52
&lt;/span&gt;&lt;span class="lnt"&gt; 53
&lt;/span&gt;&lt;span class="lnt"&gt; 54
&lt;/span&gt;&lt;span class="lnt"&gt; 55
&lt;/span&gt;&lt;span class="lnt"&gt; 56
&lt;/span&gt;&lt;span class="lnt"&gt; 57
&lt;/span&gt;&lt;span class="lnt"&gt; 58
&lt;/span&gt;&lt;span class="lnt"&gt; 59
&lt;/span&gt;&lt;span class="lnt"&gt; 60
&lt;/span&gt;&lt;span class="lnt"&gt; 61
&lt;/span&gt;&lt;span class="lnt"&gt; 62
&lt;/span&gt;&lt;span class="lnt"&gt; 63
&lt;/span&gt;&lt;span class="lnt"&gt; 64
&lt;/span&gt;&lt;span class="lnt"&gt; 65
&lt;/span&gt;&lt;span class="lnt"&gt; 66
&lt;/span&gt;&lt;span class="lnt"&gt; 67
&lt;/span&gt;&lt;span class="lnt"&gt; 68
&lt;/span&gt;&lt;span class="lnt"&gt; 69
&lt;/span&gt;&lt;span class="lnt"&gt; 70
&lt;/span&gt;&lt;span class="lnt"&gt; 71
&lt;/span&gt;&lt;span class="lnt"&gt; 72
&lt;/span&gt;&lt;span class="lnt"&gt; 73
&lt;/span&gt;&lt;span class="lnt"&gt; 74
&lt;/span&gt;&lt;span class="lnt"&gt; 75
&lt;/span&gt;&lt;span class="lnt"&gt; 76
&lt;/span&gt;&lt;span class="lnt"&gt; 77
&lt;/span&gt;&lt;span class="lnt"&gt; 78
&lt;/span&gt;&lt;span class="lnt"&gt; 79
&lt;/span&gt;&lt;span class="lnt"&gt; 80
&lt;/span&gt;&lt;span class="lnt"&gt; 81
&lt;/span&gt;&lt;span class="lnt"&gt; 82
&lt;/span&gt;&lt;span class="lnt"&gt; 83
&lt;/span&gt;&lt;span class="lnt"&gt; 84
&lt;/span&gt;&lt;span class="lnt"&gt; 85
&lt;/span&gt;&lt;span class="lnt"&gt; 86
&lt;/span&gt;&lt;span class="lnt"&gt; 87
&lt;/span&gt;&lt;span class="lnt"&gt; 88
&lt;/span&gt;&lt;span class="lnt"&gt; 89
&lt;/span&gt;&lt;span class="lnt"&gt; 90
&lt;/span&gt;&lt;span class="lnt"&gt; 91
&lt;/span&gt;&lt;span class="lnt"&gt; 92
&lt;/span&gt;&lt;span class="lnt"&gt; 93
&lt;/span&gt;&lt;span class="lnt"&gt; 94
&lt;/span&gt;&lt;span class="lnt"&gt; 95
&lt;/span&gt;&lt;span class="lnt"&gt; 96
&lt;/span&gt;&lt;span class="lnt"&gt; 97
&lt;/span&gt;&lt;span class="lnt"&gt; 98
&lt;/span&gt;&lt;span class="lnt"&gt; 99
&lt;/span&gt;&lt;span class="lnt"&gt;100
&lt;/span&gt;&lt;span class="lnt"&gt;101
&lt;/span&gt;&lt;span class="lnt"&gt;102
&lt;/span&gt;&lt;span class="lnt"&gt;103
&lt;/span&gt;&lt;span class="lnt"&gt;104
&lt;/span&gt;&lt;span class="lnt"&gt;105
&lt;/span&gt;&lt;span class="lnt"&gt;106
&lt;/span&gt;&lt;span class="lnt"&gt;107
&lt;/span&gt;&lt;span class="lnt"&gt;108
&lt;/span&gt;&lt;span class="lnt"&gt;109
&lt;/span&gt;&lt;span class="lnt"&gt;110
&lt;/span&gt;&lt;span class="lnt"&gt;111
&lt;/span&gt;&lt;span class="lnt"&gt;112
&lt;/span&gt;&lt;span class="lnt"&gt;113
&lt;/span&gt;&lt;span class="lnt"&gt;114
&lt;/span&gt;&lt;span class="lnt"&gt;115
&lt;/span&gt;&lt;span class="lnt"&gt;116
&lt;/span&gt;&lt;span class="lnt"&gt;117
&lt;/span&gt;&lt;span class="lnt"&gt;118
&lt;/span&gt;&lt;span class="lnt"&gt;119
&lt;/span&gt;&lt;span class="lnt"&gt;120
&lt;/span&gt;&lt;span class="lnt"&gt;121
&lt;/span&gt;&lt;span class="lnt"&gt;122
&lt;/span&gt;&lt;span class="lnt"&gt;123
&lt;/span&gt;&lt;span class="lnt"&gt;124
&lt;/span&gt;&lt;span class="lnt"&gt;125
&lt;/span&gt;&lt;span class="lnt"&gt;126
&lt;/span&gt;&lt;span class="lnt"&gt;127
&lt;/span&gt;&lt;span class="lnt"&gt;128
&lt;/span&gt;&lt;span class="lnt"&gt;129
&lt;/span&gt;&lt;span class="lnt"&gt;130
&lt;/span&gt;&lt;span class="lnt"&gt;131
&lt;/span&gt;&lt;span class="lnt"&gt;132
&lt;/span&gt;&lt;span class="lnt"&gt;133
&lt;/span&gt;&lt;span class="lnt"&gt;134
&lt;/span&gt;&lt;span class="lnt"&gt;135
&lt;/span&gt;&lt;span class="lnt"&gt;136
&lt;/span&gt;&lt;span class="lnt"&gt;137
&lt;/span&gt;&lt;span class="lnt"&gt;138
&lt;/span&gt;&lt;span class="lnt"&gt;139
&lt;/span&gt;&lt;span class="lnt"&gt;140
&lt;/span&gt;&lt;span class="lnt"&gt;141
&lt;/span&gt;&lt;span class="lnt"&gt;142
&lt;/span&gt;&lt;span class="lnt"&gt;143
&lt;/span&gt;&lt;span class="lnt"&gt;144
&lt;/span&gt;&lt;span class="lnt"&gt;145
&lt;/span&gt;&lt;span class="lnt"&gt;146
&lt;/span&gt;&lt;span class="lnt"&gt;147
&lt;/span&gt;&lt;span class="lnt"&gt;148
&lt;/span&gt;&lt;span class="lnt"&gt;149
&lt;/span&gt;&lt;span class="lnt"&gt;150
&lt;/span&gt;&lt;span class="lnt"&gt;151
&lt;/span&gt;&lt;span class="lnt"&gt;152
&lt;/span&gt;&lt;span class="lnt"&gt;153
&lt;/span&gt;&lt;span class="lnt"&gt;154
&lt;/span&gt;&lt;span class="lnt"&gt;155
&lt;/span&gt;&lt;span class="lnt"&gt;156
&lt;/span&gt;&lt;span class="lnt"&gt;157
&lt;/span&gt;&lt;span class="lnt"&gt;158
&lt;/span&gt;&lt;span class="lnt"&gt;159
&lt;/span&gt;&lt;span class="lnt"&gt;160
&lt;/span&gt;&lt;span class="lnt"&gt;161
&lt;/span&gt;&lt;span class="lnt"&gt;162
&lt;/span&gt;&lt;span class="lnt"&gt;163
&lt;/span&gt;&lt;span class="lnt"&gt;164
&lt;/span&gt;&lt;span class="lnt"&gt;165
&lt;/span&gt;&lt;span class="lnt"&gt;166
&lt;/span&gt;&lt;span class="lnt"&gt;167
&lt;/span&gt;&lt;span class="lnt"&gt;168
&lt;/span&gt;&lt;span class="lnt"&gt;169
&lt;/span&gt;&lt;span class="lnt"&gt;170
&lt;/span&gt;&lt;span class="lnt"&gt;171
&lt;/span&gt;&lt;span class="lnt"&gt;172
&lt;/span&gt;&lt;span class="lnt"&gt;173
&lt;/span&gt;&lt;span class="lnt"&gt;174
&lt;/span&gt;&lt;span class="lnt"&gt;175
&lt;/span&gt;&lt;span class="lnt"&gt;176
&lt;/span&gt;&lt;span class="lnt"&gt;177
&lt;/span&gt;&lt;span class="lnt"&gt;178
&lt;/span&gt;&lt;span class="lnt"&gt;179
&lt;/span&gt;&lt;span class="lnt"&gt;180
&lt;/span&gt;&lt;span class="lnt"&gt;181
&lt;/span&gt;&lt;span class="lnt"&gt;182
&lt;/span&gt;&lt;span class="lnt"&gt;183
&lt;/span&gt;&lt;span class="lnt"&gt;184
&lt;/span&gt;&lt;span class="lnt"&gt;185
&lt;/span&gt;&lt;span class="lnt"&gt;186
&lt;/span&gt;&lt;span class="lnt"&gt;187
&lt;/span&gt;&lt;span class="lnt"&gt;188
&lt;/span&gt;&lt;span class="lnt"&gt;189
&lt;/span&gt;&lt;span class="lnt"&gt;190
&lt;/span&gt;&lt;span class="lnt"&gt;191
&lt;/span&gt;&lt;span class="lnt"&gt;192
&lt;/span&gt;&lt;span class="lnt"&gt;193
&lt;/span&gt;&lt;span class="lnt"&gt;194
&lt;/span&gt;&lt;span class="lnt"&gt;195
&lt;/span&gt;&lt;span class="lnt"&gt;196
&lt;/span&gt;&lt;span class="lnt"&gt;197
&lt;/span&gt;&lt;span class="lnt"&gt;198
&lt;/span&gt;&lt;span class="lnt"&gt;199
&lt;/span&gt;&lt;span class="lnt"&gt;200
&lt;/span&gt;&lt;span class="lnt"&gt;201
&lt;/span&gt;&lt;span class="lnt"&gt;202
&lt;/span&gt;&lt;span class="lnt"&gt;203
&lt;/span&gt;&lt;span class="lnt"&gt;204
&lt;/span&gt;&lt;span class="lnt"&gt;205
&lt;/span&gt;&lt;span class="lnt"&gt;206
&lt;/span&gt;&lt;span class="lnt"&gt;207
&lt;/span&gt;&lt;span class="lnt"&gt;208
&lt;/span&gt;&lt;span class="lnt"&gt;209
&lt;/span&gt;&lt;span class="lnt"&gt;210
&lt;/span&gt;&lt;span class="lnt"&gt;211
&lt;/span&gt;&lt;span class="lnt"&gt;212
&lt;/span&gt;&lt;span class="lnt"&gt;213
&lt;/span&gt;&lt;span class="lnt"&gt;214
&lt;/span&gt;&lt;span class="lnt"&gt;215
&lt;/span&gt;&lt;span class="lnt"&gt;216
&lt;/span&gt;&lt;span class="lnt"&gt;217
&lt;/span&gt;&lt;span class="lnt"&gt;218
&lt;/span&gt;&lt;span class="lnt"&gt;219
&lt;/span&gt;&lt;span class="lnt"&gt;220
&lt;/span&gt;&lt;span class="lnt"&gt;221
&lt;/span&gt;&lt;span class="lnt"&gt;222
&lt;/span&gt;&lt;span class="lnt"&gt;223
&lt;/span&gt;&lt;span class="lnt"&gt;224
&lt;/span&gt;&lt;span class="lnt"&gt;225
&lt;/span&gt;&lt;span class="lnt"&gt;226
&lt;/span&gt;&lt;span class="lnt"&gt;227
&lt;/span&gt;&lt;span class="lnt"&gt;228
&lt;/span&gt;&lt;span class="lnt"&gt;229
&lt;/span&gt;&lt;span class="lnt"&gt;230
&lt;/span&gt;&lt;span class="lnt"&gt;231
&lt;/span&gt;&lt;span class="lnt"&gt;232
&lt;/span&gt;&lt;span class="lnt"&gt;233
&lt;/span&gt;&lt;span class="lnt"&gt;234
&lt;/span&gt;&lt;span class="lnt"&gt;235
&lt;/span&gt;&lt;span class="lnt"&gt;236
&lt;/span&gt;&lt;span class="lnt"&gt;237
&lt;/span&gt;&lt;span class="lnt"&gt;238
&lt;/span&gt;&lt;span class="lnt"&gt;239
&lt;/span&gt;&lt;span class="lnt"&gt;240
&lt;/span&gt;&lt;span class="lnt"&gt;241
&lt;/span&gt;&lt;span class="lnt"&gt;242
&lt;/span&gt;&lt;span class="lnt"&gt;243
&lt;/span&gt;&lt;span class="lnt"&gt;244
&lt;/span&gt;&lt;span class="lnt"&gt;245
&lt;/span&gt;&lt;span class="lnt"&gt;246
&lt;/span&gt;&lt;span class="lnt"&gt;247
&lt;/span&gt;&lt;span class="lnt"&gt;248
&lt;/span&gt;&lt;span class="lnt"&gt;249
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;copy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 多头注意力机制模块
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;d_model 必须能被 num_heads 整除&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 定义 Q, K, V 和输出的线性变换层&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_o&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 计算注意力得分 (QK^T)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 应用掩码 (如果提供)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将掩码中为 0 的位置设置为一个非常小的负数，这样 softmax 后会接近 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1e9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 计算注意力权重 (Softmax)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 加权求和 (权重 * V)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_probs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;split_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将输入 x 的形状从 (batch_size, seq_length, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 变换为 (batch_size, num_heads, seq_length, d_k)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;combine_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将输入 x 的形状从 (batch_size, num_heads, seq_length, d_k)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 变回 (batch_size, seq_length, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;contiguous&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 对 Q, K, V 进行线性变换&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;Q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;K&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;V&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 计算缩放点积注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 合并多头输出并进行最终的线性变换&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_o&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;combine_heads&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 位置前馈网络模块
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linear1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linear2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;relu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x 形状: (batch_size, seq_len, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linear1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;linear2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 最终输出形状: (batch_size, seq_len, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 为输入序列的词嵌入向量添加位置编码。
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 创建一个足够长的位置编码矩阵&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;div_term&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;10000.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# pe (positional encoding) 的大小为 (max_len, d_model)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 偶数维度使用 sin, 奇数维度使用 cos&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;div_term&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;div_term&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将 pe 注册为 buffer，这样它就不会被视为模型参数，但会随模型移动（例如 to(device)）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;register_buffer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;pe&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x.size(1) 是当前输入的序列长度&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将位置编码加到输入向量上&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pe&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 编码器核心层
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 多头自注意力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ff_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ff_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; 解码器核心层
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="s2"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionWiseFeedForward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm3&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 掩码多头自注意力 (对自己)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;self_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 交叉注意力 (对编码器输出)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;cross_attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cross_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cross_attn_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 前馈网络&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ff_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feed_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm3&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ff_output&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ModuleList&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;EncoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layers&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Decoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;PositionalEncoding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ModuleList&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pos_encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layers&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;final_linear&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;generate_mask&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# src_mask: (batch_size, 1, 1, src_len)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# tgt_mask: (batch_size, 1, tgt_len, tgt_len)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_pad_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (batch_size, 1, 1, tgt_len)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 下三角矩阵，用于防止看到未来的 token&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_sub_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tril&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;tgt_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_len&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# (tgt_len, tgt_len)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tgt_pad_mask&lt;/span&gt; &lt;span class="o"&gt;&amp;amp;&lt;/span&gt; &lt;span class="n"&gt;tgt_sub_mask&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;generate_mask&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;decoder_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_mask&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;final_linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;decoder_output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# --- 演示如何使用模型 ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 定义超参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;5000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;5000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;d_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2048&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 实例化模型&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d_ff&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 创建模拟输入数据&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 假设 batch_size=2, src_seq_len=10, tgt_seq_len=12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="c1"&gt;# (batch_size, seq_length)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tgt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="c1"&gt;# (batch_size, seq_length)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 模型前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tgt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 5. 打印输出形状&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;模型输出的形状:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 预期输出: torch.Size([2, 12, 5000]) -&amp;gt; (batch_size, tgt_seq_len, tgt_vocab_size)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;模型输出的形状: torch.Size([2, 12, 5000])
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;由于是为了补充智能体学习中大模型的知识体系，也就不再继续往下深入实现。至此，我们已经为理解现代大语言模型打下了坚实的架构基础。在下一节中，我们将探讨 Decoder-Only 架构，看看它是如何基于 Transformer 的思想演变而来。&lt;/p&gt;
&lt;h3 id="313-decoder-only-架构"&gt;3.1.3 Decoder-Only 架构
&lt;/h3&gt;&lt;p&gt;前面一节中，我们动手构建了一个完整的Transformer 模型，它能在很多端到端的场景表现出色。但是当任务转换为构建一个与人对话、创作、作为智能体大脑的通用模型时，或许我们并不需要那么复杂的结构。&lt;/p&gt;
&lt;p&gt;Transformer的设计哲学是“先理解，再生成”。编码器负责深入理解输入的整个句子，形成一个包含全局信息的上下文记忆，然后解码器基于这份记忆来生成翻译。但 OpenAI 在开发 &lt;strong&gt;GPT (Generative Pre-trained Transformer)&lt;/strong&gt; 时，提出了一个更简单的思想&lt;sup&gt;[5]&lt;/sup&gt;：&lt;strong&gt;语言的核心任务，不就是预测下一个最有可能出现的词吗？&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;无论是回答问题、写故事还是生成代码，本质上都是在一个已有的文本序列后面，一个词一个词地添加最合理的内容。基于这个思想，GPT 做了一个大胆的简化：&lt;strong&gt;它完全抛弃了编码器，只保留了解码器部分。&lt;/strong&gt; 这就是 &lt;strong&gt;Decoder-Only&lt;/strong&gt; 架构的由来。&lt;/p&gt;
&lt;p&gt;Decoder-Only 架构的工作模式被称为&lt;strong&gt;自回归 (Autoregressive)&lt;/strong&gt; 。这个听起来很专业的术语，其实描述了一个非常简单的过程：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;给模型一个起始文本（例如 “Datawhale Agent is”）。&lt;/li&gt;
&lt;li&gt;模型预测出下一个最有可能的词（例如 “a”）。&lt;/li&gt;
&lt;li&gt;模型将自己刚刚生成的词 “a” 添加到输入文本的末尾，形成新的输入（“Datawhale Agent is a”）。&lt;/li&gt;
&lt;li&gt;模型基于这个新输入，再次预测下一个词（例如 “powerful”）。&lt;/li&gt;
&lt;li&gt;不断重复这个过程，直到生成完整的句子或达到停止条件。&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;模型就像一个在玩“文字接龙”的游戏，它不断地“回顾”自己已经写下的内容，然后思考下一个字该写什么。&lt;/p&gt;
&lt;p&gt;你可能会问，解码器是如何保证在预测第 &lt;code&gt;t&lt;/code&gt; 个词时，不去“偷看”第 &lt;code&gt;t+1&lt;/code&gt; 个词的答案呢？&lt;/p&gt;
&lt;p&gt;答案就是&lt;strong&gt;掩码自注意力 (Masked Self-Attention)&lt;/strong&gt; 。在 Decoder-Only 架构中，这个机制变得至关重要。它的工作原理非常巧妙：&lt;/p&gt;
&lt;p&gt;在自注意力机制计算出注意力分数矩阵（即每个词对其他所有词的关注度得分）之后，但在进行 Softmax 归一化之前，模型会应用一个“掩码”。这个掩码会将所有位于当前位置之后（即目前尚未观测到）的词元对应的分数，替换为一个非常大的负数。当这个带有负无穷分数的矩阵经过 Softmax 函数时，这些位置的概率就会变为 0。这样一来，模型在计算任何一个位置的输出时，都从数学上被阻止了去关注它后面的信息。这种机制保证了模型在预测下一个词时，能且仅能依赖它已经见过的、位于当前位置之前的所有信息，从而确保了预测的公平性和逻辑的连贯性。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Decoder-Only 架构的优势&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这种看似简单的架构，却带来了巨大的成功，其优势在于：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;训练目标统一&lt;/strong&gt;：模型的唯一任务就是“预测下一个词”，这个简单的目标非常适合在海量的无标注文本数据上进行预训练。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;结构简单，易于扩展&lt;/strong&gt;：更少的组件意味着更容易进行规模化扩展。今天的 GPT-4、Llama 等拥有数千亿甚至万亿参数的巨型模型，都是基于这种简洁的架构。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;天然适合生成任务&lt;/strong&gt;：其自回归的工作模式与所有生成式任务（对话、写作、代码生成等）完美契合，这也是它能成为构建通用智能体基础的核心原因。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;总而言之，从 Transformer 的解码器演变而来的 Decoder-Only 架构，通过“预测下一个词”这一简单的范式，开启了我们今天所处的大语言模型时代。&lt;/p&gt;
&lt;h2 id="32-与大语言模型交互"&gt;3.2 与大语言模型交互
&lt;/h2&gt;&lt;h3 id="321-提示工程"&gt;3.2.1 提示工程
&lt;/h3&gt;&lt;p&gt;如果我们把大语言模型比作一个能力极强的“大脑”，那么&lt;strong&gt;提示 (Prompt)&lt;/strong&gt; 就是我们与这个“大脑”沟通的语言。提示工程，就是研究如何设计出精准的提示，从而引导模型产生我们期望输出的回复。对于构建智能体而言，一个精心设计的提示能让智能体之间协作分工变得高效。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;（1）模型采样参数&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;在使用大模型时，你会经常看到类似&lt;code&gt;Temperature&lt;/code&gt;这类的可配置参数，其本质是通过调整模型对 “概率分布” 的采样策略，让输出匹配具体场景需求，配置合适的参数可以提升Agent在特定场景的性能。&lt;/p&gt;
&lt;p&gt;传统的概率分布是由 Softmax 公式计算得到的：$p_i = \frac{e^{z_i}}{\sum_{j=1}^k e^{z_j}}$，采样参数的本质就是在此基础上，根据不同策略“重新调整”或“截断”分布，从而改变大模型输出的下一个token。&lt;/p&gt;
&lt;p&gt;&lt;code&gt;Temperature&lt;/code&gt;：温度是控制模型输出 “随机性” 与 “确定性” 的关键参数。其原理是引入温度系数$T\gt0$,将 Softmax 改写为$p_i^{(T)} = \frac{e^{z_i / T}}{\sum_{j=1}^k e^{z_j / T}}$。&lt;/p&gt;
&lt;p&gt;当T变小时，分布“更加陡峭”，高概率项权重进一步放大，生成更“保守”且重复率更高的文本。当T变大时，分布“更加平坦”，低概率项权重提升，生成更“多样”但可能出现不连贯的内容。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;低温度（0 $\leqslant$ Temperature $\lt$ 0.3）时输出更 “精准、确定”。适用场景： 事实性任务：如问答、数据计算、代码生成； 严谨性场景：法律条文解读、技术文档撰写、学术概念解释等场景。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;中温度（0.3 $\leqslant$ Temperature $\lt$ 0.7）：输出 “平衡、自然”。适用场景： 日常对话：如客服交互、聊天机器人； 常规创作：如邮件撰写、产品文案、简单故事创作。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;高温度（0.7 $\leqslant$ Temperature $\lt$ 2）：输出 “创新、发散”。适用场景： 创意性任务：如诗歌创作、科幻故事构思、广告 slogan brainstorm、艺术灵感启发； 发散性思考。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;code&gt;Top-k &lt;/code&gt;：其原理是将所有 token 按概率从高到低排序，取排名前 k 个的 token 组成 “候选集”，随后对筛选出的 k 个 token 的概率进行 “归一化”： $ \hat{p}&lt;em&gt;i = \frac{p_i}{\sum&lt;/em&gt;{j \in \text{候选集}} p_j}$&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;与温度采样的区别与联系：温度采样通过温度 T 调整所有 token 的概率分布（平滑或陡峭），不改变候选 token 的数量（仍考虑全部 N 个）。Top-k 采样通过 k 值限制候选 token 的数量（只保留前 k 个高概率 token），再从其中采样。当k=1时输出完全确定，退化为 “贪心采样”。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;code&gt;Top-p &lt;/code&gt;：其原理是将所有 token 按概率从高到低排序，从排序后的第一个 token 开始，逐步累加概率，直到累积和首次达到或超过阈值 p： $\sum_{i \in S} p_{(i)} \geq p$，此时累加过程中包含的所有 token 组成 “核集合”，最后对核集合进行归一化。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;与Top-k的区别与联系：相对于固定截断大小的 Top-k，Top-p 能动态适应不同分布的“长尾”特性，对概率分布不均匀的极端情况的适应性更好。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;在文本生成中，当同时设置 Top-p、Top-k 和温度系数时，这些参数会按照分层过滤的方式协同工作，其优先级顺序为：温度调整→Top-k→Top-p。温度调整整体分布的陡峭程度，Top-k 会先保留概率最高的 k 个候选，然后 Top-p 会从 Top-k 的结果中选取累积概率≥p 的最小集合作为最终的候选集。不过，通常 Top-k 和 Top-p 二选一即可，若同时设置，实际候选集为两者的交集。
需要注意的是，如果将温度设置为 0，则 Top-k 和 Top-p 将变得无关紧要，因为最有可能的 Token 将成为下一个预测的 Token；如果将 Top-k 设置为 1，温度和 Top-p 也将变得无关紧要，因为只有一个 Token 通过 Top-k 标准，它将是下一个预测的 Token。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;（2）零样本、单样本与少样本提示&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;根据我们给模型提供示例（Exemplar）的数量，提示可以分为三种类型。为了更好地理解它们，让我们以一个情感分类任务为例，目标是让模型判断一段文本的情感色彩（如正面、负面或中性）。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;零样本提示 (Zero-shot Prompting)&lt;/strong&gt; 这指的是我们不给模型任何示例，直接让它根据指令完成任务。这得益于模型在海量数据上预训练后获得的强大泛化能力。&lt;/p&gt;
&lt;p&gt;案例： 我们直接向模型下达指令，要求它完成情感分类任务。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;文本&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;Datawhale的AI&lt;/span&gt; &lt;span class="n"&gt;Agent课程非常棒&lt;/span&gt;&lt;span class="err"&gt;！&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;情感&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;正面&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;strong&gt;单样本提示 (One-shot Prompting)&lt;/strong&gt; 我们给模型提供一个完整的示例，向它展示任务的格式和期望的输出风格。&lt;/p&gt;
&lt;p&gt;案例： 我们先给模型一个完整的“问题-答案”对作为示范，然后提出我们的新问题。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;文本&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;这家餐厅的服务太慢了&lt;/span&gt;&lt;span class="err"&gt;。&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;情感&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;负面&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;文本&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;Datawhale的AI&lt;/span&gt; &lt;span class="n"&gt;Agent课程非常棒&lt;/span&gt;&lt;span class="err"&gt;！&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;情感&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;模型会模仿给出的示例格式，为第二段文本补全“正面”。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;少样本提示 (Few-shot Prompting)&lt;/strong&gt; 我们提供多个示例，这能让模型更准确地理解任务的细节、边界和细微差别，从而获得更好的性能。&lt;/p&gt;
&lt;p&gt;案例： 我们提供涵盖了不同情况的多个示例，让模型对任务有更全面的理解。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;文本&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;这家餐厅的服务太慢了&lt;/span&gt;&lt;span class="err"&gt;。&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;情感&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;负面&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;文本&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;这部电影的情节很平淡&lt;/span&gt;&lt;span class="err"&gt;。&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;情感&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;中性&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;文本&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;Datawhale的AI&lt;/span&gt; &lt;span class="n"&gt;Agent课程非常棒&lt;/span&gt;&lt;span class="err"&gt;！&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;情感&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;模型会综合所有示例，更准确地将最后一句的情感分类为“正面”。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;（3）指令调优的影响&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;早期的 GPT 模型（如 GPT-3）主要是“文本补全”模型，它们擅长根据前面的文本续写，但不一定能很好地理解并执行人类的指令。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;指令调优 (Instruction Tuning)&lt;/strong&gt; 是一种微调技术，它使用大量“指令-回答”格式的数据对预训练模型进行进一步的训练。经过指令调优后，模型能更好地理解并遵循用户的指令。我们今天日常工作学习中使用的所有模型（如 &lt;code&gt;ChatGPT&lt;/code&gt;, &lt;code&gt;DeepSeek&lt;/code&gt;, &lt;code&gt;Qwen&lt;/code&gt;）都是其模型家族中经过指令调优过的模型。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;对“文本补全”模型的提示(你需要用少样本提示“教会”模型做什么)：&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Plain" data-lang="Plain"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;这是一段将英文翻译成中文的程序。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;英文:Hello
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;中文:你好
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;英文:How are you?
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;中文:
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;对“指令调优”模型的提示(你可以直接下达指令)：&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Plain" data-lang="Plain"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;请将下面的英文翻译成中文:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;How are you?
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;指令调优的出现，极大地简化了我们与模型交互的方式，使得直接、清晰的自然语言指令成为可能。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;（4）基础提示技巧&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;角色扮演 (Role-playing)&lt;/strong&gt; 通过赋予模型一个特定的角色，我们可以引导它的回答风格、语气和知识范围，使其输出更符合特定场景的需求。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Plain" data-lang="Plain"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;# 案例
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;你现在是一位资深的Python编程专家。请解释一下Python中的GIL（全局解释器锁）是什么，要让一个初学者也能听懂。
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;strong&gt;上下文示例 (In-context Example)&lt;/strong&gt; 这与少样本提示的思想一致，通过在提示中提供清晰的输入输出示例，来“教会”模型如何处理我们的请求，尤其是在处理复杂格式或特定风格的任务时非常有效。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Plain" data-lang="Plain"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;# 案例
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;我需要你从产品评论中提取产品名称和用户情感。请严格按照下面的JSON格式输出。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;评论:这款“星尘”笔记本电脑的屏幕显示效果惊人，但我不太喜欢它的键盘手感。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;输出:{&amp;#34;product_name&amp;#34;: &amp;#34;星尘笔记本电脑&amp;#34;, &amp;#34;sentiment&amp;#34;: &amp;#34;混合&amp;#34;}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;评论:我刚买的“声动”耳机音质很棒，续航也超出了我的预期！
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;输出:
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;strong&gt;（5）思维链&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;对于需要逻辑推理、计算或多步骤思考的复杂问题，直接让模型给出答案往往容易出错。&lt;strong&gt;思维链 (Chain-of-Thought, CoT)&lt;/strong&gt; 是一种强大的提示技巧，它通过引导模型“一步一步地思考”，提升了模型在复杂任务上的推理能力。&lt;/p&gt;
&lt;p&gt;实现 CoT 的关键，是在提示中加入一句简单的引导语，如“请逐步思考”或“Let&amp;rsquo;s think step by step”。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Plain" data-lang="Plain"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;# 思维链提示
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;一个篮球队在一个赛季的80场比赛中赢了60%。在接下来的赛季中，他们打了15场比赛，赢了12场。两个赛季的总胜率是多少？
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;请一步一步地思考并解答。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&amp;gt;&amp;gt;&amp;gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;（模型可能会输出）
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;好的，我们来一步步计算。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第一步:计算第一个赛季赢得的比赛数。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;80场 * 60% = 48场。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第二步:计算两个赛季的总比赛数和总胜利数。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;总比赛数 = 80 + 15 = 95场。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;总胜利数 = 48 + 12 = 60场。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第三步:计算总胜率。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;总胜率 = (总胜利数 / 总比赛数) * 100% = (60 / 95) * 100% ≈ 63.16%。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;所以，两个赛季的总胜率约为63.16%。
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;通过显式地展示其推理过程，模型不仅更容易得出正确的答案，也让它的回答变得更可信、更易于我们检查和纠正。&lt;/p&gt;
&lt;h3 id="322-文本分词"&gt;3.2.2 文本分词
&lt;/h3&gt;&lt;p&gt;我们知道，计算机本质上只能理解数字。因此，在将自然语言文本喂给大语言模型之前，必须先将其转换成模型能够处理的数字格式。这个将文本序列转换为数字序列的过程，就叫做&lt;strong&gt;分词 (Tokenization)&lt;/strong&gt; 。&lt;strong&gt;分词器 (Tokenizer)&lt;/strong&gt; 的作用，就是定义一套规则，将原始文本切分成一个个最小的单元，我们称之为&lt;strong&gt;词元 (Token)&lt;/strong&gt; 。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;3.2.2.1 为何需要分词&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;早期的自然语言处理任务可能会采用简单的分词策略：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;按词分词 (Word-based)&lt;/strong&gt; ：直接用空格或标点符号将句子切分成单词。这种方法很直观，但也面临挑战：
&lt;ul&gt;
&lt;li&gt;词表爆炸与未登录词：一个语言的词汇量是巨大的，如果每个词都作为一个独立的词元，词表会变得难以管理。更糟糕的是，模型将无法处理任何未在词表中出现过的词（例如 “DatawhaleAgent”），这种现象我们称为“未登录词” (Out-Of-Vocabulary, OOV)。&lt;/li&gt;
&lt;li&gt;语义关联的缺失：模型难以捕捉词形相近的词之间的语义关系。例如，&amp;ldquo;look&amp;rdquo;、&amp;ldquo;looks&amp;rdquo; 和 &amp;ldquo;looking&amp;rdquo; 会被视为三个完全不同的词元，尽管它们有共同的核心含义。同样，训练数据中的低频词由于出现次数少，其语义也难以被模型充分学习。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;按字符分词 (Character-based)&lt;/strong&gt; ：将文本切分成单个字符。这种方法词表很小（例如英文字母、数字和标点），不存在 OOV 问题。但它的缺点是，单个字符大多不具备独立的语义，模型需要花费更多的精力去学习如何将字符组合成有意义的词，导致学习效率低下。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;为了兼顾词表大小和语义表达，现代大语言模型普遍采用&lt;strong&gt;子词分词 (Subword Tokenization)&lt;/strong&gt; 算法。它的核心思想是：将常见的词（如 &amp;ldquo;agent&amp;rdquo;）保留为完整的词元，同时将不常见的词（如 &amp;ldquo;Tokenization&amp;rdquo;）拆分成多个有意义的子词片段（如 &amp;ldquo;Token&amp;rdquo; 和 &amp;ldquo;ization&amp;rdquo;）。这样既控制了词表的大小，又能让模型通过组合子词来理解和生成新词。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;3.2.2.2 字节对编码算法解析&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;字节对编码 (Byte-Pair Encoding, BPE) 是最主流的子词分词算法之一&lt;sup&gt;[6]&lt;/sup&gt;，GPT系列模型就采用了这种算法。其核心思想非常简洁，可以理解为一个“贪心”的合并过程：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;初始化&lt;/strong&gt;：将词表初始化为所有在语料库中出现过的基本字符。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;迭代合并&lt;/strong&gt;：在语料库上，统计所有相邻词元对的出现频率，找到频率最高的一对，将它们合并成一个新的词元，并加入词表。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;重复&lt;/strong&gt;：重复第 2 步，直到词表大小达到预设的阈值。&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;案例演示：&lt;/strong&gt; 假设我们的迷你语料库是 &lt;code&gt;{&amp;quot;hug&amp;quot;: 1, &amp;quot;pug&amp;quot;: 1, &amp;quot;pun&amp;quot;: 1, &amp;quot;bun&amp;quot;: 1}&lt;/code&gt;，并且我们想构建一个大小为 10 的词表。BPE 的训练过程可以用下表3.1来表示：&lt;/p&gt;
&lt;div align="center"&gt;
&lt;p&gt;表 3.1 BPE 算法合并过程示例&lt;/p&gt;
&lt;img src="https://raw.githubusercontent.com/datawhalechina/Hello-Agents/main/docs/images/3-figures/1757249275674-5.png" alt="图片描述" width="90%"/&gt;
&lt;/div&gt;
&lt;p&gt;训练结束后，词表大小达到 10，我们就得到了新的分词规则。现在，对于一个未见过的词 &amp;ldquo;bug&amp;rdquo;，分词器会先查找 &amp;ldquo;bug&amp;rdquo; 是否在词表中，发现不在；然后查找 &amp;ldquo;bu&amp;rdquo;，发现不在；最后查找 &amp;ldquo;b&amp;rdquo; 和 &amp;ldquo;ug&amp;rdquo;，发现都在，于是将其切分为 &lt;code&gt;['b', 'ug']&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;下面我们用一段简单的 Python 代码来模拟上述过程：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;re&lt;/span&gt;&lt;span class="o"&gt;,&lt;/span&gt; &lt;span class="nn"&gt;collections&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_stats&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;统计词元对频率&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pairs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;collections&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;defaultdict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;freq&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;items&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;symbols&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;symbols&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pairs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;symbols&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;&lt;span class="n"&gt;symbols&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;freq&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;pairs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;merge_vocab&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pair&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v_in&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;合并词元对&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v_out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bigram&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;escape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39; &amp;#39;&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pair&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;(?&amp;lt;!\S)&amp;#39;&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;bigram&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;(?!\S)&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;v_in&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;w_out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sub&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;&amp;#39;&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pair&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;v_out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;w_out&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;v_in&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;v_out&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 准备语料库，每个词末尾加上&amp;lt;/w&amp;gt;表示结束，并切分好字符&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vocab&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;h u g &amp;lt;/w&amp;gt;&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;p u g &amp;lt;/w&amp;gt;&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;p u n &amp;lt;/w&amp;gt;&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;b u n &amp;lt;/w&amp;gt;&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;num_merges&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="c1"&gt;# 设置合并次数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_merges&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pairs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;get_stats&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;pairs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;best&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pairs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;pairs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;get&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vocab&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;merge_vocab&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;best&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;第&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;次合并: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; -&amp;gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;&amp;#39;&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;best&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;新词表（部分）: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;-&amp;#34;&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;这段代码清晰地展示了 BPE 算法如何通过迭代合并最高频的相邻词元对，来逐步构建和扩充词表的过程。&lt;/p&gt;
&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第1次合并: (&amp;#39;u&amp;#39;, &amp;#39;g&amp;#39;) -&amp;gt; ug
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新词表（部分）: [&amp;#39;h ug &amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;p ug &amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;p u n &amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;b u n &amp;lt;/w&amp;gt;&amp;#39;]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--------------------
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第2次合并: (&amp;#39;ug&amp;#39;, &amp;#39;&amp;lt;/w&amp;gt;&amp;#39;) -&amp;gt; ug&amp;lt;/w&amp;gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新词表（部分）: [&amp;#39;h ug&amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;p ug&amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;p u n &amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;b u n &amp;lt;/w&amp;gt;&amp;#39;]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--------------------
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第3次合并: (&amp;#39;u&amp;#39;, &amp;#39;n&amp;#39;) -&amp;gt; un
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新词表（部分）: [&amp;#39;h ug&amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;p ug&amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;p un &amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;b un &amp;lt;/w&amp;gt;&amp;#39;]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;--------------------
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;第4次合并: (&amp;#39;un&amp;#39;, &amp;#39;&amp;lt;/w&amp;gt;&amp;#39;) -&amp;gt; un&amp;lt;/w&amp;gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新词表（部分）: [&amp;#39;h ug&amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;p ug&amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;p un&amp;lt;/w&amp;gt;&amp;#39;, &amp;#39;b un&amp;lt;/w&amp;gt;&amp;#39;]
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;后续的许多算法都是在BPE的基础上进行优化的。其中，Google 开发的 WordPiece 和 SentencePiece 是影响力最大的两种。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;WordPiece&lt;/strong&gt;：Google BERT 模型采用的算法&lt;sup&gt;[7]&lt;/sup&gt;。它与 BPE 非常相似，但合并词元的标准不是“最高频率”，而是“能最大化提升语料库的语言模型概率”。简单来说，它会优先合并那些能让整个语料库的“通顺度”提升最大的词元对。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;SentencePiece&lt;/strong&gt;：Google 开源的一款分词工具&lt;sup&gt;[8]&lt;/sup&gt;，Llama 系列模型采用了此算法。它最大的特点是，将空格也视作一个普通字符（通常用下划线 &lt;code&gt;_&lt;/code&gt; 表示）。这使得分词和解码过程完全可逆，且不依赖于特定的语言（例如，它不需要知道中文不使用空格分词）。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;3.2.2.3 分词器对开发者的意义&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;理解分词算法的细节并非目的，但作为智能体的开发者，理解分词器的实际影响是重要，这直接关系到智能体的性能、成本和稳定性：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;上下文窗口限制&lt;/strong&gt;：模型的上下文窗口（如 8K, 128K）是以 &lt;strong&gt;Token 数量&lt;/strong&gt;计算的，而不是字符数或单词数。同样一段话，在不同语言（如中英文）或不同分词器下，Token 数量可能相差巨大。精确管理输入长度、避免超出上下文限制是构建长时记忆智能体的基础。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;API 成本&lt;/strong&gt;：大多数模型 API 都是按 Token 数量计费的。了解你的文本会被如何分词，是预估和控制智能体运行成本的关键一步。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;模型表现的异常&lt;/strong&gt;：有时模型的奇怪表现根源在于分词。例如，模型可能很擅长计算 &lt;code&gt;2 + 2&lt;/code&gt;，但对于 &lt;code&gt;2+2&lt;/code&gt;（没有空格）就可能出错，因为后者可能被分词器视为一个独立的、不常见的词元。同样，一个词因为首字母大小写不同，也可能被切分成完全不同的 Token 序列，从而影响模型的理解。在设计提示词和解析模型输出时，考虑到这些“陷阱”有助于提升智能体的鲁棒性。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="323-调用开源大语言模型"&gt;3.2.3 调用开源大语言模型
&lt;/h3&gt;&lt;p&gt;在本书的第一章，我们通过 API 来与大语言模型进行交互，以此驱动我们的智能体。这是一种快速、便捷的方式，但并非唯一的方式。对于许多需要处理敏感数据、希望离线运行或想精细控制成本的场景，将大语言模型直接部署在本地就显得至关重要。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Hugging Face Transformers&lt;/strong&gt; 是一个强大的开源库，它提供了标准化的接口来加载和使用数以万计的预训练模型。我们将使用它来完成本次实践。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;配置环境与选择模型&lt;/strong&gt;：为了让大多数读者都能在个人电脑上顺利运行，我们特意选择了一个小规模但功能强大的模型：&lt;code&gt;Qwen/Qwen1.5-0.5B-Chat&lt;/code&gt;。这是一个由阿里巴巴达摩院开源的拥有约 5 亿参数的对话模型，它体积小、性能优异，非常适合入门学习和本地部署。&lt;/p&gt;
&lt;p&gt;首先，请确保你已经安装了必要的库：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;pip install transformers torch
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;在 &lt;code&gt;transformers&lt;/code&gt; 库中，我们通常使用 &lt;code&gt;AutoModelForCausalLM&lt;/code&gt; 和 &lt;code&gt;AutoTokenizer&lt;/code&gt; 这两个类来自动加载与模型匹配的权重和分词器。下面这段代码会自动从 Hugging Face Hub 下载所需的模型文件和分词器配置，这可能需要一些时间，具体取决于你的网络速度。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 指定模型ID&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;Qwen/Qwen1.5-0.5B-Chat&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 设置设备，优先使用GPU&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cuda&amp;#34;&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cpu&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Using device: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 加载分词器&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_id&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 加载模型，并将其移动到指定设备&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_id&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;模型和分词器加载完成！&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;我们来创建一个对话提示，Qwen1.5-Chat 模型遵循特定的对话模板。然后，可以将使用上一步加载的 &lt;code&gt;tokenizer&lt;/code&gt; 将文本提示转换为模型能够理解的数字 ID（即 Token ID）。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 准备对话输入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;messages&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;role&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;system&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;content&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;You are a helpful assistant.&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;role&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;user&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;content&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;你好，请介绍你自己。&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 使用分词器的模板格式化输入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;apply_chat_template&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tokenize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;add_generation_prompt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 编码输入文本&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model_inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;pt&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;编码后的输入文本:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_inputs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;编码后的输入文本:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;{&amp;#39;input_ids&amp;#39;: tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; 151645, 198, 151644, 872, 198, 108386, 37945, 100157, 107828,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; 1773, 151645, 198, 151644, 77091, 198]]), &amp;#39;attention_mask&amp;#39;: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;现在可以调用模型的 &lt;code&gt;generate()&lt;/code&gt; 方法来生成回答了。模型会输出一系列 Token ID，这代表了它的回答。&lt;/p&gt;
&lt;p&gt;最后，我们需要使用分词器的 &lt;code&gt;decode()&lt;/code&gt; 方法，将这些数字 ID 翻译回人类可以阅读的文本。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-Python" data-lang="Python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 使用模型生成回答&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# max_new_tokens 控制了模型最多能生成多少个新的Token&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;generated_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;generate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model_inputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 将生成的 Token ID 截取掉输入部分&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 这样我们只解码模型新生成的部分&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;generated_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output_ids&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;):]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_ids&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_inputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generated_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 解码生成的 Token ID&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;response&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;generated_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;skip_special_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;模型的回答:&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;response&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;我是一个人工智能，没有具体的自我认识和经历。我的设计目的是帮助用户获取信息、完成任务等。我可以回答各种问题、提供定义、解释和建议，以及生成代码或文本等服务。
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;当你运行完所有代码后，你将会在本地电脑上看到模型生成的关于Qwen模型的介绍。恭喜你，你已经成功地在本地部署并运行了一个开源大语言模型！&lt;/p&gt;
&lt;h3 id="324-模型的选择"&gt;3.2.4 模型的选择
&lt;/h3&gt;&lt;p&gt;在上一节中，我们成功地在本地运行了一个小型的开源语言模型。这自然引出了一个对于智能体开发者而言至关重要的问题：在当前数百个模型百花齐放的背景下，我们应当如何为特定的任务选择最合适的模型？&lt;/p&gt;
&lt;p&gt;选择语言模型并非简单地追求“最大、最强”，而是一个在性能、成本、速度和部署方式之间进行权衡的决策过程。本节将首先梳理模型选型的几个关键考量因素，然后对当前主流的闭源与开源模型进行梳理。&lt;/p&gt;
&lt;p&gt;由于大语言模型技术正处于高速发展阶段，新模型、新版本层出不穷，迭代速度极快。本节在撰写时力求提供当前主流模型的概览和选型考量，但请读者注意，文中所提及的具体模型版本和性能数据可能随时间推移而发生变化，且只列举了部分工作并不完整。我们更侧重于介绍其核心技术特点、发展趋势以及在智能体开发中的通用选型原则。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;3.2.4.1 模型选型的关键考量&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;在为您的智能体选择大语言模型时，可以从以下几个维度进行综合评估：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;性能与能力&lt;/strong&gt;：这是最核心的考量。不同的模型擅长的任务不同，有的长于逻辑推理和代码生成，有的则在创意写作或多语言翻译上更胜一筹。您可以参考一些公开的基准测试排行榜（如 LMSys Chatbot Arena Leaderboard）来评估模型的综合能力。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;成本&lt;/strong&gt;：对于闭源模型，成本主要体现在 API 调用费用，通常按 Token 数量计费。对于开源模型，成本则体现在本地部署所需的硬件（GPU、内存）和运维上。需要根据应用的预期使用量和预算做出选择。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;速度（延迟）&lt;/strong&gt;：对于需要实时交互的智能体（如客服、游戏 NPC），模型的响应速度至关重要。一些轻量级或经过优化的模型（如 GPT-3.5 Turbo, Claude 3.5 Sonnet）在延迟上表现更优。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;上下文窗口&lt;/strong&gt;：模型能一次性处理的 Token 数量上限。对于需要理解长文档、分析代码库或维持长期对话记忆的智能体，选择一个拥有较大上下文窗口（如 128K Token 或更高）的模型是必要的。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;部署方式&lt;/strong&gt;：使用 API 的方式最简单便捷，但数据需要发送给第三方，且受限于服务商的条款。本地部署则能确保数据隐私和最高程度的自主可控，但对技术和硬件要求更高。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;生态与工具链&lt;/strong&gt;：一个模型的流行程度也决定了其周边生态的成熟度。主流模型通常拥有更丰富的社区支持、教程、预训练模型、微调工具和兼容的开发框架（如 LangChain, LlamaIndex, Hugging Face Transformers），这能极大地加速开发进程，降低开发难度。选择一个拥有活跃社区和完善工具链的模型，可以在遇到问题时更容易找到解决方案和资源。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;可微调性与定制化&lt;/strong&gt;：对于需要处理特定领域数据或执行特定任务的智能体，模型的微调能力至关重要。一些模型提供了便捷的微调接口和工具，允许开发者使用自己的数据集对模型进行定制化训练，从而显著提升模型在特定场景下的性能和准确性。开源模型在这方面通常提供更大的灵活性。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;安全性与伦理&lt;/strong&gt;：随着大语言模型的广泛应用，其潜在的安全风险和伦理问题也日益凸显。选择模型时，需要考虑其在偏见、毒性、幻觉等方面的表现，以及服务商或开源社区在模型安全和负责任AI方面的投入。对于面向公众或涉及敏感信息的应用，模型的安全性和伦理合规性是不可忽视的考量。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;3.2.4.2 闭源模型概览&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;闭源模型通常代表了当前 AI 技术的最前沿，并提供稳定、易用的 API 服务，是构建高性能智能体的首选。&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;OpenAI GPT 系列&lt;/strong&gt;：从开启大模型时代的 GPT-3，到引入 RLHF（人类反馈强化学习）、实现与人类意图对齐的 ChatGPT，再到开启多模态时代的 GPT-4，OpenAI 持续引领行业发展。最新的 GPT-5 更是将多模态能力和通用智能水平提升到新的高度，能够无缝处理文本、音频和图像输入，并生成相应的输出，其响应速度和自然度也大幅提升，尤其在实时语音对话方面表现出色。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Google Gemini 系列&lt;/strong&gt;：Google DeepMind 推出的 Gemini 系列模型是原生多模态的代表，其核心特点是能统一处理文本、代码、音视频和图像等多种模态的数据，并以其超长的上下文窗口在海量信息处理上具备优势。Gemini Ultra 是其最强大的模型，适用于高度复杂的任务；Gemini Pro 适用于广泛的任务，提供高性能和效率；Gemini Nano 则针对设备端部署进行了优化。最新的 Gemini 2.5 系列模型，如 Gemini 2.5 Pro 和 Gemini 2.5 Flash，进一步提升了推理能力和上下文窗口，特别是 Gemini 2.5 Flash 以其更快的推理速度和成本效益，适用于需要快速响应的场景。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Anthropic Claude 系列&lt;/strong&gt;：Anthropic 是一家专注于 AI 安全和负责任 AI 的公司，其 Claude 系列模型从设计之初就将 AI 安全放在首位，以其在处理长文档、减少有害输出、遵循指令方面的可靠性而闻名，深受企业级应用青睐。Claude 3 系列包括 Claude 3 Opus（最智能、性能最强）、Claude 3 Sonnet（性能与速度兼顾的平衡之选）和 Claude 3 Haiku（最快、最紧凑的模型，适用于近乎实时的交互）。最新的 Claude 4 系列模型，如 Claude 4 Opus，在通用智能、复杂推理和代码生成方面取得了显著进展，进一步提升了处理长上下文和多模态任务的能力。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;国内主流模型&lt;/strong&gt;：中国在大语言模型领域涌现出众多具有竞争力的闭源模型，以百度文心一言(ERNIE Bot)、腾讯混元(Hunyuan)、华为盘古(Pangu-α)、科大讯飞星火(SparkDesk)和月之暗面(Moonshot AI)等为代表的国产模型，在中文处理上具备天然优势，并深度赋能本土产业。&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;3.2.4.3 开源模型概览&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;开源模型为开发者提供了最高程度的灵活性、透明度和自主性，催生了繁荣的社区生态。它们允许开发者在本地部署、进行定制化微调，并拥有完整的模型控制权。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Meta Llama 系列&lt;/strong&gt;：Meta 推出的 Llama 系列是开源大语言模型的重要里程碑。该系列凭借出色的综合性能、开放的许可协议和强大的社区支持，成为许多衍生项目和研究的基座。Llama 4 系列于2025年4月发布，是Meta首批采用混合专家（MoE）架构的模型，该架构通过仅激活处理特定任务所需的模型部分来显著提升计算效率。该系列包含三款定位分明的模型：LLama 4 Scout支持1000万token的上下文窗口专为长文档分析和移动端部署设计。Llama 4 Maverick专注于多模态能力，在编码、复杂推理及多语言支持方面表现卓越。Llama 4 Behemoth多项STEM基准测试中表现超越竞争对手。是Meta目前最强大的模型&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Mistral AI 系列&lt;/strong&gt;：来自法国的 Mistral AI 以其“小尺寸、高性能”的模型设计而闻名。其最新模型 Mistral Medium 3.1 于2025年8月发布，在代码生成、STEM推理和跨领域问答等任务上准确率与响应速度均有显著提升，基准测试表现优于Claude Sonnet 3.7与Llama 4 Maverick等同级模型。它具备原生多模态能力，可同时处理图像与文字混合输入，并内置“语调适配层”，帮助企业更轻松实现符合品牌调性的输出。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;国内开源力量&lt;/strong&gt;：国内厂商和科研机构也在积极拥抱开源，例如阿里巴巴的&lt;strong&gt;通义千问 (Qwen)&lt;/strong&gt; 系列和清华大学与智谱 AI 合作的 &lt;strong&gt;ChatGLM&lt;/strong&gt; 系列，它们提供了强大的中文能力，并围绕自身构建了活跃的社区。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;对于智能体开发者而言，闭源模型提供了“开箱即用”的便捷，而开源模型则赋予了我们“随心所欲”的定制自由。理解这两大阵营的特点和代表模型，是为我们的智能体项目做出明智技术选型的第一步。&lt;/p&gt;
&lt;h2 id="33-大语言模型的缩放法则与局限性"&gt;3.3 大语言模型的缩放法则与局限性
&lt;/h2&gt;&lt;p&gt;大语言模型（LLMs）在近年来取得了令人瞩目的进展，其能力边界不断拓展，应用场景日益丰富。然而，这些成就的背后，离不开对模型规模、数据量和计算资源之间关系的深刻理解，即&lt;strong&gt;缩放法则（Scaling Laws）&lt;/strong&gt;。同时，作为新兴技术，LLMs也面临着诸多挑战和局限性。本节将深入探讨这些核心概念，旨在帮助读者全面理解LLMs的能力边界，从而在构建智能体时扬长避短。&lt;/p&gt;
&lt;h3 id="331-缩放法则"&gt;3.3.1 缩放法则
&lt;/h3&gt;&lt;p&gt;&lt;strong&gt;缩放法则（Scaling Laws）&lt;/strong&gt;是近年来大语言模型领域最重要的发现之一。它揭示了模型性能与模型参数量、训练数据量以及计算资源之间存在着可预测的幂律关系。这一发现为大语言模型的持续发展提供了理论指导，阐明了增加资源投入能够系统性提升模型性能的底层逻辑。&lt;/p&gt;
&lt;p&gt;研究发现，在对数-对数坐标系下，模型的性能（通常用损失 Loss 来衡量）与参数量、数据量和计算量这三个因素都呈现出平滑的幂律关系&lt;sup&gt;[9]&lt;/sup&gt;。简单来说，只要我们持续、按比例地增加这三个要素，模型的性能就会可预测地、平滑地提升，而不会出现明显的瓶颈。这一发现为大模型的设计和训练提供了清晰的指导：在资源允许的范围内，尽可能地扩大模型规模和训练数据量。&lt;/p&gt;
&lt;p&gt;早期的研究更侧重于增加模型参数量，但 DeepMind 在 2022 年提出的“Chinchilla 定律”对此进行了重要修正&lt;sup&gt;[10]&lt;/sup&gt;。该定律指出，在给定的计算预算下，为了达到最优性能，&lt;strong&gt;模型参数量和训练数据量之间存在一个最优配比&lt;/strong&gt;。具体来说，最优的模型应该比之前普遍认为的要小，但需要用多得多的数据进行训练。例如，一个 700 亿参数的 Chinchilla 模型，由于使用了比 GPT-3（1750 亿参数）多 4 倍的数据进行训练，其性能反而超越了后者。这一发现纠正了“越大越好”的片面认知，强调了数据效率的重要性，并指导了后续许多高效大模型（如 Llama 系列）的设计。&lt;/p&gt;
&lt;p&gt;缩放法则最令人惊奇的产物是“能力的涌现”。所谓能力涌现，是指当模型规模达到一定阈值后，会突然展现出在小规模模型中完全不存在或表现不佳的全新能力。例如，&lt;strong&gt;链式思考 (Chain-of-Thought)&lt;/strong&gt; 、&lt;strong&gt;指令遵循 (Instruction Following)&lt;/strong&gt; 、多步推理、代码生成等能力，都是在模型参数量达到数百亿甚至千亿级别后才显著出现的。这种现象表明，大语言模型不仅仅是简单地记忆和复述，它们在学习过程中可能形成了某种更深层次的抽象和推理能力。对于智能体开发者而言，能力的涌现意味着选择一个足够大规模的模型，是实现复杂自主决策和规划能力的前提。&lt;/p&gt;
&lt;h3 id="332-模型幻觉"&gt;3.3.2 模型幻觉
&lt;/h3&gt;&lt;p&gt;&lt;strong&gt;模型幻觉（Hallucination）&lt;/strong&gt;通常指的是大语言模型生成的内容与客观事实、用户输入或上下文信息相矛盾，或者生成了不存在的事实、实体或事件。幻觉的本质是模型在生成过程中，过度自信地“编造”了信息，而非准确地检索或推理。根据其表现形式，幻觉可以被分为多种类型&lt;sup&gt;[11]&lt;/sup&gt;，例如：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;事实性幻觉 (Factual Hallucinations)&lt;/strong&gt; ： 模型生成与现实世界事实不符的信息。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;忠实性幻觉 (Faithfulness Hallucinations)&lt;/strong&gt; ： 在文本摘要、翻译等任务中，生成的内容未能忠实地反映源文本的含义。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;内在幻觉 (Intrinsic Hallucinations)&lt;/strong&gt; ： 模型生成的内容与输入信息直接矛盾。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;幻觉的产生是多方面因素共同作用的结果。首先，训练数据中可能包含错误或矛盾的信息。其次，模型的自回归生成机制决定了它只是在预测下一个最可能的词元，而没有内置的事实核查模块。最后，在面对需要复杂推理的任务时，模型可能会在逻辑链条中出错，从而“编造”出错误的结论。例如：一个旅游规划 Agent，可能会为你推荐一个现实中不存在的景点，或者预订一个航班号错误的机票。&lt;/p&gt;
&lt;p&gt;此外，大语言模型还面临着知识时效性不足和训练数据中存在的偏见等挑战。大语言模型的能力来源于其训练数据。这意味着模型所掌握的知识是其训练数据收集时的最新材料。对于在此日期之后发生的事件、新出现的概念或最新的事实，模型将无法感知或正确回答。与此同时训练数据往往包含了人类社会的各种偏见和刻板印象。当模型在这些数据上学习时，它不可避免地会吸收并反映出这些偏见&lt;sup&gt;[12]&lt;/sup&gt;。&lt;/p&gt;
&lt;p&gt;为了提高大语言模型的可靠性，研究人员和开发者正在积极探索多种检测和缓解幻觉的方法：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;数据层面&lt;/strong&gt;： 通过高质量数据清洗、引入事实性知识以及强化学习与人类反馈 (RLHF) 等方式&lt;sup&gt;[13]&lt;/sup&gt;，从源头减少幻觉。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;模型层面&lt;/strong&gt;： 探索新的模型架构，或让模型能够表达其对生成内容的不确定性。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;推理与生成层面&lt;/strong&gt;：
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;检索增强生成 (Retrieval-Augmented Generation, RAG)&lt;/strong&gt; &lt;sup&gt;[14]&lt;/sup&gt;： 这是目前缓解幻觉的有效方法之一。RAG 系统通过在生成之前从外部知识库（如文档数据库、网页）中检索相关信息，然后将检索到的信息作为上下文，引导模型生成基于事实的回答。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;多步推理与验证&lt;/strong&gt;： 引导模型进行多步推理，并在每一步进行自我检查或外部验证。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;引入外部工具&lt;/strong&gt;： 允许模型调用外部工具（如搜索引擎、计算器、代码解释器）来获取实时信息或进行精确计算。&lt;/li&gt;
&lt;/ol&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;尽管幻觉问题短期内难以完全消除，但通过上述的策略，可以显著降低其发生频率和影响，提高大语言模型在实际应用中的可靠性和实用性。&lt;/p&gt;
&lt;h2 id="34-本章小结"&gt;3.4 本章小结
&lt;/h2&gt;&lt;p&gt;本章介绍了构建智能体所需的基础知识，重点围绕作为其核心组件的大语言模型 (LLM) 展开。内容从语言模型的早期发展开始，详细讲解了 Transformer 架构，并介绍了与 LLM 进行交互的方法。最后，本章对当前主流的模型生态、发展规律及其固有局限性进行了梳理。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;核心知识点回顾：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;模型演进与核心架构&lt;/strong&gt;：本章追溯了从统计语言模型 (N-gram) 到神经网络模型 (RNN, LSTM)，再到奠定现代 LLM 基础的 Transformer 架构。通过“自顶向下”的代码实现，本章拆解了 Transformer 的核心组件，并阐述了自注意力机制在并行计算和捕捉长距离依赖中的关键作用。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;与模型的交互方式&lt;/strong&gt;：本章介绍了与 LLM 交互的两个核心环节：提示工程 (Prompt Engineering) 和文本分词 (Tokenization)。前者用于指导模型的行为，后者是理解模型输入处理的基础。通过本地部署并运行开源模型的实践，将理论知识应用于实际操作。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;模型生态与选型&lt;/strong&gt;：本章系统地梳理了为智能体选择模型时需要权衡的关键因素，并概览了以 OpenAI GPT、Google Gemini 为代表的闭源模型和以 Llama、Mistral 为代表的开源模型的特点与定位。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;法则与局限&lt;/strong&gt;：本章探讨了驱动 LLM 能力提升的缩放法则，阐述了其背后的基本原理。同时，本章也分析了模型存在的如事实幻觉、知识过时等固有局限性，这对于构建可靠、鲁棒的智能体至关重要。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;从 LLM 基础到构建智能体：&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这一章的LLM基础主要是为了帮助大家更好的理解大模型的诞生以及发展过程，其中也蕴含了智能体设计的部分思考。例如，如何设计有效的提示词来引导 Agent 的规划与决策，如何根据任务需求选择合适的模型，以及如何在 Agent 的工作流中加入验证机制以规避模型的幻觉等问题，其解决方案均建立在本章的基础之上。我们现在已经准备好从理论转向实践。在下一章，我们将开始探索智能体经典范式构建，将本章所学的知识应用于实际的智能体设计之中。&lt;/p&gt;
&lt;h2 id="习题"&gt;习题
&lt;/h2&gt;&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;自然语言处理中，语言模型经历了从统计到神经网络的模型演进。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;请使用本章提供的迷你语料库（&lt;code&gt;datawhale agent learns&lt;/code&gt;, &lt;code&gt;datawhale agent works&lt;/code&gt;），计算句子 &lt;code&gt;agent works&lt;/code&gt; 在Bigram模型下的概率&lt;/li&gt;
&lt;li&gt;N-gram模型的核心假设是马尔可夫假设。请解释这个假设的含义，以及N-gram模型存在哪些根本性局限？&lt;/li&gt;
&lt;li&gt;神经网络语言模型（RNN/LSTM）和Transformer分别是如何克服N-gram模型局限的？它们各自的优势是什么？&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Transformer架构&lt;sup&gt;[4]&lt;/sup&gt;是现代大语言模型的基础。其中：&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;提示&lt;/strong&gt;：可以结合本章3.1.2节的代码实现来辅助理解&lt;/p&gt;
&lt;/blockquote&gt;
&lt;ul&gt;
&lt;li&gt;自注意力机制（Self-Attention）的核心思想是什么？&lt;/li&gt;
&lt;li&gt;为什么Transformer能够并行处理序列，而RNN必须串行处理？位置编码（Positional Encoding）在其中起什么作用？&lt;/li&gt;
&lt;li&gt;Decoder-Only架构与完整的Encoder-Decoder架构有什么区别？为什么现在主流的大语言模型都采用Decoder-Only架构？&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;文本子词分词算法是大语言模型的一项关键技术，负责将文本转换为模型可处理的 token 序列。那为什么不能直接以&amp;quot;字符&amp;quot;或&amp;quot;单词&amp;quot;作为模型的输入单元？BPE（Byte Pair Encoding）算法解决了什么问题？&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;本章3.2.3节介绍了如何本地部署开源大语言模型。请完成以下实践和分析：&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;提示&lt;/strong&gt;：这是一道动手实践题，建议实际操作&lt;/p&gt;
&lt;/blockquote&gt;
&lt;ul&gt;
&lt;li&gt;按照本章的指导，在本地部署一个轻量级的开源模型（推荐&lt;a class="link" href="https://modelscope.cn/models/Qwen/Qwen3-0.6B" target="_blank" rel="noopener"
&gt;Qwen3-0.6B&lt;/a&gt;），并尝试调整采样参数并观察其对输出的影响&lt;/li&gt;
&lt;li&gt;选择一个具体任务（如文本分类、信息抽取、代码生成等），设计并对比以下不同的提示策略（如Zero-shot、Few-shot、Chain-of-Thought）对输出结果的效果差异&lt;/li&gt;
&lt;li&gt;从性能、成本、可控性、隐私等维度比较闭源模型和开源模型&lt;/li&gt;
&lt;li&gt;如果你要构建一个企业级的客服智能体，你会选择哪种类型的模型？需要考虑哪些因素？&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;模型幻觉（Hallucination）&lt;sup&gt;[11]&lt;/sup&gt;是大语言模型当前存在的关键局限性之一。本章介绍了缓解幻觉的方法（如检索增强生成、多步推理、外部工具调用）&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;请选择其中一种，说明其工作原理和适用场景&lt;/li&gt;
&lt;li&gt;调研前沿的研究和论文，是否还有其他的缓解模型幻觉的方法，他们又有哪些改进和优势？&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;假设你要设计一个论文辅助阅读智能体，它能够帮助研究人员快速阅读并理解学术论文，包括：总结论文研究的核心内容、回答关于论文的问题、提取关键信息、比较多篇不同论文的观点等。请回答：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;你会选择哪个模型作为智能体设计时的基座模型？选择时需要考虑哪些因素？&lt;/li&gt;
&lt;li&gt;如何设计提示词来引导模型更好地理解学术论文？学术论文通常很长，可能超过模型的上下文窗口限制，你会如何解决这个问题？&lt;/li&gt;
&lt;li&gt;学术研究是严谨的，这意味着我们需要确保智能体生成的信息是准确客观忠于原文的。你认为系统中加入哪些设计能够更好的实现这一需求？&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;h2 id="参考文献"&gt;参考文献
&lt;/h2&gt;&lt;p&gt;[1] Bengio, Y., Ducharme, R., Vincent, P., &amp;amp; Jauvin, C. (2003). A neural probabilistic language model. &lt;em&gt;Journal of Machine Learning Research&lt;/em&gt;, 3, 1137-1155.&lt;/p&gt;
&lt;p&gt;[2] Elman, J. L. (1990). Finding structure in time. &lt;em&gt;Cognitive Science&lt;/em&gt;, 14(2), 179-211.&lt;/p&gt;
&lt;p&gt;[3] Hochreiter, S., &amp;amp; Schmidhuber, J. (1997). Long short-term memory. &lt;em&gt;Neural Computation&lt;/em&gt;, 9(8), 1735-1780.&lt;/p&gt;
&lt;p&gt;[4] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., &amp;hellip; &amp;amp; Polosukhin, I. (2017). Attention is all you need. In &lt;em&gt;Advances in neural information processing systems&lt;/em&gt; (pp. 5998-6008).&lt;/p&gt;
&lt;p&gt;[5] Radford, A., Narasimhan, K., Salimans, T., &amp;amp; Sutskever, I. (2018). Improving language understanding by generative pre-training. OpenAI.&lt;/p&gt;
&lt;p&gt;[6] Gage, P. (1994). A new algorithm for data compression. &lt;em&gt;C Users Journal&lt;/em&gt;, &lt;em&gt;12&lt;/em&gt;(2), 23-38.&lt;/p&gt;
&lt;p&gt;[7] Schuster, M., &amp;amp; Nakajima, K. (2012, March). Japanese and korean voice search. In &lt;em&gt;2012 IEEE international conference on acoustics, speech and signal processing (ICASSP)&lt;/em&gt; (pp. 5149-5152). IEEE.&lt;/p&gt;
&lt;p&gt;[8] Kudo, T., &amp;amp; Richardson, J. (2018). SentencePiece: A simple and language independent subword tokenizer and detokenizer for neural text processing. &lt;em&gt;arXiv preprint arXiv:1808.06226&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[9] Kaplan, J., McCandlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., &amp;hellip; &amp;amp; Amodei, D. (2020). Scaling Laws for Neural Language Models. arXiv preprint arXiv:2001.08361.&lt;/p&gt;
&lt;p&gt;[10] Hoffmann, J., Borgeaud, E., Mensch, A., Buchatskaya, E., Cai, T., Rutherford, R., &amp;hellip; &amp;amp; Sifre, L. (2022). Training Compute-Optimal Large Language Models. arXiv preprint arXiv:2203.07678.&lt;/p&gt;
&lt;p&gt;[11] Ji, Z., Lee, N., Fries, R., Yu, T., &amp;amp; Su, D. (2023). Survey of Hallucination in Large Language Models.&lt;/p&gt;
&lt;p&gt;[12] Bender, E. M., Gebru, T., McMillan-Major, A., &amp;amp; Mitchell, M. (2021). On the Dangers of Stochastic Parrots: Can Language Models Be Too Big? .&lt;/p&gt;
&lt;p&gt;[13] Christiano, P., Leike, J., Brown, T. B., Martic, M., Legg, S., &amp;amp; Amodei, D. (2017). Deep reinforcement learning from human preferences. &lt;em&gt;arXiv preprint arXiv:1706.03741&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;[14] Lewis, P., Perez, E., Piktus, A., Petroni, F., Karpukhin, V., Goswami, N., &amp;hellip; &amp;amp; Kiela, D. (2020). Retrieval-augmented generation for knowledge-intensive NLP tasks. In &lt;em&gt;Advances in neural information processing systems&lt;/em&gt; (pp. 9459-9474).&lt;/p&gt;</description></item></channel></rss>