机器学习入门【经典的CIFAR10分类】

news2024/9/23 23:29:51

模型

神经网络采用下图
在这里插入图片描述

我使用之后发现迭代多了之后一直最高是正确率65%左右,然后我自己添加了一些Relu激活函数和正则化,现在正确率可以有80%左右。

模型代码

import torch
from torch import nn


class YmModel(nn.Module):
    def __init__(self):
        super(YmModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        return self.model(x)


训练

有一点要说明的是,数据集中并没有验证集,你可以从训练集扣个1w张出来

import torch
import torchvision
from torchvision import transforms

from models.YMModel import YmModel
from torch.utils.data import DataLoader


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
print(len(train_loader), len(test_loader))

print(len(train_dataset), len(test_dataset))

model = YmModel()
#迭代次数
train_epochs = 300
#优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 损失函数
loss_fn = torch.nn.CrossEntropyLoss()

train_epochs_step = 0
best_accuracy = 0.

for epoch in range(train_epochs):
    model.train()
    print(f'Epoch is {epoch}')
    for images, labels in train_loader:
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if train_epochs_step % 100 == 0:
            print(f'Train_Epoch is {train_epochs_step}\t Loss is {loss.item()}')
        train_epochs_step += 1
    train_epochs_step = 0

    with torch.no_grad():
        loss_running_total = 0.
        acc_running_total = 0.
        for images, labels in test_loader:
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss_running_total += loss.item()
            acc_running_total += (outputs.argmax(1) == labels).sum().item()
    acc_running_total /= len(test_dataset)
    if acc_running_total > best_accuracy:
        best_accuracy = acc_running_total
        torch.save(model.state_dict(), './best_model.pth')
    print('accuracy is {}'.format(acc_running_total))
    print('total loss is {}'.format(loss_running_total))
    print('best accuracy is {}'.format(best_accuracy))


验证

import os

import numpy as np
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms

from models.TestColor import TextColor
from models.YMModel import YmModel

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

classes = ('airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
model = YmModel()

model.load_state_dict(torch.load('best_model.pth'))


model.eval()
with torch.no_grad():
    correct = 0.
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
    print('Accuracy : {}'.format(100 * correct / len(test_dataset)))
folder_path = './images'
files_names = os.listdir(folder_path)
transform_test = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

for file_name in files_names:
    image_path = os.path.join(folder_path, file_name)
    image = Image.open(image_path)
    image = transform_test(image)
    image = np.reshape(image, [1, 3, 32, 32])
    output = model(image)
    _, predicted = torch.max(output, 1)
    source_name = os.path.splitext(file_name)[0]
    predicted_class = classes[predicted.item()]
    colors = TextColor.GREEN if predicted_class == source_name else TextColor.RED
    print(f"Source is {TextColor.BLUE}{source_name}{TextColor.RESET}, and predicted is {colors}{predicted_class}{TextColor.RESET}")

结果

TextColor是自定义字体颜色的类,image中就是自己的图片。
结果如下:测试集的正确率有82.7%

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1933870.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【香橙派】Orange pi AIpro开发板评测,与树莓派的横向对比以及实机性能测试

一、前言 在人工智能领域飞速发展的时代,国产厂商们也是紧随时代的步伐,迅龙公司联合华为推出了一款全新的开发板 Orange pi AIpro 作为一款建设人工智能新生态的开发板,它可广泛适用于AI边缘计算、深度视觉学习及视频流AI分析、视频图像分析…

ssh远程登录另一台linux电脑

大部分的博客内容所说的安装好ssh服务后,terminal输入 ssh -p port_number clientnameserver_ip 之后输入密码等等就可以登上别人的电脑 但是这是有一个前提的,就是这两台电脑要在同一个局域网下面。 如果很远呢? 远到不在同一个网下面怎么办…

【智能算法应用】粒子群算法求解带出入点车间布局设计问题

目录 1.算法原理2.数学模型3.结果展示4.参考文献5.代码获取 1.算法原理 【智能算法】粒子群算法(PSO)原理及实现 设施布局问题(Facility Layout Problem, FLP),主要目的是在给定的区域内有效地放置不同设备或部件&am…

大模型学习笔记十一:视觉大模型

一、判别式模型和生成式模型 1)判别式模型Discriminative ①给某一个样本,判断属于某个类别的概率,擅长分类任务,计算量少。(学习策略函数Y f(X)或者条件概率P(YIX)) ②不能反映训练数据本身的特性 ③学习…

JavaScript学习笔记(九)

56、JavaScript 类 56.1 JavaScript 类的语法 请使用关键字 class 创建一个类。 请始终添加一个名为 constructor() 的方法。 JavaScript 类不是对象。 它是 JavaScript 对象的模板。 语法: class ClassName {constructor() { ... } }示例:例子创…

【无人值守】对数据中心电力分配系统发展的影响

数据中心在现代信息发展中承载着巨量数据的计算、存储、挖掘、分析和应用等多个方面的功能,是国计民生各行业的多样化的信息化的资产。对稳定的运行与安全运维是基本需求也是重要的保障。 数据中心属于高能耗产业,对用电负荷大且要求极度稳定。除了对电力…

一文-深入了解Ansible常见模块、安装和部署

1 Ansible 介绍 Ansible是一个配置管理系统configuration management system, python 语言是运维人员必须会的语言, ansible 是一个基于python 开发的(集合了众多运维工具 puppet、cfengine、chef、func、fabric的优点)自动化运维工具, 其功能实现基于ss…

HarmonyOS介绍

一、什么是HarmonyOS HarmonyOS是新一代的智能终端操作系统,为不同设备的智能化、互联与协同提供了统一的语言,为用户带来简捷、流畅、连续、安全可靠的全场景交互体验。 二、HarmonyOS的核心理念 1、一次开发 多端部署 指的是一个工程&#xf…

题解|2023暑期杭电多校05

【原文链接】 (补发)题解|2023暑期杭电多校05 1001.Typhoon 计算几何 题目大意 依次给定 n n n 个坐标 P P P ,预测的台风路线为按顺序两两连接给定坐标所得的折线 现在有 m m m 个庇护所的坐标 S S S ,求每个庇护所到台风…

基于AT89C51单片机的多功能自行车测速计程器(含文档、源码与proteus仿真,以及系统详细介绍)

本篇文章论述的是基于AT89C51单片机的多功能自行车测速计程器的详情介绍,如果对您有帮助的话,还请关注一下哦,如果有资源方面的需要可以联系我。 目录 选题背景 原理图 PCB图 仿真图 代码 系统论文 资源下载 选题背景 美丽的夜晚&…

c++树(一)定义,遍历

目录 树的定义 树的基本术语 树的初始起点:我们定义为根 树的层次: 树的定义: 树的性质 性质1: 性质2: 树形结构存储的两种思路 树的遍历模板 树上信息统计方式1-自顶向下统计 树上信息统计方式2-自底向上统…

【漏洞复现】泛微E-Cology WorkflowServiceXml SQL注入漏洞

0x01 产品简介 泛微e-cology是一款由泛微网络科技开发的协同管理平台,支持人力资源、财务、行政等多功能管理和移动办公。 0x02 漏洞概述 泛微OAE-Cology 接口/services/WorkflowServiceXml 存在SQL注入漏洞,可获取数据库权限,导致数据泄露…

Purple Pi OH在Android11下测试WiFi和LAN的TCP和UDP传输速率

本文适用于在Purple Pi OH在Andriod11下如何测试WiFi和LAN的TCP和UDP传输速率。触觉智能的Purple Pi OH鸿蒙开源主板,是华为Laval官方社区主荐的一款鸿蒙开发主板。 该主板主要针对学生党,极客,工程师,极大降低了开源鸿蒙开发者的…

C语言 ——— 在控制台上打印动态变化的菱形

目录 代码要求 代码实现 代码要求 输入 整数line &#xff0c;菱形的上半部分的长度就为line&#xff08;动态变化的菱形&#xff09; 菱形由 "*" 号构成 代码实现 #include<stdio.h> int main() {// 上半长int line 0;scanf("%d", &line)…

mysql常用函数五大类

mysql常用函数 1. 第一类&#xff1a;数值函数1.1 圆周率pi的值1.2 求绝对值1.3 返回数字的符号1.4 开平方&#xff0c;根号1.5 求两个数的余数1.6 截取正数部分1.7 向上取整数1.8 向下取整数1.9 四舍五入函数1.10 随机数函数1.11 数值左边补位函数1.12 数值右边补位函数1.13 次…

【网络工具】Charles 介绍及环境配置

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/iAmAo &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会整理一些工作或学习中用到的工具介绍给大家~ &#x1f4d8;Charles 系列其它文章&#xff1a;【网络…

PySide在Qt Designer中使用QTableView 显示表格数据

在 PySide6 中&#xff0c;可以使用 Qt Model View 架构中的 QTableView 部件来显示和编辑表格数据。 1、创建ui文件 在Qt Designer中新建QMainWindow&#xff0c;命名为csvShow.ui。QMainWindow上有两个部件&#xff1a;tableview和btn_exit。 2、使用pyuic工具将ui文件转换为…

路由上传一个ui_control参数(uint32类型)控制页面UI显隐

前言&#xff1a;传一个uint32类型的值&#xff0c;通过 按位或操作符&#xff08;|&#xff09;来设置ui_control的值&#xff0c;通过按位与操作符&#xff08;&&#xff09;来检测是否显示或隐藏 简单介绍一下两个概念&#xff1a; 按位与操作符和按位或操作符都是二进…

LeetCode-随机链表的复制

. - 力扣&#xff08;LeetCode&#xff09; 本题思路&#xff1a; 首先注意到随机链表含有random的指针&#xff0c;这个random指针指向是随机的&#xff1b;先一个一个节点的拷贝&#xff0c;并且把拷贝的节点放在拷贝对象的后面&#xff0c;再让拷贝节点的next指向原链表拷贝…

申贷时,被大数据风控拒贷有哪些原因呢?

很多人特别是从事过金融行业的人来说&#xff0c;大数据风控相信都不陌生&#xff0c;因为现在的银行和机构对申贷人的大数据信用看的越来越重要&#xff0c;已然成看贷前审查的重要依据&#xff0c;那申贷时&#xff0c;被大数据风控拒贷有哪些原因呢?本文就与大家一起探讨一…