【pytorch】使用mixup技术扩充数据集进行训练

news2025/1/11 23:51:26

目录

  • 1.mixup技术简介
  • 2.pytorch实现代码,以图片分类为例

1.mixup技术简介

在这里插入图片描述

mixup是一种数据增强技术,它可以通过将多组不同数据集的样本进行线性组合,生成新的样本,从而扩充数据集。mixup的核心原理是将两个不同的图片按照一定的比例进行线性组合,生成新的样本,新样本的标签也是进行线性组合得到。比如,对于两个样本x1和x2,它们的标签分别为y1和y2,那么mixup生成的新样本x’和标签y’如下:

x’ = λx1 + (1-λ)x2
y’ = λy1 + (1-λ)y2

其中,λ为0到1之间的一个随机数,它表示x1和x2在新样本中的权重。

本文中,使用mixup扩充数据集后的损失函数,为:
loss = λ * criterion(outputs, targets_a) + (1 - λ) * criterion(outputs, targets_b)

即由两张图片融合后新图片的损失为,分别计算原先两图片与各自标签的损失值之和。

mixup也可以增加数据集的多样性,从而降低模型的方差,提高模型的鲁棒性。
总之,mixup是一种非常实用的数据增强技术,它可以用于各种机器学习任务中,可以有效地防止过拟合,并且可以提高模型的泛化能力。在实际应用中,mixup可以帮助我们更好地利用有限的数据集,并且提高模型的性能。

2.pytorch实现代码,以图片分类为例

import matplotlib.pyplot as plt
import torchvision
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from PIL import Image
import os
import cv2
# 加载resnet18模型
resnet18 = models.resnet18(pretrained=False)

# 获取resnet18最后一层输出,输出为512维,最后一层本来是用作 分类的,原始网络分为1000类
# 用 softmax函数或者 fully connected 函数,但是用 nn.identtiy() 函数把最后一层替换掉,相当于得到分类之前的特征!
#Identity模块,它将输入直接传递给输出,而不会对输入进行任何变换。
resnet18.fc = nn.Identity()
# 构建新的网络,将resnet18的输出作为输入
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.resnet18 = resnet18
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
        self.fc4 = nn.Linear(10, 2)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        x = self.resnet18(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.softmax(x)
        x=x.view(-1,2)
        return x
#使用mixup时的数据集融合器,输入数据集的输入(一批图片),以及对应标签,返回线性相加后的图片
def mixup_data(x, y, alpha=1.0):
    #随机生成一个 beta 分布的参数 lam,用于生成随机的线性组合,以实现 mixup 数据扩充。
    lam = np.random.beta(alpha, alpha)
    #生成一个随机的序列,用于将输入数据进行 shuffle。
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
    #得到混合后的新图片
    mixed_x = lam * x + (1 - lam) * x[index, :]
    #得到混图对应的两类标签
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam
    
# 实例化网络
net = Net()
# 将模型放入GPU
net = net.cuda()

# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器,添加l2正则化
optimizer = torch.optim.Adam(net.parameters(), lr=0.0005,weight_decay=0)

# 加载数据集
# 创建一个transform对象
def rgb2bgr(image):
    image = np.array(image)[:, :, ::-1]
    image=Image.fromarray(np.uint8(image))
    return image

transform = transforms.Compose([
    transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=0.5),
    # rgb转bgr
    torchvision.transforms.Lambda(rgb2bgr),
    torchvision.transforms.Resize(112),
    # 入的图片为PIL image 或者 numpy.nadrry格式的图片,其shape为(HxWxC)数值范围在[0,255],转换之后shape为(CxHxw),数值范围在[0,1]
    transforms.ToTensor(),
    # 进行归一化和标准化,Imagenet数据集的均值和方差为:mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),
    # 因为这是在百万张图像上计算而得的,所以我们通常见到在训练过程中使用它们做标准化。
    transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]),
    # #这行代码表示使用transforms.RandomErasing函数,以概率p=1,在图像上随机选择一个尺寸为scale=(0.02, 0.33),长宽比为ratio=(1, 1)的区域,
    # #进行随机像素值的遮盖,只能对tensor操作:
    transforms.RandomErasing(p=0.1, scale=(0.02, 0.2), ratio=(1, 1), value='random')
])

train_dataset = torchvision.datasets.ImageFolder(r'D:\eyeDataSet\train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataset = torchvision.datasets.ImageFolder(r'D:\eyeDataSet\test', transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True)

## 绘制训练及测试集迭代曲线
# 记录训练集准确率
train_acc = []
# 记录测试集准确率
test_acc = []
for epoch in range(50):
    running_loss = 0.0
    #[(0, data1), (1, data2), (2, data3), ...]
    for i, data in enumerate(train_loader, 0):
        # 获取输入
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()
        # mixup扩充数据集
        inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=1.0)
        # 梯度清零
        optimizer.zero_grad()
        # forward + backward
        outputs = net(inputs)
        #这里对应调整了使用mixup后的数据集的loss
        loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
        loss.backward()
        # 更新参数
        optimizer.step()
        # 打印log信息
        # loss 是一个scalar,需要使用loss.item()来获取数值,不能使用loss[0]
        running_loss += loss.item()
        if i % 10 == 9: # 每200个batch打印一下训练状态
            print('[%d, %5d] loss: %.3f' \
                  % (epoch+1, i+1, running_loss / 2000))
            running_loss = 0.0
            # 在每次训练完成后,使用测试集进行测试
            correct = 0
            total = 0
            with torch.no_grad():
                for i2,data2 in enumerate(test_loader):
                    #控制测试集的数量
                    if i2>5:
                       break
                    images, labels = data2
                    images, labels = images.cuda(), labels.cuda()
                    outputs = net(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            test_acc.append(100 * correct / total)
            print('Accuracy of the network on the test images: %.3f,now max acc is %.3f' % (
                    100 * correct / total,max(test_acc)))
            # 保存测试集上准确率最高的模型
            if 100 * correct / total == max(test_acc):
                if not os.path.exists(r'./result'):
                    os.makedirs(r'./result')
                if max(test_acc)>93:
                    savename="./result/bestmodel"+"%.3f"%max(test_acc)+".pth"
                    torch.save(net.state_dict(), savename)

print("最大准确度:",max(test_acc))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/391131.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【嵌入式开发】iperf

iperf一级目录用法help文档iperf参数功能iperf测试实例测试网口上行速率测试网口下行速率perf 是一个网络性能测试工具。Iperf可以测试最大TCP和UDP带宽性能,具有多种参数和UDP特性,可以根据需要调整,可以报告带宽、延迟抖动和数据包丢失。一…

小程序项目在hbuilder里面给它打包成app

小程序项目临时有些登录需求,需要把(小程序某些功能通过条件编译让它显示到app上)小程序打包成app的话就必须需要一个打包的证书,证书的话就要去重新生成,苹果电脑可以去自动生成证书,平时是用windows进行开…

Java的四种引用强软弱虚及其使用场景

一.强引用 回收时机:在内存不足时也不会被回收。 使用方式:String str new String("str"); 使用场景:是平常用的最多的引用 二.软引用 回收时机: 在内存不足时会被回收。 使用方式:SoftRefere…

阿里云全量物理备份.xb备份文件本地恢复

一、下载备份文件并上传到本地 二、安装转码软件 1、yum install https://repo.percona.com/yum/percona-release-latest.noarch.rpm2、yum install percona-xtrabackup-24三、解包并解压 1、先解包 cat test1_qp.xb | xbstream -x -v -C /home/mysql/data 2、然后解压&…

【历史上的今天】3 月 6 日:Unix 版权争夺战;豆瓣网上线;谷歌推出了 Google Play

整理 | 王启隆 透过「历史上的今天」,从过去看未来,从现在亦可以改变未来。 今天是 2023 年 3 月 6 日,在 1475 年的今天,大艺术家米开朗琪罗诞辰。米开朗琪罗是意大利文艺复兴时期的雕塑家、画家,他活到 89 岁&#…

Web服务器基础介绍与Apache的简单介绍(LAMP架构与搭建论坛)

目录 Web服务器基础介绍 一.HTML是什么? 二.静态网页和动态网页 1.静态网页 2.动态网页 3.动态网页语言 三.HTTP协议 1.HTTP协议是什么? 2.HTTP方法 3.HTTP状态码 4.HTTP请求流程分析 4.1 请求报文 4.2 响应报文 Apache的简单介绍 一.Apa…

如何找回回收站删除的视频?这三种方法可以试试

在使用电脑过程中,我们可能会误删重要的文件,特别是影音文件。在这样的情况下,我们可以从计算机的回收站中找回已经被删除的视频。但是有时候,我们可能会不小心清空回收站,这时候就需要一些技巧来恢复回收站删除的视频…

走进JVM

JVM的位置 在操作系统之上,可以想象成一个软件,Java程序都运行在上面 JVM结构图 JVM调优的位置 99%的调优在堆中,极少数在方法区中 很多第三方插件都是在执行引擎那块地方做出修改而来,比如Lombook在程序运行时动态生成get/s…

VSYNC研究

Vsync信号是SurfaceFlinger进程中核心的一块逻辑,我们主要从以下几个方面着手讲解。软件Vsync是怎么实现的,它是如何保持有效性的?systrace中看到的VSYNC信号如何解读,这些脉冲信号是在哪里打印的?为什么VSYNC-sf / VS…

YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析

前言 前面简单介绍了YOLOv5的网络结构和创新点(直通车:【YOLO系列】YOLOv5超详细解读(网络详解)) 在接下来我们会进入到YOLOv5更深一步的学习,首先从源码解读开始。 因为我是纯小白,刚开始下…

【看表情包学Linux】进程创建 | 进程终止 | 分叉函数 fork | 写时拷贝 | 内核数据结构缓冲池 | slab 分派器

爆笑教程《看表情包学Linux》👈 猛戳订阅!​​​​​💭 写在前面:本章我们主要讲解进程的创建与终止。首先讲解进程创建,fork 函数是我们早在讲解 "进程的概念" 章节就提到过的一个函数,在上个章…

gma 地理空间绘图:(1) 绘制简单的世界地图-3.设置地图框

内容回顾 gma 地理空间绘图:(1) 绘制简单的世界地图-1.地图绘制与细节调整 gma 地理空间绘图:(1) 绘制简单的世界地图-2.设置经纬网 方法 SetFrame(FrameColor ‘black’, FrameWidth 0.6, ShowFrame True, ShowLeft True, ShowBottom True, Sho…

Golang alpine Dockerfile 最小打包

最近在ubantu 上进行了 iris项目的alpine 版本打包&#xff0c;过程遇到了一些问题&#xff0c;记录一下。 golang版本 &#xff1a;1.18 系统&#xff1a;ubantu 代码结构 Dockfile内容 FROM alpine:latest MAINTAINER Si Wei<3320376695qq.com> ENV VERSION 1.1 ENV G…

格密码学习笔记(二):连续极小、覆盖半径和平滑参数

文章目录最短距离和连续极小值距离函数和覆盖半径格的平滑参数致谢最短距离和连续极小值 除了行列式&#xff0c;格的另一个基本量是格上最短非零向量的长度&#xff0c;即格中最短距离&#xff0c;其定义为 λ1min⁡x,y∈L,x≠y∥x−y∥min⁡z∈L,z≠0∥z∥.\begin{aligned} …

一起来学ASM字节码插桩:从分析class文件结构开始

文章目录Class字节码class字节码构成类型描述符基本类型描述符非数组的引用类型数组引用类型方法描述符OpCode 操作码Class字节码 Java 能做到 一次编译&#xff0c;到处运行&#xff0c;主要就是靠 class字节码 文件&#xff0c;也就是 java 文件经过编译之后 .java -> .c…

【C语言】刷题|链表|双指针|指针|多指针|数据结构

主页&#xff1a;114514的代码大冒 qq:2188956112&#xff08;欢迎小伙伴呀hi✿(。◕ᴗ◕。)✿ &#xff09; Gitee&#xff1a;庄嘉豪 (zhuang-jiahaoxxx) - Gitee.com 文章目录 目录 文章目录 前言 一、移除链表元素 二、反转链表 三&#xff0c;链表的中间结点 四&…

springBoot 事务基本原理

springBoot事务基本原理是基于spring的BeanPostProcessor&#xff0c;在springBoot中事务使用方式为&#xff1a; 一、在启动类上添加注解&#xff1a;EnableTransactionManagement 二、在需要事务的接口上添加注解&#xff1a;Transactional 基本原理&#xff1a; 注解&am…

GB/T28181-2022图像抓拍规范解读及技术实现

规范解读GB28181-2022相对2016&#xff0c;增加了设备软件升级、图像抓拍信令流程和协议接口。我们先回顾下规范说明&#xff1a;图像抓拍基本要求源设备向目标设备发送图像抓拍配置命令,携带传输路径、会话ID等信息。目标设备完成图像传输后,发送图像抓拍传输完成通知命令,采用…

最短距离(dijkstra)

蓝桥杯集训每日一题 acwing1488 有 N 个村庄&#xff0c;编号 1 到 N。 村庄之间有 M 条无向道路&#xff0c;第 i 条道路连接村庄 ai 和村庄 bi&#xff0c;长度是 ci。 所有村庄都是连通的。 共有 K 个村庄有商店&#xff0c;第 j 个有商店的村庄编号是 xj。 然后给出 Q…

8.装饰者模式

目录 简介 角色组成 实现步骤 1. 新建 Log.class&#xff0c;添加如下代码 2. 新建 Log4j.class&#xff0c;继承 Log.class&#xff0c;并实现 record() 方法 3. 新建 Decorator.class&#xff0c;继承 Log.class 4. 新建 Log4jDecorator.class&#xff0c;继承 Decorat…