全国免费咨询:

13245491521

VR图标白色 VR图标黑色
X

中高端软件定制开发服务商

与我们取得联系

13245491521     13245491521

2024-01-14_LoRA原理与实现--PyTorch自己搭建LoRA模型

您的位置:首页 >> 新闻 >> 行业资讯

LoRA原理与实现--PyTorch自己搭建LoRA模型 一、前言在AIGC领域频繁出现着一个特殊名词“LoRA”,听上去有点像人名,但是这是一种模型训练的方法。LoRA全称Low-Rank Adaptation of Large Language Models,中文叫做大语言模型的低阶适应。如今在stable diffusion中用地非常频繁。 由于大语言模型的参数量巨大,许多大公司都需要训练数月,由此提出了各种资源消耗较小的训练方法,LoRA就是其中一种。 本文将详细介绍LoRA的原理,并使用PyTorch实现小模型的LoRA训练。 二、模型训练现在大多数模型训练都是采用梯度下降算法。梯度下降算法可以分为下面4个步骤: 正向传播计算损失值反向传播计算梯度利用梯度更新参数重复1、2、3的步骤,直到获取较小的损失以线性模型为例,模型参数为W,输入输出为x、y,损失函数以均方误差为例。那么各个步骤的计算如下,首先是正向传播,对于线性模型来说就是做一个矩阵乘法: $在求出损失后,可以计算L对W的梯度,得到dW: $dW是一个矩阵,它会指向L上升最快的方向,但是我们的目的是让L下降,因此让W减去dW。为了调整更新的步伐,还会乘上一个学习率η,计算如下: $最后一直重复即刻。上述三个步骤的伪代码如下: #4、重复1、2、3 foriinrange(10000): #1、正向传播计算损失 L=MSE(Wx,y) #2、反向传播计算梯度 dW=gradient(L,W) #3、利用梯度更新参数 W-=lr*dW 在更新完成后,得到新的参数W'。此时我们使用模型预测时,计算如下: $三、引入LoRA我们可以来思考一下W和W'之间的关系。W通常指基础模型的参数,而W'是在基础模型的基础上,经过几次矩阵加减得到的。假设在训练的过程中更新了10次,每次的dW分别为dW1、dW2、....、dW10,那么完整的更新过程可以写为一次运算: $其中dW是一个形状与W'一致的矩阵。我们把-ηdW写成矩阵R,那么更新后的参数就是: $此时训练的过程就被简化为原矩阵加上另一个矩阵R。但是求解矩阵R并没有更简单,而且也没有节约资源,此时就引出LoRA了这一思想。 一个训练充分的矩阵,通常是满秩或者基本满足秩的,即矩阵中没有一列是多余的。在论文《Scaling Laws for Neural Language Model》中提出了数据集与参数大小之间的关系,满足该关系且训练良好,得到的模型是基本满秩的。在微调模型时,我们会选取一个底模,该底模就是基本满秩的。而更新矩阵R秩的情况是如何的呢? 我们假定R矩阵是一个低秩矩阵,低秩矩阵有许多重复的列,因此可以分解为两个更小的矩阵。假如W的形状为m×n,那么A的形状也是m×n,我们把矩阵R分解为AB(其中A形状为m×r,B形状为r×N),r通常会选取一个远小于m、n的值,如图所示: image.png将低秩矩阵分解为两个矩阵几点好处,首先是参数量明显减少。假设R矩阵的形状为100×100,那么R的参数量为10000。当我们选取秩为10时,此时矩阵A的形状为100×10,矩阵B的形状为10×100,此时参数量为2000,比R矩阵少了80%。 而且由于R是低秩矩阵,所以在训练充分的情况下,A和B矩阵可以达到R的效果。这里的矩阵AB就是我们常说的LoRA模型。 在引入LoRA后,我们的预测需要将x分别输入W和AB,此时预测的计算为: $在预测时会比原始模型稍慢,但是在大模型中基本感觉不到差异。 四、实战为了把握各个细节,这里不使用大模型作为lora的实战,而是选择使用vgg19这种小型网络来训练lora模型。导入需要用到的模块: importos importtorch fromtorchimportoptim,nn fromPILimportImage fromtorch.utilsimportdata fromtorchvisionimportmodels fromtorchvision.transformsimporttransforms 4.1 数据集准备 这里使用vgg19在imagenet上的预训练权重作为底模,因此需要准备分类数据集。为了方便,这里只准备了一个类别,且只准备了5张图片,图片在项目下的data/goldfish下: image.png在imagenet中包含了goldfish类别,但是这里选取的是插画版的goldfish,经过测试,预训练模型不能将上述图片正确分类。我们的目的就是训练LoRA,让模型正确分类。 我们创建一个LoraDataset: transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]), ]) classLoraDataset(data.Dataset): def__init__(self,data_path="datas"): categories=models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"] self.files=[] self.labels=[] fordirinos.listdir(data_path): dirname=os.path.join(data_path,dir) forfileinos.listdir(dirname): self.files.append(os.path.join(dirname,file)) self.labels.append(categories.index(dir)) def__getitem__(self,item): image=Image.open(self.files[item]).convert("RGB") label=torch.zeros(1000,dtype=torch.float64) label[self.labels[item]]=1. returntransform(image),label def__len__(self): returnlen(self.files) 4.2 创建LoRA模型 我们把LoRA封装成一个层,LoRA中只有两个需要训练的矩阵,LoRA的代码如下: classLora(nn.Module): def__init__(self,m,n,rank=10): super().__init__() self.m=m self.A=nn.Parameter(torch.randn(m,rank)) self.B=nn.Parameter(torch.zeros(rank,n)) defforward(self,inputs): inputs=inputs.view(-1,self.m) returntorch.mm(torch.mm(inputs,self.A),self.B) 其中m是输入的大小,n是输出的大小,rank是秩的大小,我们可以设置一个较小的值。 在权重初始化时,我们把A用高斯噪声初始化,而B用0矩阵初始化,这样的目的是保证从底模开始训练。因为AB是0矩阵,所以初始状态下,LoRA不起作用。 4.3 设置超参数并训练 接下来就是训练了,这里和PyTorch常规训练代码基本一致,先看代码: #加载底模和lora vgg19=models.vgg19(models.VGG19_Weights.IMAGENET1K_V1) forparamsinvgg19.parameters(): params.requires_grad=False vgg19.eval() lora=Lora(224*224*3,1000) #加载数据 lora_loader=data.DataLoader(LoraDataset(),batch_size=batch_size,shuffle=True) #加载优化器 optimizer=optim.Adam(lora.parameters(),lr=lr) #定义损失 loss_fn=nn.CrossEntropyLoss() #训练 forepochinrange(epochs): forimage,labelinlora_loader: #正向传播 pred=vgg19(image)+lora(image) loss=loss_fn(pred,label) #反向传播 loss.backward() #更新参数 optimizer.step() optimizer.zero_grad() print(f"loss:{loss.item()}") 这里有两点需要注意,第一点是我们把vgg19的权重设置为不可训练,这和迁移学习很像,但其实是不一样的。 第二点则是正向传播时,我们使用了下面代码: pred=vgg19(image)+lora(image) 4.4 测试 下面来简单测试一下: #测试 forimage,_inlora_loader: pred=vgg19(image)+lora(image) idx=torch.argmax(pred,dim=1).item() category=models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"][idx] print(category) torch.save(lora.state_dict(),'lora.pth') 输出结果如下: goldfish goldfish goldfish goldfish goldfish 基本预测正确了,不过这个测试结果并不能说明什么。最后我们保存了一个5M的LoRA模型,相比vgg19的几十M算是非常小了。 五、总结LoRA是针对大模型的一种高效的训练方法,而本文则将LoRA使用在小型的分类网络中,旨在让读者更清晰认识LoRA的详细实现(同时也因为跑不动大模型)。限于数据量,对LoRA的精度效率等问题没有详细讨论,读者可以参考相关资料深入了解。 阅读原文

上一篇:2023-09-13_如何禁止别人调试自己的前端页面代码? | 文末福利 下一篇:2019-09-03_云服务提供者应当采取何种「必要措施」? ——从首例云服务器侵权案谈起

TAG标签:

20
网站开发网络凭借多年的网站建设经验,坚持以“帮助中小企业实现网络营销化”为宗旨,累计为4000多家客户提供品质建站服务,得到了客户的一致好评。如果您有网站建设网站改版域名注册主机空间手机网站建设网站备案等方面的需求...
请立即点击咨询我们或拨打咨询热线:13245491521 13245491521 ,我们会详细为你一一解答你心中的疑难。
项目经理在线

相关阅读 更多>>

猜您喜欢更多>>

我们已经准备好了,你呢?
2022我们与您携手共赢,为您的企业营销保驾护航!

不达标就退款

高性价比建站

免费网站代备案

1对1原创设计服务

7×24小时售后支持

 

全国免费咨询:

13245491521

业务咨询:13245491521 / 13245491521

节假值班:13245491521()

联系地址:

Copyright © 2019-2025      ICP备案:沪ICP备19027192号-6 法律顾问:律师XXX支持

在线
客服

技术在线服务时间:9:00-20:00

在网站开发,您对接的直接是技术员,而非客服传话!

电话
咨询

13245491521
7*24小时客服热线

13245491521
项目经理手机

微信
咨询

加微信获取报价