基于python pyotrch开发的垃圾分类程序,含数据集,基于深度学习的垃圾分类程序

news2025/1/16 1:13:06

基于python的垃圾分类程序,提供数据集(pytorch开发)

完整代码下载地址:基于python pyotrch开发的垃圾分类程序,含数据集
image-20210305103139860

垃圾分类是目前社会的一个热点,分类的任务是计算机视觉任务中的基础任务,相对来说比较简单,只要找到合适的数据集,垃圾分类的模型构建并不难,这里我找到一份关于垃圾分类的数据集,一共有四个大类和245个小类,大类分别是厨余垃圾、可回收物、其他垃圾和有害垃圾,小类主要是垃圾的具体类别,果皮、纸箱等。

为了方便大家使用,我已经提前将数据集进行了处理,按照8比1比1的比例将原始数据集划分成了训练集、验证集和测试集,大家可以从下面的链接自取。

链接:https://pan.baidu.com/s/1BkDlOmJwN37TVhfig4llow
提取码:9avi
复制这段内容后打开百度网盘手机App,操作更方便哦–来自百度网盘超级会员V4的分享

代码结构

trash1.0
├─ .idea idea配置文件
├─ imgs 图片文件
├─ main_window.py 图形界面代码
├─ models
│    └─ mobilenet_trashv1_2.pt
├─ old 一些废弃的代码
├─ readme.md 你现在看到的
├─ test.py 测试文件
├─ test4dataset.py  测试所有的数据集
├─ test4singleimg.py 测试单一的图片
├─ train_245_class.py 训练代码
└─ utils.py 工具类,用于划分数据集

训练

训练前请执行命令按照好项目所需的依赖库,关于如何在python中使用conda和pip对项目包管理可以看这篇文章或者是看我b站的这个视频,里面有详细的讲解。

csdn文章:Windows下GPU深度学习环境的配置(pytorch和tensorflow)_ECHOSON的博客-CSDN博客

conda create -n torch1.6 python==3.6.10
conda activate torch1.6
conda install pytorch torchvision cudatoolkit=10.2 # GPU(可选)
conda install pytorch torchvision cpuonly
pip install opencv-python
pip install matplotlib

首先需要把数据集下载之后进行解压,记住解压的路径,并在train.py的18行将数据集路径修改为你本地的数据集路径,修改之后执行运行train.py文件即可开始模型训练,训练之后的模型将会保存在models目录下。

模型训练部分则选用了大名鼎鼎的mobilenet,mobilenet是比较轻量的网络,在cpu上也可以运行的很快,训练的代码如下,首先通过pytorch的dataloader加载数据集,并加载预训练的mobilenet进行微调。

# coding:utf-8
# TODO 添加一个图形化界面
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import sys
import cv2
import torch
import torchvision.transforms as transforms
from PIL import Image
from old.train_based_torchvision import Net

names = []


class MainWindow(QTabWidget):
    def __init__(self):
        super().__init__()
        self.setWindowIcon(QIcon('imgs/面性铅笔.png'))
        self.setWindowTitle('垃圾识别')
        # 加载网络
        self.net = torch.load("models/mobilenet_trashv1_2.pt", map_location=lambda storage, loc: storage)
        self.transform = transforms.Compose(
            # 这里只对其中的一个通道进行归一化的操作
            [transforms.Resize([224, 224]),
             transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

        self.resize(800, 600)
        self.initUI()

    def initUI(self):
        main_widget = QWidget()
        main_layout = QHBoxLayout()
        font = QFont('楷体', 15)
        left_widget = QWidget()
        left_layout = QVBoxLayout()
        img_title = QLabel("测试样本")
        img_title.setFont(font)
        img_title.setAlignment(Qt.AlignCenter)
        self.img_label = QLabel()
        self.predict_img_path = "imgs/img111.jpeg"
        img_init = cv2.imread(self.predict_img_path)
        img_init = cv2.resize(img_init, (400, 400))
        cv2.imwrite('imgs/target.png', img_init)
        self.img_label.setPixmap(QPixmap('imgs/target.png'))
        left_layout.addWidget(img_title)
        left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)
        left_widget.setLayout(left_layout)

        right_widget = QWidget()
        right_layout = QVBoxLayout()
        btn_change = QPushButton(" 上传垃圾图像 ")
        btn_change.clicked.connect(self.change_img)
        btn_change.setFont(font)
        btn_predict = QPushButton(" 识别垃圾种类 ")
        btn_predict.setFont(font)
        btn_predict.clicked.connect(self.predict_img)

        label_result = QLabel(' 识 别 结 果 ')
        self.result = QLabel("待识别")
        label_result.setFont(QFont('楷体', 16))
        self.result.setFont(QFont('楷体', 24))
        right_layout.addStretch()
        right_layout.addWidget(label_result, 0, Qt.AlignCenter)
        right_layout.addStretch()
        right_layout.addWidget(self.result, 0, Qt.AlignCenter)
        right_layout.addStretch()
        right_layout.addWidget(btn_change)
        right_layout.addWidget(btn_predict)
        right_layout.addStretch()
        right_widget.setLayout(right_layout)

        # 关于页面
        about_widget = QWidget()
        about_layout = QVBoxLayout()
        about_title = QLabel('欢迎使用智能垃圾识别系统')
        about_title.setFont(QFont('楷体', 18))
        about_title.setAlignment(Qt.AlignCenter)
        about_img = QLabel()
        about_img.setPixmap(QPixmap('imgs/logoxx.png'))
        about_img.setAlignment(Qt.AlignCenter)
        label_super = QLabel()
        label_super.setText("<a href='https://blog.csdn.net/ECHOSON'>我的个人主页</a>")
        label_super.setFont(QFont('楷体', 12))
        label_super.setOpenExternalLinks(True)
        label_super.setAlignment(Qt.AlignRight)
        # git_img = QMovie('images/')
        about_layout.addWidget(about_title)
        about_layout.addStretch()
        about_layout.addWidget(about_img)
        about_layout.addStretch()
        about_layout.addWidget(label_super)
        about_widget.setLayout(about_layout)

        main_layout.addWidget(left_widget)
        main_layout.addWidget(right_widget)
        main_widget.setLayout(main_layout)
        self.addTab(main_widget, '主页面')
        self.addTab(about_widget, '关于')
        self.setTabIcon(0, QIcon('imgs/面性计算器.png'))
        self.setTabIcon(1, QIcon('imgs/面性本子vg.png'))

    def change_img(self):
        openfile_name = QFileDialog.getOpenFileName(self, '选择文件', '', 'Image files(*.jpg , *.png, *.jpeg)')
        print(openfile_name)
        img_name = openfile_name[0]
        if img_name == '':
            pass
        else:
            self.predict_img_path = img_name
            img_init = cv2.imread(self.predict_img_path)
            img_init = cv2.resize(img_init, (400, 400))
            cv2.imwrite('imgs/target.png', img_init)
            self.img_label.setPixmap(QPixmap('imgs/target.png'))

    def predict_img(self):
        # 预测图片
        # 开始预测
        # img = Image.open()
        transform = transforms.Compose(
            [transforms.Resize([224, 224]),
             transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

        img = Image.open(self.predict_img_path)
        RGB_img = img.convert('RGB')
        img_torch = transform(RGB_img)
        img_torch = img_torch.view(-1, 3, 224, 224)
        outputs = self.net(img_torch)
        _, predicted = torch.max(outputs, 1)
        result = str(names[predicted[0].numpy()])

        self.result.setText(result)


if __name__ == "__main__":
    app = QApplication(sys.argv)
    x = MainWindow()
    x.show()
    sys.exit(app.exec_())

测试

模型训练好之后就可以进行模型的测试了,其中test4dataset.py文件主要是对数据集进行测试,也就是解压之后的test目录下的所有文件进行测试,那么test4singleimg.py文件主要是对单一的图片进行测试。

考虑到大家可能想省去训练的过程,所以我在models目录下放了我训练好的模型,你可以直接使用我训练好的模型进行测试,目前在测试集上的准确率大概在80%左右,不是很高,但是也足够使用。

另外,处理基本的测试之外,还有分类别的测试以及heatmap形式的演示,这部分的代码写的比较乱,暂时放在了abandon目录下,如果项目的star超过100的话,我会再更新这部分的内容。以下就是部分测试的代码;

# from train import load_data
from PIL import ImageFile
import torch
import os
from torchvision import transforms, datasets
import numpy as np
from torch.utils.data import Dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# 有些图片信息不全,不能读取,跳过这些图片
ImageFile.LOAD_TRUNCATED_IMAGES = True
np.set_printoptions(suppress=True)

# todo
def load_test_data(data_dir="E:/遥感目标检测数据集/垃圾分类数据集/trash_real_split"):
    data_transforms = {
        'val': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x])
                      for x in ['val', 'test']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,
                                                  shuffle=True, num_workers=0)
                   for x in ['val', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['val', 'test']}
    class_names = image_datasets['test'].classes
    return dataloaders, dataset_sizes, class_names


def test_test_dataset(model_path="models/mobilenet_trashv1_2.pt"):
    # 加载模型
    net = torch.load(model_path, map_location=lambda storage, loc: storage)
    dataloaders, dataset_sizes, class_names = load_test_data()
    testloader = dataloaders['test']
    test_size = dataset_sizes['test']
    net.to(device)
    net.eval()
    # 测试全部的准确率
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += torch.sum(predicted == labels.data)
    correct = correct.cpu().numpy()
    print('Accuracy of the network on the %d test images: %d %%' % (test_size,
                                                                    100 * correct / total))


def test_test_dataset_by_classes(model_path="models/mobilenet_trashv1_2.pt"):
    # 加载模型
    net = torch.load(model_path, map_location=lambda storage, loc: storage)
    dataloaders, dataset_sizes, class_names = load_test_data()
    testloader = dataloaders['test']
    test_size = dataset_sizes['test']
    net.to(device)
    net.eval()
    classes = class_names
    # 测试每一类的准确率
    class_correct = list(0. for i in range(len(class_names)))
    class_total = list(0. for i in range(len(class_names)))
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    for i in range(len(class_names)):
        print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))


if __name__ == '__main__':
    print('模型在整个数据集上的表现:')
    test_test_dataset()
    print('模型在每一类上的表现:')
    test_test_dataset_by_classes()

测试结果如下图:

image-20210305140104656

图形化界面

图形化界面主要通过Pyqt5来进行开发,主要是完成一些上传图片,对图片进行识别并把识别结果进行输出的功能,俺的审美不是很好,所以设计的界面可能不是很好看,大家后面可以根据自己的需要修改界面。

image-20210305142518950

image-20210305142537660

image-20210305142602138

完整代码下载地址:基于python pyotrch开发的垃圾分类程序,含数据集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/145344.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt进度条详解以及format显示格式

进度条的步进值 设置好进度条的最大值和最小值&#xff0c;进度条将会显示完成的步进值占总的步进值的百分比&#xff0c;百分比的计算公式为&#xff1a;百分比 (value() - minimum()) / (maximum() - minimum()) 部分函数含义 QProgressBar&#xff1a;横向或纵向显示进度的…

前端必备:五大css自动化生成网站(稀有级别!)

粉丝朋友们大家好&#xff0c;我是你们的 csdn的博主&#xff1a;lqj_本人 哔哩哔哩&#xff1a;小淼前端 另外&#xff0c;大家也可以关注我的哔哩哔哩账号&#xff0c;我会不定时的发布一些有关于全栈云开发以及前端开发的详解视频源码 1.微信小程序腾讯云开发之学生端收集数…

8.3K Star,这才是我们苦苦寻找的PDF阅读器。。。

程序员宝藏库&#xff1a;https://gitee.com/sharetech_lee/CS-Books-Store 无论是在大学期间&#xff0c;还是工作之后都很难绕开PDF软件。 比如看个论文、课件、演示文档…经常会用到PDF。 大学期间我是一个特别爱折腾各种各样电子产品、数码、软件、操作系统&#xff0c;曾…

囿于数据少?泛化性差?PaddleDetection少样本迁移学习助你一键突围!

目标检测是非常基础和重要的计算机视觉任务&#xff0c;在各行业有非常广泛的应用。然而&#xff0c;在很多领域的实际落地过程中&#xff0c;由于样本稀缺、标注成本高或业务冷启动等困难&#xff0c;难以训练出可靠的模型。 在目标检测这类较为复杂的学习任务上&#xff0c;样…

2023年跨境电商依然是风口,如何做好跨境电商

2023年1月1日&#xff0c;《区域全面经济伙伴关系协定》(RCEP)正式签署生效一周年&#xff0c;(rcep)于2023年1月2日起&#xff0c;RCEP对印度尼西亚正式生效&#xff0c;至此&#xff0c;我国已与其他14个rcep成员中的13个相互实施协定。这预示着&#xff0c;东南亚市场必将成…

下拉控件无法选中

本文迁移自本人网易博客&#xff0c;写于2012年1月9日&#xff0c;二维多段线绘制 - lysygyy的日志 - 网易博客 (163.com)做符号化过程中&#xff0c;一开始发现控件下拉后导致死机&#xff0c;原来是资源切换的问题&#xff0c;使用CAcModuleResourceOverride resOverride;即可…

Cadence PCB仿真使用Allegro PCB SI配置电压地网络电压的方法图文教程

⏪《上一篇》   🏡《总目录》   ⏩《下一篇》 目录 1,概述2,配置方法3,总结1,概述 本文简单介绍使用Allegro PCB SI软件配置电压地网络电压的方法。 2,配置方法 第1步:打开待仿真的PCB文件,并确认软件为Allegro PCB SI 如果,打开软件不是Allegro PCB SI则可这样…

C++连接mysql数据库并读取数据

1、需要包含mysql API的头文件 如果需要连接都本地的mysql数据库&#xff0c;前提是本地要已经安装了mysql数据库。这里要用到一些mysql的API&#xff0c;比如连接数据库、执行查询语句等操作&#xff0c;这些接口都包含在下面的头文件中&#xff1a; #include <mysql/mys…

kubernetes部署nacos集群(防坑)

kubernetes部署nacos集群&#xff08;防坑&#xff09; 官方nacos集群yaml文档参考&#xff1a; https://github.com/nacos-group/nacos-k8s.git 一、nacos 概览 Nacos 致力于帮助您发现、配置和管理微服务。Nacos 提供了一组简单易用的特性集&#xff0c;帮助您快速实现动态…

VUE3中,使用Axios

axios是前后端数据交互的重要桥梁&#xff0c;理论和概念这里不再叙述了。可以看看官网。 axios中文文档|axios中文网 | axios 本例子&#xff0c;从简单到难。 目录 一、简单的使用 二、查询数据时出现等待窗体 一、简单的使用 1.废话少说&#xff0c;先使用HBuilder X建…

【测试】软件测试基本概念

努力经营当下&#xff0c;直至未来明朗&#xff01; 文章目录1. 什么是需求2.什么是测试用例&#xff1f;3. 软件错误Bug的概念:sparkles:小结普通小孩也要热爱生活&#xff01; 1. 什么是需求 【注】一旦提及“区别”&#xff0c;一定要回答 相同点不同点。 在企业中&#x…

算法刷题打卡第57天:合并两个有序链表

合并两个有序链表 难度&#xff1a;简单 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1&#xff1a; 输入&#xff1a;l1 [1,2,4], l2 [1,3,4] 输出&#xff1a;[1,1,2,3,4,4]示例 2&#xff1a; 输入&#x…

liunx中搭建python3.7环境和安装pycharm并搭建Django

首先第一步我们先安装python3.7的环境在liunx中&#xff0c;先去下面这个网站然后找到Gzipped source tarball https://www.python.org/downloads/release/python-377/ 下拉到最底下选择它然后下载 如果python中已经安装就跳过这一步 用python --version 检查后如果已经装好pyt…

【2022年终总结】期待下一个365天

2022结束啦&#xff01;&#xff01;&#xff01;日出万物生&#xff0c;日落满天星&#xff0c;期待下一个365天&#xff01;&#xff01;&#xff01; 不知不觉在CSDN断断续续写文章已经四个月了&#xff0c;回想这段日子&#xff0c;还是很有必要纪念一下的&#xff0c;本期…

ArcGIS基础实验操作100例--实验61数据框投影变换

本实验专栏参考自汤国安教授《地理信息系统基础实验操作100例》一书 实验平台&#xff1a;ArcGIS 10.6 实验数据&#xff1a;请访问实验1&#xff08;传送门&#xff09; 高级编辑篇--实验61 数据框投影变换 目录 一、实验背景 二、实验数据 三、实验步骤 &#xff08;1&am…

JavaScript 模块化 —— 从概念到原理

走过路过发现 bug 请指出&#xff0c;拯救一个辣鸡&#xff08;但很帅&#xff09;的少年就靠您啦&#xff01;&#xff01;&#xff01; 1. 为什么需要 Javascipt 模块化&#xff1f; 1.解决命名冲突。将所有变量都挂载在到全局 global 会引用命名冲突的问题。模块化可以把变…

人工智能与python

人工智能的话题在近几年可谓是相当火热&#xff0c;前几天看快本时其中有一个环节就是关于人工智能的&#xff0c;智能家电、智能机器人、智能工具等等&#xff0c;在我的印象里&#xff0c;提到人工智能就会出现 Python&#xff0c;然后我便在网上查找了相关信息&#xff0c;并…

(第三章)OpenGL超级宝典学习:认识渲染管线

OpGL超级宝典学习&#xff1a;认识渲染管线 前言 本章作为OpenGL学习的第三章节 在本章节我们将认识OpenGL的渲染管线 对管线内各个过程有一个初步的认识 ★提高阅读体验★ &#x1f449; ♠一级标题 &#x1f448; &#x1f449; ♥二级标题 &#x1f448; &#x1…

【KG】TransE 及其实现

原文&#xff1a;https://yubincloud.github.io/notebook/pages/paper/kg/TransE/ TransE 及其实现 1. What is TransE? TransE (Translating Embedding), an energy-based model for learning low-dimensional embeddings of entities. 核心思想&#xff1a;将 relationship …

基于R的Bilibili视频数据建模及分析——建模-因子分析篇

基于R的Bilibili视频数据建模及分析——建模-因子分析篇 文章目录基于R的Bilibili视频数据建模及分析——建模-因子分析篇0、写在前面1、数据分析1.1 建模-因子分析1.2 对数线性模型1.3 主成分分析1.4 因子分析1.5 多维标度法2、参考资料0、写在前面 实验环境 Python版本&#…