Transformer应用之构建聊天机器人(二)

四、模型训练解析

在PyTorch提供的“Chatbot Tutorial”中,关于训练提到了2个小技巧:

  • 使用”teacher forcing”模式,通过设置参数“teacher_forcing_ratio”来决定是否需要使用当前标签词汇来作为decoder的下一个输入,而不是把decoder当前预测出来的词汇当做decoder的下一个输入,这是因为存在这样的情况,如果当前预测出来的词汇跟输入词汇从语义上来讲没有多大关联时,如果继续使用预测出来的词汇来训练模型,有可能就会造成比较大的预测偏差,从而导致模型训练后的预测效果很差,如果改为直接使用输入词汇对应的目标词汇(标签)来作为decoder的下一个输入,相当于进行强制纠偏,使decoder训练时输出与输入之间不至于出现偏差很大的情况。
  • 第2个小技巧是使用梯度裁剪(Gradient Clipping),这是一种常用的防止梯度爆炸的技术。在深度学习训练过程中,因为网络层数较多,梯度可能会非常大,导致模型无法收敛。梯度裁剪的目的就是限制梯度的大小,使其不超过一个预设的阈值,从而避免梯度爆炸的问题。

训练过程如下:

  1. 输入语句正向传播通过encoder
  2. 使用SOS token作为decoder的初始输入,使用encoder的final hidden state来初始化decoder的hidden state
  3. Decoder端根据输入单步执行产生输出
  4. 如果执行”teacher forcing”模式,则把当前对应的目标词汇(标签)作为decoder的下一个输入,否则使用当前decoder的输出词汇作为decoder的下一个输入
  5. 计算并累加损失
  6. 执行反向传播
  7. 执行梯度裁剪
  8. 更新decoder和encoder的模型参数

以下是代码示例:

以下是Transformer模型训练代码示例,

  • 首先把输入sequence(对话输入),输出sequence(对话输出),以及各自的mask传入模型做正向传播
  • 计算预测结果与标签的损失,然后反向传播更新模型参数
  • 训练时可以使用验证集(dev dataset)对训练效果进行评估

五、模型预测(推理)过程解析

下面这个图描述了Transformer的预测推理过程:

  • 假设使用两个encoder和两个decoder来构成这个Transformer模型,首先把输入语句转为embedding词向量,并加入位置编码信息
  • 正向传播通过encoder1,它的输出再通过encoder2,期间会使用多头注意力机制对输入序列中的每个词向量并行地进行注意力Q,K,V的计算
  • Decoder1使用<START> token进行初始化,并使用带掩码多头注意力机制进行计算,并且需要根据前面encoder2的输出进行注意力的计算,然后输出预测得到的词汇
  • Decoder1输出的词汇作为decoder2的输入,同样decoder2在进行多头注意力计算时也需要使用encoder2的注意力计算输出结果
  • Decoder2的输出传入线性层,之后使用Softmax函数转为0到1之间的概率,然后可以使用greedy search(贪心解码)算法得到概率最高的词汇作为预测结果

下面是预测相关代码的示例:

再来看下PyTorch提供的聊天机器人样例的预测操作:

  • 用户输入正向传播通过encoder模型
  • 把encoder的final hidden layer作为decoder模型的first hidden input
  • 使用SOS_token作为decoder的第一个输入来初始化模型
  • decoder根据encoder的输出(上篇文章提到的“Luong attention”注意力机制计算),以及当前decoder的输入,hidden state来输出预测得到的词汇(迭代操作)
  • 使用Softmax计算概率并根据概率获取最有可能出现的词汇
  • 把当前预测得到的词汇作为下一个decoder的输入
  • 收集所有预测得到的词汇

以下是预测相关代码的示例:

六、聊天机器人对话效果解析

基于Transformer的聊天机器人和PyTorch提供的聊天机器人都使用同样的训练语料(“Cornell Movie-Dialogs Corpus.”)进行训练,基于Transformer的聊天机器人模型训练了20个epochs,输入语句最大长度设置为60,PyTorch提供的聊天机器人训练配置如下:

clip = 50.0

teacher_forcing_ratio = 1.0

learning_rate = 0.0001

decoder_learning_ratio = 5.0

n_iteration = 4000

print_every = 1

save_every = 500

使用同样的测试对话语料分别对两个模型进行测试,基于Transformer模型的对话测试结果如下:

PyTorch提供的聊天机器人对话测试结果如下:


http://www.niftyadmin.cn/n/363801.html

相关文章

【已更新】2023电工杯A题完整思路代码图表结果--电采暖负荷参与电力系统功率调节的技术经济分析

运行结果图完整内容可见&#xff1a;https://mbd.pub/o/bread/mbd-ZJmXmpxu 典型住户电采暖负荷用电行为分析&#xff1a; a) 分析典型房间温变过程微分方程稳态解的性态&#xff0c;包括制热功率、室内温度和墙体温度的变化特点&#xff0c;并分析模型参数对稳态解变化规律的…

Cesium教程(六):加载地图服务

目录 1、使用地图服务 2、定义地图服务 2.1 Cesium支持的影像服务: 2.2 编写代码 1、使用地图

Cesium教程(七):加载自定义影像数据

GIS开发中经常需要调用本地或供应方发布的影像数据,加载独立的场景,此时可以借助GeoServer发 布自定义影像数据。 geoserver下载地址:geoserver下载 1、geoserver安装 1.1 安装方式1(推荐) 要求已安装tomcat:下载 Web Archive 版本的GeoServer,下载完毕解压,目 录如…

[Daimayuan] pSort(C++,强连通分量)

题目描述 有一个由 n n n 个元素组成的序列 a 1 , a 2 , … , a n a_1,a_2,…,a_n a1​,a2​,…,an​&#xff1b;最初&#xff0c;序列中的每个元素满足 a i i a_ii ai​i。 对于每次操作&#xff0c;你可以交换序列中第 i i i 个元素和第 j j j 个元素当且仅当满足 …

使用本地的chatGLM

打开终端 wsl -d Ubuntu conda activate chatglm cd cd ChatGLM-6B python3 cli_demo.py 依次输入以上命令。

你还不会AVL树吗?

AVL树 AVL树概念AVL树的插入结点定义插入流程左单旋右单旋左右双旋右左双旋 验证AVL树 AVL树概念 &#x1f680;AVL树是一颗平衡的二叉搜索树&#xff0c;所谓平衡是指左右子树的高度差的绝对值不超过1。所以一颗AVL树&#xff08;如果不是空树&#xff09;有以下性质&#xf…

【python】快速使用AnacondaAnaconda安装及使用教程

Anaconda安装及使用教程 一、什么是Anaconda&#xff1f;二、Anaconda的安装步骤三、Anaconda使用教程1.管理conda2.管理环境3.管理包 一、什么是Anaconda&#xff1f; 1.介绍 Anaconda&#xff08;官方网站&#xff09;就是可以便捷获取包且对包能够进行管理&#xff0c;…

深度学习训练营之船类识别

深度学习训练营之船类识别 原文链接环境介绍前言收获前置工作设置GPU导入图片数据预处理 数据可视化配置数据集数据显示 构建模型模型训练编译训练模型 结果可视化(模型评估)损失值可视化混淆矩阵各项指标评估 原文链接 &#x1f368; 本文为&#x1f517;365天深度学习训练营 …