【自然语言处理八-transformer实现翻译任务-一(输入)】

news/2024/7/19 11:11:19 标签: 自然语言处理, transformer, 人工智能

自然语言处理八-transformer实现翻译任务-一(输入)

  • transformer架构
  • 数据处理部分
    • 模型的输入数据(图中inputs outputs outputs_probilities对应的label)
      • 以处理英中翻译数据集为例的代码
    • positional encoding 位置嵌入
      • 代码

鉴于transfomer的重要性,在两篇介绍过transfomer模型理论的基础上,我们将分几篇文章,用pytorch代码实现一个完整的transfomer模型。
下面是之前介绍模型的文章:
自然语言处理六-最重要的模型-transformer-上
自然语言处理六-最重要的模型-transformer-下

transformer_7">transformer架构

在这里重新给出transfomer架构图以及用中文翻译后的对照图:
在这里插入图片描述

从架构图可以看出实现transformer架构,encoder和decoder大部分相同。因此也规划几篇内容,介绍以下几大块:

  1. 输入输出部分
    处理数据,源和目标以及输出,以及位置编码

  2. 注意力部分
    多头注意力和掩蔽多头注意力

  3. 前馈网络等
    加和归一化,以及逐位前馈网路

  4. 训练和测试

本篇作为开始,先介绍处理数据部分

数据处理部分

模型的输入数据(图中inputs outputs outputs_probilities对应的label)

这部分用来处理transformer模型需要输入的数据,这部分其实和seq2seq架构是相同的,以自然语言的翻译为例:
比如transformer需要将 ich mochte ein bier 翻译成 i want a beer,那如需要输入的数据格式是这样的(假设句子最大长度5):
[ich mochte ein bier , i want a beer, i want a beer ]

那么上面那部分输入的用途都是什么呢?
encoder输入需要翻译的句子 ich mochte ein bier
decoder输入是 i want a beer
decoder的label是i want a beer
分别对应于图中inputs outputs outputs_probilities相对应的标签

其中是填充字符,用来填充到一个我们超参数中我们设定的sequence的长度
代表句子的开始, 代表句子结束
当然模型真正要处理还是需要根据词汇表转成数字格式的形式,才能被模型处理

以处理英中翻译数据集为例的代码

# -*- coding: utf-8 -*-

"""
加载源数据,并处理成data set和data loader
"""
import os
import zipfile
import torch
import torch.utils.data as Data
from src.configs import config
from src.utility.utils import Utils
from src.vocabulay.vocabulary import Vocabulary


class DataSetLoader:
    """
    数据封装成dateset和datelaoder
    """
    def __init__(self, is_train=True, numbers=config.num_examples):
        """
        init parameters
        """
        self.raw_file_path = config.raw_file_path
        self.raw_zip_path = config.raw_zip_path
        self.batch_size = config.batch_size
        self.num_steps = config.num_steps
        self.num_examples = numbers
        self.is_train = is_train

    def load_array(self, data_arrays):
        """
        构造一个数据迭代器
        :param data_arrays:  输入数据的列表
        :param is_train:     训练/测试数据集
        :return:             dataloader
        """
        dataset = Data.TensorDataset(*data_arrays)
        return Data.DataLoader(dataset, self.batch_size, shuffle=self.is_train)

    def load_data_nmt(self):
        """
        返回翻译数据集的迭代器和词表
        :return: 迭代器和词表
        """
        source, src_vocab, target, tgt_vocab = build_vocabu()
        src_array, src_valid_len = self.build_array_nmt(source, src_vocab)
        tgt_array, tgt_valid_len = self.build_array_nmt(target, tgt_vocab)
        data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
        data_iter = self.load_array(data_arrays)
        return data_iter, src_vocab, tgt_vocab

    def build_array_nmt(self, lines, vocab):
        """将机器翻译的文本序列转换成小批量
        source[['hello', 'world'],..]  target[['你'], [好]]
        source[[]]
        """
        lines = [vocab[l] for l in lines]
        lines = [l + [vocab['<eos>']] for l in lines]
        array = torch.tensor([Utils.truncate_pad(
            l, self.num_steps, vocab['<pad>']) for l in lines])
        valid_len = Utils.reduce_sum(
            Utils.astype(array != vocab['<pad>'], torch.int32), 1)
        return array, valid_len


def extract_content():
    """
    提取raw text中内容
    :return: raw text
    """
    if not os.path.exists(config.raw_file_path):
        with zipfile.ZipFile(config.raw_zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.dirname(config.raw_file_path))

    print('语料解压缩完成')
    with open(config.raw_file_path, 'r', encoding='UTF-8') as f:
        content = f.read()
    return content


def build_vocabu():
    """
    创建词表
    :return: 词表
    """
    text = Utils.preprocess_nmt(extract_content())
    source, target = Utils.tokenize_nmt(text, config.num_examples)
    src_vocab = Vocabulary(source, min_freq=2)
    tgt_vocab = Vocabulary(target, min_freq=2)
    return source, src_vocab, target, tgt_vocab

positional encoding 位置嵌入

为了添加位置信息,transformer的位置嵌入,每个位置的512维的数据用sin/cos做了处理
在这里插入图片描述
其中pos是在句子中位置,i是维度信息

代码

def get_sinusoid_encoding_table(n_position, d_model):
    def cal_angle(position, hid_idx):
        return position / np.power(10000, 2 * (hid_idx // 2) / d_model)
    def get_posi_angle_vec(position):
        return [cal_angle(position, hid_j) for hid_j in range(d_model)]

其他处理会在后续章节继续实现


http://www.niftyadmin.cn/n/5480104.html

相关文章

【IC前端虚拟项目】时序面积优化与综合代码出版本交付

【IC前端虚拟项目】数据搬运指令处理模块前端实现虚拟项目说明-CSDN博客 到目前为止,我们完成了第一版综合,那么就可以打开报告看一下了,一看就会发现在1GHz时钟下时序真的很差(毕竟虚拟项目里使用的工艺库还是比较旧的,如果用12nm、7mn会好很多): Timing Path Group cl…

2024春算法训练4——函数与递归题解

一、前言 感觉这次的题目都很好&#xff0c;但是E题....&#xff08;我太菜了想不到&#xff09;&#xff0c;别人的题解都上百行了&#xff0c;晕&#xff1b; 二、题解 A-[NOIP2010]数字统计_2024春算法训练4——函数与递归 (nowcoder.com) 这种题目有两种做法&#xff1a;…

shell命令行中脚本特殊注释指定脚本解释器

在Linux系统中&#xff0c;#!/usr/bin 是一个特殊的注释&#xff0c;通常称为"shebang" 或 “hashbang”。用于指定脚本的解释器。 即它的目的是告诉操作系统应该使用哪个解释器来执行脚本。 通过在脚本的第一行使用#!&#xff0c;后面跟着解释器的路径&#xff0c;…

蓝桥杯 每日两题 day3

碎碎念&#xff1a;断更了&#xff0c;&#xff0c;&#xff0c;悲惨滴去写小组作业&#xff0c;悲惨滴去搞泡泡堂。 1.直线 6.直线 - 蓝桥云课 (lanqiao.cn) from itertools import combinationsx [i for i in range(20)] y [i for i in range(21)] dots [] # 坐标 f…

蓝桥杯第十三届蓝桥杯大赛软件赛省赛C/C++ 大学 A 组题解

1.裁纸刀 题目链接&#xff1a;0裁纸刀 - 蓝桥云课 (lanqiao.cn) 思路&#xff1a;简单的推导一下公式 #include <iostream> using namespace std; int main() {// 请在此输入您的代码cout<<41921*20<<endl;return 0; } 2.灭鼠先锋 题目链接&#xff1a…

python实现pdf的页面替换

利用第三方库PyPDF2&#xff0c;下面例子中进行的是将 origin.pdf 的第17页替换为 s17.pdf 的第1页&#xff1a; import PyPDF2def replace_pages(original_pdf_path, replacement_pages):with open(original_pdf_path, rb) as original_file:original_pdf PyPDF2.PdfReader(…

CSS设置网页背景

目录 概述&#xff1a; 1.background-color: 2.background-image&#xff1a; 3.background-repeat&#xff1a; 4.background-position&#xff1a; 5.background-attachment&#xff1a; 6.background-size&#xff1a; 7.background-origin&#xff1a; 8.background-…

每日三道面试题之 Java并发编程 (三)

1.什么是上下文切换? 在Java线程知识中&#xff0c;上下文切换是指操作系统在多任务环境下&#xff0c;为了实现多任务的并行执行&#xff0c;需要在运行一个任务&#xff08;如一个线程或进程&#xff09;时切换到另一个任务运行的过程。上下文切换是多任务操作系统的核心特…