深度学习基于Resnet18的图像多分类--训练自己的数据集(超详细 含源码)

news2025/1/25 9:00:49

1.ResNet18原理

2.文件存储 一个样本存放的文件夹为dataset 下两个文件夹 train和test文件(训练和预测)

3.训练和测试的文件要相同。下面都分别放了 crane (鹤)、elephant(大象)、leopard(豹子)

4.编写预测的Python文件:code.py 跟dataset是同级路径。

5.code.py 训练模型源代码

6. 测试代码使用豹子的图像进行分类,复制自己的绝对路径。

7.预测结果达到了99.95%

1.ResNet18原理

        ResNet18是一个经典的深度卷积神经网络模型,由微软亚洲研究院提出,用于参加2015年的ImageNet图像分类比赛。ResNet18的名称来源于网络中包含的18个卷积层。

ResNet18的基本结构如下:

    输入层:接收大小为224x224的RGB图像。
    卷积层:共4个卷积层,每个卷积层使用3x3的卷积核和ReLU激活函数,提取图像的局部特征。
    残差块:共8个残差块,每个残差块由两个卷积层和一条跳跃连接构成,用于解决深度卷积神经网络中梯度消失和梯度爆炸问题。
    全局平均池化层:对特征图进行全局平均池化,将特征图转化为一维

    向量。
    全连接层:包含一个大小为1000的全连接层,用于分类输出。
    输出层:使用softmax激活函数,生成1000个类别的概率分布。

2.文件存储 一个样本存放的文件夹为dataset 下两个文件夹 train和test文件(训练和预测)

3.训练和测试的文件要相同。下面都分别放了 crane (鹤)、elephant(大象)、leopard(豹子)

4.编写预测的Python文件:code.py 跟dataset是同级路径。

5.code.py 训练模型源代码

import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
n_classes = 3  # 几种分类的
preteain = False  # 是否下载使用训练参数 有网true 没网false
epoches = 4  # 训练的轮次
traindataset = datasets.ImageFolder(root='./dataset/train/', transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

]))

testdataset = datasets.ImageFolder(root='./dataset/test/', transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

]))

classes = testdataset.classes
print(classes)

model = models.resnet18(pretrained=preteain)
if preteain == True:
    for param in model.parameters():
        param.requires_grad = False
model.fc = nn.Linear(in_features=512, out_features=n_classes, bias=True)
model = model.to(device)


def train_model(model, train_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    for idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        preds = outputs.argmax(dim=1)
        total_corrects += torch.sum(preds.eq(labels))
        total_loss += loss.item() * inputs.size(0)
        total += labels.size(0)
    total_loss = total_loss / total
    acc = 100 * total_corrects / total
    print("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, total_loss, acc))
    return total_loss, acc


def test_model(model, test_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            total_loss += loss.item() * inputs.size(0)
            total_corrects += torch.sum(preds.eq(labels))

        loss = total_loss / total
        accuracy = 100 * total_corrects / total
        print("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, loss, accuracy))
        return loss, accuracy


loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_loader = DataLoader(traindataset, batch_size=32, shuffle=True)
test_loader = DataLoader(testdataset, batch_size=32, shuffle=True)
for epoch in range(0, epoches):
    loss1, acc1 = train_model(model, train_loader, loss_fn, optimizer, epoch)
    loss2, acc2 = test_model(model, test_loader, loss_fn, optimizer, epoch)

classes = testdataset.classes
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

path = r'C:\Users\Administrator\Desktop\动物\dataset\test\leopard\img_test_850.jpg'  # 测试图片路径
model.eval()
img = Image.open(path)
img_p = transform(img).unsqueeze(0).to(device)
output = model(img_p)
pred = output.argmax(dim=1).item()
plt.imshow(img)
plt.show()
p = 100 * nn.Softmax(dim=1)(output).detach().cpu().numpy()[0]
print('该图像预测类别为:', classes[pred])

# 三分类
print('类别{}的概率为{:.2f}%,类别{}的概率为{:.2f}%,类别{}的概率为{:.2f}%'.format(classes[0], p[0], classes[1], p[1], classes[2], p[2]))

 6. 测试代码使用豹子的图像进行分类,复制自己的绝对路径。

7.预测结果达到了99.95%

如果自己的分类不止三种那么需要修改。我这里训练的是三种图像。根据自己实际情况填写。

还需修改,如果是四分类的话 多加一个 “类别{}的概率为{:.2f}%” 和 classes[3], p[3] ,因为索引是从0开始的 所以四分类的下标就为3。

# 三分类
print('类别{}的概率为{:.2f}%,类别{}的概率为{:.2f}%,类别{}的概率为{:.2f}%'.format(classes[0], p[0], classes[1], p[1], classes[2], p[2]))

制作不易 希望大家多多支持!

有问题可以联系我。添加备注 CSDN图像分类谢谢。QQ

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/673103.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ROS从入门到精通2-7:Gazebo仿真之动态生成障碍物

目录 0 专栏介绍1 动态生成障碍应用场景2 基于Gazebo动态生成障碍2.1 spawn_model服务2.2 动态构造障碍物URDF2.3 请求服务与动态生成 3 实测演示 0 专栏介绍 本专栏旨在通过对ROS的系统学习,掌握ROS底层基本分布式原理,并具有机器人建模和应用ROS进行实…

CSS | 解决html中img标签图片底部存在空白缝隙的问题

目录 问题描述 原因分析 解决方案 写在最后 问题描述 在学习CSS的过程中&#xff0c;我们经常会遇到图片底侧存在空白缝隙的问题。 代码示例&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" />&l…

SpringCloudAlibaba之Sentinel源码分析--protoc-3.17.3-win64

Sentinel源码分析 文章目录 Sentinel源码分析1.Sentinel的基本概念1.1.ProcessorSlotChain1.2.Node1.3.Entry1.3.1.自定义资源1.3.2.基于注解标记资源 1.4.Context1.4.1.什么是Context1.4.2.Context的初始化1.4.2.1.自动装配1.4.2.2.AbstractSentinelInterceptor1.4.2.3.Contex…

【C++初阶】string类常见题目详解(一)—— 仅仅反转字母、字符串中的第一个唯一字母、字符串最后一个单词的长度、验证回文串、字符串相加

​ ​&#x1f4dd;个人主页&#xff1a;Sherry的成长之路 &#x1f3e0;学习社区&#xff1a;Sherry的成长之路&#xff08;个人社区&#xff09; &#x1f4d6;专栏链接&#xff1a;C初阶 &#x1f3af;长路漫漫浩浩&#xff0c;万事皆有期待 上一篇博客&#xff1a;【C初阶】…

【Python 基础篇】Python 集合及集合常用函数

文章目录 导言一、集合的创建和访问二、集合的常用函数len()add()remove()union()intersection()difference()issubset()issuperset()clear() 总结 导言 在Python中&#xff0c;集合&#xff08;Set&#xff09;是一种无序、不重复的数据类型&#xff0c;用于存储多个唯一的元…

HCIP网络笔记分享——广域网协议及BGP协议

第二部分 HCIA回顾一、广域网技术1、HDLC2、PPP3、PAP4、CHAP5、GRE6、运行路由协议 二、动态路由协议1、OSPF2、重发布3、路由策略3.1 抓流量3.2 具体过程 4、BGP 三、BGP边界网关协议1、BGP的数据包2、BGP的状态机3、BGP的工作过程4、BGP的路由黑洞问题5、BGP的防环问题6、BG…

Studio One6.1.1免费中文版电子音乐、摇滚乐制作软件

Studio One6是一款专业的音乐制作软件&#xff0c;该软件提供了全面的音频编辑和混音功能&#xff0c;包括录制、编曲、合成、采样等多种工具&#xff0c;可用于制作各种类型的音乐&#xff0c;如流行音乐、电子音乐、摇滚乐等。 Studio One6.1的主要特点包括&#xff1a; 1. …

深入理解什么是端口(port)

每当看到有人的简历上写着熟悉 tcp/ip, http 等协议时, 我就忍不住问问他们: 你给我说说, 端口是啥吧! 可惜, 很少有人能说得让人满意... 所以这次就来谈谈端口(port), 这个熟悉的陌生人. 在此过程中, 还会谈谈间接层, naming service 等概念, IoC, 依赖倒置等原则以及 TCP 协议…

JavaEE的学习(Spring +Spring MVC + MyBatis)

一、Spring入门 Spring是一个轻量级的控制反转 (IoC-Inversion of Control)和面向切面 (AOP-Aspect Oriented Programming)的容器&#xff08;框架&#xff09;。它采用分层架构&#xff0c;由大约20个模块组成&#xff0c;这些模块分为Core Container、Data Access/Integrati…

什么是计算机蠕虫?

计算机蠕虫诞生的背景 计算机蠕虫的诞生与计算机网络的发展密切相关。20世纪60年代末和70年代初&#xff0c;互联网还处于早期阶段&#xff0c;存在着相对较少的计算机和网络连接。然而&#xff0c;随着计算机技术的进步和互联网的普及&#xff0c;计算机网络得以迅速扩张&…

TC8:SOMEIPSRV_FORMAT_09-10

SOMEIPSRV_FORMAT_09: Undefined bits in the Flag field 目的 Flag字段中的未定义位应静态设置为0 测试步骤 DUT CONFIGURE:启动具有下列信息的服务Service ID:SERVICE-ID-1Instance数量:1Tester:客户端-1监听在网卡上DUT:发送SOME/IP Notification消息Tester:验证接收…

Flutter应用开发,系统样式改不了?SystemChrome 状态栏、导航栏、屏幕方向……想改就改

文章目录 开发场景SystemChrome 介绍SystemChrome的使用导入 SystemChrome 包隐藏状态栏说明 改变状态栏的样式注意事项其他样式说明 锁定屏幕方向锁定屏幕方向实例注意事项 开发场景 开发APP时&#xff0c;我们经常要客制化状态栏、导航栏栏等的样式和风格&#xff0c;Flutte…

网络之网络基础入门

文章目录 前言一、局域网和广域网1.局域网LAN2.广域网WAN3.城域网和校园网4.如何区分广域网和局域网 二、协议1.概念2.理解3.协议分层4.数据传输的条件 三、OSI七层模型&#xff08;了解即可&#xff09;1.概念2.OSI七层模型 四、TCP/IP五层&#xff08;四层&#xff09;模型1.…

TC8:TCP_BASICS_11-17

TCP_BASICS_11: [finwait-2 -> time_wait] delay(2*MSL) -> [closed] 目的 TCP从FINWAIT-2状态到TIME-WAIT状态后,等待2MSL时间后,移动到CLOSED状态 关于为什么要等待2MSL时间,我的文章中讲过太多次了,这里就不提了 测试步骤 Tester:让DUT移动到FINWAIT-2状态Test…

使用Python批量进行数据分析

案例01 批量升序排序一个工作簿中的所有工作表——产品销售统计表.xlsx import xlwings as xw import pandas as pd app xw.App(visible False, add_book False) workbook app.books.open(产品销售统计表.xlsx) worksheet workbook.sheets # 列出工作簿中的所有工作表 fo…

SpringBoot 如何使用 ApplicationEventPublisher 发布事件

SpringBoot 如何使用 ApplicationEventPublisher 发布事件 在 SpringBoot 应用程序中&#xff0c;我们可以使用 ApplicationEventPublisher 接口来发布事件。事件可以是任何对象&#xff0c;当该对象被发布时&#xff0c;所有监听该事件的监听器都会收到通知。 下面是一个简单…

[Leetcode] 0733. 图像渲染

733. 图像渲染 点击上方&#xff0c;跳转至leetcode 题目描述 有一幅以 m x n 的二维整数数组表示的图画 image &#xff0c;其中 image[i][j] 表示该图画的像素值大小。 你也被给予三个整数 sr , sc 和 newColor 。你应该从像素 image[sr][sc] 开始对图像进行 上色填充 。 为…

第八章 MobileNetv3网络详解

系列文章目录 第一章 AlexNet网络详解 第二章 VGG网络详解 第三章 GoogLeNet网络详解 第四章 ResNet网络详解 第五章 ResNeXt网络详解 第六章 MobileNetv1网络详解 第七章 MobileNetv2网络详解 第八章 MobileNetv3网络详解 第九章 ShuffleNetv1网络详解 第十章…

1.RocketMQ的安装与集群架构

RocketMQ快速入门 RocketMQ是阿里巴巴2016年MQ中间件&#xff0c;使用Java语言开发&#xff0c;在阿里内部&#xff0c;RocketMQ承接了例如“双11”等高并发场景的消息流转&#xff0c;能够处理万亿级别的消息。 2.1 准备工作 2.1.1 下载RocketMQ RocketMQ最新版本&#xff1a;…

Redis缓存与数据库如何保证一致性?同步删除+延时双删+异步监听+多重保障方案

导航&#xff1a; 【Java笔记踩坑汇总】Java基础进阶JavaWebSSMSpringBoot瑞吉外卖SpringCloud黑马旅游谷粒商城学成在线MySQL高级篇设计模式常见面试题源码 目录 一、四种基础同步策略 1.1 同步策略 1.2 更新缓存还是删除缓存&#xff1f; 1.2.1 更新缓存的优缺点 1.2.2 …