基于Pytorch深度学习图像处理基础流程框架(以ResNetGenerator为例)

news2024/9/22 3:58:52

文章目录

  • - 模型搭建
    • 1. 搭建ResNetGenerator
    • 2. 网络实例化
    • 3.加载预训练模型权重文件
    • 4. 神经网络设置为评估模式
  • 预测处理
    • 1. 定义图片的预处理方法
    • 2. 导入图片
    • 3. 预处理图片
    • 4. 调用模型
    • 5. 输出结果


- 模型搭建

1. 搭建ResNetGenerator

import torch
import torch.nn as nn

class ResNetBlock(nn.Module): # <1>

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]
        
        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        return out


class ResNetGenerator(nn.Module):

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> 

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)

2. 网络实例化

netG = ResNetGenerator()

3.加载预训练模型权重文件

model_path = '../data/p1ch2/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

在这里插入图片描述


4. 神经网络设置为评估模式

netG.eval()

netG.eval() 是 PyTorch 中的一个方法,用于将神经网络模型设置为评估(evaluation)模式。

  1. 关闭 Dropout 和 Batch Normalization

    • 在训练过程中,Dropout 层会随机丢弃一些神经元,以防止过拟合。Batch Normalization 层会根据每个批次的数据计算均值和方差,以稳定训练过程。
    • 在评估模式下,Dropout 层会关闭,所有神经元都会参与计算。Batch Normalization 层会使用训练过程中计算的均值和方差,而不是当前批次的数据。
  2. 确保一致性

    • 在评估模式下,模型的行为会更加一致和可预测,因为不会受到随机丢弃神经元或批次数据统计特性的影响。
  3. 推理和测试

    • 在进行模型推理或测试时,应该始终将模型设置为评估模式,以确保得到准确和稳定的结果。

预测处理

1. 定义图片的预处理方法

from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([
    transforms.Resize((262, 461)),  # 调整图像大小
    transforms.ToTensor(),          # 转换为张量
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化
])

2. 导入图片

img = Image.open("../data/p1ch2/horse.jpg")

在这里插入图片描述

3. 预处理图片

# 确保图像有3个通道
if img.mode != 'RGB':
    img = img.convert('RGB')

img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)

4. 调用模型

out_t = (batch_out.data.squeeze() + 1.0) /2

5. 输出结果


out_t = (batch_out.data.squeeze() + 1.0) /2
out_img = transforms.ToPILImage()(out_t)
# out_img.save('../data/p1ch2/zebra.jpg')
out_img

在这里插入图片描述

【注*:该模型的作用是将图片中的马,生成为斑马】


(完)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2042257.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

go 调用C语言函数或者库

1.查看cgo是否开启 go env | grep CGO_ENABLED CGO_ENABLED1 2. go程序中加入 import "C" 通过 import “C” 语句启用 CGO 特性后&#xff0c;CGO 会将上一行代码所处注释块的内容视为 C 代码块 单行注释使用// 多行注释使用/* */ 3. go 与C 类型转换 在g…

HSL模型和HSB模型,和懒人配色的Color Hunt

色彩不仅仅是视觉上的享受&#xff0c;它在数据可视化中也扮演着关键角色。通过合理运用色彩模型&#xff0c;我们可以使数据更具可读性和解释性。在这篇文章将探讨HSL&#xff08;Hue, Saturation, Lightness&#xff09;和HSB&#xff08;Hue, Saturation, Brightness&#x…

【机器学习】深度学习实践

欢迎来到 破晓的历程的 博客 ⛺️不负时光&#xff0c;不负己✈️ 文章目录 引言一、深度学习基础二、图像分类示例三、拓展思考结语 引言 在当今人工智能的浪潮中&#xff0c;深度学习作为其核心驱动力之一&#xff0c;正以前所未有的速度改变着我们的世界。从图像识别、语音…

c语言第18天笔记

构造类型 结构体类型 结构体数组 案例&#xff1a; 需求&#xff1a;对候选人得票的统计程序。设有3个候选人&#xff0c;每次输入一个得票的候选人的名字&#xff0c;要求最后输出 各人得票结果。 ​ /** * 结构体数组案例&#xff1a;对候选人得票的统计程序。设有3个候…

主机组装笔记

参考资源&#xff1a;B站【装机教程】全网最好的装机教程&#xff0c;没有之一&#xff0c;仅供探讨学习 9大部件一览 其中得到固态和机械&#xff0c;是硬盘&#xff0c;存储空间&#xff0c;可以只选固态 CPU&#xff0c;主要有 AMD 和 Intel (AMD&#xff0c;基板的背面布…

力扣 58. 最后一个单词的长度

题目描述 思路 下意识想到先以空格作为分割符对字符串进行分割得到若干个子字符串&#xff0c;然后用字符串长度计算函数计算最后一个子字符串的长度。 该思路代码如下&#xff1a; class Solution:def lengthOfLastWord(self, s: str) -> int:s_array s.split()last_le…

全新在线客服系统源码(pc+h5+uniapp+公众号小程序+抖音)附搭建接入教程

全新在线客服系统源码介绍 一、系统概述与优势 本系统是一款基于PHP的开源在线客服系统&#xff0c;支持PC端、移动端(小程序)、H5页面以及Uniapp多端接入。系统利用网络技术和人工智能技术&#xff0c;实现用户与客服人员的即时聊天沟通&#xff0c;有效提升服务质量和用户满意…

Python+Selenium+Pytest+POM自动化测试框架封装详解

1、测试框架简介 1&#xff09;测试框架的优点 代码复用率高&#xff0c;如果不使用框架的话&#xff0c;代码会显得很冗余。可以组装日志、报告、邮件等一些高级功能。提高元素等数据的可维护性&#xff0c;元素发生变化时&#xff0c;只需要更新一下配置文件。使用更灵活的…

透明加密软件排行榜前十名(2024年10大好用的透明加密软件推荐)

在当今数字化的时代&#xff0c;数据的安全性和保密性已经成为了企业和个人最为关注的问题之一。随着信息技术的飞速发展&#xff0c;各种数据泄露事件层出不穷&#xff0c;给企业和个人带来了巨大的损失。在这样的背景下&#xff0c;透明加密软件应运而生&#xff0c;成为了保…

商家转账到零钱申请必过教程2024

在微信作为重要的营销场景的当下&#xff0c;微信支付的商家转账到零钱功能对于众多企业来说具有重要意义。要顺利开通该接口&#xff0c;需要注意以下几个要点。 首先&#xff0c;需要公司主体资质。申请主体必须是公司&#xff0c;个体工商户暂无法申请。同时&#xff0c;要确…

8.15成都市计量院面试问答

&#x1f416; Q&#xff1a;为什么要选择计量检定测试院&#xff1f; A&#xff1a;市计量院具备多项资质认定和计量认证项目&#xff0c;选择成都市计量检定测试院&#xff0c;意味着接触前沿技术&#xff0c;积累丰富经验&#xff0c;服务社会公益&#xff0c;参与创新研发&…

spring揭秘01-spring容器启动过程分析

文章目录 【README】【1】Spring容器根据配置元素组装可用系统的过程【2】BeanFactoryPostProcessor-Bean工厂后置处理器【2.1】属性占位符配置器使用场景代码【2.2】CustomerEditorConfigurer-自定义编辑器配置器【2.3】自定义编属性编辑器案例代码 【README】 本文总结自《s…

为什么electron占用空间大,而Tauri占用小,他们不都是封装Chromium吗

Electron 和 Tauri&#xff08;使用 WebView&#xff09;的确都涉及嵌入浏览器引擎来渲染 HTML、CSS 和 JavaScript&#xff0c;但它们的架构和设计有显著不同&#xff0c;这导致了它们在应用程序体积和资源占用上的差异。以下是一些关键的原因&#xff1a; 1. 嵌入的浏览器引…

【中等】 猿人学web第一届 第6题 js混淆-回溯

文章目录 请求流程请求参数 加密参数定位r() 方法z() 方法 加密参数还原JJENCOde js代码加密环境检测_n("jsencrypt")12345 计算全部中奖的总金额请求代码注意 请求流程 请求参数 打开 调试工具&#xff0c;查看数据接口 https://match.yuanrenxue.cn/api/match/6 请…

MySQL运维-分库分表

介绍 问题分析 拆分策略 垂直拆分 水平拆分 实现技术 Mycat概述 介绍 概念介绍 Mycat配置 schema.xml schema标签 schema标签&#xff08;table&#xff09; datanode标签 datahost标签 rule.xml sever.xml system标签 user标签 Mycat分片 分片规则-范围 分片规则-取模 分…

linux部署elasticserch单节点

简介 Elasticsearch概述&#xff1a;Elasticsearch是一个建立在Apache Lucene之上的分布式、实时文档存储搜索引擎&#xff0c;它能够胜任上百个服务节点的扩展&#xff0c;并支持PB级别的结构化或非结构化数据。 Lucene与Elasticsearch&#xff1a;简要介绍Lucene作为搜索引擎…

分布式中的CAP理论是什么?BASE理论是什么?看完你就彻底懂啦

CAP 理论 CAP理论是分布式系统理论中的一个重要概念&#xff0c;它阐述了在分布式计算环境中&#xff0c;一致性&#xff08;Consistency&#xff09;、可用性&#xff08;Availability&#xff09;和分区容错性&#xff08;Partition Tolerance&#xff09;三者之间的权衡关系…

深化解析:企业内耗的解决之道

在缺乏有效的冲突解决机制下&#xff0c;企业内部冲突难以得到妥善处理&#xff0c;这加剧了内耗&#xff0c;破坏了团队协作的和谐氛围。而当工作环境充满紧张和压力时&#xff0c;员工容易陷入焦虑和疲劳的困境&#xff0c;进而影响工作效率和团队士气。 当员工感受不到应有…

拯救中医 刻不容缓

文&#xff5c;琥珀食酒社 作者 | 积溪 “真是毫无底线” “这是挂羊头卖狗肉” “这钱也赚&#xff0c;华为太让人失望了” 看到网上的这些评论 我的拳头都硬了 华为进军中医药产业 给咱中医药打造独一无二的人工智能大模型 让中医彻底摆脱西方的围剿 这是多好的事情…

如何提取PDF其中的一页或多页?推荐4种方法!

工作中&#xff0c;我们经常需要用到PDF文件&#xff0c;如果需要提取PDF文件中的其中一页或多页内容&#xff0c;要如何做呢&#xff1f;下面小编分享4种方法&#xff0c;看下哪种适合你&#xff01; 方法1&#xff1a;使用复制粘贴 如果PDF文件中需要提取的内容是纯文字&am…