在PyTorch里面利用transformers的Trainer微调预训练大模型

news/2024/7/19 11:58:35 标签: pytorch, 人工智能, 自然语言处理, transformer, Train

背景

transformers提供了非常便捷的api来进行大模型的微调,下面就讲一讲利用Trainer来微调大模型的步骤

第一步:加载预训练的大模型

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")

第二步:设置训练超参

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="path/to/save/folder/",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
)

比如这个里面设置了epoch等于2

第三步:获取分词器tokenizer

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

第四步:加载数据集

from datasets import load_dataset

dataset = load_dataset("rotten_tomatoes")  # doctest: +IGNORE_RESULT

第五步:创建一个分词函数,指定数据集需要进行分词的字段:

def tokenize_dataset(dataset):
    return tokenizer(dataset["text"])

第六步:调用map()来将该分词函数应用于整个数据集

dataset = dataset.map(tokenize_dataset, batched=True)

第七步:使用DataCollatorWithPadding来批量填充数据,加速填充过程:

from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

第八步:初始化Trainer

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)  # doctest: +SKIP

第九步:开始训练

trainer.train()

总结:

利用Trainer提供的api,只需要简简单单的九步,十几行代码就能进行大模型的微调,你要不要动手试一试?


http://www.niftyadmin.cn/n/5038832.html

相关文章

【2023年11月第四版教材】第14章《沟通管理》(第一部分)

第14章《沟通管理》(第一部分) 1 章节说明2 管理基础2.1 沟通具体形式包括2.2 沟通模型:★★★ (17下41) (18下43)2.3 沟通模型包含5种状态2.4 沟通分类 3 管理过程3.1 管理的过程★★★ &#…

/usr/bin/ld: cannot find -lmysqlcllient

文章目录 1. question: /usr/bin/ld: cannot find -lmysqlcllient2. solution 1. question: /usr/bin/ld: cannot find -lmysqlcllient 2. solution 在 使用编译命令 -lmysqlclient时,如果提示这个信息。 先确认一下 有没有安装mysql-devel 执行如下命令 yum inst…

基于微信小程序的个人健康管理系统的设计与实现(源码+lw+部署文档+讲解等)

前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 👇🏻…

C++入门及简单例子_5

示例 1: 模板类和模板函数 #include <iostream> // 包含输入输出流库的头文件template<typename T> // 定义模板类&#xff0c;模板参数为类型T class Pair { // 定义名为Pair的类模板 private:T first; // 类模板中的成员变量T second;public:Pair(T f, T s)…

suricata学习记录

pcre正则表达式shell语法automakeautogen.shautogen.sh和autoreconf单模匹配、多模匹配AC算法radix tree(IP查找)

centost7下安装oracle11g 总结踩坑

1.参考文章 【总结】CentoS下Oracle静默安装流程_正在启动oracle universal installer..._仲冬二三的博客-CSDN博客 https://blog.csdn.net/Liqiong_0412/article/details/126153857? unset DISPLAY 可以跳过图形化检查 这边也卡了很久 [oraclewangmengyuan database]$ .…

Vue中的路由介绍以及Node.js的使用

&#x1f3c5;我是默&#xff0c;一个在CSDN分享笔记的博主。&#x1f4da;&#x1f4da; &#x1f31f;在这里&#xff0c;我要推荐给大家我的专栏《Vue》。&#x1f3af;&#x1f3af; &#x1f680;无论你是编程小白&#xff0c;还是有一定基础的程序员&#xff0c;这个专栏…

javascript使用正则表达式去除字符串中括号的方法

如下面的例子&#xff1a; (fb6d4f10-79ed-4aff-a915-4ce29dc9c7e1,39996f34-013c-4fc6-b1b3-0c1036c47119,39996f34-013c-4fc6-b1b3-0c1036c47169,39996f34-013c-4fc6-b1b3-0c1036c47111,2430bf64-fd56-460c-8b75-da0a1d1cd74c,39996f34-013c-4fc6-b1b3-0c1036c47112) 上面是前…