用TensorFlow实现线性回归

news2025/1/11 12:52:42

说明

本文采用TensorFlow框架进行讲解,虽然之前的文章都采用mxnet,但是我发现tensorflow提供了免费的gpu可供使用,所以果断开始改为tensorflow,若要实现文章代码,可以使用colaboratory进行运行,当然,如果您已经安装了tensorflow,可以采用python直接运行。

贡献

学习时采取动手学深度学习第二版作为教材,但由于本书通过引入d2l(著者自写库)进行深度学习,我希望将d2l的影响去掉,即不使用d2l,使用tensorflow,这一点通过查询GitHub中d2l库提供的相关函数尝试进行实现。

如果本系列文章具有良好表现,将译为英文版上传至Github。

预备知识

学习本篇文章之前,您最好具有以下基础知识:

  1. 线性回归的基础知识
  2. python的基础知识

基本原理 

使用一个仿射变换,通过y=wx+b的模型来对数据进行预测(w和x均为矩阵,大小取决于输入规模),反向传播采用随机梯度下降对参数进行更新,参数包括w和b,即权重和偏差。

实现过程

生成数据集

只需要引入tensorflow即可,synthetic_data()函数将初始化X和Y,即通过真实的权重和偏差值生成数据集。

import tensorflow as tf

def synthetic_data(w, b, num_examples):
    X = tf.zeros((num_examples, w.shape[0]))
    X += tf.random.normal(shape=X.shape)
    y = tf.matmul(X, tf.reshape(w, (-1, 1))) + b
    y += tf.random.normal(shape=y.shape, stddev=0.01)
    y = tf.reshape(y, (-1, 1))
    return X, y

true_w = tf.constant([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

读取数据集

加载刚刚生成的数据集,is_train表示是否进行打乱,默认对数据进行打乱处理,使用load_array函数加载数据集。

def load_array(data_arrays, batch_size, is_train=True):
    dataset = tf.data.Dataset.from_tensor_slices(data_arrays)
    if is_train:
        dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.batch(batch_size)
    return dataset

batch_size = 10
data_iter = load_array((features, labels), batch_size)

定义模型

模型使用keras API实现,keras是tensorflow中机器学习相关的库。先使用Sequential类定义承载容器,之后添加一个单神经元的全连接层。在TensorFlow中,Sequential表示容器相关的类,layer表示层相关的类。线性回归只需要通过keras中的单神经元的全连接层即可实现,神经元的值即为输出结果。

net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1))

示例的线性回归仅有一个输入X,实际在其他线性回归过程中,很有可能有多个x及其对应的w,但keras的代码均不会发生改变,因为keras的Dense类可以自动判断输入的个数。 

初始化模型参数 

stddev表示标准差,initializer生成一个标准差为1,均值为0的正态分布。在构建全连接层时,使用该正态分布进行初始化。

initializer = tf.initializers.RandomNormal(stddev=0.01)
net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1, kernel_initializer=initializer))

定义损失函数和优化算法 

损失函数使用平方损失函数进行计算,训练时使用小批量随机梯度下降SGD方法进行训练,学习率为0.03。

loss = tf.keras.losses.MeanSquaredError()
trainer = tf.keras.optimizers.SGD(learning_rate=0.03)

训练

运行以下代码可以观察训练结果。运行轮次为3轮,每一轮对所有训练集数据进行学习。计算w和b的梯度值,使用梯度下降更新权重w和偏差b。每一轮输出损失函数的值,最终显示权重和偏差的估计误差。

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        with tf.GradientTape() as tape:
            l = loss(net(X, training=True), y)
        grads = tape.gradient(l, net.trainable_variables)
        trainer.apply_gradients(zip(grads, net.trainable_variables))
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')
w = net.get_weights()[0]
print('w的估计误差:', true_w - tf.reshape(w, true_w.shape))
b = net.get_weights()[1]
print('b的估计误差:', true_b - b)

运行结果

epoch 1, loss 0.000194

epoch 2, loss 0.000091

epoch 3, loss 0.000091

w的估计误差: tf.Tensor([-0.00026917 0.00094557], shape=(2,), dtype=float32)

b的估计误差: [4.7683716e-06]

 改进尝试

  1. 更改SGD优化算法为Adam
  2. 更改MeanSquaredError为其他损失函数

对于上述改进,损失均有显著增加,表明原有方法已为最好方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2065045.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ZooKeeper入门及核心知识点整理

什么是Zookeeper Zookeeper简称zk,先从字面意思上去理解,那就是动物园管理员。其实zk是大数据领域中的一员,为整个分布式环境提供了协调服务,主要可以用于存储一些配置信息,同时也可以基于zk实现集群。它是一个apache…

RabbitMQ的基础概念介绍

MQ的三大特点:削峰、异步、解耦 1.RabblitMQ概念介绍 1.1概念 RabbitMQ是由erlang语言开发,基于AMQP(Advanced Message Queue 高级消息队列协议)协议实现的消息队列,它是一种应用程序之间的通信方法,消息…

【docker】使用docker-compose的时候如何更新镜像版本

使用docker-compose的时候如何更新镜像版本。之前总是会忘记怎么操作,每次都得百度搜,干脆记录一下。 说明 我有一个memos是用docker-compose部署的,memos更新的挺频繁的,新版本的功能也不错,更新一下没啥问题。 注…

吴恩达机器学习课后作业-03多分类、神经网络前向传播

这里写目录标题 逻辑回归解决多分类问题(逻辑回归的“一对多”(One-vs-All)策略。)绘制图像结果 神经网络前向传播数字识别 、 逻辑回归解决多分类问题(逻辑回归的“一对多”(One-vs-All)策略。…

[Algorithm][综合训练][孩子们的游戏][大数加法][拼三角]详细讲解

目录 1.孩子们的游戏1.题目链接2.算法原理详解 && 代码实现 2.大数加法1.题目链接2.算法原理详解 && 代码实现 3.拼三角1.题目链接2.算法原理详解 && 代码实现 1.孩子们的游戏 1.题目链接 孩子们的游戏 2.算法原理详解 && 代码实现 问题抽象…

LongWriter——从长文本语言模型中释放出10,000+字的生成能力

概述 当前的长上下文大型语言模型 (LLM) 可以处理多达 100,000 个词的输入,但它们很难生成超过 2,000 个词的输出。受控实验表明,该模型的有效生成长度本质上受到监督微调(SFT) 期间看到的示例的限制。换句话说,这种输出限制源于现有 SFT 数…

三维模型单体化软件:地理信息与遥感领域的精细化革命

在地理信息与遥感科学日新月异的发展浪潮中,单体化软件作为一股强大的驱动力,正引领着我们迈向空间信息处理与应用的新纪元。本文旨在深度解析单体化软件的核心价值、技术前沿、实践应用及面临的挑战,共同探讨这一技术如何塑造行业的未来。 …

【手撕OJ题】——BM8 链表中倒数最后k个结点

目录 🕒 题目⌛ 方法① - 直接遍历⌛ 方法② - 快慢指针 🕒 题目 🔎 BM8 链表中倒数最后k个结点【难度:简单🟢】 输入一个长度为 n 的链表,设链表中的元素的值为 a i a_i ai​ ,返回该链表中倒…

一款MySQL数据库实时增量同步工具,能够监听MySQL二进制日志(Binlog)的变动(附源码)

背景 作为一名CURD的程序员,少不了跟MySQL打交道,在同步数据的时候,MySQL的Binlog显得重中之重,所以处理Binlog的工具尤为重要。 其中阿里巴巴开源的canal 更是耳闻目睹,但是今天小编给大家介绍另外一款MySQL数据库实…

【C++11】常用新语法②(类的新功能 || 可变参数模板 || lambda表达式 || 包装器)

🔥个人主页: Forcible Bug Maker 🔥专栏: C 目录 🌈前言🔥类的新功能新增默认成员函数强制生成默认函数的关键字default禁止生成默认函数的关键字delete 🔥可变参数模板递归函数方式展开参数包…

论文翻译:Benchmarking Large Language Models in Retrieval-Augmented Generation

https://ojs.aaai.org/index.php/AAAI/article/view/29728 检索增强型生成中的大型语言模型基准测试 文章目录 检索增强型生成中的大型语言模型基准测试摘要1 引言2 相关工作3 检索增强型生成基准RAG所需能力数据构建评估指标 4实验设置噪声鲁棒性结果负面拒绝测试平台结果信息…

算法5:位运算

文章目录 小试牛刀进入正题 没写代码的题,其链接点开都是有代码的。开始前请思考下图: 小试牛刀 位1的个数 class Solution { public:int hammingWeight(int n) {int res 0;while (n) {n & n - 1;res;}return res;} };比特位计数 class Solution…

计算机毕业设计选题推荐-猫眼电影数据可视化分析-Python爬虫-k-means算法

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

进程和文件痕迹排查——LINUX

目录 介绍步骤 介绍 进程(Process)是计算机中的程序关于某数据集合上的一次运行活动,是系统进行资源分配和调度的基本单位,是操作系统结构的基础。 在早期面向进程设计的计算机结构中,进程是程序的基本执行实体&…

fastadmin 安装

环境要求,大家可以参考官方文档的,我这里使用的是phpstudy,很多已经集成了。 注意一点,PHP 版本:PHP 7.4 。 第二步:下载 下载地址:https://www.fastadmin.net/download.html 进入下载地址后…

IDEA:Terminal找不到npm

Terminal的命令失效通过修改cmd.exe的方式还是不生效的话,考虑是windwos11 默认idea不是通过管理员启动的,如下图修改就可以了。

前端vue 3中使用 顶象 vue3 版本

顶象 验证 的插件 不知道大家使用过没有 顶象-业务安全引领者&#xff0c;让数字世界无风险 可以防止 机器人刷接口 等 可以在任何 加密操作中使用 下面我直接 贴代码 解释 <script src"https://cdn.dingxiang-inc.com/ctu-group/captcha-ui/v5/index.js" cro…

第12章 网络 (2)

目录 12.5 网络命名空间 12.6 套接字缓冲区 12.6.1 使用 sk_buff 管理数据 12.6.2 管理套接字缓冲区数据 本专栏文章将有70篇左右&#xff0c;欢迎关注&#xff0c;查看后续文章。 12.5 网络命名空间 一个网卡可能只在某个特定命名空间可见。 struct net&#xff1a; 表…

C语言贪吃蛇之BUG满天飞

C语言贪吃蛇之BUG满天飞 今天无意间翻到了大一用C语言写的贪吃蛇&#xff0c;竟然还标注着BUG满天飞&#xff0c;留存一下做个纪念&#xff0c;可能以后就找不到了 /* 此程序 --> 贪吃蛇3.0 Sur_流沐 当前版本&#xff1a; Bug满天飞 */ #include<stdio.h> #includ…

Linux C、C++编程之线程同步

【图书推荐】《Linux C与C一线开发实践&#xff08;第2版&#xff09;》_linux c与c一线开发实践pdf-CSDN博客《Linux C与C一线开发实践&#xff08;第2版&#xff09;&#xff08;Linux技术丛书&#xff09;》(朱文伟&#xff0c;李建英)【摘要 书评 试读】- 京东图书 (jd.com…