将图像增广应用于Mnist数据集

news2025/1/10 10:30:48

将图像增广应用于Mnist数据集

不用到cifar-10的原因是要下载好久。。我就直接用在Mnist上了,先学会用

首先我们得了解一下图像增广的基本内容,这是我的一张猫图片,以下为先导入需要的包和展示图片

import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d2l.set_figsize()
img = Image.open('cat.png')
d2l.plt.imshow(img)

在这里插入图片描述
之后呢,我们先定义几个函数,以后方便调用,第一个函数show_images,他是用来展示多张图片的

def show_images(imgs, num_rows, num_cols, scale=2):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize = figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            axes[i][j].imshow(imgs[i * num_cols + j])
            axes[i][j].axes.get_xaxis().set_visible(False)
            axes[i][j].axes.get_yaxis().set_visible(False)
    return axes

然后将图像展示函数和图像增广函数结合起来展示,也用一个函数来集成

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    show_images(Y, num_rows, num_cols, scale)

接下来,就可以开始我们的图像增广之路啦

左右翻转

torchvision.transforms.RandomHorizontalFlip()这个函数有百分之五十的概率实现左右翻转

apply(img, torchvision.transforms.RandomHorizontalFlip()) # torchvision.transforms.RandomHorizontalFlip() 百分之50的概率左右翻转

在这里插入图片描述

上下翻转

torchvision.transforms.RandomVerticalFlip() 百分之50的概率上下翻转
在这里插入图片描述

随机裁剪

随机裁剪出一块面积为原面积10%100%的区域,且该区域的宽和高之比随机取自0.52,然后将该区域的宽高缩放到200像素

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

在这里插入图片描述
自然,我们也可以变换颜色,有亮度(brightness),对比度(contrast),饱和度(saturation),色调(hue)
我就直接一起写了,也可以只变单个
0.5的意思是比如对于亮度来说,他会在50%的范围内随机选择,即亮度为原来的0.5~1.5

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, hue=0.5, saturation=0.5, contrast=0.5) 
apply(img, color_aug)

在这里插入图片描述
那么当然,我们也可以把上述的那些进行叠加
用到torchvision.transforms.Compose

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug
])
apply(img, augs)

在这里插入图片描述
之后呢,就可以用增广后的图像进行训练啦,这里给大家一个例子用Resnet18进行训练Mnist数据集,Resnet18就不带着大家写了,直接调用别人写好的函数,写网络并不是本节的重点,如果以后有时间或者大家有需要我可以再来写~
(为什么是Mnist数据集,其实他在Mnist数据集上的效果并没有很明显,比较比较简单,最好是在cifar上,但是cifar要下太久了,懒,大家可以在cifar上测一下)
先写两个augs,训练集我就将他随机翻转,测试集就不动了

flip_aug = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()           # 记得转换成tensor 以便训练
])
no_aug = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

然后是load_mnist,就加一个transform就好啦,集成到一个函数里

def load_mnist(is_train, augs, batch_size, root="~/Datasets/MNIST"):
    dataset = torchvision.datasets.MNIST(train=is_train, root=root, transform=augs, download=True)
    return DataLoader(dataset, batch_size = batch_size, shuffle=is_train)

再之后就是模型的训练了,这个大家应该都写腻了,我也不多说什么了,反正就是模型前向传播+反向传播,然后再记录点值

def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = d2l.evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

再最后,就定义一个函数,把前面的都用上啦!

def train_with_data_aug(train_augs, test_augs, lr=0.001):
    batch_size, net = 256, d2l.resnet18(output=10, in_channels=1)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss()
    train_iter = load_mnist(True, train_augs, batch_size)
    test_iter = load_mnist(False, test_augs, batch_size)
    train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)

值得注意的是,这边调用别人的d2l.resnet18,要注意in_channels=1记得写,他默认是3通道的,改成1通道对于我们的mnist,如果你要是cifar-10就不用变了,把in_channel=1给删掉就好~,至此,调用我们的函数就行
在这里插入图片描述

训练还是很快滴

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1292499.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

javaTCP协议实现一对一聊天

我们首先要完成服务端,不然出错,运行也要先运行服务端,如果不先连接服务端,就不监听,那客户端不知道连接谁 服务端 import java.awt.BorderLayout; import java.awt.event.ActionEvent; import java.awt.event.Actio…

超越GPT-4!谷歌发布最强多模态大模型—Gemini

12月7日凌晨,谷歌在官网发布了全新最强多模态大模型——Gemini。 据悉,Gemini有Ultra、Pro、Nano三个版本,可自动生成文本、代码、总结内容等,并能理解图片、音频和视频内容。在MMLU、DROP 、HellaSwag、GSM8K等主流评测中&#…

JVM虚拟机(已整理,已废弃)

# JVM组成 ## 简述程序计数器 线程私有,内部保存class字节码的行号。用于记录正在执行的字节码指令的地址。 线程私有-每个线程都有自己的程序计数器PC,用于记录当前线程执行哪个行号 ## 简述堆 ## 简述虚拟机栈 ## 简述堆栈区别 ## 方法内局部变量是…

【前端架构】清洁前端架构

探索前端架构:概述与干净的前端架构相关的一些原则(SOLID、KISS、DRY、DDD等)。 在我之前的一篇帖子中,我谈到了Signals和仍然缺少的内容[1]。现在,我想谈谈一个更通用的主题,即Clean Frontend Architectu…

python+paddleocr 进行图像识别、找到文字在屏幕中的位置

目录 前言 1、安装paddleocr 2、安装PIL 3、安装numpy 4、 安装pyautogui 5、进行文本识别 6、识别结果 7、获取文字在图片/屏幕中的位置 8、pyautoguipaddleocr鼠标操作 9、完整代码 前言 最近在做自动化测试,因为是处理过的界面,所以使用pyw…

Vue3项目调用腾讯地图服务(地址解析 地址转坐标)及使用axios的跨域问题

一,需求 根据传入的文本地址 将其转换为坐标 显示地图点位在腾讯地图上 二,使用axios发送请求 import axios from axios; //引入axiosaxios({url:https://apis.map.qq.com/ws/geocoder/v1,method:get//参数 地址和key值}).then((data)>{console.log(data)});但是使用完报跨…

猫咪瘦弱的原因是什么?适合给消瘦猫咪长肉吃的猫罐头分享

很多小猫咪吃得很多,但是还是很瘦,这让很多猫主人感到困惑,猫咪瘦弱的原因是什么呢?铲屎那么多年,还是有点子养猫知识在身上的。那么,小猫咪瘦弱的原因是什么呢?让我们看看是不是这些原因导致的…

为什么有些程序员宁愿在国内 35 岁被辞退,也不愿意去国外工作?

我发现IT圈和电竞圈有一个共性:菜是原罪。 为什么有些程序员35岁就会被辞退?因为菜。 为什么有些程序员不愿意去国外工作?因为菜。 当然,我这里指的菜不是烂泥扶不上墙的那种菜,而是不够拔尖。那么这个问题也就分为了三…

【项目日记(一)】高并发内存池项目介绍

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:项目日记-高并发内存池⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习C   🔝🔝 项目日记 1. 前言2. 什么是高并发内存池…

基于ssm vue个人需求和地域特色的外卖推荐系统源码和论文

首先,论文一开始便是清楚的论述了系统的研究内容。其次,剖析系统需求分析,弄明白“做什么”,分析包括业务分析和业务流程的分析以及用例分析,更进一步明确系统的需求。然后在明白了系统的需求基础上需要进一步地设计系统,主要包罗软件架构模式、整体功能模块、数据库设计。本项…

9种伪原创工具推荐,快速提升创作效率

如何让自己的文章在海量信息中脱颖而出,成为一个备受关注的焦点,成为许多创作者迫切思考的问题。在这篇文章中,我将向大家介绍9种伪原创工具,这些工具可以让你的文章轻松升级,更具创意和吸引力。 1.Spinbot&#xff08…

simulink中 Data store memory、write和read模块及案例介绍

目录 1.Data store memory模块 2.data store write模块 3.data store read模块 4.仿真分析 4.1简单使用三个模块 4.2 模块间的调用顺序剖析 1.Data store memory模块 向右拖拉得到Data store read模块,向左拉得到Data write模块 理解:可视为定义变量…

C++ 函数详解

目录 函数概述 函数的分类 函数的参数 函数的调用 函数的嵌套调用 函数的链式访问 函数声明和定义 函数递归 函数概述 函数——具有某种功能的代码块。 一个程序中我们经常会用到某种功能,如两数相加,如果每次都在需要用到时实现,那…

矩阵学习相关——(待完善)

线性代数基础知识之–矩阵(Matrix) 矩阵概念————(基础知识) 矩阵理论基础知识 矩阵理论基础知识 矩阵入门 写给有编程基础的人 初学讲义之高中数学二十七:矩阵和行列式 直观理解!你一定要读…

C++多态(详解)

一、多态的概念 1.1、多态的概念 多态:多种形态,具体点就是去完成某个行为,当不同的对象去完成时会产生出不同的状态。 举个例子:比如买票这个行为,当普通人买票时,是全价买票;学生买票时&am…

JavaScript实现手写签名,可触屏手写,支持移动端与PC端双端保存

目录 1.HTML模板 2.获取DOM元素和定义变量 3.创建两个canvas元素,并设置它们的宽度和高度 4.绑定触摸事件:touchstart, touchmove, touchend和click 5.实现触摸事件回调函数:startDrawing, draw和stopDrawing 6.实现绘制线段的函数&…

C# WPF上位机开发(带配置文件的倒计时软件)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们写了倒计时软件,但是不知道大家有没有发现,这个软件起始有一个缺点,那就是倒计时的起始时间都是硬编码…

stl库之map与例题

map是一种关联式容器&#xff0c;它允许将键&#xff08;key&#xff09;映射到值&#xff08;value&#xff09;&#xff0c;所以我们习惯称map为映射 每个元素都是一个键值对&#xff0c;其中键是唯一的 创建map map<key类型, value类型> 变量名; 创建一个键为int&…

11.7QT界面制作

#include "widget.h"Widget::Widget(QWidget *parent): QWidget(parent) {this->resize(881,550);this->setStyleSheet("backgroud-color:rgb(33,35,40)");this->setWindowFlag(Qt::FramelessWindowHint);//标签类QLabel *l1 new QLabel(this);/…

深入理解Flexbox:构建灵活的布局系统

由于篇幅限制&#xff0c;我将提供一个详细的文章大纲和部分内容。您可以根据这个大纲扩展文章内容&#xff0c;以满足3000字的要求。 深入理解Flexbox&#xff1a;构建灵活的布局系统 引言 在现代web设计中&#xff0c;创建灵活且响应式的布局是非常重要的。Flexbox&#xf…