基于深度学习的图片上色(Opencv,Pytorch,CNN)

news2024/11/11 6:08:35

文章目录

    • 1. 前言
    • 2.图像格式(RGB,HSV,Lab)
      • 2.1 RGB
      • 2.2 hsv
      • 2.3 Lab
    • 3. 生成对抗网络(GAN)
      • 3.1 生成网络(Unet)
      • 3.2 判别网络(resnet18)
    • 4. 数据集
    • 5. 模型训练与预测流程图
      • 5.1 训练流程图
      • 5.2 预测流程图
    • 6. 模型预测效果
    • 7. GUI界面制作
    • 8.代码下载

1. 前言

最近做了一个图像着色的项目,基于pytorch和opencv使用生成对抗网络对灰度图像自动上色,然后可以对上色后的图片手动调节亮度对比度等信息,最后可以保存上色后的图像,闲话少说,先看一下效果,文章最后附有全部代码及数据集下载链接。

灰度图自动上色

b站视频地址:b站视频地址

2.图像格式(RGB,HSV,Lab)

2.1 RGB

想要对灰度图片上色,首先要了解图像的格式,对于一副普通的图像通常为RGB格式的,即红、绿、蓝三个通道,可以使用opencv分离图像的三个通道,代码如下所示:

import cv2

img=cv2.imread('pic/7.jpg')
B,G,R=cv2.split(img)
cv2.imshow('img',img)
cv2.imshow('B',B)
cv2.imshow('G',G)
cv2.imshow('R',R)
cv2.waitKey(0)

代码运行结果如下所示。
在这里插入图片描述

2.2 hsv

hsv是图像的另一种格式,其中h代表图像的色调,s代表饱和度,v代表图像亮度,可以通过调节h、s、v的值来改变图像的色调、饱和度、亮度等信息。
同样可以使用opencv将图像从RGB格式转换成hsv格式。然后可以分离h、s、v三个通道并显示图像代码如下所示:

import cv2

img=cv2.imread('pic/7.jpg')
hsv=cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
h,s,v=cv2.split(hsv)
cv2.imshow('hsv',hsv)
cv2.imshow('h',h)
cv2.imshow('s',s)
cv2.imshow('v',v)
cv2.waitKey(0)

运行结果如下所示:
在这里插入图片描述

2.3 Lab

Lab是图像的另一种格式,也是本文使用的格式,其中L代表灰度图像,a、b代表颜色通道,本文使用L通道灰度图作为输入,ab两个颜色通道作为输出,训练生成对抗网络,将图像由RGB格式转换成Lab格式的代码如下所示:

import cv2

img=cv2.imread('pic/7.jpg')
Lab=cv2.cvtColor(img,cv2.COLOR_BGR2Lab)
L,a,b=cv2.split(Lab)
cv2.imshow('Lab',Lab)
cv2.imshow('L',L)
cv2.imshow('a',a)
cv2.imshow('b',b)
cv2.waitKey(0)

在这里插入图片描述

3. 生成对抗网络(GAN)

生成对抗网络主要包含两部分,分别是生成网络和判别网络。
生成网络负责生成图像,判别网络负责鉴定生成图像的好坏,二者相辅相成,相互博弈。
本文使用U-net作为生成网络,使用ResNet18作为判别网络。U-net网络的结构图如下所示:

3.1 生成网络(Unet)

在这里插入图片描述

pytorch构建unet网络的代码如下所示:

class DownsampleLayer(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DownsampleLayer, self).__init__()
        self.Conv_BN_ReLU_2=nn.Sequential(
            nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        self.downsample=nn.Sequential(
            nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self,x):
        """
        :param x:
        :return: out输出到深层,out_2输入到下一层,
        """
        out=self.Conv_BN_ReLU_2(x)
        out_2=self.downsample(out)
        return out,out_2
class UpSampleLayer(nn.Module):
	def __init__(self,in_ch,out_ch):
	   # 512-1024-512
	   # 1024-512-256
	   # 512-256-128
	   # 256-128-64
	   super(UpSampleLayer, self).__init__()
   self.Conv_BN_ReLU_2 = nn.Sequential(
       nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
       nn.BatchNorm2d(out_ch*2),
       nn.ReLU(),
       nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
       nn.BatchNorm2d(out_ch*2),
       nn.ReLU()
   )
   self.upsample=nn.Sequential(
       nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
       nn.BatchNorm2d(out_ch),
       nn.ReLU()
   )

	def forward(self,x,out):
	   '''
	   :param x: 输入卷积层
	   :param out:与上采样层进行cat
	   :return:
	   '''
	   x_out=self.Conv_BN_ReLU_2(x)
	   x_out=self.upsample(x_out)
	   cat_out=torch.cat((x_out,out),dim=1)
	   return cat_out
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
        #下采样
        self.d1=DownsampleLayer(3,out_channels[0])#3-64
        self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
        self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
        self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
        #上采样
        self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
        self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
        self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
        self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
        #输出
        self.o=nn.Sequential(
            nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0],3,3,1,1),
            nn.Sigmoid(),
            # BCELoss
        )
    def forward(self,x):
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)
        out=self.o(out8)
        return out


3.2 判别网络(resnet18)

resnet18的结构图如下所示:
在这里插入图片描述
在pytorch内部自带resnet18模型,只需一行代码即可构建resnet18模型,然后还需要去除网络最后的全连接层,代码如下所示:

from torchvision import models

resnet18=models.resnet18(pretrained=False)
del resnet18.fc

print(resnet18)

4. 数据集

本文使用的是自然风景类的数据图片,在网站上爬取了大概1000多张数据图片,部分图片如下所示
在这里插入图片描述

5. 模型训练与预测流程图

5.1 训练流程图

如下图所示,首先将RGB图像转换成Lab图像,然后将L通道作为生成网络输入,生成网络的输出为新的ab两通道,然后将图像原始的ab通道,与生成网络生成的ab通道输入判别网络中。
在这里插入图片描述

5.2 预测流程图

下图为模型的预测过程,在预测过程中判别网络已经没有作用了,首先将RGB图像转换成,Lab图像,接着将L灰度图输入生成网络可以得到新的ab通道图像,接着将L通道图像与生成的ab通道图像进行拼接(concate),拼接以后可以得到一张新的Lab图像,然后再将其转换成RGB格式,此时图像即为上色以后的图像。
在这里插入图片描述

6. 模型预测效果

下图为模型的预测效果。左侧的为灰度图像,中间的为原始的彩色图像,右侧的是模型上色以后的图像。整体上看,网络的上色效果还不错。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

7. GUI界面制作

为了更加方便使用模型,本文使用pyqt5制作操作界面,其界面如下图所示:首先可以从电脑中加载图像,还可以切换上一张或者下一张,可以将图像灰度化显示。可以对其上色,然后可以调整上色后图像的H、S、V信息,最后支持图像导出,可以将上色后的图像保存到本地中。
在这里插入图片描述
在这里插入图片描述

8.代码下载

链接中包含了训练代码,测试代码,以及界面代码。此外还包含1000多张数据集,直接运行main.py程序即可弹出操作界面。
代码及数据集下载链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/453573.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OSCP-Medjed(重置用户密码、mysql写webshell、可写文件替换提权)

目录 扫描 FTP WEB 提权 扫描 FTP 尝试登录到FTP服务器,该服务器位于端口30021 使用Filezilla,并能够浏览文件。那里有一些配置文件,但找不到任何值得注意的东西,不能写入目录。

成长之路---C语言笔记(构造类型之字符数组及字符串函数)

决不要停止自学,也不要忘记,不管你已经学到了多少东西,已经知道了多少东西,知识和学问是没有止境的一鲁巴金 字符数组 字符数组就是用于存放字符型数据的数组。在C语言中,字符串是作为字符数组来处理的,没有…

redis设计与实现读书笔记(2)

今天看的是关于单机数据库,RDB持久化以及AOF持久化的内容。 关于单机数据库 1.默认数据库数量 redis的服务器默认是会创建16个数据库,每个客户端访问的时候都要指定自己的目标数据库。 select可以切换目标数据库。 注意事项 到目前为止&#xff0c…

部署YUM仓库

文章目录 1. YUM仓库服务1.1 YUM概述1.1 准备安装源 2.制作YUM源2.1制作ftp源2.2 国内在线yum源2.3 本地源与在线源同时使用 3.yum软件包的下载方式4.yum的常用操作命令 1. YUM仓库服务 1.1 YUM概述 yum是一个基于RPM包(是Red-Hat Package Manager红帽软件包管理器…

【Android入门到项目实战-- 6.1】—— 如何申请用户权限

目录 一、申请权限 1、布局文件 2、MainActivity类 3、AndroidManifest文件 你在使用安卓APP时可能经历过以下场景:使用APP的拍照功能时需要你授权使用相机。那么APP是如何完成申请权限功能的? 访问:https://developer.android.google.cn/…

ERTEC200P-2 PROFINET设备完全开发手册(9-2)

9.2 运行AC1/AC4参考代码 修改源代码usrapp_cfg.h的宏为 #define EXAMPL_DEV_CONFIG_VERSION 44 编译后下载到评估板运行AC4示例程序 在TIA中导入GSDML-V2.35-Siemens-ERTEC200pApp44-20210623.xml。新建项目,添加PLC和Devkit设备。 按照如下图所示配置模块&am…

2023零基础学软考网络工程师能过吗?

网络工程师是在计算机及其相关领域中拥有一定专业技能和知识的人员,可以为企业或个人设计、建设和维护计算机网络系统。网络工程师的职业前景非常广阔,尤其是随着信息化和互联网的迅速发展,网络工程师的需求也越来越大。软考是国家计算机技术…

Java多线程- synchronized关键字总结

目录 多线程锁的概要 Synchronized关键字 synchronized加锁过程 synchronized锁优化 锁消除 锁粗化 多线程锁的概要 首先对于锁的条件和要点进行一个总结: 锁使用来保护代码片段的, 以保证多线程的安全性, 一次只允许一个线程执行被保护的代码.锁可以管理视图进入被保护代…

malloc的一些知识

这是一个叫malloc的家伙,一直勤勤恳恳帮你为所欲为的玩转系统内存。可是长路漫漫,唯malloc作伴,我却不懂它。走近malloc,多了解一下总没错。 可能对我们来讲,malloc就是void* malloc (size_t len),调用就是…

AcWing算法提高课-2.1.2城堡问题

宣传一下算法提高课整理 <— CSDN个人主页&#xff1a;更好的阅读体验 <— 题目传送门点这里 题目描述 1 2 3 4 5 6 7 #############################1 # | # | # | | ######---#####---#---#####---#2 # # | # # # # ##---#…

微信小程序的【运行机制】解读

文章目录 导语1.微信小程序的运行流程1.1 微信小程序的启动模式1.2 前台与后台的概念1.3 挂起1.4 微信小程序的销毁 微信小程序冷启动的页面从新启动策略 3.微信小程序热启动页面4. 退出状态注意点补充总结 导语 前面我们有章节给大家讲到了&#xff0c;微信小程序的生命周期钩…

Socket网络编程练习(C#)

Socket编程&#xff1a;两个窗口通信 本文章代码来自b站视频&#xff1a;【.Net零基础入门 (老赵主讲)-哔哩哔哩】 https://b23.tv/YI5VWaj 原视频发布者为传智播客&#xff0c;本人根据自己的学习进度对代码做了少许优化 一、网络编程前置知识 1.1 什么是网络编程 网络编程…

【Leetcode每日一刷】动态规划:509. 斐波那契数、322. 零钱兑换、300. 最长递增子序列

博主简介&#xff1a;努力学习的22级计科生博主主页&#xff1a; 是瑶瑶子啦所属专栏: LeetCode每日一题–进击大厂 前言&#xff1a;动规五部曲 以下是《代码随想录》作者总结的动规五部曲 确定dp数组&#xff08;dp table&#xff09;以及下标的含义确定递推公式&#xff0…

什么是异步,同步,并行,串行,单工,半双工,全双工通信

目录 1 如何理解“BUS总线” 2 通信方式的分类 2.1 串行通信Serial communication 2.1.1 异步传输Asynchronous serial communication 2.1.2 同步传输Synchronous serial communication 2.1.3 单工通信Simplex communication 2.1.4 半双工通信Half-duplex communication…

Unity API详解——Matrix4x4类

在脚本中通常用Vector3、QUaternion、Transform等类的属性和方法来对物体进行变换&#xff0c;Matrix4x4类通常在一些比较特殊的地方&#xff0c;如对摄像机的非标准投影变换等。本博客主要介绍Matrix4x4类的一些实例和静态方法。 文章目录 一、Matrix4x4类实例方法1、Multply…

机器学习实战教程(九):模型泛化

泛化能力 模型泛化是指机器学习模型对新的、未见过的数据的适应能力。在机器学习中&#xff0c;我们通常会将已有的数据集划分为训练集和测试集&#xff0c;使用训练集训练模型&#xff0c;然后使用测试集来评估模型的性能。模型在训练集上表现得好&#xff0c;并不一定能在测…

Redis源码分析(基于Redis7,对比Redis6)

PS: redis7.0.9版本的 1.Redis 源代码分类 1.1Redis 基本的数据结构 基础 Redis对象object.c字符串t_string.c列表t_list.c字典t_hash.c集合及有序集合t_set.c和z_set.c数据流 t_stream 底层实现结构 listpack.c 和 rax.c 简单动态字符串 sds.c整数集合 intset.c压缩列表 z…

【dp】最长递增子序列

文章目录 方法一&#xff1a;动态规划方法二&#xff1a;贪心 二分查找构造最长递增子序列 方法一&#xff1a;动态规划 dp[i]&#xff1a;末尾元素为arr[i]的最长子序列的长度 从0遍历到i - 1&#xff0c;若遍历到的元素小于当前值arr[i]&#xff0c;表示当前值arr[i]可以和…

国考省考行测:词句理解,词的对象指代,就近原则,主语一致法,语意语境分析上下文找出指代含义

国考省考行测&#xff1a;词句理解&#xff0c;词的对象指代&#xff0c;就近原则&#xff0c;主语一致法&#xff0c;语意语境分析上下文找出指代含义 2022找工作是学历、能力和运气的超强结合体! 公务员特招重点就是专业技能&#xff0c;附带行测和申论&#xff0c;而常规国…

3年100亿!苏宁易购与倍科达成重磅战略合作

紧抓消费复苏机遇&#xff0c;家电行业迎来重磅合作。4月20日&#xff0c;苏宁易购与国际知名家电品牌倍科在南京召开战略合作发布会&#xff0c;共同宣布升级战略合作伙伴关系。双方将围绕3年100亿战略合作目标开展独家品牌授权、发起“BIS”计划、打造生态开放平台、升级用户…