Pytorch学习---基于经典网络架构ResNet训练花卉图像分类模型

news2024/9/21 0:13:29

基于经典网络架构训练图像分类模型

导包


import copy
import json
import time
import torch
from torch import nn
import torch.optim as optim
import torchvision
import os
from torchvision import transforms, models, datasets
import numpy as np
import matplotlib.pyplot as plt
import ssl

冻结中间层的所有参数,只训练最后输出全连接层

def set_parameter_requeires_grad(model,feature_extracting):
    """
    set_parameter_requires_grad 函数的作用是根据 feature_extracting 参数的值来决定是否冻结模型的参数。当用于特征提取时,它会阻止预训练模型的参数在训练过程中被更新,从而保留预训练模型的特征提取能力。当用于微调时,它不会修改参数的 requires_grad 属性,从而允许所有参数被更新。
    """
    # 该函数会遍历模型的所有参数,并将它们的 requires_grad 属性设置为 False
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

是否用gpu进行训练

def get_device() -> torch.device:
    """
    确定并返回用于训练的设备(CPU 或 GPU)
    """
    train_on_gpu = torch.cuda.is_available()
    if not train_on_gpu:
        print('你的gpu不可用,尝试在cpu训练')
    else:
        print('gpu可用,训练中在gpu中进行')
    #torch.device() 是一个用于创建设备对象的构造函数,它可以指定张量和模型运行在 CPU 还是 GPU 上。
    device = torch.device('cuda:0'if torch.cuda.is_available() else "cpu")
    return device

选择迁移模型,这里选择resnet残差神经网络,不同模型的初始化方法稍微有点不同

def initialize_model(model_name, num_classes,feature_extract,use_pretrained = True):
    """
    用于初始化一个特定的深度学习模型( ResNet-152),并将它用于图像分类任务。
    :param model_name:要学习的模型名称
    :param num_classes:指定分类任务的目标类别数量
    :param feature_extrace:是否进行特征提取(冻结训练层)
    :param use_pretrained:是否使用预训练的模型,如果为TRUE,则使用ImageNet预训练的权重初始化模型,false则随机初始化权重
    :return:
    """
    if model_name == 'resnet':
        model_ft = models.resnet152(pretrained=use_pretrained)
        set_parameter_requeires_grad(model_ft,feature_extract)
        # model_ft.fc 是指模型的最后一个全连接层(分类层)。.in_features 是一个属性,它表示这个全连接层的输入特征数量。
        num_ftrs = model_ft.fc.in_features
        # 将原模型的最后一层替换为一个适合当前任务的分类层,输出节点数量为 num_classes。新分类层由一个线性层和一个 LogSoftmax 层组成,用于输出分类概率。
        model_ft.fc = nn.Sequential(nn.Linear(num_ftrs,num_classes),nn.LogSoftmax(dim=1))
        input_size = 224
    else:
        print('无效模型,不存在!')
        exit()
    return model_ft, input_size

用于启动花卉分类任务的准备工作

def flower_start():
    # 模型初始化,获取设备,将模型放置到设备上进行训练
    model_ft, input_size = initialize_model("resnet", 102, feature_extract=True, use_pretrained=True)
    device = get_device()
    model_ft = model_ft.to(device)
    # 获取需要更新的参数,并提取
    params = model_ft.named_parameters()
    print('需要学习的参数有:')
    params_need_update = []
    for param_name, param in params:
        if param.requires_grad:
            params_need_update.append(param)
            print(param_name)
    # 数据路径
    data_dir = './flower_data'
    train_dir = data_dir+'/train'
    valid_dir = data_dir+'/valid'
    # 将分类后的编号和对应的名字找到
    with open('cat_to_name.json')as f:
        cat_to_name = json.load(f)
    # 数据增强变换
    """torchvision.transforms 提供了一系列用于图像预处理的功能,包括图像增强、转换和标准化等操作。这些变换可以应用于图像数据,以增强模型的泛化能力或改善训练效果。
    transforms.Compose 是一个容器类,用于将多个变换组合在一起形成一个变换序列。这样可以方便地定义一系列变换操作,并按照顺序依次应用到图像数据上。"""
    data_transforms = {
        'train': transforms.Compose([transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选
                                     transforms.CenterCrop(224),  # 从中心开始裁剪,只得到一张图片
                                     transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 概率为0.5
                                     transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转
                                     transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
                                     # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
                                     transforms.RandomGrayscale(p=0.025),  # 概率转换成灰度率,3通道就是R=G=B
                                     transforms.ToTensor(),
                                     # 迁移学习,用别人的均值和标准差
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 均值,标准差
                                     ]),
        'valid': transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     # 预处理必须和训练集一致
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ]),
    }
    batch_size = 8
    # 将变换后的图片用字典保存
    """os.path.join(data_dir, x):将 data_dir(数据目录)与 'train' 或 'valid' 字符串拼接起来,形成训练集和验证集的完整路径。
    data_transforms[x]:根据 'train' 或 'valid' 选择相应的数据变换。
    datasets.ImageFolder:PyTorch 中的一个类,用于加载文件夹结构中的图像数据集。该类会自动根据文件夹结构生成类标签,并应用指定的变换。"""
    image_datasets = {x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train', 'valid']}
    # print(image_datasets)
    # 批量处理
    dataloaders = {x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True) for x in ['train', 'valid']}
    dataset_sizes = {x:len(image_datasets[x]) for x in ['train','valid']}
    print(dataset_sizes)
    # 样本数据的标签
    class_names = image_datasets['train'].classes
    print(class_names)
    # 画出预处理好的图像
    fig = plt.figure(figsize=(20,12))
    columns, rows = 4,2
    dataiter = iter(dataloaders['valid'])
    inputs, classes = next(dataiter)
    for idx in range(columns * rows):
        ax = fig.add_subplot(rows, columns, idx + 1, xticks=[], yticks=[])
        # classes为索引,class_name里为实际label,再去拿到对应的花名
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
        img = transforms.ToPILImage()(inputs[idx])
        plt.imshow(img)
    plt.show()
    # 优化器设置
    optimizer_ft = optim.Adam(params_need_update,lr=1e-2)
    # 设置学习率衰减,每7个训练过程衰减为原来的1/10
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=1/10)
    # 这里不再使用交叉熵损失函数,因为模型中最后一层是logsoftmax(),已经是对数了,所有直接用nllloss输入已经经过 LogSoftmax() 处理的对数概率分布
    criterion = nn.NLLLoss()
    filename = 'wz.pth'
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = wz_model_train(model_ft,
                                                                                                   dataloaders,
                                                                                                   criterion,
                                                                                                   optimizer_ft,
                                                                                                   scheduler,
                                                                                                   filename,
                                                                                                   device)
    for param in model_ft.parameters():
        param.requires_grad = True



训练模型函数

def wz_model_train(model,dataloaders,criterion,optimizer,scheduler,filename:str,device:torch.device,num_epochs=2,is_inception=False):
    """
    训练和验证模型:通过迭代数据集来进行训练和验证。
    保存最佳模型:记录并保存验证集上表现最好的模型。
    记录训练历史:记录每个epoch的训练和验证损失及准确率。
    学习率更新:使用学习率调度器更新学习率。
    性能报告:打印每个epoch的训练时间和性能指标。
    :param model:训练的模型
    :param dataloaders:数据加载器
    :param criterion:损失函数
    :param optimizer:优化器
    :param scheduler:学习率调度器
    :param filename:保存模型的文件名
    :param device:使用的设备
    :param num_epochs:训练轮数
    :param is_inception:是否使用inception网络
    :return:
    """
    start_time = time.time() # 记录训练开始时间
    best_acc = 0  # 记录训练最好准确率
    best_model_weights = copy.deepcopy(model.state_dict())  # 记录最好训练的模型参数
    model.to(device)
    # 保存损失和准确率数据
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    # 记录每个epoch的学习率
    LRs = [optimizer.param_groups[0]['lr']]
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('----------------------------')
        # 训练和验证
        for phase in ['train', 'valid']:
            if phase=='train':
                model.train()
            else:
                model.eval()
            running_loss = 0.0 # 累计损失
            running_corrects = 0 # 累计正确预测的数量
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad() # 梯度清零
                # 在训练模式下开启梯度计算,在评估模式下关闭梯度计算。
                with torch.set_grad_enabled(phase=='train'):
                    # inception网络有一个辅助输出,和主输出加权取损失值,这样可以增加稳定性
                    """辅助输出是指在网络中间某一层产生的额外预测结果。这种设计主要用于提高模型的训练稳定性,并在一定程度上防止过拟合。辅助输出可以为网络提供额外的监督信号,帮助模型更快地收敛。
                    在 Inception 网络中,辅助分类器(Auxiliary Classifier)通常位于网络的中间部分。这些辅助分类器可以提供额外的监督信号,帮助网络更好地学习特征。辅助分类器通常包括全局平均池化、全连接层和 Softmax 层,以便产生分类结果。
                    """
                    if is_inception and phase=='train':
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1+0.4*loss2
                    else:  # 这里的训练不需要开启inception
                        # print('没有开启inception')
                        outputs = model(inputs)
                        loss = criterion(outputs,labels)
                    # 不要概率最大值本身,要的是他的标签
                    _, preds = torch.max(outputs, 1)
                    # 训练阶段更新权重
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # 计算批量的loss和正确预测数量
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds==labels.data)
            # 计算平均损失和损失率
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            # 打印一个epoch的时间,这个时间可以是训练阶段的,也可以是验证阶段的
            time_elapsed = time.time()-start_time
            print('本次epoch模型已经跑了{:.0f}分 {:.0f}秒'.format(time_elapsed//60,time_elapsed%60))
            print('{}的损失loss是:{:.4f},准确率是{:.4f}'.format(phase,epoch_loss,epoch_acc))
            # 得到最好的那次模型
            if phase=='valid' and epoch_acc>best_acc:
                best_acc = epoch_acc
                best_model_weights = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict':model.state_dict(),
                    'best_acc':best_acc,
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(state,filename)
            if phase=='valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                scheduler.step()  # 根据验证集来调整学习率
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
    time_elapsed = time.time()-start_time
    print('训练在{:.0f}分{:.0f}秒完成'.format(time_elapsed//60,time_elapsed%60))
    print('最好的精确值:{:4f}'.format(best_acc))
    # 将最好的训练一次当最终值
    model.load_state_dict(best_model_weights)
    return model, val_acc_history, train_acc_history,valid_losses,train_losses, LRs



ssl._create_default_https_context = ssl._create_unverified_context
flower_start()
gpu可用,训练中在gpu中进行
需要学习的参数有:
fc.0.weight
fc.0.bias
{'train': 6552, 'valid': 818}
['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Epoch 0/1
----------------------------
本次epoch模型已经跑了3分 39秒
train的损失loss是:10.3695,准确率是0.3219
本次epoch模型已经跑了3分 60秒
valid的损失loss是:9.9343,准确率是0.4364
Optimizer learning rate : 0.0100000
Epoch 1/1
----------------------------
本次epoch模型已经跑了7分 46秒
train的损失loss是:8.1607,准确率是0.4899
本次epoch模型已经跑了8分 8秒
valid的损失loss是:15.6906,准确率是0.3619
Optimizer learning rate : 0.0100000
训练在8分8秒完成
最好的精确值:0.436430

.4364
Optimizer learning rate : 0.0100000
Epoch 1/1
----------------------------
本次epoch模型已经跑了7分 46秒
train的损失loss是:8.1607,准确率是0.4899
本次epoch模型已经跑了8分 8秒
valid的损失loss是:15.6906,准确率是0.3619
Optimizer learning rate : 0.0100000
训练在8分8秒完成
最好的精确值:0.436430


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2150534.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【使用Hey对vllm接口压测】模型并发能力

使用Hey对vllm进行模型并发压测 docker run --rm --networkknowledge_network \registry.cn-shanghai.aliyuncs.com/zhph-server/hey:latest \-n 200 -c 200 -m POST -H "Content-Type: application/json" \-H "Authorization: xxx" \-d {"model"…

【类型黑市】指针

大家好我是#Y清墨,今天我要介绍的是指针。 意义 指针就是存放内存地址的变量。 分类 因为变量本身是分类型的,我们学过的变量类型有 int, long long, char, double, string, 甚至还有结构体变量。 同样,指针也分类型,如果指针指向…

云韧性,现代云服务不可或缺的组成部分

韧性,一个物理学概念,表示材料在变形或者破裂过程中吸收能量的能力。韧性越好,则发生脆性断裂的可能性越小。 如今,韧性也延伸到企业特质、产品特征等之中,用于形容企业、产品乃至服务的优劣。同样,随着云…

3. Internet 协议的安全性

3. Internet 协议的安全性 (1) 常用网络协议的功能、使用的端口及安全性 HTTP协议 功能:用于从服务器传输超文本到本地浏览器。端口:默认是80端口。安全性:不提供数据加密,存在数据泄露和中间人攻击风险。使用HTTPS协议(443端口)可以增强安全性。FTP协议 功能:实现文件的…

电脑录课软件哪个好用,提高教学效率?电脑微课录屏软件推荐

在当今这个数字化时代,教育领域也迎来了翻天覆地的变化。随着远程教学和在线学习的普及,教师们开始寻求更高效、更便捷的教学工具来提升教学质量和学生的学习体验。电脑录课软件,作为现代教育技术的重要组成部分,能够帮助教师轻松…

【CPP】类与继承

14 类与继承 在前面我们提到过继承的一些概念,现在我们来回顾一下 打个比方:在CS2中我们把玩家定义为一个类 class 玩家: 血量:100阵营(未分配)服饰(未分配)位置(未分配)武器(未分配)是否允许携带C4(未分配)是否拥有C4(未分配) 当对局创建时,会新生成两个类,这两个类继承自&qu…

【Linux庖丁解牛】—Linux基本指令(上)!

🌈个人主页:秋风起,再归来~🔥系列专栏: Linux庖丁解牛 🔖克心守己,律己则安 目录 1、 pwd命令 2、ls 指令 3、cd 指令 4、Linux下的根目录 5、touch指令 6、 stat指令 7、mkdi…

LabVIEW提高开发效率技巧----采用并行任务提高性能

在复杂的LabVIEW开发项目中,合理利用并行任务可以显著提高系统的整体性能和响应速度。并行编程是一种强大的技术手段,尤其适用于实时控制、数据采集以及多任务处理等场景。LabVIEW的数据流编程模型天然支持并行任务的执行,结合多核处理器的硬…

OrCAD使用,快捷键,全选更改封装,导出PCB网表

1 模块名称 2 快捷键使用 H: 镜像水平 V:镜像垂直 R: 旋转 I: 放大 O: 放小 P:放置元器件 W: 步线 B: 总线(无电气属性) E: 总线连接符(和BUS一起用&#xff09…

【网络通信基础与实践第四讲】用户数据报协议UDP和传输控制协议TCP

一、UDP的主要特点 1、UDP是无连接的,减少了开销和发送数据之前的时延 2、UDP使用尽最大努力交付,但是不保证可靠交付 3、UDP是面向报文的。从应用层到运输层再到IP层都只是添加一个相应的首部即可 4、UDP没有拥塞机制,源主机以恒定的速率…

基于JAVA+SpringBoot+Vue的学生干部管理系统

基于JAVASpringBootVue的学生干部管理系统 前言 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN[新星计划]导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末附源码下载链接🍅 哈…

力扣题解2376

大家好,欢迎来到无限大的频道。 今日继续给大家带来力扣题解。 题目描述(困难): 统计特殊整数 如果一个正整数每一个数位都是 互不相同 的,我们称它是 特殊整数 。 给你一个 正 整数 n ,请你返回区间 …

【Python报错已解决】SyntaxError invalid syntax

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 专栏介绍 在软件开发和日常使用中,BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

锐尔15注册机 锐尔文档扫描影像处理系统15功能介绍

锐尔文档扫描影像处理系统是一款全中文操作界面的文件、档案扫描及影像优化处理软件,是目前国内档案数字化行业里专业且优秀的影像优化处理软件。 无论是从纸质文件制作高质量的影像文件,或是检查已经制作好的影像文件,锐尔文档扫描影像处理…

Generative Models from the perspective of Continual Learning【小白读论文】

摘要: 本文在持续学习情况下评估各种生成模型。 本文研究了几种模型如何学习和遗忘,并考虑了各种持续学习策略:回放、正则化、生成重放和微调。 我们使用两个定量指标来估计生成质量和记忆能力。 我们在三个常用的持续学习基准(MN…

RabbitMQ08_保证消息可靠性

保证消息可靠性 一、生产者可靠性1、生产者重连机制(防止网络波动)2、生产者确认机制Publisher Return 确认机制Publisher Confirm 确认机制 二、MQ 可靠性1、数据持久化交换机、队列持久化消息持久化 2、Lazy Queue 惰性队列 三、消费者可靠性1、消费者…

新媒体运营

一、新媒体运营的概念 1.新媒体 2.新媒体运营的五大方向 用户运营 产品运营 。。。 二、新媒体的岗位职责及要求 三、新媒体平台

【redis-01】redis基本数据类型和使用场景

redis系列整体栏目 内容链接地址【一】redis基本数据类型和使用场景https://zhenghuisheng.blog.csdn.net/article/details/142406325 redis基本数据类型和使用场景 一,redis基本数据类型和使用场景1,String数据类型2,Hash数据类型3&#xff…

嵌入式linux系统中rk3588芯片引脚基本操作

第一:开发板中linux系统对应设备节点 进入用户 LED 设备文件夹: 1cd /sys/class/leds/usr_led该目录下的文件分别为 brightness、device、max_brightness、power、subsystem、trigger 和 uevent,需要注意的是 brightness、max_brightness 以及 trigger 文件,这三个文件都是…

共享单车轨迹数据分析:以厦门市共享单车数据为例(五)

先来聊聊啥是共享单车电子围栏? 共享单车电子围栏是一种基于地理位置技术的虚拟边界,用于管理和规范共享单车的停放和使用。这种技术通过在地图上划定特定区域,帮助用户了解哪些地方可以停车,哪些地方不能停车,从而减…