pytorch实战-图像分类(一)(数据预处理)

news2025/1/24 2:27:07

目录

1.导入各种库

2.数据预处理

2.1数据读取

2.2图像增强

3.构建数据网络 

3.1网络构建

3.2读取标签对应的名字

4.展示数据

4.1数据转换

4.2画图

5.模型训练


1.导入各种库

上代码:

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2.数据预处理

2.1数据读取

先看以下训练集和验证集存放的位置

 上代码

data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

2.2图像增强

目的:我们所收集准备训练的数据都是很可贵的,数据越多成本也就越高,所以希望将有限的数据集最大化利用,这就时图像增强的目的。

定义:如下图小灰猫,进行翻转操作,小黄猫,进行不同角度的旋转操作,这样实现了一图多用的效果,在原数据的基础上,将数据集翻了几倍。比方说你现在有一个1w的数据集,经过数据增强,可以完成10w的数据集。

 上代码

data_transforms = {
    'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(224),#从中心开始裁剪(224×224),因为训练集收集的图大小可能不同,但神经网络需要同样大小的输入.
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率,p=0.5就是说,有50%概率执行该操作。
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(), #将数据转化成tensor格式输入
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#因为本例是要用别人的模型训练,所以要参考别人例子中提供的均值,标准差,对自己的的训练集进行标准化操作。
    ]),
    'valid': transforms.Compose([transforms.Resize(256), #验证集不需要做数据增强,其他处理方法和train一样。
        transforms.CenterCrop(224), #验证集数据裁剪成和训练集一样,才能对比
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

3.构建数据网络 

3.1网络构建

batch_size = 8

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']} # 构建分类任务数据集,注意不同任务数据集构建方式不同。
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']} # 按照batch_size = 8大小加载数据。
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} # 看一下数据的数量,该例'train': 6552, 'valid': 818
class_names = image_datasets['train'].classes

3.2读取标签对应的名字

网络最后的输出是一个代表类别的数值,比方说1,2,3,但我们希望看到这个数值对应的类别,所以json存这些信息,比方说{'1': 'pink primrose'}。

with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f) 

4.展示数据

4.1数据转换

注意:进行训练时需要tensor格式的数据,所以展示的时候tensor的数据需要转换成numpy的格式,而且还需要还原回标准化的结果。

def im_convert(tensor): #im_convert转化函数
    """ 展示数据"""
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

4.2画图

fig=plt.figure(figsize=(20, 12))
columns = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

5.模型训练

下接该文:pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/831947.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Pytorch深度强化学习1-4:策略改进定理与贝尔曼最优方程详细推导

目录 0 专栏介绍1 贝尔曼最优方程2 贪心策略与策略改进3 策略迭代与价值迭代4 算法流程 0 专栏介绍 本专栏重点介绍强化学习技术的数学原理,并且采用Pytorch框架对常见的强化学习算法、案例进行实现,帮助读者理解并快速上手开发。同时,辅以各…

PHP正则绕过解析

正则绕过 正则表达式PHP正则回溯PHP中的NULL和false回溯案例案例1案例2 正则表达式 在正则中有许多特殊的字符,不能直接使用,需要使用转义符\。如:$,(,),*,,.,?,[,,^,{。 这里大家会有疑问:为啥小括号(),这个就需要两个来转义&a…

C++ 对象数组

**数组元素不仅可以是基本数据类型,也可以是自定义类型。**例如,要存储和处理某单位全体雇员的信息,就可以建立一个雇员类的对象数组。对象数组的元素是对象,不仅具有数据成员,而且还有函数成员。 因此,和基…

iframe跨域解决方案

在 Web 开发中,跨域是指在一个域(例如,https://www.example.com)的页面中请求了另一个域(例如,https://api.example.com)的资源,浏览器出于安全考虑会阻止这样的请求。为了解决 ifra…

C#实现旋转图片验证码

开发环境:C#,VS2019,.NET Core 3.1,ASP.NET Core 1、建立一个验证码控制器 新建两个方法Create和Check,Create用于创建验证码(返回1张图片和令牌),Check用于验证(验证图…

Json文件编辑功能

1 Json格式 JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式。它基于 ECMAScript(European Computer Manufacturers Association, 欧洲计算机协会制定的js规范)的一个子集,采用完全独立于编程语言的文本格式来存储和表示数据。…

Curve深陷安全事件,OKLink如何破局

出品|欧科云链研究院 作者|Matthew Lee 7月31号,Curve 在平台表示 Vyper 0.2.15 的稳定币池由于编译器的漏洞所以遭到攻击。具体因为重入锁功能的失效,所以黑客可以轻易发动重入攻击,即允许攻击者在单次交易中执行某…

【运维】在阿里云上搭建自己的图床,配合PicGo和Typora使用

本文将详细介绍如何在阿里云上搭建自己的图床,包括购买OSS服务、配置域名解析、创建OSS存储桶和设置图片上传规则等步骤。希望对您有所帮助! 一、购买OSS服务 首先,我们需要在阿里云官网购买OSS(Object Storage Service)服务。OSS是阿里云提…

【Linux命令200例】cp用于复制文件和目录(常用)

🏆作者简介,黑夜开发者,全栈领域新星创作者✌,阿里云社区专家博主,2023年6月csdn上海赛道top4。 🏆本文已收录于专栏:Linux命令大全。 🏆本专栏我们会通过具体的系统的命令讲解加上鲜…

《golang设计模式》第一部分·创建型模式-05-工厂方法模式(Factory Method)

文章目录 1 概述2.1 角色2.2 类图 2 代码示例2. 1 设计2.2 代码2.3 类图 3. 简单工厂3.1 角色3.2 类图3.3 代码示例3.3.1 设计3.3.2 代码3.3.3 类图 1 概述 工厂方法类定义产品对象创建接口,但由子类实现具体产品对象的创建。 2.1 角色 Product(抽象产…

opencv-38 形态学操作-闭运算(先膨胀,后腐蚀)cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)

闭运算是先膨胀、后腐蚀的运算,它有助于关闭前景物体内部的小孔,或去除物体上的小黑点,还可以将不同的前景图像进行连接。 例如,在图 8-17 中,通过先膨胀后腐蚀的闭运算去除了原始图像内部的小孔(内部闭合的…

MacBook截取网页长图

第一步:⌘Command Option I 第二步:⌘Command Shift P 第三步: 红框内输入Capture full size screenshot,回车,长图会自动下载。

软考高项(六)项目管理概述 ★重点集萃★

👑 个人主页 👑 :😜😜😜Fish_Vast😜😜😜 🐝 个人格言 🐝 :🧐🧐🧐说到做到,言出必行&am…

VR实景导航——开启3D可视化实景导航新体验

数字化时代,我们大家出门在外都是离不开各种导航软件,人们对导航的需求也越来越高,而传统的导航软件由于精度不够,无法满足人们对真实场景的需求,这个时候就需要VR实景导航为我们实景指引目的地的所在。 VR实景导航以其…

新一代开源流数据湖平台Apache Paimon入门实操-上

文章目录 概述定义核心功能适用场景架构原理总体架构统一存储基本概念文件布局 部署环境准备环境部署 实战Catalog文件系统Hive Catalog 创建表创建Catalog管理表查询创建表(CTAS)创建外部表创建临时表 修改表修改表修改列修改水印 概述 定义 Apache Pa…

【每日一题】—— C. Challenging Cliffs(Codeforces Round 726 (Div. 2))

🌏博客主页:PH_modest的博客主页 🚩当前专栏:每日一题 💌其他专栏: 🔴 每日反刍 🟡 C跬步积累 🟢 C语言跬步积累 🌈座右铭:广积粮,缓称…

SpringBoot 项目创建与运行

一、Spring Boot 1、什么是Spring Boot?为什么要学 Spring Boot Spring 的诞生是为了简化 Java 程序的开发的,而 Spring Boot 的诞生是为了简化 Spring 程序开发的。 Spring Boot 翻译一下就是 Spring 脚手架 盖房子的这个架子就是脚手架,…

【Linux】网络编程套接字

1 预备知识 1.1 IP地址 IP协议有两个版本,分别是IPv4和IPv6。没有特殊说明,默认都是IPv4对于IPv4,IP地址是一个四个字节32为的整数;对于IPv6来说,IP地址是128位的整数 我们通常也使用 “点分十进制” 的字符串表示IP…

C 语言高级1-内存分区,多级指针,位运算

目录 1. 内存分区 1.1 数据类型 1.1.1 数据类型概念 1.1.2 数据类型别名 1.1.3 void数据类型 1.1.4 sizeof操作符 1.1.5 数据类型总结 1.2 变量 1.1.1 变量的概念 3.1.2 变量名的本质 1.3 程序的内存分区模型 1.3.1 内存分区 1.3.1.1 运行之前 1.3.1.2运行之后 1…

无涯教程-Perl - 循环语句

在某些情况下,您需要多次执行一个代码块。通常,语句是按顺序执行的:函数中的第一个语句首先执行,然后第二个执行,依此类推。 Perl编程语言提供了以下类型的循环来处理循环需求。 Sr.No.Loop Type & 描述1 while loop在给定条…