pytorch03:transforms常见数据增强操作

news2025/1/23 17:28:19

目录

  • 一、数据增强
  • 二、transforms--Crop裁剪
    • 2.1 transforms.CenterCrop
    • 2.2 transforms.RandomCrop
    • 2.3 RandomResizedCrop
    • 2.4 FiveCrop和TenCrop
  • 三、transforms—Flip翻转、旋转
    • 3.1RandomHorizontalFlip和RandomVerticalFlip
    • 3.2 RandomRotation
  • 四、transforms —图像变换
    • 4.1 transforms.Pad
    • 4.2 transforms.ColorJitter
    • 4.3 Grayscale和RandomGrayscale
    • 4.4 RandomAffine
    • 4.5 RandomErasing
  • 五、transforms的操作
    • 5.1 transforms.RandomChoice
    • 5.2 transforms.RandomApply
    • 5.3 transforms.RandomOrder
  • 六、自定义transforms
    • 6.1 自定义transforms要素
    • 6.2 通过类实现多参数传入
    • 6.3 椒盐噪声
    • 6.4 自定义transforms代码实现
  • 七、数据增强策略
    • 数据增强代码实现

一、数据增强

   数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。如下是对一张图片常见的增强操作例如:旋转、裁剪、像素抖动。
在这里插入图片描述

二、transforms–Crop裁剪

2.1 transforms.CenterCrop

功能:从图像中心裁剪图片
• size:所需裁剪图片尺寸

2.2 transforms.RandomCrop

功能:从图片中随机裁剪出尺寸为size的图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• padding:设置填充大小
  当为a时,上下左右均填充a个像素,
  当为(a, b)时,上下填充b个像素,左右填充a个像素,
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• pad_if_need:若图像小于设定size,则填充
• padding_mode:填充模式,有4种模式
  1、constant:像素值由fill设定
  2、edge:像素值由图像边缘像素决定
  3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2]
  4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]
• fill:constant时,设置填充的像素值

2.3 RandomResizedCrop

功能:随机大小、长宽比裁剪图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• scale:随机裁剪面积比例, 默认(0.08, 1)
• ratio:随机长宽比,默认(3/4, 4/3)
• interpolation:插值方法
PIL.Image.NEAREST
PIL.Image.BILINEAR
PIL.Image.BICUBIC

2.4 FiveCrop和TenCrop

  功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• vertical_flip:是否垂直翻转

三、transforms—Flip翻转、旋转

3.1RandomHorizontalFlip和RandomVerticalFlip

在这里插入图片描述

功能:依概率水平(左右)或垂直(上下)翻转图片
• p:翻转概率

3.2 RandomRotation

功能:随机旋转图片
在这里插入图片描述
在这里插入图片描述

• degrees:旋转角度
  当为a时,在(-a,a)之间选择旋转角度
  当为(a, b)时,在(a, b)之间选择旋转角度
• resample:重采样方法
• expand:是否扩大图片,以保持原图

四、transforms —图像变换

4.1 transforms.Pad

功能:对图片边缘进行填充
在这里插入图片描述
• padding:设置填充大小
  当为a时,上下左右均填充a个像素
  当为(a, b)时,上下填充b个像素,左右填充a个像素
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• padding_mode:填充模式,有4种模式,constant、edge、reflect和symmetric
• fill:constant时,设置填充的像素值,(R, G, B) or (Gray)

4.2 transforms.ColorJitter

功能:调整亮度、对比度、饱和度和色相
在这里插入图片描述

• brightness:亮度调整因子
  当为a时,从[max(0, 1-a), 1+a]中随机选择
  当为(a, b)时,从[a, b]中
• contrast:对比度参数,同brightness
• saturation:饱和度参数,同brightness
• hue:色相参数,当为a时,从[-a, a]中选择参数,注: 0<= a <= 0.5
        当为(a, b)时,从[a, b]中选择参数,注:-0.5 <= a <= b <= 0.5

4.3 Grayscale和RandomGrayscale

功能:依概率将图片转换为灰度图
在这里插入图片描述
• num_ouput_channels:输出通道数只能设1或3
• p:概率值,图像被转换为灰度图的概率

4.4 RandomAffine

功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
在这里插入图片描述
在这里插入图片描述
• degrees:旋转角度设置
• translate:平移区间设置,如(a, b), a设置宽(width),b设置高(height)
    图像在宽维度平移的区间为 -img_width * a < dx < img_width * a
• scale:缩放比例(以面积为单位)
• fill_color:填充颜色设置

4.5 RandomErasing

功能:对图像进行随机遮挡
在这里插入图片描述

• p:概率值,执行该操作的概率
• scale:遮挡区域的面积
• ratio:遮挡区域长宽比
• value:设置遮挡区域的像素值,(R, G, B) or (Gray)

五、transforms的操作

5.1 transforms.RandomChoice

功能:从一系列transforms方法中随机挑选一个

transforms.RandomChoice([transforms1, transforms2, transforms3])

5.2 transforms.RandomApply

功能:依据概率执行一组transforms操作

transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

5.3 transforms.RandomOrder

功能:对一组transforms操作打乱顺序

transforms.RandomOrder([transforms1, transforms2, transforms3])

六、自定义transforms

6.1 自定义transforms要素

1.仅接收一个参数,返回一个参数
2.注意上下游的输出与输入
当前transforms的输入是上一个transforms的输出,所以要保证数据类型匹配:
在这里插入图片描述

6.2 通过类实现多参数传入

在这里插入图片描述

在Python中,__call__是一个特殊的方法,用于使一个对象可以像函数一样被调用。如果一个类定义了__call__方法,那么实例化的对象就可以被当作函数一样调用,而调用的实际上是__call__方法。

class CallableClass:
    def __init__(self):
        print("Initializing the CallableClass")

    def __call__(self, *args, **kwargs):
        print("Calling the CallableClass with arguments:", args, kwargs)

# 实例化对象
obj = CallableClass()

# 调用对象,实际上调用了__call__方法
obj(1, 2, 3, keyword_arg="hello")

上面的例子中,CallableClass定义了__call__方法,这意味着实例obj可以像函数一样被调用。当你调用obj(1, 2, 3, keyword_arg=“hello”)时,实际上是在调用obj.call(1, 2, 3, keyword_arg=“hello”)。

6.3 椒盐噪声

椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点, 白点称为盐噪声,黑色为椒噪声
信噪比(Signal-Noise Rate, SNR)是衡量噪声的比例,图像中为图像像素的占比,从下图可以看出,信噪比越小,图片丢失的像素越多。
在这里插入图片描述

6.4 自定义transforms代码实现

class AddPepperNoise(object):
    """增加椒盐噪声

    Args:
        snr (float): Signal Noise Rate 信噪比
        p (float): 概率值,依概率执行该操作

    Attributes:
        snr (float): 信噪比
        p (float): 操作执行的概率
    """

    def __init__(self, snr, p=0.9):
        # 确保传入的snr和p是float类型
        assert isinstance(snr, float) and isinstance(p, float)
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        对图像应用椒盐噪声操作。

        Args:
            img (PIL Image): PIL Image对象

        Returns:
            PIL Image: 处理后的PIL Image对象
        """
        # 根据概率决定是否执行噪声操作
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy()
            h, w, c = img_.shape
            signal_pct = self.snr
            noise_pct = (1 - self.snr)
            
            # 生成噪声掩码,表示每个像素是原始图像、盐噪声还是椒噪声
            mask = np.random.choice((0, 1, 2), size=(h, w, 1),
                                    p=[signal_pct, noise_pct / 2., noise_pct / 2.])
            mask = np.repeat(mask, c, axis=2)
            
            # 根据噪声类型修改图像像素值
            img_[mask == 1] = 255  # 盐噪声
            img_[mask == 2] = 0    # 椒噪声
            
            # 将NumPy数组转换回PIL Image对象,并确保数据类型为uint8,颜色通道为RGB
            return Image.fromarray(img_.astype('uint8')).convert('RGB')
        else:
            return img

在这里插入图片描述

七、数据增强策略

原则:让训练集与测试集更接近可以使用下面这些方法
• 空间位置:平移
• 色彩:灰度图,色彩抖动
• 形状:仿射变换
• 上下文场景:遮挡,填充

例如我们训练集白猫比较多,可以改变白猫色彩,让白猫的颜色接近黑猫。
在这里插入图片描述

数据增强代码实现

要求:使用第四套RMB进行训练,要求能对第5套RMB识别正确。

我们只进行普通的图片处理训练好的模型,发现将第五套100元都识别成一元,因为第四套人民币的1元和第五套人民的100元颜色相近,所以会导致识别错误:
在这里插入图片描述
解决方法,将所有训练集颜色都进行灰度处理,代码修改如下:

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.9),  #图片灰度化
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

修改后的预测结果如下:
在这里插入图片描述
训练完整代码如下:

# -*- coding: utf-8 -*-

import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from lenet import LeNet
from my_dataset import RMBDataset
from common_tools import transform_invert


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.9),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])


valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss_val)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

# ============================ inference ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)

for i, data in enumerate(valid_loader):
    # forward
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    rmb = 1 if predicted.numpy()[0] == 0 else 100

    img_tensor = inputs[0, ...]  # C H W
    img = transform_invert(img_tensor, train_transform)
    plt.imshow(img)
    plt.title("LeNet got {} Yuan".format(rmb))
    plt.show()
    plt.pause(0.5)
    plt.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1347265.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JavaScript】浮点数精度问题

✨ 专栏介绍 在现代Web开发中&#xff0c;JavaScript已经成为了不可或缺的一部分。它不仅可以为网页增加交互性和动态性&#xff0c;还可以在后端开发中使用Node.js构建高效的服务器端应用程序。作为一种灵活且易学的脚本语言&#xff0c;JavaScript具有广泛的应用场景&#x…

【MySQL表的增删查改】

文章目录 前言1 Create1.1 单行数据 全列插入1.2 多行数据 指定列插入1.3 插入否则更新1.4 替换 2 Retrieve2.1 SELECT 列2.1.1 全列查询2.1.2 指定列查询2.1.3 查询字段为表达式2.1.4 为查询结果指定别名2.1.5 结果去重 2.2 WHERE 条件2.2.1 英语不及格的同学及英语成绩 ( &…

vue 基础学习 一

1. vue 使用快速入门三步走 (1) 新建 HTML 页面&#xff0c;引入 Vue.js文件 1 2 3 4 5 6 7 <!DOCTYPE html> <html> <head> <meta charset"UTF-8"> <title>Vue.js 入门示例</title> <script src"https://cdn.j…

设计模式:抽象工厂模式(讲故事易懂)

抽象工厂模式 定义&#xff1a;将有关联关系的系列产品放到一个工厂里&#xff0c;通过该工厂生产一系列产品。 设计模式有三大分类&#xff1a;创建型模式、结构型模式、行为型模式 抽象工厂模式属于创建型模式 上篇 工厂方法模式 提到工厂方法模式中每个工厂只生产一种特定…

12.29_黑马数据结构与算法笔记Java

目录 305 旅行商问题 动态规划 实现2 306 旅行商问题 动态规划 实现3 307 分治 概述 308 快速选择算法 分治 309 快速选择算法 数组第k大数 Leetcode215 310 快速选择算法 数组中位数 311 快速幂 分治 312 快速幂 Leetcode50 313 平方根整数部分 Leetcode69-1 314 平方…

从实际工作情况,介绍嵌入式(MCU)软件开发常用(通用)工具

目录 前言 1、代码阅读及编辑工具&#xff08;VSCode、Understand&#xff09; 2、代码对比工具&#xff08;Beyond Compare&#xff09; 3、代码仓库相关工具&#xff08;Git、SVN、Tortoise&#xff09; 4、文本编辑器&#xff08;Notepad&#xff09; 5、电脑文件搜索工…

arkts状态管理使用(@State、@Prop、@Link、@Provide、@Consume、@objectLink和@observed)

一、状态管理 1.在声明式UI中&#xff0c;是以状态驱动视图更新&#xff1a; ①状态&#xff08;State&#xff09;:指驱动视图更新的数据&#xff08;被装饰器标记的变量&#xff09; ②视图&#xff08;View&#xff09;:基于UI描述渲染得到用户界面 注意&#xff1a; ①…

Radar System Pro - Plug Play Solution

Radar System Pro是一款功能多样且可定制的资源,旨在通过功能齐全且易于使用的雷达系统增强您的Unity项目。无论您是在开发第一人称射击游戏、策略游戏还是太空探索模拟器,我们的雷达系统都将为您提供所需的工具,以创建引人入胜且身临其境的体验。 雷达系统是一个模块化资产…

mysql间隙锁demo分析

概述 通常用的mysql都是innodb引擎&#xff1b; 一般在update的时候用id都会认为是给行记录加锁&#xff1b; 在使用非唯一索引更新时&#xff0c;会遇到临键锁&#xff08;范围锁&#xff09;&#xff1b; 临键锁和表中的数据有关&#xff1b; mysq版本:8 隔离级别&#xf…

雨课堂作业整理

第一次作业 1.下列序列是图序列的是&#xff08; &#xff09; A.1&#xff0c;2&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;4&#xff0c;5 B.1&#xff0c;1&#xff0c;2&#xff0c;2&#xff0c;4&#xff0c;6&#xff0c;6 C.0&#xff0c;0&#xff0c;2&am…

解决基于VectorGrid的矢量瓦片Y轴偏移的问题

目录 前言 一、GeoServer的瓦片 1、GeoWebcache缓存配置 2、矢量瓦片本地缓存 3、瓦片访问 二、VectorGrid加载本地瓦片 1、加载关键代码 2、默认模式的问题 3、问题分析 4、tms参数修改 总结 前言 在前面的博文介绍中&#xff0c;在线连接如下&#xff1a;浅谈前端自定义…

AI与数字化映像:颜值开端,功能至上_光点科技

在人工智能的浪潮中&#xff0c;AI数字人的兴起正成为一个不可忽视的现象。随着ChatGPT等生成式AI算法的进步&#xff0c;AIGC&#xff08;人工智能生成内容&#xff09;的应用呈现出爆发性增长&#xff0c;不仅在技术圈引起广泛关注&#xff0c;也为元宇宙及其相关产业链带来了…

C++每日一练(8):图像相似度

题目描述 给出两幅相同大小的黑白图像&#xff08;用0-1矩阵&#xff09;表示&#xff0c;求它们的相似度。 说明&#xff1a;若两幅图像在相同位置上的像素点颜色相同&#xff0c;则称它们在该位置具有相同的像素点。两幅图像的相似度定义为相同像素点数占总像素点数的百分比。…

非科班,培训出身,怎么进大厂?

今天分享一下我是怎么进大厂的经历&#xff0c;希望能给大家带来一点点启发&#xff01; 阿七毕业于上海一所大学的管理学院&#xff0c;在读期间没写过一行 Java 代码。毕业之后二战考研失利。 回过头来看&#xff0c;也很庆幸这次考研失利&#xff0c;因为这个时候对社会一…

精品Nodejs实现的在线菜谱食谱美食学习系统的设计与实现

《[含文档PPT源码等]精品Nodejs实现的在线菜谱学习系统的设计与实现[包运行成功]》该项目含有源码、文档、PPT、配套开发软件、软件安装教程、项目发布教程、包运行成功&#xff01; 软件开发环境及开发工具&#xff1a; 操作系统&#xff1a;Windows 10、Windows 7、Windows…

2023第三届中国高校大数据挑战赛B题代码

任务已完成&#xff0c;聚类效果很好&#xff08;主要在于数据的处理以及特征工程&#xff09;, 需代码si&#xff0c;yuer有限先到先得。

模型 KANO卡诺模型

本系列文章 主要是 分享 思维模型&#xff0c;涉及各个领域&#xff0c;重在提升认知。需求分析。 1 卡诺模型的应用 1.1 餐厅需求分析故事 假设你经营一家餐厅&#xff0c;你想了解客户对你的服务质量的满意度。你可以使用卡诺模型来收集客户的反馈&#xff0c;并分析客户的…

Android Matrix画布Canvas旋转Rotate,Kotlin

Android Matrix画布Canvas旋转Rotate&#xff0c;Kotlin private fun f1() {val originBmp BitmapFactory.decodeResource(resources, R.mipmap.pic).copy(Bitmap.Config.ARGB_8888, true)val newBmp Bitmap.createBitmap(originBmp.width, originBmp.height, Bitmap.Config.…

基于价值认同的需求侧电能共享分布式交易策略(matlab完全复现)

目录 1 主要内容 2 部分程序 3 程序结果 4 下载链接 1 主要内容 该程序完全复现《基于价值认同的需求侧电能共享分布式交易策略》&#xff0c;针对电能共享市场的交易机制进行研究&#xff0c;提出了基于价值认同的需求侧电能共享分布式交易策略&#xff0c;旨在降低电力市…

学习笔记15——前端和http协议

学习笔记系列开头惯例发布一些寻亲消息&#xff0c;感谢关注&#xff01; 链接&#xff1a;https://baobeihuijia.com/bbhj/ 关系 客户端&#xff1a;对连接访问到的前端代码进行解析和渲染&#xff0c;就是浏览器的内核服务器端&#xff1a;按照规则编写前端界面代码 解析标准…