Pytorch-MLP-CIFAR10

news2024/11/25 11:36:25

文章目录

  • model.py
  • main.py
  • 参数设置
  • 注意事项
  • 运行图

model.py

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class MLP_cls(nn.Module):
    def __init__(self,in_dim=3*32*32):
        super(MLP_cls,self).__init__()
        self.lin1 = nn.Linear(in_dim,128)
        self.lin2 = nn.Linear(128,64)
        self.lin3 = nn.Linear(64,10)
        self.relu = nn.ReLU()
        init.xavier_uniform_(self.lin1.weight)
        init.xavier_uniform_(self.lin2.weight)
        init.xavier_uniform_(self.lin3.weight)

    def forward(self,x):
        x = x.view(-1,3*32*32)
        x = self.lin1(x)
        x = self.relu(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.lin3(x)
        x = self.relu(x)
        return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_cls,CNN_cls


seed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
net = MLP_cls()

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_test, shuffle=True)

optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()

print("****************Begin Training****************")
net.train()
for epoch in range(epochs):
    run_loss = 0
    correct_num = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        out = net(data)
        _,pred = torch.max(out,dim=1)
        optimizer.zero_grad()
        loss = criterion(out,target)
        loss.backward()
        run_loss += loss
        optimizer.step()
        correct_num  += torch.sum(pred==target)
    print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))



print("****************Begin Testing****************")
net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):
    out = net(data)
    _,pred = torch.max(out,dim=1)
    test_loss += criterion(out,target)
    test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

参数设置

'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test  = 64
epochs = 10

optim --> SGD
learning_rate = 0.01
momentum = 0.5

注意事项

CIFAR10是彩色图像,单个大小为3*32*32。所以view的时候后面展平。

运行图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1020429.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RFID自动识别技术在数控工具系统的应用

RFID是一种自动识别技术,最早是应用在二战中进行敌我侦察机的识别,但是随着民用通信技术的放开,近年来网络通信技术以及信息安全技术都取得了重大的发展,RFID技术也逐渐在民用领域应用。 RFID自动识别技术在数控工具系统的应用 1、…

浅谈PDM与MES系统集成

摘要: 目前MES在制造行业变得炙手可热,然而很多企业都忽视了数据的源头,MES作为生产执行的信息化系统,我们该如何让其在企业中成功的实施,发挥更大的作用,这还需要PDM系统的支撑。本文就PDM与MES集成进行简…

css前端面试题(三)

文章目录 1、可继承属性和不可继承属性字体系列属性文本系列属性元素可见性列表布局属性光标属性 2、link和import的区别3、css优化4、 CSS预处理器/后处理器是什么?为什么要使用它们?5、单行、多行文本溢出隐藏6、实现一个扇形7、实现一个自适应的正方形…

【Axure高保真原型】人物卡片多条件搜索案例

今天和大家分享人物卡片多条件搜索的原型模板,我们可以输入姓名或者选择部门、岗位来快速筛选出对应的人物信息卡片。那这个模板是用中继器制作的,所以使用也很方便,只需要在中继器表格导入图片和填写对应内容,即可自动生成交互效…

1600*A. LCM Challenge(数论 || 找规律)

解析&#xff1a; n<3&#xff0c;特判 n为奇数&#xff0c;则n、n-1、n-2必定互质&#xff0c;所以结果即为三者之和。 n为偶数&#xff0c; 不会严格证明原因&#xff0c;但是找找规律&#xff0c;是这样的...... #include<bits/stdc.h> using namespace std; #de…

ros----发布者和订阅者模型

话题模型&#xff1a; 如何自定义话题消息 1.定义msg文件 2.在package.xml中添加功能包依赖 <build_depend>message_generation</build_depend> <exec_depend>message_runtime</exec_depend>3.在CMakeList.txt文件中添加编译选项 4.编译生成语言的相…

网络工程师是干什么的?常见岗位有哪些?

网络工程师是做什么工作&#xff1f; 网络工程师能够从事计算机信息系统的设计、建设、运行和维护工作。一般来说&#xff0c;分硬件网络工程师和软件网络工程师两大类&#xff0c;硬件网络工程师以负责网络硬件等物理设备的维护和通信&#xff1b;软件网络工程师负责系统软件…

LeetCode(力扣)509. 斐波那契数Python

LeetCode509. 斐波那契数 题目链接代码 题目链接 https://leetcode.cn/problems/fibonacci-number/ 代码 class Solution:def fib(self, n: int) -> int:if n 0:return 0dp [0] * (n 1)dp[0] 0dp[1] 1for i in range(2, n 1):dp[i] dp[i - 1] dp[i - 2]return d…

IM即时通讯系统[SpringBoot+Netty]——梳理(总)

文章目录 一、为什么要自研一套即时通讯系统1、实现一个即时通讯系统有哪些方式1.1、使用开源产品做二次开发或直接使用1.2、使用付费的云服务商1.3、自研 2、如何自研一套即时通讯系统2.1、早期即时通讯系统是如何实现2.2、一套即时通讯系统的基本组成2.3、当下的即时通讯系统…

【milkv】st7735驱动

前言 本文介绍milkv-duo加载st7735的lcd屏幕&#xff0c;以及屏幕显示log。 参考文章&#xff1a; 记录为Linux配置spi屏幕&#xff08;st7735s&#xff09; https://community.milkv.io/t/milk-v-duo-spi-st7789/131 一、电路图 1.1 pin设置 打开spi2的引脚 duo-buildroot…

@Transactional失效场景/原因

文章目录 1.Transactional注解在非public方法上2.Transactional使用propagation设置错误&#xff08;有3种会失效&#xff09;3.Transactional使用rollbackFor设置错误4.A方法没有使用Transactional调用了B&#xff08;有被注解&#xff09;方法5.try catch了异常6.数据库引擎不…

AI绘制流程图

1、工具&#xff1a; 使用https://chat.openai.com/c/45a81a53-cced-43f7-be3e-e5e80f1e994fFlowchart Maker & Online Diagram Software 2、使用plantuml的过程&#xff1a; 复制代码&#xff0c;打开diagram.net&#xff0c;点击加号→高级→ plantUML&#xff0c;替换掉…

已解决 Go Error: cannot use str (type string) as type int in assignment

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页: &#x1f405;&#x1f43e;猫头虎的博客&#x1f390;《面试题大全专栏》 &#x1f995; 文章图文并茂&#x1f996…

【openKylin】OpenKylin1.0 x86_64 VMWare安装手册

&#x1f341; 博主 "开着拖拉机回家"带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——&#x1f390;开着拖拉机回家_大数据运维-CSDN博客 &#x1f390;✨&#x1f341; &#x1fa81;&#x1f341; 希望本文能够给您带来一定的帮助&#x1f338;文…

华为云ROMA Connect亮相Gartner®全球应用创新及商业解决方案峰会,助力企业应用集成和数字化转型

9月13日-9月14日 Gartner全球应用创新及商业解决方案峰会在伦敦举行 本届峰会以“重塑软件交付&#xff0c;驱动业务价值”为主题&#xff0c;全球1000多位业内专家交流最新的企业应用、软件工程、解决方案架构、集成与自动化、API等企业IT战略和新兴技术热门话题。 9月13日…

【结构体类型——详细讲解】

结构体 1.结构体类型声明 1.1结构体的概念 结构体是⼀些值的集合&#xff0c;这些值称为成员变量。结构体的每个成员可以是不同类型的变量。 1.2 结构的声明 struct tag { member-list; }variable-list;例如描述⼀个学⽣&#xff1a; struct Stu { char name[20]; //名字 i…

337.打家劫舍III

337. 打家劫舍 III - 力扣&#xff08;LeetCode&#xff09; 小偷又发现了一个新的可行窃的地区。这个地区只有一个入口&#xff0c;我们称之为 root 。 除了 root 之外&#xff0c;每栋房子有且只有一个“父“房子与之相连。一番侦察之后&#xff0c;聪明的小偷意识到“这个…

SpringSecurity 入门

文章目录 Spring Security概念快速入门案例环境准备Spring配置文件SpringMVC配置文件log4j配置文件web.xmlTomcat插件 整合SpringSecurity 认证操作自定义登录页面关闭CSRF拦截数据库认证加密认证状态记住我授权注解使用标签使用 Spring Security概念 Spring Security是Spring…

竞赛选题 基于深度学习的人脸性别年龄识别 - 图像识别 opencv

文章目录 0 前言1 课题描述2 实现效果3 算法实现原理3.1 数据集3.2 深度学习识别算法3.3 特征提取主干网络3.4 总体实现流程 4 具体实现4.1 预训练数据格式4.2 部分实现代码 5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 毕业设计…

常见七大排序算法

目录 前言 冒泡排序 选择排序 插入排序 希尔排序&#xff08;shell&#xff09; 快速排序 归并排序 计数排序 前言 在前面我发布了常见的七大排序算法的相关博客&#xff0c;今天这一篇文章是做一个排序算法的小总结&#xff0c;把前面的博客集中分类到一起&#xff0c;…