大模型平民化技术之LORA

1. 引言

在这篇博文中, 我将向大家介绍LoRA技术背后的核心原理以及相应的代码实现。

LoRA Low-Rank AdaptationLow-Rank Adaptors 的首字母缩写词,它提供了一种高效且轻量级的方法,用于微调预先训练好的的大语言模型。这包括 BERT RoBERTa 等掩码语言模型,以及 GPTLlamaMistral 等因果推断模型。

闲话少说,我们直接开始吧!

2. 优势分析

LoRA的主要优点之一是他们的效率。通过使用更少的参数,LoRA显著降低了模型训练过程中计算复杂性和显存使用量。这可以让我们在消费级的GPU上来训练大模型,并且可以便利地将我们训练好的LoRA权重(以兆为单位)分发给其他人。

此外,LoRA可以提升模型的泛化性。通过限制模型的复杂度,可以有助于防止在训练数据有限场景下的过拟合现象;由于LoRA至少保留了初始模型的能力,在处理一些新的,未见过的数据时更具有弹性。

最后,LoRA可以无缝地集成到现有的神经网络架构中。这种集成允许以最小的额外训练成本对预训练模型进行微调和调整,使其非常适合迁移学习应用。

3. 工作原理

LoRA 的基本思想是保持预训练矩阵(即原始模型的参数)冻结(即处于固定状态),并且只在原始矩阵中添加一个小的增量,其参数量比原始矩阵少很多。

例如,考虑矩阵 W,它可以是全连接层的参数,也可以是来Transformer中计算自注意力机制的矩阵之一:
在这里插入图片描述

显然,如果 Worig 的维数为 n×m,而假如我们只是初始化一个具有相同维数的新的增量矩阵进行微调,虽然我们也实现类似的功能,但是我们的参数量将会加倍。 LoRA使用的Trick就是通过训练低维矩阵 B A ,通过矩阵乘法来构造 ΔW ,来使 ΔW 的参数量低于原始矩阵。
在这里插入图片描述
这里我们不妨定义秩 r,它明显小于基本矩阵维度 r≪nr≪m。则矩阵 B n×r,矩阵 A r×m。将它们相乘会得到一个维度为 nxmW 矩阵,但构建的参数量减小了很多。

此外,我们希望我们的增量ΔW在训练开始时为零,这样微调就会从原始模型一样开始。因此,B 通常初始化为全零,而 A 初始化为随机值(通常呈正态分布)。

4. 举个栗子

我们不妨来看个直观的栗子,如下图所示:
在这里插入图片描述
想象一下,我们的基本维数是 1024,我们选择了 LoRA 的秩r为 4,则对于上述过程:
● 权重W的参数量为1024X1024≈1M
AB的参数量一致,均为rX1024≈4K,这样二者之和为8K
● 这样使用LoRA技术,在上述例子中我们仅仅需要训练0.8%的参数就可以更新我们的参数矩阵

5. LoRA指令速查

主要可查阅微软的官方文档Github: 戳我 , 由于封装的很好,目前该库页整合至 HuggingFace Parameter-Efficient Fine-Tuning (PEFT) 。

  • 如果模型要将特定层替换成 LoRA,需要调整模型的结构,但调用很简单:
# ===== Before =====
# layer = nn.Linear(in_features, out_features)
# ===== After ======
import loralib as lora
# Add a pair of low-rank adaptation matrices with rank r=16
layer = lora.Linear(in_features, out_features, r=16)
  • 在训练之前要把原本的LLM模型 Freeze 住,并且设定只有 LoRA 的参数是可训练的
import loralib as lora
model = BigModel()
# This sets requires_grad to False for all parameters without the string "lora_" in their names
lora.mark_only_lora_as_trainable(model)
# Training loop
for batch in dataloader:
  • 保存模型时也可以只储存LoRA所训练的权重,这特性将方便大家分享自己的权重
# ===== Before =====
# torch.save(model.state_dict(), checkpoint_path)
# ===== After =====
torch.save(lora.lora_state_dict(model), checkpoint_path)
  • 推理时读取 LoRA 或是原本 LLM 的权重时,要将 strict 设定为 False
# Load the pretrained checkpoint first
model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
# Then load the LoRA checkpoint
model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)

6. SD-LoRA应用

近年来生成式 AI DALLE 再到 Stable-diffusion,都显示了现在的 AI 可以生成高质量以及高分辨率的图片,但是让人诟病的还是需要大量的运算资源才能够训练得了这种高分辨率的模型,因为要训练一个高分辨率的扩散模型是需要相当多内存的,即便 Stable-diffusion 将原本的 Pixel-level Diffusion 变成 Latent Diffusion Model 已经大幅降低训练的内存,但仍无法在单一张 11 GB 的 GPU 上训练,但现在不一样了,有人将 LoRA 技术整合到 Stable-diffusion,推出了 Stable Diffusion LoRA
在这里插入图片描述
整合 LoRAStable-diffusion 直接带来了以下的好处:

  • 训练快很多
  • 可在 11GB 显卡上直接进行训练
  • LoRA 权重的保存只有 3MB~200MB,易于分享

7. SD-LoRA更多资源

LoRA这项技术上的突破也使得Stable Diffusion 的社区多了许多生成模型,甚至可将模型上传至网站 CivitAI,可以看到上面有许多模型是使用 LoRA 进行训练的:

在这里插入图片描述
当然,网络上也有许多资源是使用 Colab 或是在个人 PC 上面生成/训练模型,最近 Stable diffusion 的社群已经开源相当多项目,并提供 GUI 界面,甚至不需要懂程序代码就可以训练好生成式 AI。

stable-diffusion-webui-colab

Kohya’s GUI, Support Windows

如果只是想来看看 Stable-diffusion 的人建议使用 WebUI,不仅能使用官方释出的模型,也可以直接登陆到 CivitAI,直接下载别人的训练好的生成模型:

在这里插入图片描述
在这里插入图片描述

8. 总结

共享大型的LLM模型是未来的趋势,如果要适应到某个具体任务上,只要训练LoRA模组即可,而这项技术也带来方便的替换性,未来大家只要分享LoRA的模型权重,就可以快速切换至不同的任务。

此外,LoRA通过大量降低训练参数,来大幅降低了硬体的训练门槛,并且与完全 Fine-tuning 的模型相比,推论速度的增加是相当少的。

9. 附录

本文重点参考链接如下:

[1] LoRA论文:戳我
[2] LoRA tutorial:戳我
[3] PEFT tutorial: 戳我
[4] Stable Diffusion Webui: 戳我
[5] ai-drawing: 戳我


http://www.niftyadmin.cn/n/5390017.html

相关文章

利用R语言进行因子分析实战(数据+代码+可视化+详细分析)

🍉CSDN小墨&晓末:https://blog.csdn.net/jd1813346972 个人介绍: 研一|统计学|干货分享          擅长Python、Matlab、R等主流编程软件          累计十余项国家级比赛奖项,参与研究经费10w、40w级横向 文…

Go 中如何高效遍历目录?探索几种方法

嗨,大家好!我是波罗学。本文是系列文章 Go 技巧第十八篇,系列文章查看:Go 语言技巧。 目录遍历是一个很常见的操作,它的使用场景有如文件目录查看(最典型的应用如 ls 命令)、文件系统清理、日志…

【力扣经典面试题】238. 除自身以外数组的乘积

目录 一、题目描述 二、题解分析 思路: 算法步骤: 代码(C版): 三、总结 一、题目描述 给你一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证…

STM32自学☞输入捕获测频率和占空比案例

本文是通过PA0口输出PWM波,然后通过PA6口捕获PWM波的频率和占空比,最终在oled屏上显示我们自己设置的频率和占空比。由于和前面的pwm呼吸灯代码有重合部分所以本文中的代码由前者修改而来,对于文件命名不要在意。 pwm_led.c文件 /* 编写步…

分布式ID生成方案详解

目录 引言 一. UUID(Universally Unique Identifier) UUID版本 版本1 UUID 版本4 UUID UUID用途 二、数据库自增ID 三. 基于Redis的方案 四. Twitter的snowflake算法 五、百度UidGenerator 结语 引言 在分布式系统中,生成唯一标识…

Nginx跳转模块之rewrite

一.location与rewrite模块的区别 rewrite:对访问的域名或者域名内的URL路径地址重写 location:对访问的路径做访问控制或者代理转发 二.rewrite模块基本内容 1.功能 通过正则表达式的匹配来改变URI,可以同时存在一个或多个指令&#xff0c…

第三百六十六回

文章目录 1. 概念介绍2. 使用方法2.1 List2.2 Map2.3 Set 3. 示例代码4. 内容总结 我们在上一章回中介绍了"convert包"相关的内容,本章回中将介绍collection.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 我们在本章回中介绍的内容是col…

蜂窝物联网咖WiFi认证解决方案

项目背景 随着目前网咖模式越来越流行,给网吧部署一套无缝漫游的WIFI网络势在必行。同时,网吧无线准入的验证码在客户机上面进行更新,以防周边的人员进行蹭网,损失网吧的外网带宽。 01 需求分析 1. 网吧服务区域全部覆盖无盲区…