基于安卓的虫害识别软件设计--(1)模型训练与可视化

news2025/1/20 1:51:16

引言

  • 简介:使用pytorch框架,从模型训练、模型部署完整地实现了一个基础的图像识别项目
  • 计算资源:使用的是Kaggle(每周免费30h的GPU)

1.创建名为“utils_1”的模块

模块中包含:训练和验证的加载器函数训练函数验证函数

import os
import sys

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def get_train_loader(image_path):
    train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform = train_transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32,
                                              shuffle=True, num_workers= 0)
    return train_loader

def get_val_loader(image_path):
    val_transform = transforms.Compose([transforms.Resize((224,224)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "validation"),
                                       transform = val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=32,
                                             shuffle = False, num_workers = 0)
    return val_loader

def train(train_loader,net):
    net.train()
    train_correct = 0.0
    train_loss = 0.0  # 初始化训练损失
    train_bar = tqdm(train_loader, file=sys.stdout)
    loss_function = nn.CrossEntropyLoss()
    loss_function = loss_function.to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    for step, data in enumerate(train_bar):
        images, labels = data
        images, labels = images.to(device),labels.to(device)
        # 梯度清零
        optimizer.zero_grad()
        # 训练
        outputs = net(images)
        # 计算损失
        loss = loss_function(outputs, labels)
        # 反向传播
        loss.backward()
        # 更新权重
        optimizer.step()
        # 统计
        _, preds = outputs.max(1)
        correct = preds.eq(labels).sum()
        train_correct += correct
        train_loss += loss.item()  # 累加损失值
        train_bar.desc = 'Training Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(
                loss.item(),
                (100. * correct) / len(outputs),
                trained_samples=step * train_loader.batch_size + len(images),
                total_samples=len(train_loader.dataset))
    train_correct = (100. * train_correct) / len(train_loader.dataset)
    train_loss /= len(train_loader)  # 计算平均损失值
    return train_correct, train_loss  # 返回训练正确率和平均损失值

def val(val_loader,net):
    net.eval()
    val_correct = 0.0
    val_loss = 0.0  # 初始化验证损失
    loss_function = nn.CrossEntropyLoss()
    loss_function = loss_function.to(device)

    val_bar = tqdm(val_loader, file=sys.stdout)
    for step, data in enumerate(val_bar):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            # 验证
            outputs = net(images)
            # 计算损失
            loss = loss_function(outputs, labels)
            # 统计
            _, preds = outputs.max(1)
            correct = preds.eq(labels).sum()
            val_correct += correct
            val_loss += loss.item()  # 累加损失值
            val_bar.desc = 'Valing Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(
                loss.item(),
                (100. * correct) / len(outputs),
                trained_samples=step * val_loader.batch_size + len(images),
                total_samples=len(val_loader.dataset))
    val_correct = (100. * val_correct) / len(val_loader.dataset)
    val_loss /= len(val_loader)  # 计算平均损失值
    return val_correct , val_loss  # 返回验证正确率和平均损失值

注意:若使用Kaggle,想要导入该模块,需要添加以下代码

import sys
sys.path.append(r'/kaggle/input/mycode2')

其中,模块路径如下图


2.主函数 

主函数包含:使用模型函数训练主函数画图代码

2.1使用模型函数 

【若使用其他模型,可chatgpt创建其函数】

(1)resnet101 

def get_resnet101(class_num):
    net_name = "resnet101"
    net = torchvision.models.resnet101(pretrained=True)
    net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input features
    net = net.to(device)
    return net_name, net

(2)resnet34 

def get_resnet34(class_num):
    net_name = "resnet34"
    net = torchvision.models.resnet34(pretrained=True)
    net.fc = Linear(in_features=512, out_features=class_num, bias=True)
    net = net.to(device)
    return net_name,net

(3)mobilenetv2

def get_mobilenet_v2(class_num):
    net_name = "mobilenet_v2"
    net = torchvision.models.mobilenet_v2(pretrained=True)
    net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)
    net = net.to(device)
    return net_name,net

 2.2画图代码 

    save_path="/kaggle/working/"  
  
    plt.figure(figsize=(12, 4))
    # loss
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')
    plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    # acc
    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')
    plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Acc')
    plt.legend()
    plt.savefig(os.path.join(save_path, 'result.png')) # 保存
    plt.show()

2.3完整代码 

import torch
import torchvision.models
from matplotlib import pyplot as plt
from torch.nn import Linear
import os

# 导入自己创建的模块
from utils_1 import get_train_loader, train, val, get_val_loader

# 模型选择
def get_resnet101(class_num):
    net_name = "resnet101"
    net = torchvision.models.resnet101(pretrained=True)
    net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input features
    net = net.to(device)
    return net_name, net

# def get_resnet34(class_num):
#     net_name = "resnet34"
#     net = torchvision.models.resnet34(pretrained=True)
#     net.fc = Linear(in_features=512, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net

# def get_mobilenet_v2(class_num):
#     net_name = "mobilenet_v2"
#     net = torchvision.models.mobilenet_v2(pretrained=True)
#     net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net

# 训练主函数
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #1 加载数据
    image_path = r"/kaggle/input/fruits3"
    train_loader = get_train_loader(image_path)
    val_loader = get_val_loader(image_path)
    #2 加载模型
    net_name,net = get_resnet34(class_num=5)
    #3 训练
    epochs = 5
    best_acc = 0
    
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    for epoch in range(epochs):
        train_acc,train_loss = train(train_loader, net)
        val_acc,val_loss = val(val_loader, net)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc.item())
        val_accs.append(val_acc.item())
        
        if best_acc<val_acc:
            best_acc = val_acc
            torch.save(net, os.path.join("/kaggle/working/", net_name + ".pt"))
    
    # 画图
    save_path="/kaggle/working/" # 图片保存路径
    
    plt.figure(figsize=(12, 4))
    # loss
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')
    plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    # acc
    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')
    plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Acc')
    plt.legend()
    plt.savefig(os.path.join(save_path, 'result.png')) # 保存
    plt.show()

2.4训练效果与模型文件

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1719623.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Uniapp发布流程存档

发布成小程序 配置微信小程序的appid 配置小程序的域名 修改静态资源路径为线上路径 发布成H5 配置H5 发行 运行 发布成安卓 基础配置

JS-Lodash工具库

文档&#xff1a;Lodash Documentation orderBy函数&#xff1a;根据条件进行排序 注&#xff1a;第一个是要排序的数组&#xff0c;第二个是根据什么字段进行排序&#xff0c;第三个是排序的方式&#xff08;desc倒序&#xff09; 安装方式&#xff1a;Lodash npm i lodash…

Presto 从提交SQL到获取结果 源码详解(3)

物理执行计划 回到SqlQueryExecution.startExecution() &#xff0c;执行计划划分以后&#xff0c; // 初始化连接&#xff0c;获取Connect 元数据&#xff0c;添加会话&#xff0c;初始ConnectId metadata.beginQuery(getSession(), plan.getConnectors()); // 构建物理执行…

关于MD5

首先还是介绍一下关于md5的基本信息&#xff1a; MD5&#xff08;Message Digest Algorithm 5&#xff09;是一种常用的哈希函数&#xff0c;用于产生128位&#xff08;16字节&#xff09;的哈希值&#xff0c;通常以32个十六进制数字表示。MD5广泛用于计算文件或文本数据的校…

LeetCode-131 分割回文串

LeetCode-131 分割回文串 题目描述解题思路C 代码 题目描述 给你一个字符串 s&#xff0c;请你将 s 分割成一些子串&#xff0c;使每个子串都是 回文串。返回 s 所有可能的分割方案。 示例 1&#xff1a; 输入&#xff1a;s “aab” 输出&#xff1a;[[“a”,“a”,“b”],…

Zynq学习笔记--AXI4-Stream 图像数据从仿真输出到图像文件

目录 1. 简介 2. 构建工程 2.1 Vivado 工程 2.2 TestBench 代码 2.3 关键代码分析 3. VPG Background Pattern ID (0x0020) Register 4. 总结 1. 简介 使用 SystemVerilog 将 AXI4-Stream 图像数据从仿真输出到图像文件 (PPM)。 用到的函数包括 $fopen、$fwrite 和 $f…

vmware esxi虚拟化数据迁移

1、启用esxi的ssh 登录esxi的web界面&#xff0c;选择主机-》操作——》服务——》启动ssh 2.xshell登录esxi 3、找到虚拟机所在目录 blog.csdnimg.cn/direct/d57372536a4145f2bcc1189d02cc7da8.png)#### 3在传输数据前需关闭防火墙服务 查看防火墙状态&#xff1a;esxcli …

Android MediaCodec 简明教程(九):使用 MediaCodec 解码到纹理,使用 OpenGL ES 进行处理,并编码为 MP4 文件

系列文章目录 Android MediaCodec 简明教程&#xff08;一&#xff09;&#xff1a;使用 MediaCodecList 查询 Codec 信息&#xff0c;并创建 MediaCodec 编解码器Android MediaCodec 简明教程&#xff08;二&#xff09;&#xff1a;使用 MediaCodecInfo.CodecCapabilities 查…

【传知代码】双深度学习模型实现结直肠癌检测(论文复现)

前言&#xff1a;在医学领域&#xff0c;科技的进步一直是改变人类生活的关键驱动力之一。随着深度学习技术的不断发展&#xff0c;其在医学影像诊断领域的应用正日益受到关注。结直肠癌是一种常见但危害极大的恶性肿瘤&#xff0c;在早期发现和及时治疗方面具有重要意义。然而…

【VSCode】快捷方式log去掉分号

文章目录 一、引入二、解决办法 一、引入 我们使用 log 快速生成的 console.log() 都是带分号的 但是我们的编程习惯都是不带分号&#xff0c;每次自动生成后还需要手动删掉分号&#xff0c;太麻烦了&#xff01; 那有没有办法能够生成的时候就不带分号呢&#xff1f;自然是有…

C++ 特殊运算符

一 赋值运算符 二 等号作用 三 优先级和结合顺序 四 左值和右值 五 字节数运算符 条件运算符 使用条件运算符注意 逗号运算符 优先级和结合顺序 总结

如何修改开源项目中发现的bug?

如何修改开源项目中发现的bug&#xff1f; 目录 如何修改开源项目中发现的bug&#xff1f;第一步&#xff1a;找到开源项目并建立分支第二步&#xff1a;克隆分支到本地仓库第三步&#xff1a;在本地对项目进行修改第四步&#xff1a;依次使用命令行进行操作注意&#xff1a;Gi…

平衡二叉树的应用举例

AVL 是一种自平衡二叉搜索树&#xff0c;其中任何节点的左右子树的高度之差不能超过 1。 AVL树的特点&#xff1a; 1、它遵循二叉搜索树的一般属性。 2、树的每个子树都是平衡的&#xff0c;即左右子树的高度之差最多为1。 3、当插入新节点时&#xff0c;树会自我平衡。因此…

生信服务器配置选择说明

阿小云整理关于生信云服务器的配置选择攻略&#xff0c;生物信息服务器需要强大的计算能力和大容量存储&#xff0c;超高计算能力可以应对生物数据分析计算&#xff0c;如大规模基因序列比对等&#xff0c;大容量存储可以用来存储各种基因组、蛋白质组等数据。 生信服务器配置选…

Superset二次开发之更新 SECRET_KEY

SECRET_KEY 的作用 加密和签名:SECRET_KEY用于对敏感数据(如会话、cookie、CSRF令牌)进行加密和签名,防止数据被篡改。安全性:确保应用的安全性,防止跨站请求伪造(CSRF)攻击和会话劫持等安全问题。如何生成 SECRET_KEY openssl rand -base64 42 配置 SECRET_KEY 在sup…

VisualSVN Server/TortoiseSVN更改端口号

文章目录 概述VisualSVN Server端更改端口号TortoiseSVN客户端更改远程仓库地址 概述 Subversion&#xff08;SVN&#xff09;是常用的版本管理系统之一。部署在服务器上的SVN Server端通常会在端口号80&#xff0c;或者端口号443上提供服务。其中80是HTTP访问方式的默认端口。…

虚拟现实环境下的远程教育和智能评估系统(三)

本周继续进行开发工具的选择与学习&#xff0c;基本了解了以下技术栈的部署应用&#xff1b; 一、Seata&#xff1a; Seata&#xff08;Simple Extensible Autonomous Transaction Architecture&#xff09;是一款开源的分布式事务解决方案&#xff0c;旨在提供高性能和简单易…

创新实训2024.05.29日志:评测数据集与baseline测试

1. 评测工作 在大模型微调和RAG工作都在进行的同时&#xff0c;我们搭建了一套评测数据集。这套数据集有山东大学周易研究中心背书。主要考察大模型对于易学基本概念与常识的理解与掌握能力。 1.1. 构建评测集 在周易研究中心的指导下&#xff0c;我们构建出了一套用以考察大…

【并查集】专题练习

题目列表 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 模板 836. 合并集合 - AcWing题库 #include<bits/stdc.h> using lllong long; //#define int ll const int N1e510,mod1e97; int n,m; int p[N],sz[N]; int find(int a) {if(p[a]!a) p[a]find(p[a]);return p[a…

数据结构:希尔排序

文章目录 前言一、排序的概念及其运用二、常见排序算法的实现 1.插入排序2.希尔排序总结 前言 排序在生活中有许多实际的运用。以下是一些例子&#xff1a; 购物清单&#xff1a;当我们去超市购物时&#xff0c;通常会列出一份购物清单。将购物清单按照需要购买的顺序排序&…