Multi-Concept Customization of Text-to-Image Diffusion——【代码复现】

news/2024/7/10 21:45:16 标签: 图像处理, stable diffusion, 深度学习

本文是发表于CVPR 2023上的一篇论文:[2212.04488] Multi-Concept Customization of Text-to-Image Diffusion (arxiv.org)

一、引言

本文主要做的工作是对stable-diffusion的预训练模型进行微调,需要的显存相对较多,论文中测试时是在两块GPU上微调,需要30GB的显存,不过他调的batchsize=8,因为我自己的算力有限,我把复现的时候把batchsize调成了2,然后在两块3090上跑的,至于最低要求多少还没测试,不过个人认为最低也要有一张3090。

在复现前,请自行安装好Python的环境,本文就不叙述了哈哈。

二、下载相关文件及搭建环境

1.下载项目及环境搭建

adobe-research/custom-diffusion: Custom Diffusion: Multi-Concept Customization of Text-to-Image Diffusion (CVPR 2023) (github.com)

上述链接是本文代码的链接,这篇文章的代码实际上是基于Stable-diffusion构建的,所以我的建议是可以先去复现一下stable-diffusion的代码,再来学习这篇文章以及代码。stable-diffusion的复现可以看我另外一篇文章:stable-diffusion复现笔记,当然如果你想直接上手,可以按照项目中readme来构建,这里我默认已经有装过stable-diffusion了哈,因为很多文件都是相同的,如果你是直接上手,有些文件比如sd-v1-4.ckpt的下载等问题,都可以去看我这篇stable-diffusion复现笔记。

git clone https://github.com/adobe-research/custom-diffusion.git
cd custom-diffusion
git clone https://github.com/CompVis/stable-diffusion.git
cd stable-diffusion
conda env create -f environment.yaml
conda activate ldm
pip install clip-retrieval tqdm 

上述是论文给出的环境搭建代码,如果你跟我一样已经做过stable-diffusion的安装,可以直接执行最后一行 pip install clip-retrieval tqdm 。

2.下载数据集

复现的时候我用的是官方给的数据集,下载地址:https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip

三、运行

复现的过程我主要采用以生成的图像作为正则化来实现,方便起见,主要还是按照官方给的示例来复现。

1.单一概念微调——生成的图像作为正则化

第一步:这里我们可以直接执行命令文件,<pretrained-model-path>是预训练模型的路径,如:/data/disk1/sxtang/models/sd-v1-4.ckpt

bash scripts/finetune_gen.sh "cat" data/cat gen_reg/samples_cat  cat finetune_addtoken.yaml <pretrained-model-path>

这个sh文件会执行两个脚本文件:sample.py、train.py。

先执行sample.py生成用于正则化的图像,一共是200张,然后再执行train.py文件对预训练的模型进行微调,如果一切顺利,命令行最后的输入应该如下:

生成的正则化图像的目录:

 

微调所得模型目录:

 

复现过程中我所遇到的问题:

(1).我是在RTX3090上进行采样生成图片的,但是如果按照代码中默认的参数去执行,我的显存是不够的(论文毕竟是在两块A100做的),然后我的解决方法是把参数调了一下,改成:

--n_samples 5  --n_iter 40 

这里主要还是根据自己的情况去调整,如果还是爆显存的话,可以把数值都调小点,然后多执行几次sample脚本也是可以的。

(2).之前也说了,代码默认的batchsize=4,我跑不了哈哈,所以调整一下batchsize的大小。

具体的,在configs/custom-diffusion/finetune_addtoken.yaml文件中更改:

(3).TypeError: CUDACallback.on_train_epoch_end() missing 1 required positional argument: 'outputs'问题。

这里主要是pytorch-lighting的版本问题,需要把这个outputs参数删掉,具体的,在train.py文件下的on_train_epoch_end函数中:

 

(4).pytorch_lightning.utilities.exceptions.MisconfigurationException: No `test_dataloader()` method defined to run `Trainer.test`.


 这里说什么没定义这个方法,解决的方法就是在运行的时候直接加上参数--no-test即可。

第二步:更新权重

执行下面的命令即可实现,这里<folder-name> 就是你微调后的那个模型的文件夹,比如:2024-01-13T14-11-49_cat-sdv4,这一步我在执行过程中没有遇到什么问题。

## save updated model weights
python src/get_deltas.py --path logs/<folder-name> --newtoken 1

第三步:运行

## sample
python sample.py --prompt "<new1> cat playing with a ball" --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch\=000004.ckpt --ckpt <pretrained-model-path>

这个new1就是个占位符,无需更改;<folder-name>和上述的含义一样,最后这个“000004.ckpt”是你想要用的权重文件名称。 最后--ckpt <pretrained-model-path> 就是预训练的模型路径。

如果一切顺利的话,就会出图啦!

图片存放的位置以及我生成的图片如下:

 

2.多概念微调——生成的图像作为正则化

官方的readme中只给出了基于真实图像的代码,所以自己实现了一下生成图像正则化。

第一步:生成正则化图像。

上面我们已经生成的cat的正则化图像,这里还需要wooden_pot的正则化图像,所以我们需要先采样生成图像,我这里用的命令如下:

python -u sample.py \
        --n_samples 5 \
        --n_iter 40 \
        --scale 6 \
        --ddim_steps 50  \
        --ckpt  /data/disk1/sxtang/models/sd-v1-4.ckpt  \  #预训练模型的路径
        --ddim_eta 1. \
        --outdir "gen_reg/samples_wooden_pot" \   # 输出图像的路径
        --prompt "photo of a wooden_pot" 

 第二步:微调,这里我稍微改了一下那个项目中给出的基于真实图像实现的.sh文件

#!/usr/bin/env bash
#### command to run with retrieved images as regularization
# 1st arg: target caption1
# 2nd arg: path to target images1
# 3rd arg: path where retrieved images1 are saved
# 4rth arg: target caption2
# 5th arg: path to target images2
# 6th arg: path where retrieved images2 are saved
# 7th arg: name of the experiment
# 8th arg: config name
# 9th arg: pretrained model path

ARRAY=()

for i in "$@"
do
    echo $i
    ARRAY+=("${i}")
done


python -u  train.py \
        --base configs/custom-diffusion/${ARRAY[7]}  \
        -t --gpus 6,7 \
        --resume-from-checkpoint-custom  ${ARRAY[8]} \
        --caption "<new1> ${ARRAY[0]}" \
        --datapath ${ARRAY[1]} \
        --reg_datapath "${ARRAY[2]}/samples" \
        --reg_caption "${ARRAY[0]}" \
        --caption2 "<new2> ${ARRAY[3]}" \
        --datapath2 ${ARRAY[4]} \
        --reg_datapath2 "${ARRAY[5]}/samples" \
        --reg_caption2 "${ARRAY[3]}" \
        --modifier_token "<new1>+<new2>" \
        --name "${ARRAY[6]}-sdv4"

 执行命令:

bash scripts/finetune_joint_gen.sh "wooden pot" data/wooden_pot gen_reg/samples_wooden_pot \
                                    "cat" data/cat gen_reg/samples_cat  \
                                    wooden_pot+cat finetune_joint.yaml /data/disk1/sxtang/models/sd-v1-4.ckpt

注:如果需要调整如batchsize等参数,这里是在finetune_joint.yaml文件中更改。

如果一切顺利,出现如下界面,就代表着微调成功啦:

后面两步和单个概念那边一样,这里不过多叙述。

第二步:更新权重

## save updated model weights
python src/get_deltas.py --path logs/<folder-name> --newtoken 2

 第三步:运行

## sample
python sample.py --prompt "the <new2> cat sculpture in the style of a <new1> wooden pot" --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch\=000004.ckpt --ckpt <pretrained-model-path>

下面是我测试所生成的图像:

四、最后

这篇文章和Dreambooth等有着异曲同工之妙,都是为了实现个性化的图像生成,当然论文中还有比如通过diffusers实现等功能,如果感兴趣可以自己去试试。


http://www.niftyadmin.cn/n/5321634.html

相关文章

【打卡】牛客网:BM79 打家劫舍(二)

资料&#xff1a; dp.clear()会把dp的size变为0。 assign和insert的对比&#xff1a; v1.assign(v2.begin(), v2.end()); v1.insert(pos,n,elem); //在pos位置插入n个elem数据&#xff0c;无返回值。 v1.insert(pos,beg,end); //在pos位置插入[beg,end)区间的数据&#xff0…

力扣_数组30—将有序数组转换为二叉搜索数

题目 给你一个整数数组 nums &#xff0c;其中元素已经按 升序 排列&#xff0c;请你将其转换为一棵 高度平衡 二叉搜索树。&#xff08;结果不唯一&#xff09; 高度平衡 二叉树是一棵满足「每个节点的左右两个子树的高度差的绝对值不超过 1 」的二叉树。 复习 二叉搜索树…

物理机本地和集群部署Spark

一、单机本地部署 1&#xff09;官网地址&#xff1a;http://spark.apache.org/ 2&#xff09;文档查看地址&#xff1a;https://spark.apache.org/docs/3.1.3/ 3&#xff09;下载地址&#xff1a; https://spark.apache.org/downloads.html https://archive.apache.org/dist/…

2024.1.13 Kafka六大机制和Structured Streaming

目录 一 . Kafka中生产者数据分发策略 二. Kafka消费者的负载均衡机制 三 . 数据不丢失机制 生产者端是如何保证数据不丢失的呢&#xff1f; Broker端如何保证数据不丢失 消费端如何保证数据不丢失 Kafka中消费者如何对数据仅且只消费一次 四 . 启动Kafka eagle命令 数…

CSS样式学习

html超文本传输标签&#xff0c;属性等权重 outline 标签轮廓 <input type"text"> <textarea cols"30" rows"10"></textarea> outline: none; 表示无轮廓 &#xff08;开发时用的比较多&#xff09; CSS 轮廓&#xff…

git 提交符号

emojiemoji代码commit说明&#x1f3a8; (调色板):art:改进代码结构/代码格式⚡️ (闪电):zap:提升性能&#x1f40e; (赛马):racehorse:提升性能&#x1f525; (火焰):fire:移除代码或文件&#x1f41b; (bug):bug:修复 bug&#x1f691; (急救车):ambulance:重要补丁✨ (火花…

【微信小程序独立开发1】项目提出和框架搭建

前言&#xff1a;之前学习小程序开发时仿照别人的页面自己做了一个商城项目和小说项目&#xff0c;最近突发奇想&#xff0c;想从0开发一个关于《宠物日记》的小程序&#xff0c;需求和页面都由自己设计&#xff0c;将在这记录开发的全部流程和过程中遇到的难题等... 1、搭建小…

在CentOS上设置和管理静态HTTP网站的版本控制

在CentOS上设置和管理静态HTTP网站的版本控制是一项重要的任务&#xff0c;它可以帮助您跟踪和回滚对网站所做的更改&#xff0c;确保数据的一致性和完整性。以下是在CentOS上设置和管理静态HTTP网站的版本控制的步骤&#xff1a; 安装版本控制系统在CentOS上安装Git或其他版本…