【深度学习】实验06 使用TensorFlow完成线性回归

news2025/1/22 21:46:44

文章目录

  • 使用TensorFlow完成线性回归
    • 1. 导入TensorFlow库
    • 2. 构造数据集
    • 3. 定义基本模型
    • 4. 训练模型
    • 5. 线性回归图

使用TensorFlow完成线性回归

1. 导入TensorFlow库

# 导入相关库
%matplotlib inline
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

2. 构造数据集

# 产出样本点个数
n_observations = 100
# 产出-3~3之间的样本点
xs = np.linspace(-3, 3, n_observations) 
# sin扰动
ys = np.sin(xs) + np.random.uniform(-0.5, 0.5, n_observations) 
xs

array([-3. , -2.93939394, -2.87878788, -2.81818182, -2.75757576,
-2.6969697 , -2.63636364, -2.57575758, -2.51515152, -2.45454545,
-2.39393939, -2.33333333, -2.27272727, -2.21212121, -2.15151515,
-2.09090909, -2.03030303, -1.96969697, -1.90909091, -1.84848485,
-1.78787879, -1.72727273, -1.66666667, -1.60606061, -1.54545455,
-1.48484848, -1.42424242, -1.36363636, -1.3030303 , -1.24242424,
-1.18181818, -1.12121212, -1.06060606, -1. , -0.93939394,
-0.87878788, -0.81818182, -0.75757576, -0.6969697 , -0.63636364,
-0.57575758, -0.51515152, -0.45454545, -0.39393939, -0.33333333,
-0.27272727, -0.21212121, -0.15151515, -0.09090909, -0.03030303,
0.03030303, 0.09090909, 0.15151515, 0.21212121, 0.27272727,
0.33333333, 0.39393939, 0.45454545, 0.51515152, 0.57575758,
0.63636364, 0.6969697 , 0.75757576, 0.81818182, 0.87878788,
0.93939394, 1. , 1.06060606, 1.12121212, 1.18181818,
1.24242424, 1.3030303 , 1.36363636, 1.42424242, 1.48484848,
1.54545455, 1.60606061, 1.66666667, 1.72727273, 1.78787879,
1.84848485, 1.90909091, 1.96969697, 2.03030303, 2.09090909,
2.15151515, 2.21212121, 2.27272727, 2.33333333, 2.39393939,
2.45454545, 2.51515152, 2.57575758, 2.63636364, 2.6969697 ,
2.75757576, 2.81818182, 2.87878788, 2.93939394, 3. ])

ys

array([-0.62568008, 0.01486274, -0.29232541, -0.05271084, -0.53407957,
-0.37199581, -0.40235236, -0.80005504, -0.2280913 , -0.96111433,
-0.58732159, -0.71310851, -1.19817878, -0.93036437, -1.02682804,
-1.33669261, -1.36873043, -0.44500172, -1.38769079, -0.52899793,
-0.78090929, -1.1470421 , -0.79274726, -0.95139505, -1.3536293 ,
-1.15097615, -1.04909201, -0.89071026, -0.81181765, -0.70292996,
-0.49732344, -1.22800179, -1.21280414, -0.59583172, -1.05027515,
-0.56369191, -0.68680323, -0.20454038, -0.32429566, -0.84640122,
-0.08175012, -0.76910728, -0.59206189, -0.09984673, -0.52465978,
-0.30498277, 0.08593627, -0.29488864, 0.24698113, -0.07324925,
0.12773032, 0.55508531, 0.14794648, 0.40155342, 0.31717698,
0.63213964, 0.35736413, 0.05264068, 0.39858619, 1.00710311,
0.73844747, 1.12858026, 0.59779567, 1.22131999, 0.80849061,
0.72796849, 1.0990044 , 0.45447096, 1.15217952, 1.31846002,
1.27140258, 0.65264777, 1.15205186, 0.90705463, 0.82489198,
0.50572125, 1.47115594, 0.98209434, 0.95763951, 0.50225094,
1.40415029, 0.74618984, 0.90620692, 0.40593222, 0.62737999,
1.05236579, 1.20041249, 1.14784273, 0.54798933, 0.18167682,
0.50830766, 0.92498585, 0.9778136 , 0.42331405, 0.88163729,
0.67235809, -0.00539421, -0.06219493, 0.26436412, 0.51978602])

# 可视化图长和宽
plt.rcParams["figure.figsize"] = (6,4)
# 绘制散点图
plt.scatter(xs, ys) 
plt.show()

1

3. 定义基本模型

# 占位
X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')
# 随机采样出变量
W = tf.Variable(tf.random_normal([1]), name='weight') 
b = tf.Variable(tf.random_normal([1]), name='bias')
# 手写y = wx+b
Y_pred = tf.add(tf.multiply(X, W), b) 
# 定义损失函数mse
loss = tf.square(Y - Y_pred, name='loss') 
# 学习率
learning_rate = 0.01
# 优化器,就是tensorflow中梯度下降的策略
# 定义梯度下降,申明学习率和针对那个loss求最小化
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) 

4. 训练模型

# 去样本数量
n_samples = xs.shape[0]
init = tf.global_variables_initializer()
with tf.Session() as sess:
    # 记得初始化所有变量
    sess.run(init) 
    writer = tf.summary.FileWriter('../graphs/linear_reg', sess.graph)
    # 训练模型
    for i in range(50):
        #初始化损失函数
        total_loss = 0
        for x, y in zip(xs, ys):
            # 通过feed_dic把数据灌进去
            _, l = sess.run([optimizer, loss], feed_dict={X: x, Y:y}) #_是optimizer的返回,在这没有用就省略
            total_loss += l #统计每轮样本的损失
        print('Epoch {0}: {1}'.format(i, total_loss/n_samples)) #求损失平均

    # 关闭writer
    writer.close() 
    # 取出w和b的值
    W, b = sess.run([W, b]) 
Epoch 0: [0.48447946]
Epoch 1: [0.20947962]
Epoch 2: [0.19649307]
Epoch 3: [0.19527708]
Epoch 4: [0.19514856]
Epoch 5: [0.19513479]
Epoch 6: [0.19513334]
Epoch 7: [0.19513316]
Epoch 8: [0.19513315]
Epoch 9: [0.19513315]
Epoch 10: [0.19513315]
Epoch 11: [0.19513315]
Epoch 12: [0.19513315]
Epoch 13: [0.19513315]
Epoch 14: [0.19513315]
Epoch 15: [0.19513315]
Epoch 16: [0.19513315]
Epoch 17: [0.19513315]
Epoch 18: [0.19513315]
Epoch 19: [0.19513315]
Epoch 20: [0.19513315]
Epoch 21: [0.19513315]
Epoch 22: [0.19513315]
Epoch 23: [0.19513315]
Epoch 24: [0.19513315]
Epoch 25: [0.19513315]
Epoch 26: [0.19513315]
Epoch 27: [0.19513315]
Epoch 28: [0.19513315]
Epoch 29: [0.19513315]
Epoch 30: [0.19513315]
Epoch 31: [0.19513315]
Epoch 32: [0.19513315]
Epoch 33: [0.19513315]
Epoch 34: [0.19513315]
Epoch 35: [0.19513315]
Epoch 36: [0.19513315]
Epoch 37: [0.19513315]
Epoch 38: [0.19513315]
Epoch 39: [0.19513315]
Epoch 40: [0.19513315]
Epoch 41: [0.19513315]
Epoch 42: [0.19513315]
Epoch 43: [0.19513315]
Epoch 44: [0.19513315]
Epoch 45: [0.19513315]
Epoch 46: [0.19513315]
Epoch 47: [0.19513315]
Epoch 48: [0.19513315]
Epoch 49: [0.19513315]
print(W,b)
print("W:"+str(W[0]))
print("b:"+str(b[0]))
[0.23069778] [-0.12590201]
W:0.23069778
b:-0.12590201

5. 线性回归图

# 线性回归图
plt.plot(xs, ys, 'bo', label='Real data')
plt.plot(xs, xs * W + b, 'r', label='Predicted data')
plt.legend()
plt.show()

2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/966173.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HTML+JavaScript+CSS DIY 分隔条splitter

一、需求分析 现在电脑的屏幕越来越大,为了利用好宽屏,我们在设计系统UI时喜欢在左侧放个菜单或选项面板,在右边显示与菜单或选项对应的内容,两者之间用分隔条splitter来间隔,并可以通过拖动分隔条splitter来动态调研…

TuyaOS开发学习笔记(1)——NB-IoT开发搭建环境、编译烧写(MT2625)

一、搭建环境 1.1 官方资料 TuyaOS 1.2 安装VMware 官网下载:https://customerconnect.vmware.com/en/downloads/info/slug/desktop_end_user_computing/vmware_workstation_pro/16_0 百度网盘:https://pan.baidu.com/s/1oN7H81GV0g6cD9zsydg6vg 提取…

【Redis从头学-完结】Redis全景思维导图一览!耗时半个月专为Redis初学者打造!

🧑‍💻作者名称:DaenCode 🎤作者简介:CSDN实力新星,后端开发两年经验,曾担任甲方技术代表,业余独自创办智源恩创网络科技工作室。会点点Java相关技术栈、帆软报表、低代码平台快速开…

C语言入门 Day_12 一维数组0

目录 前言 1.创建一维数组 2.使用一维数组 3.易错点 4.思维导图 前言 存储一个数据的时候我们可以使用变量, 比如这里我们定义一个记录语文考试分数的变量chinese_score,并给它赋值一个浮点数(float)。 float chinese_scoe…

2023应届生java面试紧张失误之一:CAS口误说成开心锁-笑坏面试官

源于:XX网,如果冒犯,表示歉意 面试官:什么是CAS 我:这个简单,开心锁 面试官:WTF? 我:一脸自信,对,就是这个 面试官:哈哈大笑&#xff…

YOLOv7框架解析

YOLOv7概念 YOLOv7是基于YOLO系列的目标检测算法,由Ultra-Light-Fast-Detection(ULFD)和Scaled-YOLOv4两种算法结合而来。它是一种高效、准确的目标检测算法,具有以下特点: 1. 高效:YOLOv7在保持准确率的…

一个 MySQL 数据库死锁的案例和解决方案

本文介绍了一个 MySQL 数据库死锁的案例和解决方案。 场景 生产环境出了一个偶现的数据库死锁问题,导致少部分业务处理失败。 分析特征之后,发现是多个线程并发执行同一个方法,更新关联的数据时可能会出现,把场景简化概括一下&…

MySQL基础篇:数据库概述和部署

SQL 概述 SQL,一般发音为sequel,SQL的全称Structured Query Language),SQL用来和数据库打交道,完成和数据库的通信,SQL是一套标准。但是每一个数据库都有自己的特性别的数据库没有,当使用这个数据库特性相关的功能,这…

CDH6.3.2集成Kerberos

CDH6.3.2集成Kerberos 一.参考doc CDH enable kerberos: Kerberos Security Artifacts Overview | 6.3.x | Cloudera Documentation CDH disable kerberos:https://www.sameerahmad.net/blog/disable-kerberos-on-CDH; https://community.cloudera.com/t5/Support-Questions…

ModaHub魔搭社区:自动化机器学习神器Auto-Sklearn

Auto-Sklearn Auto-Sklearn是一个开源库,用于在 Python 中执行 AutoML。它利用流行的 Scikit-Learn 机器学习库进行数据转换和机器学习算法。 它是由Matthias Feurer等人开发的。并在他们 2015 年题为“efficient and robust automated machine learning 高效且稳健的自动…

LeetCode刷题---Two Sum(一)

文章目录 🍀题目🍀解法一🍀解法二🍀哈希表 🍀题目 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每…

javaee spring aop 注解实现

切面类 package com.test.advice;import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.*;//切面类 Aspect public class MyAdvice {//定义切点表达式Pointcut("execution(* com.test.service.impl.*.add(..))")public void pc(){}//B…

Spring-Kafka生产者源码分析

文章目录 概要初始化消息发送小结 概要 本文主要概括Spring Kafka生产者发送消息的主流程 代码准备&#xff1a; SpringBoot项目中maven填加以下依赖 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent&…

大胆尝试这些新的CSS属性,释放CSS的力量吧(一)

本文章系《Unleashing the Power of CSS》&#xff08;释放CSS的力量&#xff0c;暂且这么翻译吧&#xff09;一书的学习笔记&#xff0c;希望通本书的学习&#xff0c;系统的梳理下CSS相关的高级新特性。本篇文章是其第一部分&#xff0c;由于全书英文版&#xff0c;理解和阅读…

mac制作ssl证书|生成自签名证书,nodejs+express在mac上搭建https+wss(websocket)服务器

注意 mac 自带 openssl 所以没必要像 windows 一样先安装 openssl&#xff0c;直接生成即可 生成 ssl/自签名 证书 生成 key # 生成rsa私钥&#xff0c;des3算法&#xff0c;server_ssl.key是秘钥文件名 1024位强度 openssl genrsa -des3 -out server_ssl.key 1024让输入两…

MySQL开机自启动设置(Windows)

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

黑马 软件测试从0到1 常用分类 模型 流程 用例

课程内容&#xff1a; 1、软件测试基础 2、测试设计 3、缺陷管理 4、Web常用标签 5、项目实战 以终为始&#xff0c;由交付实战目标为终&#xff0c;推出所学知识&#xff1b;从认识软件及软件测试&#xff0c;到如何设计测试、缺陷标准及缺陷管理&#xff0c;最终以项目实战贯…

pytorch异常——loss异常,不断增大,并且loss出现inf

文章目录 异常报错异常截图异常代码原因解释修正代码执行结果 异常报错 epoch1:loss3667.782471 epoch2:loss65358620.000000 epoch3:loss14979486720.000000 epoch4:loss1739650891776.000000 epoch5:loss12361745880317952.000000 epoch6:loss2740315398365287284736.000000…

PMD代码检查:为了提升性能,正确使用记录日志的语句(GuardLogStatement)

https://docs.pmd-code.org/pmd-doc-6.55.0/pmd_rules_java_bestpractices.html#guardlogstatement 对应记录日志的语句&#xff0c;要首先检查对应的日志级别有没有实际打开&#xff1b;如果没有实际打开&#xff0c;那么就要跳过字符串的生成环节&#xff0c;以提升性能。 另…

C#,数值计算——Midsql的计算方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { public class Midsql : Midpnt { private double aorig { get; set; } 0.0; public new double func(double x) { return 2.0 * x * funk.funk(aorig x * x); } p…