Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类

news2025/1/10 16:49:07

目录

1.ResNet残差网络

1.1 ResNet定义

 1.2 ResNet 几种网络配置

 1.3 ResNet50网络结构

1.3.1 前几层卷积和池化

1.3.2 残差块:构建深度残差网络

1.3.3 ResNet主体:堆叠多个残差块

1.4 迁移学习猫狗二分类实战

1.4.1 迁移学习

1.4.2 模型训练

1.4.3 模型预测


1.ResNet残差网络

1.1 ResNet定义

深度学习在图像分类、目标检测、语音识别等领域取得了重大突破,但是随着网络层数的增加,梯度消失和梯度爆炸问题逐渐凸显。随着层数的增加,梯度信息在反向传播过程中逐渐变小,导致网络难以收敛。同时,梯度爆炸问题也会导致网络的参数更新过大,无法正常收敛。

为了解决这些问题,ResNet提出了一个创新的思路:引入残差块(Residual Block)。残差块的设计允许网络学习残差映射,从而减轻了梯度消失问题,使得网络更容易训练。

下图是一个基本残差块。它的操作是把某层输入跳跃连接到下一层乃至更深层的激活层之前,同本层输出一起经过激活函数输出。
 

24353e89d9c84a17babbbf4ebe90630b.png

 1.2 ResNet 几种网络配置

如下图:

 1.3 ResNet50网络结构

ResNet-50是一个具有50个卷积层的深度残差网络。它的网络结构非常复杂,但我们可以将其分为以下几个模块:

1.3.1 前几层卷积和池化

import torch
import torch.nn as nn

class ResNet50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet50, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

1.3.2 残差块:构建深度残差网络

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

1.3.3 ResNet主体:堆叠多个残差块

在ResNet-50中,我们堆叠了多个残差块来构建整个网络。每个残差块会将输入的特征图进行处理,并输出更加丰富的特征图。堆叠多个残差块允许网络在深度方向上进行信息的层层提取,从而获得更高级的语义信息。代码如下:

class ResNet50(nn.Module):
    def __init__(self, num_classes=1000):
        # ... 前几层代码 ...

        # 4个残差块的block1
        self.layer1 = self._make_layer(ResidualBlock, 64, 3, stride=1)
        # 4个残差块的block2
        self.layer2 = self._make_layer(ResidualBlock, 128, 4, stride=2)
        # 4个残差块的block3
        self.layer3 = self._make_layer(ResidualBlock, 256, 6, stride=2)
        # 4个残差块的block4
        self.layer4 = self._make_layer(ResidualBlock, 512, 3, stride=2)

 利用make_layer函数实现对基本残差块Bottleneck的堆叠。代码如下:

def _make_layer(self, block, channel, block_num, stride=1):
    """
        block: 堆叠的基本模块
        channel: 每个stage中堆叠模块的第一个卷积的卷积核个数,对resnet50分别是:64,128,256,512
        block_num: 当期stage堆叠block个数
        stride: 默认卷积步长
    """
        downsample = None   # 用于控制shorcut路的
        if stride != 1 or self.in_channel != channel*block.expansion:   # 对resnet50:conv2中特征图尺寸H,W不需要下采样/2,但是通道数x4,因此shortcut通道数也需要x4。对其余conv3,4,5,既要特征图尺寸H,W/2,又要shortcut维度x4
            downsample = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=channel*block.expansion, kernel_size=1, stride=stride, bias=False), # out_channels决定输出通道数x4,stride决定特征图尺寸H,W/2
                nn.BatchNorm2d(num_features=channel*block.expansion))

        layers = []  # 每一个convi_x的结构保存在一个layers列表中,i={2,3,4,5}
        layers.append(block(in_channel=self.in_channel, out_channel=channel, downsample=downsample, stride=stride)) # 定义convi_x中的第一个残差块,只有第一个需要设置downsample和stride
        self.in_channel = channel*block.expansion   # 在下一次调用_make_layer函数的时候,self.in_channel已经x4

        for _ in range(1, block_num):  # 通过循环堆叠其余残差块(堆叠了剩余的block_num-1个)
            layers.append(block(in_channel=self.in_channel, out_channel=channel))

        return nn.Sequential(*layers)   # '*'的作用是将list转换为非关键字参数传入

1.4 迁移学习猫狗二分类实战

1.4.1 迁移学习

迁移学习(Transfer Learning)是一种机器学习和深度学习技术,它允许我们将一个任务学到的知识或特征迁移到另一个相关的任务中,从而加速模型的训练和提高性能。在迁移学习中,我们通常利用已经在大规模数据集上预训练好的模型(称为源任务模型),将其权重用于新的任务(称为目标任务),而不是从头开始训练一个全新的模型。

迁移学习的核心思想是:在解决一个新任务之前,我们可以先从已经学习过的任务中获取一些通用的特征或知识,并将这些特征或知识迁移到新任务中。这样做的好处在于,源任务模型已经在大规模数据集上进行了充分训练,学到了很多通用的特征,例如边缘检测、纹理等,这些特征对于许多任务都是有用的。

1.4.2 模型训练

首先,我们需要准备用于猫狗二分类的数据集。数据集可以从Kaggle上下载,其中包含了大量的猫和狗的图片。

在下载数据集后,我们需要将数据集划分为训练集和测试集。训练集文件夹命名为train,其中建立两个文件夹分别为cat和dog,每个文件夹里存放相应类别的图片。测试集命名为test,同理。然后我们使用ResNet50网络模型,在我们的计算机上使用GPU进行训练并保存我们的模型,训练完成后在测试集上验证模型预测的正确率。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50

# 设置随机种子
torch.manual_seed(42)

# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载数据集
train_dataset = ImageFolder("train", transform=transform)
test_dataset = ImageFolder("test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# 加载预训练的ResNet-50模型
model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # 替换最后一层全连接层,以适应二分类问题

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item()}")
torch.save(model,'model/c.pth')
# 测试模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        print(outputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        break

    print(f"Accuracy on test images: {(correct / total) * 100}%")

1.4.3 模型预测

首先加载我们保存的模型,这里我们进行单张图片的预测,并把预测结果打印日志。

import cv2 as cv
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torchvision.transforms as transforms
import  torch
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=torch.load('model/c.pth')
print(model)
model.to(device)

test_image_path = 'test/dogs/dog.4001.jpg'  # Replace with your test image path
image = Image.open(test_image_path)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(image).unsqueeze(0).to(device)  # Add a batch dimension and move to GPU

# Set the model to evaluation mode
model.eval()


with torch.no_grad():
    outputs = model(input_tensor)
    _, predicted = torch.max(outputs, 1)
    predicted_label = predicted.item()


label=['猫','狗']
print(label[predicted_label])
plt.axis('off')
plt.imshow(image)
plt.show()

运行截图

至此这篇文章到此结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/778745.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3基础+进阶(二、vue3常用组合式api基本使用)

目录 第二章、组合式API 2.1 入口:setup 2.1.1 setup选项的写法和执行时机 2.1.2 setup中写代码的特点 2.1.3 script setup语法糖 2.1.4 setup中this的指向 2.2 生成响应式数据:reactive和ref函数 2.2.1 reactive函数 2.2.2 ref函数 2.2.3 rea…

Cesium态势标绘专题-入口

本专题没有废话,只有代码,撸! 标绘主类MilitaryPlotting.ts /** 态势标绘主类* @Author: Wang jianLei* @Date: 2023-01-13 14:47:20* @Last Modified by: jianlei wang* @Last Modified time: 2023-05-31 09:55:34*/ import * as Creator from ./create/index; import Cre…

S32K324双核的核间通信使用示例

文章目录 前言修改ld文件核0的ld文件核1的ld文件 定义共享数据使用共享数据编译共享数据文件总结 前言 最近项目用S32K324开发,暂时只用了MCAL,没有Autosar上层的模块,最开始用官方给的demo工程双核可以正常跑起来,但实际开发时都…

使用nginx和ffmpeg搭建HTTP FLV流媒体服务器(摄像头RTSP视频流->RTMP->http-flv)

名词解释 RTSP (Real-Time Streaming Protocol) 是一种网络协议,用于控制实时流媒体的传输。它是一种应用层协议,通常用于在客户端和流媒体服务器之间建立和控制媒体流的传输。RTSP允许客户端向服务器发送请求,如…

数据分析工具与技术

数据分析工具与技术 数据分析技术 数据分析工具 备选方案分析 一种对已识别的可选方案进行评估的技术,用来决定选择哪种方案 或使用何种方法来执行项目工作。 其他风险参数评估 为了方便未来分析和行动,在对单个项目风险进行优先级排序时&#xff0…

GO内存模型(同步机制)

文章目录 概念1. 先行发生 编译器重排同步机制init函数协程的创建channelsync 包1. sync.mutex2. sync.rwmutex3. sync.once atomic 参考文献 概念 1. 先行发生 The happens before relation is defined as the transitive closure of the union of the sequenced before and …

超详细-Vivado配置Sublime+Sublime实现Verilog语法实时检查

目录 一、前言 二、准备工作 三、Vivado配置Sublime 3.1 Vivado配置Sublime 3.2 环境变量添加 3.3 环境变量验证 3.4 Vivado设置 3.5 配置验证 3.6 解决Vivado配置失败问题 四、Sublime配置 4.1 Sublime安装Package Control 4.2 Sublime安装Verilog插件 4.3 安装语…

centos7中MySQL备份还原策略

目录 一、直接拷贝数据库文件 1.1在shangke主机上停止服务并且打包压缩数据库文件 1.2 在shangke主机上把数据库文件传输到localhost主机上(ip为192.168.33.157) 1.3在localhost主机上停止服务,解压数据库文件 1.4 在localhost主机上开启服务 1.5 测试 二、m…

利用@Excel实现复杂表头导入

EasyPoi导入 <a-upload name"file" :showUploadList"false" :multiple"false" :headers"tokenHeader" :action"importExcelUrl"change"handleImportExcel"><a-button type"primary" icon&quo…

【软件测试】如何选择回归用例

目录 如何在原始用例集中挑选测试用例 具体实践 总结 本文讨论一下在回归测试活动中&#xff0c;如何选择测试用例集。 回归测试用例集包括基本测试用例集&#xff08;原始用例&#xff09;迭代新增测试用例集&#xff08;修复故障引入的用例和新增功能引入的用例集&#xf…

洛必达法则和分部积分的应用之计算数学期望EX--概率论浙大版填坑记

如下图所示&#xff0c;概率论与数理统计浙大第四版有如下例题&#xff1a; 简单说就是&#xff1a;已知两个相互独立工作电子装置寿命的概率密度函数&#xff0c;将二者串联成整机&#xff0c;求整机寿命的数学期望。 这个题目解答中的微积分部分可谓是相当的坑爹&#xff0c;…

【1++的C++初阶】之适配器

&#x1f44d;作者主页&#xff1a;进击的1 &#x1f929; 专栏链接&#xff1a;【1的C初阶】 文章目录 一&#xff0c;什么是适配器二&#xff0c;栈与队列模拟实现三&#xff0c;优先级队列四&#xff0c;reverse_iterator 一&#xff0c;什么是适配器 适配器作为STL的六大组…

【高阶数据结构】跳表

文章目录 一、什么是跳表二、跳表的效率如何保证&#xff1f;三、skiplist的实现四、skiplist跟平衡搜索树和哈希表的对比 一、什么是跳表 skiplist本质上也是一种查找结构&#xff0c;用于解决算法中的查找问题&#xff0c;跟平衡搜索树和哈希表的价值是 一样的&#xff0c;可…

Windows环境Docker安装

目录 安装Docker Desktop的步骤 Docker Desktop 更新WSL WSL 的手动安装步骤 Windows PowerShell 拉取&#xff08;Pull&#xff09;镜像 查看已下载的镜像 输出"Hello Docker!" Docker Desktop是Docker官方提供的用于Windows的图形化桌面应用程序&#xff0c…

区间预测 | MATLAB实现QRBiLSTM双向长短期记忆神经网络分位数回归多输入单输出区间预测

区间预测 | MATLAB实现QRBiLSTM双向长短期记忆神经网络分位数回归多输入单输出区间预测 目录 区间预测 | MATLAB实现QRBiLSTM双向长短期记忆神经网络分位数回归多输入单输出区间预测效果一览基本介绍模型描述程序设计参考资料 效果一览 基本介绍 区间预测 | MATLAB实现QRBiLSTM…

odoo16 用好计量单位中的激活功能

odoo16 用好计量单位中的激活功能 根据国内常用&#xff0c;把不常用的单位去除&#xff0c;删除不了&#xff0c;提示已用&#xff0c;其实不用删除&#xff0c;每个单位后有个激活功能&#xff0c;选一下就可以了&#xff0c;显示成整洁的界面了 第一次用时&#xff0c;小伙伴…

解决spring cloud 中使用spring security全局异常处理器失效

写auth认证模块实现忘记密码与注册功能时&#xff0c;用异常抛出&#xff0c;全局异常处理器无法捕获。 无法进行异常捕捉 解决方案&#xff1a;使用WebSecurityConfigurerAdapter.configure中http实现自定义异常&#xff1a; EnableWebSecurity EnableGlobalMethodSecurity(…

87、springcloud核心组件及其作用

spring Eureka: 服务注册与发现 注册:&#xff1a;每个服务都向Eureka登记自己提供服务的元数据&#xff0c;包括服务的ip地址、端口号、版本号、通信协议等 eureka将各个服务维护在了一个服务清单中 (双层Map&#xff0c;第一层key是服务名&#xff0c;第二层key是实例名&…

macOS 源码编译 qpress

╰─➤ git clone https://github.com/PierreLvx/qpress.git ╰─➤ cd qpress ╰─➤ make g -O3 -o qpress -x c quicklz.c -x c qpress.cpp aio.cpp utilities.cpp -lpthread -Wall -Wextra -Werror ╰─➤ sudo make install …

Vue--》打造个性化医疗服务的医院预约系统(三)

今天开始使用 vue3 + ts 搭建一个医院预约系统的前台页面,因为文章会将项目的每一个地方代码的书写都会讲解到,所以本项目会分成好几篇文章进行讲解,我会在最后一篇文章中会将项目代码开源到我的GithHub上,大家可以自行去进行下载运行,希望本文章对有帮助的朋友们能多多关…