Pytorch:利用torchvision调用各种网络的预训练模型,完成CIFAR10数据集的各种分类任务

news2024/10/7 12:20:15

2023.7.19

cifar10百科:

[ 数据集 ] CIFAR-10 数据集介绍_cifar10_Horizon Max的博客-CSDN博客

torchvision各种预训练模型的调用方法:

pytorch最全预训练模型下载与调用_pytorch预训练模型下载_Jorbol的博客-CSDN博客

CIFAR10数据集下载并转换为图片:

文件结构:

import torchvision
from torch.utils.data import DataLoader
import os
import numpy as np
import imageio  # 引入imageio包

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 路径将测试集作为验证集
train_path = './dataset/cifar-10-batches-py/train'
test_path = './dataset/cifar-10-batches-py/val'

for i in range(10):
    file_name = train_path + '/'+ str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

for i in range(10):
    file_name = test_path + '/' + str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

# 解压 返回解压后的字典
def unpickle(file):
    import pickle as pk
    fo = open(file, 'rb')
    dict = pk.load(fo, encoding='iso-8859-1')
    fo.close()
    return dict


# begin unpickle
root_dir = "./dataset/cifar-10-batches-py"
# 生成训练集图片
print('loading_train_data_')
for j in range(1, 6):
    dataName = root_dir + "/data_batch_" + str(j)  # 读取当前目录下的data_batch1~5文件。
    Xtr = unpickle(dataName)
    print(dataName + " is loading...")

    for i in range(0, 10000):
        img = np.reshape(Xtr['data'][i], (3, 32, 32))  # Xtr['data']为图片二进制数据
        img = img.transpose(1, 2, 0)  # 读取image
        picName = root_dir + '/train/' + str(Xtr['labels'][i]) + '/' + str(i + (j - 1) * 10000) + '.jpg'
        imageio.imsave(picName, img)  # 使用的imageio的imsave类
    print(dataName + " loaded.")

# 生成测试集图片(将测试集作为验证集)
print('loading_val_data_')
testXtr = unpickle(root_dir + "/test_batch")
for i in range(0, 10000):
    img = np.reshape(testXtr['data'][i], (3, 32, 32))
    img = img.transpose(1, 2, 0)
    picName = root_dir + '/val/' + str(testXtr['labels'][i]) + '/' + str(i) + '.jpg'
    imageio.imsave(picName, img)

 训练代码:

                                                           AlexNet 结构

1,查询需要的模型网站并填入

# 预训练模型官网
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
}

2, 加载预训练模型(在10分类的情况下,AlexNet的效果比ResNet18的效果稍好)

先print(Model),,看fc层的参数,设置好fc层

3,以AlexNet10分类任务来看 

  

# Model = models.densenet161(pretrained=True)

Model = models.resnet18(pretrained=True)
for param in Model.parameters():
param.requires_grad = True
# print(Model)
# Model.fc = nn.Linear(2208, class_num)
Model.fc = nn.Linear(512, class_num)

 此时,我们需要将在导入AlexNet模型下载的官网:

# 预训练模型官网:
model_urls = {'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',}

# 加载模型:

Model = models.alexnet(pretrained=True)
for param in Model.parameters():
param.requires_grad = True
print(Model)

   

# 根据上面Classifier层的最后(6)Linear的输入通道 in_features = 4096 更改 model的fc层

Model.fc = nn.Linear(4096, class_num)

from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]
                               )


class MyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.names_list = []

        for dirs in os.listdir(self.root_dir):
            dir_path = self.root_dir + '/' + dirs
            for imgs in os.listdir(dir_path):
                img_path = dir_path + '/' + imgs
                self.names_list.append((img_path, dirs))

    def __len__(self):
        return len(self.names_list)

    def __getitem__(self, index):
        image_path, label = self.names_list[index]
        if not os.path.isfile(image_path):
            print(image_path + '不存在该路径')
            return None
        image = Image.open(image_path).convert('RGB')

        label = np.array(label).astype(int)
        label = torch.from_numpy(label)


        if self.transform:
            image = self.transform(image)

        return image, label


if __name__ == '__main__':
    # 准备数据集
    train_data_path = './dataset/cifar-10-batches-py/train'
    val_data_path = './dataset/cifar-10-batches-py/val'

    # 数据长度
    train_data_length = len(train_data_path)
    val_data_length = len(val_data_path)

    # 分类的类别
    class_num = 10

    # 迭代次数
    epoch = 30

    # 学习率
    learning_rate = 0.00001

    # 批处理大小
    batch_size = 128

    # 数据加载器
    train_dataset = MyDataset(train_data_path, transform)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_dataset = MyDataset(val_data_path, transform)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_data_size = len(train_dataset)
    val_data_size = len(val_dataset)

    # 预训练模型官网
    model_urls = {'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',}

    # 调用预训练调整全连接层:搭建网络
    # Model = models.densenet161(pretrained=True)
    Model = models.alexnet(pretrained=True)
    for param in Model.parameters():
        param.requires_grad = True
    print(Model)
    # Model.fc = nn.Linear(2208, class_num)
    Model.fc = nn.Linear(4096, class_num)

    # 创建网络模型
    ModelOutput = Model.cuda()  # DenseNet161 ResNet18

    # 采用多GPU训练
    if torch.cuda.device_count() > 1:
        print("使用", torch.cuda.device_count(), "个GPUs进行训练")
        ModelOutput = nn.DataParallel(ModelOutput)
    else:
        ModelOutput = Model.to(device)  # .Cuda()数据是指放到GPU上
        print("使用", torch.cuda.device_count(), "个GPUs进行训练")

    # 定义损失函数
    loss_fn = nn.CrossEntropyLoss().cuda()  # 交叉熵函数

    # 定义优化器
    optimizer = optim.Adam(ModelOutput.parameters(), lr=learning_rate)

    # 记录验证的次数
    total_train_step = 0
    total_val_step = 0

    # 训练
    acc_list = np.zeros(epoch)
    print("{0:-^27}".format('Train_Model'))
    for i in range(epoch):
        print("----------epoch={}----------".format(i + 1))
        ModelOutput.train()
        for data in train_dataloader:  # data 是batch大小
            image_train_data, t_labels = data
            image_train_data = image_train_data.cuda()
            t_labels = t_labels.cuda()
            output = ModelOutput(image_train_data)
            loss = loss_fn(output, t_labels.long())

            # 优化器优化模型
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 优化更新参数

            total_train_step = total_train_step + 1
            print("train_times:{},Loss:{}".format(total_train_step, loss.item()))

        # 验证步骤开始
        ModelOutput.eval()
        total_val_loss = 0
        total_accuracy = 0
        with torch.no_grad():  # 测试的时候不需要对梯度进行调整,所以梯度设置不调整
            for data in val_dataloader:
                image_val_data, v_labels = data
                image_val_data = image_val_data.cuda()
                v_labels = v_labels.cuda()
                outputs = ModelOutput(image_val_data)
                loss = loss_fn(outputs, v_labels.long())
                total_val_loss = total_val_loss + loss.item()  # 计算损失值的和
                accuracy = 0

                for j in v_labels:  # 计算精确度的和

                    if outputs.argmax(1)[j] == v_labels[j]:
                        accuracy = accuracy + 1

                # accuracy = (outputs.argmax(1) == v_labels).sum()  # 计算一个数据的精确度
                total_accuracy = total_accuracy + accuracy

        val_acc = float(total_accuracy / val_data_size) * 100
        acc_list[i] = val_acc  # 记录验证集的正确率
        print('the_classification_is_correct :', total_accuracy, val_data_length)
        print("val_Loss:{}".format(total_val_loss))
        print("val_acc:{}".format(val_acc), '%')
        total_val_step += 1
        torch.save(ModelOutput, "Model_{}.pth".format(i + 1))
        # torch.save(ModelOutput.module.state_dict(), "Model_{}.pth".format(i + 1))
        print("{0:-^24}".format('Model_Saved'), '\n')
        print('val_max=', max(acc_list), '%', '\n')  # 验证集的最高正确率

测试代码:

import torch
from torchvision import transforms

import os
from PIL import Image

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 判断是否有GPU

model = torch.load('Model_8.pth')  # 加载模型

path = "./dataset/cifar-10-batches-py/test/"  # 测试集

imgs = os.listdir(path)

test_num = len(imgs)
print(f"test_dataset_quantity={test_num}")

for img_name in imgs:
    img = Image.open(path + img_name)

    test_transform = transforms.Compose([transforms.Resize((224, 224)),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]
                                        )

    img = test_transform(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    outputs = model(img)  # 将图片输入到模型中
    _, predicted = outputs.max(1)

    pred_type = predicted.item()
    print(img_name, 'pred_type:', pred_type)

在使用标签为9的卡车图像进行预测:

AlexNet:

 ResNet18:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/774898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

gitlab配置公钥

1、打开本地git bash,使用如下命令生成ssh公钥和私钥对 ssh-keygen -t rsa -C yourEmailgitlab.com2、然后打开~/.ssh/id_rsa.pub文件,复制里面的内容 cd ~/.ssh ls cat ./id_rsa.pub3、打开gitlab,找到Profile Settings–>SSH Keys—>Add SSH Key,并把上一…

【多线程】(六)Java并发编程深度解析:常见锁策略、CAS、synchronized原理、线程安全集合类和死锁详解

文章目录 一、常见锁策略1.1 乐观锁和悲观锁1.2 读写锁1.3 重量级锁和轻量级锁1.4 自旋锁1.5 公平锁和非公平锁1.6 可重入锁和不可重入锁 二、CAS2.1 什么是CAS2.2 CAS的实现原理2.3 CAS应用2.4 ABA问题 三、synchronized原理3.1 synchronized锁的特点3.2 加锁工作过程3.3 锁消…

VM(CentOS7安装和Linux连接工具以及换源)

目录 一、Linux意义 二、安装VMWare 三、centos7安装 1、正式安装CentOS7: 2、安装不了的解决方案 2.1常见问题——虚拟机开机就黑屏的完美解决办法 3、查看、设置IP地址 ① 查看ip地址:ip addr 或者 ifconfig, 注意与windows环境的区别…

labelme+sam在windows上使用指南

其实官网讲的很清楚了,这里做一个笔记,方便自己后面直接看。 首先,贴一下官方的链接,作者老哥很强,respect! 使用流程: https://github.com/wkentaro/labelme#installation 资源: ht…

数据库锁机制

锁机制 1. 概述2. 并发事务的不同场景2.1 读-读情况2.2 写-写情况2.3 读-写或写-读情况2.3.1 方案一:读事务使用MVCC(多版本并发控制),写事务加锁2.3.2 方案二:读、写事务均加锁 3. 锁分类3.1 从数据操作类型&#xff…

1.Docker概念

文章目录 Docker概念Docker容器与虚拟机的区别内核中的2个重要技术Linux Namespace的6大类型docker三个重要概念部署Dockeryum安装二进制安装 Docker 概念 docker是一个开源的应用容器引擎,基于go语言开发并遵循了apache2.0协议开源。docker可以让开发者打包他们的…

【PostgreSQL内核学习(三)—— 查询重写】

查询重写 查询重写系统规则视图和规则系统ASLO型规则的查询重写规则系统与触发器的区别 查询重写的处理操作定义重写规则删除重写规则对查询树进行重写 声明:本文的部分内容参考了他人的文章。在编写过程中,我们尊重他人的知识产权和学术成果&#xff0c…

王道计算机网络学习笔记(4)——网络层

前言 文章中的内容来自B站王道考研计算机网络课程,想要完整学习的可以到B站官方看完整版。 四:网络层 ​​​​​​​​​​​​​​在计算机网络中,每一层传输的数据都有不同的名称。 物理层:传输的数据称为比特(Bi…

宝塔的Redis绑定IP

宝塔安装Redis 软件商店搜索Redis 连接宝塔面板的redis服务器失败的解决办法 检查Linux是否放行6379端口修改Redis绑定IP检查阿里云/腾讯云的防火墙策略是否放行6379端口 1.bind 127.0.0.1 修改为 bind 0.0.0.0 127.0.0.1 表示只允许本地访问,无法远程连接 0.0.0.0 表…

基于Python的用户登录和密码强度等级测试|Python小应用

前言 那么这里博主先安利一些干货满满的专栏了! 这两个都是博主在学习Linux操作系统过程中的记录,希望对大家的学习有帮助! 操作系统Operating Syshttps://blog.csdn.net/yu_cblog/category_12165502.html?spm1001.2014.3001.5482Linux S…

Micro-app vue3+vite+ts用法

前言: 微前端的概念是由ThoughtWorks在2016年提出的,它借鉴了微服务的架构理念,核心在于将一个庞大的前端应用拆分成多个独立灵活的小型应用,每个应用都可以独立开发、独立运行、独立部署,再将这些小型应用融合为一个…

文库小程序在线阅读下载文档模板流量主小程序

一、什么是文库小程序? 文库小程序连接流量主,具体流程是粉丝进入小程序下载文档模板,下载前需要看广告,阅读后可以免费下载文档模板。具体的小程序演示请参见抖音云云文库 二、文库小程序的应用范围 小程序主要实现文档共享功能…

Redis Linux安装

Redis版本下载,版本地址http://download.redis.io/releases/ 点击跳转 新建文件夹 mkdir /usr/local/redis 上传压缩包,并使用命令解压tar -zxvf redis-6.2.8.tar.gz (redis-6.2.8.tar.gz为安装包) 安装依赖 yum install gcc-c 编译 make 安装 make install 修改配置 …

概率论和随机过程的学习和整理20:条件概率我知道,但什么是条件期望?可用来解决递归问题

目录 1 目标问题: 什么是条件期望? 条件期望有什么用? 2 条件期望,全期望公式 3 条件期望,全期望公式 和 条件概率,全概率公式的区别和联系 3.1 公式如下 3.2 区别和联系 3.3 概率和随机过程 4 有什…

Zabbix“专家坐诊”第200期问答汇总

问题一 Q:想请问下大佬们,我们zabbix最近有误告警的情况,这个怎么排查呢? 用了proxy,我看了proxy和server的日志,除了有慢查的日志,其它没有异常日志输出。 A:看下这个unreachable的…

首次与电商平台战略签约 第一三共与阿里健康达成战略合作

7月18日,阿里健康与第一三共在杭州正式签署战略合作协议。双方宣布将在此前合作基础上,全面深化心脑血管、风湿骨外科等疾病领域的合作深度,探索以患者为中心、以数字化为驱动力的创新型医药健康服务模式。据悉,此次合作是第一三共…

椒图--分析中心和后台管理中心

护网的时候我们要把右边的开关开启。开启就会对系统全量的记录,包含有网络行为日志,就会检测我们服务器里面的链接,端口箭头,内内网暴露的链接;进程操作日志,就可以看我们系统创建了哪些进程,就…

融云出海:不止假发出口和四卡四待手机,「非洲市场」的参差与机遇

↑ 点击预约“融云北极星”直播↑ 点击预约“实时社区”直播 比白皮书更精炼省流,比图谱更实用有效。 融云《社交泛娱乐出海作战地图》,被多位大咖标记为出海人必备工作手册。针对地图的核心模块,我们推出了系列解读文章,更详尽…

Redis数据持久化的两种方式

说明:Redis数据是存储在内存中的,Redis服务被关闭,数据是会被清除的。但Redis有数据持久化机制,在默认情况下,停止Redis服务会触发数据持久化机制,将数据保存下来,在下次启动时再读取出来。 Re…

解决spring security No AuthenticationProvider found for com.问题

No AuthenticationProvider found for com.xxx.xx 原因 当你验证过,后记得这个这里返回true。不然,就会出现既没有异常,又没验证返回通过的中间尴尬状态,security会当做没有验证通过来处理。 修改