【TensorFlow1.X】系列学习笔记【基础一】

news2024/11/18 19:59:40

【TensorFlow1.X】系列学习笔记【基础一】

大量经典论文的算法均采用 TF 1.x 实现, 为了阅读方便, 同时加深对实现细节的理解, 需要 TF 1.x 的知识


文章目录

  • 【TensorFlow1.X】系列学习笔记【基础一】
  • 前言
  • 线性回归
  • 非线性回归
  • 逻辑回归
  • 总结


前言

本篇博主将用最简洁的代码由浅入深实现几个小案例,让读者直观体验最基础的数据的处理、模型的设计以及模型的优化。【代码参考】


线性回归

线性回归是一种常见回归分析方法,它假设目标值与特征之间存在线性关系。线性回归模型通过拟合线性函数来预测目标值。线性回归模型的形式比较单一的,即满足一个多元一次方程。常见的线性方程如: y = w × x + b {\rm{y}} = w \times x + b y=w×x+b,但是观测到的数据往往是带有噪声,于是给现有的模型一个因子 ε \varepsilon ε,并假设该因子符合标准正态分布: y = w × x + b + ε {\rm{y}} = w \times x + b + \varepsilon y=w×x+b+ε。对于线性模型,深度学习可以通过构建单层神经网络来描述,这个单层神经网络通常被称为全连接层(Fully Connected Layer)或线性层(Linear Layer),其中每个神经元都与上一层的所有神经元相连接,且没有非线性激活函数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# 随机生成100个数据点,服从“0~1”均匀分布
x_data = np.random.rand(100)

# 提升维度(100)-->(100,1)
x_data = x_data[:, np.newaxis]

# 制作噪声,shape与x_data一致
noise = np.random.normal(0, 0.02,  x_data.shape)

# 构造目标公式
y_data = 0.8 * x_data + 0.1 + noise

# 输入层:placeholder用于接收训练的数据
x = tf.placeholder(tf.float32, [None, 1], name="x_input")
y = tf.placeholder(tf.float32, [None, 1], name="y_input")

# 构造线性模型
b = tf.Variable(0., name="bias")
w = tf.Variable(0., name="weight")
out = w * x_data + b

# 构建损失函数
loss = 1/2*tf.reduce_mean(tf.square(out - y))
# print(loss)

# 定义优化器
optim = tf.train.GradientDescentOptimizer(0.1)
# print(optim)

# 最小化损失函数
train_step = optim.minimize(loss)

# 初始化全部的变量
init = tf.global_variables_initializer()

# 训练迭代
with tf.Session() as sess:
    sess.run(init)
    for step in range(2000):
        sess.run([loss, train_step], {x: x_data, y: y_data})
        if step % 200 == 0:
            w_value, b_value, loss_value = sess.run([w, b, loss], {x: x_data, y: y_data})
            print("step={}, k={}, b={}, loss={}".format(step, w_value, b_value, loss_value))
    prediction_value = sess.run(out, feed_dict={x: x_data})

plt.figure()
plt.scatter(x_data, y_data)
plt.plot(x_data, prediction_value, "r-", lw=3)
plt.show()


非线性回归

非线性回归也是一种常见回归分析方法,它假设目标值与特征之间存在非线性关系。与线性回归不同,非线性回归模型可以拟合复杂的非线性关系。通过拟合非线性函数到数据中,非线性回归模型可以找到最佳的函数参数,以建立一个能够适应数据的非线性关系的模型。非线性回归模型的形式可以是多项式函数、指数函数、对数函数、三角函数等任意形式的非线性函数,这些函数可以包含自变量的高次项、交互项或其他非线性变换。常见的非线性方程如: y = x 2 {\rm{y}} = {x^2} y=x2,但是观测到的数据往往是带有噪声,于是给现有的模型一个因子 ε \varepsilon ε,并假设该因子符合标准正态分布: y = x 2 + ε {\rm{y}} = {x^2} + \varepsilon y=x2+ε。深度学习模型通常由多个神经网络层组成,每一层都包含许多神经元。每个神经元接收来自前一层的输入,并通过激活函数对输入进行非线性转换,然后将结果传递给下一层,通过多个层的堆叠,深度学习模型可以学习到多个抽象层次的特征表示。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# 生成200个数据点,从“-0.5~0.5”均匀排布
x_data = np.linspace(-0.5, 0.5, 200)

# 提升维度(200)-->(200,1)
x_data = x_data[:, np.newaxis]

# 制作噪声,shape与x_data一致
noise = np.random.normal(0, 0.02,  x_data.shape)

# 构造目标公式
y_data = np.square(x_data) + noise

# 输入层:placeholder用于接收训练的数据
x = tf.placeholder(tf.float32, [None, 1], name="x_input")
y = tf.placeholder(tf.float32, [None, 1], name="y_input")

# 隐藏层
W_1 = tf.Variable(tf.random_normal([1, 10]))
b_1 = tf.Variable(tf.zeros([1, 10]))
a_1 = tf.matmul(x, W_1) + b_1
out_1 = tf.nn.tanh(a_1)

# 输出层
W_2 = tf.Variable(tf.random_normal([10, 1]))
b_2 = tf.Variable(tf.zeros([1, 1]))
a_2 = tf.matmul(out_1, W_2) + b_2
out_2 = tf.nn.tanh(a_2)

# 构建损失函数
loss = 1/2*tf.reduce_mean(tf.square(out_2- y))

# 定义优化器
optim = tf.train.GradientDescentOptimizer(0.1)

# 最小化损失函数
train_step = optim.minimize(loss)

# 初始化全部的变量
init = tf.global_variables_initializer()

# 训练
with tf.Session() as sess:
    sess.run(init)
    for epc in range(10000):
        sess.run([loss, train_step], {x:x_data,y:y_data})
        if epc % 1000 == 0:
            loss_value = sess.run([loss], {x:x_data,y:y_data})
            print("epc={}, loss={}".format(epc, loss_value))
    prediction_value = sess.run(out_2, feed_dict={x:x_data})

plt.figure()
plt.scatter(x_data, y_data)
plt.plot(x_data, prediction_value, "r-", lw=3)
plt.show()


逻辑回归

逻辑回归是一种用于分类问题的统计模型,它假设目标变量与特征之间存在概率关系。逻辑回归模型通过线性函数和逻辑函数的组合来建模概率,以预测样本属于某个类别的概率。逻辑回归本身是一个简单的线性分类模型,但深度学习可以自动地学习特征表示,并通过多层非线性变换来模拟更复杂的关系。MNIST数据集通常被认为是深度学习的入门级别任务之一,可以帮助初学者熟悉深度学习的基本概念、模型构建和训练过程。虽然MNIST是一个入门级别的任务,但它并不能完全代表实际应用中的复杂视觉问题。在实践中,还需要面对更大规模的数据集、多类别分类、图像分割、目标检测等更具挑战性的问题。

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt

# 载入数据集:首次调用时自动下载数据集(MNIS 数据集)并将其保存到指定的目录中。
mnist = input_data.read_data_sets("MNIST", one_hot=True)

# 设置batch_size的大小
batch_size = 50
# (几乎)所有数据集被用于训练所需的次数
n_batchs = mnist.train.num_examples // batch_size

# 输入层:placeholder用于接收训练的数据
# 这里图像大小是28×28,对数据集进行压缩28×28=782
x = tf.placeholder(tf.float32, [None, 784],name="x-input")
# 10分类(数字0~9)
y = tf.placeholder(tf.float32, [None, 10], name="y-input")

# 隐藏层
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([1,10]))
# 全连接层
prediction = tf.matmul(x, w) + b
prediction_softmax = tf.nn.softmax(prediction)
# 交叉熵损失函数+计算张量在指定维度(默认0维)上的平均值
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))

# 定义优化器
optim = tf.train.GradientDescentOptimizer(0.01)

# 最小化损失函数
train_step = optim.minimize(loss)

# 初始化全部的变量
init = tf.global_variables_initializer()

# 计算准确率:选择概率最大的数字作为预测值与真实值进行比较,统计正确的个数再计算准确率
correct_prediction = tf.equal(tf.argmax(prediction_softmax, 1), tf.argmax(y, 1))
accuarcy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# GPU使用和显存分配:最大限度为1/3
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
# 用于配置 GPU
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

epoch_arr = np.array([])
acc_arr = np.array([])
loss_arr = np.array([])

with tf.Session() as sess:
    sess.run(init)
    # 训练总次数
    for epoch in range(200):
        # 每轮训练的迭代次数
        for batch in range(n_batchs):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            sess.run([train_step],{x:batch_x, y: batch_y})
        # 用训练集每完成一次训练,则用测试集验证
        acc, los = sess.run([accuarcy, loss], feed_dict = {x:mnist.test.images, y:mnist.test.labels})
        epoch_arr= np.append(epoch_arr, epoch)
        acc_arr = np.append(acc_arr, acc)
        loss_arr = np.append(loss_arr, los)
        print("epoch: ", epoch, "acc: ",acc, "loss: ", los)

# 分别显示精度上升趋势和损失下降趋势
fig, (ax1, ax2) = plt.subplots(1, 2)

ax1.set_title('acc_trends')
ax1.set_xlabel('epoch')
ax1.set_ylabel('acc')
ax1.plot(epoch_arr, acc_arr, "r-", lw=3)

ax2.set_title('loss_trends')
ax2.set_xlabel('epoch')
ax2.set_ylabel('loss')
ax2.plot(epoch_arr, loss_arr, "g-", lw=3)
plt.show()


总结

训练深度学习模型通常需要大量的标记数据和计算资源。一种常用的训练算法是反向传播算法,它通过最小化损失函数来优化模型参数。常见的损失函数是均方误差损失函数和交叉熵损失函数,可以度量模型输出的概率分布与实际标签之间的差异。在实际应用中,深度学习通常用于处理非线性回归,而逻辑回归和线性回归则是其中的一些特殊情况。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1117338.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AD20原理图库的制作

1、打开“51单片机最小系统”的工程文件。 2、创建原理图库文件:单击“文件”菜单,选择“新的”选项中的“库”选项,再选择“原理图库”,进入原理图库元件的编辑界面。 3、保存原理图库文件:选择“文件”菜单&#xff…

antd vue 组件 使用下拉框的层级来显示后面的输入框

效果图&#xff1a; 代码&#xff1a; HTML: <dir><a-row><a-col :span"4"><a-form-model-item label"审批层级" ><a-selectplaceholder"请选择审批层级"v-model"form.PlatformPurchaseApproveLevel"cha…

安达发|人工智能在APS高级计划与排程中的应用

随着人工智能&#xff08;AI&#xff09;技术的发展&#xff0c;其在生产计划与排程&#xff08;APS&#xff09;领域的应用也日益广泛。APS是一种复杂的系统工程&#xff0c;它需要处理大量的数据&#xff0c;包括需求预测、资源优化、路径规划等。AI技术的应用可以帮助企业更…

身份证读卡器ubuntu虚拟机实现RK3399 Arm Linux开发板交叉编译libdonsee.so找不到libusb解决办法

昨天一个客户要在RK3399 Linux开发板上面使用身份证读卡器&#xff0c;由于没有客户的开发板&#xff0c;故只能用本机ubuntu虚拟机来交叉编译&#xff0c;用客户发过来的交叉编译工具&#xff0c;已经编译好libusb然后编译libdonsee.so的时候提示找不到libusb&#xff0c;报错…

语音芯片KT142C两种音频输出方式PWM和DAC的区别

目录 语音芯片KT142C两种音频输出方式PWM和DAC的区别 一般的语音芯片&#xff0c;输出方式&#xff0c;无外乎两种&#xff0c;即dac输出&#xff0c;或者PWM输出 其中dac的输出&#xff0c;一般应用场景都是外挂功放芯片&#xff0c;实现声音的放大&#xff0c;比如常用的音箱…

csapp-Machine-Level Representation of Program-review

Machine-Level Representation of Program收获和思考 Basics Machine-Level Programming可以看成是机器执行对于上层代码的一种翻译&#xff0c;即硬件是如何通过一个个的指令去解释每一行代码&#xff0c;然后操纵各种硬件执行出对应的结果。 Machine-Level Programming有2种…

Jprofiler V14中文使用文档

JProfiler介绍 什么是JProfiler? JProfiler是一个用于分析运行JVM内部情况的专业工具。 在开发中你可以使用它,用于质量保证,也可以解决你的生产系统遇到的问题。 JProfiler处理四个主要问题: 方法调用 这通常被称为"CPU分析"。方法调用可以通过不同的方式进行测…

【剑指Offer】33.二叉搜索树的后序遍历序列

题目 输入一个整数数组&#xff0c;判断该数组是不是某二叉搜索树的后序遍历的结果。如果是则返回 true ,否则返回 false 。假设输入的数组的任意两个数字都互不相同。 数据范围&#xff1a; 节点数量 0≤n≤1000 &#xff0c;节点上的值满足 1≤val≤10^5 &#xff0c;保证节…

Xcode报错“compact unwind compressed function offset doesn‘t fit in 24 bits

Assertion failed: (false && “compact unwind compressed function offset doesn’t fit in 24 bits”), function operator(), file Layout.cpp, line 5758. 解决方案&#xff1a;targerts->build settings->other linker Flages增加-ld64

企业数据防泄密软件-文件外发管理,文件,文档,图纸不外泄

企业数据防泄密软件可以帮助保护企业的重要数据和知识产权&#xff0c;其中文件外发管理是一个重要的环节。 PC访问地址&#xff1a;https://isite.baidu.com/site/wjz012xr/2eae091d-1b97-4276-90bc-6757c5dfedee 以下是一些关键功能&#xff1a; 透明加密&#xff1a;加密软件…

【框架源码篇 01】Spring源码-手写IOC

Spring源码手写篇-手写IoC 一、IoC分析 1.Spring的核心 在Spring中非常核心的内容是 IOC和 AOP. 2.IoC的几个疑问? 2.1 IoC是什么&#xff1f; IoC:Inversion of Control 控制反转&#xff0c;简单理解就是&#xff1a;依赖对象的获得被反转了。 2.2 IoC有什么好处? IoC带…

浏览器的四种缓存协议

❤️浏览器缓存 在HTTP里所谓的缓存本质上只是浏览器和业务侧根据不同的报文字段做出不同的缓存动作而已 四种缓存协议如下 Cache-ControlExpiresETag/If-None-MatchLast-Modified/If-Modified-Since &#x1f3a1;Cache-Control 通过响应头设置Cache-Control和max-age&…

【必须安排】书单|1024程序员狂欢节充能书单!

注&#xff1a;以上书单可从京东商城优惠购买&#xff0c;点击以下链接进入图书专题&#xff01;1024程序员狂欢节充能书单 一年一度的1024程序员狂欢节又到啦&#xff01;成为更卓越的自己&#xff0c;坚持阅读和学习&#xff0c;别给自己留遗憾&#xff0c;行动起来吧&#x…

UniApp百度人脸识别插件YL-FaceDetect

插件地址&#xff1a;https://ext.dcloud.net.cn/plugin?id15061 插件说明&#xff1a; 百度离线人脸识别&#xff0c;人脸收集&#xff0c;属性&#xff08;性别年龄&#xff09;识别等&#xff0c;目前只支持安卓端&#xff01; 另&#xff1a;该插件支持的功能为属性识别…

win10下yolox tensorrt模型部署

TensorRT系列之 Win10下yolov8 tensorrt模型加速部署 TensorRT系列之 Linux下 yolov8 tensorrt模型加速部署 TensorRT系列之 Linux下 yolov7 tensorrt模型加速部署 TensorRT系列之 Linux下 yolov6 tensorrt模型加速部署 TensorRT系列之 Linux下 yolov5 tensorrt模型加速部署…

51系列—基于51单片机的数字频率计(代码+文档资料)

本文主要说明基于51单片机的数字频率计设计&#xff0c;完整资料见文末链接 数字频率计概述 数字频率计是计算机、通讯设备、音频视频等科研生产领域不可缺少的测量仪器。它是一种用十进制数字显示被测信号频率的数字测量仪器。它的基本功能是测量正弦信号&#xff0c;方波信…

安达发|AI在APS生产计划排程系统中的应用与优势

随着科技的不断发展&#xff0c;人工智能&#xff08;AI&#xff09;已经在许多领域取得了显著的成果。在生产管理计划系统中&#xff0c;AI技术的应用也日益受到关注。本文将探讨如何将AI人工智能用在生产管理计划系统上&#xff0c;以提高生产效率、降低成本并优化资源配置。…

Qt之使用bitblt抓取bitmap(位图)并转QImage

一.效果 点击按钮抓取窗口自身并显示到QLable中 二.实现 pro文件 QT += core guigreaterThan(QT_MAJOR_VERSION, 4): QT += widgetsCONFIG += c++11SOURCES += \main.cpp \mainwindow.cppHEADERS += \mainwindow.hFORMS += \mainwindow.uiLIBS += -lgdi32 -luser32 -l…

前端数据可视化之【Echarts下载使用】

目录 &#x1f31f;下载&#x1f31f;浏览器引入&#x1f31f;模块化引入 &#x1f31f;使用&#x1f31f;基本使用步骤 &#x1f31f;绘制一个简单的图表&#x1f31f;写在最后 &#x1f31f;下载 &#x1f31f;浏览器引入 官网下载界面&#xff1a;官方网站 或 Echarts中文…

旺店通企业版与金蝶云星辰数据集成方案分享

今天我们将深入介绍旺店通企业版与金蝶云星辰业财一体化数据集成方案&#xff0c;以丰富的业务场景示例展示该平台如何以无缝的方式连接不同系统&#xff0c;实现数据同步、提高效率&#xff0c;同时突显轻易云的出色能力。 概述 旺店通企业版与金蝶云星辰业财一体化数据集成方…