PyTorch深度学习实战——图像着色

news2024/11/26 21:23:33

PyTorch深度学习实战——图像着色

    • 0. 前言
    • 1. 模型与数据集分析
      • 1.1 数据集介绍
      • 1.2 模型策略
    • 2. 实现图像着色
    • 相关链接

0. 前言

图像着色指的是将黑白或灰度图像转换为彩色图像的过程,传统的图像处理技术通常基于直方图匹配和颜色传递的方法或基于用户交互的方法等完成图像着色操作,不但耗时且需要专业知识,而基于深度学习的方法能够实现自动着色,极大的提高了效率。在训练图着色模型时,我们可以将原始图像转换为黑白图像作为网络输入,原始彩色图像作为输出。

1. 模型与数据集分析

在本节中,我们将利用 CIFAR-10 数据集执行图像着色。

1.1 数据集介绍

CIFAR-10 数据集是一个广泛应用于计算机视觉领域的图像分类数据集。它由 10 个不同类别的彩色图像组成,每个类别包含 600032 x 32 像素的图像。该数据集涵盖了各种不同的对象类别,包括飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。与一些只包含灰度图像的数据集相比,CIFAR-10 数据集的图像是彩色的,但由于图像分辨率相对较低,图像中的细节和特征相对较少。
CIFAR-10 数据集在计算机视觉领域的研究和开发中得到了广泛的应用,许多图像分类算法和深度学习模型都在 CIFAR-10 上进行了测试和验证。它提供了一个标准化的基准,用于比较不同算法的性能。

1.2 模型策略

了解了所用数据集后,本节中,我们继续介绍图像着色模型策略:

  1. 获取训练数据集中的原始彩色图像,将其转换为灰度图像,构造输入(灰度)-输出(原始彩色图像)对
  2. 执行归一化输入和输出图像
  3. 构建 U-Net 架构
  4. 训练模型

2. 实现图像着色

接下来,使用 PyTorch 实现以上策略,构建图像着色模型。

(1) 导入所需库:

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch import optim
import numpy as np
import torchvision
from matplotlib import pyplot as plt

(2) 下载数据集,并定义训练、验证数据集和数据加载器。

下载数据集:

data_folder = 'cifar10/cifar/' 
datasets.CIFAR10(data_folder, download=True)

定义训练、验证数据集和数据加载器:

class Colorize(torchvision.datasets.CIFAR10):
    def __init__(self, root, train):
        super().__init__(root, train)
        
    def __getitem__(self, ix):
        im, _ = super().__getitem__(ix)
        bw = im.convert('L').convert('RGB')
        bw, im = np.array(bw)/255., np.array(im)/255.
        bw, im = [torch.tensor(i).permute(2,0,1).to(device).float() for i in [bw,im]]
        return bw, im

trn_ds = Colorize('cifar10/cifar/', train=True)
val_ds = Colorize('cifar10/cifar/', train=False)

trn_dl = DataLoader(trn_ds, batch_size=256, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=256, shuffle=False)

输入和输出图像的样本如下:

a,b = trn_ds[0]
plt.subplot(121)
plt.imshow(a.permute(1,2,0).cpu(), cmap='gray')
plt.subplot(122)
plt.imshow(b.permute(1,2,0).cpu())
plt.show()

样本示例
(3) 定义网络架构:

class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x

class DownConv(nn.Module):
    def __init__(self, ni, no, maxpool=True):
        super().__init__()
        self.model = nn.Sequential(
            nn.MaxPool2d(2) if maxpool else Identity(),
            nn.Conv2d(ni, no, 3, padding=1),
            nn.BatchNorm2d(no),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(no, no, 3, padding=1),
            nn.BatchNorm2d(no),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, x):
        return self.model(x)

class UpConv(nn.Module):
    def __init__(self, ni, no, maxpool=True):
        super().__init__()
        self.convtranspose = nn.ConvTranspose2d(ni, no, 2, stride=2)
        self.convlayers = nn.Sequential(
            nn.Conv2d(no+no, no, 3, padding=1),
            nn.BatchNorm2d(no),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(no, no, 3, padding=1),
            nn.BatchNorm2d(no),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
    def forward(self, x, y):
        x = self.convtranspose(x)
        x = torch.cat([x,y], axis=1)
        x = self.convlayers(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.d1 = DownConv( 3, 64, maxpool=False)
        self.d2 = DownConv( 64, 128)
        self.d3 = DownConv( 128, 256)
        self.d4 = DownConv( 256, 512)
        self.d5 = DownConv( 512, 1024)
        self.u5 = UpConv (1024, 512)
        self.u4 = UpConv ( 512, 256)
        self.u3 = UpConv ( 256, 128)
        self.u2 = UpConv ( 128, 64)
        self.u1 = nn.Conv2d(64, 3, kernel_size=1, stride=1)

    def forward(self, x):
        x0 = self.d1( x) # 32
        x1 = self.d2(x0) # 16
        x2 = self.d3(x1) # 8
        x3 = self.d4(x2) # 4
        x4 = self.d5(x3) # 2
        X4 = self.u5(x4, x3)# 4
        X3 = self.u4(X4, x2)# 8
        X2 = self.u3(X3, x1)# 16
        X1 = self.u2(X2, x0)# 32
        X0 = self.u1(X1) # 3
        return X0

(4) 定义模型、优化器和损失函数:

def get_model():
    model = UNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()
    return model, optimizer, loss_fn

(5) 定义模型在批数据进行训练和验证的函数:

def train_batch(model, data, optimizer, criterion):
    model.train()
    x, y = data
    _y = model(x)
    optimizer.zero_grad()
    loss = criterion(_y, y)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    model.eval()
    x, y = data
    _y = model(x)
    loss = criterion(_y, y)
    return loss.item()

(6) 训练模型:

model, optimizer, criterion = get_model()
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

_val_dl = DataLoader(val_ds, batch_size=1, shuffle=True)

n_epochs = 100
train_loss_epochs = []
val_loss_epochs = []

for ex in range(n_epochs):
    N = len(trn_dl)
    trn_loss = []
    val_loss = []
    for bx, data in enumerate(trn_dl):
        loss = train_batch(model, data, optimizer, criterion)
        pos = (ex + (bx+1)/N)
        trn_loss.append(loss)
    train_loss_epochs.append(np.average(trn_loss))

    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss = validate_batch(model, data, criterion)
        pos = (ex + (bx+1)/N)
        val_loss.append(loss)
    val_loss_epochs.append(np.average(val_loss))
        
    exp_lr_scheduler.step()
    if (ex+1)%10 == 0:
        for _ in range(5):
            a,b = next(iter(_val_dl))
            _b = model(a)
            plt.subplot(131)
            plt.imshow(a[0].permute(1,2,0).cpu(), cmap='gray')
            plt.subplot(132)
            plt.imshow(b[0].permute(1,2,0).cpu())
            plt.subplot(133)
            plt.imshow(_b[0].permute(1,2,0).detach().cpu().numpy())
            plt.show()
epochs = np.arange(n_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

着色结果

从前面的输出中,可以看到模型能够很好地为灰度图像着色。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1175668.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

提升你的C#技能:掌握PrintDocument实现打印操作的秘诀

前言: 我们用C#在开发应用的时候,经常需要打印操作,比如你需要打印某些记录,或者是某些图像都需要用到打印的操作,比如我需要打印报警记录,按照指定的格式打印出来,我需要PrintDocument类&…

项目管理之如何识别并应对项目风险

项目风险管理是项目管理中不可忽视的环节,如何识别并应对项目的风险对于项目的成功实施至关重要。本文将介绍风险管理的流程、风险分解结构、定性及定量风险评估方法,以及消极和积极的风险应对策略,旨在帮助读者更好地理解和应对项目风险。 …

(1)(1.12) LeddarTech LeddarVu8

文章目录 前言 1 连接到自动驾驶仪 2 参数说明 前言 LeddarTech LeddarVu8 是一款长距离(185m)激光雷达,可在 16 度至 99 度视场范围内提供 8 个单独的距离,具体取决于所使用的型号。ArduPilot 始终使用所提供的 8 个距离中最…

VSCode设置中文语言界面(VScode设置其他语言界面)

一、下载中文插件 二、修改配置 1、使用快捷键 CtrlShiftP 显示出搜索框 2、然后输入 configure display language 3、点击 (中文简体) 需要修改的语言配置 三、重启 四、可能出现的问题 1、如果configure display language已经是中文配置,界面仍是英文 解决&a…

优化C++资源利用:探索高效内存管理技巧

W...Y的主页 😊 代码仓库分享💕 🍔前言: 我们之前在C语言中学习过动态内存开辟,使用malloc、calloc与realloc进行开辟,使用free进行堆上内存的释放。进入C后对于动态内存开辟我们又有了新的内容new与dele…

【C++】一篇文章搞懂auto关键字及其相关用法!

💐 🌸 🌷 🍀 🌹 🌻 🌺 🍁 🍃 🍂 🌿 🍄🍝 🍛 🍤 📃个人主页 :阿然成长日记 …

C语言基础篇1:数据类型、常量、变量

1 C语言基础 1.1 关键字 在C语言中,关键字是指被赋予特定意义的一些单词,不能把这些单词作为标识符来使用.C语言一共有32个关键字,如下图。在后面的学习中会逐渐接触到这些关键字的具体使用用法。 1.2 标识符 标识符可以简单的理解为一个名字…

第四届辽宁省大学生程序设计竞赛(正式赛)(12/13)

AC情况 赛中通过赛后通过暂未通过A√B√C√D○E○F√G√H√I○J√K—L√M√ 整体体验 easy:ABFHL mid:MJGC hard:IDKE 心得 感觉出了一堆典题,少数题还有些意思,E题确实神仙 题解 A. 欢迎来到辽宁省赛&#x…

bff层解决了什么痛点

bff层 -- 服务于前端的后端 什么是bff? Backend For Frontend(服务于前端的后端),也就是服务器设计API的时候会考虑前端的使用,并在服务端直接进行业务逻辑的处理,又称为用户体验适配器。BFF只是一种逻辑…

基于爬行动物算法的无人机航迹规划-附代码

基于爬行动物算法的无人机航迹规划 文章目录 基于爬行动物算法的无人机航迹规划1.爬行动物搜索算法2.无人机飞行环境建模3.无人机航迹规划建模4.实验结果4.1地图创建4.2 航迹规划 5.参考文献6.Matlab代码 摘要:本文主要介绍利用爬行动物算法来优化无人机航迹规划。 …

Python|OpenCV-图像的添加和混合操作(8)

前言 本文是该专栏的第8篇,后面将持续分享OpenCV计算机视觉的干货知识,记得关注。 在使用OpenCV库对图像操作的时候,有时需要对图像进行运算操作,类似于加法,减法,位操作等处理。而本文,笔者将针对OpenCV对图像的添加,混合以及位操作进行详细的介绍说明和使用。 下面,…

03、SpringBoot + 微信支付 ---- 创建订单、保存二维码url、显示订单列表

目录 Native 下单1、创建课程订单保存到数据库1-1:需求:1-2:代码:1-3:测试结果: 2、保存支付二维码的url2-1:需求:2-2:代码:2-3:测试:…

python 之 sorted 函数

文章目录 sorted() 函数的语法返回值使用示例:示例 1:基本使用示例 2:指定降序排序示例 3:使用 key 参数进行自定义排序 注意事项: sorted() 是 Python 中的一个内置函数,用于对可迭代对象进行排序&#xf…

jquery之checkbox全选反选提交参数

实现效果 <!DOCTYPE html> <html> <head><meta charset"UTF-8"><title>Checkbox操作示例</title><script src"https://code.jquery.com/jquery-3.5.1.min.js"></script><script>$(document).ready(…

FPGA高端项目:图像缩放+GTP+UDP架构,高速接口以太网视频传输,提供2套工程源码加QT上位机源码和技术支持

目录 1、前言免责声明本项目特点 2、相关方案推荐我这里已有的 GT 高速接口解决方案我这里已有的以太网方案我这里已有的图像处理方案 3、设计思路框架设计框图视频源选择ADV7611 解码芯片配置及采集动态彩条跨时钟FIFO图像缩放模块详解设计框图代码框图2种插值算法的整合与选择…

【数据结构与算法】JavaScript实现哈希表

文章目录 一、哈希表简介1.1.认识哈希表1.2.哈希化的方式1.3.解决冲突的方法1.4.寻找空白单元格的方式线性探测二次探测再哈希化 1.5.不同探测方式性能的比较1.6.优秀的哈希函数快速计算均匀分布 二、初步封装哈希表2.1.哈希函数的简单实现2.2.创建哈希表2.3.put(key,value)2.4…

时间序列预测模型实战案例(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测

论文地址->TPA-LSTM论文地址 项目地址-> TPA-LSTM时间序列预测实战案例 本文介绍 本文通过实战案例讲解TPA-LSTM实现多元时间序列预测&#xff0c;在本文中所提到的TPA和LSTM分别是注意力机制和深度学习模型,通过将其结合到一起实现时间序列的预测&#xff0c;本文利用…

Google发布移动终端对象检测模型——mediapipe,无GPU依然飞快

对象检测模型最出名的当选YOLO系列,其YOLO系列已经更新到V8系列,但是现有的YOLO模型面临限制,如量化支持不足和准确性延迟权衡不足。 YOLO-NAS模型在包括COCO、Objects365和Roboflow 100在内的知名数据集上进行了预训练,使其非常适合生产环境中的下游对象检测任务。YOLO-NA…

unity【动画】脚本_角色动画控制器 c#

首先创建一个代码文件夹Scripts 从人物角色Player的基类开始 创建IPlayer类 首先我们考虑到如果不挂载MonoBehaviour需要将角色设置成预制体实例化到场景上十分麻烦&#xff0c; 所以我们采用继承MonoBehaviour类的角色基类方法写代码 也就是说这个脚本直接绑定在角色物体…

Quartz之JDBC-JobStoreTX配置

一、前言 上篇 《Quartz介绍》中使用的是RAMJobStored存储调度信息&#xff0c;当进程终止调度信息会丢失&#xff0c;本篇我们介绍使用JDBCJobStore来存储调度信息&#xff08;jobs、Triggers和日历&#xff09;。 二、Quartz 表结构 可以从官网&#xff08;http://www.qua…