<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>文本分类 on 酒中仙</title><link>https://hanguangwu.github.io/blog/tags/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB/</link><description>Recent content in 文本分类 on 酒中仙</description><generator>Hugo -- gohugo.io</generator><language>zh-cn</language><copyright>hanguangwu</copyright><lastBuildDate>Wed, 25 Mar 2026 18:34:25 -0800</lastBuildDate><atom:link href="https://hanguangwu.github.io/blog/tags/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB/index.xml" rel="self" type="application/rss+xml"/><item><title>微调 BERT 模型进行文本分类</title><link>https://hanguangwu.github.io/blog/p/%E5%BE%AE%E8%B0%83-bert-%E6%A8%A1%E5%9E%8B%E8%BF%9B%E8%A1%8C%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB/</link><pubDate>Wed, 25 Mar 2026 18:34:25 -0800</pubDate><guid>https://hanguangwu.github.io/blog/p/%E5%BE%AE%E8%B0%83-bert-%E6%A8%A1%E5%9E%8B%E8%BF%9B%E8%A1%8C%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB/</guid><description>&lt;h1 id="微调-bert-模型进行文本分类"&gt;微调 BERT 模型进行文本分类
&lt;/h1&gt;&lt;p&gt;回顾前两节的内容，我们依次实现了一个基于全连接网络的“词袋”模型和一个基于 LSTM 的序列模型。这两次的运行结果揭示了一个有趣的现象，对于当前的新闻分类任务，结构更复杂的 LSTM 模型在性能上并未超越更简单的全连接模型。这说明，对于这个特定任务，捕捉“关键词”比分析“词序”更关键。不过，这两种模型都是从零开始训练的，它们对语言的理解完全依赖于我们提供的小规模 &lt;code&gt;20 Newsgroups&lt;/code&gt; 数据集。那么我们能否利用在更大、更通用语料库上预先学到的知识，来帮助模型更好地理解文本，从而提升分类性能呢？&lt;/p&gt;
&lt;p&gt;这就是&lt;strong&gt;预训练语言模型&lt;/strong&gt;，特别是 BERT，所要解决的问题。在&lt;strong&gt;第五章第一节&lt;/strong&gt;中，我们已经学习了 BERT 的原理。它通过在海量原始文本上以&lt;strong&gt;自监督&lt;/strong&gt;的方式构造“掩码语言模型”和“下一句预测”等训练任务（无需人工标注标签），学习到了丰富的语言学知识和世界知识。本节是文本分类系列实战的最后一站。我们将把模型架构迁移为 BERT，探索从“从零训练”到“微调”这一范式转变所带来的性能提升。&lt;/p&gt;
&lt;h2 id="一从序列建模到预训练微调"&gt;一、从“序列建模”到“预训练微调”
&lt;/h2&gt;&lt;p&gt;回顾前两个模型，它们的核心都是在特定任务数据上从随机初始化的词向量开始学习如何进行分类。而基于 BERT 的微调则采用了一种完全不同的范式，通常包含以下三个步骤：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;加载预训练权重&lt;/strong&gt;：我们不再随机初始化模型，而是加载一个已经在海量数据（如维基百科、书籍）上训练好的 BERT 模型。这个模型已经是一个通用的“语言理解专家”。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;附加任务相关的“头”&lt;/strong&gt;：在 BERT 模型的主体结构之上，我们添加一个简单的、未经训练的分类层（通常就是一个全连接层）。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;在下游任务上“微调”&lt;/strong&gt;：使用我们的新闻分类数据，对整个模型（或者仅仅是顶部的分类层）进行训练。由于 BERT 部分已经具备了强大的语言理解能力，整个模型可以很快地适应新的分类任务，并且通常只需要很少的训练轮次和较小的学习率。&lt;/p&gt;
&lt;p&gt;这个“预训练-微调”的范式是现代 NLP 领域最主流、最有效的方法之一。它大大降低了对特定任务标注数据的依赖，并显著提升了模型性能的上限。&lt;/p&gt;
&lt;h2 id="二代码修改实践"&gt;二、代码修改实践
&lt;/h2&gt;&lt;p&gt;将 LSTM 模型改造为 BERT 模型，同样遵循之前的思路，主要修改涉及数据处理和模型结构，同时也要相应地调整训练超参数。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;a class="link" href="https://github.com/datawhalechina/base-nlp/blob/main/code/C7/03_bert_text_classification.ipynb" target="_blank" rel="noopener"
&gt;本节完整代码&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="21-替换为-berttokenizer"&gt;2.1 替换为 &lt;code&gt;BertTokenizer&lt;/code&gt;
&lt;/h3&gt;&lt;p&gt;现在我们不需要手动构建词典。&lt;code&gt;transformers&lt;/code&gt; 库为每个预训练模型都提供了配套的 &lt;code&gt;Tokenizer&lt;/code&gt;。对于英文 &lt;code&gt;20 Newsgroups&lt;/code&gt; 数据集，选择 &lt;code&gt;bert-base-uncased&lt;/code&gt; 模型及其对应的分词器。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;bert_model_name&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;bert-base-uncased&amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;bert_model_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 查看特殊token&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;UNK token: &amp;#39;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unk_token&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#39;, ID: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unk_token_id&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;PAD token: &amp;#39;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pad_token&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#39;, ID: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pad_token_id&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;CLS token: &amp;#39;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls_token&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#39;, ID: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls_token_id&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;SEP token: &amp;#39;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sep_token&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#39;, ID: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sep_token_id&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Vocab size: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;UNK token: &lt;span class="s1"&gt;&amp;#39;[UNK]&amp;#39;&lt;/span&gt;, ID: &lt;span class="m"&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;PAD token: &lt;span class="s1"&gt;&amp;#39;[PAD]&amp;#39;&lt;/span&gt;, ID: &lt;span class="m"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;CLS token: &lt;span class="s1"&gt;&amp;#39;[CLS]&amp;#39;&lt;/span&gt;, ID: &lt;span class="m"&gt;101&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;SEP token: &lt;span class="s1"&gt;&amp;#39;[SEP]&amp;#39;&lt;/span&gt;, ID: &lt;span class="m"&gt;102&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Vocab size: &lt;span class="m"&gt;30522&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;code&gt;BertTokenizer&lt;/code&gt; 会自动处理文本的预处理（如小写转换、标点分割），并为文本添加特殊的 &lt;code&gt;[CLS]&lt;/code&gt; 和 &lt;code&gt;[SEP]&lt;/code&gt; 标记。其中，&lt;strong&gt;&lt;code&gt;[CLS]&lt;/code&gt;&lt;/strong&gt; 位于序列开头，它在 BERT 输出中对应的向量通常被用作整个序列的聚合表示，非常适合用于分类任务；&lt;strong&gt;&lt;code&gt;[SEP]&lt;/code&gt;&lt;/strong&gt; 用于分隔两个句子，在单句分类任务中标志着句子的结束。&lt;/p&gt;
&lt;h3 id="22-改造-dataset-与-collate_fn"&gt;2.2 改造 &lt;code&gt;Dataset&lt;/code&gt; 与 &lt;code&gt;collate_fn&lt;/code&gt;
&lt;/h3&gt;&lt;p&gt;为了适配 BERT，数据处理流程需要进行如下调整：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;&lt;code&gt;Dataset&lt;/code&gt;&lt;/strong&gt;: &lt;code&gt;BertTextClassificationDataset&lt;/code&gt; 现在直接调用 &lt;code&gt;BertTokenizer&lt;/code&gt; 来进行分词和ID转换。处理长文本的滑窗分割逻辑保持不变。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;BertTextClassificationDataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Dataset&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;texts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;processed_data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;texts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;total&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;desc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Processing Dataset&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 直接使用BertTokenizer进行编码&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;encoding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;add_special_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;encoding&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;input_ids&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 滑窗分割逻辑保持不变&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;processed_data&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;input_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;label&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;stride&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stride&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;chunk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;processed_data&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;input_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;chunk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;label&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（2）&lt;strong&gt;&lt;code&gt;collate_fn&lt;/code&gt;&lt;/strong&gt;: BERT 的一个重要输入是 &lt;code&gt;attention_mask&lt;/code&gt;（注意力掩码）。它是一个与 &lt;code&gt;input_ids&lt;/code&gt; 形状相同的张量，用 &lt;code&gt;1&lt;/code&gt; 标记真实 Token，用 &lt;code&gt;0&lt;/code&gt; 标记填充（Padding）的 Token。模型会根据这个掩码，在计算注意力时忽略填充部分。所以，我们需要修改 &lt;code&gt;collate_fn&lt;/code&gt; 以生成并返回 &lt;code&gt;attention_mask&lt;/code&gt;。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;bert_collate_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_batch_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;input_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_attention_masks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;input_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_batch_len&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padded_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pad_token_id&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 新增：生成 attention_mask&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_input_ids&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;padded_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_attention_masks&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attention_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_labels&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;label&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;input_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 新增：返回 attention_mask&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;attention_mask&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_attention_masks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;labels&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_labels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="23-构建-textclassifierbert-模型"&gt;2.3 构建 &lt;code&gt;TextClassifierBERT&lt;/code&gt; 模型
&lt;/h3&gt;&lt;p&gt;得益于 &lt;code&gt;transformers&lt;/code&gt; 库的高度封装，从代码实现的角度来看，新的模型结构非常简洁，一个预训练的 BERT 主干网络 + 一个线性分类头。尽管 BERT 模型内部结构极其复杂，但我们只需几行代码便可调用。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;BertModel&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TextClassifierBERT&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;freeze_bert&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;TextClassifierBERT&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 加载预训练的BERT模型&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bert&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertModel&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. 定义分类头&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classifier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bert&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. (可选) 冻结BERT参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;freeze_bert&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;param&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bert&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;param&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="kc"&gt;False&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;attention_mask&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将输入传入BERT&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bert&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;attention_mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;attention_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 使用[CLS] token的输出(pooler_output)进行分类&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pooled_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pooler_output&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 传入分类头得到logits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pooled_output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;在 &lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt; 方法中，&lt;code&gt;BertModel.from_pretrained(model_name)&lt;/code&gt; 会自动下载并加载指定名称的预训练模型权重。分类头的输入维度直接从 &lt;code&gt;self.bert.config.hidden_size&lt;/code&gt; 获取，这是一个非常好的实践，避免了硬编码。此外提供了 &lt;code&gt;freeze_bert&lt;/code&gt; 选项，如果为 &lt;code&gt;True&lt;/code&gt;，则 BERT 部分的参数不会在训练中更新。这被称为“特征提取”模式，训练速度更快，但效果通常不如全量微调。实战中还可以更细粒度地“冻结”部分层（例如仅冻结 Embedding 和前几层 Transformer Block，或按层号前缀选择参数，将其 &lt;code&gt;requires_grad=False&lt;/code&gt;），在&lt;strong&gt;训练速度 / 显存占用&lt;/strong&gt;与&lt;strong&gt;微调效果&lt;/strong&gt;之间做折中，这里为了示例清晰，仅展示了“全部冻结 BERT 主干”这一简单形式。而在 &lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt; 函数中，现在接收 &lt;code&gt;input_ids&lt;/code&gt; 和 &lt;code&gt;attention_mask&lt;/code&gt;。BERT 模型的输出 &lt;code&gt;outputs&lt;/code&gt; 中，&lt;code&gt;outputs.pooler_output&lt;/code&gt; 是 &lt;code&gt;[CLS]&lt;/code&gt; Token 对应的隐藏状态经过进一步处理后得到的向量，专门用于句子级别的任务，我们直接取用这个向量送入分类层即可。&lt;/p&gt;
&lt;h3 id="24-调整-trainer-与-predictor"&gt;2.4 调整 &lt;code&gt;Trainer&lt;/code&gt; 与 &lt;code&gt;Predictor&lt;/code&gt;
&lt;/h3&gt;&lt;p&gt;&lt;code&gt;Trainer&lt;/code&gt; 的 &lt;code&gt;_run_epoch&lt;/code&gt; 和 &lt;code&gt;_evaluate&lt;/code&gt; 方法需要修改，以将 &lt;code&gt;attention_mask&lt;/code&gt; 传递给模型。同时，保存模型的逻辑也应更新为 &lt;code&gt;transformers&lt;/code&gt; 推荐的方式。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 在 Trainer._run_epoch 方法中&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;input_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;attention_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;attention_mask&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 新增&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;labels&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;attention_mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;attention_mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 修改&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 在 Trainer._save_checkpoint 方法中&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 对于transformers模型，推荐使用save_pretrained来保存&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bert&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;save_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 单独保存分类头&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;classifier.pth&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;save&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;classifier_path&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;另外，别忘了在训练脚本中将 &lt;code&gt;tokenizer&lt;/code&gt; 一并保存到同一个目录，方便推理阶段直接从该目录恢复分词器配置与词表，例如：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 训练脚本中&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;save_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;code&gt;Predictor&lt;/code&gt; 的逻辑与 LSTM 版本非常相似，同样采用&lt;strong&gt;分块+投票&lt;/strong&gt;的策略。主要区别在于，现在需要为每个 &lt;code&gt;chunk&lt;/code&gt; 创建对应的 &lt;code&gt;attention_mask&lt;/code&gt;。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 在 Predictor.predict 方法中&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# (分块逻辑不变)...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;padded_chunks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;attention_masks&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;chunk&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;chunks&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_chunk_len&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padded_chunks&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pad_token_id&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;attention_masks&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 新增&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;input_ids_tensor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;padded_chunks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;attention_mask_tensor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attention_masks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 新增&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 将input_ids和attention_mask都传入&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;input_ids_tensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;attention_mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;attention_mask_tensor&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# (投票逻辑不变)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="25-更新训练超参数"&gt;2.5 更新训练超参数
&lt;/h3&gt;&lt;p&gt;微调 BERT 时，超参数的选择与从零训练有很大不同。&lt;strong&gt;学习率&lt;/strong&gt;通常设置得非常小，例如 &lt;code&gt;2e-5&lt;/code&gt; 到 &lt;code&gt;5e-5&lt;/code&gt; 之间，这是因为我们希望在预训练学到的知识基础上做“微小”的调整，过大的学习率会破坏这些知识。至于&lt;strong&gt;训练轮次&lt;/strong&gt;通常只需要 3-5 个轮次就足以收敛。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;span class="lnt"&gt;9
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;hparams&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;model_name&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;bert-base-uncased&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;num_classes&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;freeze_bert&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;epochs&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 减少轮次&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;learning_rate&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;2e-5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 降低学习率&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cuda&amp;#34;&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cpu&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;output_bert&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="26-推理阶段资源加载"&gt;2.6 推理阶段资源加载
&lt;/h3&gt;&lt;p&gt;推理阶段的整体流程与 LSTM 版本保持一致，但在&lt;strong&gt;加载推理所需资源&lt;/strong&gt;时有几个容易忽略的细节。训练时我们使用 &lt;code&gt;save_pretrained&lt;/code&gt; 将 BERT 主干和 tokenizer 一并保存到 &lt;code&gt;output_bert&lt;/code&gt; 目录，所以推理阶段不需要手动构建或加载词表，可以应该直接从该目录恢复。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;labels_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;label_map.json&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nb"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;labels_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;r&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;utf-8&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;label_map_loaded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;json&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;inference_tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;inference_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TextClassifierBERT&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model_name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;label_map_loaded&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;classifier.pth&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;inference_model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;classifier_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;map_location&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;这里有两个关键点：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;通过 &lt;code&gt;BertTokenizer.from_pretrained&lt;/code&gt; 从输出目录一次性恢复完整的分词器配置与词表，无需手动加载独立的 vocab 文件；&lt;/li&gt;
&lt;li&gt;&lt;code&gt;TextClassifierBERT&lt;/code&gt; 的 &lt;code&gt;model_name&lt;/code&gt; 也改为输出目录，从而加载微调后的 BERT 权重。&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="三实验结果与分析"&gt;三、实验结果与分析
&lt;/h2&gt;&lt;p&gt;完成所有改造后，我们启动训练。由于 BERT 模型参数量远大于之前的模型（&lt;code&gt;bert-base-uncased&lt;/code&gt; 约有1.1亿参数），每个 &lt;code&gt;epoch&lt;/code&gt; 的训练时间会更长，对计算资源（特别是 GPU 显存）的要求也更高。下面是本次实验的训练日志：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;1&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 625/625 &lt;span class="o"&gt;[&lt;/span&gt;01:53&amp;lt;00:00, 5.53it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;1&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 467/467 &lt;span class="o"&gt;[&lt;/span&gt;00:29&amp;lt;00:00, 16.02it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 1/5 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.4214 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8738
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新最佳模型已保存! Epoch: 1, 验证集准确率: 0.8738
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;2&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 625/625 &lt;span class="o"&gt;[&lt;/span&gt;02:14&amp;lt;00:00, 4.64it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;2&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 467/467 &lt;span class="o"&gt;[&lt;/span&gt;00:31&amp;lt;00:00, 14.84it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 2/5 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.1495 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8827
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新最佳模型已保存! Epoch: 2, 验证集准确率: 0.8827
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;3&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 625/625 &lt;span class="o"&gt;[&lt;/span&gt;02:18&amp;lt;00:00, 4.51it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;3&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 467/467 &lt;span class="o"&gt;[&lt;/span&gt;00:31&amp;lt;00:00, 14.66it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 3/5 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0698 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8915
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新最佳模型已保存! Epoch: 3, 验证集准确率: 0.8915
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;4&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 625/625 &lt;span class="o"&gt;[&lt;/span&gt;02:20&amp;lt;00:00, 4.46it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;4&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 467/467 &lt;span class="o"&gt;[&lt;/span&gt;00:32&amp;lt;00:00, 14.44it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 4/5 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0404 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8946
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新最佳模型已保存! Epoch: 4, 验证集准确率: 0.8946
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;5&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 625/625 &lt;span class="o"&gt;[&lt;/span&gt;02:22&amp;lt;00:00, 4.37it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;5&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 467/467 &lt;span class="o"&gt;[&lt;/span&gt;00:32&amp;lt;00:00, 14.48it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 5/5 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0251 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.9032
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新最佳模型已保存! Epoch: 5, 验证集准确率: 0.9032
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;训练完成！
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Tokenizer 和标签映射 &lt;span class="o"&gt;(&lt;/span&gt;output_bert&lt;span class="se"&gt;\l&lt;/span&gt;abel_map.json&lt;span class="o"&gt;)&lt;/span&gt; 已保存。
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/7_3_1.png" width="70%" alt="BERT 模型训练过程指标" /&gt;
&lt;br /&gt;
&lt;em&gt;图 7-6 BERT 模型训练损失与验证集准确率变化曲线&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;我们将三个模型的最佳性能进行对比：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;全连接模型 (基线)&lt;/strong&gt;：最佳验证集准确率 &lt;strong&gt;~0.8469&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;LSTM 模型 (正则化后)&lt;/strong&gt;：最佳验证集准确率 &lt;strong&gt;~0.8415&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;BERT 微调模型&lt;/strong&gt;：最佳验证集准确率 &lt;strong&gt;~0.9032&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;结果分析&lt;/strong&gt;:&lt;/p&gt;
&lt;p&gt;通过日志可以看出，BERT 模型的性能远超前两个从零开始训练的模型。具体来看，BERT的优势体现在以下几点：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;性能上限更高&lt;/strong&gt;：BERT 微调模型最终达到了约 &lt;strong&gt;90.32%&lt;/strong&gt; 的准确率，比之前两个模型高出&lt;strong&gt;超过 5 个百分点&lt;/strong&gt;，这是一个显著的提升。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;收敛速度快&lt;/strong&gt;：仅仅在第一个 &lt;code&gt;epoch&lt;/code&gt; 结束后，BERT 模型的准确率（87.38%）就已经超过了前两个模型经过 20 个 &lt;code&gt;epoch&lt;/code&gt; 充分训练后的最佳水平。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;强大的上下文理解能力&lt;/strong&gt;：BERT 的核心是 Transformer 的自注意力机制，它能够捕捉句子中任意两个词之间的依赖关系，无论它们相隔多远。使得 BERT 能够生成真正“上下文相关”的词向量，深刻理解词语在不同语境下的含义。&lt;/p&gt;
&lt;p&gt;（4）&lt;strong&gt;海量预训练知识的迁移&lt;/strong&gt;：BERT 在预训练阶段已经学习了丰富的语法、语义和世界知识。在微调时，这些知识被有效地迁移到了下游的新闻分类任务中。模型不再是一个“新生儿”，而是一个知识渊博的“专家”，只需要少量数据就能学会如何应用已有知识来完成新任务。&lt;/p&gt;
&lt;p&gt;（5）&lt;strong&gt;成熟的范式&lt;/strong&gt;：相比于需要精心设计网络结构、调整正则化策略的从零训练，BERT 的“预训练-微调”范式更加成熟和标准化。它为各种 NLP 任务提供了一个更高的起点，通过这种方式我们能够用相对少的代码和调试，就达到了出色的效果。&lt;/p&gt;
&lt;p&gt;这个结果初看似乎与上一节“如无必要，勿增实体”的结论有所矛盾。但这并不意味着奥卡姆剃刀原理失效了，而是提醒我们要在正确的维度上应用它。当然，这并不绝对否定简单模型（如全连接或 LSTM）通过更精细的特征工程、算法优化和超参数调优，&lt;strong&gt;有可能&lt;/strong&gt;在特定任务上接近甚至超越 BERT 的效果。但是，那通常需要耗费巨大的精力。相比之下，BERT 的成功揭示了“预训练-微调”范式的巨大优势。一方面它具备&lt;strong&gt;强大的预训练知识&lt;/strong&gt;，BERT 不是从零学习，而是将从海量文本中学到的通用语言知识&lt;strong&gt;迁移&lt;/strong&gt;到了我们的任务中，所以它对词汇和语境的理解深度远超任何从头训练的模型。另一方面它是一条&lt;strong&gt;更便捷的路径&lt;/strong&gt;，我们不再需要为特定任务从头设计复杂的网络或特征，而是可以方便地在一个强大的通用模型基础上进行微调，用更少的努力达到更高的性能上限。&lt;/p&gt;
&lt;p&gt;所以，这里的结论并非“模型越复杂越好”，而是“&lt;strong&gt;利用高质量的预训练模型进行微调，往往是一种在下游任务中以更少的开发精力和数据量达成更好性能的推荐方式&lt;/strong&gt;”。与此同时，我们也必须认识到，微调所带来的便捷和高效，是建立在 BERT 等大模型在预训练阶段已经消耗了巨大计算资源和时间的基础之上的。&lt;/p&gt;
&lt;h2 id="本章小结"&gt;本章小结
&lt;/h2&gt;&lt;p&gt;综合三节的实践，我们完成了一次 NLP 文本分类任务的探索之旅。整个过程为我们提供了宝贵的实践经验，并最终指向一个现代 NLP 项目中进行模型选择与迭代的常用流程：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;从一个简单、快速的基线模型开始&lt;/strong&gt;（如第一节的全连接模型）。建立基线有助于我们评估任务的难度，并为后续的优化提供一个比较标准。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;审慎地增加模型的复杂性&lt;/strong&gt;。第二节的实验证明，对于特定任务，更复杂的结构（从零训练的LSTM）未必能带来性能提升，其结果恰好印证了“奥卡姆剃刀原理”。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;优先考虑利用高质量的预训练模型进行微调&lt;/strong&gt;。当基线模型无法满足需求时，与其从零开始构建更复杂的模型，不如优先采用“预训练-微调”的范式。第三节的实验清晰地展示了，采用此范式通常是通往较优性能的高效路径。&lt;/p&gt;
&lt;p&gt;整个探索过程也侧面反映了 NLP 技术的发展脉络，也提供了一个更全面、辩证的实践准则，帮助我们在未来的项目中做出更明智的技术选型。&lt;/p&gt;
&lt;h2 id="附录使用bert实现中文文本情感分类"&gt;附录——使用BERT实现中文文本情感分类
&lt;/h2&gt;&lt;p&gt;参考资料：&lt;/p&gt;
&lt;p&gt;&lt;a class="link" href="https://blog.csdn.net/l01011_/article/details/145895561" target="_blank" rel="noopener"
&gt;微调BERT模型实现文本分类&lt;/a&gt;&lt;/p&gt;
&lt;h3 id="bert文本处理tokenizer"&gt;BERT文本处理——Tokenizer
&lt;/h3&gt;&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;bert-base-chinese&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_encode_plus&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_text_or_text_pairs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;今天我要努力学习自然语言处理&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;明天我要认真学习金融学&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;truncation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;max_length&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_length&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;pt&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;input_ids&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;input_ids&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;input_ids torch.Size([2, 15])
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;token_type_ids torch.Size([2, 15])
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;attention_mask torch.Size([2, 15])
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;[CLS] 今 天 我 要 努 力 学 习 自 然 语 言 处 [SEP]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;[CLS] 明 天 我 要 认 真 学 习 金 融 学 [SEP] [PAD] [PAD]
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="未微调时使用-text-classification-pipeline"&gt;未微调时使用 text-classification pipeline
&lt;/h3&gt;&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;BertForSequenceClassification&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pipeline&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertForSequenceClassification&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;bert-base-chinese&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_labels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;bert-base-chinese&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pipeline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;text-classification&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;我今天心情很好&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;你好，我是AI助手&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;我今天很生气&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: [&amp;#39;classifier.bias&amp;#39;, &amp;#39;classifier.weight&amp;#39;]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Device set to use cuda:0
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;[{&amp;#39;label&amp;#39;: &amp;#39;LABEL_2&amp;#39;, &amp;#39;score&amp;#39;: 0.5192456841468811}]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;[{&amp;#39;label&amp;#39;: &amp;#39;LABEL_2&amp;#39;, &amp;#39;score&amp;#39;: 0.502333402633667}]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;[{&amp;#39;label&amp;#39;: &amp;#39;LABEL_2&amp;#39;, &amp;#39;score&amp;#39;: 0.47641533613204956}]
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="微调后使用-text-classification-pipeline"&gt;微调后使用 text-classification pipeline
&lt;/h3&gt;&lt;p&gt;首先我们需要通过同样的方法来构建基础模型。然后通过语料样本来进行微调。这里我们使用lansinuote/ChnSentiCorp数据集。&lt;/p&gt;
&lt;p&gt;lansinuote/ChnSentiCorp数据集是一个用于中文情感分析的数据集。该数据集汇集了来自网络平台的多样化评论数据，主要覆盖三大领域：酒店住宿体验、笔记本电脑使用评价以及书籍阅读感受。数据集分为训练集、验证集和测试集。其中，训练集包含约 9600 条数据，验证集和测试集各包含约 1200 条数据。每条数据包含一段评论文本和对应的情感标签，情感标签通常为二分类（如好评、差评），部分版本可能包含中性标签。&lt;/p&gt;
&lt;p&gt;将lansinuote/ChnSentiCorp数据集下载之后，使用模型的分词器对其进行处理，将处理之后的数据放入模型进行训练，我们仅训练1轮看看效果。训练完之后再测试集上进行预测查看训练效果。并将模型保存。实现代码如下。&lt;/p&gt;
&lt;p&gt;模型预训练过程：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;BertForSequenceClassification&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Trainer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;TrainingArguments&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;load_dataset&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;re&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;bert-base-chinese&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;mode&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertForSequenceClassification&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;bert-base-chinese&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_labels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;load_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;lansinuote/ChnSentiCorp&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;clean_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sub&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;[^\w\s]+&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39; &amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;strip&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;text&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;clean_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;text&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;label&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;label&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;tokenize_function&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;examples&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;examples&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;text&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;padding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;max_length&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_length&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;encoded_dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokenize_function&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batched&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;training_args&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TrainingArguments&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;./results&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_train_epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;per_device_train_batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;per_device_eval_batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;eval_strategy&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;epoch&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logging_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;./logs&amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;trainer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Trainer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mode&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;training_args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;train_dataset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;encoded_dataset&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;train&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;eval_dataset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;encoded_dataset&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;validation&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;trainer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;trainer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;encoded_dataset&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;test&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;metric_key_prefix&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;eval&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;mode&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;save_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;./sentiment_model&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;save_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;./sentiment_model&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;保存模型后进行情感分类：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;BertForSequenceClassification&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pipeline&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;mode_dir&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;./sentiment_model&amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertForSequenceClassification&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mode_dir&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;BertTokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mode_dir&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;classifier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pipeline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;text-classification&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;我今天心情很好&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;你好，我是AI助手&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;我今天很生气&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出结果如下所示：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Device set to use cuda:0
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;[{&amp;#39;label&amp;#39;: &amp;#39;LABEL_1&amp;#39;, &amp;#39;score&amp;#39;: 0.7787742614746094}]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;[{&amp;#39;label&amp;#39;: &amp;#39;LABEL_0&amp;#39;, &amp;#39;score&amp;#39;: 0.5039134621620178}]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;[{&amp;#39;label&amp;#39;: &amp;#39;LABEL_0&amp;#39;, &amp;#39;score&amp;#39;: 0.9373114109039307}]
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;效果明显好了很多。&lt;/p&gt;</description></item><item><title>基于 LSTM 的文本分类</title><link>https://hanguangwu.github.io/blog/p/%E5%9F%BA%E4%BA%8E-lstm-%E7%9A%84%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB/</link><pubDate>Wed, 25 Mar 2026 17:34:25 -0800</pubDate><guid>https://hanguangwu.github.io/blog/p/%E5%9F%BA%E4%BA%8E-lstm-%E7%9A%84%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB/</guid><description>&lt;h1 id="基于-lstm-的文本分类"&gt;基于 LSTM 的文本分类
&lt;/h1&gt;&lt;p&gt;在上一节，我们实现了一个基于全连接层的文本分类模型。该模型虽然简单有效，但它的核心是将所有词元的特征向量进行平均池化，这本质上是一种“词袋”模型。这种方法的&lt;strong&gt;一个显著局限是它忽略了文本中词语的顺序&lt;/strong&gt;，而语序在多数 NLP 任务中是很重要的。那么，对于文本分类任务，捕捉序列信息是否总能带来性能提升呢？为了验证这一点，我们自然会想到循环神经网络（RNN）及其变体，如LSTM。在&lt;strong&gt;第三章第二节&lt;/strong&gt;中我们已经学习了 LSTM 的原理。理论上，它能够通过处理序列信息来捕捉更丰富的语义。本节将进行一次实验，我们将上一节的全连接模型改造为基于LSTM的模型，来&lt;strong&gt;探索&lt;/strong&gt;在本新闻分类任务上，序列建模是否会比简单的词袋模型更有效。&lt;/p&gt;
&lt;h2 id="一从词袋到序列建模"&gt;一、从“词袋”到序列建模
&lt;/h2&gt;&lt;p&gt;先回顾一下基线模型的主要操作：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;词嵌入&lt;/strong&gt;：将输入的 &lt;code&gt;token_ids&lt;/code&gt; (&lt;code&gt;[batch_size, seq_len]&lt;/code&gt;) 转换为词向量 &lt;code&gt;embedded&lt;/code&gt; (&lt;code&gt;[batch_size, seq_len, embed_dim]&lt;/code&gt;)。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;特征提取&lt;/strong&gt;：通过几层全连接网络，将每个词向量独立地映射到更高维的特征空间，得到 &lt;code&gt;token_features&lt;/code&gt; (&lt;code&gt;[batch_size, seq_len, hidden_dim]&lt;/code&gt;)。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;掩码平均池化&lt;/strong&gt;：为了处理变长序列，将所有 &lt;code&gt;token_features&lt;/code&gt; 沿 &lt;code&gt;seq_len&lt;/code&gt; 维度进行求和，再除以真实长度，得到一个代表整句话的向量 &lt;code&gt;pooled_features&lt;/code&gt; (&lt;code&gt;[batch_size, hidden_dim]&lt;/code&gt;)。&lt;/p&gt;
&lt;p&gt;（4）&lt;strong&gt;分类&lt;/strong&gt;：将 &lt;code&gt;pooled_features&lt;/code&gt; 输入最后的分类层，得到最终预测。&lt;/p&gt;
&lt;p&gt;这个流程的瓶颈在第三步。&lt;strong&gt;平均池化&lt;/strong&gt;操作将序列信息压缩成一个向量，这可能导致词序信息的丢失。&lt;/p&gt;
&lt;p&gt;与之相对，LSTM 网络通过其内部的循环结构和门控机制，能够逐个处理序列中的词元，并持续更新一个内部状态（记忆）。这个状态在每个时间步都会编码从序列开始到当前位置的所有信息。因此，当 LSTM 处理完整个序列后，它最终的隐藏状态理论上包含了对整个句子序列更丰富的语义表示，这&lt;strong&gt;有可能&lt;/strong&gt;比简单的词向量平均更能捕捉句子的深层含义。&lt;/p&gt;
&lt;h2 id="二代码修改实践"&gt;二、代码修改实践
&lt;/h2&gt;&lt;p&gt;将基线模型改造为 LSTM 模型，主要涉及这三个部分的修改：数据处理、模型结构和推理逻辑。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;a class="link" href="https://github.com/datawhalechina/base-nlp/blob/main/code/C7/02_lstm_text_classification.ipynb" target="_blank" rel="noopener"
&gt;本节完整代码&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="21-改造-collate_fn-以提供序列长度"&gt;2.1 改造 &lt;code&gt;collate_fn&lt;/code&gt; 以提供序列长度
&lt;/h3&gt;&lt;p&gt;为了让 LSTM 能够高效地处理被填充（Padding）过的变长序列，需要使用 &lt;code&gt;torch.nn.utils.rnn.pack_padded_sequence&lt;/code&gt; 函数。该函数要求在输入批次中明确提供每个样本在填充前的&lt;strong&gt;真实长度&lt;/strong&gt;。所以，我们应该修改 &lt;code&gt;collate_fn&lt;/code&gt; 函数，让它在返回 &lt;code&gt;token_ids&lt;/code&gt; 和 &lt;code&gt;labels&lt;/code&gt; 的同时，也返回一个包含该批次中每个序列真实长度的张量 &lt;code&gt;lengths&lt;/code&gt;。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;collate_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_batch_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_token_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_labels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_lengths&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 新增：记录真实长度&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;lengths&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_batch_len&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;lengths&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padded_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_token_ids&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;padded_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_labels&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;label&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 新增：将长度加入列表&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_lengths&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;lengths&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_token_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;labels&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_labels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 新增：返回长度张量&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;lengths&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_lengths&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="22-构建-textclassifierlstm-模型"&gt;2.2 构建 &lt;code&gt;TextClassifierLSTM&lt;/code&gt; 模型
&lt;/h3&gt;&lt;p&gt;这是本次优化的主要内容。我们将原来的 &lt;code&gt;TextClassifier&lt;/code&gt; 替换为一个新的 &lt;code&gt;TextClassifierLSTM&lt;/code&gt; 模型。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TextClassifierLSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bidirectional&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;TextClassifierLSTM&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;padding_idx&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bidirectional&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;bidirectional&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt; &lt;span class="c1"&gt;# 关键参数：输入和输出张量的维度为 (batch, seq, feature)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_directions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;bidirectional&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classifier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;num_directions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lengths&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 1. 打包序列&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;packed_embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;utils&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pack_padded_sequence&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;lengths&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cpu&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="c1"&gt;# 长度必须在CPU上&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;enforce_sorted&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 2. LSTM 前向传播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# hidden 和 cell 的形状: [n_layers * num_directions, batch_size, hidden_dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;packed_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;packed_embedded&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 3. 提取最终隐藏状态用于分类&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bidirectional&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 拼接最后一个时间步的前向和后向的隐藏状态&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# hidden[-2,:,:] 是前向的最后一个隐藏状态&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# hidden[-1,:,:] 是后向的最后一个隐藏状态&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,:,:],&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,:,:]),&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 只取最后一层的最后一个隐藏状态&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,:,:]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 4. 分类&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;&lt;strong&gt;模型解析&lt;/strong&gt;:&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;&lt;code&gt;__init__&lt;/code&gt;&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;除了词嵌入层 &lt;code&gt;nn.Embedding&lt;/code&gt; 和分类层 &lt;code&gt;nn.Linear&lt;/code&gt;，核心是一个 &lt;code&gt;nn.LSTM&lt;/code&gt; 层。&lt;/li&gt;
&lt;li&gt;增加几个 LSTM 相关的超参数：&lt;code&gt;n_layers&lt;/code&gt; (LSTM层数), &lt;code&gt;dropout&lt;/code&gt; (层间丢弃率), &lt;code&gt;bidirectional&lt;/code&gt; (是否使用双向LSTM)。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;batch_first=True&lt;/code&gt; 是一个重要的设置，它让 LSTM 接受 &lt;code&gt;[batch_size, seq_len, feature_dim]&lt;/code&gt; 形状的输入，与 &lt;code&gt;DataLoader&lt;/code&gt; 的输出保持一致，简化了代码。&lt;/li&gt;
&lt;li&gt;分类层的输入维度需要根据 &lt;code&gt;bidirectional&lt;/code&gt; 的值来动态确定。如果是双向的，隐藏层维度会加倍。&lt;/li&gt;
&lt;li&gt;在 PyTorch 的 &lt;code&gt;nn.LSTM&lt;/code&gt; 中，&lt;code&gt;dropout&lt;/code&gt; 只在 &lt;code&gt;n_layers &amp;gt; 1&lt;/code&gt; 时于层间生效；当仅 1 层时该参数不会起作用。若使用单层 LSTM，可将 &lt;code&gt;dropout&lt;/code&gt; 设为 &lt;code&gt;0.0&lt;/code&gt;（或保留任意值，效果一致），避免造成误解。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;（2）&lt;strong&gt;&lt;code&gt;forward&lt;/code&gt;&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;forward&lt;/code&gt; 函数现在额外接收 &lt;code&gt;lengths&lt;/code&gt; 参数。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;打包 (Packing)&lt;/strong&gt;：&lt;code&gt;pack_padded_sequence&lt;/code&gt; 是处理填充序列的关键。它会将一个填充过的批次数据（例如，多个句子被填充到相同长度）压缩成一个更紧凑的表示，LSTM 只需对真实的、非填充部分进行计算，大大提高了效率和准确性。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;最终状态提取&lt;/strong&gt;：LSTM 的输出 &lt;code&gt;hidden&lt;/code&gt; 张量包含了所有层在最后一个时间步的隐藏状态。我们通常取最后一层（对于单向 LSTM 是 &lt;code&gt;hidden[-1,:,:]&lt;/code&gt;）作为整个序列的语义表示。如果是双向 LSTM，则需要拼接前向和后向的最终隐藏状态。&lt;/li&gt;
&lt;li&gt;最后，将这个代表序列的 &lt;code&gt;hidden&lt;/code&gt; 向量送入分类器。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="23-调整-trainer-和-predictor"&gt;2.3 调整 &lt;code&gt;Trainer&lt;/code&gt; 和 &lt;code&gt;Predictor&lt;/code&gt;
&lt;/h3&gt;&lt;p&gt;由于模型 &lt;code&gt;forward&lt;/code&gt; 函数的输入签名发生了变化，我们需要对 &lt;code&gt;Trainer&lt;/code&gt; 和 &lt;code&gt;Predictor&lt;/code&gt; 进行微调，以确保 &lt;code&gt;lengths&lt;/code&gt; 张量被正确传递。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;1. &lt;code&gt;Trainer&lt;/code&gt; 修改&lt;/strong&gt;:
在 &lt;code&gt;_run_epoch&lt;/code&gt; 和 &lt;code&gt;_evaluate&lt;/code&gt; 方法中，从 &lt;code&gt;batch&lt;/code&gt; 字典中取出 &lt;code&gt;lengths&lt;/code&gt;，并将其传递给 &lt;code&gt;self.model&lt;/code&gt;。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 在 Trainer._run_epoch 方法中&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;labels&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;lengths&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;lengths&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lengths&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;（&lt;code&gt;_evaluate&lt;/code&gt; 方法同理）&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;2. &lt;code&gt;Predictor&lt;/code&gt; 修改&lt;/strong&gt;:
&lt;code&gt;Predictor&lt;/code&gt; 在处理单个文本时，也需要模拟批处理的逻辑：对文本分块后，手动计算每个块的长度，并进行填充，然后将 &lt;code&gt;chunk_tensors&lt;/code&gt; 和 &lt;code&gt;length_tensors&lt;/code&gt; 一同传入模型。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 在 Predictor.predict 方法中&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# (文本分块逻辑不变) ...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;chunks&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 手动计算长度并进行填充&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;chunk_lengths&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;chunks&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;max_chunk_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk_lengths&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;chunk_lengths&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;padded_chunks&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;chunk&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;chunks&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_chunk_len&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padded_chunks&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;padded_chunks&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;无法预测（文本过短）&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;chunk_tensors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;padded_chunks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;length_tensors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk_lengths&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 长度在CPU上&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk_tensors&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;length_tensors&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="24-更新训练入口代码"&gt;2.4 更新训练入口代码
&lt;/h3&gt;&lt;p&gt;最后一步，更新用于启动训练的单元格。我们需要：&lt;/p&gt;
&lt;p&gt;（1）为 LSTM 添加新的超参数（&lt;code&gt;n_layers&lt;/code&gt;, &lt;code&gt;dropout&lt;/code&gt;, &lt;code&gt;bidirectional&lt;/code&gt;）。&lt;/p&gt;
&lt;p&gt;（2）实例化新的 &lt;code&gt;TextClassifierLSTM&lt;/code&gt; 模型。&lt;/p&gt;
&lt;p&gt;（3）（可选）为新的模型实验设置一个独立的输出目录，如 &lt;code&gt;&amp;quot;output_lstm&amp;quot;&lt;/code&gt;。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;hparams&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;vocab_size&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;embed_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;hidden_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;256&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;num_classes&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;n_layers&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 新增&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;dropout&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 新增：此处显式设为 0，当前不启用 Dropout&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;bidirectional&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 新增&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;epochs&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;learning_rate&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;0.001&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cuda&amp;#34;&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cpu&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;output_lstm&amp;#34;&lt;/span&gt; &lt;span class="c1"&gt;# 修改输出目录&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 实例化新模型&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TextClassifierLSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;vocab_size&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;embed_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;hidden_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;num_classes&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;n_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;n_layers&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;dropout&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;bidirectional&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;bidirectional&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# (后续代码不变)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;完成以上修改后，重新运行整个 Notebook，即可训练一个能够处理序列信息的 LSTM 模型。接下来，我们来对比它与基线模型的性能，并分析序列建模在本次任务中的实际效果。&lt;/p&gt;
&lt;h3 id="25-实验结果与分析"&gt;2.5 实验结果与分析
&lt;/h3&gt;&lt;p&gt;在分别运行了基线的全连接模型和我们新构建的LSTM模型后（均未加正则化策略），我们得到了如下的性能数据：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;全连接模型 (基线)&lt;/strong&gt;：最终验证集最佳准确率约为 &lt;strong&gt;0.8469&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;LSTM 模型&lt;/strong&gt;：最终验证集最佳准确率约为 &lt;strong&gt;0.8143&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;p align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/7_2_1.png" width="70%" alt="LSTM 模型训练过程指标" /&gt;
&lt;br /&gt;
&lt;em&gt;图 7-4 LSTM 模型训练损失与验证集准确率变化曲线&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;结果分析&lt;/strong&gt;:&lt;/p&gt;
&lt;p&gt;显然结果并不符合我们的预期，理论上更能捕捉序列信息的 LSTM 模型，在本次新闻分类任务上的表现反而&lt;strong&gt;劣于&lt;/strong&gt;简单的全连接模型。这个发现说明&lt;strong&gt;模型的复杂性与任务的实际需求应该匹配&lt;/strong&gt;。值得一提的是，本次对比实验并未严格控制固定随机数种子，所以每次运行的结果会存在细微的差别。然而，一个稳定的现象是，引入序列建模的 LSTM 并未带来性能提升，其结果反而总是比简单的全连接模型低 &lt;strong&gt;~2%&lt;/strong&gt; 左右，足以让我们得出以下结论。&lt;/p&gt;
&lt;p&gt;出现这种结果的可能原因有两点：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;任务对语序相对不敏感&lt;/strong&gt;：在目前的数据规模、从零开始训练模型的前提下，这个新闻分类任务在很大程度上依赖于&lt;strong&gt;关键词&lt;/strong&gt;。例如，看到 &amp;ldquo;Jesus&amp;rdquo;、&amp;ldquo;God&amp;rdquo; 很可能属于宗教类；看到 &amp;ldquo;Graphics&amp;rdquo;、&amp;ldquo;Monitor&amp;rdquo; 很可能属于计算机图形类。全连接模型本质上是一个高效的“词袋”模型，非常擅长捕捉这类强特征词的存在与否。对于这个特定实验设置来说，&lt;strong&gt;“有哪些词”远比“这些词的顺序”更重要&lt;/strong&gt;。LSTM 为学习语序付出的额外努力，在这里并没有转化为实际的性能优势。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;模型复杂性与过拟合&lt;/strong&gt;：LSTM 模型比简单的全连接网络复杂得多，拥有更多的参数。虽然它理论上能学习到更复杂的模式，但也&lt;strong&gt;更容易在数据量不够大的情况下陷入过拟合&lt;/strong&gt;。从训练日志中可以看到，普通 LSTM 的训练损失已经非常低，但验证集准确率却不高，这是过拟合症状。模型过于“记住”了训练集中的特定句子结构，而没有学到普适的规律。&lt;/p&gt;
&lt;h2 id="三过拟合解决方案与效果对比"&gt;三、过拟合解决方案与效果对比
&lt;/h2&gt;&lt;p&gt;基础 LSTM 模型效果不佳的一个可能原因是 &lt;strong&gt;过拟合&lt;/strong&gt;。在第一节的末尾，我们介绍了三种简单有效的正则化方法分别是&lt;strong&gt;提前停止&lt;/strong&gt;、&lt;strong&gt;随机Token遮盖&lt;/strong&gt;和&lt;strong&gt;Dropout&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;现在，我们将这三种方法组合应用到新的 LSTM 模型上，观察它们的综合效果。&lt;/p&gt;
&lt;h3 id="31-随机token遮盖"&gt;3.1 随机Token遮盖
&lt;/h3&gt;&lt;p&gt;这是一种数据增强技术。我们在 &lt;code&gt;TextClassificationDataset&lt;/code&gt; 的基础上创建一个子类，在 &lt;code&gt;__getitem__&lt;/code&gt; 方法中，对训练样本的 &lt;code&gt;token_ids&lt;/code&gt; 进行随机替换。具体来说，以一定概率（例如10%）将部分词元替换为 &lt;code&gt;&amp;lt;UNK&amp;gt;&lt;/code&gt; 对应的 ID。使得模型不能过度依赖个别特征词，而是要从更广的上下文中学习语义，从而增强泛化能力。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;random&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TextClassificationDatasetWithMasking&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;TextClassificationDataset&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;texts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;is_train&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask_prob&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;texts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;is_train&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mask_prob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mask_prob&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unk_token_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;token_to_id&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;get&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&amp;lt;UNK&amp;gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__getitem__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 关键：创建副本，避免修改原始数据&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;item&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__getitem__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_train&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;token_ids&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;masked_token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;token_id&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 不遮盖PAD (ID=0)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;token_id&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mask_prob&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;masked_token_ids&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unk_token_id&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;masked_token_ids&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_id&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;token_ids&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;masked_token_ids&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;blockquote&gt;
&lt;p&gt;在 &lt;code&gt;TextClassificationDatasetWithMasking&lt;/code&gt; 的 &lt;code&gt;__getitem__&lt;/code&gt; 方法中，有一个非常关键的细节，&lt;code&gt;item = super().__getitem__(idx).copy()&lt;/code&gt;。必须使用 &lt;code&gt;.copy()&lt;/code&gt; 方法来创建数据的副本。&lt;/p&gt;
&lt;p&gt;如果没有 &lt;code&gt;.copy()&lt;/code&gt;，&lt;code&gt;__getitem__&lt;/code&gt; 中的修改将会永久地改变原始数据集。这会导致在第二个训练周期（Epoch）时，模型看到的是已经被第一次随机遮盖过的数据，并在此基础上进行二次遮盖，如此循环往复，最终导致有效信息完全丢失。数据增强必须保证每一轮都是在干净的原始数据上进行的独立操作。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="32-提前停止-early-stopping"&gt;3.2 提前停止 (Early Stopping)
&lt;/h3&gt;&lt;p&gt;提前停止是一种简单而高效的正则化策略。其核心思想是在训练过程中持续监控模型在验证集上的性能。如果验证集准确率（或损失）连续 &lt;code&gt;N&lt;/code&gt; 个轮次（&lt;code&gt;N&lt;/code&gt; 称为“耐心值” &lt;code&gt;patience&lt;/code&gt;）没有超过历史最佳水平，就认为模型已经达到了最佳点或开始过拟合，此时应提前终止训练。我们在 &lt;code&gt;Trainer&lt;/code&gt; 类的基础上创建一个子类，重写 &lt;code&gt;train&lt;/code&gt; 方法以实现该逻辑。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;os&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;json&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TrainerWithEarlyStopping&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Trainer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;valid_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;.&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;patience&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;valid_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;patience&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;patience&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epochs_no_improve&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label_map&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;avg_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_run_epoch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;val_accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Epoch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; | 训练损失: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;avg_loss&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.4f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; | 验证集准确率: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;val_accuracy&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.4f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;current_best&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_accuracy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_save_checkpoint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_accuracy&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_accuracy&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;current_best&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epochs_no_improve&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epochs_no_improve&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epochs_no_improve&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;patience&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;提前停止于 Epoch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;，因为验证集准确率连续 &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;patience&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; 轮未提升。&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s2"&gt;训练完成！&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# ... (保存词典和标签映射)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="33-实验与对比"&gt;3.3 实验与对比
&lt;/h3&gt;&lt;p&gt;最后，我们将所有正则化策略整合起来，实例化相应的数据集、模型和训练器，并启动训练。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 1. 创建应用了随机遮盖的数据集&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;train_dataset_reg&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TextClassificationDatasetWithMasking&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;is_train&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask_prob&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;train_loader_reg&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset_reg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;collate_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;collate_fn&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 继续使用前面改造过的 collate_fn，以便返回 lengths&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 2. 实例化模型，并启用 Dropout&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model_reg&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TextClassifierLSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# 启用 Dropout&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 3. 使用带提前停止功能的训练器&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;trainer_reg&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TrainerWithEarlyStopping&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model_reg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_lstm_regularized&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;patience&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 启动训练&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;trainer_reg&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;完成训练后，可以通过比较两个实验的输出日志，来分析正则化带来的效果：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;训练是否提前停止？&lt;/strong&gt; 如果是，说明模型可能在更早的阶段就已收敛。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;最终验证集准确率&lt;/strong&gt;：对比 &lt;code&gt;output_lstm&lt;/code&gt; 和 &lt;code&gt;output_lstm_regularized&lt;/code&gt; 中 &lt;code&gt;best_model.pth&lt;/code&gt; 对应的验证集准确率，正则化版本是否取得了更好的泛化性能？&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;训练损失与验证准确率曲线&lt;/strong&gt;：观察两个实验的日志，正则化版本的验证集准确率曲线是否更平滑，或者与训练损失的差距是否更小？这些都是过拟合得到缓解的迹象。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="34-最终效果分析"&gt;3.4 最终效果分析
&lt;/h3&gt;&lt;p&gt;在应用了这三种策略后，我们的 LSTM 模型取得了约 &lt;strong&gt;0.8415&lt;/strong&gt; 的最佳验证集准确率。从下面的训练日志中可以看到，模型在第16轮达到了最佳性能，并在第19轮成功触发了“提前停止”策略，避免了不必要的训练和潜在的过拟合。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;16&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 224/224 &lt;span class="o"&gt;[&lt;/span&gt;00:06&amp;lt;00:00, 33.60it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;16&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 169/169 &lt;span class="o"&gt;[&lt;/span&gt;00:01&amp;lt;00:00, 97.13it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 16/20 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0206 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8415
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;新最佳模型已保存! Epoch: 16, 验证集准确率: 0.8415
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;17&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 224/224 &lt;span class="o"&gt;[&lt;/span&gt;00:06&amp;lt;00:00, 32.80it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;17&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 169/169 &lt;span class="o"&gt;[&lt;/span&gt;00:01&amp;lt;00:00, 94.98it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 17/20 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0176 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8243
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;18&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 224/224 &lt;span class="o"&gt;[&lt;/span&gt;00:07&amp;lt;00:00, 31.27it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;18&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 169/169 &lt;span class="o"&gt;[&lt;/span&gt;00:01&amp;lt;00:00, 96.61it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 18/20 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0172 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8175
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;19&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 224/224 &lt;span class="o"&gt;[&lt;/span&gt;00:06&amp;lt;00:00, 33.10it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;19&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 169/169 &lt;span class="o"&gt;[&lt;/span&gt;00:01&amp;lt;00:00, 95.57it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 19/20 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0136 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8066
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;提前停止于 Epoch 19，因为验证集准确率连续 &lt;span class="m"&gt;3&lt;/span&gt; 轮未提升。
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;训练完成！
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;词典 &lt;span class="o"&gt;(&lt;/span&gt;output_lstm_regularized&lt;span class="se"&gt;\v&lt;/span&gt;ocab.json&lt;span class="o"&gt;)&lt;/span&gt; 和标签映射 &lt;span class="o"&gt;(&lt;/span&gt;output_lstm_regularized&lt;span class="se"&gt;\l&lt;/span&gt;abel_map.json&lt;span class="o"&gt;)&lt;/span&gt; 已保存。
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/7_2_2.png" width="70%" alt="正则化LSTM模型训练过程指标" /&gt;
&lt;br /&gt;
&lt;em&gt;图 7-5 正则化 LSTM 模型训练损失与验证集准确率变化曲线&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;这个结果展示了正则化策略的价值：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;相比于无正则化的LSTM（~0.8143）&lt;/strong&gt;：性能得到了&lt;strong&gt;明显提升&lt;/strong&gt;。这证明我们之前的判断是正确的——基础 LSTM 模型的一个主要问题就是过拟合。通过数据增强、提前停止和层间Dropout的组合，有效地抑制了模型对训练数据的“死记硬背”，使它学习到了更具泛化能力的模式。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;随机 Token 遮盖&lt;/strong&gt;强迫模型不能过度依赖训练集中少数几个“明星”关键词（例如特定作者），而是要学会识别更广泛、更多样化的关键词组合来做出判断，从而提升模型的健壮性和泛化能力。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;提前停止&lt;/strong&gt;则像一个“安全阀”，在模型性能达到巅峰并即将开始下滑（过拟合）的时刻及时终止了训练，锁定了最佳的模型状态。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Dropout&lt;/strong&gt;在多层 LSTM 中，会对除最后一层外各层的输出施加随机丢弃（dropout），相当于在层与层之间随机“关闭”部分神经元连接，破坏可能形成的“共适应”关系，从而增强模型的独立特征学习能力。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;相比于全连接模型（~0.8469）&lt;/strong&gt;：经过正则化后，LSTM模型的性能已经非常接近，但仍然略逊于更简单的基线模型。再次证明了对于这个特定的、以关键词为驱动的新闻分类任务，一个高效的“词袋”模型已经足够强大。试图用更复杂的序列模型来捕捉此处并不关键的语序信息，即使在组合了多种正则化策略后，也难以带来超越性的优势。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个系列的实验也印证了著名的“奥卡姆剃刀原理”——&lt;strong&gt;如无必要，勿增实体&lt;/strong&gt;。在模型选择上，我们应该&lt;strong&gt;从一个基线开始，逐步增加复杂性，并通过实验去验证每一步改动是否真的带来了收益&lt;/strong&gt;。&lt;/p&gt;</description></item><item><title>文本分类简单实现</title><link>https://hanguangwu.github.io/blog/p/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E7%AE%80%E5%8D%95%E5%AE%9E%E7%8E%B0/</link><pubDate>Wed, 25 Mar 2026 16:34:25 -0800</pubDate><guid>https://hanguangwu.github.io/blog/p/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E7%AE%80%E5%8D%95%E5%AE%9E%E7%8E%B0/</guid><description>&lt;h1 id="文本分类简单实现"&gt;文本分类简单实现
&lt;/h1&gt;&lt;h2 id="一文本分类任务概述"&gt;一、文本分类任务概述
&lt;/h2&gt;&lt;p&gt;文本分类是 NLP 中常见的任务之一，它的目标是将给定的文本自动分配到一个或多个预定义的类别中。这项技术的实际应用广泛，例如&lt;strong&gt;情感分析&lt;/strong&gt;可以判断商品评价或电影评论的情感倾向是正面、负面还是中性；&lt;strong&gt;新闻分类&lt;/strong&gt;能够将新闻报道自动归入体育、财经、科技或娱乐等不同频道；在智能客服或语音助手中，&lt;strong&gt;意图识别&lt;/strong&gt;技术用于判断用户输入的指令属于查询天气还是播放音乐等特定意图；而&lt;strong&gt;垃圾邮件过滤&lt;/strong&gt;则能自动识别并拦截收件箱中的垃圾邮件，净化沟通环境。&lt;/p&gt;
&lt;p&gt;在理论篇的&lt;strong&gt;第二章&lt;/strong&gt;中，我们已经学习了如何将文本进行分词，并通过词向量技术将其转换为模型可以理解的数值形式。本节将在此基础上，以一个经典的新闻分类任务为例，详细讲解如何从零开始，一步步构建、训练和评估一个用于文本分类的深度学习模型。这个过程将涵盖数据处理、模型设计、训练循环、推理预测等所有核心环节。&lt;/p&gt;
&lt;h2 id="二nlp-项目通用流程"&gt;二、NLP 项目通用流程
&lt;/h2&gt;&lt;p&gt;无论是文本分类，还是其他更复杂的 NLP 任务，深度学习的解决方案通常遵循一个标准化的项目流程。可以概括为以下几个核心模块：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;graph LR
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; A[1. 数据准备] --&amp;gt; B[2. 模型构建]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; B --&amp;gt; C[3. 定义损失与优化器]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; subgraph &amp;#34;训练循环 (Training Loop)&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; D[4. 迭代训练] --&amp;gt; E[5. 模型评估]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; E --&amp;gt; F{6. 是否更优?}
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; F -- 是/Yes --&amp;gt; G[7. 保存模型]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; F -- 否/No --&amp;gt; D
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; G --&amp;gt; D
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; end
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; C --&amp;gt; D
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;这个流程是搭建深度学习应用的通用范式，是一套标准化、可复用的模板。理解并掌握这套流程，比单纯实现某个模型更为重要。在接下来的内容中，我们将按照这个流程，将各个模块封装成独立的类，构建一个更规范、更易于维护和扩展的项目。&lt;/p&gt;
&lt;h2 id="三新闻文本分类代码实践"&gt;三、新闻文本分类代码实践
&lt;/h2&gt;&lt;blockquote&gt;
&lt;p&gt;&lt;a class="link" href="https://github.com/datawhalechina/base-nlp/blob/main/code/C7/01_text_classification.ipynb" target="_blank" rel="noopener"
&gt;本节完整代码&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;本节将使用 &lt;code&gt;scikit-learn&lt;/code&gt; 库中的 &lt;code&gt;20 Newsgroups&lt;/code&gt; 数据集，这是一个包含约20000篇新闻文档、近似均衡分布在20个不同新闻组（类别）的集合。&lt;/p&gt;
&lt;h3 id="31-模块化设计思路"&gt;3.1 模块化设计思路
&lt;/h3&gt;&lt;p&gt;在开始编写具体代码之前，更重要的步骤是“设计”。一个原则是，要先想清楚每个模块的输入和输出是什么。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;数据模块的输出是什么？&lt;/strong&gt; -&amp;gt; 模型需要的“词元ID序列” (&lt;code&gt;token_ids&lt;/code&gt;) 张量和“标签ID” (&lt;code&gt;label_ids&lt;/code&gt;) 张量。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;模型的输入是什么？&lt;/strong&gt; -&amp;gt; 数据模块的输出。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;模型的输出是什么？&lt;/strong&gt; -&amp;gt; 每个类别的置信度。&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;如果对数据处理感到困惑，不妨先从模型定义开始&lt;/strong&gt;。一旦我们清晰地定义了模型 &lt;code&gt;forward&lt;/code&gt; 函数需要的输入（例如，ID序列），数据处理阶段的目标就变得很明确了，只需要把原始文本处理成模型所需的格式。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id="32-步骤一数据解析与加载"&gt;3.2 步骤一：数据解析与加载
&lt;/h3&gt;&lt;h4 id="321-数据加载"&gt;3.2.1 数据加载
&lt;/h4&gt;&lt;p&gt;首先，加载&lt;code&gt;scikit-learn&lt;/code&gt;提供的原始数据集。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;sklearn.datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;fetch_20newsgroups&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 为了方便演示，只选择4个类别&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;categories&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;alt.atheism&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;soc.religion.christian&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;comp.graphics&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;sci.med&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;train_dataset_raw&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;fetch_20newsgroups&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;subset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;train&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;categories&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;categories&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;test_dataset_raw&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;fetch_20newsgroups&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;subset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;test&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;categories&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;categories&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;sample&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;text_preview&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][:&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;label&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;sample&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;text_preview&amp;#39;&lt;/span&gt;: &lt;span class="s1"&gt;&amp;#39;From: sd345@city.ac.uk (Michael Collier)\nSubject: Converting images to HP LaserJet III?\nNntp-Posting-Host: hampton\nOrganization: The City University\nLines: 14\n\nDoes anyone know of a good way (standard&amp;#39;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s1"&gt;&amp;#39;label&amp;#39;&lt;/span&gt;: &lt;span class="s1"&gt;&amp;#39;comp.graphics&amp;#39;&lt;/span&gt;&lt;span class="o"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h4 id="322-数据探索与可视化"&gt;3.2.2 数据探索与可视化
&lt;/h4&gt;&lt;p&gt;在进行任何复杂的预处理之前，对数据进行探索性分析是很重要且必要的。这有助于我们理解数据特性，从而做出更合理的设计决策。&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;文本长度分布：&lt;/strong&gt;&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;re&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 为了进行探索，先定义一个简单的分词函数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;basic_tokenize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lower&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sub&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;[^a-z0-9(),.!?\&amp;#39;`]&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34; &amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sub&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;([,.!?\&amp;#39;`])&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34; \1 &amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tokens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;strip&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tokens&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 计算每篇文档的词元数量&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;train_text_lengths&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;basic_tokenize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_text_lengths&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;blue&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Distribution of Text Lengths in Training Data&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Number of Tokens&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Frequency&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/7_1_1.png" width="70%" alt="训练集文本长度分布" /&gt;
&lt;br /&gt;
&lt;em&gt;图 7-1 训练集文本长度分布&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;词频分布：&lt;/strong&gt;&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;collections&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Counter&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 计算所有词元的频率&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;word_counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;basic_tokenize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 获取频率并按降序排序&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;frequencies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;sorted&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;reverse&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 生成排名&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;ranks&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;frequencies&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 绘制对数坐标图&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loglog&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ranks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;frequencies&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Rank vs. Frequency (Log-Log Scale)&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Rank (Log)&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Frequency (Log)&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/7_1_2.png" width="70%" alt="词频-排名对数图" /&gt;
&lt;br /&gt;
&lt;em&gt;图 7-2 词频-排名对数图&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;通过&lt;strong&gt;数据分析&lt;/strong&gt;可以发现，图 7-1 的文本长度分布直方图显示大部分文本的长度集中在较短的区间，但仍存在少量长度非常长的“异常值”，说明简单的直接截断策略可能会丢失过多信息。除此之外，如图 7-2 的对数坐标图所示，词频分布呈现出自然语言中典型的齐夫定律（Zipf&amp;rsquo;s Law）现象，即少数高频词占据了绝大多数的出现次数，而大量词汇构成了长长的“尾巴”，其出现频率极低。&lt;/p&gt;
&lt;h4 id="323-tokenizer-封装"&gt;3.2.3 Tokenizer 封装
&lt;/h4&gt;&lt;p&gt;接下来，我们创建一个 &lt;code&gt;Tokenizer&lt;/code&gt;（分词器）类来负责所有与分词、词典构建和 ID 转换相关的任务，它封装了与数据探索时相同的分词逻辑并增加了 ID 转换等功能。其中 &lt;code&gt;_tokenize_text&lt;/code&gt; 方法实现了一套基于正则表达式的分词策略，先将文本转为小写，通过 &lt;code&gt;re.sub&lt;/code&gt; 移除非字母、数字和基本标点之外的字符，为了确保标点符号能被作为独立的词元，在它们周围添加空格，最后按空格切分文本得到词元列表。在词典构建方面，通过遍历所有训练文本统计词频，并过滤掉出现次数过少的低频词以减少词典规模和噪声，同时词典初始化时会预设两个特殊的 Token，即用于填充的 &lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt;（ID 为 0）和用于表示未登录词的 &lt;code&gt;&amp;lt;UNK&amp;gt;&lt;/code&gt;（ID 为 1）。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Tokenizer&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;token_to_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nd"&gt;@staticmethod&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_tokenize_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lower&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sub&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;[^a-z0-9(),.!?&lt;/span&gt;&lt;span class="se"&gt;\\&lt;/span&gt;&lt;span class="s2"&gt;&amp;#39;`]&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34; &amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sub&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;([,.!?&lt;/span&gt;&lt;span class="se"&gt;\\&lt;/span&gt;&lt;span class="s2"&gt;&amp;#39;`])&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34; &lt;/span&gt;&lt;span class="se"&gt;\\&lt;/span&gt;&lt;span class="s2"&gt;1 &amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;tokens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;strip&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tokens&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;convert_tokens_to_ids&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokens&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;token_to_id&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;get&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&amp;lt;UNK&amp;gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;token&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tokens&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;tokenize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_tokenize_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__len__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h4 id="324-tokenizer-与词典构建"&gt;3.2.4 Tokenizer 与词典构建
&lt;/h4&gt;&lt;p&gt;基于前面对数据的分析，现在可以正式构建词典和 &lt;code&gt;Tokenizer&lt;/code&gt;。词典将只包含在训练集中出现超过 &lt;code&gt;min_freq&lt;/code&gt; 次的词元。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;build_vocab_from_counts&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;min_freq&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vocab&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&amp;lt;PAD&amp;gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;&amp;lt;UNK&amp;gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;count&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;items&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;count&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;=&lt;/span&gt; &lt;span class="n"&gt;min_freq&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;vocab&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 使用上一步计算出的word_counts来构建词典&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vocab&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;build_vocab_from_counts&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;min_freq&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;vocab_size&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;)}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;vocab_size&amp;#39;&lt;/span&gt;: 10983&lt;span class="o"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h4 id="325-如何处理长文本"&gt;3.2.5 如何处理长文本？
&lt;/h4&gt;&lt;p&gt;在数据探索中能够发现，&lt;code&gt;20 Newsgroups&lt;/code&gt; 数据集中存在大量超长文本，有的甚至超过1万个词元。而大部分深度学习模型（尤其是非 Transformer 模型）都难以处理过长的序列，直接输入会导致内存溢出和计算效率低下。而简单的截断会丢失大量文本末尾的信息，可能会导致关键信息丢失。&lt;/p&gt;
&lt;p&gt;一个更好的方法是将一篇长文档切分成多个固定长度、且有部分重叠的“文本块”（Chunks）。例如，一篇 1000 词的文档若按 &lt;code&gt;max_len=128&lt;/code&gt;、&lt;code&gt;overlap=26&lt;/code&gt; 的方式进行切分，此时第一个块会包含 &lt;code&gt;words[0:128]&lt;/code&gt;，第二个块则顺延为 &lt;code&gt;words[102:230]&lt;/code&gt;（&lt;code&gt;128-26=102&lt;/code&gt;），并以此类推完成整个文档的切分。这样做有两大好处，一方面通过&lt;strong&gt;信息保全&lt;/strong&gt;完整地利用了整篇文章的信息；另一方面则带来了&lt;strong&gt;数据增强&lt;/strong&gt;的效果，将一篇长文档变成了多条训练样本，增加了训练数据量。&lt;/p&gt;
&lt;h4 id="326-封装-dataset-和-dataloader"&gt;3.2.6 封装 &lt;code&gt;Dataset&lt;/code&gt; 和 &lt;code&gt;DataLoader&lt;/code&gt;
&lt;/h4&gt;&lt;p&gt;&lt;code&gt;TextClassificationDataset&lt;/code&gt; 负责的核心逻辑是接收原始文本，调用 &lt;code&gt;tokenizer&lt;/code&gt; 进行 ID 化，并应用 &lt;strong&gt;滑窗分割&lt;/strong&gt; 策略处理长文本。如果文本超过 &lt;code&gt;max_len&lt;/code&gt;，则会进行切分。代码中的 &lt;code&gt;stride&lt;/code&gt; 被设置为 &lt;code&gt;max_len&lt;/code&gt; 的 80%，意味着每个文本块之间有20%的重叠，这有助于保持上下文信息的连续性。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nn"&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;torch.utils.data&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Dataset&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;tqdm&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TextClassificationDataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Dataset&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;texts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;processed_data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;texts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;total&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;convert_tokens_to_ids&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 滑窗分割逻辑&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;processed_data&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;label&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;stride&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stride&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;chunk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;processed_data&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;chunk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;label&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__len__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;processed_data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__getitem__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;processed_data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;接着，定义 &lt;code&gt;collate_fn&lt;/code&gt; 函数，它负责将一个批次内长短不一的样本，通过 &lt;strong&gt;填充&lt;/strong&gt; 操作（使用 &lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt; 对应的ID &lt;code&gt;0&lt;/code&gt;），打包成形状规整的张量，以便模型进行批处理。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;collate_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;max_batch_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_token_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_batch_len&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padded_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;padding_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_token_ids&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;padded_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;batch_labels&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;label&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_token_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;labels&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_labels&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;使用我们创建的 &lt;code&gt;Dataset&lt;/code&gt; 和 &lt;code&gt;collate_fn&lt;/code&gt; 来实例化训练和验证数据加载器 &lt;code&gt;DataLoader&lt;/code&gt;：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;span class="lnt"&gt;3
&lt;/span&gt;&lt;span class="lnt"&gt;4
&lt;/span&gt;&lt;span class="lnt"&gt;5
&lt;/span&gt;&lt;span class="lnt"&gt;6
&lt;/span&gt;&lt;span class="lnt"&gt;7
&lt;/span&gt;&lt;span class="lnt"&gt;8
&lt;/span&gt;&lt;span class="lnt"&gt;9
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="nn"&gt;torch.utils.data&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;DataLoader&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;train_dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TextClassificationDataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;train_loader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;collate_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;collate_fn&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;valid_dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TextClassificationDataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;test_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;test_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;valid_loader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;valid_dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;collate_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;collate_fn&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;train_samples&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;valid_samples&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;valid_dataset&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;batch_size&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;train_samples&amp;#39;&lt;/span&gt;: 7142, &lt;span class="s1"&gt;&amp;#39;valid_samples&amp;#39;&lt;/span&gt;: 5408, &lt;span class="s1"&gt;&amp;#39;batch_size&amp;#39;&lt;/span&gt;: 32&lt;span class="o"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="33-步骤二模型构建"&gt;3.3 步骤二：模型构建
&lt;/h3&gt;&lt;h4 id="331-模型结构设计"&gt;3.3.1 模型结构设计
&lt;/h4&gt;&lt;p&gt;在编写模型代码前，先梳理清楚数据的“变形记”，也就是张量形状在网络中如何变化：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-fallback" data-lang="fallback"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Input:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; token_ids (词元ID序列): [batch_size, seq_len]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; |
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; V
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;nn.Embedding(padding_idx=0)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; |
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; V
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; embedded: [batch_size, seq_len, embed_dim]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; |
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; V
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;nn.Linear(embed_dim, hidden_dim*2) -&amp;gt; nn.ReLU -&amp;gt; nn.Linear(hidden_dim*2, hidden_dim*4) -&amp;gt; nn.ReLU
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; |
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; V
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; token_features: [batch_size, seq_len, hidden_dim*4]
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; |
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; V
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Masked Average Pooling (关键操作)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; |
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; V
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; pooled_features: [batch_size, hidden_dim*4] &amp;lt;-- seq_len维度被聚合掉了
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; |
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; V
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;nn.Linear (分类层)
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; |
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; V
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Output:
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; logits: [batch_size, num_classes]
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h4 id="332-掩码平均池化"&gt;3.3.2 掩码平均池化
&lt;/h4&gt;&lt;p&gt;池化（Pooling）的目的是将一个序列的特征（&lt;code&gt;[seq_len, hidden_dim]&lt;/code&gt;）聚合成一个代表整条序列的向量（&lt;code&gt;[hidden_dim]&lt;/code&gt;），但简单的平均池化会受到填充 &lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt; 的影响从而导致语义偏差。举例来说，假设一个批次有 2 个句子且最大长度为 4，其中句子 A 的真实长度为 4（表示为 &lt;code&gt;[v_I, v_love, v_NLP, v_too]&lt;/code&gt;），而句子 B 的真实长度为 2（表示为 &lt;code&gt;[v_NLP, v_rocks, v_PAD, v_PAD]&lt;/code&gt;）。掩码池化的计算过程如下：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;创建掩码&lt;/strong&gt;：&lt;code&gt;mask = [[1, 1, 1, 1], [1, 1, 0, 0]]&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;向量置零&lt;/strong&gt;：将句子 B 中 &lt;code&gt;&amp;lt;PAD&amp;gt;&lt;/code&gt; 对应的向量 &lt;code&gt;v_PAD&lt;/code&gt; 乘以 0，使其变为零向量。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;向量求和&lt;/strong&gt;：句子 A 求和得到 &lt;code&gt;sum_A = v_I + v_love + v_NLP + v_too&lt;/code&gt;；句子 B 求和得到 &lt;code&gt;sum_B = v_NLP + v_rocks + 0 + 0&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;（4）&lt;strong&gt;除以真实长度&lt;/strong&gt;：句子 A 除以 4 得到 &lt;code&gt;pool_A = sum_A / 4&lt;/code&gt;；句子 B 除以 2 得到 &lt;code&gt;pool_B = sum_B / 2&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;通过这种方式就得到了不受填充影响的、精确的句子平均向量。而在 &lt;code&gt;forward&lt;/code&gt; 方法中，这个过程大致包含四个步骤。首先是&lt;strong&gt;创建掩码&lt;/strong&gt;，即根据输入的词元 ID 序列（&lt;code&gt;token_ids&lt;/code&gt;）中不等于 &lt;code&gt;padding_idx&lt;/code&gt; 的位置，生成一个值为 0 或 1 的掩码张量；紧接着进行&lt;strong&gt;向量置零&lt;/strong&gt;，利用广播机制将特征向量与掩码相乘，使所有填充位置的特征向量都会变为零向量；随后&lt;strong&gt;向量求和&lt;/strong&gt;，沿序列长度维度对特征向量进行求和；最后&lt;strong&gt;除以真实长度&lt;/strong&gt;，将求和结果除以每个样本的真实长度（即掩码中 1 的数量），得到最终的池化向量。&lt;/p&gt;
&lt;h4 id="333-模型代码"&gt;3.3.3 模型代码
&lt;/h4&gt;&lt;p&gt;根据上述分析，下面是 &lt;code&gt;TextClassifier&lt;/code&gt; 模型的完整实现：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TextClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;TextClassifier&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;padding_idx&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feature_extractor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classifier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;embedded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;token_features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feature_extractor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embedded&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# shapes:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# token_ids: [batch_size, seq_len]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# embedded: [batch_size, seq_len, embed_dim]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# token_features: [batch_size, seq_len, hidden_dim * 4]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# padding_mask: [batch_size, seq_len]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# masked_features: [batch_size, seq_len, hidden_dim * 4]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# summed_features: [batch_size, hidden_dim * 4]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# pooled_features: [batch_size, hidden_dim * 4]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# logits: [batch_size, num_classes]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# --- 掩码平均池化 ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;padding_mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;padding_idx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;masked_features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;token_features&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;padding_mask&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;summed_features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;masked_features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;real_lengths&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;padding_mask&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;pooled_features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;summed_features&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;clamp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;real_lengths&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;min&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;classifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pooled_features&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="34-步骤三训练与评估"&gt;3.4 步骤三：训练与评估
&lt;/h3&gt;&lt;p&gt;将所有与训练、评估、优化和模型保存相关的逻辑都封装到一个&lt;code&gt;Trainer&lt;/code&gt;类中。这个类负责协调模型、数据和优化器，完成整个训练流程。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;span class="lnt"&gt;49
&lt;/span&gt;&lt;span class="lnt"&gt;50
&lt;/span&gt;&lt;span class="lnt"&gt;51
&lt;/span&gt;&lt;span class="lnt"&gt;52
&lt;/span&gt;&lt;span class="lnt"&gt;53
&lt;/span&gt;&lt;span class="lnt"&gt;54
&lt;/span&gt;&lt;span class="lnt"&gt;55
&lt;/span&gt;&lt;span class="lnt"&gt;56
&lt;/span&gt;&lt;span class="lnt"&gt;57
&lt;/span&gt;&lt;span class="lnt"&gt;58
&lt;/span&gt;&lt;span class="lnt"&gt;59
&lt;/span&gt;&lt;span class="lnt"&gt;60
&lt;/span&gt;&lt;span class="lnt"&gt;61
&lt;/span&gt;&lt;span class="lnt"&gt;62
&lt;/span&gt;&lt;span class="lnt"&gt;63
&lt;/span&gt;&lt;span class="lnt"&gt;64
&lt;/span&gt;&lt;span class="lnt"&gt;65
&lt;/span&gt;&lt;span class="lnt"&gt;66
&lt;/span&gt;&lt;span class="lnt"&gt;67
&lt;/span&gt;&lt;span class="lnt"&gt;68
&lt;/span&gt;&lt;span class="lnt"&gt;69
&lt;/span&gt;&lt;span class="lnt"&gt;70
&lt;/span&gt;&lt;span class="lnt"&gt;71
&lt;/span&gt;&lt;span class="lnt"&gt;72
&lt;/span&gt;&lt;span class="lnt"&gt;73
&lt;/span&gt;&lt;span class="lnt"&gt;74
&lt;/span&gt;&lt;span class="lnt"&gt;75
&lt;/span&gt;&lt;span class="lnt"&gt;76
&lt;/span&gt;&lt;span class="lnt"&gt;77
&lt;/span&gt;&lt;span class="lnt"&gt;78
&lt;/span&gt;&lt;span class="lnt"&gt;79
&lt;/span&gt;&lt;span class="lnt"&gt;80
&lt;/span&gt;&lt;span class="lnt"&gt;81
&lt;/span&gt;&lt;span class="lnt"&gt;82
&lt;/span&gt;&lt;span class="lnt"&gt;83
&lt;/span&gt;&lt;span class="lnt"&gt;84
&lt;/span&gt;&lt;span class="lnt"&gt;85
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;os&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="nn"&gt;json&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Trainer&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;valid_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;.&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;criterion&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;criterion&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_loader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train_loader&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;valid_loader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;valid_loader&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output_dir&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;output_dir&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;makedirs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;exist_ok&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 用于记录历史数据&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;val_accuracies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_run_epoch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;total_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;desc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Epoch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; [训练中]&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;labels&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;total_loss&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;total_loss&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_loader&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;correct_preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;total_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;valid_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;desc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Epoch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; [评估中]&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;token_ids&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;labels&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;total_samples&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;correct_preds&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;correct_preds&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;total_samples&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_save_checkpoint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_accuracy&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;val_accuracy&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_accuracy&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;val_accuracy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;save_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;best_model.pth&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;save&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;save_path&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;新最佳模型已保存! Epoch: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;, 验证集准确率: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;val_accuracy&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.4f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label_map&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;val_accuracies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;avg_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_run_epoch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;val_accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;avg_loss&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;val_accuracies&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;val_accuracy&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Epoch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; | 训练损失: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;avg_loss&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.4f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt; | 验证集准确率: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;val_accuracy&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;.4f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_save_checkpoint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_accuracy&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;训练完成！&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 训练结束后，保存最终的词典和标签映射&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;vocab_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;vocab.json&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nb"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;w&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;utf-8&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;json&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dump&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ensure_ascii&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;indent&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;labels_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;label_map.json&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nb"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;labels_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;w&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;utf-8&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;json&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dump&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;label_map&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ensure_ascii&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;indent&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;词典 (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;vocab_path&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;) 和标签映射 (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;labels_path&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;) 已保存。&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;val_accuracies&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h3 id="35-步骤四执行训练"&gt;3.5 步骤四：执行训练
&lt;/h3&gt;&lt;p&gt;通过前面的精心封装，现在执行训练的入口代码变得非常直观和简洁。我们先定义一个超参数字典 &lt;code&gt;hparams&lt;/code&gt; 来集中管理所有配置，这是一种良好的工程实践。然后，只需实例化所有需要的“零件”，并将它们交给“训练总管” &lt;code&gt;Trainer&lt;/code&gt; 即可。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 超参数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;hparams&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;vocab_size&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;embed_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;hidden_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;256&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;num_classes&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;epochs&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;learning_rate&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;0.001&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cuda&amp;#34;&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;cpu&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;output&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 实例化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TextClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;vocab_size&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;embed_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;hidden_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;num_classes&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;criterion&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Adam&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;learning_rate&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;然后，我们使用这些超参数来实例化模型、损失函数、优化器，并将它们全部交给 &lt;code&gt;Trainer&lt;/code&gt; 类进行管理。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;trainer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Trainer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;train_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;valid_loader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;output_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 创建 标签名 -&amp;gt; ID 的映射，并传入 trainer 以便保存&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;label_map&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataset_raw&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;)}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 开始训练，并接收返回的历史数据&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_accuracies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trainer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;epochs&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label_map&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;label_map&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;...
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;14&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 224/224 &lt;span class="o"&gt;[&lt;/span&gt;00:00&amp;lt;00:00, 314.05it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;14&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 169/169 &lt;span class="o"&gt;[&lt;/span&gt;00:00&amp;lt;00:00, 788.40it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 14/20 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0003 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8469
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;...
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;19&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 224/224 &lt;span class="o"&gt;[&lt;/span&gt;00:00&amp;lt;00:00, 326.07it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;19&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 169/169 &lt;span class="o"&gt;[&lt;/span&gt;00:00&amp;lt;00:00, 786.61it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 19/20 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0001 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8450
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;20&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;训练中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 224/224 &lt;span class="o"&gt;[&lt;/span&gt;00:00&amp;lt;00:00, 324.01it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch &lt;span class="m"&gt;20&lt;/span&gt; &lt;span class="o"&gt;[&lt;/span&gt;评估中&lt;span class="o"&gt;]&lt;/span&gt;: 100%&lt;span class="p"&gt;|&lt;/span&gt;██████████&lt;span class="p"&gt;|&lt;/span&gt; 169/169 &lt;span class="o"&gt;[&lt;/span&gt;00:00&amp;lt;00:00, 792.40it/s&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;Epoch 20/20 &lt;span class="p"&gt;|&lt;/span&gt; 训练损失: 0.0001 &lt;span class="p"&gt;|&lt;/span&gt; 验证集准确率: 0.8443
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;训练完成！
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;词典 &lt;span class="o"&gt;(&lt;/span&gt;output&lt;span class="se"&gt;\v&lt;/span&gt;ocab.json&lt;span class="o"&gt;)&lt;/span&gt; 和标签映射 &lt;span class="o"&gt;(&lt;/span&gt;output&lt;span class="se"&gt;\l&lt;/span&gt;abel_map.json&lt;span class="o"&gt;)&lt;/span&gt; 已保存。
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;在本次训练中，模型于第14个轮次（Epoch）达到了最佳性能，验证集准确率最高为 &lt;strong&gt;84.69%&lt;/strong&gt;。由于代码并未固定随机数种子，模型初始权重和数据加载顺序在每次运行时都会有所不同，所以每次运行时得到的结果可能会有细微差异。&lt;/p&gt;
&lt;h4 id="351-训练过程可视化"&gt;3.5.1 训练过程可视化
&lt;/h4&gt;&lt;p&gt;为了更直观地分析模型的训练过程，例如判断是否收敛、是否存在过拟合等，可以将每个周期的训练损失和验证集准确率绘制成图表。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;plot_history&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_accuracies&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;title_prefix&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;epochs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 绘制训练损失曲线&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;bo-&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Training Loss&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;title_prefix&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s1"&gt; Training Loss&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Epochs&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Loss&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="c1"&gt;# 绘制验证集准确率曲线&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_accuracies&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;ro-&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Validation Accuracy&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;title_prefix&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s1"&gt; Validation Accuracy&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Epochs&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;Accuracy&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;suptitle&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;title_prefix&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s1"&gt; Training and Validation Metrics&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;fontsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 调用绘图函数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;plot_history&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_accuracies&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;title_prefix&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;Feed-Forward Network&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p align="center"&gt;
&lt;img src="https://cdn.jsdelivr.net/gh/Hanguangwu/MyImageBed01/img/7_1_3.png" width="70%" alt="训练过程指标" /&gt;
&lt;br /&gt;
&lt;em&gt;图 7-3 训练损失与验证集准确率变化曲线&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;从图 7-3 中能够看出：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;训练损失&lt;/strong&gt;：随着训练的进行，损失稳步下降并趋于平缓，说明模型在训练数据上得到了有效的学习。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;验证集准确率&lt;/strong&gt;：准确率在前几个轮次（Epochs）中迅速提升，随后在达到一个较高水平后出现小幅波动并趋于饱和。这表明模型在训练早期就快速收敛，并在后续训练中将性能稳定在最佳水平附近。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="36-步骤五模型推理"&gt;3.6 步骤五：模型推理
&lt;/h3&gt;&lt;p&gt;训练完成后，最终的目的是使用模型对全新的、未见过的数据进行预测。一个健壮的推理流程必须确保使用与训练时完全相同的预处理配置（特别是词典）和模型权重。&lt;/p&gt;
&lt;h4 id="361-长文本推理的聚合策略"&gt;3.6.1 长文本推理的聚合策略
&lt;/h4&gt;&lt;p&gt;由于我们对长文本进行了滑窗分割，一篇原始文档在推理时会得到多个文本块的预测结果。那么如何将这些结果聚合成一个最终预测呢？常见的策略主要有两种，第一个是&lt;strong&gt;多数投票法&lt;/strong&gt;，也是最直观的方法，具体做法是分别查看每个文本块被预测成的类别，然后选择得票最多的那个类别作为最终结果，若出现平票则可选择置信度总和最高的类别。第二个是&lt;strong&gt;概率累乘/平均法&lt;/strong&gt;，该方法会计算每个类别在所有文本块上的平均置信度或概率，然后选择平均置信度最高的类别。虽然累乘也是一种选择，但在实践中容易因小概率值导致数值下溢，因此取对数后再求和（等价于累乘）或直接平均更为常用。&lt;/p&gt;
&lt;p&gt;下面的 &lt;code&gt;Predictor&lt;/code&gt; 类将封装完整的推理流程，并实现了“多数投票法”作为聚合策略。&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt; 1
&lt;/span&gt;&lt;span class="lnt"&gt; 2
&lt;/span&gt;&lt;span class="lnt"&gt; 3
&lt;/span&gt;&lt;span class="lnt"&gt; 4
&lt;/span&gt;&lt;span class="lnt"&gt; 5
&lt;/span&gt;&lt;span class="lnt"&gt; 6
&lt;/span&gt;&lt;span class="lnt"&gt; 7
&lt;/span&gt;&lt;span class="lnt"&gt; 8
&lt;/span&gt;&lt;span class="lnt"&gt; 9
&lt;/span&gt;&lt;span class="lnt"&gt;10
&lt;/span&gt;&lt;span class="lnt"&gt;11
&lt;/span&gt;&lt;span class="lnt"&gt;12
&lt;/span&gt;&lt;span class="lnt"&gt;13
&lt;/span&gt;&lt;span class="lnt"&gt;14
&lt;/span&gt;&lt;span class="lnt"&gt;15
&lt;/span&gt;&lt;span class="lnt"&gt;16
&lt;/span&gt;&lt;span class="lnt"&gt;17
&lt;/span&gt;&lt;span class="lnt"&gt;18
&lt;/span&gt;&lt;span class="lnt"&gt;19
&lt;/span&gt;&lt;span class="lnt"&gt;20
&lt;/span&gt;&lt;span class="lnt"&gt;21
&lt;/span&gt;&lt;span class="lnt"&gt;22
&lt;/span&gt;&lt;span class="lnt"&gt;23
&lt;/span&gt;&lt;span class="lnt"&gt;24
&lt;/span&gt;&lt;span class="lnt"&gt;25
&lt;/span&gt;&lt;span class="lnt"&gt;26
&lt;/span&gt;&lt;span class="lnt"&gt;27
&lt;/span&gt;&lt;span class="lnt"&gt;28
&lt;/span&gt;&lt;span class="lnt"&gt;29
&lt;/span&gt;&lt;span class="lnt"&gt;30
&lt;/span&gt;&lt;span class="lnt"&gt;31
&lt;/span&gt;&lt;span class="lnt"&gt;32
&lt;/span&gt;&lt;span class="lnt"&gt;33
&lt;/span&gt;&lt;span class="lnt"&gt;34
&lt;/span&gt;&lt;span class="lnt"&gt;35
&lt;/span&gt;&lt;span class="lnt"&gt;36
&lt;/span&gt;&lt;span class="lnt"&gt;37
&lt;/span&gt;&lt;span class="lnt"&gt;38
&lt;/span&gt;&lt;span class="lnt"&gt;39
&lt;/span&gt;&lt;span class="lnt"&gt;40
&lt;/span&gt;&lt;span class="lnt"&gt;41
&lt;/span&gt;&lt;span class="lnt"&gt;42
&lt;/span&gt;&lt;span class="lnt"&gt;43
&lt;/span&gt;&lt;span class="lnt"&gt;44
&lt;/span&gt;&lt;span class="lnt"&gt;45
&lt;/span&gt;&lt;span class="lnt"&gt;46
&lt;/span&gt;&lt;span class="lnt"&gt;47
&lt;/span&gt;&lt;span class="lnt"&gt;48
&lt;/span&gt;&lt;span class="lnt"&gt;49
&lt;/span&gt;&lt;span class="lnt"&gt;50
&lt;/span&gt;&lt;span class="lnt"&gt;51
&lt;/span&gt;&lt;span class="lnt"&gt;52
&lt;/span&gt;&lt;span class="lnt"&gt;53
&lt;/span&gt;&lt;span class="lnt"&gt;54
&lt;/span&gt;&lt;span class="lnt"&gt;55
&lt;/span&gt;&lt;span class="lnt"&gt;56
&lt;/span&gt;&lt;span class="lnt"&gt;57
&lt;/span&gt;&lt;span class="lnt"&gt;58
&lt;/span&gt;&lt;span class="lnt"&gt;59
&lt;/span&gt;&lt;span class="lnt"&gt;60
&lt;/span&gt;&lt;span class="lnt"&gt;61
&lt;/span&gt;&lt;span class="lnt"&gt;62
&lt;/span&gt;&lt;span class="lnt"&gt;63
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Predictor&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label_map&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;label_map&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;label_map&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;id_to_label&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;label_map&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_len&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;token_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;convert_tokens_to_ids&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;chunks&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;chunks&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;stride&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stride&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;chunks&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_ids&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_len&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;chunk_tensors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunks&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chunk_tensors&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_pred_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bincount&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;preds&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;final_pred_label&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;id_to_label&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;final_pred_id&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;final_pred_label&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 加载资源&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;vocab_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;vocab.json&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nb"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;r&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;utf-8&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;loaded_vocab&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;json&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;labels_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;label_map.json&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nb"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;labels_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;r&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;utf-8&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;label_map_loaded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;json&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 实例化推理组件&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;inference_tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;loaded_vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;inference_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;TextClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inference_tokenizer&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;embed_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;hidden_dim&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="nb"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;label_map_loaded&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;model_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;output_dir&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;best_model.pth&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;inference_model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;map_location&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;predictor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Predictor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inference_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;inference_tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;label_map_loaded&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="n"&gt;hparams&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;device&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="c1"&gt;# 预测&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;new_text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;The doctor prescribed a new medicine for the patient&amp;#39;s illness, focusing on its gpu accelerated healing properties.&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="n"&gt;predicted_class&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;predictor&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;new_text&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="s2"&gt;&amp;#34;text&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;new_text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;#34;pred&amp;#34;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;predicted_class&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;p&gt;输出如下：&lt;/p&gt;
&lt;div class="highlight"&gt;&lt;div class="chroma"&gt;
&lt;table class="lntable"&gt;&lt;tr&gt;&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code&gt;&lt;span class="lnt"&gt;1
&lt;/span&gt;&lt;span class="lnt"&gt;2
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;
&lt;td class="lntd"&gt;
&lt;pre tabindex="0" class="chroma"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt;&lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;text&amp;#39;&lt;/span&gt;: &lt;span class="s2"&gt;&amp;#34;The doctor prescribed a new medicine for the patient&amp;#39;s illness, focusing on its gpu accelerated healing properties.&amp;#34;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span class="line"&gt;&lt;span class="cl"&gt; &lt;span class="s1"&gt;&amp;#39;pred&amp;#39;&lt;/span&gt;: &lt;span class="s1"&gt;&amp;#39;sci.med&amp;#39;&lt;/span&gt;&lt;span class="o"&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;
&lt;/div&gt;
&lt;/div&gt;&lt;h2 id="四过拟合问题"&gt;四、过拟合问题
&lt;/h2&gt;&lt;p&gt;刚刚构建的模型并没有考虑&lt;strong&gt;过拟合（Overfitting）&lt;/strong&gt; 的问题，即模型在训练集上表现优异，但在未见过的验证集或测试集上表现不佳。下面简单介绍三个常用的方案：&lt;/p&gt;
&lt;p&gt;（1）&lt;strong&gt;提前停止（早停）&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这种方法是在&lt;code&gt;Trainer&lt;/code&gt;的&lt;code&gt;train&lt;/code&gt;方法中，持续监控验证集的准确率（或损失）。如果发现验证集准确率连续N个轮次（N被称为“耐心值”，Patience）都没有超过历史最佳值，就提前终止训练。这可以在&lt;code&gt;Trainer&lt;/code&gt;中增加一个&lt;code&gt;patience&lt;/code&gt;参数和一个计数器来实现此逻辑。&lt;/p&gt;
&lt;p&gt;（2）&lt;strong&gt;随机 Token 遮盖&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;这是一种数据增强方法，具体操作是在&lt;strong&gt;训练过程&lt;/strong&gt;中，随机地将文本中的一部分词元（例如15%）替换为&lt;code&gt;&amp;lt;UNK&amp;gt;&lt;/code&gt;。使得模型不能过度依赖个别“明星词汇”，而是要学习更全面的上下文语义来进行判断，继而提升模型的泛化能力。这个修改可以在&lt;code&gt;TextClassificationDataset&lt;/code&gt;类的&lt;code&gt;__getitem__&lt;/code&gt;方法中，在返回数据前增加一个随机替换的步骤。不过要注意，这个操作只应在训练时进行。&lt;/p&gt;
&lt;p&gt;（3）&lt;strong&gt;Dropout&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;它的核心是在训练过程中，以一定的概率p随机地将神经网络中某些神经元的输出置为零。可以防止神经元之间形成过于复杂的共适应关系，迫使网络学习到更鲁棒、更泛化的特征。可以在 &lt;code&gt;TextClassifier&lt;/code&gt; 模型的 &lt;code&gt;feature_extractor&lt;/code&gt; 模块中，于 &lt;code&gt;nn.Linear&lt;/code&gt; 层和 &lt;code&gt;nn.ReLU&lt;/code&gt; 激活函数之后加入 &lt;code&gt;nn.Dropout(p)&lt;/code&gt; 层。&lt;/p&gt;</description></item></channel></rss>