DNN代码实战

news2024/11/25 5:55:26

DNN的原理

神经网络通过学习大量样本的输入与输出特征之间的关系,以拟合出输入与输出之间的方程,学习完成后,只给它输入特征,它便会可以给出输出特征。神经网络可以分为这么几步:划分数据集、训练网络、测试网络、使用网络。

划分数据集

数据集里每个样本必须包含输入与输出,将数据集按一定的比例划分为训练集与测试集,分别用于训练网络与测试网络

# 生成数据集
X1 = torch.rand(10000, 1)
X2 = torch.rand(10000, 1)
X3 = torch.rand(10000, 1)
Y1 = ((X1 + X2 + X3) < 1).float()
Y2 = ((1 < (X1 + X2 + X3)) & ((X1 + X2 + X3) < 2))
Y3 = ((X1 + X2 + X3) > 2).float()
# 整合数据集
Data = torch.cat([X1, X2, X3, Y1, Y1, Y2, Y3], axis=1)
# Data = Data.to('cuda: 0 ')
# 划分训练集和测试集
train_size = int(len(Data) * 0.7)
test_size = len(Data) - train_size
Data = Data[torch.randperm(Data.size(0)), :]
train_Data = Data[:train_size, :]
test_Data = Data[train_size:, :]

训练网络

神经网络的训练过程,就是经过很多次前向传播与反向传播的轮回,最终不断调整其内部参数(权重 ω 与偏置 b),以拟合任意复杂函数的过程。内部参数一开始是随机的(如 Xavier 初始值、He 初始值),最终会不断优化到最佳。还有一些训练网络前就要设好的外部参数:网络的层数、每个隐藏层的节点数、每个节点的激活函数类型、学习率、轮回次数、每次轮回的样本数等等。
业界习惯把内部参数称为参数,外部参数称为超参数。

# 定义DNN类
class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 5), nn.ReLU(),
            nn.Linear(5, 5), nn.ReLU(),
            nn.Linear(5, 5), nn.ReLU(),
            nn.Linear(5, 3)
        )
    def forward(self, x):
        y = self.net(x)
        return y
# 创建子类的实例
model = DNN()
# 损失函数
loss_fn = nn.MSELoss()
# 优化算法
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 训练网络
epochs = 100
losses = []
# 给训练集划分输入和输出
X = train_Data[:, :3]
Y = train_Data[:, -3:]
for epoch in range(epochs):
    Pred = model(X)
    loss = loss_fn(Pred, Y)
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
Fig = plt.figure()
plt.plot(range(epochs),losses)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

测试网络

为了防止训练的网络过拟合,因此需要拿出少量的样本进行测试。过拟合的意思是:网络优化好的内部参数只能对训练样本有效,换成其它就寄。当网络训练好后,拿出测试集的输入,进行 1 次前向传播后,将预测的输出与测试集的真实输出进行对比,查看准确率。

# 测试网络
X = test_Data[:, :3]
Y = test_Data[:, -3:]
with torch.no_grad():
    Pred = model(X)
    Pred[:, torch.argmax(Pred, axis=1)] = 1
    Pred[Pred != 1] = 0
    correct = torch.sum((Pred == Y).all(1))
    total = Y.size(0)
    print(f'测试集准确度:{100*correct/total}%')

使用网络

真正使用网络进行预测时,样本只知输入,不知输出。直接将样本的输入进行 1 次前向传播,即可得到预测的输出。

# 保存网络
torch.save(model, 'DNN.path')
new_model = torch.load('DNN.path')

完整代码

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 生成数据集
X1 = torch.rand(10000, 1)
X2 = torch.rand(10000, 1)
X3 = torch.rand(10000, 1)
Y1 = ((X1 + X2 + X3) < 1).float()
Y2 = ((1 < (X1 + X2 + X3)) & ((X1 + X2 + X3) < 2))
Y3 = ((X1 + X2 + X3) > 2).float()
# 整合数据集
Data = torch.cat([X1, X2, X3, Y1, Y1, Y2, Y3], axis=1)
# Data = Data.to('cuda: 0 ')
# 划分训练集和测试集
train_size = int(len(Data) * 0.7)
test_size = len(Data) - train_size
Data = Data[torch.randperm(Data.size(0)), :]
train_Data = Data[:train_size, :]
test_Data = Data[train_size:, :]

# 定义DNN类
class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 5), nn.ReLU(),
            nn.Linear(5, 5), nn.ReLU(),
            nn.Linear(5, 5), nn.ReLU(),
            nn.Linear(5, 3)
        )
    def forward(self, x):
        y = self.net(x)
        return y
# 创建子类的实例
model = DNN()
# 损失函数
loss_fn = nn.MSELoss()
# 优化算法
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 训练网络
epochs = 100
losses = []
# 给训练集划分输入和输出
X = train_Data[:, :3]
Y = train_Data[:, -3:]
for epoch in range(epochs):
    Pred = model(X)
    loss = loss_fn(Pred, Y)
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
Fig = plt.figure()
plt.plot(range(epochs),losses)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

# 测试网络
X = test_Data[:, :3]
Y = test_Data[:, -3:]
with torch.no_grad():
    Pred = model(X)
    Pred[:, torch.argmax(Pred, axis=1)] = 1
    Pred[Pred != 1] = 0
    correct = torch.sum((Pred == Y).all(1))
    total = Y.size(0)
    print(f'测试集准确度:{100*correct/total}%')

# 保存网络
torch.save(model, 'DNN.path')
new_model = torch.load('DNN.path')

运行截图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2045714.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++_2_nullptr关键字(3/3)

本节内容有C的NULL在前面打头阵&#xff0c;学起来犹如探囊取物。 先来分析一段代码&#xff0c;本段代码恰好也结合了上节的宏。 #include<iostream> using namespace std; void f(int x) { cout << "f(int x)" << endl; } void f(int* ptr) { …

Android Settings 跳转流程

我们知道在Settings中&#xff0c;各模块之间的Fragment基本都继承了DashboardFragment&#xff0c;当有点击事件时&#xff0c;就会回调DashboardFragment中的onPerferenceTreeClick()方法。 在onPreferenceTreeClick()方法中可以根据preference的key做事件拦截&#xff0c;如…

Linux线程实用场景

文章目录 前言生产者消费者模型1.基于阻塞队列特点实现使用 2.基于环形队列和信号量实现使用 读者写者模型实现思想 线程池实现 前言 生产者消费者模型和读者写者模型这些模型是用于在线程间协调和管理资源访问的模式, 我们在之前已经理解了线程的概念以及同步与互斥, 现在我们…

无人机之消费级和工业级,两者区别分析

消费级无人机和工业级无人机在多个方面存在显著差异&#xff0c;这些差异主要体现在搭载设备、应用领域、针对用户、使用条件、性能要求、营销模式以及价格等方面。以下是对两者区别的详细分析&#xff1a; 1. 搭载设备 消费级无人机&#xff1a;主要搭载相机&#xff0c;并配…

C++ | Leetcode C++题解之第337题打家劫舍III

题目&#xff1a; 题解&#xff1a; struct SubtreeStatus {int selected;int notSelected; };class Solution { public:SubtreeStatus dfs(TreeNode* node) {if (!node) {return {0, 0};}auto l dfs(node->left);auto r dfs(node->right);int selected node->val…

Windows禁止应用联网

转自两种方法阻止电脑上的软件彻底联网&#xff01; - 知乎 (zhihu.com) 但为了稳妥&#xff0c;自己还是稍微记录一下 1、创建bat脚本文件 创建文本-将下面的代码填入-保存为.bat文件 Echo Off SetLocal:beginecho: echo ****** 禁止文件夹联网 ****** echo:set /p folder…

Qt报“libpng warning: iCCP: known incorrect sRGB profile”问题解决方法

Qt开发应用程序&#xff0c;界面加载图片或按钮加载图标时&#xff0c;会遇到编译器报“libpng warning: iCCP: known incorrect sRGB profile”问题&#xff0c;原因为色彩配置问题&#xff0c;需要修正图像的ICC配置文件&#xff0c;将其转换成sRGB类型。不同操作系统解决方法…

停车场拓扑(parking lot topology)中的 bbr 与 aimd

bbr 讨论组有个有趣的问题&#xff1a;[bbr-dev] Parking lot topology 我此前也意识到这个问题(参见 pacing 之对错)&#xff0c;但几乎所有 cc 的建模都基于 dumbbell topology&#xff0c;parking lot topology 因其太 “不理想”&#xff0c;“不规则” 而无人讨论&#x…

11.2.软件系统分析与设计-数据库分析与设计

文章目录 数据库分析与设计步骤ER图和关系模型 需求分析阶段概念结构设计逻辑结构设计物理结构设计数据库实施与运维 数据库分析与设计 数据库设计属于系统设计的范畴。通常把使用数据库系统的系统统称为数据库应用系统&#xff0c;把对数据库应用系统的设计简称为数据库设计。…

轻松拿捏自动添加好友

释放双手&#xff0c;一键导入数据&#xff01; 通过好友后可以自动备注 轻松自动添加好友&#xff0c;更可以个性化设置验证信息 手动点击“开始”&#xff0c;后台可以看到数据使用情况和添加情况&#xff0c;频繁了会自动停止

【STM32】ADC模拟数字转换(规则组多通道)+ DMA数据转运(外设到存储器)

本篇博客重点在于标准库函数的理解与使用&#xff0c;搭建一个框架便于快速开发 目录 前言 ADC规则组扫描模式DMA 定义变量 规则组配置 ADC初始化 连续模式 扫描模式 规则组通道个数 ADC初始化框架 DMA初始化 ADC和DMA使能 软件触发转运 代码框架 ADC扫描转换与DM…

一眼心动的HAProxy高级功能配置

目录 一.haproxy-基于cookie的会话保持 二.七层IP透传 三.四层IP透传 四.访问控制列表ACL 五.acl做动静分离访问控制 六.基于自定义的错误页面文件 七.HAProxy 四层负载 八.HAProxy https 实现 九.让文件编写更简单的方法 一.haproxy-基于cookie的会话保持 cookie va…

C语言程序设计(初识C语言后部分)

1024M1GB&#xff0c;1GB1级棒。关爱一级棒的程序员们&#xff0c;宠TA没商量&#xff01; 5&#xff09;函数的嵌套调用和链式访问 函数和函数之间可以根据实际的需求进行组合的&#xff0c;也就是相互调用的。 1.嵌套调用 函数可以嵌套调用&#xff0c;但不可以嵌套定义&a…

【网络】UDP和TCP之间的差别和回显服务器

文章目录 UDP 和 TCP 之间的差别有连接/无连接可靠传输/不可靠传输面向字节流/面向数据报全双工/半双工 UDP/TCP API 的使用UDP APIDatagramSocket构造方法方法 DatagramPacket构造方法方法 回显服务器&#xff08;Echo Server&#xff09;1. 接收请求2. 根据请求计算响应3. 将…

html+css+js网页制作 纳尔多珠宝40个页面

htmlcssjs网页制作 纳尔多珠宝40个页面 网页作品代码简单&#xff0c;可使用任意HTML编辑软件&#xff08;如&#xff1a;Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改编辑等操作&#xff09;。 获取源码 1&#…

用python制作88键赛博钢琴(能用鼠标键盘进行弹奏)

用python制作88键赛博钢琴 前言 恭喜这位博主终于想起了自己的账号密码&#xff01; 时光荏苒&#xff0c;转眼间已逾一年未曾在此留下墨香。尽管这一年间&#xff0c;博主投身于无尽的忙碌与挑战之中&#xff0c;但令人欣慰的是&#xff0c;那份初心与热情似乎并未因岁月的流…

谷歌浏览器网页底图设置为全黑

输入网址&#xff1a;chrome://flags/ 搜索dark&#xff0c;选择Enabled&#xff0c;重启浏览器即可

C#使用SharGL实现PUMA560机械臂

1、四轴机械臂 下载链接&#xff1a;https://download.csdn.net/download/panjinliang066333/89645225 关键代码 public void DrawRobot1(ref OpenGL gl,float[] angle,float[] yLength,bool isPuma560_Six){//坐标系说明&#xff1a;//①X轴正向&#xff1a;屏幕朝右//②Y轴…

【运维系列】windows虚拟机作为服务器,将服务启动作为脚本设置为开机自启,服务中断、手动操作的烦恼通通滚蛋!

文章目录 前言一、开机启动文件夹&#xff08;StartUp&#xff09;是否可行&#xff1f;二、任务计划程序1.编写脚本2.打开任务计划程序3.创建任务4.配置常规选项5.配置触发器选项6. 配置操作选项7.配置条件选项8.配置设置选项 总结 前言 在实际应用过程中&#xff0c;我们难免…

有没有电脑桌面监控软件|大佬都在用的7大电脑屏幕监控软件!

当谈到电脑桌面监控软件时&#xff0c;确实有许多受欢迎且功能强大的选项。 这些软件在企业管理、远程办公、家庭监控等多个领域都有广泛应用。 以下是大佬常用的7大电脑屏幕监控软件推荐&#xff1a; 1. Teramind 特点&#xff1a;它是一款功能强大的企业级监控软件&#x…