基于Pytorch框架的深度学习RegNet神经网络二十五种宝石识别分类系统源码

news2024/11/18 2:37:04

 第一步:准备数据

25种宝石数据,总共800张:

{ "0": "Alexandrite","1": "Almandine","2": "Benitoite","3": "Beryl Golden","4": "Carnelian", "5": "Cats Eye","6": "Danburite", "7": "Diamond","8": "Emerald","9": "Fluorite","10": "Garnet Red","11": "Hessonite","12": "Iolite","13": "Jade","14": "Kunzite","15": "Labradorite","16": "Malachite","17": "Onyx Black","18": "Pearl","19": "Quartz Beer","20": "Rhodochrosite","21": "Sapphire Blue","22": "Tanzanite","23": "Variscite","24": "Zircon"}

第二步:搭建模型

本文选择一个RegNet网络,其原理介绍如下:

该论文提出了一个新的网络设计范式,并不是专注于设计单个网络实例,而是设计了一个网络设计空间network design space。整个过程类似于经典的手工网络设计,但被提升到了设计空间的水平。使用本文的方法,作者探索了网络设计的结构方面,并得到了一个由简单、规则的网络构成了低维设计空间并称之为RegNet。RegNet设计空间提供了各个范围flop下简单、快速的网络。在类似的训练设置和flops下,RegNet的效果超过了EfficientNet同时在GPU上快了5倍

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import create_regnet
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    tb_writer = SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 如果存在预训练权重则载入
    model = create_regnet(model_name=args.model_name,
                          num_classes=args.num_classes).to(device)
    # print(model)

    if args.weights != "":
        if os.path.exists(args.weights):
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后的全连接层外,其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("train {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        scheduler.step()

        # validate
        acc = evaluate(model=model,
                       data_loader=val_loader,
                       device=device)

        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        tags = ["loss", "accuracy", "learning_rate"]
        tb_writer.add_scalar(tags[0], mean_loss, epoch)
        tb_writer.add_scalar(tags[1], acc, epoch)
        tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=25)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lrf', type=float, default=0.01)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default=r"G:\demo\data\gemstone\archive_train")
    parser.add_argument('--model-name', default='RegNetY_400MF', help='create model name')

    # 预训练权重下载地址
    # 链接: https://pan.baidu.com/s/1XTo3walj9ai7ZhWz7jh-YA  密码: 8lmu
    parser.add_argument('--weights', type=str, default='regnety_400mf.pth',
                        help='initial weights path')
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习RegNet神经网络二十五种宝石识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1708194.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ctfshow web入门 黑盒测试

web380 这里文章看的我好有感触 但是影响做题 扫描一下 访问flag.php啥也没有再访问page.php page.php?idflagweb381 扫出来page.php但是没啥用哇,查看源代码 这些文件挨个试发现啥也没,最后仔细对比发现其实都是layui,然后尝试着访问…

神经网络的工程基础(零)——PyTorch基础

相关说明 这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。 本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/gradient_descent.ipynb 本文将介绍PyTorch的基础。…

Centos安装,window、ubuntus双系统基础上安装Centos安装

文章目录 前言一、准备工作二、开始安装1、2、首先选择DATE&TIME2、选择最小安装3、 选择安装位置 总结 前言 因工作需要,我需要在工控机上额外装Centos7系统,不过我是装在机械硬盘上了不知道对性能是否有影响,若有影响,后面…

MyBatis系统学习篇 - MyBatis逆向工程

MyBatis的逆向工程是指根据数据库表结构自动生成对应的Java实体类、Mapper接口和XML映射文件的过程。逆向工程可以帮助开发人员快速生成与数据库表对应的代码,减少手动编写重复代码的工作量。 我们在MyBatis中通过逆向工具来帮我简化繁琐的搭建框架,减少…

macOS上用Qt creator编译并跑shotcut

1 简介 Shotcut是一个开源的跨平台的视频编辑软件,支持WIN/MACOS/LINUX等平台,由于该项目的编译较为麻烦,踩坑几许,因此写此文章记录完整编译构建过程,后续按此法编译,可减少走弯路,提高生产力。…

vue+antd实践:在输入框光标处插入内容

今天来看一个很简单的需求。 需求描述:在输入框光标处,插入指定的内容。 效果如下: 实现思路:刚开始还在想怎么获取光标的位置,但是发现所做的项目是基于vue3antd组件,那么不简单了嘛,只要调…

SpringBoot Redis 扩展高级功能

环境:SpringBoot2.7.16 Redis6.2.1 1. Redis消息发布订阅 Spring Data 为 Redis 提供了专用的消息传递集成,其功能和命名与 Spring Framework 中的 JMS 集成类似。Redis 消息传递大致可分为两个功能区域: 信息发布 信息订阅 这是一个通常…

[XYCTF新生赛]-Reverse:你是真的大学生吗?解析(汇编异或逆向)

无壳 查看ida 没有办法反汇编,只能直接看汇编了。 这里提示有输入,输入到2F地址后,然后从后往前异或,其中先最后一个字符与第一个字符异或。这里其实也有字符串的长度,推测应该是cx自身异或之后传给了cx 完整exp&am…

三分钟轻松搞定内容,2024视频号最新AI自动生成影视解说,,百分之百过原创, 月入1万+

在这个数字时代,我们有幸见证了AI技术对创新的推动。现如今,一个崭新的平台出现了,它能让你用AI软件在短短3分钟内制作完成一段影视解说,而且由于这个平台尚属于新兴,竞争者稀少,提供了一个广阔的机遇天地。…

如何将云服务器上操作系统由centos切换为ubuntu

本文将介绍如何将我们购买的云服务器上之前装的centos切换为ubuntu,云服务器以华为云为例,要切换的ubuntu版本为ubuntu20.04。 参考官方文档:切换操作系统_弹性云服务器 ECS (huaweicloud.com) 首先打开华为云官网,登录后点击右…

多模态MLLM都是怎么实现的(9)-时序LLM是怎么个事儿?

时序预测这东西大家一般不陌生,随便举几个例子 1- 金融,比如预测股票(股市有风险,入市需谨慎),纯用K线做,我个人不太推荐 2- 天气,比如预测云图,天气预报啥的 3- 交通,早晚高峰,堵车啥的,车啥时候加油,啥时候充电之类的 4- 医疗,看你病史和喝酒的剂量建模,看你会…

斯坦福大学ALOHA家务机器人团队发布了最新研究成果—YAY Robot语言交互式操作系统

ALOHA YAY 演示视频-智能佳 斯坦福的ALOHA家务机器人团队,发布了最新研究成果—Yell At Your Robot(简称YAY),有了它,机器人的“翻车”动作,只要喊句话就能纠正了! 标ALOHA2协作平台题 而且机器…

Shell脚本基本命令

文件名后缀.sh 编写shell脚本一定要说明一下在#!/bin/bash在进行编写。命令选项空格隔开。Shell脚本是解释的语言,bash 文件名即可打印出编写的脚本。chmod给权限命令。如 chmod 0777 文件名意思是给最高权限。 注意:count赋值不能加空格。取消变量可在变…

基于Spring 框架中的@Async 注解实现异步任务

Async 是 Spring 框架中的一个注解,用于实现方法级别的异步执行。使用 Async 可以让你的代码在非当前线程中执行,从而提高应用的并发性能。 1、 启用异步支持 在 Spring 应用的主配置类或任何其他配置类上添加 EnableAsync 注解来开启异步任务的支持 …

13.Redis之数据库管理redis客户端JAVA客户端

1.数据库管理 mysql 中有一个重要的概念,database 1个 mysql 服务器上可以有很多个 database1个 database 上可以有很多个 表mysql 上可以随心所欲的 创建/删除 数据库~~ Redis 提供了⼏个⾯向 Redis 数据库的操作,分别是 dbsize、select、flushdb、flushall 命令…

2024年中国金融行业网络安全市场全景图

网络安全一直是国家安全的核心组成部分,特别是在金融行业,金融机构拥有大量的敏感数据,包括个人信息、交易记录、财务报告等,这些数据的安全直接关系到消费者的利益和金融市场的稳定,因此金融行业在网络安全建设领域一…

Java—选择排序

选择排序是一种简单但高效的排序算法。它的基本思想是从未排序的部分中选择最小(或最大)的元素,并将其放置在已排序部分的末尾。 实现步骤 具体实现选择排序的步骤如下: 遍历数组:从数组的第一个元素开始&#xff0…

cesium绘制区域编辑

npm 安装也是可以的 #默认安装最新的 yarn add cesium#卸载插件 yarn remove cesium#安装指定版本的 yarn add cesium1.96.0#安装指定版本到测试环境 yarn add cesium1.96.0 -D yarn install turf/turf <template><div id"cesiumContainer"></div&…

从零开始学React--环境搭建

React官网 快速入门 – React 中文文档 1.搭建环境 下载nodejs,双击安装 nodejs下载地址 更新npm npm install -g npm 设置npm源&#xff0c;加快下载速度 npm config set registry https://registry.npmmirror.com 创建一个react应用 npx create-react-app react-ba…

QT 自定义协议TCP传输文件

后面附带实例的下载地址 一、将文件看做是由:文件头+文件内容组成,其中文件头包含文件的一些信息:文件名称、文件大小等。 二、文件头单独发送,文件内容切块发送。 三、每次发送信息格式:发送内容大小、发送内容类型(文件头或是文件块内容)、文件块内容。 四、效果展…