📚 PyTorch实战:用LSTM实现文本风格分类(余华vs路遥)
📚 PyTorch实战:用LSTM实现文本风格分类(余华vs路遥)
📝 前言
本周我们将之前学习的序列建模基石:RNN-LSTM-GRU理论落地,用PyTorch实现一个基于LSTM的文本风格分类器——通过书名判断作者是余华还是路遥。
这篇博客完全对应之前的知识点:从文本预处理(词表构建、序列填充),到LSTM的门控机制与细胞状态,再到二分类交叉熵损失与模型训练,完整打通“理论→代码”的逻辑链条。
一、数据流向图
1 | 一次前向传播的数据流向图 |
二、核心参数解析
| 参数名 | 出现位置 | 含义 | 示例值 | 作用说明 |
|---|---|---|---|---|
| batch | 全程 | 批次大小,一次同时处理多少条样本 | 2 | 同时处理两条句子,GPU并行加速,梯度更新更稳定 |
| seq_len | 填充后 | 序列长度,每条样本被统一到的词数 | 5 | 所有句子补齐到相同长度,短句用PAD填充 |
| vocab_size | Embedding | 词表大小,一共多少个不同的字 | 取决于语料 | 决定Embedding层的查表范围 |
| embedding_dim | Embedding | 词向量维度,每个字用几个浮点数表示 | 16 | 维度越大表达能力越强,但训练越慢 |
| hidden_dim | LSTM, Linear | 隐藏层维度,LSTM记忆向量的宽度 | 32 | 贯穿LSTM输出和Linear输入,是模型的”脑容量” |
| num_layers | LSTM | LSTM堆叠层数 | 1 | 单层适合小数据,多层适合复杂任务 |
| output_dim | Linear | 输出维度 | 1 | 二分类只输出一个分数,多分类则输出类别数 |
| padding_idx | Embedding | 填充符索引,指定哪个索引不做梯度更新 | 0 | 让PAD符号不参与训练,避免噪声干扰 |
关键传递关系:
embedding_dim 是 Embedding 的输出维度,也是 LSTM 的输入维度
hidden_dim 是 LSTM 的输出维度,也是 Linear 的输入维度,必须一致
vocab_size 决定 Embedding 的查表范围
num_layers 决定 hidden 和 cell 的第0维大小
三、完整代码及注释
1 | import torch |
四、模型调优:可调参数与影响分析
4.1 结构参数(影响模型容量)
| 参数 | 当前值 | 增大影响 | 减小影响 | 调整信号 |
|---|---|---|---|---|
| embedding_dim | 16 | 表达能力增强,训练变慢,可能过拟合 | 训练加快,表达能力减弱 | 训练集准确率高而测试集低 → 减小;训练集一直上不去 → 增大 |
| hidden_dim | 32 | 记忆容量增大,可处理更复杂句子 | 训练加快,显存占用减少 | 长句判断差 → 增大;显存不足/训练太慢 → 减小 |
| num_layers | 1 | 模型变深,能学更抽象特征 | 模型变浅,训练加速 | 数据量小 → 保持1层;数据量大且欠拟合 → 增加到2-3层 |
4.2 训练参数(影响收敛过程)
| 参数 | 当前值 | 增大影响 | 减小影响 | 调整信号 |
|---|---|---|---|---|
| lr (学习率) | 0.01 | 学习加快,但可能不收敛或震荡 | 学习减慢,收敛更稳但训练更久 | loss忽大忽小 → 减小;loss下降极慢 → 增大;loss变NaN → 立刻减小 |
| epochs | 300 | 训练更充分 | 训练更短 | loss仍在下降 → 增大;loss已平稳或反弹 → 减小或加早停 |
| batch_size | 6(全量) | 梯度更准,训练稳定 | 梯度噪声大,泛化可能更好 | 显存不足 → 减小;训练不稳定 → 增大 |
4.3 典型问题诊断表
| 训练现象 | 可能原因 | 调整方案 |
|---|---|---|
| 训练集准确率100%,测试集却很差 | 数据太少+模型容量过大,过拟合 | 减小embedding_dim/hidden_dim;增加数据量 |
| loss一直不降,准确率在50%附近波动 | 模型容量不足或学习率不合适 | 增大hidden_dim;调整lr |
| loss变成NaN | 学习率过大导致梯度爆炸 | 将lr减小一个数量级 |
| 训练很快达到高准确率但loss仍高 | 模型开始死记硬背 | 减小epochs;增加dropout(此处未加) |
4.4 调优操作流程
先固定epochs=100,用当前参数跑一次,记录loss曲线和最终准确率
观察曲线特征,对照4.3诊断表定位问题类型
每次只调一个参数,改后重跑对比
确认当前参数最优后,再增大epochs做最终训练
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 河岳日星的博客!
评论