PyTorch入门(六)使用Transformer模型进行中文文本分类
创始人
2025-05-31 11:15:21
0

  在文章PyTorch入门(五)使用CNN模型进行中文文本分类中,笔者介绍了如何在PyTorch中使用CNN模型进行中文文本分类。本文将会使用Transformer模型实现中文文本分类。
  本文将会使用相同的数据集。文本预处理已经在文章PyTorch入门(五)使用CNN模型进行中文文本分类中介绍,本文使用Transformer模型的Encoder部分,Transformer模型如图:
Transformer模型图
使用Transformer的Encoder部分建立文本分类模型,Python代码如下:

# -*- coding: utf-8 -*-
# @Time : 2023/3/16 14:28
# @Author : Jclian91
# @File : model.py
# @Place : Minghang, Shanghai
import math
import torch
import torch.nn as nnfrom params import NUM_WORDS, EMBEDDING_SIZE# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):def __init__(self, d_model, vocab_size=5000, dropout=0.1):super().__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(vocab_size, d_model)position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float()* (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer("pe", pe)def forward(self, x):x = x + self.pe[:, : x.size(1), :]return self.dropout(x)# Text classifier based on a pytorch TransformerEncoder.
class TextClassifier(nn.Module):def __init__(self,nhead=8,dim_feedforward=2048,num_layers=6,dropout=0.1,activation="relu",classifier_dropout=0.1):super().__init__()vocab_size = NUM_WORDS + 2d_model = EMBEDDING_SIZE# vocab_size, d_model = embeddings.size()assert d_model % nhead == 0, "nheads must divide evenly into d_model"# Embedding layer definitionself.emb = nn.Embedding(vocab_size, d_model, padding_idx=0)self.pos_encoder = PositionalEncoding(d_model=d_model,dropout=dropout,vocab_size=vocab_size)encoder_layer = nn.TransformerEncoderLayer(d_model=d_model,nhead=nhead,dim_feedforward=dim_feedforward,dropout=dropout)self.transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers=num_layers)self.classifier = nn.Linear(d_model, 5)self.d_model = d_modeldef forward(self, x):x = self.emb(x) * math.sqrt(self.d_model)x = self.pos_encoder(x)x = self.transformer_encoder(x)x = x.mean(dim=1)x = self.classifier(x)return x

需要注意的是,Encoder部分的位置编码(PositionalEncoding类)需要自己实现,因为PyTorch中没有实现。
  设置模型参数如下:

  • 文字总数为5500
  • 文本长度(SENT_LENGTH)为200
  • 词向量维度(EMBEDDING_SIZE)为128
  • Transformer的Encoder层数(num_layers)为1
  • 学习率(learning rate)为0.01
  • 训练轮数(epoch)为10
  • 批量大小(batch size)为32

进行模型训练,得到在验证集上的结果为:accuracy=0.9010, precision=0.9051, recall=0.9010, f1-score=0.9018, 混淆矩阵为:

在验证集上的混淆矩阵

参数影响

  我们考察模型参数对模型在验证集上的表现的影响。

  • 考察句子长度对模型表现的影响

保持其它参数不变,设置文本长度(SENT_LENGTH)分别为200,256,300,结果如下:

文本长度accuracyprecisionrecallf1-score
2000.90100.90510.90100.9018
2560.89900.90190.89900.8977
3000.87880.88240.87880.8774
  • 考察词向量维度对模型表现的影响

设置文本长度(SENT_LENGTH)为200,保持其它参数不变,设置词向量维度为32, 64, 128,结果如下:

词向量维度accuracyprecisionrecallf1-score
320.68690.74020.68690.6738
640.75760.76290.75760.7518
1280.90100.90510.90100.9018
2560.92120.92380.92120.9213

从中,我们可以发现,文本长度对模型表现的影响不如词向量维度对模型表现的影响大,当然,这是相对而言,因为文本长度一直保持在200以上,如果文本长度下降很多的话,模型表现会迅速下降。

总结

  本文介绍了如何使用Transformer模型进行中文文本分类,并考察了各重要参数对模型表现的影响。
  本项目已上传至Github,访问网址为:https://github.com/percent4/pytorch_transformer_chinese_text_classification

参考文献

  1. Language Modeling with nn.Transformer and TorchText: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
  2. The Annotated Transformer: http://nlp.seas.harvard.edu/annotated-transformer/

相关内容

热门资讯

136. 只出现一次的数字 总结 异或位运算方法 给你一个非空整数 nums ,除了某个元素只出现一次以外&#x...
C++笔记——第七篇 stac... 目录 一、stack 1.介绍 2.使用  二、queue 1.介绍 2.使用 三、priority...
Java多线程之Executo... 文章目录1 ExecutorCompletionService1.1 简介1.2 原理1.3 Dem...
2023跨境市场洞察:金矿在哪... 就全球市场而言,跨境电商的高速增长时代已成过去时,但就意味电商金矿被挖空...
Scala中Array常用的方...         在scala中,Array有大量的方法。定义一个数组arr后ÿ...
C++基础学习笔记(四)——核... 参考链接:https://www.bilibili.com/video/BV1et41...
超详细-安装vCenterv ... 目录 介绍: 第一阶段安装: 第二阶段安装: 最近在玩虚拟...
第14届蓝桥杯STEMA测评真... [导读]:超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成,...
ChatGPT助力校招----... 1 ChatGPT每日一题:简述SPI通信协议 问题:简述SPI通信协议...
新版PMP考试难不难? 1.新版考试题量和答题时间的变化? 总题量从200道减少到180道,所以...
HBase客户端、服务器端、列... HBase客户端、服务器端、列簇设计、HDFS相关优化,HBase写性能优化切入点&#...
linux 全局环境变量删除后... linux 全局环境变量删除后 还有 仍然存在1、编辑 /etc/profile2、设置REDISC...
网站流量飙升背后:外贸企业谷歌... 自从我涉足外贸行业,我逐渐认识到谷歌SEO优化在提升网站流量和吸引潜在客户方面的重要性...
一、trino406系列 之 ... 文章目录前言Trino不是什么?Trino是什么?概览服务类型Coord...
基于Java+SpringBo...  博主介绍:专注于Java技术领域和毕业项目实战 🍅文末获取源码联系&...
财经时评|破除“内卷式”竞争 ... 作者 中国汽车工程学会理事长张进华“十四五”以来,我国智能网联新能源汽车产业坚持以科技创新引领和推动...
二十六、对象的实例化内存布局与... 一、对象的实例化 1.判断对象对用的类是否加载、链接、初始化。 2.为对象分配内存。 3.处理并发...
C语言简单工厂模式和工程创建 一,设计模式概念引入① 什么是设计模式设计模式通常被面向对象的软件开发人员所采用&#x...
新势力车企5月销量:零跑汽车再... 红星资本局6月1日消息,今日,新势力车企陆续公布5月销量数据。零跑汽车(09863.HK)再创历史新...
150.网络安全渗透测试—[C... 我认为,无论是学习安全还是从事安全的人多多少少都会有些许的情怀和使命感!...