「MobileNet V3」70 个犬种的图片分类

news2024/11/17 7:33:44

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 数据集与 Notebook
    • 环境准备
    • 数据集
    • 可视化
    • 模型
    • 预测
    • Loss 与评价指标


数据集与 Notebook

数据集:70 Dog Breeds-Image Data Set
Notebook:「MobileNet V3」70 Dog Breeds-Image Classification


环境准备

import warnings
warnings.filterwarnings('ignore')

禁用警告,防止干扰。

!pip install lightning --quiet

安装 PyTorch Lightning。

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})

导入常用的库,设置绘图风格。

import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

导入 PyTorch 相关的库。

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

导入 PyTorch Lightning 相关的库。

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
pl.seed_everything(seed, workers=True)

设置随机种子。


数据集

batch_size = 64

设置批次大小。

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

设置数据集的预处理。

train_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/train", transform=train_transform)
val_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/valid", transform=test_transform)
test_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/test", transform=test_transform)

读取数据集。

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

加载数据集。


可视化

class_names = train_dataset.classes
class_count = [train_dataset.targets.count(i) for i in range(len(class_names))]
df = pd.DataFrame({"Class": class_names, "Count": class_count})

plt.figure(figsize=(12, 20), dpi=100)
sns.barplot(x="Count", y="Class", data=df)
plt.tight_layout()
plt.show()

绘制训练集的类别分布。

训练集的类别分布

plt.figure(figsize=(12, 20), dpi=100)
images, labels = next(iter(val_loader))
for i in range(8):
    ax = plt.subplot(8, 4, i + 1)
    plt.imshow(images[i].permute(1, 2, 0).numpy())
    plt.title(class_names[labels[i]])
    plt.axis("off")
plt.tight_layout()
plt.show()

绘制训练集的样本。

训练集的样本


模型

class LitModel(pl.LightningModule):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.model = models.mobilenet_v3_large(weights="IMAGENET1K_V2")
        # for param in self.model.parameters():
        #     param.requires_grad = False
        self.model.classifier[3] = nn.Linear(self.model.classifier[3].in_features, num_classes)
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.precision = torchmetrics.Precision(task="multiclass", average="macro", num_classes=num_classes)
        self.recall = torchmetrics.Recall(task="multiclass", average="macro", num_classes=num_classes)
        self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(), lr=0.001, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
        self.log_dict(
            {
                "train_acc": self.accuracy(y_hat, y),
                "train_prec": self.precision(y_hat, y),
                "train_recall": self.recall(y_hat, y),
                "train_f1score": self.f1score(y_hat, y),
            },
            on_step=True,
            on_epoch=False,
            logger=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True)
        self.log_dict(
            {
                "val_acc": self.accuracy(y_hat, y),
                "val_prec": self.precision(y_hat, y),
                "val_recall": self.recall(y_hat, y),
                "val_f1score": self.f1score(y_hat, y),
            },
            on_step=False,
            on_epoch=True,
            logger=True,
        )

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        self.log_dict(
            {
                "test_acc": self.accuracy(y_hat, y),
                "test_prec": self.precision(y_hat, y),
                "test_recall": self.recall(y_hat, y),
                "test_f1score": self.f1score(y_hat, y),
            }
        )

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        y_hat = self(x)
        preds = torch.argmax(y_hat, dim=1)
        return preds

定义模型。

num_classes = len(class_names)
model = LitModel(num_classes=num_classes)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(
    monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
)
trainer = pl.Trainer(
    max_epochs=20,
    enable_progress_bar=True,
    logger=logger,
    callbacks=[early_stop_callback],
    deterministic=True,
)
trainer.fit(model, train_loader, val_loader)

训练模型。

trainer.test(model, val_loader)

测试模型。


预测

pred = trainer.predict(model, test_loader)
pred = torch.cat(pred, dim=0)
pred = pd.DataFrame(pred.numpy(), columns=["Class"])
pred["Class"] = pred["Class"].apply(lambda x: class_names[x])

plt.figure(figsize=(12, 20), dpi=100)
sns.countplot(y="Class", data=pred)
plt.tight_layout()
plt.show()

绘制预测结果的类别分布。

预测结果的类别分布


Loss 与评价指标

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"

plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="train_loss", data=metrics, label="Train Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_loss", data=metrics, label="Valid Loss", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()


plt.figure(figsize=(14, 12), dpi=100)

plt.subplot(2,2,1)
sns.lineplot(x=x_name, y="train_acc", data=metrics, label="Train Accuracy", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_acc", data=metrics, label="Valid Accuracy", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")

plt.subplot(2,2,2)
sns.lineplot(x=x_name, y="train_prec", data=metrics, label="Train Precision", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_prec", data=metrics, label="Valid Precision", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Precision")

plt.subplot(2,2,3)
sns.lineplot(x=x_name, y="train_recall", data=metrics, label="Train Recall", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_recall", data=metrics, label="Valid Recall", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Recall")

plt.subplot(2,2,4)
sns.lineplot(x=x_name, y="train_f1score", data=metrics, label="Train F1-Score", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_f1score", data=metrics, label="Valid F1-Score", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("F1-Score")

plt.tight_layout()
plt.show()

绘制 Loss 与评价指标的变化。

Loss

评价指标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1234093.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Softing mobiLink助力过程自动化——兼容HART、FF、PA的多协议接口工具

由于全球人口增加和气候变化等因素,“水”比以往任何时候都更具有价值。与此同时,环境法规和水处理标准也变得愈加严格。在这一大环境下,自来水公司不得不应对一些新的挑战,例如,更好地提高能源效率、最大程度地减少资…

Linux | 从虚拟地址到物理地址

前言 本章主要讲解虚拟地址是怎么转化成物理地址的,以及页表相关知识;本文环境默认为32位机器下;如果你连什么是虚拟地址都不知道可以先看看下面这篇文章; Linux | 进程地址空间-CSDN博客 一、概念补充 页表:是一种数据…

使用Arrays.asList与不使用的区别

在写算法的时候,遇到了有的题解使用的是Arrays.asList,也有的是直接新建一个List集合将元素加进去的。 看了一下算法的时间,两者居然相差了9秒。 算法原地址: 力扣(LeetCode)官网 - 全球极客挚爱的技术成长…

【发明专利】天洑软件再度收获六项国家发明专利授权

近日,南京天洑软件有限公司再度收获行业内六项国家发明专利授权,专利名称为:一种发电机绕组温度预警方法及装置(专利号:ZL 2022 1 1525605.3),一种CSTR系统的控制方法及装置(专利号&…

卷?中学生开始学习人工智能和大模型,附课件!

卷?中学生开始学习人工智能和大模型,附课件! 大家好,我是老章 发现一个面向11-14岁人群的AI课程,还附加了大模型内容,浏览了一遍它们的课件(还有面向教师的资源),感觉非…

ProtoBuf的使用

目录 1.创建.proto文件 1.1文件规范 1.2添加注释 1.3指定proto3语法 1.4package声明符 1.5定义消息(message) 1.6定义消息字段 2.编译contacts.proto文件 3.序列化与反序列化的使用 1.创建.proto文件 1.1文件规范 • 创建.proto文件时,⽂件命名应该使用全…

活动回顾 | 数字外贸私享会【上海站】成功举办

11月17日,由箱讯科技主办的数字外贸高端定制私享会【上海站】成功举办!本次会议的主题为“新模式、新商机、新政策”,外贸行业的老板、企业家们齐聚一堂,凝聚共识,共话数字外贸的新趋势和新机遇。 近年来,数…

webpack external 详解

作用:打包时将依赖独立出来,在运行时(runtime)再从外部获取这些扩展依赖,目的时解决打包文件过大的问题。 使用方法: 附上代码块 config.set(externals, {vue: Vue,vue-router: VueRouter,axios: axios,an…

C语言基本算法之选择排序

目录 概要: 代码如下 运行结果如下 概要: 它和冒泡排序一样,都是把数组元素按顺序排列,但是方法不同,冒泡排序是把较小值一个一个往后面移,选择排序则是直接找出最小值,可以这个说&#xff…

1、数仓模型概述

1、问:什么是数据模型? 数仓领域中的模型指的是数据模型,要和商业分析中的模型不同 数据模型就是数据组织和存储方法,它强调从业务、数据存取和使用的角度合理的存储数据 2、问:模型和表的区别? 表是数据物…

SpringBoot-Docker容器化部署发布

在生产环境都是怎么部署 Spring Boot? 打成 jar 直接一键运行打成 war 扔到 Tomcat 容器中运行容器化部署 一、准备Docker 在 CentOS7 上安装好 Docker 修改 Docker 配置,开启允许远程访问 Docker 的功能,开启方式很简单,修改 /usr/lib/s…

redis--高可用之持久化

redis高可用相关知识 在web服务器中,高可用是指服务器可以正常访问的时间,衡量的标准是在多长时间内可以提供正常服务(99.9%、99.99%、99.999%等等)。 但是在Redis语境中,高可用的含义似乎要宽泛一些,除了保证提供正常服务( 如主…

Openlayer【三】—— 绘制多边形GeoJson边界绘制

1.1、绘制多边形 在绘制多边形和前面绘制线有异曲同工之妙,多边形本质上就是由多个点组成的线然后连接组成的面,这个面就是最终的结果,那么这里使用到的是Polygon对象,而传给这个对象的值也是多个坐标,坐标会一个个的…

华为ac+fit漫游配置案例

Ap漫游配置: 其它配置上面一样,ap管理dhcp和业务dhcp全在汇聚交换机 R1: interface GigabitEthernet0/0/0 ip address 11.1.1.1 255.255.255.0 ip route-static 12.2.2.0 255.255.255.0 11.1.1.2 ip route-static 192.168.0.0 255.255.0.0 11.1.1.2 lsw1: vlan batch 100 200…

《栈和队列》的模拟实现(顺序栈) (链队列)

目录 前言: 栈和队列: 栈: 队列: 模拟实现《栈》: 1.typedef数据类型 2.初始化栈 3.销毁栈 4.入栈 5.出栈 6.取栈顶元素 7.判断栈是否为空 8.栈的大小 9.打印栈 模拟实现《队列》 : 1.type…

基于C#实现KMP算法

一、BF 算法 如果让你写字符串的模式匹配,你可能会很快的写出朴素的 bf 算法,至少问题是解决了,我想大家很清楚的知道它的时间复杂度为 O(MN),原因很简单,主串和模式串失配的时候,我…

做黄金代理可以代理什么品种?

近几年,黄金代理这个职业发展的比较迅猛,主要是受金融环境越来越稳定、金融投资越来越发达的大势所推动。那些有意想做黄金代理的朋友就会有疑问,做了黄金代理可以代理什么品种的? 其实广义上来说,黄金代理有很多种&am…

【Pytorch】Visualization of Feature Maps(1)

学习参考来自 CNN可视化Convolutional Featureshttps://github.com/wmn7/ML_Practice/blob/master/2019_05_27/filter_visualizer.ipynb 文章目录 filter 的激活值 filter 的激活值 原理:找一张图片,使得某个 layer 的 filter 的激活值最大&#xff0c…

C#核心笔记——(二)C#语言基础

一、C#程序 1.1 基础程序 using System; //引入命名空间namespace CsharpTest //将以下类定义在CsharpTest命名空间中 {internal class TestProgram //定义TestProgram类{public void Test() { }//定义Test方法} }方法是C#中的诸多种类的函数之一。另一种函数*,还…

机器学习介绍与分类

随着科学技术的不断发展,机器学习作为人工智能领域的重要分支,正逐渐引起广泛的关注和应用。本文将介绍机器学习的基本概念、原理和分类方法,帮助读者更好地理解和应用机器学习技术。 一、机器学习的基本概念 机器学习是一种通过从数据中学…