TensorFlow笔记之单神经元完成多分类任务

news2024/12/27 13:43:57

文章目录

  • 前言
  • 一、逻辑回归
    • 1.二分类问题
    • 2.多分类问题
  • 二、数据集调用
  • 三、TensorFlow1.x
    • 1.定义模型
    • 2.训练模型
    • 3.结果可视化
  • 四、TensorFlow2.x
    • 1.定义模型
    • 2.训练模型
    • 3.结果可视化
  • 总结


前言

记录分别在TensorFlow1.x与TensorFlow2.x中使用单神经元完成MNIST手写数字识别的过程。


一、逻辑回归

将回归值映射为各分类的概率

1.二分类问题

1.sigmod函数: y = 1 1 + e − z y= \frac{1}{1+e^{-z}} y=1+ez1
z ∈ ( − ∞ , + ∞ ) z\in ( -\infty,+\infty ) z(,+)映射到 y ∈ [ 0 , 1 ] y\in [0,1 ] y[0,1],0→0.5,连续可微
代入到平方损失函数,为非凸函数,有多个最小值,会产生局部最优
2.对数损失函数: L o s s = ∑ [ − y log ⁡ ( y ^ ) − ( 1 − y ) log ⁡ ( 1 − y ^ ) ] Loss=\sum[-y\log (\hat{y})-(1-y)\log( 1-\hat{y})] Loss=[ylog(y^)(1y)log(1y^)]为凸函数

2.多分类问题

1.softmax函数: P i = e − y i ∑ e − y k {P_i}= \frac{e^{-y_i}}{\sum e^{-y_k}} Pi=eykeyi
增大差距,映射到 y ∈ [ 0 , 1 ] y\in \left [0,1 \right ] y[0,1],各分类概率和为1
2.交叉熵损失函数 L o s s = ∑ − y log ⁡ ( y ^ ) Loss=\sum-y\log (\hat{y}) Loss=ylog(y^)
两个概率分布的距离

二、数据集调用

在tensorflow2.x中调用数据集;
训练集训练模型,验证集调整超参数,测试集测试模型效果
训练集60000个样本,取5000个样本作为验证集;测试集10000个样本

import tensorflow as tf2
import matplotlib.pyplot as plt
import numpy as np
mnist = tf2.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#维度转换,灰度值归一化,标签独热编码
x_train = x_train.reshape((-1, 784))
x_train = tf2.cast(x_train/255.0, tf2.float32)
x_test = x_test.reshape((-1, 784))
x_test = tf2.cast(x_test/255.0, tf2.float32)
y_train = tf2.one_hot(y_train, depth=10)
y_test = tf2.one_hot(y_test, depth=10)
#训练集训练模型,验证集调整超参数,测试集测试模型效果
#训练集60000个样本,取5000个样本作为验证集;测试集10000个样本
x_valid, y_valid = x_train[55000:], y_train[55000:]
x_train, y_train = x_train[:55000], y_train[:55000]

显示图片、标签与预测值

def show(images, labels, preds):
    fig1 = plt.figure(1, figsize=(12, 12))
    for i in range(16):
        ax = fig1.add_subplot(4, 4, i+1)
        ax.imshow(images[i].reshape(28, 28), cmap='binary')
        label = np.argmax(labels[i])
        pred = np.argmax(preds[i])       
        title = 'label:%d,pred:%d' % (label, pred)
        ax.set_title(title)
        ax.set_xticks([])
        ax.set_yticks([])

三、TensorFlow1.x

1.定义模型

import tensorflow.compat.v1 as tf
from sklearn.utils import shuffle
tf.disable_eager_execution()
with tf.name_scope('Model'):
    x = tf.placeholder(tf.float32, [None, 784], name='X')
    y = tf.placeholder(tf.float32, [None, 10], name='Y') 
    w = tf.Variable(tf.random_normal((784, 10)), name='W')
    b = tf.Variable(tf.zeros((10)), name='B')
    def model(x, w, b):
        y0 = tf.matmul(x, w) + b#前向计算
        y = tf.nn.softmax(y0)#结果分类
        return y
    pred = model(x, w, b)

2.训练模型

#训练参数
train_epoch = 100
learning_rate = 0.1
batch_size = 100
batch_num = x_train.shape[0] // batch_size
#损失函数与准确率
step = 0
display_step = 5
loss_list = []
acc_list = []
loss_function = tf.reduce_mean(-y*tf.log(pred))
accuracy = tf.reduce_mean(tf.cast\
        (tf.equal(tf.argmax(y, axis=1), tf.argmax(pred, axis=1)), tf.float32))
#优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate)\
    .minimize(loss_function)

变量初始化

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    #tf转为numpy
    x_train = sess.run(x_train)
    x_valid = sess.run(x_valid)
    x_test = sess.run(x_test)
    y_train = sess.run(y_train)
    y_valid = sess.run(y_valid)
    y_test = sess.run(y_test)

迭代训练

    for epoch in range(train_epoch):
        if epoch % 10 == 0:
            print('epoch:%d' % epoch)
        for batch in range(batch_num):
            xi = x_train[batch*batch_size:(batch+1)*batch_size]
            yi = y_train[batch*batch_size:(batch+1)*batch_size]
            sess.run(optimizer, feed_dict={x:xi, y:yi})
            step = step + 1
            if step % display_step == 0:
                loss, acc = sess.run([loss_function, accuracy],\
                                     feed_dict={x:x_valid, y:y_valid})
                loss_list.append(loss)
                acc_list.append(acc)
        #打乱顺序
        x_train, y_train = shuffle(x_train, y_train)

3.结果可视化

    y_pred, acc = sess.run([pred, accuracy],\
                            feed_dict={x:x_test, y:y_test})
fig2 = plt.figure(2, figsize=(12, 6))
ax = fig2.add_subplot(1, 2, 1)
ax.plot(loss_list, 'r-')
ax.set_title('loss')
ax = fig2.add_subplot(1, 2, 2)
ax.plot(acc_list, 'b-')
ax.set_title('acc')
print('Accuracy:{:.2%}'.format(acc))
show(x_test, y_test, y_pred)

测试集上的准确率
验证集上的损失值与准确率曲线

测试集图片标签与预测

四、TensorFlow2.x

1.定义模型

import tensorflow as tf
from sklearn.utils import shuffle
w = tf.Variable(tf.random.normal((784, 10)), tf.float32)
b = tf.Variable(tf.zeros(10), tf.float32)
def model(x, w, b):
    y0 = tf.matmul(x, w) + b
    y = tf.nn.softmax(y0)
    return y
#损失函数
def loss_function(x, y, w, b):
    pred = model(x, w, b)
    loss = tf.keras.losses.categorical_crossentropy(
        y_true=y, y_pred=pred)
    return tf.reduce_mean(loss)
#准确率
def accuracy(x, y, w, b):
    pred = model(x, w, b)  
    acc = tf.equal(tf.argmax(y, axis=1), tf.argmax(pred, axis=1))
    acc = tf.cast(acc, tf.float32)
    return tf.reduce_mean(acc)
#梯度
def grad(x, y, w, b):
    with tf.GradientTape() as tape:
        loss = loss_function(x, y, w, b)
        return  tape.gradient(loss, [w,b])

2.训练模型

#训练参数
train_epoch = 10
learning_rate = 0.01
batch_size = 100
batch_num = x_train.shape[0] // batch_size
#展示间隔
step = 0
display_step = 5
loss_list = []
acc_list = []
#Adam优化器
optimizer = tf.keras.optimizers.Adam(learning_rate)

迭代训练

for epoch in range(train_epoch):
    print('epoch:%d' % epoch)
    for batch in range(batch_num):
        xi = x_train[batch*batch_size: (batch+1)*batch_size]
        yi = y_train[batch*batch_size: (batch+1)*batch_size]
        grads = grad(xi, yi, w, b)
        optimizer.apply_gradients(zip(grads, [w,b]))
        step = step + 1
        if step % display_step == 0:
            loss_list.append(loss_function(x_valid, y_valid, w, b))
            acc_list.append(accuracy(x_valid, y_valid, w, b))
    #打乱顺序
    x_train, y_train = shuffle(x_train.numpy(), y_train.numpy())
    x_train = tf.cast(x_train, tf.float32)
    y_train = tf.cast(y_train, tf.float32)   

3.结果可视化

#验证集结果
fig2 = plt.figure(2, figsize=(12, 6))
ax = fig2.add_subplot(1, 2, 1)
ax.plot(loss_list, 'r-')
ax.set_title('loss')
ax = fig2.add_subplot(1, 2, 2)
ax.plot(acc_list, 'b-')
ax.set_title('acc')
#测试集结果
acc = accuracy(x_test, y_test, w, b)
print('Accuracy:{:.2%}'.format(acc))
y_pred = model(x_test, w, b)
show(x_test.numpy(), y_test, y_pred)

测试集上的准确率
验证集上的损失值与准确率曲线

测试集图片标签与预测


总结

分类在回归的基础上通过softmax函数放大不同类之间的概率差异,损失函数改为凸的交叉熵损失函数。
在tf1.x中,feed_dict需要提交numpy数组,可通过sess.run(Tensor)将张量转换为数组;
sklearn.utils.shuffle不能打乱张量类型,在tf2.x中使用Tensor.numpy()将张量转换为数组。
使用Adam优化器,一轮的训练速度减慢,但收敛速度加快,模型准确率也提高。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/146837.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux出现ping: www.baidu.com: 未知的名称或服务解决方法

文章目录解决对象方法先找到网关在Windows下进行VMnet8的配置ping成功Linux出现ping: www.baidu.com: 未知的名称或服务解决方法 解决对象 本文的方法用于各位大佬已经用过以下方法仍然无法ping成功 Linux防火墙已关闭和Windows防火墙已经关闭已经配置好 vim /etc/sysconfig/…

手撕C语言理论知识(上)粗略讲解C语言的部分入门知识

目录 C语言的一些基础知识 操作符简介 Scanf的%[ ] 语句(分支、循环、goto) 函数 C语言的一些基础知识 主函数 - 程序的入口 - main函数有且仅有一个。char - short - int - long - long long - float - double%d - 十进制整型 %u - 无符号整型 %…

【博学谷学习记录超强总结,用心分享|产品经理基础总结和感悟15】

互联网产品设计背后的心理学02:你就是会被其他人的行为所影响一、前言二、实验设计及结果分析三、实验原理四、实验方法总结五、产品设计中的应用六、结束语前文回顾:让人们做出决定并不是信息本身,而是这些信息呈现的背景或情景。我们这个信…

Spring Cloud Alibaba Dubbo(服务远程调用)

一、软件环境 &#xff08;1&#xff09;自己部署服务器 所有软件及服务器自己进行管理提供&#xff0c;可以直接在项目中添加Spring Cloud依赖。推荐 <dependencyManagement> <dependencies> <dependency> <groupId>com.a…

liunx centos9中安装flask并在pycharm中使用图文攻略

liunx centos9中安装flask并在pycharm中使用图文攻略1.首先在liunx的终端中输入2.安装好flask之后就在pycharm创建新的项目处添加flask项目3.点击绿色三角箭头开始运行flask项目4. 然后登录ip地址就出现Hllo world就代表flask环境搭建完成需要注意事项1.首先在liunx的终端中输入…

ngx_thread_pool_init()

ngx_thread_pool_cycle()函数的主要工作是从待处理的任务队列中获取一个任务&#xff0c;然后调用任务对象的handler()函数处理任务&#xff0c;完成后把任务放置到完成队列中&#xff0c;并通过ngx_notify()通知主线程 手写线程池与性能分析 - 知乎 pthread_cond_wait函数的原…

【5G RRC】5G系统消息介绍

博主未授权任何人或组织机构转载博主任何原创文章&#xff0c;感谢各位对原创的支持&#xff01; 博主链接 本人就职于国际知名终端厂商&#xff0c;负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作&#xff0c;目前牵头6G算力网络技术标准研究。 博客…

一键绕过ID锁激活,为什么很多人都会失败?绕ID这一篇就够了

最近阳了所以暂时断更&#xff0c;你们也要注意身体&#xff0c;最好不要阳 现在绕ID的方法已经非常完善&#xff0c;一个小白选手只要有设备就可以正常绕过ID&#xff0c;总的来说绕ID分为两个步骤&#xff1a;第一步是手机的越狱&#xff0c;这里只能是用checkra1n越狱&…

数据在内存中存储☞(超详解)

目录 一.数据类型大家族 1.了解类型的意义 2.数据类型大家族的分类 二.详解☞数据储存之整形 1.储存方式 &#xff08;1&#xff09;.原码反码补码的概念 &#xff08;2&#xff09;.原码反码补码出现的原因&#xff1a; 计算机中只有加法器没有减法器&#xff0c;所有只…

SemanticKITTI: A Dataset for Semantic Scene Understanding of LiDAR Sequences

Paper name SemanticKITTI: A Dataset for Semantic Scene Understanding of LiDAR Sequences Paper Reading Note URL: https://arxiv.org/pdf/1904.01416.pdf TL;DR 2019 ICCV 论文&#xff0c;提出了一个大规模的真实场景 LiDAR 点云标注数据集 SemanticKITTI&#xff…

数字信号处理第六次试验:数字信号处理在双音多频拨号系统中的应用

数字信号处理第六次试验&#xff1a;数字信号处理在双音多频拨号系统中的应用前言一、实验目的二、实验原理和方法1.关于双音多频拨号系统2.电话中的双音多频&#xff08;DTMF&#xff09;信号的产生与检测3.检测DTMF信号的DFT参数选择4.DTMF信号的产生与识别仿真实验三、实验内…

菜鼠的保研总结

1.个人基本情况 本科学校&#xff1a;山东某双非 本科专业&#xff1a;网络工程 成绩排名&#xff1a;1/46 英语成绩&#xff1a;四级529&#xff0c;六级502 科研竞赛&#xff1a;美国大学生数学建模比赛特等奖提名、全国英语翻译比赛三等奖、山东省蓝桥杯java大学生B组三等奖…

C++基础:KMP

让我们先看一个问题&#xff1a;给定一个字符串 S&#xff0c;以及一个模式串 P&#xff0c;所有字符串中只包含大小写英文字母以及阿拉伯数字。模式串 P 在字符串 S 中多次作为子串出现。求出模式串 P 在字符串 S 中所有出现的位置的起始下标。输入格式第一行输入整数 N&#…

【Python】杨辉三角中的排成一列编号的问题

题目描述 下面的图形是著名的杨辉三角形&#xff1a; 如果我们按从上到下、从左到右的顺序把所有数排成一列&#xff0c;可以得到如下数列&#xff1a; 1,1,1,1,2,1,1,3,3,1,1,4,6,4,1,⋯ 给定一个正整数 N&#xff0c;请你输出数列中第一次出现 N是在第几个数&#xff1f; …

Go语言设计与实现 -- 内存管理器

不同的编程语言选择不同的方式管理内存&#xff0c;本节会介绍Go语言内存分配器。 Go内存分配的设计思想是&#xff1a; 内存分配算法采用Google的TCMalloc算法&#xff0c;每个线程都会自行维护一个独立的内存池&#xff0c;进行内存分配时优先从该内存池中分配&#xff0c;…

第十八章Vue的学习

文章目录什么是VueVue.js的官网介绍环境配置基本语法声明式渲染绑定元素属性双向数据绑定条件渲染列表渲染事件驱动侦听属性Vue对象生命周期什么是Vue 对于Java程序来说&#xff0c;我们使用框架就是导入那些封装了**『固定解决方案』的jar包&#xff0c;然后通过『配置文件』…

CSS3 之选择器

文章目录1、关系性选择器&#xff1a;EFE~F2、属性选择器3、伪元素选择器4、伪类选择器(被选中的元素的一个种状态)calc1、关系性选择器&#xff1a;EFE~F 2、属性选择器 E[attr~“val”]E[attr|“val”]E[attr^“val”]E[attr$“val”]E[attr*“val”]3、伪元素选择器 E::pl…

CesiumLab对BIM模型的输入格式要求 CesiumaLab系列教程

BIM 模型和手工模型最大的区别在于几点&#xff1a; 1.建模目标不同&#xff0c;手工模型的目的是为了可视化&#xff0c;就是为了看的见&#xff0c;看不见的东西能省则省。BIM 完全是按照一些工程标准去创建的&#xff0c;比如路面可能有多个层代表了不同的物理层。手工模型…

windows编译Paraview源码

目录一. 环境准备二. 编译1. CMake2. Visual Studio一. 环境准备 下载基本所需&#xff1a; paraview官方给了编译文档&#xff1a;https://github.com/Kitware/ParaView/blob/master/Documentation/dev/build.md 所需要的基础有&#xff1a; 如图&#xff1a;&#xff08;进入…

2022我的年度总结-- AI遮天之路

我是一个普普通通的大学生&#xff0c;我的博客记录了我学习编程以来共计1年多的水平&#xff0c;我希望能把自己大学的经历、选择、困惑等与同样身处大学&#xff0c;选择AI方向不知如何发展的人进行分享&#xff0c;因此写了这篇年终总结。另外&#xff0c;对于一些刚刚开始写…