使用Resnet进行图像分类训练

news2024/11/24 5:29:38

本文仅给出最基础的baseline进行图像分类训练,后续可在此代码基础上对模型结构进行修改。

一、图像分类数据集

现有一份图像类别数据集,类别为Y和N,数据目录如下:

/datasets/data/
|-- train/
|   |-- Y/
|   |-- N/

划分训练集和测试集
split_dataset.py

import os
import random
import shutil
import argparse

# 创建命令行参数解析器
parser = argparse.ArgumentParser(description='移动验证集样本到验证集文件夹')
parser.add_argument('--name', required=True, help='数据集名称')
parser.add_argument('--val_ratio', type=float, default=0.2, help='验证集比例')
args = parser.parse_args()

# 数据集路径
dataset_name = args.name
dataset_path = f'/datasets/{dataset_name}'
train_path = os.path.join(dataset_path, 'train')
val_path = os.path.join(dataset_path, 'val')

# 创建验证集文件夹
os.makedirs(val_path, exist_ok=True)
os.makedirs(os.path.join(val_path, 'Y'), exist_ok=True)
os.makedirs(os.path.join(val_path, 'N'), exist_ok=True)

# 计算验证集的数量
val_ratio = args.val_ratio  # 验证集比例
val_size_Y = int(len(os.listdir(os.path.join(train_path, 'Y'))) * val_ratio)
val_size_N = int(len(os.listdir(os.path.join(train_path, 'N'))) * val_ratio)

# 随机选择验证集样本
random.seed(42)
val_samples_Y = random.sample(os.listdir(os.path.join(train_path, 'Y')), val_size_Y)
val_samples_N = random.sample(os.listdir(os.path.join(train_path, 'N')), val_size_N)

# 将验证集样本移动到验证集文件夹
for sample in val_samples_Y:
    src_path = os.path.join(train_path, 'Y', sample)
    dst_path = os.path.join(val_path, 'Y', sample)
    shutil.move(src_path, dst_path)

for sample in val_samples_N:
    src_path = os.path.join(train_path, 'N', sample)
    dst_path = os.path.join(val_path, 'N', sample)
    shutil.move(src_path, dst_path)

调用方式:

# 按9:1划分训练集和验证集,并整理好数据目录
python split_dataset.py --name data --val_ratio 0.1

切分后的样式

/datasets/data/
|-- train/
|   |-- Y/
|   |-- N/
|-- val/
|   |-- Y/
|   |-- N/

二、模型构建

调用Resnet模型

class ClassifyModel(nn.Module):
    def __init__(self, num_classes):
        super(ClassifyModel, self).__init__()

        # Load the pre-trained ResNet model
        self.model = models.resnet50(pretrained=True)
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, num_classes)  # Replace the final fully connected layer

    def forward(self, x):
        return self.model(x)

三、模型训练

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

# Set device
device = torch.device('cuda')

# Define data transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),  # Randomly flip images horizontally for data augmentation
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

# Set the path to the train directory
data_dir = './data/train'

# Create train dataset
train_data = datasets.ImageFolder(data_dir, transform=data_transforms['train'])

# Set the batch size and create a data loader
batch_size = 16
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)

# Load the pre-trained model
model = ClassifyModel(num_classes=2).to(device)  # Instantiate the model
model.train()  # Set the model to training mode

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Print training progress
    train_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), './model/Resnet-ClassifyModel.pth')

训练结果展示如下:
在这里插入图片描述

四、模型验证

# Set device
device = torch.device('cuda')

# Define data transformations
data_transforms = {
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
}

# Set the path to the validation directory
val_dir = './data/val'

# Create validation dataset
val_data = datasets.ImageFolder(val_dir, transform=data_transforms['val'])

# Set the batch size and create a data loader
batch_size = 16
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=0)

# Load the pre-trained ResNet model
model = ClassifyModel(num_classes=2)
model = model.to(device)

# Load the trained model
model.load_state_dict(torch.load('./model/Resnet-ClassifyModel.pth'))
model.eval()  # Set the model to evaluation mode

# Define class names
class_names = val_data.classes

# Validation loop
correct_predictions = 0
total_images = 0
true_positives = 0
false_negatives = 0

with torch.no_grad():
    # Iterate through the validation images
    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        # Process each image in the batch
        for i in range(images.size(0)):
            image_path = val_data.imgs[total_images + i][0]
            predicted_label = predicted[i].item()
            true_label = labels[i].item()

            # Calculate accuracy and recall
            if predicted_label == true_label:
                correct_predictions += 1
                if true_label == class_names.index('Y'):
                    true_positives += 1
            elif true_label == class_names.index('Y'):
                false_negatives += 1

            # Print the predicted label with the complete image path
            print(f"Image: {image_path}, Predicted Label: {class_names[predicted_label]}")

        total_images += images.size(0)

# Calculate accuracy and recall
accuracy = correct_predictions / total_images
recall = true_positives / (true_positives + false_negatives)

# Print accuracy and recall
print(f"Accuracy: {accuracy:.4f}")
print(f"Recall: {recall:.4f}")

验证结果展示:
在这里插入图片描述在这里插入图片描述在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1068715.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

超自动化加速落地,助力运营效率和用户体验显著提升|爱分析报告

RPA、iPaaS、AI、低代码、BPM、流程挖掘等在帮助企业实现自动化的同时,也在构建一座座“自动化烟囱”。自动化工具尚未融为一体,协同价值没有得到释放。Gartner于2019年提出超自动化(Hyperautomation)概念,主要从技术组…

法律战爆发:“币安退出俄罗斯引发冲击波“

币安是全球最大的加密货币交易所之一,经历了几个月的艰难时期,面临着各种法律挑战,最近将其俄罗斯分公司的所有资产出售给了一家几天前才成立的公司。 这家主要交易所的麻烦始于 6 月份,当时美国证券交易委员会 (SEC)起…

PyTorch Lightning - LightningModule 训练逻辑 (training_step) 异常处理 try-except

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/133673820 在使用 LightningModule 框架训练模型时,因数据导致的训练错误,严重影响训练稳定性,因此需要使用 t…

消费者的力量:跨境电商如何满足新一代的需求

当代跨境电商行业正处于高速发展的阶段,而新一代消费者正在塑造这一行业的未来。他们的需求和消费行为发生了巨大变化,对于跨境电商来说,满足这一新一代消费者的需求至关重要。本文将探讨新一代消费者的需求以及跨境电商如何满足这些需求的方…

【Bond随你温故Azure Architecture】之HADR篇

上次复盘数据保护策略还是在《数据需要找回怎么办?我们如何选择正确的恢复/退回方式?》探讨了在application&DB层面上,不同level的数据保护有不同策略。而它也恰好是今天HA&DR版图的一角(RDBMS部分)&#xff0…

【机器学习】svm

参考 sklearn中SVC中的参数说明与常用函数_sklearn svc参数-CSDN博客https://blog.csdn.net/transformed/article/details/90437821 参考PYthon 教你怎么选择SVM的核函数kernel及案例分析_clfsvm.svc(kernel)-CSDN博客https://blog.csdn.net/c1z2w3456789/article/details/10…

【Python_PySide2学习笔记(十六)】多行文本框QPlainTextEdit类的的基本用法

多行文本框QPlainTextEdit类的的基本用法 前言正文1、创建多行文本框2、多行文本框获取文本3、多行文本框获取选中文本4、多行文本框设置提示5、多行文本框设置文本6、多行文本框在末尾添加文本7、多行文本框在光标处插入文本8、多行文本框清空文本9、多行文本框拷贝文本到剪贴…

什么是EJB以及和Spring Framework的区别

👔 前言 EJB,对于新生代程序员来说,是一个既熟悉又陌生的名词,EJB,大家都听说过,但是不一定都了解过,EJB是一种开发规范,而不是像Spring Framework一样是一个开源框架,E…

卫星/RedCap/高算力/解决方案/创新金奖……移远通信为IOTE 2023再添新活力

9月20日,IOTE 2023第二十届国际物联网展深圳场震撼来袭。 作为IOTE多年的“老朋友”,移远通信在参展当天,不仅有5G RedCap、卫星通信、高算力、车载等高性能产品及终端展出,还携智慧出行、智慧生活、智慧能源、工业互联网等多领域…

redis集群是符合cap中的ap还是cp

近期忽然间考虑到了这个问题。 cap 理论 cap是实现分布式系统的思想。 由3个元素组成。 Consistency(一致性) 在任何对等 server 上读取的数据都是最新版,不会读取出旧数据。比如 zookeeper 集群,从任何一台节点读取出来的数据…

SpringBoot 如何配置 OAuth2 认证

在Spring Boot中配置OAuth2认证 OAuth2是一种用于授权的开放标准,允许应用程序安全地访问用户的资源。Spring Boot提供了强大的支持,使得在应用程序中配置OAuth2认证变得非常容易。本文将介绍如何在Spring Boot中配置OAuth2认证,以便您可以在…

ThreeJS-3D教学六-物体位移旋转

之前文章其实也有涉及到这方面的内容,比如在ThreeJS-3D教学三:平移缩放物体沿轨迹运动这篇中,通过获取轨迹点物体动起来,其它几篇文章也有旋转的效果,本篇我们来详细看下,另外加了tween.js知识点&#xff0…

基于SpringBoot的靓车汽车销售网站

目录 前言 一、技术栈 二、系统功能介绍 用户信息管理 车辆展示管理 车辆品牌管理 用户交流管理 购物车 用户交流 我的订单管理 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的…

除静电离子风刀的工作原理及应用

除静电离子风刀是一种能够产生高速气流并带有离子的设备,主要用于去除物体表面的静电。它的工作原理是通过离子产生器产生大量负离子,并通过高压电场将离子加速,使其成为一股高速气流,从而将静电荷从物体表面中除去。 除静电离子…

阿里云 linux tomcat 无法访问方法

1、阿里云放行tomcat端口 例如7077端口号 2、linux 命令行防火墙 设置端口打开 以下命令查看是否开启指定端口 firewall-cmd --list-ports以下命令添加指定端口让防火墙放行 firewall-cmd --zonepublic --add-port3306/tcp --permanent以下命令重新启动防火墙 systemctl re…

聊一下读完“优势成长”这本书后感

(优势成长上) (优势成长下) 最近读完了一本个人觉得还可以的书,这本书是一位新东方老师,帅键翔老师写的 整本书概括起来,最重要一点就是找到自己的优势,然后利用自己的优势,去挖掘自己的潜力,发现新大陆 能适应时代变化的,是“新木桶原理”&a…

JAVA中解析package、import、class、this关键字

一、前言 代码写的多了有时候我们就慢慢忽视了最简单,最基本的东西。比如一个类中最常见出现的package、import、class、this关键字。我们平时很少追究它的含义或者从来不会深究为什么需要这些关键字。不需要这些关键字,又会怎样。这边博文就简单介绍一下…

设计模式 - 观察者模式

目录 一. 前言 二. 实现 三. 优缺点 一. 前言 观察者模式属于行为型模式。在程序设计中,观察者模式通常由两个对象组成:观察者和被观察者。当被观察者状态发生改变时,它会通知所有的观察者对象,使他们能够及时做出响应&#xf…

攻防世界 Web_python_template_injection SSTI printer方法

这题挺简单的 就是记录一下不同方法的rce python_template_injection ssti了 {{.__class__.__mro__[2].__subclasses__()}} 然后用脚本跑可以知道是 71 {{.__class__.__mro__[2].__subclasses__()[71]}} 然后直接 init {{.__class__.__mro__[2].__subclasses__()[71].__i…

18373-2013 印制板用E玻璃纤维布 知识梳理

声明 本文是学习GB-T 18373-2013 印制板用E玻璃纤维布.pdf而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了印制板用E 玻璃纤维布的定义、代号与规格、要求、试验方法、检验规则、标志、包装、 运输和贮存。 本标准适用于以E 玻璃…