【Python机器学习】神经网络中全连接层与线性回归的讲解及实战(Tensorflow、MindSpore平台 附源码)

news2024/11/24 18:32:05

需要全部代码请点赞关注收藏后评论区留言私信~~~

全连接层与线性回归

神经网络模型也是参数学习模型,因为对它的学习只是得到神经网络参数的最优值,而神经网络的结构必须事先设计好。如果确实不能通过改进学习过程来达到理想效果,则要重新设计神经网络的结构。

层状神经网络的隐层和输出层具有处理信息的能力,它们又可细分为全连接层、卷积层、池化层、LSTM层等等,通过适当排列可以组合成适应不同任务的网络。

全连接层是层状神经网络最基本的层,本小节从线性回归模型入手,深入讨论全连接层。

线性回归模型改写为:

 

神经元模型

 

可以将线性回归看成是神经元模型,其阈值θ=w^(0),其激励函数为等值函数f(x)=x,即该神经元是没有激励函数的特殊神经元。

先定义一个二维平面上的线性目标函数并用它来生成训练样本,再定义一个代表线性回归模型的神经网络,然后用训练样本对该网络进行训练,并在训练的过程中动态显示线性模型的拟合过程。

效果如下 

 

 代码如下

### 定义训练样本生成函数
import numpy as np
np.random.seed(1101) # 指定随机数种子,产生相同的随机数,便于观察试验结果
 
def f(x, w=3.0, b=1.0): # 目标函数
    return x * w + b
 
def get_data(num):
    for _ in range(num):
        x = np.random.uniform(-10.0, 10.0)
        noise = np.random.normal(0, 3)
        y = f(x) + noise
        yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
### 生成训练样本并增强
from mindspore import dataset as ds
import matplotlib.pyplot as plt
 
data_number = 80 # 样本总数
batch_size = 16 # 每批训练样本数(批梯度下降法)
repeat_size = 1
 
train_data = list(get_data(data_number))
X, y = zip(*train_data)
plt.scatter(X, y, color="black", s=10)
xx = np.arange(-10.0, 10, 1)
yy = f(xx)
plt.plot(xx, yy, color="red", linewidth=1, linestyle='-')
plt.show()

按上述方法构建的层被称为全连接层(fully connected layers),它是层状神经网络最基本的层。

全连接层的每一个节点都与上一层的所有节点相连。设前一层的输出为X=(x_1, x_2,…,x_i,…,x_m),本层的输出为Y=(y_1, y_2,…,y_j,…,y_n),其中:

 

定义连接系数矩阵:

 

和阈值系数向量:

 

全连接层的计算可以写成矩阵形式:

 

在全连接层中,连接系数和阈值系数是要训练的参数,它们一共有m×n+n个。

动态拟合过程,当训练到第十轮左右的时候模型拟合度已经十分高了 

 

代码如下

import numpy as np
np.random.seed(1101) # 指定随机数种子,产生相同的随机数,便于观察试验结果

def f(x, w=3.0, b=1.0): # 目标函数
    return x * w + b

def get_data(num):
    for _ in range(num):
        x = np.random.uniform(-10.0, 10.0)
        noise = np.random.normal(0, 3)
        y = f(x) + noise
        yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
from mindspore import dataset as ds
import matplotlib.pyplot as plt

data_number = 80 # 样本总数
batch_size = 16 # 每批训练样本数(批梯度下降法)
repeat_size = 1

train_data = list(get_data(data_number))
X, y = zip(*train_data)
plt.scatter(X, y, color="black", s=10)
xx = np.arange(-10.0, 10, 1)
yy = f(xx)
plt.plot(xx, yy, color="red", linewidth=1, linestyle='-')
plt.show()
import time
from mindspore import Tensor

def plot_model_and_datasets(net, train_data):
    weight = net.trainable_params()[0]
    bias = net.trainable_params()[1]
    x = np.arange(-10, 10, 1)
    y = x * Tensor(weight).asnumpy()[0][0] + Tensor(bias).asnumpy()[0]
    x1, y1 = zip(*train_data)
    x_target = x
    y_target = f(x_target)

    plt.axis([-11, 11, -20, 25])
    plt.scatter(x1, y1, color="black", s=10)
    plt.plot(x, y, color="blue", linestyle=':', linewidth=2)
    plt.plot(x_target, y_target, color="red")
    plt.show()
    time.sleep(0.02)
    
from IPython import display
from mindspore.train.callback import Callback

class ImageShowCallback(Callback): # 回调类
    def __init__(self, net, train_data):
        self.net = net
        self.train_data = train_data

    def step_end(self, run_context):
        plot_model_and_datasets(self.net, self.train_data)
        display.clear_output(wait=True)

 创作不易 觉得有帮助请点赞关注收藏~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/101251.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux基础学习-用户权限相关命令

用户权限相关命令 用户和权限的基本概念 基本概念 用户是linux系统工作中的重要的一环,用户管理包括 用户 和 组 管理在linux系统下,不论是由本机还是远程登录系统,每个系统都必须有一个账号,并且对于不同的系统资源拥有不同的使用…

银河系中心黑洞的首张照片

说到黑洞,那就不得不提起我们的家园银河系中心的大黑洞,在今天这张照片出来之前,所有关于银河系黑洞的描述都是推测、理论,而今天成为了现实! 2019年,同一团队拍摄了梅西耶87星系(M87&#xff0…

文本分类优化方法

文本分类优化方法 文本分类是NLP的基础工作之一,也是文本机器学习中最常见的监督学习任务之一,情感分类,新闻分类,相似度判断、问答匹配、意图识别、推断等等领域都使用到了文本分类的相关知识或技术。文本分类技术在机器学习的发…

用 Markdown 快速生成漂亮的 Latex 伪代码

参考:在 Markdown 中书写伪代码 文章目录配置 VSCode编写 Latex 源码生成 pseudocode配置 VSCode 组合: VSCode Markdown Preview Enhanced pseudocode.js 安装好 VSCode 和 Markdown Preview Enhanced 插件 按下快捷键 Ctrl Shift P,打…

定时器/计数器的基本概念

80C51单片机中有两个计数器,即T0和T1。 单片机内有一个定时器/计数器T0,可以用编程的方法将它设为计数器。当用作计数器时,它是一个16位计数器,它的最大计数值为65536。 定时器/计数器T0和T1分别是由TH0、TL0和TH1、TL1两个8位计数…

Vue3与Vue2生命周期不同点

一、前言 随着Vue3发布了两年多的时间,越来越多的小伙伴已经将老项目中的Vue2版本进行升级或者在新项目中使用到了Vue3.x的版本,今天就来总结以下Vue3相较于Vue2升级的生命周期不同点在哪。 二、生命周期 下面是生命周期对比图: Vue2Vue3…

JMeter基础入门

目录:导读 一、概述 二、Jmeter目录文件讲解 结语 一、概述 JMeter是Apache下一款在国外非常流行和受欢迎的开源性能测试工具,JMeter可用于模拟大量负载来测试一台服务器,网络或者对象的健壮性或者分析不同负载下的整体性能。 1、压测不同…

ValidateCode验证码的使用详解(初学看完都会用)

✅作者简介:热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏:Java案例分…

Android 进阶——性能优化之电量优化全攻略及实战小结(二)

文章大纲引言一、在低电耗模式和应用待机模式下进行测试1、在低电耗模式下测试您的应用2、在应用待机模式下测试您的应用3、列入白名单的可接受用例4、确定当前充电状态5、监控充电状态变化6、确定当前电池电量7、监控显著的电池电量变化二、Wakelock 机制1、WakeLock分类2、申…

Linux系统x86-64架构下,从零实现一个系统调用。Ubuntu22.04LTS

名称版本OSUbuntu 22.04 LTSCurrent Kernel5.15.0-56-genericDestination Kernel5.16.60首先要会编译linux内核的源码,这块在我的另外一片文章里面。 https://blog.csdn.net/jl19861101/article/details/128327069 打开linux内核源码目录/arch/x86/entry/syscalls/syscall_64.t…

前端面试比较好的回答

介绍一下Connection:keep-alive 什么是keep-alive 我们知道HTTP协议采用“请求-应答”模式,当使用普通模式,即非KeepAlive模式时,每个请求/应答客户和服务器都要新建一个连接,完成 之后立即断开连接(HTTP协议为无连接…

【笔记】canvas 绘制足球 —— 第一步 画个球体

文章目录一、球体分析二、足球结构分析三、canvas常用API四、画个球体1.初始化2.代码五、加上足球的皮肤一、球体分析 先上两张图 球坐标转直角坐标 xρsin(φ)cos(θ)x \rho \times sin(\varphi) \times cos(\theta) xρsin(φ)cos(θ) yρsin(φ)sin(θ)y \rho \times si…

柴油,光伏模块,风力涡轮机,电池和水力抽水蓄能组成的混合隔离微电网的设计(Matlab实现)

目录 0 引言 1 概述 2 HYMOD 软件操作 2.1 设计的三个阶段 3 HYMOD 软件架构 4 系统和元件的可靠性 5 微电网设计示例 6 Matlab代码与结论 7 操作指南 7.1 概述 7.2 操作 7.3 软件具体操作 0 引言 本文介绍了混合微电网优化设计 (Hymod) 软件。该软件具有最先进…

Redis发布和订阅

Redis发布和订阅1.什么是发布和订阅2.Redis命令演示发布订阅1.什么是发布和订阅 Redis发布订阅(pub/sub)是一种消息通信模式:发布者(pub)发送消息,订阅者(sub)接收消息。 Redis客户端可以订阅任意数量的频道。 2.Redis命令演示发布订阅 打开两个终端 终…

python教程二十 输入和输出

输出格式美化 Python两种输出值的方式: 表达式语句和 print() 函数。 第三种方式是使用文件对象的 write() 方法,标准输出文件可以用 sys.stdout 引用。 如果你希望输出的形式更加多样,可以使用 str.format() 函数来格式化输出值。 如果你希望将输出…

记录在苹果mac os系统上使用51单片机仿真软件Proteus

目录 1.安装Wineskin shell 指令 2.安装Wrapper 点击update ​​​​​​​ 1.安装Wineskin 首先我们需要安装一个程序: 可以将在Windows系统上才能运行exe文件打包为mac系统可执行的文件。 shell 指令 brew install --no-quarantine gcenx/wine/unofficial…

气体在线监测仪——排水管井内的有害气体监测

一、产品概述 气体在线监测仪内部采用模块化设计,可对雨污水管井内的有害气体进行在线监测,设备采用高精度、高分辨率的原装进口气体传感器,具有体积小、重量轻、设计简洁、高性价比、多参数高集成、安装方便等特点。 气体在线监测仪广泛应…

Java IO

目录 一、File 类 二、RandomAccessFile 三、流类 四、字节流 4.1 、InputStream 4.2、OutputStream 五、字符流 5.1、Reader 5.2、Writer 六、管道流 七、ByteArrayInputStream 和 ByteArrayOutputStream 八、System.out 和 System.in 九、打印流 十、DataOutp…

【Leetcode】单值二叉树、 相同的树、对称二叉树、另一颗树的子树、二叉树遍历、二叉树的前序遍历

文章目录OJ链接单值二叉树相同的树对称二叉树另一颗树的子树二叉树遍历二叉树的前序遍历OJ链接 1、【单值二叉树】OJ链接 2、【相同的树】OJ链接 3、【对称二叉树】OJ链接 4、【另一棵树的子树】OJ链接 5、【二叉树遍历】OJ链接 6、【二叉树的前序遍历】OJ链接 单值二叉树 >…

R语言用线性模型进行臭氧预测: 加权泊松回归,普通最小二乘,加权负二项式模型,多重插补缺失值

最近我们被客户要求撰写关于线性模型的研究报告,包括一些图形和统计输出。在这篇文章中,我将从一个基本的线性模型开始,然后尝试找到一个更合适的线性模型。 数据预处理 由于空气质量数据集包含一些缺失值,因此我们将在开始拟合…