动手学深度学习(Pytorch版)代码实践 -计算机视觉-37微调

news2025/1/15 23:23:38

37微调

在这里插入图片描述

import os
import torch
import torchvision
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt
from d2l import torch as d2l

# 获取数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
#Downloading ../data\hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

# 分别读取训练和测试数据集中的所有图像文件
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
# ImageFolder 会递归地读取指定目录下的所有图像文件。
# print(train_imgs.classes)#一个类名列表 # ['hotdog', 'not-hotdog']
# print(train_imgs.class_to_idx) # 一个字典,类名映射到类索引 # {'hotdog': 0, 'not-hotdog': 1}
# print(train_imgs.imgs) # 一个包含所有图像路径和对应类索引的列表
# 例如:[('../data\\hotdog\\train\\hotdog\\0.png', 0), ('../data\\hotdog\\train\\hotdog\\1.png', 0)
#       , ('../data\\hotdog\\train\\not-hotdog\\999.png', 1)]
# 显示了前8个正类样本图片和最后8张负类样本图片

# hotdogs = [train_imgs[i][0] for i in range(8)] #train_imgs[i] 返回一个元组 (image, label),
# # 其中 image 是图像张量,label 是对应的标签。[0] 只提取图像张量。

# not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)] # 索引从 -1 到 -8

# d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
# plt.show() # 显示图片

# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = torchvision.transforms.Compose([
    #从图像中裁切随机大小和随机长宽比的区域,然后将该区域缩放为224 * 224
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224), # 裁剪中央224 * 224
    torchvision.transforms.ToTensor(),
    normalize])

# 定义和初始化模型
# 使用在ImageNet数据集上预训练的ResNet-18作为源模型
pretrained_net = torchvision.models.resnet18(pretrained=True)

# 源模型实例包含许多特征层和一个输出层fc
print(pretrained_net.fc)
# Linear(in_features=512, out_features=1000, bias=True)

finetune_net = pretrained_net
# 改变输出层fc
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
# 参数初始化
nn.init.xavier_uniform_(finetune_net.fc.weight)


def train_batch_ch13(net, X, y, loss, trainer, devices):
    """使用多GPU训练一个小批量数据。
    参数:
    net: 神经网络模型。
    X: 输入数据,张量或张量列表。
    y: 标签数据。
    loss: 损失函数。
    trainer: 优化器。
    devices: GPU设备列表。
    返回:
    train_loss_sum: 当前批次的训练损失和。
    train_acc_sum: 当前批次的训练准确度和。
    """
    # 如果输入数据X是列表类型
    if isinstance(X, list):
        # 将列表中的每个张量移动到第一个GPU设备
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])# 如果X不是列表,直接将X移动到第一个GPU设备
    y = y.to(devices[0])# 将标签数据y移动到第一个GPU设备
    net.train() # 设置网络为训练模式
    trainer.zero_grad()# 梯度清零
    pred = net(X) # 前向传播,计算预测值
    l = loss(pred, y) # 计算损失
    l.sum().backward()# 反向传播,计算梯度
    trainer.step() # 更新模型参数
    train_loss_sum = l.sum()# 计算当前批次的总损失
    train_acc_sum = d2l.accuracy(pred, y)# 计算当前批次的总准确度
    return train_loss_sum, train_acc_sum# 返回训练损失和与准确度和


def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    """训练模型在多GPU
    参数:
    net: 神经网络模型。
    train_iter: 训练数据集的迭代器。
    test_iter: 测试数据集的迭代器。
    loss: 损失函数。
    trainer: 优化器。
    num_epochs: 训练的轮数。
    devices: GPU设备列表,默认使用所有可用的GPU。
    """
    # 初始化计时器和训练批次数
    timer, num_batches = d2l.Timer(), len(train_iter)
    # 初始化动画器,用于实时绘制训练和测试指标
    animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                           legend=['train loss', 'train acc', 'test acc'])
    # 将模型封装成 DataParallel 模式以支持多GPU训练,并将其移动到第一个GPU设备
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    # 训练循环,遍历每个epoch
    for epoch in range(num_epochs):
        # 初始化指标累加器,metric[0]表示总损失,metric[1]表示总准确度,
        # metric[2]表示样本数量,metric[3]表示标签数量
        metric = lp.Accumulator(4)
        # 遍历训练数据集
        for i, (features, labels) in enumerate(train_iter):
            timer.start()  # 开始计时
            # 训练一个小批量数据,并获取损失和准确度
            l, acc = train_batch_ch13(net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())   # 更新指标累加器
            timer.stop()  # 停止计时
            # 每训练完五分之一的批次或者是最后一个批次时,更新动画器
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3], None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter) # 在测试数据集上评估模型准确度
        animator.add(epoch + 1, (None, None, test_acc))# 更新动画器
    # 打印最终的训练损失、训练准确度和测试准确度
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    # 打印每秒处理的样本数和使用的GPU设备信息
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')


def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):
    """
    参数:
    net: 神经网络模型。
    learning_rate: 学习率。
    batch_size: 每个小批量的大小,默认为128。
    num_epochs: 训练的轮数,默认为5。
    param_group: 是否对不同层使用不同的学习率,默认为True。
    """
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)  # 创建训练数据集的迭代器
    
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)  # 创建测试数据集的迭代器

    devices = d2l.try_all_gpus()  # 获取所有可用的GPU设备
    loss = nn.CrossEntropyLoss(reduction="none")   # 定义损失函数
    # 如果使用参数组
    if param_group:
        # 获取除最后全连接层外的所有参数
        # 列表params_1x,包含除最后一层全连接层外的所有参数。
        params_1x = [param for name, param in net.named_parameters()
                     if name not in ["fc.weight", "fc.bias"]]
        # 定义优化器,分别为不同的参数组设置不同的学习率
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                  lr=learning_rate, weight_decay=0.001)
    else:
        # 如果不使用参数组,为所有参数设置相同的学习率
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    # 调用训练函数,开始训练
    train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
    
train_fine_tuning(finetune_net, 5e-5)
# loss 0.211, train acc 0.927, test acc 0.938
# 456.7 examples/sec on [device(type='cuda', index=0)]


"""
为了进行比较,我们定义了一个相同的模型,但是将其所有模型参数初始化为随机值。
由于整个模型需要从头开始训练,因此我们需要使用更大的学习率。
"""
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)
# loss 0.338, train acc 0.842, test acc 0.859
# 457.7 examples/sec on [device(type='cuda', index=0)]

plt.show() #显示图片 

预训练resnet18模型运行效果:

在这里插入图片描述

初始化resnet18模型运行效果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1853782.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

setInterval 定时任务执行时间不准验证

一般在处理定时任务的时候都使用setInterval间隔定时调用任务。 setInterval(() > {console.log("interval"); }, 2 * 1000);我们定义的是两秒执行一次,但是浏览器实际执行的间隔时间只多不少。这是由于浏览器执行 JS 是单线程模式,使用se…

二进制炸弹的fp是什么?

🏆本文收录于「Bug调优」专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&&…

Go日常分享 - error类型是指针类型吗?

背景 这个问题的产生来源于小泉在开发rpc接口时返回error遇到的问题,开发时想在defer里对err进行最终的统一处理赋值,发现外层接收一直都未生效。问题可以简化为成下面的小demo。 func returnError() error {var err errordefer func() {//err errors…

PMBOK® 第六版 管理项目知识

目录 读后感—PMBOK第六版 目录 在前面的文章中,输入环节都可以看见有事业环境因素、组织过程资产;工具与技术都有专家判断。都是说明知识的重要性。 虽然项目具有其独特的、唯一性,但项目相关的经验却能如同家族传承般,被持续地…

【Python】已解决:安装python-Levenshtein包时遇到的subprocess-exited-with-error问题

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例及解决方案五、注意事项 已解决:安装python-Levenshtein包时遇到的subprocess-exited-with-error问题 一、分析问题背景 在安装python-Levenshtein这个Python包时,有时会…

基于Java的火车订票管理系统【附源码】

火车订票管理登录 摘要:随着我国铁路交通的不断发展,简单的窗口售票模式已经不能满足方便人们出行的目的。采用先进的网络技术开发出方便快捷的火车票订票系统是现代客运业务发展的必然需求。本次设计的火车票订票系统通过访问主页,可以实现…

196.每日一题:检测大写字母(力扣)

代码解决 class Solution { public:bool detectCapitalUse(string word) {int capitalCount 0;int n word.size();// 统计大写字母的数量for (char c : word) {if (isupper(c)) {capitalCount;}}// 检查是否满足三种情况之一if (capitalCount n) {// 全部字母都是大写return…

[最全]设计模式实战(一)UML六大原则

UML类图 UML类图是学习设计模式的基础,学习设计模式,主要关注六种关系。即:继承、实现、组合、聚合、依赖和关联。 UML类图基本用法 继承关系用空心三角形+实线来表示。实现接口用空心三角形+虚线来表示。eg:大雁是最能飞的,它实现了飞翔接口。 关联关系用实线箭头来表示…

Python武器库开发-武器库篇之ThinkPHP 5.0.23-RCE 漏洞复现(六十四)

Python武器库开发-武器库篇之ThinkPHP 5.0.23-RCE 漏洞复现(六十四) 漏洞环境搭建 这里我们使用Kali虚拟机安装docker并搭建vulhub靶场来进行ThinkPHP漏洞环境的安装,我们进入 ThinkPHP漏洞环境,可以 cd ThinkPHP,然…

c#使用自带库对字符串进行AES加密、解密

文章目录 1 代码1.1 定义Aes加密类块1.2 在主函数中调用 2 获取Key和IV2.1 基本方法2.2 自定义Key2.3 技术方面的原理 参考文章: C#软件加密实例? 。 参考官文: Aes 类。 在使用C#的自带的System.Security.Cryptography.Aes模块进行加密和解…

mediasoup 源码分析 (八)分析PlainTransport

mediasoup 源码分析 (六)分析PlainTransport 一、接收裸RTP流二、mediasoup 中udp建立过程 tips 一、接收裸RTP流 PlainTransport 可以接收裸RTP流,也可以接收AES加密的RTP流。源码中提供了一个通过ffmpeg发送裸RTP流到mediasoup的脚本&…

基于PyTorch设计的全景图合成系统【文末完整工程源码下载】

前言 本项目实现基于PyTorch将多张图片合成为一张全景图。(图像存储路径为/images/1)。 作者:阿齐Archie(联系我微信公众号:阿齐Archie) 使用的图片为: 合成后为: 这个全景图项目主…

eNSP启动设备失败,错误代码40,网卡配置正常,虚拟机导致的错误解决过程

安装eNSP后出现以下错误。 按照帮助文档,查看了相关软件,尤其是vitualbox的版本以及网卡问题。网卡设置正常,vitualbox也匹配成功。 附:vitualbox各个版本的下载地址: 关于网卡名称的修改方法,参照博客 …

python实现技术指标(简单移动平均,加权移动平均线,指数移动平均线)

移动平均线是最常见的技术指标,它能够去除时间序列的短期波动,使得数据变得平滑,从而可以方便看出序列的趋势特征。常见的移动平均线有简单移动平均线,加权移动平均线,指数移动平均线。 一. 简单移动平均(SMA) 简单移…

2.超声波测距模块

1.简介 2.超声波的时序图 3.基于51单片机实现的代码 #include "reg52.h" #include "intrins.h" sbit led1P3^7;//小于10,led1亮,led2灭 sbit led2P3^6;//否则,led1灭,led2亮 sbit trigP1^5; sbit echo…

电容的命名规则

给如下参数给采购,就可以获取 还有一些参数需要重视 容值随着环境温度而保持的程度 常规应用时是可以不用看材质,但是如果使用在新能源汽车和极端环境下的电子产品,就需要关注材质,曾有供应商把可用级电容供应车企,导致…

动手学深度学习(Pytorch版)代码实践 -计算机视觉-36图像增广

6 图片增广 import matplotlib.pyplot as plt import numpy as np import torch import torchvision from d2l import torch as d2l from torch import nn from PIL import Image import liliPytorch as lp from torch.utils.data import Dataset, DataLoaderplt.figure(cat)…

8.DELL R730服务器对RAID5进行扩容

如果服务器的空间不足了,如何进行扩容?我基本上按照如何重新配置虚拟磁盘或添加其他硬盘来进行操作。我的机器上已经有三块硬盘了,组了Raid5,现在再添加一块硬盘。 先把要添加的硬盘插入服务器,无论是在IDRAC还是管理…

基于S7-200PLC的全自动洗衣机控制系统设计

wx供重浩:创享日记 那边对话框发送:plc洗衣 获取完整无水印设计说明报告(含程序梯形图) 1.自动洗衣机PLC控制的控制要求 1.1全自动洗衣机的基本结构、工作流程和工作原理 1.自动洗衣机的基本结构 2.自动洗衣机的工作流程 自动洗…

RepVGG论文阅读笔记

目录 RepVGG: Making VGG-style ConvNets Great Again摘要INTRODUCTION—简介RepVGG BlockModel Re-parameterization -- 模型重参数化融合Conv2d和BN,将三个分支上的卷积算子和BN算子都转化为卷积算子(包括卷积核和偏置)多分支融合&#xff…