基于神将网络方式进行数据回归拟合实例

news2024/11/18 7:50:43

前言

本篇博客主要以神经网络拟合数据这个简单例子讲起,然后介绍网络的保存与读取,以及快速新建网络的方法。

一、神经网络对数据进行拟合

import torch
from matplotlib import pyplot as plt
import torch.nn.functional as F


# 自定义一个Net类,继承于torch.nn.Module类
# 这个神经网络的设计是只有一层隐含层,隐含层神经元个数可随意指定
class Net(torch.nn.Module):
    # Net类的初始化函数
    def __init__(self, n_feature, n_hidden, n_output):
        # 继承父类的初始化函数
        super(Net, self).__init__()
        # 网络的隐藏层创建,名称可以随便起
        self.hidden_layer = torch.nn.Linear(n_feature, n_hidden)
        # 输出层(预测层)创建,接收来自隐含层的数据
        self.predict_layer = torch.nn.Linear(n_hidden, n_output)

    # 网络的前向传播函数,构造计算图
    def forward(self, x):
        # 用relu函数处理隐含层输出的结果并传给输出层
        hidden_result = self.hidden_layer(x)
        relu_result = F.relu(hidden_result)
        predict_result = self.predict_layer(relu_result)
        return predict_result


# 训练次数
TRAIN_TIMES = 300
# 输入输出的数据维度,这里都是1维
INPUT_FEATURE_DIM = 1
OUTPUT_FEATURE_DIM = 1
# 隐含层中神经元的个数
NEURON_NUM = 32
# 学习率,越大学的越快,但也容易造成不稳定,准确率上下波动的情况
LEARNING_RATE = 0.1

# 数据构造
# 这里x_data、y_data都是tensor格式,在PyTorch0.4版本以后,也能进行反向传播
# 所以不需要再转成Variable格式了
# linspace函数用于生成一系列数据
# unsqueeze函数可以将一维数据变成二维数据,在torch中只能处理二维数据
x_data = torch.unsqueeze(torch.linspace(-4, 4, 80), dim=1)
# randn函数用于生成服从正态分布的随机数
y_data = x_data.pow(3) + 3 * torch.randn(x_data.size())
y_data_real = x_data.pow(3)

# 建立网络
net = Net(n_feature=INPUT_FEATURE_DIM, n_hidden=NEURON_NUM, n_output=OUTPUT_FEATURE_DIM)
print(net)

# 训练网络
# 这里也可以使用其它的优化方法
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
# 定义一个误差计算方法
loss_func = torch.nn.MSELoss()

for i in range(TRAIN_TIMES):
    # 输入数据进行预测
    prediction = net(x_data)
    # 计算预测值与真值误差,注意参数顺序问题
    # 第一个参数为预测值,第二个为真值
    loss = loss_func(prediction, y_data)

    # 开始优化步骤
    # 每次开始优化前将梯度置为0
    optimizer.zero_grad()
    # 误差反向传播
    loss.backward()
    # 按照最小loss优化参数
    optimizer.step()

    # 可视化训练结果
    if i % 2 == 0:
        # 清空上一次显示结果
        plt.cla()
        # 无误差真值曲线
        plt.plot(x_data.numpy(), y_data_real.numpy(), c='blue', lw='3')
        # 有误差散点
        plt.scatter(x_data.numpy(), y_data.numpy(), c='orange')
        # 实时预测的曲线
        plt.plot(x_data.numpy(), prediction.data.numpy(), c='red', lw='2')
        plt.text(-0.5, -65, 'Time=%d Loss=%.4f' % (i, loss.data.numpy()), fontdict={'size': 15, 'color': 'red'})
        plt.pause(0.1)

训练300次的可视化效果如下。

二、模型的保存与读取

模型保存比较简单,直接调用torch.save()函数即可。 有两种方式可以保存网络,一种是直接保存整个网络,另一种则是只保存网络中节点的参数。代码如下。

# 保存整个网络
torch.save(net,'net.pkl')
# 只保存网络中节点的参数
torch.save(net.state_dict(),'net_params.pkl')

保存好网络后就是载入网络,相对也比较简单。对于第一种保存整个网络的方式而言,直接torch.load()即可。 对于第二种方式保存的网络,则需要先建立一个和之前结构一模一样的网络,然后再将保存的参数载入进来。代码如下。

# 直接装载网络
net_restore=torch.load('net.pkl')
# 先新建个一模一样的网络,再载入参数
net_rebuild=Net(n_feature=INPUT_FEATURE_DIM,n_hidden=NEURON_NUM,n_output=OUTPUT_FEATURE_DIM)
net_rebuild.load_state_dict(torch.load('net_params.pkl'))

两种方法第二种保存、恢复效率会更高一些,尤其是网络很大很复杂的时候。

三、模型快速搭建

主要利用torch.nn.Sequntial()函数实现而无需新建一个class,示例代码如下。

net2=torch.nn.Sequential(
    torch.nn.Linear(INPUT_FEATURE_DIM,NEURON_NUM),
    torch.nn.ReLU(),
    torch.nn.Linear(NEURON_NUM,OUTPUT_FEATURE_DIM))
print(net2)

利用上述代码便可以搭建一个和第一部分一样的网络。唯一有所区别的是这里每一层是没有名字的,以序号标出。 而在之前我们定义层的时候,便指定了每一层的名字。除了这点区别外,其它没有区别了。 但是相比于第一种方法却简单很多了,不用定义类,也不用写初始化和前向传播函数,十分方便。

四、总计

以上便是本篇博客的主要内容,介绍了神经网络拟合数据的例子,以及网络的保存、恢复。 最后介绍了模型的快速新建方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/148384.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Diffusion model(二): 训练推导详解

接上文 Diffusion的训练推导 1. 最小化负对数似然与变分下界 在弄懂diffusion model前向和反向过程之后,最后我们需要了解其训练推导过程,即用什么loss以及为什么。在diffusion的反向过程中,根据(3)(3)(3)式我们需要预测μθ(xt,t),Σθ(x…

【Linux】进程状态和进程优先级

文章目录1. 进程状态2. Linux的进程状态3. 僵尸进程4. 孤儿进程5. 进程优先级1. 进程状态 为了更深入地了解进程,我们需要知道进程的不同状态。 不同的操作系统,对于进程状态有着不同的说法,如:运行、阻塞、挂起、新建、就绪、等…

SIoU Loss

1、论文 题目:《SIoU Loss: More Powerful Learning for Bounding Box Regression》 参考博客: https://blog.csdn.net/qq_56749449/article/details/125753992 2、原理 有关IoU损失函数,像GIoU、DIoU、CIoU都没有考虑真实框与预测框之间的…

关于zookeeper和kafka不得不说的秘密

zookeeper简介1. zookeeper的概述ZooKeeper是一个分布式的,开放源码的分布式应用程序协调服务,是Google的Chubby一个开源的实现,是Hadoop和Hbase的重要组件。它是一个为分布式应用提供一致性服务的软件,提供的功能包括&#xff1a…

【四】Netty 分隔符和定长解码器的应用

Netty 分隔符和定长解码器的应用理论说明LineBasedFrameDecoder 开发大概流程代码展示netty 依赖EchoServer 服务端启动类EchoServerHandlerEchoClientEchoClientHandler结果打印客户端打印服务端打印FixedLengthFrameDecoder 开发代码展示EchoServer 服务端启动类EchoFixServe…

【云原生】k8s之pod控制器

内容预知 前言 1.pod控制器的相关知识 1.1 pod控制器的作用 1.2 pod控制器的多种类型 1.3 pod容器中的有状态和无状态的对比 (1)有状态实例 (2)无状态实例 2.Deployment控制器 2.1 SatefulSet 控制器的运用 2.1 Sateful…

从0到1完成一个Vue后台管理项目(六、404页)

往期 从0到1完成一个Vue后台管理项目(一、创建项目) 从0到1完成一个Vue后台管理项目(二、使用element-ui) 从0到1完成一个Vue后台管理项目(三、使用SCSS/LESS,安装图标库) 从0到1完成一个Vu…

[LeetCode周赛复盘] 第 95 场周赛20230107

[LeetCode周赛复盘] 第 95 场周赛20230107 一、本周周赛总结二、 [Easy] 2525. 根据规则将箱子分类1. 题目描述2. 思路分析3. 代码实现三、[Medium] 2526. 找到数据流中的连续整数![在这里插入图片描述](https://img-blog.csdnimg.cn/237210adb20e457aaf2671e6e8f9e43b.png)2. …

Linux系统中C++多态和数据封装的基本方法

大家好,今天主要和大家分享一下,多态,数据封装的使用方法。 目录 第一:C中的多态 第二:C中数据封装方法 第一:C中的多态 C多态意味着调用成员函数时,会根据调用函数的对象的类型来执行不同的函…

将内核加载到内存

文章目录前言前置知识代码实验操作前言 本博客记录《操作系统真象还原》第五章第3个实验的操作~ 实验环境:ubuntu18.04VMware , Bochs下载安装 实验内容:将内核载入内存,初始化内核代码 实验原理 编写内核程序。将内核程序用dd命令复制到…

Odoo 16 企业版手册 - 库存管理之存储类别

存储类别 Odoo中的存储类别功能将允许您将许多存储位置分组到一个类别下。您可以在Odoo 库存管理模块中创建许多此类类别,这将有助于执行更智能的放置操作。在配置存储类别之前,您必须配置库存中可用的存储位置。然后,您可以将它们分组到一个…

LeetCode刷题模版:31 - 40

目录 简介31. 下一个排列32. 最长有效括号33. 搜索旋转排序数组34. 在排序数组中查找元素的第一个和最后一个位置35. 搜索插入位置36. 有效的数独37. 解数独38. 外观数列39. 组合总和40. 组合总和 II结语简介 Hello! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指…

电影解说开头怎么写吸引人?

电影解说开头怎么写吸引人?很多电影解说创作者文采不够好,开头不知道怎么写?毕竟想留住用户继续观看视频,开头是至关重要的,今天笔者就分享电影解说文案万能公式模板,让大家创作更简单!一个好的…

feature engnineering 特征工程

特征工程数值型变量standardizationlog_transformation(使其符合正态分布)polynomial features分类型变量orinigalencoderonehot encoder分类创造下的数值以下代码根据Abhishek Thakur在kaggle上的机器学习30天 (b站) (kaggle)可惜的是,我没有…

Oracle 19c VLDB and Partitioning Guide 第5章:管理和维护基于时间的信息 读书笔记

本文为Oracle 19c VLDB and Partitioning Guide第5章Managing and Maintaining Time-Based Information的读书笔记。 Oracle 数据库提供了基于时间管理和维护数据的策略。 本章讨论 Oracle 数据库中的组件,这些组件可以构建基于时间管理和维护数据的策略。 尽管大…

计算机网络复习之网络层

文章目录数据报与虚电路服务的对比IP 协议IP数据报格式IP地址NAT(网络地址转换)子网划分和子网掩码在支持子网划分的因特网中,路由器如何转发IP数据报无分类编制CIDR构成超网RIP协议OSPF协议ARP协议ICMP协议Ping和Traceroute参考路由选择是网…

Eclipse安装教程

Eclipse安装教程 目录一. 概述二. 下载eclipse三. 安装eclipse四. 使用eclipse。一. 概述 eclipse是针对java编程的集成开发环境,其设计思想是“一切皆插件”。就其本身而言,eclipse只是一个框架…

Hive表的创建,删除,修改

TBLPROPERTIES的主要作用是按键-值对的格式为表增加额外的文档说明。Hive会自动增加两个表属性:一个是last_modified_by,其保存着最后修改这个表的用户的用户名﹔另一个是 last_modified_time,其保存着最后一次修改的新纪元时间秒。用户还可以拷贝一张已…

数据的合并和分组聚合

一:字符串离散化的案列 对于这一组电影数据,如果我们希望统计电影分类(genre)的情况,应该如何处理数据? 思路:重新构造一个全为0的数组,列名为分类,如果某一条数据中分类出现过,就让…

Java之class类

Class类 1.类图 2.Class类对象 系统创建 该class对象是通过类加载器ClassLoader的loadClass()方法生成对应类对应的class对象 通过debug可以追到该方法 3.对于某个类的class类对象 只加载一次 因为类值加载一次 类加载的时机 //1.创建对象实例的时候(new&#xf…