生成AI(三)—创建自己的MidJorney

news2025/2/14 0:06:34

背景:MidJorney是面向互联网的图像AIGC产品,在政企内部,存在大量需求训练内部的知识作为自己的AIGC工具。基本需求是信息安全考虑,合规考虑。

目标:通过自准备的数据训练MidJorney同类模型,成为私有化部署的AIGC产品。

Diffusion原理:
在这里插入图片描述

训练:

扩散过程:从X0到XT,采用高斯函数“扩散”,无需学习算法可以直接推导。

反扩散过程:从XT到X0,根据时间t等相关进行无监督学习,得到模型参数。

推理:

通过文本向量转为输入向量,反扩散给定轮次,得到最终图片X0

技术方案:

最新的方式是文本到图像的AIGC生成,但要求的GPU配置笔者手边电脑无法达到。

本例:预训练蝴蝶数据集,预初始随机点生成新的蝴蝶图片,无需任何输入。

模型:本例采用无条件的Diffusion图片生成(有条件的例如,文本到图像、图像到图像等)

数据:Smithsonian Butterflies

时长:GPU3小时,CPU78小时

用法:直接命令生成蝴蝶图像

一、环境准备

1、安装Anaconda

查看Conda环境

conda env list 

创建新环境

conda create -n mydiffusion python=3.8
activate mydiffusion

2、安装依赖包

#如果是CPU
pip install torch
#如果是GPU,首先查看本机的GPU版本
#去官网查看pytorch与CUDA对应,https://pytorch.org/get-started/previous-versions/
#以下是11.6cuda安装pytorch例子
nvcc -V
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install diffusers[training]

#如果有提示其它包安装,看错误提示

pip install xxx

二、训练代码

train.py文件

from dataclasses import dataclass
from datasets import load_dataset
import matplotlib.pyplot as plt
from torchvision import transforms
import torch
from diffusers import UNet2DModel
import torch
from PIL import Image
from diffusers import DDPMScheduler
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDPMPipeline
import math
import os
from accelerate import Accelerator
from huggingface_hub import HfFolder, Repository, whoami
from tqdm.auto import tqdm
from pathlib import Path

from accelerate import notebook_launcher
import glob

#训练参数
@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0
#gpu or cpu  
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    device = torch.device("cuda")          
    print("There are %d GPU(s) available." % torch.cuda.device_count())
    print("We will use the GPU:", torch.cuda.get_device_name(0))
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")
config = TrainingConfig()

#加载数据集
config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")

#查看数据集
# fig, axs = plt.subplots(1, 4, figsize=(16, 4))
# for i, image in enumerate(dataset[:4]["image"]):
#     axs[i].imshow(image)
#     axs[i].set_axis_off()
# fig.show()
# plt.pause(3000)

#预处理尺寸归一化
preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

#预处理函数
def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}
dataset.set_transform(transform)
#加载数据
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

#创建模型
model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)
#验证形状
sample_image = dataset[0]["images"].unsqueeze(0)
print("Input shape:", sample_image.shape)
print("Output shape:", model(sample_image, timestep=0).sample.shape)
#创建执行计划
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)

Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])
#创建损失函数
noise_pred = model(noisy_image, timesteps).sample
loss = F.mse_loss(noise_pred, noise)
#优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
#调度器
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

#保存为网格
def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i % cols * w, i // cols * h))
    return grid

#DDPMPipeline 
def evaluate(config, epoch, pipeline):
    # Sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is `List[PIL.Image]`
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.manual_seed(config.seed),
    ).images

    # Make a grid out of the images
    image_grid = make_grid(images, rows=4, cols=4)

    # Save the images
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")



def get_full_repo_name(model_id: str, organization: str = None, token: str = None):
    if token is None:
        token = HfFolder.get_token()
    if organization is None:
        username = whoami(token)["name"]
        return f"{username}/{model_id}"
    else:
        return f"{organization}/{model_id}"


def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
     # 初始化加速器和张量板日志记录
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        # log_with="tensorboard",
        # logging_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.push_to_hub:
            repo_name = get_full_repo_name(Path(config.output_dir).name)
            repo = Repository(config.output_dir, clone_from=repo_name)
        elif config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # 训练模型
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = batch["images"]
            # 添加到图像的样本噪声
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]

            # 为每个图像采样一个随机时间步长
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
            ).long()

            # 根据每个时间步的噪声幅度给干净图像添加噪声
            # (这是前向扩散过程)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                # 预测噪声残差
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # 在每个 epoch 之后,您可以选择使用 evaluate() 对一些演示图像进行采样并保存模型
        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                if config.push_to_hub:
                    repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True)
                else:
                    pipeline.save_pretrained(config.output_dir)

#将训练循环、所有训练参数和进程数(您可以将此值更改为可用的 GPU 数)传递给函数以用于训练
# args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
#准备好使用 Accelerate 的notebook_launcher功能启动训练
# notebook_launcher(train_loop, args, num_processes=1)
#查看生成结果
sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])

训练代码中的TrainingConfig,比较关键的几个配置:

名称解释
图片大小决定生成图像的大小
batch_size如果在较小的GPU下训练,请适当调低该参数,比如8、4

在Terminal中运行:

python train.py

查看gpu,CMD输入以下命令

nvidia-smi

会提示请求连接关闭,多尝试几次,笔者尝试20多次才正常下载模型548M

三、推理代码

test.py

from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("ddpm-butterflies-128")
generator.to("cuda")
image = generator().images[0]
image.save("generated_image.png")

在命令行中运行:

python test.py

四、输出

每次都是反扩散1000次的结果,预计50秒生成一张图。

以下是多次运行自动生成的蝴蝶:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

五、总结

本例中仅使用蝴蝶进行训练,在数据集中生成新的蝴蝶符合预期,但部分图像有杂点。

在现实生产环境中,图像来源方方面面,我们应该注意在准备数据集的时候尽量平均不同类型的图片数量与质量,以防止模型的主观。

目前的性能离实时生成,性能上有一定的距离。

该模型源自Diffusion,2022年推广且比较新的模型,期待抛转引玉。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/645408.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【基于容器的部署、扩展和管理】3.9 云原生容器的安全性和合规性

往期回顾: 第一章:【云原生概念和技术】 第二章:【容器化应用程序设计和开发】 第三章:【3.1 容器编排系统和Kubernetes集群的构建】 第三章:【3.2 基于容器的应用程序部署和升级】 第三章:【3.3 自动…

关于Android的帧动画,补间动画,属性动画的使用和总结。(附源码)

说明&#xff1a;内容有点多&#xff0c;可以分块阅读&#xff0c;后续可能会拆分为三讲 一. Android的动画总结 一 . 帧动画 帧动画其实就是通过连续播放图片来模拟动画效果 以下是俩种实现方式&#xff1a; 1. xml文件的方式 首先在drawable下建立animation_lufi.xml <?…

视频剪辑需要学哪些软件 视频剪辑在哪里学

视频剪辑涉及到素材准备、视频的编辑与生成、格式的转换等方面&#xff0c;因此需要学习到的软件类型也不少。比如在准备素材时&#xff0c;可能会涉及到音频、图片等处理&#xff0c;以及特效的应用等。接下来&#xff0c;就让我们详细了解下视频剪辑需要学哪些软件&#xff0…

如何进行安全技术交底

安全技术交底是一项非常重要的安全管理工作&#xff0c;对于保障施工现场的安全和人员的生命安全具有不可替代的作用。那么作为公司管理层如何实时了解安全技术交底的执行情况&#xff0c;工作人员是否做到认真、安全、彻底执行&#xff1f; 有没有更好、更简便低成本的方法来做…

利好再现!股、债携手上涨将继续?

政策支持继续&#xff0c;6月13日&#xff0c;国家发改委等部门印发《关于做好2023年降成本重点工作的通知》&#xff0c;助力经济运行整体好转。当日&#xff0c;央行公开市场操作中7天逆回购中标利率也下调10个基点&#xff0c;市场对于6月降息预期越来越浓了。带动A股市场震…

Spark SQL数据源:JSON数据集

文章目录 一、读取JSON文件简介二、读取JSON文件案例演示&#xff08;一&#xff09;创建JSON文件并上传到HDFS&#xff08;二&#xff09;读取JSON文件&#xff0c;创建临时表&#xff0c;进行关联查询1、读取user.json文件&#xff0c;创建临时表t_user2、读取score.json文件…

高速视觉筛选机PCI Express实时运动控制卡XPCIE1028

产品导读 正运动技术的PCI Express总线运动控制卡XPCIE1028&#xff0c;具备位置锁存、多维高速硬件位置比较输出PSO、同步跟随、精准触发的运动控制和I/O控制功能。 配合正运动技术MotionRT7实时内核使用&#xff0c;可高度满足高速视觉筛选机应用所需的运动控制需求。 XPC…

png转jpg,直接改后缀?

通过把.png改为.jpg可以改变图片的格式么&#xff1f; 将PNG文件扩展名改为JPEG的扩展名&#xff08;.jpg或.jpeg&#xff09;不会更改图像的格式。它只是更改了文件扩展名&#xff0c;这可能导致一些图像查看器和编辑器无法正确识别和处理该文件。 PNG和JPEG是两种不同的图像文…

揭秘水文覆盖变化!使用 R 语言轻松处理 GRACE.nc 文件

一、引言 在今天越来越严重的气候变化条件下&#xff0c;水文覆盖成为了越来越多研究者重视的话题。水文覆盖指的是地表或植被表面被水覆盖的面积&#xff0c;包括河流、洼地、湖泊、蓄水池等。它反应了一个地区的水资源分布、水域利用等情况&#xff0c;对于水资源管理和自然…

centos7中docker安装单机版本及对应的分布式应用中心【亲测可用】

第一部分&#xff1a;安装docker篇 1.安装docker&#xff0c;sudo为以管理员身份运行,如当前登录为root用户&#xff0c;加上也不影响 sudo yum remove docker \ docker-client \ docker-client-latest \ docker-common \ docker-latest \ docker-latest-logrotate \ docker-…

在不安装ghostscript软件情况下,Windows中将ghostscript DLL(gsdll64.dll)库提供给python,并将资源打包进exe

1. 先安装ghostscript软件&#xff0c;将安装后的文件夹复制到项目文件夹下 2. 安装ghostscript&#xff0c;修改代码调用gsdll64.dll文件 pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple ghostscript 将ghostscript 库安装的文件夹复制到项目文件夹下&#xff…

信贷产品的贷前获客营销策略搭建

在竞争激烈的信贷市场中&#xff0c;有效的贷前获客营销策略对于吸引潜在借款人、提高转化率以及保持客户忠诚度至关重要。本文将分享一些关于信贷产品贷前获客营销策略搭建的基本框架和经验分享&#xff0c;希望能对大家有所启发。 1、市场调研和目标客户定义 在制定贷前获客…

20230614使用360安全卫士的断网急救箱解决不能上网的问题

20230614使用360安全卫士的断网急救箱解决不能上网的问题 2023/6/14 12:29 未连接到互联网 网络连接错误&#xff0c;请检查您的网络设置 刷新 无法访问此网站youtube.com 的响应时间过长。 请试试以下办法&#xff1a; 检查网络连接 检查代理服务器和防火墙 运行 Windows 网…

小程序步骤条实现

步骤条实现 <template><view class"contractInfo"><view class"contractInfo_center" style"overflow-y: auto; display: flex; overflow-y: hidden"><view class"contractInfo_center_block" v-for"(ite…

AI推文三天百万播放项目拆解

小说推文是之前操作的第一个入局的项目&#xff0c;很快就跑通了0-1&#xff0c;但是实践三个月后我决定从入⻔到放弃&#xff0c;但是大家可以借鉴一下这个项目操作经验&#xff0c;网上报了两个推文项目的陪跑199299,分享一下这个经验&#xff0c;大家可以提提意⻅。 为什么…

自动驾驶专题介绍 ———— 激光雷达标定

文章目录 介绍激光雷达与激光雷达之间的外参标定激光雷达与摄像头的标定 介绍 激光雷达在感知、定位方面发挥着重要作用。跟摄像头一样&#xff0c;激光雷达也是需要进行内外参数标定的。内参标定是指内部激光发射器坐标系与雷达自身坐标系的转换关系&#xff0c;在出厂之前就已…

管理类联考——逻辑——知识篇——第二章 模态命题(考1题)(以性质命题为基础)

第二章 模态命题&#xff08;考1题&#xff09;&#xff08;以性质命题为基础&#xff09; 一、模态命题 模态命题多指包含有“必然&#xff08;一定&#xff09;”或“可能”这两个模态词的狭义模态命题&#xff1a;必然命题或可能命题。 二、模态考点 联考中模态的考点比…

uniapp小程序中的相关设置

要让uniapp中的背景图片全屏&#xff0c;可以在<style>标签中添加以下样式&#xff1a; page { background-image: url(/static/bg.jpg); background-size: cover; background-repeat: no-repeat; background-position: center center; } 在这个样式中&…

终于让我找到支持任意经纬度生活指数查询API 了

引言 未来7天生活指数API 支持通过输入任意经纬度查询&#xff0c;提供丰富包括晨练、洗车、穿衣、感冒、运动、旅游、舒适度、紫外线、钓鱼、晾晒、过敏、啤酒等多个方面的指数&#xff0c;为用户提供了更加全面的天气信息和建议。 在本文中&#xff0c;我们将深入了解未来7…

华为OD机试真题 Java 实现【非严格递增连续数字序列】【2022Q4 100分】

一、题目描述 输入一个字符串仅包含大小写字母和数字&#xff0c;求字符串中包含的最长的非严格递增连续数字序列的长度&#xff0c;比如122889属于非严格递增连续数字序列。 二、输入描述 输入一个字符串仅包含大小写字母和数字&#xff0c;输入的字符串最大不超过255个字符…