TensorFlow介绍二-线性回归案例

news2024/11/11 6:58:12

一.案例步骤

1.准备数据集:y=0.8x+0.7  100个样本

2.建立线性模型,初始化w和b变量

3.确定损失函数(预测值与真实值之间的误差),均方误差

4.梯度下降优化损失

二.完整功能代码:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf

def linear_regression():
    """
    自实现线性回归
    :return: None
    """
    # 构造数据X为一百行一列
    X = tf.random_normal(shape=(100, 1), mean=2, stddev=2)
    # 真实值,y=x*0.8+0.7,这里X为tf.tensor数据在乘的时候要使用二维数据
    y_true = tf.matmul(X, [[0.8]]) + 0.7
    # 使用Variable初始化w,b,因为w和b要参与更新所有要使用变量。trainable是设置这个变量是否参与训练
    weights = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)
    bias = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)
    # 构造预测值,使用X乘上更新后的变量w加上b
    y_predict = tf.matmul(X, weights) + bias
    # 计算均方误差,用真实值减去预测值的平方,因为这是一百个数据,使用要求它的平均值
    error = tf.reduce_mean(tf.square(y_predict - y_true))
    # 构建优化器,这里使用的是梯度下降优化误差来更新w和b,0.01是学习率
    optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(error)

    # 初始化变量
    init = tf.global_variables_initializer()

    with tf.Session() as sess:  # 会话
        # 运行初始化变量op
        sess.run(init)
        # 打印一下初始化的权重和偏置
        print("随机初始化的权重为%f, 偏置为%f" % (weights.eval(), bias.eval()))
        # 开始训练,训练的次数越多越接近真实值
        for i in range(100):

            sess.run(optimizer)
            # 打印每一次更新后的权重,偏置,误差
            print("第%d步的误差为%f,权重为%f, 偏置为%f" % (i, error.eval(), weights.eval(), bias.eval()))

    return None

if __name__ == '__main__':
    linear_regression()

三.增加其他功能

1.增加命名空间

使代码结构更加清晰,Tensorboard图结构更加清楚,

使用tf.variable_scope方法,里面的名字自己定义

with tf.variable_scope("lr_model"):

2.收集变量

这样更容易观察参数的更新情况 

3.写入事件

使用tensorboard观察,在命令行中切换到事件所在文件目录,使用命令:

tensorboard --logdir="事件所在的文件目录"

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf

def linear_regression():
    """
    自实现线性回归
    :return: None
    """
    # 构造数据X为一百行一列
    with tf.variable_scope("original_data"):  # 表示正在创建数据

        X = tf.random_normal(shape=(100, 1), mean=2, stddev=2)
        # 真实值,y=x*0.8+0.7,这里X为tf.tensor数据在乘的时候要使用二维数据
        y_true = tf.matmul(X, [[0.8]]) + 0.7

    with tf.variable_scope("linear_model"): # 初始化变量
        # 使用Variable初始化w,b,因为w和b要参与更新所有要使用变量。trainable是设置这个变量是否参与训练
        weights = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)
        bias = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)
        # 构造预测值,使用X乘上更新后的变量w加上b
        y_predict = tf.matmul(X, weights) + bias

    with tf.variable_scope("loss"):  # 确定误差
        # 计算均方误差,用真实值减去预测值的平方,因为这是一百个数据,使用要求它的平均值
        error = tf.reduce_mean(tf.square(y_predict - y_true))

    with tf.variable_scope("gd_optimizer"):  # 构建优化器
        # 构建优化器,这里使用的是梯度下降优化误差来更新w和b,0.01是学习率
        optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(error)

    # 收集变量
    tf.summary.scalar("error", error)
    tf.summary.histogram("weights", weights)
    tf.summary.histogram("bias", bias)

    # 合并变量
    merge=tf.summary.merge_all()

    # 初始化变量
    init = tf.global_variables_initializer()

    with tf.Session() as sess:  # 会话
        # 运行初始化变量op
        sess.run(init)
        # 打印一下初始化的权重和偏置
        print("随机初始化的权重为%f, 偏置为%f" % (weights.eval(), bias.eval()))

        # 创建事件文件,将事件写入到ligdir中的目录中
        file_writer=tf.summary.FileWriter(logdir="./summary",graph=sess.graph)

        # 开始训练,训练的次数越多越接近真实值
        for i in range(100):

            sess.run(optimizer)
            # 打印每一次更新后的权重,偏置,误差
            print("第%d步的误差为%f,权重为%f, 偏置为%f" % (i, error.eval(), weights.eval(), bias.eval()))

            # 运行合并变量op
            summary=sess.run(merge)
            file_writer.add_summary(summary,i)


    return None

if __name__ == '__main__':
    linear_regression()

 四.模型的保存和加载

tf.train.Saver(var_list=None,max_to_keep=5)

保存和加载模型(保存文件格式:checkpoint文件)
var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.
max_to_keep:指示要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件。如果无或0,则保留所有检查点文件。默认为5(即保留最新的5个检查点文件。)

例如

# 指定目录+模型名字
# 保存
saver.save(sess, '/tmp/ckpt/test/myregression.ckpt')
# 加载
saver.restore(sess, '/tmp/ckpt/test/myregression.ckpt')

如果判断模型是否存在,直接指定目录

checkpoint = tf.train.latest_checkpoint("./tmp/model/")

saver.restore(sess, checkpoint)

五.命令行参数使用

1.tf.app.flags,它支持应用从命令行接收参数,可以用来指定集训配置等,在tf.app.flags下面各种定义参数的类型

2、 tf.app.flags.,在flags有一个FLAGS标志,它在程序中可以调用到我们

前面具体定义的flag_name

3.通过tf.app.run()启动main(argv)函数

# 定义一些常用的命令行参数
# 训练步数
tf.app.flags.DEFINE_integer("max_step", 0, "训练模型的步数")
# 定义模型的路径
tf.app.flags.DEFINE_string("model_dir", " ", "模型保存的路径+模型名字")

# 定义获取命令行参数
FLAGS = tf.app.flags.FLAGS

# 开启训练
# 训练的步数(依据模型大小而定)
for i in range(FLAGS.max_step):
     sess.run(train_op)

六.完整代码

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf

# 模型保存
tf.app.flags.DEFINE_string("model_path", "./linear_regression/", "模型保存的路径和文件名")
FLAGS = tf.app.flags.FLAGS


def linear_regression():
    """
    自实现线性回归
    :return: None
    """
    # 构造数据X为一百行一列
    with tf.variable_scope("original_data"):  # 表示正在创建数据

        X = tf.random_normal(shape=(100, 1), mean=2, stddev=2)
        # 真实值,y=x*0.8+0.7,这里X为tf.tensor数据在乘的时候要使用二维数据
        y_true = tf.matmul(X, [[0.8]]) + 0.7

    with tf.variable_scope("linear_model"): # 初始化变量
        # 使用Variable初始化w,b,因为w和b要参与更新所有要使用变量。trainable是设置这个变量是否参与训练
        weights = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)
        bias = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)
        # 构造预测值,使用X乘上更新后的变量w加上b
        y_predict = tf.matmul(X, weights) + bias

    with tf.variable_scope("loss"):  # 确定误差
        # 计算均方误差,用真实值减去预测值的平方,因为这是一百个数据,使用要求它的平均值
        error = tf.reduce_mean(tf.square(y_predict - y_true))

    with tf.variable_scope("gd_optimizer"):  # 构建优化器
        # 构建优化器,这里使用的是梯度下降优化误差来更新w和b,0.01是学习率
        optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(error)

    # 收集变量
    tf.summary.scalar("error", error)
    tf.summary.histogram("weights", weights)
    tf.summary.histogram("bias", bias)

    # 合并变量
    merge=tf.summary.merge_all()

    # 初始化变量
    init = tf.global_variables_initializer()

    with tf.Session() as sess:  # 会话
        # 运行初始化变量op
        sess.run(init)
        # 打印一下初始化的权重和偏置
        print("随机初始化的权重为%f, 偏置为%f" % (weights.eval(), bias.eval()))

        # 创建事件文件,将事件写入到ligdir中的目录中
        file_writer=tf.summary.FileWriter(logdir="./summary",graph=sess.graph)

        # 开始训练,训练的次数越多越接近真实值
        for i in range(100):

            sess.run(optimizer)
            # 打印每一次更新后的权重,偏置,误差
            print("第%d步的误差为%f,权重为%f, 偏置为%f" % (i, error.eval(), weights.eval(), bias.eval()))

            # 运行合并变量op
            summary=sess.run(merge)
            file_writer.add_summary(summary,i)


    return None


def main(argv):
    print("这是main函数")
    print(argv)
    print(FLAGS.model_path)
    linear_regression()

if __name__ == '__main__':
    tf.app.run()

都看到这里了,点个赞呗!!!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2101965.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端脚手架,自动创建远程仓库并推送

包含命令行选择和输入配置,远程仓库拉取模板,根据配置将代码注入模板框架的代码中,自动创建远程仓库,初始化git并提交至远程仓库,方便项目开发,简化流程。 目录结构 创建一个bin文件夹,添加ind…

KAN 学习 Day2 —— utils.py及spline.py 代码解读及测试

在KAN学习Day1——模型框架解析及HelloKAN中,我对KAN模型的基本原理进行了简单说明,并将作者团队给出的入门教程hellokan跑了一遍,今天我们直接开始进行源码解读。 目录 一、kan目录 二、utils.py 2.1 导入库和模块 2.2 逆函数定义 2.3 …

CentOS 7安装Docker详细步骤-无坑-丝滑-顺畅

一,安装软件包 yum install -y yum-utils device-mapper-persistent-data lvm2二,更换yum源为阿里源: yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo 三,查看docker版本&…

标准库标头 <optional> (C++17)学习之optional

类模板 std::optional 管理一个可选 &#xfeff;的所含值&#xff0c;即既可以存在也可以不存在的值。 一种常见的 optional 使用情况是作为可能失败的函数的返回值。与如 std::pair<T, bool> 等其他手段相比&#xff0c;optional 可以很好地处理构造开销高昂的对象&a…

【科普】双轴测径仪是根据哪个测量值控制外径尺寸?

单轴测径仪与双轴测径仪都是自带闭环控制功能的在线外径测量设备&#xff0c;单轴测径仪只有一个测头&#xff0c;是根据该测头的检测数据进行控制&#xff0c;这点毋庸置疑&#xff0c;那双轴测径仪这种具备两组测头的设备又是如何控制的&#xff0c;本文就来简单的介绍一下。…

Ubuntu安装网卡驱动

没有无线网 给自己装了双系统后&#xff0c;发现没有无线网络 下载驱动文件 打开终端&#xff0c;输入 lspci -k 能看到&#xff0c;虽然我是RTL8125BG&#xff0c;但use的是r8169: 08:00.0 Ethernet controller: Realtek Semiconductor Co., Ltd. RTL8125 2.5GbE Controll…

Vue前端路由详解——以Ruoyi框架为案例学习

Vue路由 Vue路由详解_vue 页面路由-CSDN博客 路由模式 Vue 的路由模式&#xff1a;hash 模式和 history 模式的区别_vue路由history和hash的区别-CSDN博客 URL格式&#xff1a; Hash模式&#xff1a;URL中包含#号&#xff0c;用于区分页面部分&#xff0c;实际请求的页面地址…

OpenCV下的无标定校正(stereoRectifyUncalibrated)

OpenCV下的无标定校正(stereoRectifyUncalibrated) 文章目录 1. 杂话2. 无标定校正2.1 先看代码2.2 一点解释2.3 findFundamentalMat参数2.4 stereoRectifyUncalibrated参数 3. 矫正结果 1. 杂话 咱们在之前的帖子里面讲了一些比较常规的标定和校正OpenCV下的单目标定&#xff…

Unity数据持久化 之 文件操作(增删查改)

本文仅作笔记学习和分享&#xff0c;不用做任何商业用途 本文包括但不限于unity官方手册&#xff0c;unity唐老狮等教程知识&#xff0c;如有不足还请斧正​​ 这里需要弄清几个概念&#xff1a; File&#xff1a;提供文件操作的静态方法&#xff0c;是管理的 Windows.File -…

除浮毛用吸尘器有用吗?除浮毛真正有用浮毛空气净化器总结

我的医生朋友经常给朋友们讲解宠物毛发对呼吸道的潜在影响&#xff0c;这引起了不同的反应。有人采纳了他的建议&#xff0c;采取了防护措施&#xff1b;而有人则认为他在制造恐慌&#xff0c;特别是当听到宠物医生的说法与之相左时。 我曾也心存疑虑&#xff0c;但经过与朋友…

做开发一年多了,分享一下自己的疑惑以及大模型给我的一些建议~

写在最前面,下面的疑问是我自己的一些困惑和想知道背后的答案,回答这块是大模型的一些建议,我觉得对我来说不能说很对,至少给我了启发和思考,分享出来给大家,大家如果也有类似的疑惑,希望能提供到帮助 原先Java生态是出现各种复杂的业务场景,需要使用合理且合适的技术架…

house of cat

文章目录 house of cat概述&#xff1a;_IO_wfile_jumps进入_IO_wfile_seekoffFSOP__malloc_assert 例题&#xff1a;思路&#xff1a;分析&#xff1a;利用&#xff1a; house of cat 概述&#xff1a; house of cat主要的摸底还是覆盖vtable指针&#xff0c;因为在glibc-2.2…

结构型设计模式—桥接模式

结构型设计模式—桥接模式 欢迎长按图片加好友&#xff0c;我会第一时间和你分享持续更多的开发知识&#xff0c;面试资源&#xff0c;学习方法等等。 假设你要买一张新桌子&#xff0c;你有两个选择&#xff1a;一种是木制的桌子&#xff0c;另一种是金属制的桌子。 无论你选…

软件工程知识点总结(1):软件工程概述

1 什么是软件&#xff1f; 定义&#xff1a;计算机系统中的程序及其文档。 ——程序是计算机任务的处理对象和处理规模的描述&#xff1b; ——文档是为了便于了解程序所需要的阐明性资料。 2 软件的特点&#xff1f; 软件是无形的&#xff0c;不可见的逻辑实体 ——它的正确与…

【Python基础】字典类型

本文收录于 《Python编程入门》专栏&#xff0c;从零基础开始&#xff0c;分享一些Python编程基础知识&#xff0c;欢迎关注&#xff0c;谢谢&#xff01; 文章目录 一、前言二、Python 字典类型2.1 访问字典里的值2.2 修改字典2.3 删除字典元素2.4 字典键值的特性2.5 遍历字典…

免费pdf转word软件,为你整理出8种方法,总有一个适合你

在日常办公和学习中&#xff0c;PDF文档因其格式稳定、不易修改的特性而广受欢迎。然而&#xff0c;有时我们需要对PDF内容进行编辑或格式调整&#xff0c;这时将其转换为Word文档便显得尤为重要。下面给大家介绍8种将PDF转换成Word的方法&#xff0c;包括在线网站、专业软件及…

第四篇——数学思维:数学家如何从逻辑出发想问题?

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么&#xff1f; 四、总结五、升华 一、背景介绍 数学思维中的很多方法能够让我们脱离事物的表象去直击本质通过本质进行逻…

在Linux中使用MySQL基础SQL语句及校验规则

卸载内置环境 查看是否存在MySQL ps axj | grep mysql关闭MySQL systemctl stop mysqld MySQL对应的安装文件 rpm -qa | grep mysql 批量卸载 rpm -qa | grep mysql | xargs yum -y remove 上传MySQL rz 查看本地yum源 ls /etc/yum.repos.d/ -a 安装MySQL rpm -ivh…

Linux:手搓shell

之前学了一些和进程有关的特性&#xff0c;什么进程控制啊进程替换啊&#xff0c;我们来尝试自己搓一个shell()吧 首先我们观察shell的界面&#xff0c;发现centos的界面上有命令提示符&#xff1a; [主机名用户名当前路径] 我们可以通过调用系统函数获取当前路径&#xff0…

C语言代码练习(第十二天)

今日练习&#xff1a; 28、&#xff08;指针&#xff09;将字符串 a 复制为字符串 b &#xff0c;然后输出字符串 b 29、改变指针变量的值 30、输入两个整数&#xff0c;然后让用户选择1或者2&#xff0c;选择1是调用 max &#xff0c;输出两者中的大数&#xff0c;选择2是调用…