当前位置: 首页 > news >正文

NLP —— 迁移学习 FastText

一、FastText 介绍

官网:https://fasttext.cc/

1.作用:作为NLP工程领域常用的工具包,FastText有两大作用

① 进行文本分类

② 训练词向量

模型领域中起着承上启下的作用

上:transformer

下:训练模型 || 大模型

2.FastText工具包的优势

① fastText 工具包中内含的fastText模型 十分简单的网络结构

② 使用fastText模型训练词向量时使用层次softmax结构,来提升超多类别下的模型性能。

③ 由于fasttext模型过于简单无法捕捉词序特征, 因此会进行n-gram特征提取以弥补模型缺陷提升精度

3.FastText安装

优先用:pip install fasttext

报错就用:pip install fasttext-wheel

二、FastText 模型架构

FastText 模型架构和 Word2Vec 中的 CBOW 模型很类似, 不同之处在于, FastText 预测标签, 而 CBOW 模型预测中间词.

FastText的模型分为三层架构:

  • 输入层: 是对文档embedding之后的向量, 包含N-gram特征

  • 隐藏层: 是对输入数据的求和平均

  • 输出层: 是文档对应的label

层次softmax(hierarchical softmax)

  • 为了提高效率, 在fastText中计算分类标签概率的时候, 不再使用传统的softmax来进行多分类的计算, 而是使用哈夫曼树, 使用层次化的softmax来进行概率的计算.

二、FastText 文本分类

文本分类的过程¶

  • 第一步: 获取数据

  • 第二步: 训练集与验证集的划分

  • 第三步: 训练模型

  • 第四步: 使用模型进行预测并评估

  • 第五步: 模型调优

  • 第六步: 模型保存与重加载

API使用代码

# fasttext文本分类API的使用 import fasttext # 1- 模型训练和预测 def demo01_base(): # 1- 模型训练 model = fasttext.train_supervised(input="data/cooking_train.txt") # 2- 使用训练好的模型对数据进行预测 result1 = model.predict("Which baking dish is best to bake a banana bread ?") print(result1) result2 = model.predict("Why not put knives in the dishwasher?") print(result2) # 3- 模型测试 result = model.test(path="data/cooking_valid.txt") print(result) # (3000, 0.15566666666666668, 0.06732016721925904) 样本条数 精确率 召回率 # 2- 数据基本处理:统一成大小写、标点符号前面加空格。。。 def demo02_preprocessing(): # 1- 模型训练 model = fasttext.train_supervised(input="data/cooking.pre.train") # 2- 模型测试 result = model.test(path="data/cooking.pre.valid") print(result) # 3- 增加训练轮次 def demo03_epoch(): # 1- 模型训练 model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20) # 2- 模型测试 result = model.test(path="data/cooking.pre.valid") print(result) # 4- 调整学习率 def demo04_lr(): # 1- 模型训练 model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=1) # 2- 模型测试 result = model.test(path="data/cooking.pre.valid") print(result) # 5- 设置n-gram参数 def demo05_n_gram(): # 1- 模型训练 model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=1,wordNgrams=2) # 2- 模型测试 result = model.test(path="data/cooking.pre.valid") print(result) # 6- 调整损失函数 def demo06_loss(): # 1- 模型训练 model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=1,wordNgrams=2,loss="hs") # 2- 模型测试 result = model.test(path="data/cooking.pre.valid") print(result) # 7- 自动超参数调优 def demo07_auto(): # 1- 模型训练 """ 参数解释: autotuneDuration:查找最优超参数组合的时间,最终找到的参数组合不一定是最优的 autotuneValidationFile:找到最优超参数组合,使用验证集数据对参数效果进行验证 """ model = fasttext.train_supervised( input="data/cooking.pre.train", autotuneValidationFile="data/cooking.pre.valid", autotuneDuration=60*2 ) # 2- 模型测试 result = model.test(path="data/cooking.pre.valid") print(result) # 8- 多标签多分类问题:将问题拆解为单标签多分类问题,loss损失函数需要设置为ova->one vs all def demo08_ova(): # 1- 模型训练 # 注意:lr的学习率不要过大,如果过大会出现梯度消失/爆炸的情况 model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=0.1,wordNgrams=2,loss="ova") # 2- 预测 """ k:表示预测结果中目标值最多展示多少个。如果值为-1,那么是尽可能将所有的目标值全部都展示 threshold:阈值。如果预测的目标值的概率超过了threshold,那么才有可能显示出来 result1中输出的概率值是经过sigmoid计算后的结果,计算是__label__baking标签的概率以及不是__label__baking标签的概率。 多标签中每个标签的概率计算是相互独立的 """ result1 = model.predict("Which baking dish is best to bake a banana bread ?",k=3,threshold=0.5) print(result1) result2 = model.predict("Which baking dish is best to bake a banana bread ?",k=-1) print(result2) # 3- 模型测试 result = model.test(path="data/cooking.pre.valid") print(result) # 9- 保存模型和重新加载模型 def demo09_savemodel(): # 1- 模型训练 model = fasttext.train_supervised(input="data/cooking.pre.train",epoch=20,lr=1,wordNgrams=2,loss="hs") # 2- 保存训练好的模型 model.save_model("model/cooking_model.pkl") # 3- 加载训练好的模型 model2 = fasttext.load_model("model/cooking_model.pkl") result = model2.test(path="data/cooking.pre.valid") print(result) if __name__ == '__main__': # 1- 模型训练和预测 # demo01_base() # (3000, 0.15566666666666668, 0.06732016721925904) # 2- 数据基本处理 # demo02_preprocessing() # (3000, 0.172, 0.07438373936860314) # 3- 增加训练轮次 # demo03_epoch() # (3000, 0.48733333333333334, 0.21075392821104225) # 4- 调整学习率 # demo04_lr() # (3000, 0.5976666666666667, 0.2584690788525299) # 5- 设置n-gram参数 # demo05_n_gram() # (3000, 0.596, 0.2577483061842295) # 6- 调整损失函数 # demo06_loss() # (3000, 0.5946666666666667, 0.2571716880495892) # 7- 自动超参数调优 # demo07_auto() # (3000, 0.536, 0.23180049012541445) # 8- 多标签多分类问题 # demo08_ova() # (3000, 0.532, 0.23007063572149344) # 9- 保存模型和重新加载模型 demo09_savemodel()
http://www.cnnetsun.cn/news/2634308.html

相关文章:

  • 职业倦怠的识别与应对:从个人能量管理到组织健康构建
  • UE5静态网格体也能玩变形?手把手教你用Morph Targets实现动态环境交互(材质顶点偏移实战)
  • 微信聊天记录数据备份:3步学会用WeChatExporter安全导出你的珍贵回忆
  • 手把手教你学 Simulink—— 基于滑模观测器(SMO)的电动汽车电机无位置传感器控制仿真
  • 从1080P到8K视频:FPGA的BANK设计如何影响你的LVDS接口性能?以Xilinx 7系列为例
  • Claude Code / Codex 一键安装器 (附带C#源码,MIT开源)
  • 厌倦了在编辑器、终端和浏览器之间频繁切换?试试这个基于无限画布(类Figma风格)的下一代开源桌面开发环境“Cate”
  • TVA凭什么成为具身机器人的“类人智眼“(3)
  • 费米悖论五层拆解:从德雷克方程到大过滤器,探寻宇宙寂静之谜
  • SketchUp STL插件终极指南:5步掌握3D打印模型导入导出
  • 免费开源AMD Ryzen调试工具:SMUDebugTool完全指南
  • 【Mysql】B+树索引
  • 强化基准精度管理,优化传动设备全生命周期成本
  • 别再乱卸载补丁了!Win10/11共享打印机报错0x0000011b,试试这个注册表一键修复法
  • PPO算法里的GAE到底怎么算?一个PyTorch逆向遍历代码带你彻底搞懂优势估计
  • 别再死磕有限元了!用Python和PyTorch快速上手PINN,搞定偏微分方程反问题
  • 神经形态计算与氧化物界面器件的存算一体技术
  • 信号处理避坑指南:你的Savitzky-Golay滤波器用对了吗?详解阶数、窗长与延迟那些事儿
  • ARMv7-M架构LDM/STM指令中断机制解析
  • 别再只盯着LOF了!盘点5种更高效的异常检测算法(附Python代码与适用场景指南)
  • 别再死记硬背了!用‘悬崖行走’游戏带你直观理解Model-based和Model-free的区别
  • 如何彻底解放你的QQ音乐:qmcdump终极音频解密指南
  • RePKG:解锁Wallpaper Engine壁纸资源的钥匙
  • GIS数据工程师的私藏技巧:用FME的StringSearcher和AttributeCreator玩转OSGB批量重命名与格式转换
  • 从零构建320万参数微型语言模型:拆解Transformer与自注意力机制
  • 用Arduino和5个舵机,我复刻了一台能抓牛奶的并联机械臂(附完整代码与3D文件)
  • 不止于切换:深入龙讯HDMI 2.0矩阵芯片LT86404UX,玩转串口指令与通道管理逻辑
  • ChatGPT时代:从内容通胀到信任重构的思维范式转变
  • 终极游戏手柄兼容性解决方案:ViGEmBus驱动完整指南
  • 别急着重装!NextCloud登录失败的三个隐蔽配置项检查(附Nginx反向代理避坑指南)