Transformer的PyTorch实现之若干问题探讨(一)

news/2024/7/19 10:10:15 标签: transformer, pytorch, 深度学习

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。

1.自定义数据中enc_input、dec_input及dec_output的区别

博文中给出了两对德语翻译成英语的例子:

# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [
    # enc_input   dec_input    dec_output
    ['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],
    ['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]

初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。

在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。

我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。

实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。

2.Transformer的训练流程

此处给出博文中附带的非常简洁的Transformer训练代码:

from torch import optim
from model import *

model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

# 训练开始
for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        '''
        enc_inputs: [batch_size, src_len] [2,5]
        dec_inputs: [batch_size, tgt_len] [2,6]
        dec_outputs: [batch_size, tgt_len] [2,6]
        '''
        enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]
        outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]
        loss = criterion(outputs, dec_outputs.view(-1))  # 将dec_outputs展平成一维张量

        # 更新权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')

这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:

ich mochte ein bier

在进行训练时,当我们给出起始符S,接下来应该预测出:

I

那训练时,有了SI后,则应该预测出

want

那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。

那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。


http://www.niftyadmin.cn/n/5364631.html

相关文章

LLMs之miqu-1-70b:miqu-1-70b的简介、安装和使用方法、案例应用之详细攻略

LLMs之miqu-1-70b:miqu-1-70b的简介、安装和使用方法、案例应用之详细攻略 目录 miqu-1-70b的简介 miqu-1-70b的安装和使用方法 1、安装 2、使用方法 miqu-1-70b的案例应用 miqu-1-70b的简介 2024年1月28日,发布了miqu 70b,潜在系列中的…

TrinityCore安装记录

TrinityCore模拟魔兽世界(World of Warcraft)的开源项目,并且该项目代码广泛的优化、改善和清理代码。 前期按照官方手册按部就班的安装即可。 注意几点: 1 需要配置Ubuntu22.04版本的服务器或者Debian11 服务器。2 需要使用gi…

Android电动汽车充电服务vue+uniAPP微信小程序

本系统利用SSM和Uniapp技术进行开发电动汽车充电服务系统是未来的趋势。该系统使用的编程语言是Java,数据库采用的是MySQL数据库,基本完成了系统设定的目标,建立起了一个较为完整的系统。建立的电动汽车充电服务系统用户使用浏览器就可以对其…

【JMeter】使用技巧

在这此对新版本jmeter的学习温习的过程,发现了一些以前不知道的功能,所以,整理出来与大分享。本文内容如下。 如何使用英文界面的jmeter如何使用镜像服务器Jmeter分布式测试启动Debug 日志记录搜索功能线程之间传递变量 如何使用英文界面的…

go-基于逃逸分析来提升性能程序

go-基于逃逸分析来提升性能程序 为什么要学习逃逸分析: 为了提高程序的性能,通过逃逸分析我们能知道指标是分配到堆上还是栈上,如何是 分配到栈上,内存的分配和释放都是由编译器进行管理的,分配和释放的速度都非常的…

游戏行业需要高防护服务器的理由有哪些?

众所周知,游戏行业是最易受DDOS攻击的行业,不管游戏设置的多么精彩,一旦玩家在玩的过程中经常出现死机、卡机的状况游戏用户也会选择离开的,所以服务器是影响游戏业务是否能正常运行的重要因素之一,而游戏行业选用高防…

FFMPEG推流到B站直播

0、参考 ffmpeg安装参考小弟另外的一个博客:FFmpeg和rtsp服务器搭建视频直播流服务-CSDN博客推流参考:用ffmpeg 做24小时推流直播_哔哩哔哩_bilibili 一、获取b站直播码 点击开始直播后,会出现以下的画面 二、ffmpeg进行直播推流 ffmpeg -r…

国辰智企APS自动化排产平台:实现生产计划与其他系统无缝协同

在当今竞争激烈的制造环境中,有效的生产计划和排程对于企业的成功至关重要。APS生产计划排程平台作为一种先进的工具,正越来越受到企业的关注和应用。那么,APS生产计划排程平台有哪些类型呢?本文将为您详细介绍。 1.基于规则的APS…