Pytorch-CNN-CIFAR10

news2025/2/27 1:53:41

文章目录

  • model.py
  • main.py
  • 运行图

model.py

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class CNN_cls(nn.Module):
    def __init__(self,in_dim):
        super(CNN_cls,self).__init__()
        self.conv1 = nn.Conv2d(in_dim,32,1,1)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(32,64,1,1)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(64,128,1,1)
        self.lin1 = nn.Linear(128*8*8,512)
        self.lin2 = nn.Linear(512,64)
        self.lin3 = nn.Linear(64,10)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1,128*8*8)
        x = self.lin1(x)
        x = self.relu(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.lin3(x)
        return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import CNN_cls


seed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
net = CNN_cls(3)

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_test, shuffle=True)

optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()

print("****************Begin Training****************")
net.train()
for epoch in range(epochs):
    run_loss = 0
    correct_num = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        out = net(data)
        _,pred = torch.max(out,dim=1)
        optimizer.zero_grad()
        loss = criterion(out,target)
        loss.backward()
        run_loss += loss
        optimizer.step()
        correct_num  += torch.sum(pred==target)
    print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))



print("****************Begin Testing****************")
net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):
    out = net(data)
    _,pred = torch.max(out,dim=1)
    test_loss += criterion(out,target)
    test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

运行图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1030774.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云原生微服务治理 第五章 Spring Cloud Netflix 之 Ribbon

系列文章目录 第一章 Java线程池技术应用 第二章 CountDownLatch和Semaphone的应用 第三章 Spring Cloud 简介 第四章 Spring Cloud Netflix 之 Eureka 第四章 Spring Cloud Netflix 之 Ribbon 文章目录 系列文章目录[TOC](文章目录) 前言1、负载均衡1.1、服务端负载均衡1.2、…

JavaScript系列从入门到精通系列第五篇:JavaScript中的强制类型转换包含强制类型转换之Number,包含强制类型转换之String

文章目录 前言 一:强制类型转换 1:强制类型转换为String (一):方式一:调用被转换类型的toString()方法 (二):方式二:调用String函数 2:强制类型转换为Number (一):方式一&…

【Proteus仿真】【STM32单片机】大棚远程监测控制

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 系统运行后,LCD1604显示传感器检测的环境温湿度、土壤湿度、光照强度、CO2浓度和阈值;可通过上位机远程观察传感器采集的数据显示;可通过K3键进入阈值设置模式&#xf…

如何使用大型语言模型LLMs作为历史课程的教学工具?#提示工程技巧

Mixlab从2018就开始分享过一些关于教育的内容: GPT-4等对教育的未来意味着什么?2023-05-05 学习的目的是什么?我喜欢的教育产品应该是这样的 2019-07-08 你是 Infinite Learner 吗?2018-05-27 今天继续教育的话题,我们…

基于微信小程序的超市售货管理平台设计与实现(源码+lw+部署文档+讲解等)

前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 👇🏻…

通过内网穿透,在Windows 10系统下搭建个人《我的世界》服务器公网联机

文章目录 1. Java环境搭建2.安装我的世界Minecraft服务3. 启动我的世界服务4.局域网测试连接我的世界服务器5. 安装cpolar内网穿透6. 创建隧道映射内网端口7. 测试公网远程联机8. 配置固定TCP端口地址8.1 保留一个固定tcp地址8.2 配置固定tcp地址 9. 使用固定公网地址远程联机 …

浅谈电气防火保护器在地下商场的应用 安科瑞 缪阳扬

摘 要:近年来,我国城市发展速度加速。很多城市大力建造地下建筑设施,比如地铁、地下停车场和地下商场等。地下商场属于人员密集型建筑,其防火设计一直令相关的专家头疼。由于人员密集,防火处理不好将酿成灾难性的后果。…

软件定制APP开发步骤分析|小程序

软件定制APP开发步骤分析|小程序 软件定制开发步骤: 1.需求分析: 这是软件定制开发的第一步,也是最关键的一步。在这个阶段,软件开发团队需要与客户进行沟通,了解客户的具体需求和期望。通过讨论和交流,确…

【C++】左值和右值

基本概念左值和右值左值引用和右值引用 右值引用使用场景和意义左值引用的使用场景左值引用的短板右值引用和移动语义编译器优化移动赋值move右值引用引用左值右值引用的其他使用场景 完美转发万能引用forward 模板函数 基本概念 左值和右值 左值 左值(lvalue&…

tp5连接多个数据库

一、如果你的主数据库配置文件都在config.php里 直接在config.php中中定义db2&#xff1a; 控制器中打印一下&#xff1a; <?php namespace app\index\controller; use think\Controller; use think\Db; use think\Request; class Index extends Controller {public fun…

DEM格式转换:转换NSDTF-DEM国标数据格式为通用格式,使用ArcGIS工具转换NSDTF-DEM国标.dem文件为通用.tif格式。

DEM格式转换&#xff1a;转换NSDTF-DEM国标数据格式为通用格式&#xff0c;使用ArcGIS工具转换NSDTF-DEM国标.dem文件为通用.tif格式。 *.dem是一种比较常见的DEM数据格式&#xff0c;其有两种文件组织方式&#xff0c;即NSDTF-DEM和USGS-DEM。 &#xff08;1&#xff09;NSDT…

【Linux基础】第26讲 Linux 查找和过滤命令(一)——find命令

find命令是根据文件属性进行查找的&#xff0c;如文件名&#xff0c;文件大小&#xff0c;所有者&#xff0c;所有组&#xff0c;是否为空&#xff0c;访问时间&#xff0c;修改时间等。基本格式&#xff1a; find path [options] 先定位到etc 目录下 cd /etc1.按照文件名查找 …

C-Lodop 在域名下使用跨域问题

Access to script at http://localhost:18000/CLodopfuncs.js?priority0 from origin http://xxxxxx has been blocked by CORS policy: The request client is not a secure context and the resource is in more-private address space local. 解决&#xff1a; 浏览器输入…

记一次 Java Testcontainers CPU 100% 问题排查过程

以为代码进入了死循环&#xff0c;结果并没有&#xff01; 文章目录 背景与问题排查过程代码路经确认内存分析咨询 okio 社区等等&#xff0c;好像并没有死循环能否从内存快照发现其他问题&#xff1f; 背景与问题 本问题来源于 ShardingSphere issue: Integration tests occa…

使用applescript自动化trilium的数学公式环境

众所周知&#xff0c;trilium什么都好&#xff0c;就是对数学公式的支持以及markdown格式的导入导出功能太拉了&#xff0c;而最拉的时刻当属把这两个功能结合起来的时候&#xff1a;导入markdown文件之后&#xff0c;原来的数学公式全没了&#xff0c;需要一个一个手动用ctrlm…

解密list的底层奥秘

&#x1f388;个人主页:&#x1f388; :✨✨✨初阶牛✨✨✨ &#x1f43b;强烈推荐优质专栏: &#x1f354;&#x1f35f;&#x1f32f;C的世界(持续更新中) &#x1f43b;推荐专栏1: &#x1f354;&#x1f35f;&#x1f32f;C语言初阶 &#x1f43b;推荐专栏2: &#x1f354;…

AS中部署NCNN

参考链接 http://681314.com/A/Clzr6Q2OBO https://blog.csdn.net/xs1997/article/details/131747372 一、文章背景&#xff1a;公司再进行一个项目时&#xff0c;使用PyTorch框架&#xff0c;python语言及opencv工具进行神经网络深度学习算法进行训练。生成ONNX模型&#xff…

RocketMQ 消费者分类与分组

文章目录 消费者分类PushConsumerPushConsumer 内部原理使用注意事项 SimpleConsumerinvisibleDuration 消息不可见时间 消费者分组&#xff08;消费者负载均衡&#xff09;广播消费和共享消费负载均衡策略多个消费者消费顺序消息多消费者消费顺序消息示例 消费者分组管理关闭自…

八股文死记硬背打脸记

背景 我们都知道&#xff0c;再编程领域数据结构的重要性&#xff0c;常见的数据结构包括 List、Set、Map、Queue、Tree、Graph、Stack 等&#xff0c;其中 List、Set、Map、Queue 可以从广义上统称为集合类数据结构。而Java也提供了很多的集合数据结构以供开发者开箱即用&…

左神高级提升班2 约瑟夫环结构

目录 【案例1】 【题目描述】 【输入描述&#xff1a;】 【输出描述&#xff1a;】 【输入】 【输出】 【思路解析】 【代码实现】 【案例1】 【题目描述】 某公司招聘&#xff0c;有n个人入围&#xff0c;HR在黑板上依次写下m个正整数A1、A2、……、Am&#xff0c;然后…