PyTorch深度学习笔记之五(使用神经网络拟合数据)

news2025/1/17 16:16:58

使用神经网络拟合数据

1. 人工神经网络

1.1 神经网络和神经元

神经网络:一种通过简单函数的组合来表示复杂函数的数学实体。
人工神经网络和生理神经网络似乎都使用模糊相似的数学策略来逼近复杂的函数,因为这类策略非常有效。

这些复杂函数的基本构件是神经元。其核心就是给输入做一个线性变换(如乘以一个权重再加上一个常数作为偏置),然后应用一个固定的非线性函数,即激活函数。
比如: o = t a n h ( w x + b ) o = tanh(wx+b) o=tanh(wx+b)

  • w和b就是要学习的参数,wx+b 就是线性变换,
  • t a n h tanh tanh (双曲正切函数) 就是激活函数,也是一个非线性函数。

所以,神经元就是一个包含在非线性函数中的线性变化。

从数学上,单个神经元可以写成 o = f ( w x + b ) o=f(wx+b) o=f(wx+b),

  • f 就是激活函数;
  • x 和 o 可以是简单的 标量(即0维张量)向量(即1维张量)
  • w, b, x 都可以是标量或向量,但它们应当一致;
  • 当w和b代表向量时,f(wx+b) 代表了一层神经元(可理解为一列),因为它通过多维权重和偏置来表示多个神经元。

1.2 组成一个多层网络

一个多层的神经网络大致如下:

x1 = f0(w0 * x + b0)
x2 = f1(w1 * x1 + b1)
x3 = f2(w2 * x2 + b2)

y = fn(wn * xn + bn)

前一层神经元的输出,被用作后一层神经元的输入。
使用向量可以让 w 承载整个神经元层,而不是单一的权重。
(注:w是向量,代表一层神经元,即一列神经元,每个神经元上的权重可能不同。)
o = t a n h ( w n ( . . . t a n h ( w 1 ( t a n h ( w 0 ∗ x + b 0 ) + b 1 ) ) . . . ) + b n ) o = tanh(wn(...tanh(w1(tanh(w0*x+b0)+b1))...)+bn) o=tanh(wn(...tanh(w1(tanh(w0x+b0)+b1))...)+bn)

1.3 激活函数

激活函数是非线性函数。它有2个作用。

  1. 在模型的内部,它允许输出函数在不同的值上有不同的斜率。这是线性函数无法做到的。通过巧妙地设置不同的斜率,神经网络可以近似任意函数。
  2. 在网络的最后一层,它可以将前面的线性运算的输出集中到给定的范围内。

下面看这第2个作用的几个具体例子。

  1. 限制输出范围
    比如,限制在 [a, b] 这样的范围。这可以调用 torch.nn.Hardtanh()的简单激活函数。注意,默认的范围是 [-1, 1]
    这里“限制”的意思是,强行设置。比如大于1的就强行设为1.

  2. 压缩输出范围
    这一类的代表有 torch.nn.Sigmoid(), torch.tanh(), 1/(1+e**-x)等。
    效果类似于,x趋于负无穷大时,y趋于0或-1; x趋于正无穷大时,y趋于1.

举个例子,识别一张图片是否是狗的图片。
结果越接近-1,越不可能是狗;越接近1,越可能是狗;0附近的不易判别。
给出车辆、狗、熊,一共3张图片。

  • 车辆在倒数第2层输出 -2.2, 那么,math.tanh(-2.2)= -0.97
  • 熊在倒数第2层输出0.1, 那么,math.tanh(0.1)= 0.09
  • 狗在倒数第2层输出2.5, 那么, math.tanh(2.5)= 0.98

关于更多的激活函数的信息,请参见以下的图片部分。
Tanh
Hardtanh
Sigmoid
ReLU
Leaky ReLU
Softplus

2. 完成一个神经网络

我们仍然使用前面讲过的温度计的例子,来完成一个神经网络。之前的神经网络是只有一层的、线性的神经网络;这里我们来做一个3层的神经网络,因为将包含激活函数。
为保证文章的完整性,再将问题描述如下:
假设有一个温度计,它本没有刻度和单位;我们给它标上刻度,然后用摄氏温度来解释这个特殊温度计测量值;即,给定特殊温度计的读数,将其翻译成摄氏度。

2.1 PyTorch 的 nn 模块

首先,PyTorch 有一个专门用于神经网络的子模块,叫做 torch.nn, 它包含创建各种神经网络结构所需的构建块。(在PyTorch中,这些构建块称为模块;在其他框架中,这样的构建块称为层)

PyTorch 模块派生自基类 nn.Module, 一个模块可以有一个或多个参数实例作为属性,这些参数实例是张量,它们的值在训练过程中得到了优化(如线性模型中的w和b)。
一个模块还可以有一个或多个子模块(nn.Module的子类)作为属性,并且它还能追踪它们的属性。子模块必须是顶级属性,而不是隐藏在列表或dict中,否则优化器无法定位子模块及它们的参数。

2.2 替换线性模型

nn.Linear是一个关于线性模型的类。它的构造函数接收3个参数:

  1. 输入特征的数量
  2. 输出特征的数量
  3. 线性模型是否包含偏置,默认为True

看一段代码:

import torch.nn as nn

linear_model = nn.Linear(1, 1)  # 输入只有1个特征,输出也只有1个特征
linear_model(t_un_val)  # t_un_val 是一个代表输入的张量

optimizer = optim.SGD(linear_model.parameters(), lr=1e-2)

解释:

  • linear_modelnn.Linear类的一个实例,其构造函数用(1,1) 来初始化,代表输入和输出都是只有一个特征;
  • linear_model(t_un_val)是调用了 __call__()方法,相当于 linear_model.__call__(t_un_val)
  • 这里想让模型跑起来,所以貌似可以调用 forward()方法,但是一般不能这么做。
    因为在__call__()方法的实现中,不仅调用了forward(),还调用了若干hook函数。所以如果我们仅仅调用forward()方法,这些hook函数就不会被调到从而引起错误。
  • 在构建SGD优化器实例的时候,将线性模型的参数传递给该优化器的构造函数的第一个参数

另外,nn包含几个常见的损失函数,其中 nn.MSELoss()就是均方误差,和我们之前定义的 loss_fn一样。因此,我们可以直接调用 MSELoss(), 而不再需要手写损失函数了。

2.3 构建神经网络

下面构建一个最终的神经网络:一个线性模块 => 一个激活函数 => 另一个线性模块
第一个线性模块+激活层通常也被称为隐藏层,因为它的输出并不能被直接观察到;
第一个线性模块的输出通常大于1;
激活层用于:不同的单元对不同范围的输入做出响应,以增加模型的容量;
最后的线性层:获取激活层的输出,并将它们进行线性组合以产生最后的输出值。

nn提供了一种通过 nn.Sequential容器来连接模型的方式.
这里指定第1层的输出张量的大小为10,这基本可以随意指定一个值;
第3层的输出张量的大小需要和第1层的输出张量的大小一致;
最后的 seq_model 就是一个3层的神经网络模型了。

seq_model = nn.Sequential(
            nn.Linear(1, 10),
            nn.Tanh(),
            nn.Linear(10, 1))

调用 seq_model.parameters()将从第1个和第2个线性模块收集权重和偏置;
将来在调用 seq_model.backward()之后,所有参数都填充了它们的梯度;
然后优化器在调用 optimizer.step()期间会更新这些参数的值。

nn.Sequential也接受 OrderedDict, 可以用它来命名每个模块。

from collections import OrderedDict

seq_model = nn.Sequential(OrderedDict([
    ('hidden_linear', nn.Linear(1, 10)),
    ('hidden_activation', nn.Tanh()),
    ('output_linear', nn.Linear(10, 1))
]))

2.4 完整的程序

完整程序如下:

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from collections import OrderedDict
from matplotlib import pyplot as plt

def training_loop(n_epochs, optimizer, model, loss_fn, t_u_train, t_u_val,
                  t_c_train, t_c_val):
    for epoch in range(1, n_epochs + 1):
        t_p_train = model(t_u_train)
        loss_train = loss_fn(t_p_train, t_c_train)

        t_p_val = model(t_u_val)  # 前向传播,计算出预测值
        loss_val = loss_fn(t_p_val, t_c_val)  # 计算损失,只是为了打印
        
        optimizer.zero_grad()  # 梯度清零
        loss_train.backward()  # 利用前向图计算梯度
        optimizer.step()       # 利用梯度,更新参数

        if epoch == 1 or epoch % 1000 == 0:
            print(f"Epoch {epoch}, Training loss {loss_train.item():.4f},"
                  f" Validation loss {loss_val.item():.4f}")

# 共11个元素的向量(1维张量)
t_c = [0.5,  14.0, 15.0, 28.0, 11.0,  8.0,  3.0, -4.0,  6.0, 13.0, 21.0]
t_u = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]

# 增加一个维度,即增加维度1,且该维度上只有1个元素,变成 torch.Size([11,1]) 的形状
t_c = torch.tensor(t_c).unsqueeze(1)
t_u = torch.tensor(t_u).unsqueeze(1)

n_samples = t_u.shape[0]   # 11个元素
n_val = int(0.2 * n_samples)   # 2, 代表评估集大小

shuffled_indices = torch.randperm(n_samples)  # 产生11个随机数,但范围是0-10
train_indices = shuffled_indices[:-n_val]     # 训练集,取前9个随机数作为index
val_indices = shuffled_indices[-n_val:]       # 评估集,取后2个随机数作为index

# 特殊用法,即按照index,取出相应的元素组成Tensor,下面是训练集
t_u_train = t_u[train_indices]
t_c_train = t_c[train_indices]

# 下面是评估集
t_u_val = t_u[val_indices]
t_c_val = t_c[val_indices]

# 归一化处理
t_un_train = 0.1 * t_u_train
t_un_val = 0.1 * t_u_val

# 神经元层的神经元数量
neuron_count = 20

# 构建3层神经网络
seq_model = nn.Sequential(OrderedDict([
    ('hidden_linear', nn.Linear(1, neuron_count)),
    ('hidden_activation', nn.Tanh()),
    ('output_linear', nn.Linear(neuron_count, 1))
]))

# 用整个模型的参数列表作为优化器构造函数的第一个参数
optimizer = optim.SGD(seq_model.parameters(), lr=1e-3)

# 训练
training_loop(
    n_epochs = 5000, 
    optimizer = optimizer,
    model = seq_model,
    loss_fn = nn.MSELoss(),
    t_u_train = t_un_train,
    t_u_val = t_un_val, 
    t_c_train = t_c_train,
    t_c_val = t_c_val)

# seq_model 已训练完毕,下面开始作图

t_range = torch.arange(20., 90.).unsqueeze(1)
fig = plt.figure(dpi=150)
plt.xlabel("Fahrenheit")
plt.ylabel("Celsius")
plt.plot(t_u.numpy(), t_c.numpy(), 'o')

# detach() returns a tensor; numpy() returns numpy.ndarray
plt.plot(t_range.numpy(), seq_model(0.1 * t_range).detach().numpy(), 'c-')
plt.plot(t_u.numpy(), seq_model(0.1 * t_u).detach().numpy(), 'kx')
plt.show()

运行的输出是:

python nn_plot.py
Epoch 1, Training loss 184.2818, Validation loss 231.0507
Epoch 1000, Training loss 3.2540, Validation loss 6.3815
Epoch 2000, Training loss 2.7581, Validation loss 8.6664
Epoch 3000, Training loss 1.9816, Validation loss 6.5488
Epoch 4000, Training loss 1.7593, Validation loss 5.7282
Epoch 5000, Training loss 1.6765, Validation loss 5.3717

显示的图片如下:
神经网络拟合

(完)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/49141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多线程轮流打印

一、背景 面试的时候,有一个高频的笔试题: 让2个线程轮流打印,a线程是打印ABCDEFGHIJ,b线程是打印1、2、3、4、5、6、7、8、9、10 二、原理 这种类型的面试题,主要是考察object的wait()方法和notify()方法的使用 …

spring整合Mybatis-P23,24,25

复习Mybatis&#xff08;都是之前的内容&#xff0c;不再解释&#xff09; 6个需要修改或创建的文件 UserMapper package com.Li.mapper;import com.Li.pojo.User;import java.util.List;public interface UserMapper {public List<User> selectUser(); }UserMapper.xm…

如何全面提升架构设计的质量

低成本 低成本本质上是对架构的一种约束&#xff0c;与高性能等架构是冲突的 手段和应用 先设计架构方案&#xff0c;再看如何降低成本 优化 引入缓存虚拟化、容器化性能调优采用高性能硬件采用开源方案 创新 NoSQL vs SQLSQL vs 倒排索引Hadoop vs MySQL 安全性 复杂…

《码出高效:Java开发手册》 四-走进JVM

前言 JVM是java中底层的知识&#xff0c;这里的内容比较复杂&#xff0c;对于一些软件编程&#xff0c;会经常使用&#xff0c;但很多业务其实碰不到这里的知识&#xff0c;下图为目录 介绍 JVM&#xff0c;java虚拟机&#xff0c;它的前身是99年的hotspot java虚拟机&…

vue 计算属性未重新计算 / computed 未触发 / computed 原理源码分析

点击可打开demo 这里在一秒后改了数组里value属性的值 虽然数据有更新&#xff0c;但打开控制台&#xff0c;可以发现computed函数只在初始化时执行了一次 按理说一秒后改变了value值&#xff0c;应该执行两次才对呀&#xff1f; 但如果computed属性这样写&#xff0c;明确写…

数据分析之大数据分析

一 什么是大数据分析 大数据是指无法在一定时间范围内用常规软件工具进行捕捉、管理和处理的数据集合&#xff0c;是需要新处理模式才能具有更强的决策力、洞察发现力和流程优化能力的海量、高增长率和多样化的信息资产。大数据的特点可以概括为5个V&#xff1a;数据量大&…

当湿度达到70蜂鸣器警报

1.编写设备树&#xff0c;添加蜂鸣器等设备 驱动代码&#xff1a; #include <linux/init.h> #include <linux/module.h> #include <linux/i2c.h> #include <linux/fs.h> #include <linux/uaccess.h> #include <linux/device.h> #include …

QCSPCChart for Java R3x0 Crack

Java 的 SPC 控制图工具 版本 3.04 QCSPCChart添加变量控制图&#xff08;X-Bar R、X-Bar Sigma、Individual Range、Median Range、EWMA、MA、MAMR、MAMS 和 CuSum 图&#xff09;、属性控制图&#xff08;p-、np-、c-、u- 和DPMO 图&#xff09;、频率直方图和 Pareto 图到…

[附源码]Python计算机毕业设计Django的旅游景点管理系统的设计与实现

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

[附源码]Python计算机毕业设计SSM老年公寓管理系统(程序+LW)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

计算机编程

文章目录计算机编程计算机编程语言计算机编程 人与人之间信息&#xff08;如想法、思想等&#xff09;的交流和传递&#xff0c;需要借助双方都能听得懂的语言。人和计算机之间实现交流也是如此&#xff0c;需要借助一种人和计算机都能理解的语言&#xff0c;这种语言称为编程…

LCHub低代码社区:旧的低代码,腾讯怎么讲出新故事

腾讯微搭的对手从来都不是钉钉。 低代码是 " 旧瓶装新酒 " 吗? 低代码风潮在国内兴盛已有两年,但也并不是已经被所有人接受,有不少开发者还保有否定、抵触的态度。 那为什么我们还认为这是一个不可逆的趋势呢? 这里先看下被否定的原因,LCHub在调研中听到的主…

怎么把PDF转换成图片?这三种转换方法都可以实现

怎么把PDF文件的内容转换成图片来使用呢&#xff1f;大家在办公或者是学习的过程中没少使用过PDF文件&#xff0c;有的文件我们翻阅起来会比较费时间&#xff0c;因为文件的内容多&#xff0c;这时候我们只需要把文件内容转成图片就可以解决这一问题&#xff0c;想要使用哪部分…

手把手带你开发你的第一个前端脚手架

开发一个简单的脚手架 1.创建 npm 项目 首先创建一个文件夹&#xff0c;然后进入到该文件夹目录下&#xff0c;执行 npm init -y 2.创建脚手架入口文件bin/index.js&#xff0c;在index.js中添加如下代码 #!/usr/bin/env nodeconsole.log(hello cli) 3.配置 package.json&a…

YOLOv5如何训练自己的数据集

目录 一、标注 1.1 标注软件下载labelimg 下载地址&#xff1a;mirrors / tzutalin / labelimg GitCode 1.2 json转txt 1.3 xml转txt 二、修改配置文件 2.1 建立文件目录 2.2 修改wzry_parameter.yaml文件 三、开始训练 3.1 2.结果 四、识别检测detect.py 1.调参找…

Jetson NX系统烧录以及CUDA、cudnn、pytorch等环境的安装

安装虚拟机和Ubuntu18.04环境 这两步比较简单&#xff0c;所以略了。虚拟机的配置需要注意硬盘空间大一点&#xff0c;至少40G。 安装sdk-manager NVIDIA SDK Manager下载地址&#xff1a;https://developer.nvidia.com/drive/sdk-manager sudo dpkg -i sdkmanager_1.9.0-…

YOLOv5和YOLOv7环境(GPU)搭建测试成功

本来是用doc写的&#xff0c;直接复制到这里很多图片加载缓慢&#xff0c;我直接把doc上传到资源里面了&#xff0c;0积分下载&#xff1a; (10条消息) YOLOv5和YOLOv7开发环境搭建和demo运行-Python文档类资源-CSDN文库 一、环境搭建 1.1 环境搭建参考链接 YOLO实践应用之…

uni-app 超详细教程(一)(从菜鸟到大佬)

一&#xff0c;uni-app 介绍 &#xff1a; 官方网页 uni-app 是一个使用 Vue.js 开发所有前端应用的框架&#xff0c;开发者编写一套代码&#xff0c;可发布到iOS、Android、Web&#xff08;响应式&#xff09;、以及各种小程序&#xff08;微信/支付宝/百度/头条/飞书/QQ/快手…

百度集团副总裁吴甜发布文心大模型最新升级,AI应用步入新阶段

11月30日&#xff0c;由深度学习技术与应用国家工程研究中心主办、百度飞桨承办的WAVE SUMMIT2022深度学习开发者峰会如期举行。百度集团副总裁、深度学习技术及应用国家工程研究中心副主任吴甜带来了文心大模型的最新升级&#xff0c;包括新增11个大模型&#xff0c;大模型总量…

PyQt5_寻找顶(底)背离并可视化

技术指标的背离是指技术指标曲线的波动方向与价格曲线的趋势方向不一致&#xff0c;是使用技术指标最为重要的一点。在股市中&#xff0c;常见的技术指标的背离分为两种常见的形式&#xff0c;即顶背离和底背离。背离是预示市场走势即将见顶或者见底的依据&#xff0c;在价格还…