<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Seq2Seq on 酒中仙</title><link>https://hanguangwu.github.io/blog/tags/seq2seq/</link><description>Recent content in Seq2Seq on 酒中仙</description><generator>Hugo -- gohugo.io</generator><language>zh-cn</language><copyright>hanguangwu</copyright><lastBuildDate>Sat, 07 Feb 2026 18:34:25 -0800</lastBuildDate><atom:link href="https://hanguangwu.github.io/blog/tags/seq2seq/index.xml" rel="self" type="application/rss+xml"/><item><title>Seq2Seq 架构</title><link>https://hanguangwu.github.io/blog/p/seq2seq-%E6%9E%B6%E6%9E%84/</link><pubDate>Sat, 07 Feb 2026 18:34:25 -0800</pubDate><guid>https://hanguangwu.github.io/blog/p/seq2seq-%E6%9E%B6%E6%9E%84/</guid><description>&lt;h1 id="seq2seq-架构"&gt;Seq2Seq 架构
&lt;/h1&gt;&lt;p&gt;前面我们已经学习了如何使用 RNN 和 LSTM 处理序列数据。这些模型在三类任务中表现出色：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;多对一（Many-to-One）&lt;/strong&gt;：将整个序列信息压缩成一个特征向量，用于文本分类、情感分析等任务。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;多对多（Many-to-Many, Aligned）&lt;/strong&gt;：为输入序列的每一个词元（Token）都生成一个对应的输出，如词性标注、命名实体识别等。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;一对多（One-to-Many）&lt;/strong&gt;：从一个固定的输入（如一张图片、一个类别标签）生成一个可变长度的序列，例如图像描述生成、音乐生成等。&lt;/p&gt;
&lt;p&gt;但是，在自然语言处理中，还存在一类更复杂的、被称为&lt;strong&gt;多对多（Many-to-Many, Unaligned）&lt;/strong&gt; 的任务，它们的&lt;strong&gt;输入序列和输出序列的长度可能不相等&lt;/strong&gt;，且元素之间没有严格的对齐关系。最典型的例子就是&lt;strong&gt;机器翻译&lt;/strong&gt;，比如将“我是中国人”（3个词）翻译成 &amp;ldquo;I am Chinese&amp;rdquo;（3个词），但 “我爱人工智能”（3个词）翻译成 &amp;ldquo;I love artificial intelligence&amp;rdquo;（4个词）。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;此处将“人工智能”视为单个单元仅为方便举例，旨在说明输入与输出序列长度可能不等的概念，不代表严格的分词标准。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;对于这类问题，简单的 RNN 或 LSTM 架构难以胜任。为了解决这一挑战，2014年，研究者们提出了&lt;strong&gt;序列到序列（Sequence-to-Sequence, Seq2Seq）&lt;/strong&gt; 架构，它成功地将一种通用的&lt;strong&gt;编码器-解码器（Encoder-Decoder）&lt;/strong&gt; 架构应用于序列转换任务 &lt;sup id="fnref:1"&gt;&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref"&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;sup id="fnref:2"&gt;&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref"&gt;2&lt;/a&gt;&lt;/sup&gt;。该架构一经提出，便在机器翻译、文本摘要、对话系统等领域取得了巨大成功。&lt;/p&gt;
&lt;h2 id="一从自编码器到-seq2seq"&gt;一、从自编码器到 Seq2Seq
&lt;/h2&gt;&lt;p&gt;要理解 Seq2Seq，可以先从一种更基础的、同样使用编码器-解码器思想的无监督神经网络——&lt;strong&gt;自编码器（Autoencoder）&lt;/strong&gt; 说起。自编码器由两个部分组成：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;编码器&lt;/strong&gt;: 读取输入数据（如一张图片、一个向量），并将其压缩成一个低维度的、紧凑的&lt;strong&gt;潜在表示 (Latent Representation)&lt;/strong&gt; 。这个过程可以看作是特征提取或数据压缩。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;解码器&lt;/strong&gt;: 接收这个潜在表示，并尝试将其&lt;strong&gt;重构&lt;/strong&gt;回原始的输入数据。&lt;/p&gt;
&lt;p&gt;自编码器的训练目标是让&lt;strong&gt;输出与输入尽可能地相同&lt;/strong&gt;。通过这个过程，模型被迫学习到数据中最具代表性的核心特征，并将其编码在潜在表示中。它的目标是&lt;strong&gt;数据重构&lt;/strong&gt;，常被用于降维、特征学习或数据去噪等任务。&lt;/p&gt;
&lt;h2 id="二seq2seq-架构详解"&gt;二、Seq2Seq 架构详解
&lt;/h2&gt;&lt;p&gt;Seq2Seq 架构借鉴了自编码器的结构，但对其核心目标进行了关键的&lt;strong&gt;泛化&lt;/strong&gt;：它不再要求解码器的输出与编码器的输入相同，而是要生成一个全新的、与输入语义相关的&lt;strong&gt;目标序列&lt;/strong&gt; 。&lt;/p&gt;
&lt;h3 id="21-整体架构"&gt;2.1 整体架构
&lt;/h3&gt;&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;核心思想&lt;/strong&gt;：借鉴人类进行翻译的过程——先完整地阅读并理解源语言的整个句子，形成一个综合的&lt;strong&gt;语义表示&lt;/strong&gt;；然后，基于这个语义表示，开始用目标语言逐词生成译文。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;目标&lt;/strong&gt;：从 &lt;code&gt;Input&lt;/code&gt; 到 &lt;code&gt;Output&lt;/code&gt; 的&lt;strong&gt;转换&lt;/strong&gt;，而非重构。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;模型同样被拆分为两个组件：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;编码器&lt;/strong&gt;：扮演“阅读和理解”的角色。它负责接收整个输入序列，并将其信息压缩成一个固定长度的&lt;strong&gt;上下文向量（Context Vector）&lt;/strong&gt; ，通常记为 $C$。这个向量就是输入序列的“语义概要”。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;解码器&lt;/strong&gt;：扮演“组织语言并生成”的角色。它接收上下文向量 $C$ 作为初始信息，然后逐个生成输出序列中的词元。&lt;/p&gt;
&lt;p&gt;在最初基于 Seq2Seq 架构的模型中，编码器和解码器通常都由 RNN 或其变体（如 LSTM、GRU）构成。图 4-1 以“I love you” -&amp;gt; “我爱你”的翻译任务为例，展示了一个基于 LSTM 的 Seq2Seq 架构如何将编码、解码与自回归机制结合在一起的完整工作流程。&lt;/p&gt;
&lt;div align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/4_1_1.svg" width="80%" alt="Seq2Seq详细工作流程" /&gt;
&lt;p&gt;图 4-1 Seq2Seq 详细工作流程&lt;/p&gt;
&lt;/div&gt;
&lt;h3 id="22-编码器-encoder"&gt;2.2 编码器 (Encoder)
&lt;/h3&gt;&lt;p&gt;编码器的任务是生成上下文向量 $C$。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;它可以是一个标准的 RNN（或 LSTM），逐个读取输入序列的词元 $x_1, x_2, \dots, x_T$。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;在每个时间步，它都会根据前一时刻的状态和当前输入来更新自身状态。对于标准 RNN，这个过程可以简化为 $h_t = f(h_{t-1}, x_t)$；而对于 LSTM，则同时更新隐藏状态和细胞状态： $(h_t, c_t) = \text{LSTM}((h_{t-1}, c_{t-1}), x_t)$。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;当处理完最后一个输入词元 $x_T$后，编码器最终的状态就被用作整个输入序列的上下文向量 $C$。对于 LSTM，上下文向量 $C$ 通常就是最后一个时间步的隐藏状态和细胞状态的元组，即 $C = (h_T, c_T)$。虽然这是最常见的做法，但上下文向量 $C$ 也可以由所有时间步的隐藏状态 ${h_1, h_2, \dots, h_T}$ 经过某种变换（如拼接后通过一个线性层、或取平均池化）得到，以期保留更全面的序列信息。在图中，编码器依次处理英文单词 “I”、“love”、“you” 的词嵌入向量，并将最终的状态打包成上下文向量（Context Vector）传递给解码器。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;由于编码器在处理时可以访问整个输入序列，因此它可以使用&lt;strong&gt;双向 RNN&lt;/strong&gt;。通过同时从正向和反向两个方向读取序列，编码器可以为每个词元生成更全面的上下文表示，从而得到一个信息更丰富的上下文向量 $C$。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="23-解码器-decoder"&gt;2.3 解码器 (Decoder)
&lt;/h3&gt;&lt;p&gt;解码器的任务是根据上下文向量 $C$ 生成输出序列 $y_1, y_2, \dots, y_{T&amp;rsquo;}$。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;解码器同样可以使用一个标准的 RNN（或 LSTM）作为核心，但它扮演的角色是&lt;strong&gt;生成器&lt;/strong&gt;而非信息压缩器，因此其工作流程与编码器有显著差异。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;初始化&lt;/strong&gt; ：解码器的初始状态直接由编码器生成的上下文向量 $C$ 初始化。对于 LSTM，这意味着初始的隐藏状态和细胞状态 $(h^{\prime}_0, c^{\prime}_0)$ 都被设置为编码器的最终状态 $C=(h_T, c_T)$。这相当于将整个输入序列的“语义概要”交给了解码器。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;自回归生成 (Auto-regressive Generation)&lt;/strong&gt; ：解码器逐个生成词元。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;在第一个时间步，它以初始状态（对 LSTM 而言是 $(h^{\prime}_0, c^{\prime}_0)$ ）和一个特殊的起始符 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt; (Start of Sentence) 作为输入，生成第一个目标词元 $y_1$。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;在第二个时间步，它将上一步的状态（ $(h^{\prime}_1, c^{\prime}_1)$ ）和 &lt;strong&gt;上一步生成的词元 $y_1$&lt;/strong&gt; 作为输入，生成第二个目标词元 $y_2$。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;这个过程不断重复，状态也随之更新。对于 LSTM，这个更新过程可以表示为 $(h^{\prime}&lt;em&gt;t, c^{\prime}&lt;em&gt;t) = \text{LSTM}((h^{\prime}&lt;/em&gt;{t-1}, c^{\prime}&lt;/em&gt;{t-1}), y_{t-1})$。这个过程将持续进行，直到生成一个特殊的终止符 &lt;code&gt;&amp;lt;EOS&amp;gt;&lt;/code&gt; (End of Sentence) 或达到预设的最大长度。图中展示的正是这个过程，解码器首先接收 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt; 符和上下文向量，生成第一个汉字“我”；接着，它将“我”作为下一步的输入，生成“爱你”；这个过程将持续进行，直到生成句子结束符 &lt;code&gt;&amp;lt;EOS&amp;gt;&lt;/code&gt; 为止。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;解码器在生成序列时，是按照从左到右的顺序逐词生成的，它在预测当前词元时不能“看到”未来的词元。为满足因果性约束，解码器通常使用单向 RNN（或采用因果掩码的解码结构）。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;在每个生成步骤中，解码器的隐藏状态 $h^{\prime}_t$ 会经过一个额外的全连接层（通常带有 Softmax 激活函数），以计算出词汇表中每个单词的概率分布。然后，模型会选择概率最高的单词作为当前时间步的输出。&lt;/p&gt;
&lt;h4 id="231-解码器一个条件语言模型"&gt;2.3.1 解码器：一个条件语言模型
&lt;/h4&gt;&lt;p&gt;从更深层次看，解码器本身就是一个强大的&lt;strong&gt;条件语言模型 (Conditional Language Model)&lt;/strong&gt; 。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;语言模型&lt;/strong&gt;：一个任务是&lt;strong&gt;预测下一个词元&lt;/strong&gt;的模型。就像日常使用的手机输入法，输入“今天天气”后，它会预测出“真好”、“不错”等可能的后续词语。一个不带任何条件的标准语言模型，可以基于上文持续生成文本。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Seq2Seq 的解码器&lt;/strong&gt;：它执行的也是“预测下一个词元”的任务，但它不是凭空预测，而是 &lt;strong&gt;以编码器生成的上下文向量 $C$ 为初始条件&lt;/strong&gt;。一旦接收了 $C$，解码器就开启了它的生成过程，将这个语义概要“翻译”成目标序列。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;因此，可以认为 &lt;strong&gt;Seq2Seq 的本质 = 编码器（用于理解和压缩信息）+ 条件语言模型（用于在特定条件下生成信息）&lt;/strong&gt;。&lt;/p&gt;
&lt;h4 id="232-解码器与大语言模型的关系"&gt;2.3.2 解码器与大语言模型的关系
&lt;/h4&gt;&lt;p&gt;如果去掉编码器，只保留解码器部分，会发生什么？&lt;/p&gt;
&lt;p&gt;此时，模型不再接收外部的上下文向量 $C$ 作为条件。它可视为仅基于自身前缀（提示词）的&lt;strong&gt;decoder-only 语言模型&lt;/strong&gt;：根据已有前缀预测下一个词元。例如，给定一个起始词元或提示前缀，它可以自回归地生成后续词元，写出完整的句子或段落。&lt;/p&gt;
&lt;p&gt;这与 &lt;strong&gt;GPT (Generative Pre-trained Transformer)&lt;/strong&gt; 等现代大语言模型的训练范式一致：它们采用解码器（“Decoder-only”）架构，以“预测下一个词”为目标在海量文本上预训练；使用时通过提示词进行条件化生成，学习并利用语言规律、事实知识与一定的推理能力。&lt;/p&gt;
&lt;h3 id="24-实现细节与考量"&gt;2.4 实现细节与考量
&lt;/h3&gt;&lt;h4 id="241-词嵌入层共享"&gt;2.4.1 词嵌入层共享
&lt;/h4&gt;&lt;p&gt;编码器和解码器都需要将输入的词元ID转换为向量，这通常由一个 &lt;code&gt;Embedding&lt;/code&gt; 层完成。这里存在一个设计选择：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;不共享&lt;/strong&gt;：编码器和解码器各自拥有独立的 &lt;code&gt;Embedding&lt;/code&gt; 层。若源语言和目标语言的词汇表彼此独立（如未采用联合子词/合并词表的英译中），通常选择不共享。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;共享&lt;/strong&gt;：编码器和解码器使用同一个 &lt;code&gt;Embedding&lt;/code&gt; 层。如果源语言和目标语言词汇表有大量重叠（如文本摘要任务，输入和输出都是中文），或者干脆将两种语言的词汇合并成一个大词汇表，那么共享 &lt;code&gt;Embedding&lt;/code&gt; 层是可行的。这样做可以减少模型参数，并可能让模型学到两种语言之间词元的潜在联系。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h4 id="242-上下文向量的传递与使用"&gt;2.4.2 上下文向量的传递与使用
&lt;/h4&gt;&lt;p&gt;理论上，编码器的最终状态 $C$ 会被传递给解码器。但在实践中，如何使用这个上下文向量 $C$ 有两种主流方式，这体现了架构设计的灵活性。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;方式一：作为解码器的初始状态&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这是最经典的做法。将编码器输出的上下文向量 $C$ 经过适配层（如全连接层、&lt;code&gt;reshape&lt;/code&gt;、&lt;code&gt;permute&lt;/code&gt;等操作）变换后，作为解码器RNN的初始隐藏状态 $h^{\prime}_0$（以及 $c^{\prime}_0$ for LSTM）。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;优点&lt;/strong&gt;：概念直观，符合“先理解全文（生成 $C$ ），再开始生成（初始化解码器）”的逻辑。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;缺点&lt;/strong&gt;：解码器的所有生成步骤都源于这“唯一一次”的初始信息输入。对于长序列，RNN自身的长距离依赖问题可能会导致初始状态的信息在多步传递后逐渐“稀释”或“遗忘”。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;方式二：作为解码器每个时间步的输入&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;另一种方式是，不改变解码器默认的零向量初始状态，而是将上下文向量 $C$ 作为解码器&lt;strong&gt;每一个时间步&lt;/strong&gt;的额外输入。具体实现上，在第 $t$ 个时间步，将常规的词元输入 $y_{t-1}$ 经过&lt;code&gt;Embedding&lt;/code&gt;层后得到的向量，与上下文向量 $C$ 进行合并。合并的方式可以是：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;拼接&lt;/strong&gt;：将两个向量拼接在一起。这会改变输入特征的维度，需要相应地调整后续RNN层的输入大小。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;相加&lt;/strong&gt;：将两个向量逐元素相加。这要求词嵌入向量的维度与上下文向量 $C$ 的维度相同。同时，为了进行广播加法，需要先将二维的 $C$ (&lt;code&gt;Batch Size, Hidden Size&lt;/code&gt;) 通过 &lt;code&gt;unsqueeze&lt;/code&gt; 操作扩展成与输入词嵌入序列匹配的三维形状 (&lt;code&gt;Batch Size, 1, Hidden Size&lt;/code&gt;)。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;优点&lt;/strong&gt;：在每个生成步骤都直接“提醒”解码器全局上下文信息是什么，理论上可以更好地对抗信息遗忘。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;缺点&lt;/strong&gt;：虽然信息在每个时间步都存在，但输入的全局信息始终是&lt;strong&gt;同一个&lt;/strong&gt;静态向量 $C$，它仍然无法解决更深层次的“对齐”问题（即，在生成某个特定词时，应该重点关注输入的哪个部分）。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;总的来说，这两种方式都无法从根本上解决信息瓶颈问题，但这也正是它们激发后续注意力机制（Attention Mechanism）诞生的重要原因。&lt;/p&gt;
&lt;h4 id="243-损失函数计算"&gt;2.4.3 损失函数计算
&lt;/h4&gt;&lt;p&gt;在训练过程中，解码器的目标是让其在每个时间步 $t$ 的输出概率分布，尽可能地接近真实的目标词元 $y_t$。具体来说：&lt;/p&gt;
&lt;p&gt;（1）解码器在时间步 $t$ 的隐藏状态 $h^{\prime}_t$ 会经过一个全连接层（分类器），并使用 Softmax 函数计算出词汇表中每个词的概率，得到一个概率分布向量 $p_t$。模型的原始输出形状通常是 &lt;code&gt;(Batch Size, Sequence Length, Vocab Size)&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（2）损失函数（通常是&lt;strong&gt;交叉熵损失 Cross-Entropy Loss&lt;/strong&gt;）会计算这个预测概率分布 $p_t$ 与真实目标 $y_t$ 之间的差异。以 PyTorch 为例，&lt;code&gt;CrossEntropyLoss&lt;/code&gt; 接受形状为 &lt;code&gt;(N, C, ...)&lt;/code&gt; 的输入，也可将 &lt;code&gt;(N, L, C)&lt;/code&gt; 展平为 &lt;code&gt;(N·L, C)&lt;/code&gt; 与目标 &lt;code&gt;(N·L)&lt;/code&gt; 计算；若使用 &lt;code&gt;(N, C, L)&lt;/code&gt; 形式，可通过 &lt;code&gt;permute&lt;/code&gt; 将 &lt;code&gt;(N, L, C)&lt;/code&gt; 交换至 &lt;code&gt;(N, C, L)&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（3）训练时通常配合 &lt;code&gt;ignore_index&lt;/code&gt; 忽略 &lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt; 位置的损失，从而避免填充对梯度的干扰。&lt;/p&gt;
&lt;p&gt;（4）损失函数的计算本质上是取出 $p_t$ 中对应真实词元 $y_t$ 的那个概率值，取其负对数： $Loss_t = -\log p_t(y_t)$。&lt;/p&gt;
&lt;p&gt;（5）整个序列的总损失是所有时间步损失的累加或平均： $Loss_{total} = \sum_{t=1}^{T&amp;rsquo;} Loss_t$。&lt;/p&gt;
&lt;p&gt;（6）最后，通过反向传播算法，根据这个总损失来更新模型的所有参数（包括编码器和解码器）。&lt;/p&gt;
&lt;h4 id="244-数据填充与特殊词元"&gt;2.4.4 数据填充与特殊词元
&lt;/h4&gt;&lt;p&gt;在实际处理中，一个批次（Batch）中的序列长度往往不同。为了能够进行高效的矩阵运算，需要将它们填充（Pad）到相同的长度。此外，还需要引入一些特殊的词元（Token）来辅助模型处理。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;特殊词元&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt;：填充符，用于对齐长度，在计算损失时会被忽略。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt; 或 &lt;code&gt;&amp;lt;GO&amp;gt;&lt;/code&gt;：句子起始符，作为解码器第一个时间步的输入，启动生成过程。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;&amp;lt;EOS&amp;gt;&lt;/code&gt;：句子终止符，是解码器生成的目标之一。当模型生成它时，表示句子已完整，可以停止生成。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;&amp;lt;UNK&amp;gt;&lt;/code&gt;：未知词元。用于替换在训练词汇表中未出现过的词，增强模型的鲁棒性。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;编码器输入&lt;/strong&gt;：对源语言序列进行填充，通常在末尾添加 &lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt;。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;解码器输入与目标&lt;/strong&gt;：解码器的输入和目标序列需要精心构造，以实现“错位”训练，即用上一个真实词元预测下一个词元。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;原始目标序列&lt;/strong&gt;: &lt;code&gt;W, X, Y, Z&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;解码器输入&lt;/strong&gt;: 在序列开头添加起始符 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt;，并移除最后一个词元。如果需要填充，则在末尾添加 &lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt;。
&lt;code&gt;-&amp;gt; &amp;lt;SOS&amp;gt;, W, X, Y&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;解码器目标&lt;/strong&gt;: 在序列末尾添加终止符 &lt;code&gt;&amp;lt;EOS&amp;gt;&lt;/code&gt;后，再根据需要添加 &lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt; 进行填充以对齐批次。计算损失时会&lt;strong&gt;忽略&lt;/strong&gt;填充位置的损失。
&lt;code&gt;-&amp;gt; W, X, Y, Z, &amp;lt;EOS&amp;gt;&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;通过这种方式，模型在每个时间步都能学到从正确的历史信息到下一个正确词元的映射关系。&lt;/p&gt;
&lt;h2 id="三训练与推理模式"&gt;三、训练与推理模式
&lt;/h2&gt;&lt;h3 id="31-教师强制"&gt;3.1 教师强制
&lt;/h3&gt;&lt;p&gt;基于 Seq2Seq 架构的模型在训练和推理（即实际生成）时，解码器的工作模式有很大差异。如果在训练时也采用推理时的自回归模式（将上一时刻的预测作为下一时刻的输入），会存在两个问题：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;收敛缓慢&lt;/strong&gt;：模型在训练初期预测不准，错误的预测会不断被喂给后续的步骤，导致误差累积，模型很难收敛。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;难以并行&lt;/strong&gt;：每个时间步的计算都依赖于上一步的结果，使得训练过程无法并行化，效率低下。&lt;/p&gt;
&lt;p&gt;为了解决这个问题，Seq2Seq 采用了一种名为 &lt;strong&gt;教师强制 (Teacher Forcing)&lt;/strong&gt; &lt;sup id="fnref:3"&gt;&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref"&gt;3&lt;/a&gt;&lt;/sup&gt; 的高效训练策略。在教师强制模式下，解码器在计算第 $t$ 步的输出时，其输入 &lt;strong&gt;不再是自己上一时刻的预测值 $y^{\prime}_{t-1}$&lt;/strong&gt; ，而是直接使用 &lt;strong&gt;数据集中真实的标签值 $y_{t-1}$&lt;/strong&gt; ——其构造方式正是在 &lt;code&gt;2.4.4&lt;/code&gt; 节中描述的“解码器输入”序列。&lt;/p&gt;
&lt;p&gt;通过这种方式，解码器的每个时间步都可以接收到正确的历史信息，避免了误差累积，显著提升了收敛稳定性与速度。需要注意，对于基于 RNN 的解码器，时间维的计算仍是串行依赖的。在 Transformer 等非递归结构中，训练时可在时间步上并行（配合适当掩码）。&lt;/p&gt;
&lt;h3 id="32-自回归"&gt;3.2 自回归
&lt;/h3&gt;&lt;p&gt;在模型训练完毕，进行实际的翻译或生成任务时，我们并没有“正确答案”可以喂给解码器。此时，模型必须工作在&lt;strong&gt;自回归模式&lt;/strong&gt;下，即“自己教自己”：&lt;/p&gt;
&lt;p&gt;（1）编码器处理输入序列，生成上下文向量 $C$。&lt;/p&gt;
&lt;p&gt;（2）解码器以 $C$ 和 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt; 为初始输入，生成第一个词元 $y_1$。&lt;/p&gt;
&lt;p&gt;（3）将生成的 $y_1$ 作为解码器下一时间步的输入，生成 $y_2$。&lt;/p&gt;
&lt;p&gt;（4）不断将上一时刻的输出作为下一时刻的输入，循环此过程。&lt;/p&gt;
&lt;p&gt;（5）当解码器生成 &lt;code&gt;&amp;lt;EOS&amp;gt;&lt;/code&gt; 标志，或达到预设的最大输出长度时，生成过程停止。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;推理效率的优化&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;在朴素的自回归实现中，存在大量的重复计算。例如：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;第1步&lt;/strong&gt;：输入 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt;，RNN 内部计算 $h^{\prime}_1 = f(h^{\prime}_0, y_0)$。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;第2步&lt;/strong&gt;：输入 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt;, $y^{\prime}_1$，RNN 会&lt;strong&gt;重新&lt;/strong&gt;计算 $h^{\prime}_1 = f(h^{\prime}_0, y_0)$，然后再计算 $h^{\prime}_2 = f(h^{\prime}_1, y^{\prime}_1)$。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;第3步&lt;/strong&gt;：输入 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt;, $y^{\prime}_1$, $y^{\prime}_2$，RNN 会&lt;strong&gt;再次重新&lt;/strong&gt;计算 $h^{\prime}_1$ 和 $h^{\prime}_2$，然后再计算 $h^{\prime}_3$。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这种“从头算起”的方式效率极低，更高效的实现方式是&lt;strong&gt;缓存并利用上一个时间步的输出状态&lt;/strong&gt;。在生成第 $t$ 个词元时，只将&lt;strong&gt;第 $t-1$ 个词元&lt;/strong&gt; $y^{\prime}&lt;em&gt;{t-1}$ 和&lt;strong&gt;上一步的隐藏状态&lt;/strong&gt; $h^{\prime}&lt;/em&gt;{t-1}$ 作为 RNN 的输入，RNN 仅执行一步计算，得到新的隐藏状态 $h^{\prime}_{t}$ 和当前词元的预测 logits，这个新的状态 $h^{\prime}_t$ 会被缓存，用于下一步的计算。通过这种方式，每个时间步都只进行一次 RNN 单元的计算。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;这种在每一步都选择当前概率最高的词元作为输出的策略，被称为&lt;strong&gt;贪心搜索（Greedy Search）&lt;/strong&gt;。它简单高效，但在某些情况下可能会导致次优解。例如，如果在某一步选择了一个局部最优但在全局看来是错误的词，这个错误可能会影响后续所有词元的生成，导致整个输出序列的质量下降。这就好比下棋时只看眼前一步的最好走法，却可能导致满盘皆输。&lt;/p&gt;
&lt;p&gt;要缓解这个问题，通常有两种思路：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;提升模型能力&lt;/strong&gt;：通过使用更深、更复杂的模型架构（例如，从3层网络变成30层网络）和更大规模的训练数据，让模型本身在每一步做出正确预测的概率大大提高。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;改进解码策略&lt;/strong&gt;：使用更复杂的解码算法，如 &lt;strong&gt;束搜索（Beam Search）&lt;/strong&gt;。它在每一步都会保留多个（而不是一个）最可能的候选序列，并在最后选择整体概率最高的序列作为最终输出，从而在全局上找到更优的解，避免“一步错，步步错”的陷阱。&lt;/p&gt;
&lt;h2 id="四pytorch-代码实现与分析"&gt;四、PyTorch 代码实现与分析
&lt;/h2&gt;&lt;h3 id="41-标准的-encoder-decoder"&gt;4.1 标准的 Encoder-Decoder
&lt;/h3&gt;&lt;p&gt;首先，来构建模型的基础骨架，编码器、解码器以及将它们组合在一起的 Seq2Seq 包装器。&lt;/p&gt;
&lt;h4 id="411-编码器-encoder"&gt;4.1.1 编码器 (Encoder)
&lt;/h4&gt;&lt;p&gt;编码器的职责是读取输入序列并生成上下文向量。在这个示例实现中，将单向 LSTM 的最终隐藏状态 &lt;code&gt;hidden&lt;/code&gt; 和细胞状态 &lt;code&gt;cell&lt;/code&gt; 直接作为上下文，传递给解码器。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bidirectional&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x shape: (batch_size, seq_length)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 返回最终的隐藏状态和细胞状态作为上下文&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（1）&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;self.embedding&lt;/code&gt;: 定义词嵌入层，将输入的词元ID（整数）映射为稠密的 &lt;code&gt;hidden_size&lt;/code&gt; 维度向量。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;self.rnn&lt;/code&gt;: 定义 LSTM 层。&lt;code&gt;input_size&lt;/code&gt; 和 &lt;code&gt;hidden_size&lt;/code&gt; 均为 &lt;code&gt;hidden_size&lt;/code&gt;，因为词嵌入向量的维度与 LSTM 隐藏状态的维度在此设计中保持一致。此处为简化演示选择单向（&lt;code&gt;bidirectional=False&lt;/code&gt;）；实际工程中编码器常使用双向 RNN 以获取更充分的上下文，需要将双向状态（如拼接/线性映射）转换为解码器的初始状态。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（2）&lt;strong&gt;&lt;code&gt;forward(self, x)&lt;/code&gt;&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;输入 &lt;code&gt;x&lt;/code&gt; 是一个形状为 &lt;code&gt;(batch_size, seq_length)&lt;/code&gt; 的张量，代表了一批句子的词元ID序列。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;embedded = self.embedding(x)&lt;/code&gt;: 输入经过词嵌入层，形状变为 &lt;code&gt;(batch_size, seq_length, hidden_size)&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;_, (hidden, cell) = self.rnn(embedded)&lt;/code&gt;: &lt;code&gt;self.rnn&lt;/code&gt; 处理整个嵌入序列后，会返回两个内容：
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;outputs&lt;/code&gt;: 包含了序列中&lt;strong&gt;每一个时间步&lt;/strong&gt;的隐藏状态。对于编码器而言，中间步骤的输出通常不被使用，因此用 &lt;code&gt;_&lt;/code&gt; 接收。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;(hidden, cell)&lt;/code&gt;: 一个元组，包含了整个序列&lt;strong&gt;最后一个时间步&lt;/strong&gt;的隐藏状态和细胞状态。这正是我们需要的、概括了整个输入序列信息的&lt;strong&gt;上下文向量&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;code&gt;return hidden, cell&lt;/code&gt;: 函数最终返回这两个状态，作为上下文传递给解码器。这种实现方式对应了 &lt;code&gt;2.2&lt;/code&gt; 节中描述的最经典的做法，即直接使用编码器最后一个时间步的状态作为上下文向量 $C$。&lt;/li&gt;
&lt;/ul&gt;
&lt;h4 id="412-解码器-decoder"&gt;4.1.2 解码器 (Decoder)
&lt;/h4&gt;&lt;p&gt;解码器在每一步接收一个词元和前一步的状态，然后输出预测和新的状态。这个实现体现了为&lt;strong&gt;高效推理&lt;/strong&gt;而设计的&lt;strong&gt;单步前向传播&lt;/strong&gt;逻辑，即 &lt;code&gt;forward&lt;/code&gt; 函数一次只处理一个时间步。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Decoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;in_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# x shape: (batch_size)，只包含当前时间步的token&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; (batch_size, 1)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 接收上一步的状态 (hidden, cell)，计算当前步&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="c1"&gt;# -&amp;gt; (batch_size, vocab_size)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;predictions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（1）&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;self.embedding&lt;/code&gt; 和 &lt;code&gt;self.rnn&lt;/code&gt;: 与编码器中的定义类似。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;self.fc&lt;/code&gt;: 增加了一个全连接层（&lt;code&gt;Linear&lt;/code&gt;），它的作用是将 LSTM 输出的 &lt;code&gt;hidden_size&lt;/code&gt; 维度的隐藏状态，映射到 &lt;code&gt;vocab_size&lt;/code&gt; 维度的向量上。这个向量的每一个元素对应词汇表中一个词的得分（logit），后续可以通过 Softmax 函数转换为概率。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（2）&lt;strong&gt;&lt;code&gt;forward(self, x, hidden, cell)&lt;/code&gt;&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;这是一个&lt;strong&gt;单步&lt;/strong&gt;的前向传播函数，其输入 &lt;code&gt;x&lt;/code&gt; 是一个形状为 &lt;code&gt;(batch_size,)&lt;/code&gt; 的张量，仅包含当前时间步的词元ID。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;x = x.unsqueeze(1)&lt;/code&gt;: 为了适应 &lt;code&gt;nn.Embedding&lt;/code&gt; 和 &lt;code&gt;nn.LSTM&lt;/code&gt; 对输入形状（需要有序列长度维度）的要求，需要给 &lt;code&gt;x&lt;/code&gt; 增加一个长度为1的“伪序列”维度，使其形状变为 &lt;code&gt;(batch_size, 1)&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;embedded = self.embedding(x)&lt;/code&gt;: 词元经过嵌入，形状变为 &lt;code&gt;(batch_size, 1, hidden_size)&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;outputs, (hidden, cell) = self.rnn(embedded, (hidden, cell))&lt;/code&gt;: 解码器的 RNN 接收两个输入：当前步的嵌入向量 &lt;code&gt;embedded&lt;/code&gt;，以及&lt;strong&gt;上一步&lt;/strong&gt;传递过来的隐藏状态 &lt;code&gt;(hidden, cell)&lt;/code&gt;。它只进行一步计算，然后返回当前步的输出 &lt;code&gt;outputs&lt;/code&gt; 和更新后的状态 &lt;code&gt;(hidden, cell)&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;predictions = self.fc(outputs.squeeze(1))&lt;/code&gt;: RNN 的输出 &lt;code&gt;outputs&lt;/code&gt; 形状是 &lt;code&gt;(batch_size, 1, hidden_size)&lt;/code&gt;，需要用 &lt;code&gt;squeeze(1)&lt;/code&gt; 移除长度为1的序列维度，再送入全连接层，得到形状为 &lt;code&gt;(batch_size, vocab_size)&lt;/code&gt; 的最终预测。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;return predictions, hidden, cell&lt;/code&gt;: 返回当前步的预测，以及更新后的状态，用于下一步的计算。&lt;/li&gt;
&lt;/ul&gt;
&lt;h4 id="413-seq2seq-包装模块"&gt;4.1.3 Seq2Seq 包装模块
&lt;/h4&gt;&lt;p&gt;这个包装模块将编码器和解码器连接起来，并负责实现&lt;strong&gt;训练&lt;/strong&gt;时的逻辑，特别是&lt;strong&gt;教师强制&lt;/strong&gt;。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Seq2Seq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Seq2Seq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;teacher_forcing_ratio&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;out_features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 第一个输入是 &amp;lt;SOS&amp;gt;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 决定是否使用 Teacher Forcing&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;teacher_force&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;teacher_forcing_ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;top1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 如果 teacher_force，下一个输入是真实值；否则是模型的预测值&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;teacher_force&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;top1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;code&gt;forward&lt;/code&gt; 函数接收源序列 &lt;code&gt;src&lt;/code&gt; (形状 &lt;code&gt;(batch_size, src_len)&lt;/code&gt;) 和目标序列 &lt;code&gt;trg&lt;/code&gt; (形状 &lt;code&gt;(batch_size, trg_len)&lt;/code&gt;)，并模拟了训练过程中的一个批次计算：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;初始化&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;outputs = torch.zeros(...)&lt;/code&gt;: 创建一个形状为 &lt;code&gt;(batch_size, trg_len, vocab_size)&lt;/code&gt; 的全零张量，用于存储解码器在每一个时间步的输出 logits。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;hidden, cell = self.encoder(src)&lt;/code&gt;: 调用编码器处理源序列 &lt;code&gt;src&lt;/code&gt;，得到初始的上下文向量。&lt;code&gt;hidden&lt;/code&gt; 和 &lt;code&gt;cell&lt;/code&gt; 的形状均为 &lt;code&gt;(num_layers, batch_size, hidden_size)&lt;/code&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（2）&lt;strong&gt;启动解码&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;input = trg[:, 0]&lt;/code&gt;: 取出目标序列 &lt;code&gt;trg&lt;/code&gt; 的第一个词元（通常是 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt; 标志），作为解码器循环的起始输入。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（3）&lt;strong&gt;循环解码&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;for t in range(1, trg_len)&lt;/code&gt;: 循环从第二个词元（索引为1）开始，直到目标序列结束。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;output, hidden, cell = self.decoder(input, hidden, cell)&lt;/code&gt;: 调用解码器执行单步计算。它接收形状为 &lt;code&gt;(batch_size)&lt;/code&gt; 的 &lt;code&gt;input&lt;/code&gt; 和上一时刻的状态，返回当前步的预测 &lt;code&gt;output&lt;/code&gt; 和更新后的状态。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;outputs[:, t, :] = output&lt;/code&gt;: 将当前步的预测存入 &lt;code&gt;outputs&lt;/code&gt; 张量中。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（4）&lt;strong&gt;教师强制&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;teacher_force = random.random() &amp;lt; teacher_forcing_ratio&lt;/code&gt;: 以一定的概率决定是否启用教师强制。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;top1 = output.argmax(1)&lt;/code&gt;: 找出当前步预测概率最高的词元ID，得到形状为 &lt;code&gt;(batch_size)&lt;/code&gt; 的张量 &lt;code&gt;top1&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;input = trg[:, t] if teacher_force else top1&lt;/code&gt;: 这是教师强制的关键。根据 &lt;code&gt;teacher_force&lt;/code&gt; 的值，选择真实的下一个词元 &lt;code&gt;trg[:, t]&lt;/code&gt; 或模型自己的预测 &lt;code&gt;top1&lt;/code&gt; 作为下一步的输入。无论哪种情况，下一步的 &lt;code&gt;input&lt;/code&gt; 形状都将是 &lt;code&gt;(batch_size)&lt;/code&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（5）&lt;strong&gt;返回&lt;/strong&gt;: 最终返回 &lt;code&gt;outputs&lt;/code&gt; 张量，其形状为 &lt;code&gt;(batch_size, trg_len, vocab_size)&lt;/code&gt;，用于后续与真实标签计算损失。&lt;/p&gt;
&lt;h3 id="42-高效的推理实现"&gt;4.2 高效的推理实现
&lt;/h3&gt;&lt;p&gt;在推理时，模型必须以自回归模式运行。一个最直接的实现方式是在生成每个新词元时，都将&lt;strong&gt;已生成的完整序列&lt;/strong&gt;重新喂给解码器。例如，生成第3个词时，将 &lt;code&gt;&amp;lt;SOS&amp;gt;, y'_1, y'_2&lt;/code&gt; 作为解码器输入。&lt;/p&gt;
&lt;p&gt;这种方式虽然逻辑简单，但会导致严重的&lt;strong&gt;重复计算&lt;/strong&gt;。RNN在处理 &lt;code&gt;y'_2&lt;/code&gt; 时，会&lt;strong&gt;重新&lt;/strong&gt;计算 &lt;code&gt;&amp;lt;SOS&amp;gt;&lt;/code&gt; 和 &lt;code&gt;y'_1&lt;/code&gt; 对应的隐藏状态，而这些状态在上一步其实已经计算过了。随着序列变长，这种浪费会越来越严重，导致推理效率极低。&lt;/p&gt;
&lt;p&gt;正确的做法是利用 RNN 的“记忆”能力，&lt;strong&gt;缓存并传递状态&lt;/strong&gt;，避免重复计算。我们设计的 &lt;code&gt;Decoder&lt;/code&gt; 每次只处理一个时间步，正是为了支持这种高效模式。在推理时，只需将&lt;strong&gt;上一步的输出词元&lt;/strong&gt;和&lt;strong&gt;上一步的隐藏状态&lt;/strong&gt;传入解码器，进行单步计算，然后用返回的新状态覆盖旧状态即可。&lt;code&gt;Seq2Seq&lt;/code&gt; 类中的 &lt;code&gt;greedy_decode&lt;/code&gt; 方法展示了这一过程：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# ... 在 Seq2Seq 类中 ...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;greedy_decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sos_idx&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eos_idx&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;推理模式下的高效贪心解码。&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;sos_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 输入只有上一个时刻的词元&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_tensor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LongTensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;trg_indexes&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 解码一步，并传入上一步的状态&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trg_tensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 获取当前步的预测，并更新状态用于下一步&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pred_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pred_token&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;pred_token&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;eos_idx&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;这种方式通过&lt;strong&gt;状态的传递与更新&lt;/strong&gt;避免了重复计算：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;&lt;code&gt;hidden, cell = self.encoder(src)&lt;/code&gt;&lt;/strong&gt;: 在循环开始前，只调用一次编码器，获取初始上下文。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;循环内部&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(self.device)&lt;/code&gt;: 每次的输入&lt;strong&gt;仅仅是上一步生成的最后一个词元&lt;/strong&gt; &lt;code&gt;trg_indexes[-1]&lt;/code&gt;，而不是整个序列。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;output, hidden, cell = self.decoder(trg_tensor, hidden, cell)&lt;/code&gt;: 将这个单词元输入和&lt;strong&gt;上一步的 &lt;code&gt;hidden&lt;/code&gt;, &lt;code&gt;cell&lt;/code&gt; 状态&lt;/strong&gt;送入解码器。解码器只执行一步计算，并返回&lt;strong&gt;新的 &lt;code&gt;hidden&lt;/code&gt;, &lt;code&gt;cell&lt;/code&gt; 状态&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;这两个新状态会&lt;strong&gt;覆盖&lt;/strong&gt;旧的状态变量，并在下一次循环中被用作输入。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;通过这种方式，信息流和状态在时间步之间平稳地传递，每个时间步都只进行一次必要的计算。&lt;/p&gt;
&lt;h3 id="43-上下文向量的另一种用法"&gt;4.3 上下文向量的另一种用法
&lt;/h3&gt;&lt;p&gt;除了将上下文向量用作解码器的初始状态外，还可以将其作为解码器&lt;strong&gt;每个时间步的额外输入&lt;/strong&gt;。这种方式可以持续地为解码器提供全局信息。下面是这种变体解码器的实现。注意 &lt;code&gt;rnn&lt;/code&gt; 层的输入维度和 &lt;code&gt;forward&lt;/code&gt; 函数中的&lt;strong&gt;拼接&lt;/strong&gt;操作。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DecoderAlt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;DecoderAlt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 主要改动 1: RNN的输入维度是 词嵌入+上下文向量&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;in_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_ctx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 主要改动 2: 将上下文向量与当前输入拼接&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 这里简单地取编码器最后一层的 hidden state 作为上下文代表&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_ctx&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;rnn_input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 解码器的初始状态 hidden, cell 在第一步可设为零；之后需传递并更新上一步状态&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rnn_input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;predictions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（1）&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;:
- &lt;code&gt;self.rnn = nn.LSTM(...)&lt;/code&gt;: 这里的主要改动是 &lt;code&gt;input_size=hidden_size + hidden_size&lt;/code&gt;。因为在每个时间步，输入给 LSTM 的不再仅仅是词嵌入向量（维度 &lt;code&gt;hidden_size&lt;/code&gt;），而是&lt;strong&gt;词嵌入向量&lt;/strong&gt;与&lt;strong&gt;上下文向量&lt;/strong&gt;（维度也是 &lt;code&gt;hidden_size&lt;/code&gt;）拼接后的新向量，因此输入维度加倍。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;&lt;code&gt;forward(self, x, hidden_ctx, hidden, cell)&lt;/code&gt;&lt;/strong&gt;:
- &lt;code&gt;context = hidden_ctx[-1].unsqueeze(1).repeat(1, embedded.shape[1], 1)&lt;/code&gt;: 这一步是为了准备用于拼接的上下文向量。
- &lt;code&gt;rnn_input = torch.cat((embedded, context), dim=2)&lt;/code&gt;: 核心操作，在最后一个维度（特征维度）上，将词嵌入向量和上下文向量拼接起来，形成 RNN 的最终输入。
- &lt;code&gt;outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))&lt;/code&gt;: 将拼接后的向量送入 RNN。注意，这里传入的 &lt;code&gt;(hidden, cell)&lt;/code&gt; 是解码器自身的上一步状态（初始是零向量），而不是编码器传来的上下文 &lt;code&gt;hidden_ctx&lt;/code&gt;。上下文信息已经通过输入端注入了。&lt;/p&gt;
&lt;h3 id="完整代码"&gt;完整代码
&lt;/h3&gt;&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt; 10
&lt;/span&gt;&lt;span class="lnt"&gt; 11
&lt;/span&gt;&lt;span class="lnt"&gt; 12
&lt;/span&gt;&lt;span class="lnt"&gt; 13
&lt;/span&gt;&lt;span class="lnt"&gt; 14
&lt;/span&gt;&lt;span class="lnt"&gt; 15
&lt;/span&gt;&lt;span class="lnt"&gt; 16
&lt;/span&gt;&lt;span class="lnt"&gt; 17
&lt;/span&gt;&lt;span class="lnt"&gt; 18
&lt;/span&gt;&lt;span class="lnt"&gt; 19
&lt;/span&gt;&lt;span class="lnt"&gt; 20
&lt;/span&gt;&lt;span class="lnt"&gt; 21
&lt;/span&gt;&lt;span class="lnt"&gt; 22
&lt;/span&gt;&lt;span class="lnt"&gt; 23
&lt;/span&gt;&lt;span class="lnt"&gt; 24
&lt;/span&gt;&lt;span class="lnt"&gt; 25
&lt;/span&gt;&lt;span class="lnt"&gt; 26
&lt;/span&gt;&lt;span class="lnt"&gt; 27
&lt;/span&gt;&lt;span class="lnt"&gt; 28
&lt;/span&gt;&lt;span class="lnt"&gt; 29
&lt;/span&gt;&lt;span class="lnt"&gt; 30
&lt;/span&gt;&lt;span class="lnt"&gt; 31
&lt;/span&gt;&lt;span class="lnt"&gt; 32
&lt;/span&gt;&lt;span class="lnt"&gt; 33
&lt;/span&gt;&lt;span class="lnt"&gt; 34
&lt;/span&gt;&lt;span class="lnt"&gt; 35
&lt;/span&gt;&lt;span class="lnt"&gt; 36
&lt;/span&gt;&lt;span class="lnt"&gt; 37
&lt;/span&gt;&lt;span class="lnt"&gt; 38
&lt;/span&gt;&lt;span class="lnt"&gt; 39
&lt;/span&gt;&lt;span class="lnt"&gt; 40
&lt;/span&gt;&lt;span class="lnt"&gt; 41
&lt;/span&gt;&lt;span class="lnt"&gt; 42
&lt;/span&gt;&lt;span class="lnt"&gt; 43
&lt;/span&gt;&lt;span class="lnt"&gt; 44
&lt;/span&gt;&lt;span class="lnt"&gt; 45
&lt;/span&gt;&lt;span class="lnt"&gt; 46
&lt;/span&gt;&lt;span class="lnt"&gt; 47
&lt;/span&gt;&lt;span class="lnt"&gt; 48
&lt;/span&gt;&lt;span class="lnt"&gt; 49
&lt;/span&gt;&lt;span class="lnt"&gt; 50
&lt;/span&gt;&lt;span class="lnt"&gt; 51
&lt;/span&gt;&lt;span class="lnt"&gt; 52
&lt;/span&gt;&lt;span class="lnt"&gt; 53
&lt;/span&gt;&lt;span class="lnt"&gt; 54
&lt;/span&gt;&lt;span class="lnt"&gt; 55
&lt;/span&gt;&lt;span class="lnt"&gt; 56
&lt;/span&gt;&lt;span class="lnt"&gt; 57
&lt;/span&gt;&lt;span class="lnt"&gt; 58
&lt;/span&gt;&lt;span class="lnt"&gt; 59
&lt;/span&gt;&lt;span class="lnt"&gt; 60
&lt;/span&gt;&lt;span class="lnt"&gt; 61
&lt;/span&gt;&lt;span class="lnt"&gt; 62
&lt;/span&gt;&lt;span class="lnt"&gt; 63
&lt;/span&gt;&lt;span class="lnt"&gt; 64
&lt;/span&gt;&lt;span class="lnt"&gt; 65
&lt;/span&gt;&lt;span class="lnt"&gt; 66
&lt;/span&gt;&lt;span class="lnt"&gt; 67
&lt;/span&gt;&lt;span class="lnt"&gt; 68
&lt;/span&gt;&lt;span class="lnt"&gt; 69
&lt;/span&gt;&lt;span class="lnt"&gt; 70
&lt;/span&gt;&lt;span class="lnt"&gt; 71
&lt;/span&gt;&lt;span class="lnt"&gt; 72
&lt;/span&gt;&lt;span class="lnt"&gt; 73
&lt;/span&gt;&lt;span class="lnt"&gt; 74
&lt;/span&gt;&lt;span class="lnt"&gt; 75
&lt;/span&gt;&lt;span class="lnt"&gt; 76
&lt;/span&gt;&lt;span class="lnt"&gt; 77
&lt;/span&gt;&lt;span class="lnt"&gt; 78
&lt;/span&gt;&lt;span class="lnt"&gt; 79
&lt;/span&gt;&lt;span class="lnt"&gt; 80
&lt;/span&gt;&lt;span class="lnt"&gt; 81
&lt;/span&gt;&lt;span class="lnt"&gt; 82
&lt;/span&gt;&lt;span class="lnt"&gt; 83
&lt;/span&gt;&lt;span class="lnt"&gt; 84
&lt;/span&gt;&lt;span class="lnt"&gt; 85
&lt;/span&gt;&lt;span class="lnt"&gt; 86
&lt;/span&gt;&lt;span class="lnt"&gt; 87
&lt;/span&gt;&lt;span class="lnt"&gt; 88
&lt;/span&gt;&lt;span class="lnt"&gt; 89
&lt;/span&gt;&lt;span class="lnt"&gt; 90
&lt;/span&gt;&lt;span class="lnt"&gt; 91
&lt;/span&gt;&lt;span class="lnt"&gt; 92
&lt;/span&gt;&lt;span class="lnt"&gt; 93
&lt;/span&gt;&lt;span class="lnt"&gt; 94
&lt;/span&gt;&lt;span class="lnt"&gt; 95
&lt;/span&gt;&lt;span class="lnt"&gt; 96
&lt;/span&gt;&lt;span class="lnt"&gt; 97
&lt;/span&gt;&lt;span class="lnt"&gt; 98
&lt;/span&gt;&lt;span class="lnt"&gt; 99
&lt;/span&gt;&lt;span class="lnt"&gt;100
&lt;/span&gt;&lt;span class="lnt"&gt;101
&lt;/span&gt;&lt;span class="lnt"&gt;102
&lt;/span&gt;&lt;span class="lnt"&gt;103
&lt;/span&gt;&lt;span class="lnt"&gt;104
&lt;/span&gt;&lt;span class="lnt"&gt;105
&lt;/span&gt;&lt;span class="lnt"&gt;106
&lt;/span&gt;&lt;span class="lnt"&gt;107
&lt;/span&gt;&lt;span class="lnt"&gt;108
&lt;/span&gt;&lt;span class="lnt"&gt;109
&lt;/span&gt;&lt;span class="lnt"&gt;110
&lt;/span&gt;&lt;span class="lnt"&gt;111
&lt;/span&gt;&lt;span class="lnt"&gt;112
&lt;/span&gt;&lt;span class="lnt"&gt;113
&lt;/span&gt;&lt;span class="lnt"&gt;114
&lt;/span&gt;&lt;span class="lnt"&gt;115
&lt;/span&gt;&lt;span class="lnt"&gt;116
&lt;/span&gt;&lt;span class="lnt"&gt;117
&lt;/span&gt;&lt;span class="lnt"&gt;118
&lt;/span&gt;&lt;span class="lnt"&gt;119
&lt;/span&gt;&lt;span class="lnt"&gt;120
&lt;/span&gt;&lt;span class="lnt"&gt;121
&lt;/span&gt;&lt;span class="lnt"&gt;122
&lt;/span&gt;&lt;span class="lnt"&gt;123
&lt;/span&gt;&lt;span class="lnt"&gt;124
&lt;/span&gt;&lt;span class="lnt"&gt;125
&lt;/span&gt;&lt;span class="lnt"&gt;126
&lt;/span&gt;&lt;span class="lnt"&gt;127
&lt;/span&gt;&lt;span class="lnt"&gt;128
&lt;/span&gt;&lt;span class="lnt"&gt;129
&lt;/span&gt;&lt;span class="lnt"&gt;130
&lt;/span&gt;&lt;span class="lnt"&gt;131
&lt;/span&gt;&lt;span class="lnt"&gt;132
&lt;/span&gt;&lt;span class="lnt"&gt;133
&lt;/span&gt;&lt;span class="lnt"&gt;134
&lt;/span&gt;&lt;span class="lnt"&gt;135
&lt;/span&gt;&lt;span class="lnt"&gt;136
&lt;/span&gt;&lt;span class="lnt"&gt;137
&lt;/span&gt;&lt;span class="lnt"&gt;138
&lt;/span&gt;&lt;span class="lnt"&gt;139
&lt;/span&gt;&lt;span class="lnt"&gt;140
&lt;/span&gt;&lt;span class="lnt"&gt;141
&lt;/span&gt;&lt;span class="lnt"&gt;142
&lt;/span&gt;&lt;span class="lnt"&gt;143
&lt;/span&gt;&lt;span class="lnt"&gt;144
&lt;/span&gt;&lt;span class="lnt"&gt;145
&lt;/span&gt;&lt;span class="lnt"&gt;146
&lt;/span&gt;&lt;span class="lnt"&gt;147
&lt;/span&gt;&lt;span class="lnt"&gt;148
&lt;/span&gt;&lt;span class="lnt"&gt;149
&lt;/span&gt;&lt;span class="lnt"&gt;150
&lt;/span&gt;&lt;span class="lnt"&gt;151
&lt;/span&gt;&lt;span class="lnt"&gt;152
&lt;/span&gt;&lt;span class="lnt"&gt;153
&lt;/span&gt;&lt;span class="lnt"&gt;154
&lt;/span&gt;&lt;span class="lnt"&gt;155
&lt;/span&gt;&lt;span class="lnt"&gt;156
&lt;/span&gt;&lt;span class="lnt"&gt;157
&lt;/span&gt;&lt;span class="lnt"&gt;158
&lt;/span&gt;&lt;span class="lnt"&gt;159
&lt;/span&gt;&lt;span class="lnt"&gt;160
&lt;/span&gt;&lt;span class="lnt"&gt;161
&lt;/span&gt;&lt;span class="lnt"&gt;162
&lt;/span&gt;&lt;span class="lnt"&gt;163
&lt;/span&gt;&lt;span class="lnt"&gt;164
&lt;/span&gt;&lt;span class="lnt"&gt;165
&lt;/span&gt;&lt;span class="lnt"&gt;166
&lt;/span&gt;&lt;span class="lnt"&gt;167
&lt;/span&gt;&lt;span class="lnt"&gt;168
&lt;/span&gt;&lt;span class="lnt"&gt;169
&lt;/span&gt;&lt;span class="lnt"&gt;170
&lt;/span&gt;&lt;span class="lnt"&gt;171
&lt;/span&gt;&lt;span class="lnt"&gt;172
&lt;/span&gt;&lt;span class="lnt"&gt;173
&lt;/span&gt;&lt;span class="lnt"&gt;174
&lt;/span&gt;&lt;span class="lnt"&gt;175
&lt;/span&gt;&lt;span class="lnt"&gt;176
&lt;/span&gt;&lt;span class="lnt"&gt;177
&lt;/span&gt;&lt;span class="lnt"&gt;178
&lt;/span&gt;&lt;span class="lnt"&gt;179
&lt;/span&gt;&lt;span class="lnt"&gt;180
&lt;/span&gt;&lt;span class="lnt"&gt;181
&lt;/span&gt;&lt;span class="lnt"&gt;182
&lt;/span&gt;&lt;span class="lnt"&gt;183
&lt;/span&gt;&lt;span class="lnt"&gt;184
&lt;/span&gt;&lt;span class="lnt"&gt;185
&lt;/span&gt;&lt;span class="lnt"&gt;186
&lt;/span&gt;&lt;span class="lnt"&gt;187
&lt;/span&gt;&lt;span class="lnt"&gt;188
&lt;/span&gt;&lt;span class="lnt"&gt;189
&lt;/span&gt;&lt;span class="lnt"&gt;190
&lt;/span&gt;&lt;span class="lnt"&gt;191
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;random&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 1. 全局配置&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;src_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;trg_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;src_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;trg_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;120&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;sos_idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="c1"&gt;# 句子起始标记索引&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;eos_idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="c1"&gt;# 句子结束标记索引&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 2. 模型定义&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;编码器: 读取输入序列，输出上下文向量（隐藏状态和细胞状态）。&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 输入特征维度&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 隐藏状态维度&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bidirectional&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;解码器（标准实现）: 接收上一个预测的token和当前状态，单步输出预测和新状态。&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Decoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;in_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;predictions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Seq2Seq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;Seq2Seq 包装模块: 管理 Encoder 和 Decoder。&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Seq2Seq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;teacher_forcing_ratio&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;训练模式下的前向传播，使用 Teacher Forcing。&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;out_features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;teacher_force&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;teacher_forcing_ratio&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;top1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;teacher_force&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;top1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;greedy_decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;推理模式下的高效贪心解码。&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;sos_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_tensor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LongTensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;trg_indexes&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trg_tensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pred_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pred_token&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;pred_token&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;eos_idx&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 3. 变体模型定义&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DecoderAlt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;解码器变体: 不用上下文初始化状态，而是在每步将其作为输入。&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;DecoderAlt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 注意：这里的输入维度是两个 hidden_size 拼接而成&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;in_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_ctx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_ctx&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;rnn_input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rnn_input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;predictions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 4. 解码策略&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;alternative_greedy_decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;&amp;#34;&amp;#34;配合 DecoderAlt 的解码实现。&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_ctx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell_ctx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;sos_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 初始化解码器的&amp;#34;真实&amp;#34;状态为0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_tensor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LongTensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;trg_indexes&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trg_tensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_ctx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pred_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pred_token&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;pred_token&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;eos_idx&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;trg_indexes&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 5. 主流程&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="vm"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;__main__&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;cuda&amp;#34;&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cpu&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Using device: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# --- 统一初始化模型 ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;encoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Encoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;decoder&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Decoder&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trg_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Seq2Seq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# --- 统一创建伪数据 ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;src&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src_len&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;trg&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg_len&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# =========================================&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1: 标准训练 &amp;amp; 高效推理&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# =========================================&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;=&amp;#34;&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="mi"&gt;25&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34; 1: 标准模式 &amp;#34;&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;=&amp;#34;&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 训练过程模拟 (Teacher Forcing)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;teacher_forcing_ratio&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;训练模式输出张量形状: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 推理过程模拟 (高效的自回归)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;prediction&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;greedy_decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;高效推理的预测结果: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;prediction&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# =========================================&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2: 上下文向量的另一种用法&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# =========================================&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;=&amp;#34;&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="mi"&gt;23&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34; 2: 上下文变体用法 &amp;#34;&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;=&amp;#34;&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="mi"&gt;23&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;decoder_alt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;DecoderAlt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trg_vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;prediction_alt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alternative_greedy_decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;encoder&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decoder_alt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;src&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;变体用法预测结果: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;prediction_alt&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h2 id="五应用与泛化"&gt;五、应用与泛化
&lt;/h2&gt;&lt;p&gt;Seq2Seq 架构的成功也揭示了其背后 Encoder-Decoder 框架的强大通用性。这个框架本质上定义了一个“将一种数据形态转换为另一种数据形态”的通用范式，因此其应用远不止于文本到文本的任务。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;语音识别（Audio-to-Text）&lt;/strong&gt;：编码器可以是一个处理音频信号的模型（如基于RNN或卷积的模型），提取语音特征并生成上下文向量；解码器则基于此向量生成识别出的文本序列。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;图像描述生成（Image-to-Text）&lt;/strong&gt;：编码器也可以是一个卷积神经网络（CNN），负责“阅读”整张图片并提取其视觉特征，生成一个概括图片内容的上下文向量；解码器则根据该向量生成一段描述性的文字，实现“看图说话”。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;文本到语音（Text-to-Speech, TTS）&lt;/strong&gt;：与语音识别相反，编码器处理输入文本，解码器则生成对应的音频波形数据。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;问答系统（QA）&lt;/strong&gt;：模型可以将一篇参考文章和用户提问一起编码，然后解码生成问题的答案。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;任务范式统一&lt;/strong&gt;：甚至传统的分类任务也可以被“生成化”。例如，在文本分类任务中，可以构造一个特殊的输入（即 Prompt），引导模型直接生成类别名称。这种方式极大地统一了不同 NLP 任务的处理范式。一个具体的例子如下：
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;输入&lt;/strong&gt;: &lt;code&gt;&amp;quot;请判断以下文本的类别。可选类别列表为：[科技, 体育, 财经]。文本：中国队在世界游泳锦标赛上获得了五枚金牌。&amp;quot;&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;期望输出&lt;/strong&gt;: &lt;code&gt;&amp;quot;体育&amp;quot;&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;通过替换不同的编码器和解码器实现，Seq2Seq 架构可以灵活地应用于各种跨模态的转换任务中。&lt;/p&gt;
&lt;h2 id="六seq2seq-的局限性信息瓶颈"&gt;六、Seq2Seq 的局限性：信息瓶颈
&lt;/h2&gt;&lt;p&gt;尽管基于 Seq2Seq 架构的模型取得了巨大成功，但它也存在一个明显的缺陷——&lt;strong&gt;信息瓶颈（Information Bottleneck）&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;这个问题在概念上与前一章讨论&lt;strong&gt;长距离依赖&lt;/strong&gt;非常相似，但发生在不同的层面：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;长距离依赖&lt;/strong&gt;是 RNN &lt;strong&gt;内部&lt;/strong&gt;的问题，指信息在&lt;strong&gt;单一序列处理过程&lt;/strong&gt;中因梯度累乘而难以从序列开端传递到末端。LSTM 通过门控机制和细胞状态缓解了这个问题。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;信息瓶颈&lt;/strong&gt;则是 Encoder-Decoder &lt;strong&gt;架构层面&lt;/strong&gt;的问题。它与 RNN 内部如何传递信息无关，而在于它规定了编码器和解码器之间唯一的沟通桥梁就是一个&lt;strong&gt;固定长度&lt;/strong&gt;的上下文向量 $C$。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;编码器必须将输入序列的所有信息，无论其长短，都压缩到这个向量中。可以说，编码器自身的长距离依赖问题，进一步加剧了信息瓶颈的严重性。即便编码器使用了 LSTM，能更好地在内部传递信息，但当输入句子很长时，这个最终的上下文向量 $C$ 依然很难承载全部的语义细节，模型可能会“遗忘”掉句子开头的关键信息，导致生成质量下降。这就好比让一个人将一篇长文的所有细节都总结成&lt;strong&gt;一句话&lt;/strong&gt;，然后仅凭这一句话去复述原文，必然会丢失大量信息。&lt;/p&gt;
&lt;p&gt;可以用一个更具体的例子来理解这个问题。假设在做一个对联生成的任务，上联是“两个黄鹂鸣翠柳”。在生成下联时，期望第一个词（如“一行”）能够主要参考上联的第一个词“两个”，第二个词（如“白鹭”）主要参考“黄鹂”，以此类推，形成对仗。&lt;/p&gt;
&lt;p&gt;不过，在标准的 Seq2Seq 架构中，存在两个核心问题：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;信息稀释&lt;/strong&gt;：“两个”这个词的信息经过多步 RNN 传递后，在最终的上下文向量 $C$ 中可能已经变得非常微弱。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;信息无差别（缺乏倾向性）&lt;/strong&gt;：解码器在生成每一个词（“一行”、“白鹭”、“上青天”）时，所依赖的全局信息都是同一个、包含了整个上联概要的上下文向量 $C$。它没有一种机制去“特别关注”或“倾向于”当前生成位置所对应的输入部分。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;即使采用 &lt;code&gt;2.4.2&lt;/code&gt; 讨论的第二种方式，将 $C$ 作为解码器&lt;strong&gt;每个时间步的额外输入&lt;/strong&gt;，问题依然存在。因为每个时间步输入的都是&lt;strong&gt;同一个&lt;/strong&gt; $C$，模型仍然无法学会有选择性地、有侧重地利用输入信息，缺乏这种动态的“倾向性”。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;为了解决这个信息瓶颈和对齐问题，后续研究者们引入了&lt;strong&gt;注意力机制&lt;/strong&gt;。允许解码器在生成每个词元时，都能“回头看”并动态地计算一个权重分布，从而重点关注输入序列的不同部分，而不是仅仅依赖于单一的上下文向量。这极大地提升了长序列任务的性能，并直接催生了后来更强大的 Transformer 模型。&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id="参考文献"&gt;参考文献
&lt;/h2&gt;&lt;div class="footnotes" role="doc-endnotes"&gt;
&lt;hr&gt;
&lt;ol&gt;
&lt;li id="fn:1"&gt;
&lt;p&gt;&lt;a class="link" href="https://proceedings.neurips.cc/paper/2014/hash/a14ac55a4f27472c5d894ec1c3c743d2-Abstract.html" target="_blank" rel="noopener"
&gt;Sutskever, I., Vinyals, O., &amp;amp; Le, Q. V. (2014). &lt;em&gt;Sequence to sequence learning with neural networks&lt;/em&gt;. Advances in Neural Information Processing Systems, 27.&lt;/a&gt;&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:2"&gt;
&lt;p&gt;&lt;a class="link" href="https://aclanthology.org/D14-1179/" target="_blank" rel="noopener"
&gt;Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau,D., Bougares, F., Schwenk, H., &amp;amp; Bengio, Y. (2014). &lt;em&gt;Learning phrase representations using RNN encoder-decoder for statistical machine translation&lt;/em&gt;. Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP).&lt;/a&gt;&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li id="fn:3"&gt;
&lt;p&gt;&lt;a class="link" href="https://proceedings.neurips.cc/paper/2015/hash/e995f98d56967d946471af29d7bf9a1e-Abstract.html" target="_blank" rel="noopener"
&gt;Bengio, S., Vinyals, O., Jaitly, N., &amp;amp; Shazeer, N. (2015). &lt;em&gt;Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks&lt;/em&gt;. Advances in Neural Information Processing Systems (NeurIPS).&lt;/a&gt;&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink"&gt;&amp;#x21a9;&amp;#xfe0e;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;/div&gt;</description></item></channel></rss>