搭建一个简单的深度神经网络

news2024/12/27 14:30:05

目录

一、引入所需要的库

二、制作数据集

三、搭建神经网络

四、训练网络

五、测试网络

本博客实验环境为jupyter

一、引入所需要的库

torch库是核心,其中torch.nn 提供了搭建网络所需的所有组件,nn即神经网络。matplotlib类似与matlab,其中pyplot用于进行数据可视化,如绘制图表、曲线等。%matplotlib inline: 这是IPython(Jupyter Notebook)的魔法命令,用于在Notebook中直接显示Matplotlib绘制的图表,而不是弹出一个新窗口显示。

import torch 
import torch.nn as nn 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
%matplotlib inline

# 展示高清图 
from matplotlib_inline import backend_inline #导入Matplotlib库中的backend_inline模块,用于控制图表的显示方式。
backend_inline.set_matplotlib_formats('svg') #设置Matplotlib图表的显示格式为SVG格式,SVG格式的图表在显示时具有高清晰度,适合用于展示精细的图形。

二、制作数据集

主要任务是读取数据集,划分为训练集和测试集,一定要随机划分。

读取的数据集中共760组数据,共8个输入特征,1个输出特征。

其中第一列是索引,从0开始,70%为训练集,30%为测试集。

#读取数据
df = pd.read_csv('Data.csv', index_col=0)#之前的pandas库中有介绍到,即df为读取后的对象,以第一列为索引    
arr = df.values #转化为numpy数组              
arr = arr.astype(np.float32)#转化为深度学习常用的单精度浮点类型    
ts = torch.tensor(arr)#转化为张量tensor         
ts = ts.to('cuda')#送到cuda设备上即gpu上计算             

# 划分训练集与测试集 
train_size = int(len(ts) * 0.7) #训练集的大小为百分之七十          
test_size = len(ts) - train_size #测试集的大小为百分之三十         
ts = ts[ torch.randperm( ts.size(0) ) , : ] #随机打乱数据集中样本的顺序    
train_Data = ts[ : train_size , : ] #将前百分之七十行给训练集       
test_Data = ts[ train_size : , : ]  #将百分之七十后的行给测试集        

三、搭建神经网络

主要是构建DNN类,需要对python的类定义有较为深入的理解能力。

class DNN(nn.Module): #定义了一个名为DNN的PyTorch模型类,该类继承自nn.Module类,表示这是一个神经网络模型。
 
    def __init__(self): #定义了模型的初始化方法
        ''' 搭建神经网络各层 ''' 
        super(DNN,self).__init__() #调用父类的初始化方法,确保模型的其他部分也能够被正确初始化。
        self.net = nn.Sequential(            # 按顺序搭建各层 
            nn.Linear(8, 32), nn.Sigmoid(),   # 第1层:全连接层 ,是一个包含32个神经元的全连接层,输入特征数为8(表示输入数据维度为8),并使用Sigmoid激活函数。
            nn.Linear(32, 8), nn.Sigmoid(),   # 第2层:全连接层 ,是一个包含8个神经元的全连接层,输入特征数为32,同样使用Sigmoid激活函数。
            nn.Linear(8, 4), nn.Sigmoid(),    # 第3层:全连接层 ,是一个包含4个神经元的全连接层,输入特征数为8,同样使用Sigmoid激活函数。
            nn.Linear(4, 1), nn.Sigmoid()    # 第4层:全连接层 ,是一个包含1个神经元的全连接层,输入特征数为4,同样使用Sigmoid激活函数。这是模型的输出层。
        ) 
 
    def forward(self, x): 
        ''' 前向传播 ''' 
        y = self.net(x)    # 将输入数据x通过模型定义的神经网络结构self.net进行前向传播计算,得到输出y。
        return y        # y即输出数据
 model = DNN().to('cuda:0')    # 创建子类的实例,并搬到GPU上 

这个代码可以当做模板,其中需要修改的部分为网络层的搭建,输入特征,中间层,输出特征一般都要为2的n次幂。

这就是该实例的各层。

四、训练网络

通过前向传播,反向传播等操作,本质上是不断调整权重和偏置。

loss_fn = nn.BCELoss(reduction='mean')#选择二元交叉熵损失函数作为模型的损失函数,其中reduction='mean'表示采用平均损失值作为最终的损失值。
learning_rate = 0.005    # 设置学习率 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 
# 训练网络 
epochs = 5000     #设置训练的总轮数为5000。
losses = []        # 记录损失函数变化的列表 
 
# 给训练集划分输入与输出 
X = train_Data[ : , : -1 ]                # 前8列为输入特征 
Y = train_Data[ : , -1 ].reshape((-1,1))    # 后1列为输出特征
for epoch in range(epochs):     #对于每个epoch进行循环。
    Pred = model(X)              #通过模型进行一次前向传播,得到模型的预测结果Pred。
    loss = loss_fn(Pred, Y)        #计算模型预测结果与实际标签之间的损失值。
    losses.append(loss.item())     #将当前轮次的损失值记录到losses列表中。
    optimizer.zero_grad()        #清空上一轮的梯度信息。
    loss.backward()             #进行反向传播,计算梯度。
    optimizer.step()             #根据优化算法更新模型的参数,完成一轮训练。

#绘制损失函数随训练轮次变化的图像,用于可视化训练过程中损失值的变化。
Fig = plt.figure() 
plt.plot(range(epochs), losses) 
plt.ylabel('loss') 
plt.xlabel('epoch') 
plt.show() 

生成结果为

可以发现随着训练的进行,loss开始减少。

五、测试网络

通过用训练好的模型对测试集进行测试,由于只有一个输出特征为0或者1,将大于0.5的置为1,小于0.5的置为0,可以类比成可能性从0到1。

# 测试网络 
# 给测试集划分输入与输出 
X = test_Data[ : , : -1 ]                
Y = test_Data[ : , -1 ].reshape((-1,1))    
with torch.no_grad():    #进入上下文管理器,表示接下来的计算不会被记录在计算图中,因此不会影响梯度的计算。
    Pred = model(X)     
    Pred[Pred>=0.5] = 1 
    Pred[Pred<0.5] = 0 
    correct = torch.sum( (Pred == Y).all(1) )    #统计预测正确的样本数,使用.all(1)表示在第1维度(即行)上进行比较,得到一个布尔张量,再进行求和操作。
    total = Y.size(0)   #获取试集样本总数。
    print(f'测试集精准度: {100*correct/total} %') 

一般精准度得百分之八九十才合格哦,所以精度不高很有可能是训练集或者环境的问题,所以训练前一定要做好准备工作,因为训练一个模型要花费很久时间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1814529.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Apple Intelligence全面来袭,熟悉但又不同的味道

大模型技术论文不断&#xff0c;每个月总会新增上千篇。本专栏精选论文重点解读&#xff0c;主题还是围绕着行业实践和工程量产。若在某个环节出现卡点&#xff0c;可以回到大模型必备腔调或者LLM背后的基础模型新阅读。而最新科技&#xff08;Mamba,xLSTM,KAN&#xff09;则提…

618购物节入手哪些数码好物好?年度必备好物清单大盘点

随着一年一度的618购物节的到来&#xff0c;数码市场再次掀起了热潮&#xff0c;在这个属于消费者的狂欢节里&#xff0c;各大品牌和商家纷纷推出优惠活动和新品&#xff0c;为数码爱好者们带来了无数的购物选择&#xff0c;那么在这个购物盛宴中&#xff0c;我们应该如何挑选那…

如何进行论文查重,选择合适的查重系统?

原创性是学术写作海洋中的航行灯塔&#xff0c;而论文查重&#xff08;www.check110.com&#xff09;则是保障这束光芒不被云雾遮蔽的工具。而查重系统如何对论文进行查重&#xff0c;又该如何选择论文查重系统呢&#xff1f; 一、论文查重 论文查重&#xff0c;就是检测学术…

Python基础教程(十三):file文件及相关的函数

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; &#x1f49d;&#x1f49…

idea中使用逆向工程生成数据库表的实体类

1、在idea中打开数据库视图&#xff1b; 2、点击database中的号创建数据源连接&#xff08;以MySQL为例&#xff09;&#xff1b; 填入账户密码以及数据库名&#xff1b; 点击测试连接&#xff0c;若出现爆红Server returns invalid timezone. Go to Advanced tab and set serv…

RawChat:优化AI对话体验,全面兼容GPT功能平台

文章目录 一、Rawchat简介1.1 RawChat的主要特性1.2 RawChat的技术原理简述 二、使用教程三、案例应用3.1 图片内容分析3.2 生图演示3.3 文档解析3.4 探索更多 四、小结 一、Rawchat简介 RawChat平台的诞生&#xff0c;其核心理念是降低用户访问类似ChatGPT这类先进AI服务的门…

MySQL复习题(期末考试)

MySQL复习题&#xff08;期末考试&#xff09; 1.MySQL支持的日期类型&#xff1f; DATE,DATETIME,TIMESTAMP,TIME,TEAR 2.为表添加列的语法&#xff1f; alter table 表名 add column 列名 数据类型; 3.修改表数据类型的语法是&#xff1f; alter table 表名 modify 列名 新…

文心智能体体验,打造你自己的GPTs应用

利用百度智能体搭建的《RPG冒险游戏大作战》已经发布啦&#xff01; RPG冒险游戏大作战 玩家扮演一位小小勇士女孩&#xff0c;从被巨龙毁灭的冒险小镇出发&#xff0c;一路披荆斩棘&#xff0c;集齐四件神器后&#xff0c;打败巨龙&#xff0c;夺回小镇的安宁&#xff01; 整…

python3的基本语法说明一

一. 简介 本文开始学习 python3 的基本语法。 二. python3的基本语法 1. 编码 默认情况下&#xff0c;Python 3 源码文件以 UTF-8 编码&#xff0c;所有字符串都是 unicode 字符串。 当然你也可以为源码文件指定不同的编码&#xff1a; # -*- coding: cp-1252 -*- 上述…

Unity图集

概述 相信在同学们学习过程中&#xff0c;在UI的的使用时候一定经常听说过图集的概念。 Unity有UI的组件&#xff0c;有同学们好奇&#xff0c;那为什么还要使用图集呢&#xff1f; 这就需要提到一个性能优化的问题了&#xff0c;因为过多的UI图片&#xff0c;会大幅增加Dra…

隔离式 AC-DC 反激电源设计原理分析

LinkSwitch-LP 系列旨在取代手机/无绳电话、PDA、数码相机和便携式音频播放器等应用中输出功率 < 2.5 W 的低效线频线性变压器电源。LinkSwitch-LP 还可用作白色家电等应用中的辅助电源。 LinkSwitch-LP 将高压功率 MOSFET 开关与 ON/OFF 控制器集成在一个设备中。它完全由…

Vue 路由传递参数 query、params

1、to的对象写法,绑定参数 <template> 2 <ul> 3 <li v-for"m in messlist" :key"m.id"> 4 <router-link :to"{ //使用params时&#xff0c;这个路径必须用name及别名......name: xiangqing, path: /bbb/message/deta…

自动驾驶#芯片-1

概述 汽车是芯片应用场景之一&#xff0c;汽车芯片需要具备车规级。  车规级芯片对加工工艺要求不高&#xff0c;但对质量要求高。需要经过的认证过程&#xff0c;包括质量管理标准ISO/TS 16949、可靠性标准 AEC-Q100、功能安全标准ISO26262等。  汽车内不同用途的芯片要求…

批量替换删除图片文件名称中相同数字:轻松管理文件结构新技巧大揭秘

特别是当图片文件名称中包含相同的数字时&#xff0c;想要快速找到或整理这些文件更是难上加难。今天&#xff0c;我要向大家揭秘一种轻松管理图片文件结构的新软件——文件批量改名高手。 进入“文件批量改命名高手”主页面&#xff0c;你会看到一个简洁明了的操作界面。在板…

聚焦新版综合编程能力面试考查汇总

目录 一、业务性编程和广度能力考查 &#xff08;一&#xff09;基本定义 &#xff08;二&#xff09;必要性分析 二、高频考查样题&#xff08;编程扩展问法&#xff09; 考题1: 用java 代码实现一个死锁用例&#xff0c;说说怎么解决死锁问题&#xff1f;&#xff08;高…

Python 组内序号

模仿SQL的row_number() over (partition by column order by column) import pandas as pd # 创建一个示例数据框 data { group: [A, A, A, B, B, C, C, C, C], value: [3, 1, 2, 5, 4, 6, 9, 7, 8] } df pd.DataFrame(data) # 先按group分组&#xff0c;再按val…

eclipse导入Tomcat9源码

环境准备 下载Tomcat源码 https://github.com/apache/tomcat/tagsJDK版本 Tomcat9要求JDK17以上版本 https://www.oracle.com/java/technologies/javase/jdk17-archive-downloads.htmlAnt安装 https://ant.apache.org/bindownload.cgi我这里装的是apache-ant-1.10.14版本 …

Ubuntu系统调试分析工具

文章目录 一、火焰图一、下载 FlameGraph二、安装 iperf三、使用二、Lockdep1、内核开启 Lockdep 配置2、判断 Lockdep 开启是否成功一、火焰图 一、下载 FlameGraph git clone https://github.com/brendangregg/FlameGraph.gitFlameGraph 介绍:   基本思想是将程序的函数…

便民智慧小程序源码系统 同城信息+商家联盟+生活电商 功能强大 带完整的安装代码包以及搭建部署教程

系统概述 便民智慧小程序源码系统是一个高度集成化的本地化服务平台解决方案&#xff0c;它融合了同城信息发布、商家联盟管理和生活电商平台三大核心模块&#xff0c;旨在打造一个全方位、多维度的生活服务生态系统。该系统采用先进的前后端分离架构&#xff0c;支持快速响应…

Redux 与 MVI:Android 应用的对比

Redux 与 MVI&#xff1a;Android 应用的对比 在为 Android 应用选择合适的状态管理架构时可能会感到困惑。在这个领域中&#xff0c;有两种流行的选择是 Redux 和 MVI&#xff08;Model-View-Intent&#xff09;。两者都有各自的优缺点&#xff0c;因此在深入研究之前了解它们…