神经网络分类和回归任务实战

news2025/1/13 7:40:01

学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程

直接上案例,先来跑,遇到什么解决什么

数据集Minist 数据集

做简单的任务 Minist 分类任务

总体代码(可以跑通)

from pathlib import Path
import requests
import pickle
import gzip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
bs=64
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
                ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
from matplotlib import pyplot
import numpy as np
print(x_train.shape)

# pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#获取训练数据和测试数据
x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
#设置模型结构
class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
net = Mnist_NN()
# print(net)
# for name, parameter in net.named_parameters():
#     print(name, parameter,parameter.size())
#设置数据集和数据集加载器
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64 * 2)
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )

loss_func = F.cross_entropy
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)
#训练参数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(20, model, loss_func, opt, train_dl, valid_dl)
corret=0
total=0
for xb,yb in valid_dl:
    outputs=model(xb)
    _,predicted=torch.max(outputs.data,1)
    total+=yb.size(0)
    corret+=(predicted==yb).sum().item()
print('准确率是:%d %%'%(100*corret/total))

1.首先我们从最终实现的fit 函数开始看,

 在fit h函数之前有一个get_model 函数 得到model和优化器

model, opt = get_model()

得到模型的优化器以后

需要把训练轮数 模型 损失函数 训练数据 测试数据传入fit 训练函数

fit(20, model, loss_func, opt, train_dl, valid_dl)

fit 函数

def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))

xb是从dataloader 中取64个训练数据图片 也就是64*784维度,784代表28*28的手写数字图片的展平成一维向量

yb是64个图片对应的数字值

loss_batch(model, loss_func, xb, yb, opt)

我们看一下loss_batch 函数

loss 反向传播——更新优化器——优化器梯度归0

loss 是一个带有梯度的tensor    .item()返回的是loss 的值 len(xb )是为了求精度的时候算

输入是模型损失函数xb ,yb 和优化器

loss_func = F.cross_entropy 损失函数是交叉熵损失函数

将xb 经过model 得到输出后 和xb 求损失函数  model 是定义的一个简单的模型有两个隐藏层一个输出层 784——128——256——10

再回到fit 函数的验证部分model.evl()

先来一个 

with torch.no_grad()

不去计算梯度

zip(*的意思是解压缩 分别得到losses 和nums)

loss 和num 鲜橙 求每64个batch 的总loss 再将datalosder的所有batch 相加除以总数得到训练损失

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1573141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于vue+node.js导师选择分配管理系统

开发语言 node.js 框架:Express 前端:Vue.js 数据库:mysql 数据库工具:Navicat 开发软件:VScode .设计一套导师选择管理系统,帮助学校进行导师选择管理等繁琐又重复的工作,提高工作效率的同时&#xff0c…

三、Jenkins相关操作

Jenkins操作 一、插件管理1.修改公共插件源2.下载中文汉化插件2.1 安装插件2.2 重启2.3 设置为中文 3.远程部署插件 二、用户权限管理1.安装权限插件2.开启权限3.创建角色3.1 Global roles3.2 Item roles 4.创建用户5.给用户分配角色 三、凭证管理四、Git管理1.账号密码方式1.1…

【智能排班系统】雪花算法生成分布式ID

文章目录 雪花算法介绍起源与命名基本原理与结构优势与特点应用场景 代码实现代码结构自定义机器标识RandomWorkIdChooseLocalRedisWorkIdChooselua脚本 实体类SnowflakeIdInfoWorkCenterInfo 雪花算法类配置类雪花算法工具类 说明 雪花算法介绍 在复杂而庞大的分布式系统中&a…

力控机器人原理及力控制实现

力控机器人原理及力控制实现 力控机器人是一种能够感知力量并具有实时控制能力的机器人系统。它们可以在与人类进行精准协作和合作时,将力传感技术(Force Sensing Technology)和控制算法(Control Algorithm)结合起来&a…

场景文本检测识别学习 day01(传统OCR的流程、常见的损失函数)

传统OCR的流程 传统OCR:传统光学字符识别常见的的模型主要包括以下几个步骤来识别文本 预处理:预处理是指对输入的图像进行处理,以提高文字识别的准确率。这可能包括调整图像大小、转换为灰度图像、二值化(将图像转换为黑白两色&…

4.1.k8s的pod-创建,数据持久化,网络暴露,env环境变量

目录 一、Pod介绍 二、指令创建和管理Pod 三、资源清单创建pod 1.挂载hostPath存储卷 2.NFS存储卷 所有节点安装nfs k8s3编辑NFS配置文件 k8s1,k8s2节点开机挂载 编辑pod资源清单,挂载nfs 四、pod网络暴露 1.hostNetwork使用宿主机的网络 2.…

Struts2的入门:新建项目——》导入jar包——》jsp,action,struts.xml,web.xml——》在项目运行

文章目录 配置环境tomcat 新建项目导入jar包新建jsp界面新建action类新建struts.xml,用来配置action文件配置Struts2的核心过滤器:web.xml 启动测试给一个返回界面在struts.xml中配置以实现页面的跳转:result再写个success.jsp最后在项目运行 配置环境 …

实现点击用户头像或者id与其用户进行聊天(vue+springboot+WebSocket)

用户点击id直接与另一位用户聊天 前端如此&#xff1a; <template><!-- 消息盒子 --><div class"content-box" :style"contentWidth"><!-- 头像&#xff0c;用户名 --><div class"content-box-top box--flex">&l…

sqlmap(四)案例

一、注入DB2 http://124.70.71.251:49431/new_list.php?id1 这是墨者学院里的靶机&#xff0c;地址&#xff1a;https://www.mozhe.cn/ 1.1 测试数据库类型 python sqlmap.py -u "http://124.70.71.251:49431/new_list.php?id1" 1.2 测试用户权限类型 查询选…

【linux】ubuntu ib网卡驱动如何适配

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…

晚枫|2024年3月总结,人生中第一次「车祸」

这个3月&#xff0c;真正做到了&#xff1a;过得紧张又刺激&#xff01; 1、第一次「车祸」 整个3月&#xff0c;凌晨1点之前没睡过觉&#xff0c;大部分时间都是2点以后睡。 从去年10月和朋友一起办Python中国的活动开始&#xff0c;就逐渐进入这种状态了。 除了本职工作因…

c# wpf LiveCharts 绑定 简单试验

1.概要 c# wpf LiveCharts 绑定 简单试验 2.代码 <Window x:Class"WpfApp3.Window2"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schem…

Python3 Ubuntu

一、安装中文输入法 1.sudo apt install ibus-sunpinyin 2.点击右上角输入法&#xff0c;然后点击加号&#xff0c;输入yin添加进来&#xff0c;最后选中输入法即可 二、安装截屏软件 1.sudo apt install gnome-screenshot 三、安装opencv-python 1.pip3 install --upgrade…

CleanMyMac ----MacBook卡顿处理小妙招

MacBook 是苹果公司的一款非常流行的笔记本电脑&#xff0c;但使用时 间久了之后很容易出现卡顿的问题。那么出现卡顿的原因有哪些呢&#xff1f; MacBook卡顿怎么处理?先找出卡顿的一般原因&#xff1a; 我们需要排除 MacBook 的硬件问题。可以通过访问 Mac 系统自带的硬件检…

【C++】红黑树讲解及实现

前言&#xff1a; AVL树与红黑树相似&#xff0c;都是一种平衡二叉搜索树&#xff0c;但是AVL树的平衡要求太严格&#xff0c;如果要对AVL树做一些结构修改的操作性能会非常低下&#xff0c;比如&#xff1a;插入时要维护其绝对平衡&#xff0c;旋转的次数比较多&#xff0c;更…

pyside6怎么使用Qt Designer设计自定义组件

第一步&#xff0c;新建一个自定义组件的python文件 from PySide6.QtWidgets import QPlainTextEdit from PySide6.QtCore import Signal,Qtclass CustomPlainTextEdit(QPlainTextEdit):enterPressed Signal(str)def __init__(self, parentNone):super().__init__(parent)def…

EChart简单入门

echart的安装就细不讲了&#xff0c;直接去官网下&#xff0c;实在不会的直接用cdn,省的一番口舌。 cdn.staticfile.net/echarts/4.3.0/echarts.min.js 正入话题哈 什么是EChart&#xff1f; EChart 是一个使用 JavaScript 实现的开源可视化库&#xff0c;Echart支持多种常…

【环境变量】常见的环境变量 | 相关指令 | 环境变量系统程序的结合理解

目录 常见的环境变量 HOME PWD SHELL HISTSIZE 环境变量相关的指令 echo&env export unset 本地变量 环境变量整体理解 程序现象_代码查看环境变量 整体理解 环境变量表 环境变量表的传递 环境变量表的查看 测试验证 少说废话&#x1f197; 每个用户…

前端canvas项目实战——在线图文编辑器(八):复制、删除、锁定、层叠顺序

目录 前言一、效果展示二、实现步骤1. 复制2. 删除3. 锁定4. 层叠顺序 三、实现过程中发现的bug1. clone方法不复制自定义属性2. 复制「锁定」状态的对象&#xff0c;得到的新对象也是「锁定」状态 四、Show u the code后记 前言 上一篇博文中&#xff0c;我们细致的讲解了实现…

高精度端到端在线校准环视相机和LIDAR(精度0.2度内!无需训练数据)

高精度端到端在线校准环视相机和LIDAR&#xff08;精度0.2度内&#xff01;无需训练数据&#xff09; 附赠自动驾驶学习资料和量产经验&#xff1a;链接 写在前面 在自动驾驶车辆的使用寿命内&#xff0c;传感器外参校准会因振动、温度和碰撞等环境因素而发生变化。即使是看似…