神经网络框架的基本设计

news2024/9/24 17:13:03

一、神经网络框架设计的基本流程

确定网络结构、激活函数、损失函数、优化算法,模型的训练与验证,模型的评估和优化,模型的部署。

二、网络结构与激活函数

1、网络架构

这里我们使用的是多层感知机模型MLP(multilayer prrceptron):

MLP一般分为三层:输入层、隐藏层和输出层。
输入层:接收输入数据。
隐藏层:负责处理数据,可以有一个或多个隐藏层,每个隐藏层包含若干个神经元。
输出层:输出最终结果,通常是一个softmax层,用于多分类任务,或者是一个sigmoid层,用于二分类任务。
MLP的核心思想是通过增加神经元的数量和层次,提高模型的表达能力。由于有多个层,参数需要在这些层之间传递。

每个隐藏层神经元中计算过程如下:
将输入数据传递给第一个隐藏层的神经元。
对于每个神经元,计算其加权和,即将输入与对应的权重相乘并求和,再加上偏置项。
将加权和输入到激活函数中,得到激活值,作为该神经元的输出。
将每个神经元的输出传递到下一层的神经元,直至输出层。
在这个过程中,数据和权重是前向传播的主要传播内容。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28,312),
            nn.ReLU(),
            nn.Linear(312, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)
        return logits

2、激活函数:

用于引入非线性因素,使得神经网络具有更强的表达能力。常见的激活函数有Sigmoid、ReLU、Tanh等。这里我们暂时略过

3、损失函数:

损失函数是用来量化模型预测值与真实值差距的方式(这么说可能过于抽象,我们初中数学应该学过标准差和方差的概念,方差和标准差是量化数组各元素与平均值的差距的。)而能反映二者间差距的值有很多,常见的有均方损失、交叉熵损失、铰链损失、绝对值损失。

具体什么含义可以自行学习。

损失函数的作用是用来评估模型的准确度的,

在上次实验中,我们使用的是均方损失,这一次我们使用交叉熵损失函数:

loss_fu = torch.nn.CrossEntropyLoss()
loss = loss_fu(pred, label_batch)

torch.nn模块:是构建神经网络的基石,提供了各种类型的层,包括卷积层、池化层、激活函数、循环层和全连接层

4、优化算法

本次还是采用:自适应学习率的优化算法,adam优化器。

三、模型的训练

1、数据集调整

此次的数据集依然是mnist数据集,下载与处理可以看Hello World!-CSDN博客。不过对于idx3-udyte文件,如果我们每次都需要这么读一下,那多少有些麻烦,可以将idx3-udyte文件转换成npy文件以便长期存储。

npy文件是NumPy数组的一种存储格式,用于将Python中的NumPy数组保存到磁盘上,并可以在以后需要时加载回来。它不仅存储了数组的数据,还存储了数组的形状、数据类型等信息。这使得.npy文件非常紧凑且占用空间小,并且可以快速地读写。由于它是二进制格式,所以不能直接用文本编辑器打开查看内容。非常适合在需要频繁保存和加载数组数据的场景下使用,例如机器学习中的模型参数保存、数据集的存储等。

我们回到上期的读取文件的部分

​
def read_data3(self, roadurl):
        with open(roadurl, 'rb') as f:
            content = f.read()
        # print(content)
 
        fmt_header = '>iiii'  # 网络字节序
        offset = 0
 
        magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, content, offset)
        print('幻数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))
        # 定义一张图片需要的数据个数(每个像素一个字节,共需要行*列个字节的数据)
        img_size = num_rows * num_cols
        # struct.calcsize(fmt)用来计算fmt格式所描述的结构的大小
        offset += struct.calcsize(fmt_header)
        # '>784B'是指用大端法读取784个unsigned byte
        fmt_image = '>' + str(img_size) + 'B'
        # 定义了一个三维数组,这个数组共有num_images个 num_rows*num_cols尺寸的矩阵。
        images = np.empty((num_images, num_rows, num_cols))
 
        for i in range(num_images):
            images[i] = np.array(struct.unpack_from(fmt_image, content, offset)).reshape((num_rows, num_cols))
            offset += struct.calcsize(fmt_image)
 
        return images

​

我们已经接收到了images这个numpy数组,接下来只需要增加一点代码

# 保存到
np.save('文件路径/文件名.npy', array)
# 读取文件
data1 = np.load('文件路径/文件名.npy')

2、模型搭建

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 指定GPU编
import torch
import numpy as np
from tqdm import tqdm

batch_size = 320  # 设定每次训练的批次数
epochs = 1024  # 设定训练次数

# device = "cpu"                         # Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"  # 在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式


# 设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 312),
            torch.nn.ReLU(),
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)
        )

    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)

        return logits


model = NeuralNetwork()
model = model.to(device)  # 将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
# model = torch.compile(model)            # Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()       # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)  # 设定优化函数

# 载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")

train_num = len(x_train) // batch_size

# 开始计算
for epoch in range(20):
    train_loss = 0
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        train_batch = torch.tensor(x_train[start:end]).to(device)
        label_batch = torch.tensor(y_train_label[start:end]).to(device)

        pred = model(train_batch)
        loss = loss_fu(pred, label_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  # 记录每个批次的损失值

    # 计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_size
    print("epoch:", epoch, "train_loss:", round(train_loss, 2), "accuracy:", round(accuracy, 2))
torch.save(model, './model.pth')
print("模型已更新")

3、训练完成

四、可视化操作

1、代码

import torch
import torch.nn as nn

if __name__ == '__main__':
    model = NeuralNetwork()
    #print(model)

    params = list(model.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

2、使用netron软件进行可视化(推荐)

在github上下载:https://github.com/lutzroeder/netron

 五、模型的部署(暂无)

从前边可知,所谓模型就是一个由无数参数组成的.pth文件,而.pth文件可以通过python指令读取。

torch.load()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1360170.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iOS苹果和Android安卓测试APP应用程序的区别差异

在移动应用开发中,测试是一个至关重要的环节。无论是iOS苹果还是Android安卓,测试APP应用程序都需要注意一些差异和细节。本文将详细介绍iOS和Android的测试差异,包括操作系统版本、设备适配、测试工具和测试策略,并回答一些新手容…

Hive实战:分科汇总求月考平均分

文章目录 一、实战概述二、提出任务三、完成任务(一)准备数据1、在虚拟机上创建文本文件2、上传文件到HDFS指定目录 (二)实现步骤1、启动Hive Metastore服务2、启动Hive客户端3、创建分区的学生成绩表4、按分区加载数据5、查看分区…

C语言编译器(C语言编程软件)完全攻略(第二十部分:Code::Blocks下载地址和安装教程(图解))

介绍常用C语言编译器的安装、配置和使用。 二十、Code::Blocks下载地址和安装教程(图解) Code::Blocks 是一款免费开源的 C/C IDE,支持 GCC、MSVC 等多种编译器,还可以导入 Dev-C 的项目。Code::Blocks 的优点是:跨…

支持向量机(Support Vector Machines,SVM)

什么是机器学习 支持向量机(Support Vector Machines,SVM)是一种强大的机器学习算法,可用于解决分类和回归问题。SVM的目标是找到一个最优的超平面,以在特征空间中有效地划分不同类别的样本。 基本原理 超平面 在二…

YOLOv8改进 | 损失篇 | VarifocalLoss密集目标检测专用损失函数 (VFLoss,原论文一比一复现)

一、本文介绍 本文给大家带来的是损失函数改进VFLoss损失函数,VFL是一种为密集目标检测器训练预测IoU-aware Classification Scores(IACS)的损失函数,我经过官方的版本将其集成在我们的YOLOv8的损失函数使用上,其中有很多使用的小细节(否则按照官方的版本使用根本拟合不了…

代码随想录刷题第三十九天| 62.不同路径 ● 63. 不同路径 II

代码随想录刷题第三十九天 不同路径 (LC 62) 题目思路: 代码实现: class Solution:def uniquePaths(self, m: int, n: int) -> int:dp [[0 for _ in range(n1)] for _ in range(m1)]dp[0][1] 1for i in range(1,m1):for j in range(1, n1):dp[i]…

Qt6入门教程 2:Qt6下载与安装

Qt6不提供离线安装包,下载和安装实际上是一体的了。 关于Qt简介,详见:Qt6入门教程1:Qt简介 一.下载在线安装器 Qt官网 地址:https://download.qt.io/ 在线下载器地址:https://download.qt.io/archive/on…

PHP运行环境之宝塔软件安装及Web站点部署流程

PHP运行环境之宝塔软件安装及Web站点部署流程 1.1安装宝塔软件 官网:https://www.bt.cn/new/index.html 自行注册账号,稍后有用 下载安装页面:宝塔面板下载,免费全能的服务器运维软件 1.1.1Linux 安装 如图所示,宝…

使用STM32微控制器驱动LCD1602显示器

驱动LCD1602显示器是嵌入式系统常见的任务之一,而STM32微控制器因其灵活性和丰富的外设而成为了广泛采用的解决方案。在这篇文章中,我们将探讨如何使用STM32微控制器来驱动LCD1602显示器。我们将从STM32的GPIO配置、延时函数以及LCD1602的初始化和写入数…

深度学习中的自动化标签转换:对数据集所有标签做映射转换

在机器学习中,特别是在涉及图像识别或分类的项目中,标签数据的组织和准确性至关重要。本文探讨了一个旨在高效转换标签数据的 Python 脚本。该脚本在需要更新或更改类标签的场景中特别有用,这是正在进行的机器学习项目中的常见任务。我们将逐…

Windows BAT脚本 | 定时关机程序

使用说明:输入数字,实现一定时间后自动关机。 单位小时,用后缀 h 或 H。示例 1h 单位分钟,用后缀 m 或 M 或 min。示例 30min 单位秒。用后缀 s 或不用后缀。示例 100s 源码 及 配置方法 桌面新建文本文件,输入下面…

Jmeter相关概念

Jmeter相关概念 jmeter性能指标 Aggregate Report 是 JMeter 常用的一个 Listener,中文被翻译为“聚合报告”。今天再次有同行问到这个报告中的各项数据表示什么意思,顺便在这里公布一下,以备大家查阅。 如果大家都是做Web应用的性能测试&a…

实现并解决微服务间OpenFeign转发文件格式MultipartFile

场景 使用openfeign转发MultipartFile类型的文件时出现了下面的错误。 PostMapping(value "/upload", consumes MediaType.MULTIPART_FORM_DATA_VALUE) ApiOperation(value "导入") public ResponseJson<String> uploadFiles(RequestParam(&quo…

uniapp微信小程序投票系统实战 (SpringBoot2+vue3.2+element plus ) -小程序首页实现

锋哥原创的uniapp微信小程序投票系统实战&#xff1a; uniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )_哔哩哔哩_bilibiliuniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )共计21条视频…

ant-design-vue 使用本地iconfont.js

createFromIconfontCN只能使用【在线资源】&#xff0c;但是在线资源存在不稳定的风险 有人提了issue&#xff0c;不过目前也没有解决&#xff0c;但是有人提出了一种新的的解决方案 参考链接&#xff1a; https://github.com/ant-design/ant-design/issues/16480 main.js im…

package-info.java delete

package-info.java delete

Spring见解2

3.基于注解的IOC配置 学习基于注解的IOC配置&#xff0c;大家脑海里首先得有一个认知&#xff0c;即注解配置和xml配置要实现的功能都是一样的&#xff0c;都是要降低程序间的耦合。只是配置的形式不一样。4 3.1.创建工程 3.1.1.pom.xml <?xml version"1.0" e…

uniapp vue2 车牌号输入组件记录

uniapp vue2 车牌号输入案例记录 组件如图 直接上代码 1.html <template><view><view class"plate" :class"{show: show}"><view class"itemFirst flex-d"><view class"item item1" click"handl…

ubuntu 22 virt-manger(kvm)安装winxp

安装 、启动 virt-manager sudo apt install virt-manager sudo systemctl start libvirtdsudo virt-manager安装windowsXP 安装过程截图如下 要点1 启用 “包括寿终正寝的操作系统” win_xp.iso 安装过程 &#xff1a; 从winXp.iso启动, 执行完自己重启从硬盘重启&#xff0c…

八个LOGO素材网站推荐分享

即时设计资源广场 在UI界面设计中&#xff0c;为了找到合适的图标icon&#xff0c;你有没有尝试过翻遍整个网络&#xff0c;找到自己想要的&#xff0c;却无法下载或收费使用&#xff1f;最后&#xff0c;只收集图标icon材料需要半天时间。专业设计师使用的图标icon设计材料“…