0 简单的图像分类

news2024/11/27 5:29:16

本文主要针对交通标识图片进行分类,包含62类,这个就是当前科大讯飞比赛,目前准确率在0.94左右,难点如下:

1 类别不均衡,有得种类图片2百多,有个只有10个不到;

2 像素大小不同,导致有的图片很清晰,有的很模糊;

直接上代码:

import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split

from torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings
import pandas as pd
from torch.utils.data import random_split

warnings.filterwarnings("ignore")

# 检测能否使用GPU
print(#labels
torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
)

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
n_classes = 62  # 几种分类的
preteain = False  # 是否下载使用训练参数 有网true 没网false
epoches = 10  # 训练的轮次
traindataset = datasets.ImageFolder(root='../all/data/train_set/', transform=transforms.Compose([
    transforms.Resize((224,224)),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
]))


# 分割比例:比如80%的数据用于训练,20%用于验证
train_val_ratio = 0.8
train_size = int(len(traindataset) * train_val_ratio)
val_size = len(traindataset) - train_size
train_dataset, val_dataset = random_split(traindataset, [train_size, val_size])


classes = traindataset.classes
print(classes)
 
model = models.resnext50_32x4d(pretrained=preteain)
#model = models.resnet34(pretrained=preteain)

if preteain == True:
    for param in model.parameters():
        param.requires_grad = False
        
model.fc = nn.Linear(in_features=2048, out_features=n_classes, bias=True)
model = model.to(device)
 
 
def train_model(model, train_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    for idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        preds = outputs.argmax(dim=1)
        total_corrects += torch.sum(preds.eq(labels))
        total_loss += loss.item() * inputs.size(0)
        total += labels.size(0)
    total_loss = total_loss / total
    acc = 100 * total_corrects / total
    print("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, total_loss, acc))
    return total_loss, acc
 
 
def test_model(model, test_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            total_loss += loss.item() * inputs.size(0)
            total_corrects += torch.sum(preds.eq(labels))
 
        loss = total_loss / total
        accuracy = 100 * total_corrects / total
        print("轮次:%4d|测试集损失:%.5f|测试集准确率:%6.2f%%" % (epoch + 1, loss, accuracy))
        return loss, accuracy
 
 
loss_fn = nn.CrossEntropyLoss().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
for epoch in range(0, epoches):
    loss1, acc1 = train_model(model, train_loader, loss_fn, optimizer, epoch)
    loss2, acc2 = test_model(model, test_loader, loss_fn, optimizer, epoch)

模型预测:

sub = pd.read_csv("../all/data/example.csv")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

model.eval()
for path in os.listdir("../all/data/test_set/"):
    try:
        img = Image.open("../all/data/test_set/"+path)
        img_p = transform(img).unsqueeze(0).to(device)
        output = model(img_p)
        pred = output.argmax(dim=1).item()
        if img.size[0] * img.size[1]<2000:
            plt.imshow(img)
            plt.show()
        p = 100 * nn.Softmax(dim=1)(output).detach().cpu().numpy()[0]
        sub.loc[sub['ImageID'] == path,'label'] = classes[pred]
        print(f'{path} size = {img.size}, 该图像预测类别为:', classes[pred])
    except:
        print(f'error {path}')
sub.loc[sub['ImageID']=='e57471de-6527-4b9b-90a8-4f1d93909216.png','label'] = 'Under Construction'
sub.loc[sub['ImageID']=='ff38d59e-9a11-41e4-901b-67097bb0e960.png','label'] = 'Keep Left'
sub.columns = ['ImageID','Sign Name']
label_map = pd.read_excel("../all/data/label_map.xlsx")
sub_all = pd.merge(left=sub,right=label_map,on='Sign Name',how='left')
#sub_all[['ImageID','label']].to_csv('./sub_resnet34_add_img_ratio_drop_dire.csv',index=False)

个人的心得:

1 如何进行图片增强,图片增强应该注意什么(方向问题);总结一些transforms,数据增强的方式 - 代码天地

2 模型大小如何进行选择;

更新:

6.17: 昨天我尝试使用更多的增强技术,直接上到0.98;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1845325.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【论文笔记】Prefix-Tuning: Optimizing Continuous Prompts for Generation

题目:Prefix-Tuning: Optimizing Continuous Prompts for Generation 来源: ACL 2021 模型名称: Prefix-Tuning 论文链接: https://aclanthology.org/2021.acl-long.353/ 项目链接: https://github.com/XiangLi1999/PrefixTuning 感觉与prompt的想法很相近&#xff0c;那么问题…

php基础语法_面向对象

PHP php代码标记 多种标记来区分php脚本 ASP标记&#xff1a;<% php代码 %> 短标记&#xff1a; 脚本标记: 标准标记&#xff08;常用&#xff09;&#xff1a; 简写风格&#xff1a; ASP风格&#xff1a;<% php代码 %> 注意&#xff1a;简写风格和ASP风格…

【PHP项目实战训练】——使用thinkphp框架对数据进行增删改查功能

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;开发者-曼亿点 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 曼亿点 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a…

【Effective Web】常见的css布局方式--三栏布局

常见的css居中方式–三栏布局 第一种实现&#xff1a;table布局&#xff08;不推荐&#xff09; 缺点&#xff1a;在table加载前&#xff0c;整个table都是空白的&#xff0c;且修改布局排版都十分困难 <table class"container"><td class"left"…

《广州化工》是什么级别的期刊?是正规期刊吗?能评职称吗?

​问题解答 问&#xff1a;《广州化工》是不是核心期刊&#xff1f; 答&#xff1a;不是&#xff0c;是知网收录的正规学术期刊 问&#xff1a;《广州化工》级别&#xff1f; 答&#xff1a;省级。主办单位&#xff1a;广州化工集团有限公司 主管单位&#xff1a;广州化工…

学生护眼大路灯应该怎么选?五款护眼大路灯对比推荐

我们都知道光线无处不在&#xff0c;想要减少近视隐患&#xff0c;就不得不提一下护眼灯了&#xff0c;特别是经常坐在电脑前码字的上班族以及深夜还在学习的学生党这一类人群&#xff0c;经常用眼光线不好不仅影响视力健康&#xff0c;还会影响效率。而一款护眼灯能够提供柔和…

环境配置02:CUDA安装

1. CUDA安装 Nvidia官网下载对应版本CUDA Toolkit CUDA Toolkit 12.1 Downloads | NVIDIA Developer CUDA Toolkit 12.5 Downloads | NVIDIA Developer 安装配置步骤参考&#xff1a;配置显卡cuda与配置pytorch - 知乎 (zhihu.com) 2. 根据CUDA版本&#xff0c;安装cudnn …

【node】启动本地打包文件的方式

前言 … 目标 1 初始化node文件 2 将打包文件通过node发布到本地 3 系列文件 【node】创建本地接口 一 node方式 1 在新建一个空的文件夹node 进入空文件夹在,文件夹的地址栏输入cmd回车,会自动跳转到命令行工具里 2 配置初始化文件 在命令行输入命令npm init,生成pac…

git 上拉下来的新项目web文件夹没有被idea管理,导致启动不了

让idea识别web项目&#xff0c;操作步骤&#xff1a; 1. 打开idea -- 文件 -- 项目结构&#xff1b; 2. 选择 模块 --- 添加 --- web -- 应用 --- 确定&#xff0c;就好了。 3. 文件夹中间出现个圆圈就是被识别到了。

vscode插件开发之 - TestController

TesController概要介绍 TestController 组件是用于实现自定义测试框架和集成测试结果的。它允许开发者定义自己的测试运行器&#xff0c;以支持在VSCode中运行和展示测试。以下是一些使用 TestController 组件的主要场景&#xff1a; 自定义测试框架&#xff1a;如果你正在开发…

如何通过自定义模块DIY出专属个性化的CSDN主页?一招教你搞定!

个人主页&#xff1a;学习前端的小z 个人专栏&#xff1a;HTML5和CSS3悦读 本专栏旨在分享记录每日学习的前端知识和学习笔记的归纳总结&#xff0c;欢迎大家在评论区交流讨论&#xff01; 文章目录 &#x1f4af;如何通过HTMLCSS自定义模板diy出自己的个性化csdn主页&#x…

学分制系统 GetCalendarContentById SQL注入致RCE漏洞复现

0x01 产品简介 学分制系统由上海鹏达计算机系统开发有限公司研发,是基于对职业教育特点和需求的深入理解,结合教育部相关文件精神,并广泛吸纳专家、学者意见而开发的一款综合性管理系统。系统采用模块化的设计方法,方便学校根据自身教学改革特点、信息化建设进程情况选择、…

使用 Java 构建和消费 RESTful 服务的基本方法

REST&#xff08;Representational State Transfer&#xff09;是一种架构风格&#xff0c;它基于Web标准和HTTP协议&#xff0c;常用于构建网络服务。使用Java构建和消费RESTful服务需要掌握一些基本概念和技术。 一、RESTful服务的基本概念 1. REST架构风格 REST架构风格的…

Linux系统资源监控nmon工具下载及使用介绍

一、资源下载 夸克网盘链接&#xff1a;https://pan.quark.cn/s/2684089bc34d 里面包含了各种分享的实用工具&#xff0c;nmon在 Linux服务器监控nmon工具 文件夹内 文件说明&#xff1a; nmon16p_binaries.tar.gz 为最新的nmon官方工具包&#xff0c;支持linux全平台 nmo…

酷开科技丨引领家庭娱乐新潮流,酷开系统带你开启多彩生活新篇章

在繁忙的都市生活节奏中&#xff0c;人们对生活品质的追求从未停歇。家庭娱乐作为提升生活质量的重要部分&#xff0c;随着科技进步和个性化需求的增长&#xff0c;已经发生了翻天覆地的变化。多样化的娱乐方式不仅为家庭生活增添了色彩&#xff0c;也为家庭成员间的相聚带来了…

selenium框架学习

概念 WEB自动化框架 三大组件: selenium IDE 浏览器插件,实现脚本录制WebDriver 实现对浏览器的各种操作(API包)Grid 实现同时对多个用例进行执行,用例在多个浏览器同步执行环境搭建 1、安装selenium: pip install selenium2、安装浏览器 3、安装浏览器驱动(对应的驱…

戴尔外星人原厂系统美版改国行正确识别本机SN,支持F12 Support Assist OS Recevory恢复重置识别SN服务编码

1.重新部署可以永久正确识别My Alienware和Support Assist服务编码 原厂系统远程恢复安装&#xff1a;https://pan.baidu.com/s/166gtt2okmMmuPUL1Fo3Gpg?pwdm64f 提取码:m64f 2.安装有两个软件和官网主页会识别原机的SN码&#xff0c;就是本机服务编码&#xff08;my Alie…

【开发工具】git服务器端安装部署+客户端配置

自己安装一个轻量级的git服务端&#xff0c;仅仅作为代码维护&#xff0c;尤其适合个人代码管理。毕竟代码的版本管理是很有必要的。 这里把git服务端部署在centos系统里&#xff0c;部署完成后可以通过命令行推拉代码&#xff0c;进行版本和用户管理。 一、服务端安装配置 …

开源低代码平台,JeecgBoot v3.7.0 里程碑版本发布

项目介绍 JeecgBoot是一款企业级的低代码平台&#xff01;前后端分离架构 SpringBoot2.x&#xff0c;SpringCloud&#xff0c;Ant Design&Vue3&#xff0c;Mybatis-plus&#xff0c;Shiro&#xff0c;JWT 支持微服务。强大的代码生成器让前后端代码一键生成! JeecgBoot引领…

RocketMQ快速入门:集成spring, springboot实现各类消息消费(七)附带源码

0. 引言 rocketmq支持两种消费模式&#xff1a;pull和push&#xff0c;在实际开发中这两种模式分别是如何实现的呢&#xff0c;在spring框架和springboot框架中集成有什么差异&#xff1f;今天我们一起来探究这两个问题。 1. java client实现消息消费 1、添加依赖 <depen…