Generative AI 新世界 | 文生图领域动手实践:预训练模型的微调

news2025/1/11 3:00:07

在上期文章,我们探讨了预训练模型的部署和推理,包括运行环境准备、角色权限配置、支持的主要推理参数、图像的压缩输出、提示工程 (Prompt Engineering)、反向提示 (Negative Prompting) 等内容。

亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术,观点,和项目,并将中国优秀开发者或技术推荐给全球云社区。如果你还没有关注/收藏,看到这里请一定不要匆匆划过,点这里让它成为你的技术宝库!

本期文章,我们将探讨如何在自定义数据集上来微调(fine-tuned)模型,该模型可以针对任何图像数据集进行微调。即使你手上只有几张自定义的图像提供做训练,模型也能输出比较理想的结果。

首先,让我们通过一篇论文的概括解读,来了解这种文生图模型的微调 (fine-tuned),背后的工作原理和理论基础知识。

DreamBooth 论文概述

这种文生图模型的微调(fine-tuned)理论基础来自于 DreamBooth 论文,如下图所示:

image.png

DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-DrivenGeneration

https://arxiv.org/pdf/2208.12242.pdf?trk=cndc-detail

在论文的开头,作者提出一个挑战性的问题:

虽然当时的文生图模型已经可以根据给定的 **prompt **生成高质量的图片,但是这些模型并不能模仿给定参考图片中的物体要素,在不同情景中来生成新的图片。

举个例子。

我家里有一只叫做“小花”的可爱加菲猫,如下图:

image.png

我想让加菲猫“小花”带上一顶礼帽,如下图:

image.png

或者带上一副很酷炫的墨镜,如下图:

image.png

甚至想象下她刷牙的魔幻景象,如下图:

image.png

事实上,上面的这些加菲猫“小花”的照片(戴礼帽、戴墨镜、刷牙),都是大模型使用 DreamBooth 做微调后生成的。很有趣吧?在文末会提供生成这些魔幻照片的全部代码。

我们先看下 DreamBooth 论文阐述的背后原理。

DreamBooth 论文提出一个新颖的方法:将输入图片中的物体与一个特殊标识符绑定在一起,即用这个特殊标记符来表示输入图片中的物体。因此论文为微调模型设计了一种 prompt 格式:a [identifier] [class noun],即将所有输入图片的 prompt 都设置成这种形式,其中 identifier 是一个与输入图片中物体相关联的特殊标记符,class noun 是对物体的类别描述。

这里之所以在 prompt 中加入类别,是因为想利用预训练模型中关于该类别物品的先验知识,并将先验知识与特殊标记符相关信息进行融合,这样就可以在不同场景下生成不同姿势的目标物体。

简单来说就是:不要学了新的知识,就忘了旧的知识

论文提出的方法,大致如下图所示,即仅仅通过 3 到 5 张图片去微调文生图模型,使得模型能将输入图片中特定的物品和 prompt 中的特殊标记符关联起来了。

image.png

Source: https://dreambooth.github.io\?trk=cndc-detail

关于特殊标记符的选择,论文提出通过在词表中选择罕见词来作为特殊标记符,这样避免了预训练模型对特殊标记符有很强烈的先验知识。

DreamBooth 论文提出一个新的微调方法:**通过预先生成的一些图像,来保留先验损失权重;以此来解决过拟合与语言漂移问题。**用模型自己生成的样本来监督模型,以便在 few-shot(小样本)微调开始后保留先验知识,如以下论文中提供的解释图所示:

image.png

Source: https://dreambooth.github.io/?trk=cndc-detail

给定大约 3-5 张拍摄对象的图像,我们分两步微调文本到图像的扩散:

  1. 使用输入图像与包含唯一标识符和主题所属类名称(例如:“A photo of a [T] dog”)的文本提示配对;同时,我们应用特定于类的预先保存损失,它利用了模型之前的语义通过在文本提示中注入类名,来鼓励它生成属于受试者类的各种实例提示(例如:“A photo of a dog”)。
  2. 使用从我们的输入图像集中拍摄的低分辨率和高分辨率图像,对超分辨率组件进行微调,这使我们能够保持对拍摄对象小细节的高保真度。

引入了先验损失的 loss 公式,如下所示:

image.png

通过这种 DreamBooth 方法,使得:输入训练集 + 提示词 [v] dog,然后还有用模型本身自己生成的 dog 图像,训练完成后得到了一个特殊标记符:[v]。通过这个特殊标记符 [v],就把这次训练的 dog 和其他本身学过的 dog 分开了。

最后得到惊艳的结果,比如给一只小熊带上太阳镜,如下图所示:

image.png

Source: https://dreambooth.github.io/?trk=cndc-detail

接下来,我们将完整用代码演示,如何给我家的加菲猫“小花”带上眼镜和礼帽。

Fine-tune 预训练模型在自有数据集上的微调

我们使用 Amazon SageMaker Studio 来实现在自有数据上的模型微调。

我首先将为我家的加菲猫“小花”拍摄几张照片,然后用这几张照片来微调模型;完成模型微调后,我们将使用 “a picture of Garfield cat with glasses” 这样的提示词,来直接为我家的加菲猫“小花”带上眼镜。

1 实例和环境准备

这个 Notebook 在带有 Python 3(Data Science)内核的 SageMaker Studio 中,使用 ml.t3.medium 实例上进行了测试。要对数据集的模型进行微调,您需要在账户中提供 ml.g4dn.2xlarge 实例类型。

完整的示例代码,可参考以下 GitHub 文档链接,从 “Fine-tune the pre-trained model on a custom dataset” 这一部分开始阅读代码:

https://github.com/aws/studio-lab-examples/blob/main/generative-deep-learning/stable-diffusion-finetune/Amazon_JumpStart_Text_To_Image.ipynb?trk=cndc-detail

你存放自定义照片的 s3 路径,应该看起来像这样:s3://bucket_name/input_directory/

请注意,后面的“/”为必填项。

以下是训练数据的示例格式:

input_directory
    |---instance_image_1.png
    |---instance_image_2.png
    |---instance_image_3.png
    |---instance_image_4.png
    |---instance_image_5.png
    |---dataset_info.json
    |---class_data_dir
        |---class_image_1.png
        |---class_image_2.png
        |---class_image_3.png
        |---class_image_4.png

 

预先保存、实例提示和类提示(Prior preservation, instance prompt and class prompt):预先保存是一种使用我们正在尝试训练的同一个类的其他图像的技术。例如,如果训练数据由特定狗的图像组成,并事先保存,则我们会合并普通犬的类别图像。它试图通过在为特定狗训练时显示不同狗的图像来避免过度拟合。类提示中缺少表示实例提示中存在的特定狗的标签。

例如,实例提示可能是 “A photo of a Garfield cat”,类提示可能是 “A photo of a cat”。

您可以通过将超参数设置为 _prior_preservation = True 来启用预先保存。

以下为使用我家加菲猫“小花”的照片的 dataset_info.json 的文件示例:

$ cat dataset_info.json
{
  "instance_prompt": "A photo of a Garfield cat",
  "class_prompt": "A photo of a cat"
}

 

以下是我为了微调模型,而拍摄的我家加菲猫“小花”的照片。我只用了下面这六张照片,就实现了模型的微调。

image.png

我存放照片(即为微调模型提供的自定义训练图片)的 S3 桶参考路径如下:s3://sagemaker-us-east-1-xxxxxxxxxxxx/haowen-datasets/cat_finetuning/

其中 “sagemaker-us-east-1-xxxxxxxxxxxx” 需要更新为你自己定义的桶名。

最终完成微调后,模型存放的 S3 桶参考路径如下:s3://sagemaker-us-east-1-xxxxxxxxxxxx/jumpstart-example-sd-training/output

其中 “sagemaker-us-east-1-xxxxxxxxxxxx” 需要更新为你自己定义的桶名。

2 检索训练数据的 Artifacts

在这里,我们检索训练 docker 容器、训练算法源和预先训练的基础模型。

请注意,model_version= “*” 获取的是最新的模型版本号。以下代码选择了 Stable Diffusion V2.1 Base 的文生图大模型。

# Select a model 
train_model_id, train_model_version, train_scope = (
    "model-txt2img-stabilityai-stable-diffusion-v2-1-base",
    "*",
    "training",
)

以下代码选择了微调模型的实例是 ml.g4dn.2xlarge:

training_instance_type = "ml.g4dn.2xlarge"

以下代码获取 Docker Image:

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    model_id=train_model_id,
    model_version=train_model_version,
    image_scope=train_scope,
    instance_type=training_instance_type,
)

 

以下代码获取训练脚本:

# Retrieve the training script. This contains all the necessary files including data processing, model training etc.
train_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
)

以下代码获取预训练模型的 tarball 包,用于之后的微调工作:

# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)

3 设置训练参数

现在我们已经完成了所有需要的设置,我们已经准备好微调 Stable Diffusion 模型了。首先,让我们创建一个 sageMaker.estimator.Estimator 对象。该 Estimator 将启动训练作业。

模型的微调训练需要设置两种参数。

第一组参数是训练作业的参数。其中包括:

  1. 训练数据路径,这是存储输入数据的 S3  路径。即之前我们准备的 “s3://sagemaker-us-east-1-xxxxxxxxxxxx/haowen-datasets/cat_finetuning/” 这个路径;
  2.  输出路径,这是存储微调模型训练的输出 s3 路径。即之前我们准备的“s3://sagemaker-us-east-1-xxxxxxxxxxxx/jumpstart-example-sd-training/output” 这个路径;
  3. 训练实例类型,这表示运行模型微调训练的机器类型。我们在上面定义了训练实例类型,以获取正确的 train_image_uri。

第二组参数是特定于算法的训练超参数。对于算法特定的超参数,我们首先获取算法接受的训练超参数的 python 字典及其默认值,然后可以将其改写为自定义值。示例代码如下所示:

from sagemaker import hyperparameters

# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(
    model_id=train_model_id, model_version=train_model_version
)

# [Optional] Override default hyperparameters with custom values
hyperparameters["max_steps"] = "400"
print(hyperparameters)

4 启动模型微调训练

我们首先使用所有必需的 assets 创建 estimator 对象,然后启动训练作业。

from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.tuner import HyperparameterTuner

training_job_name = name_from_base(f"jumpstart-example-{train_model_id}-transfer-learning")

# Create SageMaker Estimator instance
sd_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",  # Entry-point file in source_dir and present in train_source_uri.
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
)

# Launch a SageMaker Training job by passing s3 path of the training data
sd_estimator.fit({"training": training_dataset_s3_path}, logs=True)

模型训练开始后,如果观察 SageMaker 的控制台,会发现:

  1. 训练任务的状态,从 “InProgress” 逐渐变成 “Completed”;
  2. 超参调优的状态,从 “InProgress” 逐渐变成 “Completed”。

如下图所示:

image.png

image.png

image.png

经过那六张照片作为新的输入数据,微调后的模型重新训练完成后,就可以进入以下的模型部署阶段了。

5 微调后模型的部署

我们将遵循上一篇中介绍的模型部署的相同步骤,在训练好的模型上运行推理。我们首先检索用于部署端点的 jumpstart 工件。但是,我们部署的是经过微调的 sd_estimator 估算器,而不是上一篇中使用的 base_predictor 估算器。

inference_instance_type = "ml.g4dn.2xlarge"

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=train_model_id,
    model_version=train_model_version,
    instance_type=inference_instance_type,
)
# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope="inference"
)

endpoint_name = name_from_base(f"jumpstart-example-FT-{train_model_id}-")

# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = sd_estimator.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    endpoint_name=endpoint_name,
)

在等待新模型部署的过程中,可以回到 SageMaker 的控制台,在 Endpoints 项中刷新检查模型部署的情况。当 Status 从 “Creating” 变成 “Completed”,就表示微调后的新模型已经部署完成可以开始进行推理了。如下图所示:

image.png

6 微调后模型的推理

下面进入激动人心的时刻,我们在微调后的模型上进行推理。

我输入的提示词是:“a photo of a Garfield cat with a hat”(一只带帽子的加菲猫)。

text = " a photo of a Garfield cat with a hat"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

模型的魔幻输出如下图所示。我们成功地给加菲猫“小花”带上礼帽了!

image.png

接着我们给加菲猫“小花”带上眼镜,我输入的提示词是:“a picture of Garfield cat with glasses”:

text = " a picture of Garfield cat with glasses"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

模型的输出如下:

image.png

最后让加菲猫“小花”像人类一样去刷牙,我输入的提示词是:“a picture of Garfield cat brushing her teeth”:

text = " a picture of Garfield cat brushing her teeth"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

image.png

神奇吧?加菲猫“小花”会自己刷牙了!

7 计算资源删除和清理

和以前一样,实验完成后别忘记清除相关的 endpoint 资源,以避免产生不必要的费用:

# Delete the SageMaker endpoint
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()

总结

本文我们学习了如何使用 Amazon SageMaker JumpStart 方便地微调文生图的 Stable Diffusion 模型。

Amazon SageMaker JumpStart 为预训练的模型提供了微调功能,本文的例子中,你只需使用六张训练图像即可根据自己的用例调整模型。这在创建个性化艺术品、独特的徽标、企业的 LOGO、或者其他需要自定义设计的场景时非常有用。

下一期的文章,我们将重新回到文本生成的大模型场景,探讨如何在 Amazon SageMaker JumpStart 上部署当今炙手可热的开源大语言模型。我们将以 Falcon 40B 开源大模型为例,逐行代码轻松部署高达 400 亿参数的这个大型语言模型。敬请期待。

请持续关注 Build On Cloud 专栏,了解更多面向开发者的技术分享和云开发动态!

 

作者 黄浩文

亚马逊云科技资深开发者布道师,专注于 AI/ML、Data Science 等。拥有 20 多年电信、移动互联网以及云计算等行业架构设计、技术及创业管理等丰富经验,曾就职于 Microsoft、Sun Microsystems、中国电信等企业,专注为游戏、电商、媒体和广告等企业客户提供 AI/ML、数据分析和企业数字化转型等解决方案咨询服务。

文章来源:https://dev.amazoncloud.cn/column/article/64cb87265306fa4a7fa3a3c9?sc_medium=regulartraffic&sc_campaign=crossplatform&sc_channel=CSDN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1077283.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

掌握C语言:开启编程世界的大门

掌握C语言:开启编程世界的大门 C语言编写的程序通常更高效,代码行数更少,适用于需要高性能的场景。掌握C语言还为你打开了学习其他高级编程语言的大门。C语言拥有庞大的开源社区和丰富的现成代码库,为你快速开发算法和函数提供了…

Nginx配置ssl证书(https证书)

Nginx配置ssl证书(https证书) 安装nginxNginx 的 SSL 模块安装下载Nginx 服务证书配置nginx.conf 安装nginx 搭建服务器,安装docker-compose https://blog.csdn.net/qq_33240556/article/details/124789530 安装docker-compose nginx https://blog.csdn.net/qq_33240556/artic…

做运维有前途吗?

不管男生女生,都不建议做运维!!就一个原因,性价比太低!需要会的东西多,没有一个统一的运维标准!你心目中的运维和别人心目中的运维,不是一个运维!也不建议做测试&#xf…

【软件测试】博客系统项目测试报告(ssm项目)

文章目录 一. 报告概要二. 引言三. 测试环境四. 测试执行概况及功能测试1. 手工测试1.1 编写测试用例1.2 执行部分测试用例 2. 自动化测试Selenium2.1 编写测试用例2.2自动化测试代码1. 自动化测试工具类2. 博客登录页测试3. 博客注册页4. 博客详情页5. 博客编辑页6. 博客列表页…

易点易动:解决纸质固定资产审批痛点,助您高效自定义审批流程

固定资产审批是企业日常管理中不可或缺的环节,然而,传统的纸质审批流程常常面临繁琐、低效的问题。易点易动作为一款先进的固定资产管理系统,以其自定义设置流程的特点,为企业打破审批瓶颈,实现高效审批提供了理想解决…

轻量级虚拟化技术草稿

Support Tech ST.1 virtiofs ST.1.1 fuse framework 引用wiki中关于fuse的定义: Filesystem in Userspace (FUSE) is a software interface for Unix and Unix-like computer operating systems that lets non-privileged users create their own file systems w…

Python 编程基础概念

目录 1 Python程序的构成1.1 使用\行连接符 2 对象3 引用4 标识符4.1 Python标识符命名规则 5 变量和简单赋值语句5.1 变量的声明和赋值5.2 删除变量和垃圾回收机制5.3 链式赋值5.4 系列解包赋值5.5 常量 6 最基本内置数据类型和运算符6.1 基本运算符6.2 整数6.3 浮点数6.4 类型…

华为数通方向HCIP-DataCom H12-831题库(多选题:241-259)

第241题 设备产生的信息可以向多个方向输出信息,为了便于各个方向信息的输出控制,信息中心定义了10条信息通道,使通道之间独立输出,缺省情况下,以下哪些通道对应的输出方向可以接收Trap信息? A、console通道 B、logbuffer通道 C、snmpagent通道 D、trapbuffer通道 答案:…

山西电力市场日前价格预测【2023-10-11】

日前价格预测 预测说明: 如上图所示,预测明日(2023-10-11)山西电力市场全天平均日前电价为507.37元/MWh。其中,最高日前电价为873.70元/MWh,预计出现在18: 45。最低日前电价为313.23元/MWh,预计…

互联网从业者如何调节压力

互联网从业者面临着多种压力源,如工作性质、竞争、项目失败、对失败的恐惧等,这些压力会影响身心健康以及工作效率。因此,采取有效的压力调节方法是必要的,接下来我们从三个方向探讨下互联网从业者有关压力来源、应对压力的方法及…

leetcode每日一练-第977题-有序数组的平方

一、思路 双指针 二、 解题方法 i指向起始位置&#xff0c;j指向终止位置。 定义一个新数组result&#xff0c;和A数组一样的大小&#xff0c;让k指向result数组终止位置。 如果A[i] * A[i] < A[j] * A[j] 那么result[k--] A[j] * A[j]; 。 如果A[i] * A[i] > A[j…

Echarts使用感受

目录 数据处理 遇到的问题 更换echart主题 Y轴数字后添加百分比号 eCharts饼图显示百分比 echarts自定义主题的手把手教学 查看UI图 点击下方链接页面的定制主题按钮 点击下载主题 点击主题下载的JSON版本&#xff0c;点击复制 ​编辑 新建js文件&#xff0c;把复制的…

【Java】什么是API

API (Application Programming Interface,应用程序编程接口) Java中的API 指的就是 JDK 中提供的各种功能的 Java类&#xff0c;这些类将底层封装起来&#xff0c;我们不需要关心这些类是如何实现的&#xff0c;只需要学习这些类如何使用即可&#xff0c;我们可以通过帮助文档…

二、监控搭建-Prometheus-采集端部署

二、监控搭建-Prometheus-采集端部署 1、背景2、目标3、传承4、操作 1、背景 在上一篇中我们搭建了Prometheus平台&#xff0c;平台的搭建跟Linux系统上面安装了vim软件一样&#xff0c;给的只是一个很好的铸剑玄铁&#xff0c;具体的使用需要打磨和配件的运用。 2、目标 使…

XGBoost 2.0:对基于树的方法进行了重大更新

XGBoost是处理不同类型表格数据的最著名的算法&#xff0c;LightGBM 和Catboost也是为了修改他的缺陷而发布的。9月12日XGBoost发布了新的2.0版&#xff0c;本文除了介绍让XGBoost的完整历史以外&#xff0c;还将介绍新机制和更新。 这是一篇很长的文章&#xff0c;因为我们首…

转守为攻,亚马逊云换帅背后的战略转向

点击关注 文&#xff5c;刘雨琦 一则人事任命&#xff0c;揭开了亚马逊云在大中华区反击战的序幕。 10月9日&#xff0c;亚马逊云科技全球销售、市场和服务高级副总裁 Matt Garman 宣布了大中华区领导人变更任命&#xff0c;储瑞松将接替张文翊担任该职位&#xff0c;继续带领…

2023年网络安全岗位有哪些?金九银十别错过秋招!

网络安全有哪些岗位&#xff1f; 1. 安全服务工程师 7-10k 网络安全工程师、安全项目经理&#xff1a;主要负责甲方设备安全调试工作。需精通服务器、网络技术以及安全设备原理与配置。 2. 安全运维工程师 7-10k 安全运维工程师&#xff0c;主要对己方安全防御体系的运维和应急…

如何在 Spring Boot 中提高应用程序的安全性

如何在 Spring Boot 中提高应用程序的安全性 Spring Boot是一种流行的Java开发框架&#xff0c;用于构建Web应用程序和微服务。在构建应用程序时&#xff0c;安全性是至关重要的因素。不论您的应用程序是面向公众用户还是企业内部使用&#xff0c;都需要采取适当的措施来确保数…

(java)(python)以代理IP的方式进行请求数据

文章目录 前言(java)(python)以代理IP的方式进行请求数据1. python2. java 前言 如果您觉得有用的话&#xff0c;记得给博主点个赞&#xff0c;评论&#xff0c;收藏一键三连啊&#xff0c;写作不易啊^ _ ^。   而且听说点赞的人每天的运气都不会太差&#xff0c;实在白嫖的话…

洛谷100题DAY7

31.P1636 Einstein学画画 此题为欧拉通路&#xff0c;必须要满足奇点的个数为0或2个 奇点&#xff1a;度数&#xff08;入度出度&#xff09;为奇数的点 如果奇点为2个或者0个就可以直接一笔化成 eg. 我们发现奇数点个数每增加2个就多一笔 #include<bits/stdc.h> us…