《pytorch深度学习实战》学习笔记第2章

news2025/1/11 14:23:38

第2章 预训练网络

讨论3种常用的预训练模型:

        1、根据内容对图像进行标记(识别)

        2、从真实图像中生成新图像(GAN)

        3、使用正确的英语句子来描述图像内容(自然语言)

2.1 获取一个预训练好的网络用于图像识别

ImageNet数据集,用于大规模视觉识别挑战赛。

所有预训练好的模型都在TorchVision中。

2.1.1 导入已有的模型

所有模型都在torchvison的models中。导入并查看。

from torchvision import models
dir(models)

输出的是所有torchvison里面集成的模型框架。其中首字母大写的是一些流行的模型小写的名字是快捷函数,返回实例化模型函数

1.1.1 AlexNet模型

实例化AlexNet。

alexnet=models.AlexNet()
alexnet

可以像函数一样调用它。给alexnet输入数据,就会通过正向传播(forward pass)得到输出。比如output=alexnet(input)。由于网络没有初始化,没有经过训练。所以一般先要将模型从头训练或者加载训练好的网络。然后再调用。

1.1.2 Resnet模型

(1)加载在ImageNet数据集上训练好的权重,来实例化ResNet101
resnet=models.resnet101(pretrained=True)
resnet

然后就开始下载,下载完成后查看resnet101的结构。

神经网络由许多模块构成,包含过滤器和非线性函数,fc层结束,输出每个类的分数。

预训练好的模型可以跟函数一样调用,并输入图片实现预测。

(2)定义预处理函数:
from torchvision import transforms
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

预处理包括:图像缩放到256*256像素,围绕中心裁剪到224*224像素,转为张量,归一化处理,使用定义的均值和标准差。

(3)导入图片并进行预处理

导入一张狗的照片并显示

from PIL import Image
img = Image.open('bobby.jpg')
img

待用预处理函数对图片进行预处理。

img_t=preprocess(img)
img_t.shape

输出为一个3维的张量。

给张量前面再增加一个维度。

import torch
batch_t = torch.unsqueeze(img_t,0)#加一个维度,数字0代表增加在第0维前面,如果为1就代表维度1前面
batch_t.shape

输出在第0维前面增加了一个1.

(4)运行模型

在新数据上运行训练过的模型的过程被称为推理(inference),为了推理需要先将网络放到eval模式。执行代码:

resnet.eval()

进行推理:

out=resnet(batch_t)
out

产生了一个1000分类的向量,每个ImageNet对应一个分数。

(5)查看预测结果

加载定义好的ImageNet标签。

with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

需要找出out输出在labels标签中的索引。可以利用max()函数输出张量中最大值以及最大值的索引。代码如下:

_,index=torch.max(out,1)
index

输出的索引不是一个数字,而是一个一维张量。

使用index[0]获得实际的数字作为标签列表的索引,用torch.nn.functional.softmax()将输出归一化到[0,1],然后除以总和。可以求出模型在预测中的置信度。

代码:

percentage = torch.nn.functional.softmax(out,dim=1)[0]*100
labels[index[0]],percentage[index[0]].item()

输出:('golden retriever', 96.29335021972656)

分类结果维金毛犬,置信度为96%。

也可以对预测结果的其它值进行排序输出。比如输出前5个。

_,indices = torch.sort(out,descending=True)
[(labels[idx],percentage[idx].item()) for idx in indices[0][:5]]

输出:

2.2 一个足以以假乱真的预训练模型

GAN是生成式对抗网络(generative adversarial network)的缩写。

cycleGAN是循环生成式对抗网络的缩写,可以将一个领域的图像转换为另一个领域的图像。

2.2.1 将马变为斑马的网络

CycleGAN从ImageNet数据集中提取的马和斑马的数据集进行训练。该网络学习获取一匹或多匹马的图像,并将它们全部变成斑马,图像的其余部分尽可能不被修改。

使用预训练好的CycleGAN将使我们有机会更进一步了解网络是如何实现的,对于本例就是生成器。

(1)以ResNet为例,定义一个ResNetGenerator类。

import torch
import torch.nn as nn

class ResNetBlock(nn.Module): # <1>

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        return out


class ResNetGenerator(nn.Module):

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> 

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)

(2)实例化

netG = ResNetGenerator()

权重为随机权重。

(3)将预训练好的权重添加到ReNet Generator中。

model_path='horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

执行后,netG就获得了训练中需要的所有知识。

(4)推理

netG.eval()

输出:

程序是将一匹或多匹马逐像素修改。

导入随机马的图像进行测试。

导入需要的库:

from PIL import Image
from torchvision import transforms

定义预处理函数:

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor()
])

导入马的图片:

img = Image.open('horse.jpg')
img

对图片预处理:

img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t,0)

将变量传递给模型:

batch_out = netG(batch_t)

将生成器的输出转换为图像。

out_t = (batch_out.data.squeeze()+1.0)/2.0
out_img = transforms.ToPILImage()(out_t)
out_img

2.6 练习题

1.将金毛猎犬的图像输入马-斑马模型中。

参考资料:

1. 预训练网络 · 深度学习与PyTorch(中文版) (paper2fox.github.io)

4. PyTorch深度学习 Deep Learning with PyTorch ch.2, p2_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1568331.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

el-upload上传图片图片、el-load默认图片重新上传、el-upload初始化图片、el-upload编辑时回显图片

问题 我用el-upload上传图片&#xff0c;再上一篇文章已经解决了&#xff0c;el-upload上传图片给SpringBoot后端,但是又发现了新的问题&#xff0c;果然bug是一个个的冒出来的。新的问题是el-upload编辑时回显图片的保存。 问题描述&#xff1a;回显图片需要将默认的 file-lis…

A53 cache的架构解读

快速链接: 【精选】ARMv8/ARMv9架构入门到精通-[目录] &#x1f448;&#x1f448;&#x1f448; 引流关键词:缓存,高速缓存,cache, CCI,CMN,CCI-550,CCI-500,DSU,SCU,L1,L2,L3,system cache, Non-cacheable,Cacheable, non-shareable,inner-shareable,outer-shareable, optee、…

精准扶贫管理系统|基于Springboot的精准扶贫管理系统设计与实现(源码+数据库+文档)

精准扶贫管理系统目录 目录 基于Springboot的精准扶贫管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员模块的实现 &#xff08;1&#xff09;用户信息管理 &#xff08;2&#xff09;贫困户信息管理 &#xff08;3&#xff09;新闻类型管理 &a…

【Java EE】关于Maven

文章目录 &#x1f38d;什么是Maven&#x1f334;为什么要学Maven&#x1f332;创建⼀个Maven项目&#x1f333;Maven核心功能&#x1f338;项目构建&#x1f338;依赖管理 &#x1f340;Maven Help插件&#x1f384;Maven 仓库&#x1f338;本地仓库&#x1f338;私服 ⭕总结 …

Unity类银河恶魔城学习记录12-3 p125 Limit Inventory Slots源代码

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释&#xff0c;可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili Inventory.cs using Newtonsoft.Json.Linq; using System.Collections; us…

Linux 线程:线程互斥、互斥量、可重入与线程安全

目录 一、线程互斥 1、回顾相关概念 2、抢票场景分析代码 多个线程同时操作全局变量 产生原因 如何解决 二、互斥量 1、概念 2、初始化互斥量&#xff1a; 方法1&#xff1a;静态分配 方法2&#xff1a;动态分配 3、销毁互斥量&#xff1a; 4、加锁和解锁 示例抢…

多忽悠几次AI全招了!Anthropic警告:长上下文成越狱突破口,GPT羊驼Claude无一幸免

ChatGPT狂飙160天&#xff0c;世界已经不是之前的样子。 新建了免费的人工智能中文站https://ai.weoknow.com 新建了收费的人工智能中文站https://ai.hzytsoft.cn/ 更多资源欢迎关注 大模型厂商在上下文长度上卷的不可开交之际&#xff0c;一项最新研究泼来了一盆冷水—— Cl…

计算机网络-HTTP相关知识-HTTP的发展

HTTP/1.1 特点&#xff1a; 简单&#xff1a;HTTP/1.1的报文格式包括头部和主体&#xff0c;头部信息是键值对的形式&#xff0c;使得其易于理解和使用。灵活和易于扩展&#xff1a;HTTP/1.1的请求方法、URL、状态码、头字段等都可以自定义和扩展&#xff0c;使得其具有很高的…

Golang 内存管理和垃圾回收底层原理(一)

一、这篇文章我们来聊聊Golang内存管理和垃圾回收&#xff0c;主要注重基本底层原理讲解&#xff0c;进一步实战待后续文章 1、这篇我们来讨论一下Golang的内存管理 先上结构图 从图我们来讲Golang的基本内存结构&#xff0c;内存结构可以分为&#xff1a;协程缓存、中央缓存…

Redis的高可用(主从复制、哨兵模式、集群)的概述及部署

目录 一、Redis主从复制 1、Redis的主从复制的概念 2、Redis主从复制的作用 ①数据冗余&#xff1a; ②故障恢复&#xff1a; ③负载均衡&#xff1a; ④高可用基石&#xff1a; 3、Redis主从复制的流程 4、Redis主从复制的搭建 4.1、配置环境以及安装包 4.2所有主机…

chabot项目介绍

项目介绍 整体的目录如下所示&#xff1a; 上述的项目结构中出了model是必须的外&#xff0c;其他的都可以根据训练的代码参数传入进行调整&#xff0c;有些不需要一定存在data train.pkl:对原始训练语料进行tokenize之后的文件,存储一个list对象&#xff0c;list的每条数据表…

解密AI人工智能的整体分层架构:探索智能科技的未来之路

随着人工智能技术的迅猛发展&#xff0c;AI已经渗透到我们生活的方方面面。而支撑AI人工智能系统运作的核心是其整体分层架构。本文将深入探讨AI人工智能的整体分层架构&#xff0c;揭示其中的奥秘&#xff0c;探索智能科技的未来之路。 ### AI人工智能整体分层架构的重要性 …

Linux上安装DM8(达梦数据库),SpringBoot集成达梦

1.达梦数据库在Linux上的安装 官方手册:https://eco.dameng.com/document/dm/zh-cn/start/install-dm-linux-prepare.html 1.1下载安装包 官网:https://www.dameng.com/list_103.html 点击”服务与合作”--> “下载中心” 这里选择对应的cpu和操作系统(举个例子:windows版本…

Java基础 - 代码练习

第一题&#xff1a;集合的运用&#xff08;幸存者&#xff09; public class demo1 {public static void main(String[] args) {ArrayList<Integer> array new ArrayList<>(); //一百个囚犯存放在array集合中Random r new Random();for (int i 0; i < 100; …

JimuReport 积木报表

一款免费的数据可视化报表&#xff0c;含报表和大屏设计&#xff0c;像搭建积木一样在线设计报表&#xff01;功能涵盖&#xff0c;数据报表、打印设计、图表报表、大屏设计等&#xff01; Web 版报表设计器&#xff0c;类似于 excel 操作风格&#xff0c;通过拖拽完成报表设计…

从零到百万富翁:ChatGPT + Pinterest

原文&#xff1a;Zero to Millionaire Online: ChatGPT Pinterest 译者&#xff1a;飞龙 协议&#xff1a;CC BY-NC-SA 4.0 在社交媒体上赚取百万美元 - 逐步指南&#xff0c;如何在线赚钱版权 献给&#xff1a; 我将这本书&#xff0c;“从零到百万富翁在线&#xff1a;Chat…

【御控物联】JavaScript JSON结构转换(18):数组To对象——多层属性重组

文章目录 一、JSON结构转换是什么&#xff1f;二、案例之《JSON数组 To JSON对象》三、代码实现四、在线转换工具五、技术资料 一、JSON结构转换是什么&#xff1f; JSON结构转换指的是将一个JSON对象或JSON数组按照一定规则进行重组、筛选、映射或转换&#xff0c;生成新的JS…

基于PHP的校园招聘管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的校园招聘管理系统 一 介绍 此校园招聘管理系统基于原生PHP开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为个人用户&#xff0c;企业和管理员三种。 技术栈&#xff1a;phpmysqlbootstrapphpstudyvscode 二…

JS——判断节假日(假日包括周末,不包括调休上班的周末)

思路&#xff1a;创建两个数组&#xff0c;数组1为节假日数组&#xff0c;数组2为是周末上班日期数组。如果当前日期&#xff08;或某日期&#xff09;同时满足2个条件&#xff08;1.在节假日数组内或在周末。2.不在周末上班日期数组&#xff09;即为节假日&#xff0c;否则即为…

Mybatis-Plus05(分页插件)

分页插件 MyBatis Plus自带分页插件&#xff0c;只要简单的配置即可实现分页功能 1. 添加配置类 Configuration MapperScan("com.atguigu.mybatisplus.mapper") //可以将主类中的注解移到此处 public class MybatisPlusConfig {Bean public MybatisPlusIntercepto…