基于torchvision的CV迁移学习

news2024/12/23 14:25:53

前面我们用过了cifar10,这里因为我们模型的体量更大,他能够理解更加复杂的数据集,所以这里我们就使用更加复杂的数据集叫做cifar100,顾名思义就是它是一个100分类的图像数据集,分类数据更多,复杂度更多。

定义数据集

import torchvision
import torch


#定义数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self, train):

        #在线加载数据集
        #更多数据集:https://pytorch.org/vision/stable/datasets.html
        self.data = torchvision.datasets.CIFAR100(root='data',
                                                  train=train,
                                                  download=True)

        #更多数据增强:https://pytorch.org/vision/stable/transforms.html
        self.compose = torchvision.transforms.Compose([

            #原本是32*32的,缩放到300*300,这是为了适应预训练模型的习惯,便于它抽取图像特征
            torchvision.transforms.Resize(300),

            #随机左右翻转,这是一种图像增强,很显然,左右翻转不影响图像的分类结果
            torchvision.transforms.RandomHorizontalFlip(p=0.5),

            #图像转矩阵数据,值域是0-1之间
            torchvision.transforms.ToTensor(),

            #让图像的3个通道的数据分别服从3个正态分布,这3分数据是从一个大的数据集上统计得出的
            #投影也是为了适应预训练模型的习惯
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        #取数据
        x, y = self.data[i]

        #应用compose,图像转数据
        x = self.compose(x)

        return x, y


dataset = Dataset(train=True)

x, y = dataset[0]

print(len(dataset), x.shape, y)

这里我们使用了torchvision的加载数据集的方式,它能够在线的加载数据集,那么更多的数据集可以通过上面注释中的连接获得。root='data',是指将下载的数据集保存到本地磁盘的路径,也就是数据缓存的位置,下面这个参数train它的取值是一个布尔值,指的是要下载的数据集的训练的部分还是测试的部分,这里用到是一个变量因为两部分训练集我们都需要。

compose这个变量是torchvision提供的另外一个功能,就是图像的数据增强,具体的方法也可以通过注释当中的链接查到,下面演示的是常用的几个,第一个是resize也就是图像的缩放,原本是32x32,这里统一缩放到300x300,这是为了适应预训练模型的习惯,便于预训练模型抽取图像的特征,因为我们的预训练模型它训练的时候都是使用300x300的图像来训练的。第二个应用的图像增强就是随机的左右翻转,翻转的概率设置为0.5,但对于cifar100这个数据集来说,左右翻转是不影响图像的分类结果,加上这个数据增强是让我们的数据集更加的丰富。使用ToTensor这个工具类将图像转为矩阵,值域在0到1之间。最后我们对数据进行一个normalize,也就是让我们的图像数据的三个通道分别的服从三个正态分布,三个正态分布它的均值和标准差都写在上面了。

然后就是len和getitem,getitem这个函数每次取一批数据后,然后对这个图像应用我们的compose,这样对我们的图像增强,然后把图像转换为了数据。

定义loader

#每次从loader获取一批数据时回回调,可以在这里做一些数据整理的工作
#这里写的只是个例子,事实上这个回调函数什么也没干..
def collate_fn(data):
    #取数据
    x = [i[0] for i in data]
    y = [i[1] for i in data]

    #比如可以手动转换数据格式
    x = torch.stack(x)
    y = torch.LongTensor(y)

    return x, y


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     shuffle=True,
                                     drop_last=True,
                                     collate_fn=collate_fn)

x, y = next(iter(loader))

print(len(loader), x.shape, y)

loader的代码没有什么可讲的,要提到的只有collate_fn,这个函数是每次从loader取一批数据的时候都会回调的,所以可以在这一个函数里面做一些数据整理的工作。

(6250, torch.Size([8, 3, 300, 300]), tensor([50, 54, 98, 51, 77, 96, 72, 81]))

很显然,x就是8张图像,y就是8个整数,取值是在0到100之间。

迁移学习

一般模型的第一部都是将数据读进去,然后一层层的抽取特征,最后把数据抽取成一个向量后,放到一个全连接的神经网络当中取进行分类,那么对于一个训练好的神经网络模型来说,其中的很多层其实是可以复用的。比方这里的一个模型它是一个回归的结果,然后我又不想要回归了怎么办。很简单,我将最后一层剪掉,然后重新接上三层新的,然后在这三层当中想做分类还是回归就有我自己决定了。也就是说前面的这些层我是不对它进行训练的,或者说这些层基本上训练好了,即使我对它进行重新的训练,难度也会想对的较小。

这种就是迁移学习,它的核心思想就是复用以前训练好的模型,它其中的一些层的参数,尤其是浅层的,因为这些层它是负责图像数据的特征抽取的,在我新的模型当中是可以复用的,因为数据特征抽取这个工作我一样是要做的。

定义模型

按照之前所说的,我们需要一个预训练好的模型,使用torchvision来完成这项工作,在这里面它提供了很多的预训练模型,更多的选择可以去链接中查找,这样通过torchvision加载后, 重新组装模型,而我们也只需要这里面的feature部分,后面给它接一个全连接的输出层,那么我要做的是分类还是回归就可以自己决定了。

class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

        #加载预训练模型
        #更多模型:https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights
        pretrained = torchvision.models.efficientnet_v2_s(
            weights=torchvision.models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)

        #重新组装模型,只要特征抽取部分
        pretrained = torch.nn.Sequential(
            pretrained.features,
            pretrained.avgpool,
            torch.nn.Flatten(start_dim=1),
        )

        #锁定参数,不训练
        for param in pretrained.parameters():
            param.requires_grad_(False)

        pretrained.eval()
        self.pretrained = pretrained

        #线性输出层,这部分是要重新训练的
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(1280, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 100),
        )

    def forward(self, x):
        #调用预训练模型抽取参数,因为预训练模型是不训练的,所以这里不需要计算梯度
        with torch.no_grad():
            #[8, 3, 300, 300] -> [8, 1280]
            x = self.pretrained(x)

        #计算线性输出
        #[8, 1280] -> [8, 100]
        return self.fc(x)


model = Model()

x = torch.randn(8, 3, 300, 300)

print(model.pretrained(x).shape, model(x).shape)

模型训练

#训练
def train():
    #注意这里的参数列表,只包括要训练的参数即可
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
    loss_fun = torch.nn.CrossEntropyLoss()
    model.fc.train()

    #定义计算设备,优先使用gpu
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)

    print('device=', device)

    for i, (x, y) in enumerate(loader):
        #如果使用gpu,数据要搬运到显存里
        x = x.to(device)
        y = y.to(device)

        out = model(x)
        loss = loss_fun(out, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i % 500 == 0:
            acc = (out.argmax(dim=1) == y).sum().item() / len(y)
            print(i, loss.item(), acc)

    #保存模型,只保存训练的部分即可
    torch.save(model.fc.to('cpu'), 'model/8.model')

定义设备是我们普遍会写的,如果有GPU那么就使用GPU进行运算

测试

@torch.no_grad()
def test():

    #加载保存的模型
    model.fc = torch.load('model/8.model')
    model.fc.eval()

    #加载测试数据集,共10000条数据
    loader_test = torch.utils.data.DataLoader(dataset=Dataset(train=False),
                                              batch_size=8,
                                              shuffle=True,
                                              drop_last=True)

    correct = 0
    total = 0
    for i in range(100):
        x, y = next(iter(loader_test))

        #这里因为数据量不大,使用cpu计算就可以了
        out = model(x).argmax(dim=1)

        correct += (out == y).sum().item()
        total += len(y)

    print(correct / total)

这里加载的是测试数据集,注意这里的train是False,最后得出的正确率为70%,还是比较的高的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/757871.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

哈希表(hashtable)的数据插入、查找和遍历

文章目录 前言一、哈希二、哈希的具体实现2.1 准备工作2.2 插入数据2.3 输出哈希表2.4 在哈希表中寻找数据2.5 销毁哈希表 三、 哈希表的调用总结 前言 本期主要和大家介绍一下哈希算法,这里主要给出哈希算法的实现方法; 一、哈希 hash是一种算法: 哈希…

Linux进程理解【环境变量】

Linux进程理解【环境变量】 提到环境变量,大家可能有些陌生,如果编写过Java就知道,编写Java需要提前安装JDK,这个操作就是配置Java的编码环境,在Linux中当然也少不了环境变量,下面我们就一起来看看 文章目…

SpringBoot 统一功能的处理

SpringBoot 统一功能的处理 文章目录 SpringBoot 统一功能的处理1. 用户登录权限校验1.1 最初用户登录验证1.2 Spring AOP 统一用户登录验证的问题1.3 SpringAOP 拦截器1.3.1 实现自定义拦截器1.3.2 将自定义拦截器加入到系统配置 1.4 拦截器实现原理1.4.1 实现流程图1.4.2 实现…

LeetCode:3. 无重复字符的最长子串

🍎道阻且长,行则将至。🍓 🌻算法,不如说它是一种思考方式🍀 算法专栏: 👉🏻123 题解目录 一、🌱[3. 无重复字符的最长子串](https://leetcode.cn/problems/l…

分享维修一例DELL R540服务器黄灯无法开机故障

DELL PowerEdge R540服务器故障维修案例:(看到文章就是缘分) 客户名称:东莞市某街道管理中心 故障机型:DELL R540服务器 故障问题:DELL R540服务器无法开机,前面板亮黄灯,工程师通过…

私有GitLab仓库 - 本地搭建GitLab私有代码仓库并随时远程访问

文章目录 前言1. 下载Gitlab2. 安装Gitlab3. 启动Gitlab4. 安装cpolar内网穿透5. 创建隧道配置访问地址6. 固定GitLab访问地址6.1 保留二级子域名6.2 配置二级子域名 7. 测试访问二级子域名 前言 GitLab 是一个用于仓库管理系统的开源项目,使用Git作为代码管理工具…

javaee jstl表达式

jstl是el表达式的扩展 使用jstl需要添加jar包 package com.test.servlet;import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map;import javax.servlet.ServletException; import javax.servlet…

【C++】面向对象三大特性之继承

【C】面向对象三大特性之继承 继承的概念继承基类成员访问方式的变化子类到父类对象之间赋值兼容转换继承中的作用域子类的默认成员函数继承和友元、静态成员的关系菱形继承和菱形的虚拟继承虚拟继承解决二义性和数据冗余 继承的概念 继承:是面向对象程序设计使代码…

解析Android VNDK/VSDK Snapshot编译框架

1.背景 背景一: 为解决Android版本碎片化问题,引入Treble架构,它提供了稳定的新SoC供应商接口,引入HAL 接口定义语言(HIDL/Stable AIDL,技术栈依然是Binder),它指定了 vendor HAL 和system fr…

动态规划01背包之416分割等和子集(第10道)

题目: 给你一个 只包含正整数 的 非空 数组 nums 。请你判断是否可以将这个数组分割成两个子集,使得两个子集的元素和相等。 示例: 解法: 先复习一下01背包问题: dp[i][j]的含义:从下标为[0-i]的物品里…

【Spring Boot】Spring Boot的系统配置 — 日志配置

日志配置 日志对于系统监控、故障定位非常重要,比如当生产系统发生问题时,完整清晰的日志记录有助于快速定位问题。接下来介绍Spring Boot对日志的支持。 1.Spring Boot日志简介 Spring Boot自带spring-boot-starter-logging库实现系统日志功能&#…

基于linux下的高并发服务器开发(第一章)- 目录操作函数

09 / 目录操作函数 &#xff08;1&#xff09;int mkdir(const char* pathname,mode_t mode); #include <sys/stat.h> #include <sys/types.h>int mkdir(const char *pathname, mode_t mode); 作用&#xff1a;创建一个目录 参数&#xff1a; pat…

固态硬盘SSD选型测试大纲

一&#xff0c;前言 目前不仅仅是家用电脑系统盘很多都采用了固态硬盘&#xff0c;很多工业产品也选用固态硬盘作为存储介质&#xff0c;这主要得益于固态硬盘相对于机械硬盘的优势。 固态硬盘(Solid State Disk)都是由主控芯片和闪存芯片组成&#xff0c;简单来说就是用固态电…

Python编程从入门到实践_5-10 检查用户名_答案

#《Python编程从入门到实践》&#xff0c;动手试一试&#xff0c;5-10检查用户名&#xff0c;答案。2023-07-15,by qs。 current_users [AaA,bBb,CcC,DdD,EeE] new_users [AAA,bbb,abc,def,hij] for new_user in new_users:current_users_1 []for current_user in current_u…

安达发|汽车零部件行业追溯系统的应用

汽车行业正处于一个蓬勃发展的阶段&#xff0c;随着客户需求的不断变化&#xff0c;生产厂商推出新款商品的速度也越来越快&#xff0c;新项目和变更的不断出现&#xff0c;就可能导致在交付的产品质量方面遇到各种各样的问题。如果这些质量问题得不到及时有效地追溯和控制&…

华为模拟器eNSP过程中所遇问题(40错误)与解决办法

1. 版本 2.打开ensp开启AR2204&#xff0c;报错40 3.弹出文档&#xff0c;挨着试一遍先 安装eNSP的PC上是否存在名为“VirtualBox Host-Only Network”的虚拟网卡 需要启用。虚拟网卡的设置是否符合以下要求&#xff1a;IP地址为192.168.56.1&#xff0c;子网掩码为255.255.2…

Typora设置Gitee图床,自动上传图片

之前写了一篇同类型文章&#xff1a;如何将Typora中图片上传到csdn 实现了Typora本地编辑的内容中的图片&#xff0c;可以直接复制到csdn上进行发布。但是在使用过程中发现sm.ms这个图床站不是很稳定&#xff0c;即使用了翻墙也不稳定。 这篇文章推荐使用Gitee作为图床&#xf…

C++教程(六)——数组

1 数组 1.1 概述 所谓数组&#xff0c;就是一个集合&#xff0c;里面存放了相同类型的数据元素 **特点1&#xff1a;**数组中的每个数据元素都是相同的数据类型 **特点2&#xff1a;**数组是由连续的内存位置组成的 12 一维数组 1.2.1 一维数组定义方式 一维数组定义的三…

FL Studio21编曲软件中文版如何下载更新?

国人习惯称之为水果&#xff0c;也是我个人现在在用的软件。FL Studio是一款比较全面的编曲软件&#xff0c;其通道机架可以使用户添加各种音频采样&#xff0c;快捷编辑节奏型。对于音频的剪辑、拼接、效果处理也非常优秀。非常适合电子音乐编曲以及一些Hiphop。但是其录音、以…

组合数学相关知识总结(目前主要总结了卡特兰数)

全排列 例子&#xff1a; n n n 个数取 m m m 个数有序排放 通项公式&#xff1a; A n m ( P n m ) n ∗ ( n − 1 ) ∗ ( n − 2 ) ∗ ⋅ ⋅ ⋅ ∗ ( n − m 1 ) n ! ( n − m ) ! A_n^m(P_n^m)n*(n-1)*(n-2)**(n-m1) \frac{n!}{(n-m)!} Anm​(Pnm​)n∗(n−1)∗(n−2)∗…