基于Pytorch框架的深度学习ConvNext神经网络宠物猫识别分类系统源码

news2025/2/27 13:18:10

 第一步:准备数据

12种宠物猫类数据:self.class_indict = ["阿比西尼猫", "豹猫", "伯曼猫", "孟买猫", "英国短毛猫", "埃及猫", "缅因猫", "波斯猫", "布偶猫", "克拉特猫", "泰国暹罗猫", "加拿大无毛猫"]

,总共有2160张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个ConvNext网络,其原理介绍如下:

ConvNext (Convolutional Network Net Generation), 即下一代卷积神经网络, 是近些年来 CV 领域的一个重要发展. ConvNext 由 Facebook AI Research 提出, 仅仅通过卷积结构就达到了与 Transformer 结构相媲美的 ImageNet Top-1 准确率, 这在近年来以 Transformer 为主导的视觉问题解决趋势中显得尤为突出.

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

from my_dataset import MyDataSet
from model import convnext_tiny as create_model
from utils import read_split_data, create_lr_scheduler, get_params_groups, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    print(f"using {device} device.")

    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    tb_writer = SummaryWriter()

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    img_size = 224
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}


    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    model = create_model(num_classes=args.num_classes).to(device)

    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)["model"]
        # 删除有关分类类别的权重
        for k in list(weights_dict.keys()):
            if "head" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    # pg = [p for p in model.parameters() if p.requires_grad]
    pg = get_params_groups(model, weight_decay=args.wd)
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd)
    lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,
                                       warmup=True, warmup_epochs=1)

    best_acc = 0.
    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch,
                                                lr_scheduler=lr_scheduler)

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)

        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        if best_acc < val_acc:
            torch.save(model.state_dict(), "./weights/best_model.pth")
            best_acc = val_acc


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=12)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--wd', type=float, default=5e-2)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default=r"G:\demo\data\cat_data_sets_models\cat_12_train")

    # 预训练权重路径,如果不想载入就设置为空字符
    # 链接: https://pan.baidu.com/s/1aNqQW4n_RrUlWUBNlaJRHA  密码: i83t
    parser.add_argument('--weights', type=str, default='./convnext_tiny_1k_224_ema.pth',
                        help='initial weights path')
    # 是否冻结head以外所有权重
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习ConvNext神经网络宠物猫识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1823532.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ARM单片机使用CAN总线部署BootLoader

1.引言 1.1.单片机开发BootLoader意义 单片机开发BootLoader的原因主要与其在嵌入式系统中的关键作用有关。BootLoader是硬件启动的引导程序&#xff0c;它在操作系统内核或用户应用程序运行之前执行。以下是单片机开发BootLoader的主要原因&#xff1a; 初始化硬件设备&…

2024年春季学期《算法分析与设计》练习13

A:菱形图案 题目描述 KiKi学习了循环&#xff0c;BoBo老师给他出了一系列打印图案的练习&#xff0c;该任务是打印用“*”组成的菱形图案。 输入 多组输入&#xff0c;一个整数&#xff08;2~20&#xff09;。 输出 针对每行输入&#xff0c;输出用“*”组成的菱形&#xff0c;…

Java 18新特性概览与解读

随着技术的不断进步&#xff0c;Java作为最流行的编程语言之一&#xff0c;也在持续地进行版本更新&#xff0c;为开发人员提供更强大、更高效的工具和特性。Java 18作为最新的稳定版本&#xff0c;引入了一系列引人注目的新特性和改进。以下是对Java 18中一些主要新特性的详细…

Petalinux由于网络原因产生的编译错误(3)-qemu-xilinx-system-native 失败

1 获取qemu-xilinx-system-native 失败 编译时遇到qemu-xilinx-system-native 包获取失败&#xff0c;如下图所示&#xff1a; 解决这种错误方法如下&#xff1a; 进入Petalinux 工程&#xff0c;编辑工程下的 project-spec/meta-user/conf/petalinuxbsp.conf 文件&#xff0…

什么是DMZ?路由器上如何使用DMZ?

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 DMZ 📒🚀 DMZ的应用场景💡 路由器设置DMZ🎈 注意事项 🎈⚓️ 相关链接 ⚓️📖 介绍 📖 在网络管理中,DMZ(Demilitarized Zone,隔离区)是一个特殊的网络区域,常用于将公共访问和内部网络隔离开来。DMZ功能允许…

关联规则延伸之协同过滤

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 一、协同过滤1、含义2、策略 二、基于用户的协同过滤1、寻找相似偏好的用户2、欧式距离及系数3、皮尔逊系数4、其他系数5、算法步骤6、局限性 三、基于物品的协同过滤1、…

php实现一个简单的MySQL分页

一、案例演示: 二、php 代码 <?php $servername = "localhost"; // MySQL服务器名称或IP地址 $username = "root"; // MySQL用户名 $password = "123456"; // MySQL密码 $dbname = "test"; // 要连接的数据…

外盘黄金期货需要注意什么?

为大家整理了关于黄金做单的五大原则&#xff0c;相信对于新手投资者来说肯定会产生一定的帮助。  1、看多空&#xff1a;主要有两种方法&#xff0c;基本面判断和技术面判断&#xff0c;基本面判断&#xff0c;主要是借助基本信息面&#xff0c;如政策。供需&#xff0c;产量…

文字不换行了

单行文字不换行 添加... .line1Text {overflow: hidden;text-overflow: ellipsis;white-space: nowrap;cursor: pointer; } 双行文字换行添加... .line2Text {overflow: hidden;display: -webkit-box;-webkit-box-orient: vertical;-webkit-line-clamp: 2;text-overflow: e…

向量化在人工智能领域的深度实践:技术革新与效率提升

在人工智能&#xff08;AI&#xff09;的飞速发展中&#xff0c;向量化技术作为一种基础且关键的数据处理手段&#xff0c;正日益受到广泛关注。向量化是将文本、图像、声音等数据转换为数值向量的过程&#xff0c;这些向量能够表示原始数据的特征和语义信息&#xff0c;为深度…

【gtest】 C++ 的测试框架之使用 gtest 编写单元测试

目录 &#x1f30a;前言 &#x1f30a;使用 cmake 启动并运行 gtest &#x1f30d;1. 设置项目 &#x1f30d;2. 创建并运行二进制文件 &#x1f30a;1. gtest 入门 &#x1f30d;1.1 断言&#xff08;assertions&#xff09; &#x1f30d;1.2 简单测试 &#x1f30d;…

进程(Processes)

在 Elixir 中&#xff0c;所有代码都在进程内运行。进程彼此隔离&#xff0c;彼此并发运行并通过消息传递进行通信。进程不仅是 Elixir 中并发的基础&#xff0c;而且还提供了构建分布式和容错程序的方法。 Elixir 的进程不应与操作系统进程混淆。Elixir 中的进程在内存和 CPU…

如何使用CCS9.3打开CCS3.0工程

如何使用CCS9.3打开CCS3.0工程 点菜单栏上的project&#xff0c;选择Import Legacy CCSv3.3 Porjects…&#xff0c;弹出对话框&#xff0c;通过Browse…按钮导入一个3.3版本的工程项目&#xff1b; 选择.pjt文件&#xff0c;选择Copy projects into worlkspace 右击选择P…

二分查找总结:算法原理,适用题型,经典题单

二分查找 感谢灵神的题单 题单&#xff1a;分享丨【题单】二分算法&#xff08;二分答案/最小化最大值/最大化最小值/第K小&#xff09; - 力扣&#xff08;LeetCode&#xff09; 每天四道题&#xff0c;大概用时一个月刷完&#xff0c;如果没有时间的同学可以学习我总结的算…

中信所:中国科学技术信息研究所-国家科技图书文献中心

文章目录 1. Intro2. History3. Staffing level4. Facility resources5. Scientific achievementsReference国家工程技术数字图书馆 National Engineering and Technology Digital Library 1. Intro 中国科学技术信息研究所(以下简称中信所)是在周恩来总理、聂荣臻元帅等党和…

【CS.AL】算法复杂度分析 —— 渐进符号表示法

文章目录 1 概述2 渐进符号详解2.1 大O符号&#xff08;O&#xff09;2.2 Ω符号&#xff08;Ω&#xff09;2.3 Θ符号&#xff08;Θ&#xff09;2.4 o符号&#xff08;o&#xff09;2.5 ω符号&#xff08;ω&#xff09; 3 具体例子3.1 插入排序&#xff08;Insertion Sort…

GitHub强制双重验证、二次验证之下载微软Authenticator

Download Microsoft Authenticator 如上图&#xff0c;安卓手机在国内&#xff0c;基本没有下载使用的可能。 下载 Microsoft Authenticator 如上图&#xff0c;找到了国内下载渠道&#xff0c;但仅联想商店的新版本适合使用&#xff0c;下载安装即可。

C++11初始化列表打包器initializer_list

有时我们无法提前知道应该向函数传递几个实参。为了编写能处理不同数量实参的函数我们使用initializer_list Cplusplus中的定义&#xff1a; 其里面有三个成员函数 也就是说他是支持迭代器的&#xff0c;支持迭代器就支持范围for 图像理解 函数类型 void Test1_initializer_li…

一手洞悉巴西slot游戏包投放本土网盟CPI广告优势

一手洞悉巴西slot游戏包投放本土网盟CPI广告优势 在巴西这片热土上&#xff0c;slot游戏包的投放本土网盟CPI广告是一项既充满挑战又富有机遇的任务。CPI&#xff08;Cost Per Install&#xff09;广告模式&#xff0c;即按安装付费&#xff0c;已经成为许多游戏开发商推广产品…

ios18新功能:设专属“咒语”动动嘴巴即可操作iphone

苹果 iOS / iPadOS 18 系统引入了“人声快捷指令”&#xff08;Vocal Shortcuts&#xff09;功能&#xff0c;即便iPhone、iPad 处于锁屏状态下&#xff0c;也能响应你的语音命令。 苹果官方对“人声快捷指令”的介绍如下&#xff1a;iPhone 和 iPad 用户可以通过人声快捷指令…