pytorch利用简单CNN实现葡萄病虫害图片识别

news2024/9/22 17:31:55

1 前言

之前我开发了一个葡萄病虫害的可视化系统,最近就想给这个系统增加2个功能,一个是对接一个AI助手,可以进行葡萄病虫害的咨询,直接对接千问大模型,这个在之前的博文里已经介绍过对接方法了,第二个是做一个根据图片识别病虫害(分类)的功能。

2 实现思路

实现思路是想通过pytorch做一个CNN模型的训练,然后根据给出的图片进行类型的预测。

3 数据集

我没有数据集,仅有的一些图片是之前委托我做程序的bro给的,所以我们训练的时候图片并不多,不过这个没关系,数据集可以后期扩充,目前先实现功能部分

4 安装依赖

该功能由python语言实现,使用pip 安装如下依赖

torch
torchvision
matplotlib

5 数据位置

在这里插入图片描述
数据类似这样去组织,一种类型建一个文件夹,然后同一类型的图片放一起。

6 训练模型

import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 调整图片大小
    transforms.ToTensor(),            # 转换为 Tensor
])

# 加载数据集
data_dir = 'dataset'
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# 获取类别标签
class_names = dataset.classes
num_classes = len(class_names)

# 构建简单的 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 32 * 32, 128)  # 128 = (128/2)*(128/2)*(32/2)*(32/2)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 32 * 32)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
model = SimpleCNN(num_classes)

# 训练配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(torch.cuda.is_available())
print(f'Using device: {device}')

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练循环
num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}")

print("Training finished.")

# 保存模型
torch.save(model.state_dict(), 'plant_disease_model.pth')

执行代码之后得到模型文件:
在这里插入图片描述

7 预测模型

然后我们随便去找些病虫害图片,来做预测

import torch
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import torch.nn.functional as F

# 定义简单的 CNN 模型结构
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 32 * 32)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 预测函数
def predict(image_path, model, class_names):
    # 定义图像预处理
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # 统一大小
        transforms.ToTensor(),
    ])

    # 加载和预处理图像
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # 增加批次维度

    # 将图像输入模型进行预测
    model.eval()  # 设置模型为评估模式
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)

    # 返回预测的类别
    return class_names[predicted.item()]

if __name__ == "__main__":
    # 加载训练好的模型
    num_classes = 2  # 根据你的数据集类别数量修改
    model = SimpleCNN(num_classes)
    model.load_state_dict(torch.load('plant_disease_model.pth'))
    model.eval()

    # 类别名称(根据你的数据集修改)
    class_names = ['disease1', 'disease2']  # 替换为实际类别名称

    # 测试预测
    test_image_path = '1.jpg'  # 替换为测试图像的路径
    predicted_class = predict(test_image_path, model, class_names)
    print(f'Predicted class: {predicted_class}')

8 结果

给出的图片和图片预测结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2102713.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ChatGPT与R语言融合技术在生态环境数据统计分析、绘图、模型中的实践与进阶应用

自2022年GPT(Generative Pre-trained Transformer)大语言模型的发布以来,它以其卓越的自然语言处理能力和广泛的应用潜力,在学术界和工业界掀起了一场革命。在短短一年多的时间里,GPT已经在多个领域展现出其独特的价值…

【JavaWeb】JDBCDruidTomcat入门使用

本章使用技术版本: Tomcatv10.1.25 关于javaweb相关的其他技术,比如tomcat和maven,在我的主页记录了笔记,ajax我用的是本地笔记以后再考虑上传,前端三板斧我用的菜鸟教程文档 JDBC 初识 JDBC概念 JDBC 就是使用Jav…

MatLab基础学习01

MatLab基础学习01 1.基础入门2.MatLab的数据类型2.1数字2.2字符串2.3矩阵2.4.元胞数组2.5结构体 3.MatLab的矩阵的操作3.1矩阵定义与构造3.2矩阵的下标取值 4.MatLab的逻辑流程4. For循环结构4.2 While循环,当条件成立的时候进行循环4.3 IF end 1.基础入门 matlba必…

1.3 SQL注入之MYSQL系统库

一.系统库释义 提供了访问数据库元数据的方式 元数据是关于数据库的数据,如数据库名和表名,列的数据类型或访问权限。 1.information_schema 库:是信息数据库,其中保存着关于MySQL服务器所维护的所有其他数据库的信息&#xff1…

公园智能厕所引导大屏,清楚显示厕位有无人状态

在科技飞速发展的今天,公园的设施也在不断与时俱进。其中,公园智能厕所引导大屏的出现,为游客带来了全新的如厕体验。 走进公园的智能厕所区域,首先映入眼帘的便是那醒目的引导大屏。屏幕上清晰地显示着各个厕位的有无人状态&…

如何使用Cheerio与jsdom解析复杂的HTML结构进行数据提取

背景介绍 在现代网页开发中,HTML结构往往非常复杂,包含大量嵌套的标签和动态内容。这给爬虫技术带来了不小的挑战,尤其是在需要精确提取特定数据的场景下。传统的解析库可能无法有效处理这些复杂的结构,而JavaScript环境下的Chee…

机器学习(五) -- 监督学习(8) --神经网络2

机器学习系列文章目录及序言深度学习系列文章目录及序言 上篇:机器学习(五) -- 监督学习(8) --神经网络1 下篇: 前言 tips:标题前有“***”的内容为补充内容,是给好奇心重的宝宝看…

Fast Vision Transformers with HiLo Attention

总结 提出了 HiLo Attention 机制: 该机制将自注意力层分为两部分:Hi-Fi(高频注意力) 和 Lo-Fi(低频注意力)。Hi-Fi 捕捉局部细节,通过在局部窗口内应用自注意力,减少了计算复杂度…

图文解析保姆级教程:Tomcat下载、安装、卸载、启动、关闭,解决窗口闪退问题、端口号冲突问题

文章目录 1. 下载2. 安装与卸载3. 启动与关闭4. 常见问题问题1:Tomcat启动时,窗口一闪而过问题2:端口号冲突(Tomcat使用的端口被占用) 此教程摘选自我的笔记:黑马JavaWeb开发笔记14——Tomcat(介…

linux进程处理

1.测试这样没意义,要向后加 wait等待进程结束 1. 2.测试 发送异常结束的信号,通过kill 二、子进程的回收 对于子进程的结束而言,都希望父进程能够知道并作出一定的反应,通过 wait、waitpid 函数可以知道子进程是如何结束的…

人工智能训练师边缘计算实训室解决方案

一、引言 随着物联网(IoT)、大数据、人工智能(AI)等技术的飞速发展,计算需求日益复杂和多样化。传统的云计算模式虽在一定程度上满足了这些需求,但在处理海量数据、保障实时性与安全性、提升计算效率等方面…

使用VM创建centos7环境

1、安装VMware Workstation 1.1安装VMware Workstation pro 16 修改自己的安装位置 一直下一步到 1.2激活VMware Workstation pro 16 点击许可证 解压这个压缩包,密码是ai95 之后找到下面文件打开 将生成的许可证码输入到安装VMware Workstation pro 16完成安…

gitk无法打开

1、电脑重装,重新安装git工具后,发现无法打开现有的仓库,报错如下: 搜索网上的信息,显示是目录下没有.git文件夹,但是在xshell查看文件夹是存在的。 然后进行测试git log指令发现也无法进行显示。 然后按…

OJ习题 篇2

🚀个人主页:奋斗的小羊 🚀所属专栏:C 很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~ 目录 💥1、删除有序数组中的重复项💥2、数组中出现次数超过一半的数字💥3、最…

python 怎么样反向输出字符串

python如何反向输出字符串?下面给大家介绍两种方法: 方法一:采用列表reversed函数 class Solution(object):def reverse_string(self, s):if len(s) > 1:reversed_s .join(reversed(s))return reversed_s return s 方法二:采…

爬取图片保存为pdf

本文章想借着爬虫给大家介绍一下图片转pdf,有需要的友友们可以看看参考参考,有帮助到友友的可以收藏+关注。下面以爬取初中7年级数学上册为例给大家演示一下。网址是这个 https://mp.weixin.qq.com/s?__bizMzAxOTE4NjI1Mw&mid2650214000&idx…

10-1 注意力提示

感谢读者对本书的关注,因为读者的注意力是一种稀缺的资源: 此刻读者正在阅读本书(而忽略了其他的书), 因此读者的注意力是用机会成本(与金钱类似)来支付的。 为了确保读者现在投入的注意力是值得…

巨魔商店2.1正式更新,最高支持iOS17.6.1

巨魔商店2.1,来了 不得不说,我此刻的心情,确实有点振奋。一天之内,巨魔连续传来2个大动作。 一个是iOS17.0有了刷巨魔的正式方法,iPhone 15点燃了巨魔的火种。 另一个就是巨魔商店的开发者opa334,突然发…

Postgresql碎片整理

创建pgstattuple 扩展 CREATE EXTENSION pgstattuple 获取表的元组(行)信息,包括空闲空间的比例和行的平均宽度 SELECT * FROM pgstattuple(表名); 查看表和索引大小 SELECT pg_relation_size(表名), pg_relation_size(索引名称); 清理碎片方…

【鸿蒙HarmonyOS NEXT】List组件的使用

【鸿蒙HarmonyOS NEXT】List组件的使用 一、环境说明二、List组件及其使用三、示例代码如下 一、环境说明 DevEco Studio 版本:DevEco Studio NEXT Developer Beta5 Build #DS-233.14475.28.36.503700 Build Version: 5.0.3.700, built on August 19, 2024 Runtime…