yolov8实战第四天——yolov8图像分类 ResNet50图像分类(保姆式教程)

news2025/1/20 22:02:38

yolov8实战第一天——yolov8部署并训练自己的数据集(保姆式教程)_yolov8训练自己的数据集-CSDN博客在前几天,我们使用yolov8进行了部署,并在目标检测方向上进行自己数据集的训练与测试,今天我们训练下yolov8的图像分类,看看效果如何,同时使用resnet50也训练一个分类模型,看看哪个效果好!

图像分类是指将输入的图像自动分类为不同的类别。它是计算机视觉领域的一个重要应用,可以用于人脸识别、物体识别、场景分类等任务。

通常情况下,图像分类的流程如下:

  1. 收集和准备数据集:收集与任务相关的图像数据,并将其打上标签。
  2. 定义模型:选择一种适合于你的任务的深度学习模型,例如卷积神经网络(CNN)。
  3. 训练模型:使用收集到的数据集对模型进行训练,通过反向传播算法来更新模型参数,使其可以根据输入图像进行正确的分类。
  4. 评估模型性能:使用测试集对已经训练好的模型进行评估,比较模型预测结果与真实标签之间的差异,从而评估模型的性能。
  5. 使用模型进行预测:使用已经训练好的模型对新的图像进行分类预测。

在实际应用中,可以使用各种深度学习框架(例如 TensorFlow、PyTorch、Keras 等)来构建图像分类模型,并使用各种数据增强技术(例如旋转、缩放、裁剪等)来增加数据集的多样性和数量。

如果你想学习如何使用深度学习框架来构建图像分类模型,可以参考一些在线教程、书籍或者 MOOC。

一、yolov8图像分类

1.模型选型

下载yolov8分类模型。

分别使用模型进行测试:

yolov8n-cls效果:

yolov8m-cls效果:

总结:n效果不咋地,还是得使用m进行后续训练工作。 

2.数据集准备

皮肤癌检测_数据集-飞桨AI Studio星河社区

同目标检测,还是放在datasets下。

直接改成这个,省去分数据集操作。 

 3.训练

yolo classify train data=./datasets/skin-cancer-detection model=yolov8n-cls.pt epochs=100

测试:

yolo classify predict model=runs/classify/train4/weights/best.pt source='./datasets/skin-cancer-detection/train/nevus'

  

label: 

 pred:

总结:数据集比较小,yolov8效果不太好。

、resnet50图像分类

Resnet50 网络中包含了 49 个卷积层、一个全连接层。如图下图所示,Resnet50网络结构可以分成七个部分,第一部分不包含残差块,主要对输入进行卷积、正则化、激活函数、最大池化的计算。第二、三、四、五部分结构都包含了残差块,图 中的绿色图块不会改变残差块的尺寸,只用于改变残差块的维度。在 Resnet50 网 络 结 构 中 , 残 差 块 都 有 三 层 卷 积 , 那 网 络 总 共 有1+3×(3+4+6+3)=49个卷积层,加上最后的全连接层总共是 50 层,这也是Resnet50 名称的由来。网络的输入为 224×224×3,经过前五部分的卷积计算,输出为 7×7×2048,池化层会将其转化成一个特征向量,最后分类器会对这个特征向量进行计算并输出类别概率。

运行train.py即可。

train.py

import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import time
 
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

 
# 一、建立数据集
# animals-6
#   --train
#       |--dog
#       |--cat
#       ...
#   --valid
#       |--dog
#       |--cat
#       ...
#   --test
#       |--dog
#       |--cat
#       ...
# 我的数据集中 train 中每个类别60张图片,valid 中每个类别 10 张图片,test 中每个类别几张到几十张不等,一共 6 个类别。
 
# 二、数据增强
# 建好的数据集在输入网络之前先进行数据增强,包括随机 resize 裁剪到 256 x 256,随机旋转,随机水平翻转,中心裁剪到 224 x 224,转化成 Tensor,正规化等。
image_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
}
 
# 三、加载数据
# torchvision.transforms包DataLoader是 Pytorch 重要的特性,它们使得数据增加和加载数据变得非常简单。
# 使用 DataLoader 加载数据的时候就会将之前定义的数据 transform 就会应用的数据上了。
dataset = 'skin-cancer-detection'
train_directory = './skin-cancer-detection/train'
valid_directory = './skin-cancer-detection/val'
 
batch_size = 32
num_classes = 9 #分类种类数
print(train_directory)
data = {
    'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),
    'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid'])
}
print("训练集图片类别及其对应编号(种类名:编号):",data['train'].class_to_idx)
print("测试集图片类别及其对应编号:",data['valid'].class_to_idx)
 
 
train_data_size = len(data['train'])
valid_data_size = len(data['valid'])
 
train_data = DataLoader(data['train'], batch_size=batch_size, shuffle=True, num_workers=0)
valid_data = DataLoader(data['valid'], batch_size=batch_size, shuffle=True, num_workers=0)
 
print("训练集图片数量:",train_data_size, "测试集图片数量:",valid_data_size)
 
# 四、迁移学习
# 这里使用ResNet-50的预训练模型。
#resnet50 = models.resnet50(pretrained=True)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

 
 
# 在PyTorch中加载模型时,所有参数的‘requires_grad’字段默认设置为true。这意味着对参数值的每一次更改都将被存储,以便在用于训练的反向传播图中使用。
# 这增加了内存需求。由于预训练的模型中的大多数参数已经训练好了,因此将requires_grad字段重置为false。
for param in resnet50.parameters():
    param.requires_grad = False
 
# 为了适应自己的数据集,将ResNet-50的最后一层替换为,将原来最后一个全连接层的输入喂给一个有256个输出单元的线性层,接着再连接ReLU层和Dropout层,然后是256 x 6的线性层,输出为6通道的softmax层。
fc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, num_classes),
    nn.LogSoftmax(dim=1)
)
 
# 用GPU进行训练。
resnet50 = resnet50.to('cuda:0')
 
# 定义损失函数和优化器。
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())
 
# 五、训练
def train_and_valid(model, loss_function, optimizer, epochs=25):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    history = []
    best_acc = 0.0
    best_epoch = 0
 
    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1, epochs))
 
        model.train()
 
        train_loss = 0.0
        train_acc = 0.0
        valid_loss = 0.0
        valid_acc = 0.0
 
        for i, (inputs, labels) in enumerate(tqdm(train_data)):
            inputs = inputs.to(device)
            labels = labels.to(device)
 
            #因为这里梯度是累加的,所以每次记得清零
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            print("标签值:",labels)
            print("输出值:",outputs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            train_acc += acc.item() * inputs.size(0)
 
        with torch.no_grad():
            model.eval()
 
            for j, (inputs, labels) in enumerate(tqdm(valid_data)):
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                valid_loss += loss.item() * inputs.size(0)
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))
                acc = torch.mean(correct_counts.type(torch.FloatTensor))
                valid_acc += acc.item() * inputs.size(0)
 
        avg_train_loss = train_loss/train_data_size
        avg_train_acc = train_acc/train_data_size
 
        avg_valid_loss = valid_loss/valid_data_size
        avg_valid_acc = valid_acc/valid_data_size
 
        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])
 
        if best_acc < avg_valid_acc:
            best_acc = avg_valid_acc
            best_epoch = epoch + 1
 
        epoch_end = time.time()
 
        print("Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(
            epoch+1, avg_valid_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start
        ))
        print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch))
 
        torch.save(model, 'models/'+dataset+'_model_'+str(epoch+1)+'.pt')
    return model, history
 
num_epochs = 100 #训练周期数
trained_model, history = train_and_valid(resnet50, loss_func, optimizer, num_epochs)
torch.save(history, 'models/'+dataset+'_history.pt')
 
history = np.array(history)
plt.plot(history[:, 0:2])
plt.legend(['Tr Loss', 'Val Loss'])
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.ylim(0, 1)
plt.savefig(dataset+'_loss_curve.png')
plt.show()
 
plt.plot(history[:, 2:4])
plt.legend(['Tr Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.savefig(dataset+'_accuracy_curve.png')
plt.show()

测试:图片名改下即可。

import torch
from torchvision import  models, transforms
import torch.nn as nn
import cv2
classes = ["1","2","3","4","5","6","7","8","9"] #识别种类名称(顺序要与训练时的数据导入编号顺序对应,可以使用datasets.ImageFolder().class_to_idx来查看)
 
transf = transforms.ToTensor()
device = torch.device('cuda:0')
num_classes = 2
model_path = "models/skin-cancer-detection_model_3.pt"
image_input = cv2.imread("ISIC_0000019.jpg")
image_input = transf(image_input)
image_input = torch.unsqueeze(image_input,dim=0).cuda()
#搭建模型
resnet50 = models.resnet50(pretrained=True)
for param in resnet50.parameters():
    param.requires_grad = False
 
fc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, num_classes),
    nn.LogSoftmax(dim=1)
)
resnet50 = torch.load(model_path)
 
outputs = resnet50(image_input)
value,id =torch.max(outputs,1)
print(outputs,"\n","结果是:",classes[id])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1344608.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PHP序列化总结1--序列化和反序列化的基础知识

序列化和反序列化的作用 1.序列化&#xff1a;将对象转化成数组或者字符串的形式 2.反序列化&#xff1a;将数组或字符串的形式转化为对象 为什么要进行序列化 这种数据形式中间会有很多空格&#xff0c;不同人有不同的书写情况&#xff0c;可能还会出现换行的情况 为此为了…

C# Image Caption

目录 介绍 效果 模型 decoder_fc_nsc.onnx encoder.onnx 项目 代码 下载 C# Image Caption 介绍 地址&#xff1a;https://github.com/ruotianluo/ImageCaptioning.pytorch I decide to sync up this repo and self-critical.pytorch. (The old master is in old ma…

07-项目打包 React Hooks

项目打包 项目打包是为了把整个项目都打包成最纯粹的js&#xff0c;让浏览器可以直接执行 打包命令已经在package.json里面定义好了 运行命令&#xff1a;npm run build&#xff0c;执行时间取决于第三方插件的数量以及电脑配置 打包完之后再build文件夹下&#xff0c;这个…

Self-attention学习笔记(Self Attention、multi-head self attention)

李宏毅机器学习Transformer Self Attention学习笔记记录一下几个方面的内容 1、Self Attention解决了什么问题2、Self Attention 的实现方法以及网络结构Multi-head Self Attentionpositional encoding 3、Self Attention 方法的应用4、Self Attention 与CNN以及RNN对比 1、Se…

STM32移植LVGL图形库

1、问题1&#xff1a;中文字符keil编译错误 解决方法&#xff1a;在KEIL中Options for Target Flash -> C/C -> Misc Controls添加“--localeenglish”。 问题2&#xff1a;LVGL中显示中文字符 使用 LVGL 官方的在线字体转换工具&#xff1a; Online font converter -…

Python面向对象编程 —— 类和异常处理

​ &#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 &#x1f4ab;个人格言:"没有罗马,那就自己创造罗马~" 目录 1. 类 1.1 类的定义 1.2 类变量和实例变量 1.3 类的继承 2. 异常处理 2.1类型异常 2.…

【集合竞价】

文章目录 集合竞价时间集合竞价量柱的关系a、亮度顶部是未匹配的成交量b、量柱底部表示虚拟撮合的成交量 集合竞价常用的几种形态1、低开集合竞价2.平开集合竞价3、高开集合竞价4、看集合竞价的关键时间点&#xff1a;5、需要注意的几点 庄家使用集合竞价进行的操作1、利用集合…

【EI会议征稿通知】2024年通信技术与软件工程国际学术会议 (CTSE 2024)

2024年通信技术与软件工程国际学术会议 (CTSE 2024) 2024 International Conference on Communication Technology and Software Engineering (CTSE 2024) 2024年通信技术与软件工程国际学术会议 (CTSE 2024)将于2024年03月15-17日在中国长沙举行。会议专注于通信技术与软件工…

3D视觉-结构光测量-线结构光测量

概述 线结构光测量中&#xff0c;由激光器射出的激光光束透过柱面透镜扩束&#xff0c;再经过准直&#xff0c;产生一束片状光。这片光束像刀刃一样横切在待测物体表面&#xff0c;因此线结构光法又被成为光切法。线结构光测量常采用二维面阵 CCD 作为接受器件&#xff0c;因此…

【代码随想录】刷题笔记Day42

前言 这两天机器狗终于搞定了&#xff0c;一个控制ROS大佬&#xff0c;一个计院编程大佬&#xff0c;竟然真把创新点这个弄出来了&#xff0c;牛牛牛牛&#xff08;菜鸡我只能负责在旁边喊加油&#xff09;。下午翘了自辩课来刷题&#xff0c;这次应该是元旦前最后一刷了&…

自然语言处理3——玩转文本分类 - Python NLP高级应用

目录 写在开头1. 文本分类的背后原理和应用场景1.1 文本分类的原理1.2 文本分类的应用场景 2. 使用机器学习模型进行文本分类&#xff08;朴素贝叶斯、支持向量机等&#xff09;2.1 朴素贝叶斯2.1.1 基本原理2.1.2 数学公式2.1.3 一般步骤2.1.4 简单python代码实现 2.2 支持向量…

JSON 详解

文章目录 JSON 的由来JSON 的基本语法JSON 的序列化简单使用stringify 方法之 replacerstringify 方法之 replacer 参数传入回调函数stringify 方法之 spacestringify 方法之 toJSONparse 方法之 reviver 利用 stringify 和 parse 实现深拷贝 json 相信大家一定耳熟能详&#x…

2023-12-30 AIGC-LangChain介绍

摘要: 2023-12-30 AIGC-LangChain介绍 LangChain介绍 1. https://youtu.be/Ix9WIZpArm0?t353 2. https://www.freecodecamp.org/news/langchain-how-to-create-custom-knowledge-chatbots/ 3. https://www.pinecone.io/learn/langchain-conversational-memory/ 4. https://de…

AJAX: 整理2:学习原生的AJAX,这边借助express框架

1. npm install express 终端直接安装 2. 测试案例&#xff1a;Hello World&#xff01; 新建一个express.js的文件&#xff0c;写入下方的内容 // 1. 引入express const express require(express)// 2. 创建服务器 const app express()// 3.创建路由规则 // request 是对请…

迪杰斯特拉(Dijkstra)算法详解

【专栏】数据结构复习之路 这篇文章来自上述专栏中的一篇文章的节选&#xff1a; 【数据结构复习之路】图&#xff08;严蔚敏版&#xff09;两万余字&超详细讲解 想了解更多图论的知识&#xff0c;可以去看看本专栏 Dijkstra 算法讲解&#xff1a; 迪杰斯特拉算法(Di…

如何理解李克特量表?选项距离相等+题目权重相等!

在学术研究中&#xff0c;通过开展问卷调查获取数据时&#xff0c;调查问卷分为量表题和非量表题。量表题就是测试受访者的态度或者看法的题目&#xff0c;大多采用李克特量表。 李克特量表是一种评分加总式态度量表&#xff08;attitude scale&#xff09;&#xff0c;由美国…

TCP/IP的五层网络模型

目录 封装&#xff08;打包快递&#xff09; 6.1应用层 6.2传输层 6.3网络层 6.4数据链路层 6.5物理层 分用&#xff08;拆快递&#xff09; 6.5物理层 6.4数据链路层 6.3网络层 6.2传输层 6.1应用层 封装&#xff08;打包快递&#xff09; 6.1应用层 此时做的数据…

【数字图像处理】比特平面分割,对比空间平滑滤波器的尺寸对滤波效果,对比均值滤波器和中值滤波器对图像的平滑效果

实验目的 对比均值滤波器和中值滤波器对图像的平滑效果&#xff1b;编程对比空间平滑滤波器的尺寸对滤波效果的影响&#xff1b;编程对比均值滤波器和中值滤波器对图像的平滑效果。 实验方法 比特平面分割 比特平面分层&#xff0c;代替突出灰度级范围&#xff0c;突出特定…

Unity坦克大战开发全流程——开始场景——设置界面

开始场景——设置界面 step1&#xff1a;设置面板的背景图 照着这个来设置就行了 step2&#xff1a;写代码 关联的按钮控件 监听事件函数 注意&#xff1a;要在start函数中再写一行HideMe函数&#xff0c;以便该面板能在一开始就能隐藏自己。 再在BeginPanel脚本中调用该函数即…

基于大语言模型LangChain框架:知识库问答系统实践

ChatGPT 所取得的巨大成功&#xff0c;使得越来越多的开发者希望利用 OpenAI 提供的 API 或私有化模型开发基于大语言模型的应用程序。然而&#xff0c;即使大语言模型的调用相对简单&#xff0c;仍需要完成大量的定制开发工作&#xff0c;包括 API 集成、交互逻辑、数据存储等…