基于深度学习的图像分类或识别系统(含全套项目+PyQt5界面)

news2024/9/20 9:38:56

目录

一、项目界面

二、代码实现

1、网络代码

2、训练代码

3、评估代码

4、结果显示

三、项目代码


一、项目界面

二、代码实现

1、网络代码

该网络基于残差模型修改

import torch
import torch.nn as nn
import torchvision.models as models


class resnet18(nn.Module):
    def __init__(self, num_classes=5, pretrained=False):
        super(resnet18, self).__init__()

        # 加载ResNet-18模型
        self.model = models.resnet18(pretrained=pretrained)
        # print(self.model)

        # 更改全连接层以输出自定义类别数量
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

if __name__ == '__main__':
    # 示例用法
    num_classes = 10
    model = resnet18(num_classes=num_classes)

    # 打印模型以确认更改
    print(model)
2、训练代码
import os
import torch
import torch.nn as nn
from models.resnet18 import resnet18
from utils.utils import train_and_val,plot_acc,plot_loss,plot_lr,MyDataset
import numpy as np
from torch.utils.data import DataLoader
import glob
import pandas as pd
import config

def main(epochs,model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not os.path.exists(config.save_results):
        os.makedirs(config.save_results)

    # ----------------------------模型加载-------------------------
    model = model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=5,
                                                gamma=0.9)  # 每经过5个epoch,学习率乘以0.9
    # ------------------------------------------------------------

    # ---------------------------加载数据--------------------------
    im_train_list = glob.glob(config.train_path + "/*/*." + config.img_)
    im_val_list = glob.glob(config.val_path + "/*/*." + config.img_)

    train_dataset = MyDataset(im_train_list, config.label_names)
    val_dataset = MyDataset(im_val_list, config.label_names)

    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)

    val_loader = DataLoader(val_dataset,
                             batch_size=config.batch_size,
                             shuffle=False)

    print("num of train", len(train_dataset))
    print("num of val", len(val_loader))
    # ------------------------------------------------------------

    # ---------------------------网络训练--------------------------
    history = train_and_val(epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,config.save_results,device)
    df = pd.DataFrame(history) # 转换为DataFrame
    df.to_excel(os.path.join(config.save_results,'history.xlsx'), index=False) # 保存为 Excel 文件

    plot_loss(np.arange(0,epochs),config.save_results, history)
    plot_acc(np.arange(0,epochs),config.save_results, history)
    plot_lr(np.arange(0,epochs),config.save_results, history)

if __name__ == '__main__':
    model = resnet18(num_classes=config.num_classes)
    main(config.epochs,model)
3、评估代码
from sklearn.metrics import classification_report
import torch
import os
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from models.resnet18 import resnet18
import matplotlib.pyplot as plt
from utils.utils import MyDataset,reports
from torch.utils.data import DataLoader
import seaborn as sns
import glob
import config

def main(model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ----------------------------模型加载-------------------------
    model = model.to(device)
    checkpoint = torch.load(os.path.join(config.save_results,"best.pth"))
    model.load_state_dict(checkpoint, strict=True)
    model.eval()
    # ------------------------------------------------------------

    # ---------------------------加载数据--------------------------
    im_test_list = glob.glob(config.test_path + "/*/*." + config.img_)
    test_dataset = MyDataset(im_test_list, config.label_names)
    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             shuffle=False)
    print("num of test", len(test_loader))
    # ------------------------------------------------------------

    act = nn.Softmax(dim=-1)
    y_true, y_pred = [], []
    with torch.no_grad():
        with tqdm(total=len(test_loader)) as pbar:
            for images, labels in test_loader:
                outputs = act(model(images.to(device)))
                _, predicted = torch.max(outputs, 1)
                predicted = predicted.cpu()
                y_pred.extend(predicted.numpy())
                y_true.extend(labels.cpu().numpy())
                pbar.update(1)

    oa,aa,kappa,cls,cm = reports(y_true, y_pred)
    cr = classification_report(y_true, y_pred, target_names=config.label_names.values(), output_dict=True)

    df = pd.DataFrame(cr).transpose()
    df.to_csv(os.path.join(config.save_results,"classification_report.csv"), index=True)
    print("Accuracy is :", oa)

    with open(os.path.join(config.save_results,"results.txt"), "a") as file:
        file.write('OA:{:.4f} AA:{:.4f} kappa:{:.4f}\ncls:{}\n混淆矩阵:\n{}\n'.format(oa, aa, kappa,cls,cm))

    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, xticklabels=config.label_names.values(), yticklabels=config.label_names.values(), cmap='Blues', fmt="d")

    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(config.save_results,'test_confusion_matrix.png'))
    plt.clf()

if __name__ == '__main__':
    model = resnet18()
    main(model)
4、结果显示

上述仅仅是简单演示,结果没有参考意义。

三、项目代码

本项目的代码通过以下链接下载:基于深度学习的图像分类或识别系统(含全套项目+PyQt5界面)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2140656.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言】(指针系列2)指针运算+指针与数组的关系+二级指针+指针数组+《剑指offer面试题》

前言:开始之前先感谢一位大佬,清风~徐~来-CSDN博客,由于是时间久远,博主指针的系列忘的差不多了,所以有些部分借鉴了该播主的,有些地方如果解释的不到位,请翻看这位大佬的,感谢大家&…

C++ char*和char[] 可能指向的内存区域详解(附实验)

C char* 指向的内存区域详解 写在前面c内存结构简介指针常量和常量指针简介情况一:char* 指向栈区内容情况二:char* 指向堆区内容情况三:char* 指向常量区内容情况四:char* 指向静态区内容情况五:char* 指向全局区内容…

Scratch游戏-史诗忍者7免费下载

小虎鲸Scratch资源站-免费少儿编程Scratch作品源码,素材,教程分享网站! 作品描述: 在Scratch版本的《史诗忍者7》中,你需要穿越关卡,击败敌人并收集33个水果。通过灵活的操作和精准的攻击,逐步闯过重重难关。游戏中提供了丰富的技…

【GESP】C++一级练习BCQM3005,基本输出语句printf

一道基础练习题,练习基本输出语句printf。 BCQM3005 题目要求 描述 输出表达式1234∗5678的结果。 输入 无 输出 1234∗56787006652 输入样例 无 输出样例 1234 * 5678 7006652 全文详见个人独立博客:https://www.coderli.com/gesp-1-bcqm3005/ 【…

使用 SuperCraft AI 设计书橱模型的指南

在现代家居设计中,书橱不仅是存放书籍的地方,更是展示个人品味和风格的重要家具。借助 SuperCraft AI,你可以轻松设计出独一无二的书橱。以下是详细的步骤指南,帮助你从零开始设计一个理想的书橱。 1. 创建项目 首先&#xff0c…

【探索数据结构与算法】插入排序:原理、实现与分析(图文详解)

目录 一、插入排序 算法思想 二、插入排序 算法步骤 四、复杂度分析 时间复杂度:O(n^2) 空间复杂度:O(1) 稳定性:稳定算法 五、应用场景 💓 博客主页:C-SDN花园GGbond ⏩ 文章专栏:探索数据结构…

node卸载流程

步骤: 1.开始中搜素”命令提示符“,并将其以”管理员身份运行“ 在弹出的框中输入cmd,并确认进入”命令提示符“ 2.在里面通过npm config list查看node相关文件路径, 并找到config from与prefix ,后面对应的路径&…

element-plus的菜单组件el-menu

菜单是几乎是每个管理系统的软件系统中不可或缺的,element-plus提供的菜单组件可以快速完成大部分的菜单的需求开发, 该组件内置和vue-router的集成,使用起来很方便。 主要组件如下 el-menu 顶级菜单组件 主要属性 mode:决定菜单的展示模式…

visual studio给项目增加eigen库 手把手教程

Eigen是一个开源的C库,主要用来支持线性代数,矩阵和矢量运算,数值分析及其相关的算法。Eigen 除了C标准库以外,不需要任何其他的依赖包。Eigen库3.4.0版本的下载地址为: https://gitlab.com/libeigen/eigen/-/archive/…

qt-creator-10.0.2之后版本的jom.exe编译速度慢下来了

1、Qt的IDE一直在升级,qt-creator的新版本下载地址 https://download.qt.io/official_releases/qtcreator/ 2、本人一直用的是qt-creator-10.0.2版本,官网历史仓库可以下载安装包qt-creator-opensource-windows-x86_64-10.0.2.exe https://download.qt…

清华大佬自曝:接到了省烟草局的offer,我就拒掉了华为!结果华为立马给我申请了特殊涨薪,总包70w是烟草的2倍,这可如何是好?

《网安面试指南》http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247484339&idx1&sn356300f169de74e7a778b04bfbbbd0ab&chksmc0e47aeff793f3f9a5f7abcfa57695e8944e52bca2de2c7a3eb1aecb3c1e6b9cb6abe509d51f&scene21#wechat_redirect 《Java代码审…

Java设计模式—面向对象设计原则(四) ----->接口隔离原则(ISP) (完整详解,附有代码+案例)

文章目录 3.4 接口隔离原则(ISP)3.4.1 概述3.4.2 案列 3.4 接口隔离原则(ISP) 接口隔离原则:Interface Segregation Principle,简称ISP 3.4.1 概述 客户端测试类不应该被迫依赖于它不使用的方法;一个类对另一个类的依赖应该建立在最小的接…

我国常见电压等级有哪些?来试试电压版2048让你轻松牢记

作为电气工程专业的小姜同学,平时喜欢敲点代码,前一阵做了一个电气特色的2048小游戏。既能缓解学习压力,又能让大家在玩之中把我国的电压等级牢记于心。 项目预览 功能描述 游戏方法与2048相同,根据我国标准的电压等级&#xff…

基于SpringBoot的人事管理系统【附源码】

基于SpringBoot的人事管理系统(源码L文说明文档) 目录 4 系统设计 4.1 系统概述 4.2系统功能结构设计 4.3数据库设计 4.3.1数据库E-R图设计 4.3.2 数据库表结构设计 5 系统实现 5.1管理员功能介绍 5.1.1管理员登…

【Leetcode:1184. 公交站间的距离 + 模拟】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

PDF转JPG,奋斗汪的必备技能,你掌握了吗?

现在大家都用电脑手机处理文件,PDF和JPG是最常见的两种。PDF文件方便打印和分享,而JPG图片小巧清晰,适合在手机上看和发给别人。有时候,我们需要把PDF文件变成JPG图片,比如想把教材或报告变成图片,方便在手…

F28335 时钟及控制系统

1 F28335 系统时钟来源 1.1 振荡器OSC与锁相环PLL 时钟信号对于DSP来说是非常重要的,它为DSP工作提供一个稳定的机器周期从而使系统能够正常运行。时钟系统犹如人的心脏,一旦有问题整个系统就崩溃。DSP 属于数字信号处理器, 它正常工作也必须为其提供时钟信号。那么这个时钟…

树(森林)的定义和画图

目录 代码实现 “双亲表示法”顺序存储 “孩子表示法”链式存储 树的孩子表示法存储 v.s. 图的邻接表存储 v.s. 散列表的拉链法 v.s. 基数排序 “孩子兄弟表示法”链式存储 画图表示 “双亲表示法” 1.树 2.森林 “孩子表示法” 1.树 2.森林 “孩子兄弟表示法” …

SPI学习笔记

SPI SPI是一种同步串行通信接口规范,它允许一个主设备与一个或多个从设备进行全双工通信。SPI用于短距离通信,主要应用于嵌入式系统。 SPI通信过程 1.初始化:SPI主机首先将SS或CS线拉低,以选择特定的从设备并开始通信。 2.数据…

AI健身体能测试之基于paddlehub实现引体向上计数个数统计

【引体向上计数】 本项目使用PaddleHub中的骨骼检测模型human_pose_estimation_resnet50_mpii,进行人体运动分析,实现对引体向上的自动计数。 1. 项目介绍 人体运动分析是近几年许多领域研究的热点问题。在学科的交叉研究上,人体运动分析涉…