peft模型微调--Prompt Tuning

news/2024/7/19 9:45:13 标签: prompt, peft, transformer, 模型微调

模型微调(Model Fine-Tuning)是指在预训练模型的基础上,针对特定任务进行进一步的训练以优化模型性能的过程。预训练模型通常是在大规模数据集上通过无监督或自监督学习方法预先训练好的,具有捕捉语言或数据特征的强大能力。

PEFT(Parameter-Efficient Fine-Tuning)是一种针对大模型微调的技术,其核心思想是在保持大部分预训练模型参数不变的基础上,仅对一小部分额外参数进行微调,以实现高效的资源利用和性能优化。这种方法对于那些计算资源有限、但又需要针对特定任务调整大型语言模型(如LLM:Large Language Models)的行为时特别有用。

在应用PEFT技术进行模型微调时,通常采用以下策略之一或组合:

Adapter Layers: 在模型的各个层中插入适配器模块,这些适配器模块通常具有较低的维度,并且仅对这部分新增的参数进行微调,而不改变原模型主体的参数。

Prefix Tuning / Prompt Tuning: 通过在输入序列前添加可学习的“提示”向量(即prefix或prompt),来影响模型的输出结果,从而达到微调的目的,而无需更改模型原有权重。

LoRA (Low-Rank Adaptation): 使用低秩矩阵更新原始模型权重,这样可以大大减少要训练的参数数量,同时保持模型的表达能力。

P-Tuning V1/V2: 清华大学提出的一种方法,它通过学习一个连续的prompt嵌入向量来指导模型生成特定任务相关的输出。

冻结(Freezing)大部分模型参数: 只对模型的部分层或头部(如分类器层)进行微调,其余部分则保持预训练时的状态不变。

下面简单介绍一个通过peft使用Prompt Tuning对模型进行微调训练的简单流程。

# 基于peft使用prompt tuning对生成式对话模型进行微调 
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
# 数据加载
ds = Dataset.load_from_disk("/alpaca_data_zh")
print(ds[:3])
# 数据处理
tokenizer = AutoTokenizer.from_pretrained("../models/bloom-1b4-zh")
# 数据处理函数
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

# 数据处理
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
print(tokenized_ds)
# 模型创建
model = AutoModelForCausalLM.from_pretrained("../models/bloom-1b4-zh", low_cpu_mem_usage=True)
# 套用peft对模型进行参数微调
from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit

# 1、配置文件参数
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM,
                            prompt_tuning_init=PromptTuningInit.TEXT,
                            prompt_tuning_init_text="下面是一段人与机器人的对话。",
                            num_virtual_tokens=len(tokenizer("下面是一段人与机器人的对话。")["input_ids"]),
                            tokenizer_name_or_path="../models/bloom-1b4-zh")

# 2、创建模型
model = get_peft_model(model, config)
# 查看模型的训练参数
model.print_trainable_parameters()
# 配置训练参数
args = TrainingArguments(
    output_dir="./peft_model",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)

# 创建训练器
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
# 模型训练
trainer.train()
# 模型推理
peft_model = model.cuda()
ipt = tokenizer("Human: {}\n{}".format("周末去重庆怎么玩?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)
print(tokenizer.decode(peft_model.generate(**ipt, max_length=256, do_sample=True)[0], skip_special_tokens=True))

http://www.niftyadmin.cn/n/5430412.html

相关文章

C++初阶 | [九] list 及 其模拟实现

摘要:介绍 list 容器,list 模拟实现 list(带头双向循环列表) 导入:list 的成员函数基本上与 vector 类似,具体内容可以查看相关文档(cplusplus.com/reference/list/list/),这里不多赘述。以下对…

【析】一类动态车辆路径问题模型和两阶段算法

一类动态车辆路径问题模型和两阶段算法 摘要 针对一类动态车辆路径问题,分析4种主要类型动态信息对传统车辆路径问题的本质影响,将动态车辆路径问题(Dynamic Vehicle Routing Problem, DVRP)转化为多个静态的多车型开放式车辆路径问题(The Fleet Size a…

计算机视觉+人工智能碰撞出新的火花

计算机视觉(CV)技术的优势是其能够处理大量的图像和视频数据,并快速准确地提取出有用的信息。 1. 自动化:CV技术可以自动化地执行各种图像处理任务,例如目标检测、图像分类和图像分割。这样可以提高工作效率并降低人工…

Exam in MAC [容斥]

题意 思路 正难则反 反过来需要考虑的是: (1) 所有满条件一的(x,y)有多少对: x 0 时,有c1对 x 1 时,有c对 ...... x c 时,有1对 以此类推 一共有 (c2)(c1)/2 对 (2) 符合 x y ∈ S的有多少对&#xff1a…

二叉树算法

递归序 每个节点都能回到3次! 相当于2执行完然后返回了代码会往下走,来到3节点 小总结: 也就是4节点先来到自己一次,不会执行if,先调用自己左边的那个函数,但是是null,直接返回。 这个函数执行完了,就会回到自己,调用自己右边的那个函数,结果又是空,又返回,回到…

医疗手持智能终端读取条码二维码的难点有哪些?

在医疗科技行业信息化的大潮中,医疗手持式智能终端的应用越发普及,医疗手持式智能终端对条码二维码技术应用显得尤为关键,作为信息朔源载体的条码二维码读取方面,在实际应用中却面临着诸多问题,我们该如何应对&#xf…

Linux创建新用户,属主相关的题目以及讲解

3月13日 属主相关 题目 一、文件与目录管理 在根目录下创建temp目录,然后在temp目录下创建文件夹dir1和dir2,在dir1下创建dir11,在dir2下创建文件test2.txt,并删除dir2。 [实现步骤]:cd / 进入根目录 mkdir temp 创…

知识蒸馏Matching logits与RocketQAv2

知识蒸馏Matching logits 公式推导 刚开始的怎么来,可以转看下面证明梯度等于输出值-标签y C是一个交叉熵,我们要求解的是这个交叉熵对的这个梯度。就是你可以理解成第个类别的得分。就是student model,被蒸馏的模型,它所输出的…