pytorch实现水果2分类(蓝莓,苹果)

news2024/9/22 5:31:36

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.path

import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class DtataAndLabel(Dataset):
    def __init__(self,path='fruits',is_train=True):
        self.tran=transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(size=(88,88))
        ])
        is_train='train' if True else 'test'
        self.data=[]
        path=os.path.join(path,is_train)
        print('path=',path)
        print(os.path.join(path, '*', '*'))
        img_paths=glob.glob(os.path.join(path,'*','*'))
        for img_path in img_paths:
            label=0 if img_path.split('\\')[-2]=='blueberry' else 1
            self.data.append((img_path,label))
    def __getitem__(self, idx):
        #每一张图片返回一个img_tensor,one_hot
        img_path,label =self.data[idx]
        img=cv2.imread(img_path)
        # img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
        img_tensor=self.tran(img)
        img_tensor=img_tensor/255
        img_tensor=torch.flatten(img_tensor)
        one_hot=torch.zeros(2)
        one_hot[label]=1
        return img_tensor,one_hot
    def __len__(self):
        return len(self.data)

if __name__ == '__main__':
    # 测试
    data=DtataAndLabel()
    print(data[1][0].shape)
    print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(88*88*3, 800),
            nn.ReLU(),
            nn.Linear(800, 500),
            nn.ReLU(),
            nn.Linear(500, 800),
            nn.ReLU(),
            nn.Linear(800, 200),
            nn.ReLU(),
            nn.Linear(200, 2),

        )
        self.softmax=nn.Softmax(dim=1)
    def forward(self,x):
        x=self.model(x)
        x=self.softmax(x)
        return x
if __name__ == '__main__':
    net=Net()
    #测试一下
    x=torch.randn(1,100*100)
    out=net(x)
    print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():
    def __init__(self):
        self.writer = SummaryWriter("logs")
        self.train_data=DtataAndLabel(is_train=True)
        self.test_data=DtataAndLabel(is_train=False)
        #使用加载器分批加载
        self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)
        self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)
        #初始化网络
        #损失函数
        #优化器

        net=Net()
        self.net=net
        self.loss=nn.MSELoss()
        self.opt=torch.optim.Adam(net.parameters(),lr=0.001)
        self.min_loss=100.0
        self.weight_path='weight/best.pt'

    def train(self,epoch):
        sum_loss = 0
        sum_acc = 0
        for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):
            out = self.net(img_tensors)
            loss = self.loss(out, targets)

            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

            sum_loss += loss.item()
            pred_cls = torch.argmax(out, dim=1)
            target_cls = torch.argmax(targets, dim=1)
            accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))
            sum_acc += accuracy.item()
        avg_loss = sum_loss / len(self.train_loader)
        avg_acc = sum_acc / len(self.train_loader)
        print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
        self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)
        self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)

    def test(self,epoch):
        sum_loss = 0
        sum_acc = 0
        for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):
            out = self.net(img_tensors)
            loss = self.loss(out, targets)
            sum_loss += loss.item()
            pred_cls = torch.argmax(out, dim=1)
            target_cls = torch.argmax(targets, dim=1)
            accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))
            sum_acc += accuracy.item()
        avg_loss = sum_loss / len(self.test_loader)
        avg_acc = sum_acc / len(self.test_loader)
        print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
        self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)
        self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)
        if avg_loss<self.min_loss:
            self.min_loss=min(self.min_loss,avg_loss)
            torch.save(self.net.state_dict(), self.weight_path)
    def run(self):
        for epo in range(100):
            self.train(epo)
            self.test(epo)

if __name__ == '__main__':
    trainer=TrainAndTest()
    trainer.run()



精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1916061.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何分辨AI生成的内容?AI生成内容检测工具对比实验

检测人工智能生成的文本对各个领域的组织都提出了挑战&#xff0c;包括学术界和新闻界等。生成式AI与大语言模型根据短描述来进行内容生成的能力&#xff0c;产生了一个问题&#xff1a;这篇文章/内容/作业/图像到底是由人类创作的&#xff0c;还是AI创作的&#xff1f;虽然 LL…

【数据库】Redis主从复制、哨兵模式、集群

目录 一、Redis的主从复制 1.1 主从复制的架构 1.2 主从复制的作用 1.3 注意事项 1.4 主从复制用到的命令 1.5 主从复制流程 1.6 主从复制实现 1.7 结束主从复制 1.8 主从复制优化配置 二、哨兵模式 2.1 哨兵模式原理 2.2 哨兵的三个定时任务 2.3 哨兵的结构 2.4 哨…

MT3047 区间最大值

思路&#xff1a; 使用哈希表map和set&#xff08;去重&#xff09;维护序列 代码&#xff1a; #include <bits/stdc.h> using namespace std; const int N 1e5 10; int n, k, A[N]; map<int, int> mp; // 元素出现的次数 set<int> s; // 维护出现…

Android平台GB28181记录仪在电网巡检抢修中的应用和技术实现

技术背景 在探讨Android平台GB28181设备接入端在电网巡检抢修优势之前&#xff0c;我们已经在执法记录仪、智能安全帽、智能监控、智慧零售、智慧教育、远程办公、明厨亮灶、智慧交通、智慧工地、雪亮工程、平安乡村、生产运输、车载终端等场景有了丰富的经验积累&#xff0c;…

linux创建定时任务

crontab方式 先查看是否有cron systemctl status crond 没有的话就安装 yum install cronie 打开你的crontab文件进行编辑。使用以下命令打开当前用户的crontab文件&#xff1a; crontab -e * * * * * /export/test.sh >> /export/test.log 2>&1/export/test.s…

什么是量化机器人?它能来作些什么?一篇文章带你了解!

在科技日新月异的今天&#xff0c;我们经常会听到一些听起来高大上的词汇&#xff0c;比如“人工智能”、“大数据”和“量化交易”。而在这其中&#xff0c;“量化机器人”更是一个让人既好奇又略感神秘的存在。今天&#xff0c;我们就用通俗易懂的语言&#xff0c;一起来揭开…

通知notification

通知 权限&#xff1a;manifest.xml&#xff0c;可以不提前写&#xff0c;后面写代码时显示缺少点击添加即可。 <uses-permission android:name"android.permission.VIBRATE"/>//振动权限 <uses-permission android:name"android.permission.POST_NOT…

洛谷 7.10 数数

Vanya and Books - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) ac代码 #include<bits/stdc.h> typedef long long ll;#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0) const ll N1e3; using namespace std;int main() {IOS;ll x;cin>>x;ll ans0,px…

阿一课代表随堂分享:红队反向代理之使用frp搭建反向代理

frp反向代理 frp简介 frp 是一个开源、简洁易用、高性能的内网穿透和反向代理软件&#xff0c;支持 tcp, udp, http, https等协议。 frp 是一个可用于内网穿透的高性能的反向代理应用&#xff0c;分为服务端frps和客户端frpc&#xff0c;支持 tcp, udp, http, https 协议。详…

明白这两大关键点,轻松脱单不再是难题!

很多未婚男女都渴望找到心仪的伴侣&#xff0c;建立稳定的情感关系&#xff0c;但往往在脱单的过程中跌跌撞撞。平时与同学、同事之间相处得很融洽&#xff0c;一旦遇到心仪的异性&#xff0c;情商直接掉线&#xff0c;难道情商也会选择性地发挥作用吗&#xff1f;其实&#xf…

怎么将3张照片合并成一张?这几种拼接方法很实用!

怎么将3张照片合并成一张&#xff1f;在我们丰富多彩的日常生活里&#xff0c;是否总爱捕捉那些稍纵即逝的美好瞬间&#xff0c;将它们定格为一张张珍贵的图片&#xff1f;然而&#xff0c;随着时间的推移&#xff0c;这些满载回忆的宝藏却可能逐渐演变成一项管理挑战&#xff…

wifi模组Ai-M62-32S的IO映射和UDP透传测试

wifi模组Ai-M62-32S的IO映射和UDP透传测试 基本IO 映射配网示例开启UDP透传示例复位AT查询wifi是否在线配置DHCP静态IP连接wifi连接UDP开启透传 基本IO 映射 对于wifi模组Ai-62-32S来说其模组 IO 引脚&#xff08;从模组左上角逆时针排序&#xff0c;引脚序号从 1 开始&#x…

10个JavaScript One-Liners让初学者看起来很专业

原文链接&#xff1a;https://pinjarirehan.medium.com/10-javascript-one-liners-for-beginner-developers-to-look-pro-b9548353330a 原文作者&#xff1a;Rehan Pinjari 翻译&#xff1a;小圆 你是不是在辛苦码字时&#xff0c;看到别人轻松甩出一行 JavaScript 就搞定难题…

GitLab和Git

GitLab保姆级教程 文章目录 GitLab保姆级教程一、GitLab安装二、添加组和用户三、新增项目四、Git上传项目说明五、命令行指引 根据以下说明从计算机中上传现有文件&#xff1a;六、创建与合并分支七、GitLab回滚到特定版本八、数据备份与恢复九、docker中创建gitlab GIT 常用命…

maven 依赖冲突

依赖冲突 1、对于 Maven 而言&#xff0c;同一个 groupId 同一个 artifactId 下&#xff0c;只能使用一个 version。 <!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 --><dependency><groupId>org.apache.commons</groupId&…

MVC 生成验证码

在mvc 出现之前 生成验证码思路 在一个html页面上&#xff0c;生成一个验证码&#xff0c;在把这个页面嵌入到需要验证码的页面中。 JS生成验证码 <script type"text/javascript">jQuery(function ($) {/**生成一个随机数**/function randomNum(min, max) {…

从0-1搭建一个web项目(路由目录分析)详解

本章分析vue路由目录文件详解 ObJack-Admin一款基于 Vue3.3、TypeScript、Vite3、Pinia、Element-Plus 开源的后台管理框架。在一定程度上节省您的开发效率。另外本项目还封装了一些常用组件、hooks、指令、动态路由、按钮级别权限控制等功能。感兴趣的小伙伴可以访问源码点个赞…

【计算机网络03】不花钱怎么搭建一个网络实验室

使用GNS3和虚拟机搭建网络实验室 1、安装抓包工具分析数据包2、定义和使用抓包筛选器3、安装和配置GNS34、配置路由器和VPCS5、使用WireShark捕获GNS3网络数据包6、VMware创建虚拟机7、使用思科PacketTracer 1、安装抓包工具分析数据包 官网安装wireshark&#xff1a;https://…

前端面试题26(vue3中响应式实现原理)

Vue 3 中响应式系统的实现主要依赖于 ES6 的 Proxy 对象&#xff0c;这与 Vue 2 中使用 Object.defineProperty 的方式有着本质的区别。Proxy 提供了一种更为强大且灵活的方法来拦截和定制对象的操作&#xff0c;例如获取、设置属性值等。下面是对 Vue 3 响应式系统实现方式的详…

鸿蒙语言基础类库:【@ohos.util.TreeSet (非线性容器TreeSet)】

非线性容器TreeSet 说明&#xff1a; 本模块首批接口从API version 8开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。开发前请熟悉鸿蒙开发指导文档&#xff1a;gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。 T…