pytorch-01

news2024/10/5 12:55:00

加载mnist数据集

one-hot编码实现

import numpy as np
import torch
x_train = np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集,并解压缩
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64)  # 获取前5个样本的标签数据
# 定义一个张量输入,因为此时有 5 个数值,且最大值为9,类别数为10
# 所以我们可以得到 y 的输出结果的形状为 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10)  # 一个参数张量x,10为类别数
print(y)

对于拥有6000个样本的MNIST数据集来说,标签就是一个6000\times 10大小的矩阵张量。

多层感知机模型

#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()  # 拉平图像矩阵
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28*28,312),   # 输入大小为28*28,输出大小为312维的线性变换层
            torch.nn.ReLU(),   # 激活函数层
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)  # 最终输出大小为10,对应one-hot标签维度
        )
    def forward(self, input):   # 构建网络
        x = self.flatten(input)  #拉平矩阵为1维
        logits = self.linear_relu_stack(x) # 多层感知机

        return logits

损失函数

优化函数

model = NeuralNetwork()
loss_fu = torch.nn.CrossEntropyLoss() # 交叉熵损失函数,内置了softmax函数,
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

loss = loss_fu(pred,label_batch)  # 计算损失

完整模型

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编
import torch
import numpy as np


batch_size = 320                        #设定每次训练的批次数
epochs = 1024                           #设定训练次数

#device = "cpu"                         #Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"                         #在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式


#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28*28,312),
            torch.nn.ReLU(),
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)
        )
    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)

        return logits

model = NeuralNetwork()
model = model.to(device)                #将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
#model = torch.compile(model)            #Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")

train_num = len(x_train)//batch_size

#开始计算
for epoch in range(20):
    train_loss = 0
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        train_batch = torch.tensor(x_train[start:end]).to(device)
        label_batch = torch.tensor(y_train_label[start:end]).to(device)

        pred = model(train_batch)
        loss = loss_fu(pred,label_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  # 记录每个批次的损失值

    # 计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_size
    print("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

可视化模型结构和参数

model = NeuralNetwork()
print(model)

是对模型具体使用的函数及其对应的参数进行打印。

格式化显示:

param = list(model.parameters())
k=0
for i in param:
    l = 1
    print('该层结构:'+str(list(i.size())))
    for j in i.size():
        l*=j
    print('该层参数和:'+str(l))
    k = k+l
print("总参数量:"+str(k))

模型保存

model = NeuralNetwork()
torch.save(model, './model.pth')

netron可视化

安装:pip install netron

运行:命令行输入netron

打开:通过网址http://localhost:8080打开

打开保存的模型文件model.pth:

 

 点击颜色块,可以显示详细信息:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1877002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【单片机毕业设计11-基于stm32c8t6的智能水质检测】

【单片机毕业设计11-基于stm32c8t6的智能水质检测】 前言一、功能介绍二、硬件部分三、软件部分总结 前言 🔥这里是小殷学长,单片机毕业设计篇11基于stm32的智能水质检测系统 🧿创作不易,拒绝白嫖可私 一、功能介绍 -------------…

基于VMware的linux操作系统安装(附安装包)

目录 一、linux操作系统下载链接 二、开始导入镜像源 注:若是还没安装VMware请转到高效实现虚拟机(VMware)安装教程(附安装包)-CSDN博客 一、linux操作系统下载链接 1.官网链接下载 ubuntu:ubuntu官网…

连环计 | 第6集 | 百姓有倒悬之危,君臣有累卵之急 | 貂蝉 | 三国演义 | 逐鹿群雄

🙋大家好!我是毛毛张! 🌈个人首页: 神马都会亿点点的毛毛张 📌这篇博客分享的是《三国演义》文学剧本第Ⅰ部分《群雄逐鹿》的第6️⃣集《连环计》的经典语句和文学剧本全集台词 文章目录 1.经典语句2.文学剧本台词 …

【Spring Boot】Java 的数据库连接模板:JDBCTemplate

Java 的数据库连接模板:JDBCTemplate 1.JDBCTemplate 初识1.1 JDBC1.2 JDBCTemplate 2.JDBCTemplate 实现数据的增加、删除、修改和查询2.1 配置基础依赖2.2 新建实体类2.3 操作数据2.3.1 创建数据表2.3.2 添加数据2.3.3 查询数据2.3.4 查询所有记录2.3.5 修改数据2…

AXI接口简介

AXI接口,全称为Advanced eXtensible Interface,是ARM公司推出的一种高性能、低成本、可扩展的高速总线接口。AXI接口是ARM公司提出的AMBA(Advanced Microcontroller Bus Architecture)高级微控制器总线架构的一部分。2003年发布了…

List接口, ArrayList Vector LinkedList

Collection接口的子接口 子类Vector,ArrayList,LinkedList 1.元素的添加顺序和取出顺序一致,且可重复 2.每个元素都有其对应的顺序索引 方法 在index 1 的位置插入一个对象,list.add(1,list2)获取指定index位置的元素&#…

Lr、LrC软件下载安装 Adobe Lightroom专业摄影后期处理软件安装包分享

Adobe Lightroom它不仅为摄影师们提供了一个强大的照片管理平台,更以其出色的后期处理功能,成为了摄影爱好者们争相追捧的必备工具。 在这款软件中,摄影师们可以轻松地管理自己的照片库,无论是按拍摄日期、主题还是其他自定义标签…

LONGAGENT:优化大模型处理长文本

现有的大模型(LLMs),尽管在语言理解和复杂推理任务上取得了显著进展,但在处理这些超长文本时却常常力不从心。它们在面对超过10万令牌的文本输入时,常常会出现性能严重下降的问题,这被称为“中间丢失”现象…

安全与加密常识(0)安全与加密概述

文章目录 一、信息安全的基本概念二、加密技术概述三、常见的安全协议和实践四、加密的挑战与应对 在数字时代,信息安全和加密已成为保护个人和企业数据不受侵犯的关键技术。本文将探讨信息安全的基础、加密的基本原理,以及实用的保护措施,以…

Installed Build Tools revision xxx is corrupted. Remove and install again 解决

1.在buildTools文件下找到对应的sdk版本,首先将版本对应目录下的d8.bat改名为dx.bat。 2.在lib文件下将d8.jar改名为dx.jar。 3.重新编译工程即可

响应式宠物商店网站pbootcms模板

模板介绍 这是一款源码下载响应式宠物商店网站pbootcms模板。该模板采用响应式自适应设计,非常适合宠物行业的任何服务项目或在线商店或宠物网站,下载即用,组织代码优秀。 模板截图 源码下载 响应式宠物商店网站pbootcms模板

Python 算法交易实验74 QTV200第二步(改): 数据清洗并写入Mongo

说明 之前第二步是打算进入Clickhouse的,实测下来有一些bug 可以看到有一些分钟数据重复了。简单分析原因: 1 起异步任务时,还是会有两个任务重复的问题,这个在同步情况下是不会出现的2 数据库没有upsert模式。clickhouse是最近…

代码随想录:链表

文章目录 代码随想录---链表链表基础(创建以及增删查改)设计链表 链表的反转[206. 反转链表](https://leetcode.cn/problems/reverse-linked-list/)递归法迭代法 删除链表倒数第N个结点[19. 删除链表的倒数第 N 个结点](https://leetcode.cn/problems/remove-nth-node-from-end…

3ds Max导出fbx贴图问题简单记录

1.前言 工作中发现3ds Max导出的fbx在其它软件(Autodesk viewer,blender,navisworks,FBXReview等)中丢失了部分贴图,但导出的fbx用3ds Max打开却正常显示。 fbx格式使用范围较广,很多常见的三…

基于MDEV的PCI设备虚拟化DEMO实现

利用周末时间做了一个MDEV虚拟化PCI设备的小试验&#xff0c;简单记录一下&#xff1a; DEMO架构&#xff0c;此图参考了内核文档&#xff1a;Documentation/driver-api/vfio-mediated-device.rst host kernel watchdog pci driver: #include <linux/init.h> #include …

【Java】面试必问之Java常见线上故障排查方案详解

一、问题解析 在软件开发过程中&#xff0c;排查和修复产线问题是每⼀位⼯程师都需要掌握的基本技能。但是在⽣产环境中&#xff0c; 程序代码、硬件、⽹络、协作软件等任⼀因素&#xff0c;都会引发意想不到的问题&#xff0c;所以排查产线问题⽐较困 难&#xff0c;所以问…

关于数据库的ACID几点

首先的话就是关于ACID&#xff0c;最重要的就是原子性了&#xff0c;这是基础。 原子性是指事务包含的所有操作&#xff0c;要么全部完成&#xff0c;要么全部不完成。如果不能保证原子性&#xff0c;可能会出现以下问题&#xff1a; 数据不一致&#xff1a;事务中的部分操作…

QT事件处理及实例(鼠标事件、键盘事件、事件过滤)

这篇文章通过鼠标事件、键盘事件和事件过滤的三个实例介绍事件处理的实现。 鼠标事件及实例 鼠标事件包括鼠标的移动、按下、松开、单击和双击等。 创建一个MouseEvent项目&#xff0c;通过项目介绍如何获得和处理鼠标事件。程序效果如下图所示。 界面布局代码如下&#xff…

【算法训练记录——Day36】

Day36——贪心Ⅳ 1.leetcode_452用最少数量的箭引爆气球2.leetcode_435无重叠区间3.leetcode_763划分字母区间4.leetcode_ 1.leetcode_452用最少数量的箭引爆气球 思路&#xff1a;看了眼题解&#xff0c;局部最优&#xff1a;当气球出现重叠&#xff0c;一起射&#xff0c;所用…

【工具推荐】Nuclei

文章目录 NucleiLinux安装方式Kali安装Windows安装 Nuclei Nuclei 是一款注重于可配置性、可扩展性和易用性的基于模板的快速漏洞验证工具。它使用 Go 语言开发&#xff0c;具有强大的可配置性、可扩展性&#xff0c;并且易于使用。Nuclei 的核心是利用模板&#xff08;表示为简…