Tensorflow2.0笔记 - 不使用layer方式,简单的MNIST训练

news2025/1/23 9:30:32

        本笔记不使用layer相关API,搭建一个三层的神经网络来训练MNIST数据集。

        前向传播和梯度更新都使用最基础的tensorflow API来做。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import numpy as np

def load_mnist():
    path = r'./mnist.npz' #放置mnist.py的目录。注意斜杠
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)

#加载mnist数据集
#X_train: [60000, 28, 28] 图片
#Y_train: [60000] 标签
#mnist数据集下载:https://blog.csdn.net/charles_neil/article/details/107851880
#                https://www.zhihu.com/question/56773355
(X_train,Y_train),(X_test,Y_test) = load_mnist()

#转换为tensor
#图片数据值转换到0-1
x = tf.convert_to_tensor(X_train, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(Y_train, dtype=tf.int32)
print(x.shape,y.shape)
print(tf.reduce_min(x), tf.reduce_max(x))
print(tf.reduce_min(y), tf.reduce_max(y))

#数据集切分为多个batch
train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter = iter(train_db)

sample = next(train_iter)
print(sample[0].shape, sample[1].shape)


#学习率
lr = 0.1
#用三个神经元,[b:784] => [b,256] => [b,128] => [b,10]
w1 = tf.Variable(tf.random.truncated_normal([784,256], stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256,128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128,10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))

for epoch in range(10):
    print("[==================Epoch ", epoch, "========================]")
    for step, (x,y) in enumerate(train_db):
        x = tf.reshape(x, [-1, 28*28])
        #对标签进行onehot编码
        y_onehot = tf.one_hot(y, depth=10)
    
        with tf.GradientTape() as tape:
            #第一层,输入x [128,784]
            #x@w + b: [batch, 784] [784,256] + [256] => [batch,256]
            h1 = x@w1 + b1
            h1 = tf.nn.relu(h1)
            #第二层:[batch, 256] => [batch, 128]
            h2 = h1@w2 + b2
            h2 = tf.nn.relu(h2)
            #输出层:[batch,128] => [batch,10]
            out = h2@w3 + b3
        
            #计算损失
            #使用MSE: mean(sum(y - out)^2)
            loss = tf.reduce_mean(tf.square(y_onehot - out))
        #计算梯度
        grads = tape.gradient(loss, [w1,b1,w2,b2,w3,b3])
        #更新w和b: w = w - lr * w_grad
        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1])
        w2.assign_sub(lr * grads[2])
        b2.assign_sub(lr * grads[3])
        w3.assign_sub(lr * grads[4])
        b3.assign_sub(lr * grads[5])
        
        if (step % 100 == 0):
            print("Batch:", step, "loss:", float(loss))

        运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1397426.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

arthas(阿尔萨斯)日常java代码调优使用命令

官方项目文档:https://gitee.com/arthas/arthas (最权威的教学还是得官网,这里仅作简单记录) 1:启动 java -jar arthas-boot.jar 2:查看cpu占用排名前三 thread -3 3:查看指定id thread 203 4:查…

ui 开发 剪辑等工具集成网站

这里给大家推荐一个工具集成网站,总体来说还是挺不错的 菜鸟工具 - 不止于工具

mysql中DATE_FORMAT() 函数详解

mysql中DATE_FORMAT() 函数详解 一. 说明 在 MySQL 中,DATE_FORMAT() 函数用于将日期/时间类型的值按照指定的格式进行格式化输出。它的一般语法如下: DATE_FORMAT(date, format)其中,date 参数是要被格式化的日期/时间值,form…

C++03:条件与分支语句

2024年1月14日 内容来自The Cherno:C系列 2024年1月17日 更新内容整理自: 南京大学 陈佳俊 郑涛 《程序设计教程 用C语言编程》 --------------------------------------------------------------------------------------------------------------…

Java和SpringBoot学习路线图

看了一下油管博主Amigoscode的相关视频,提到了Java和SpringBoot的学习路线,相关视频地址为: How To Master Java - Java for Beginners RoadmapSpring Boot Roadmap - How To Master Spring Boot 如下图所示: 当然关于Java和Spr…

【声光语音告警】小机房-动环系统与服务器监控二合一告警方案

目前场景及存在的问题 目前有很多小规模机房,服务器数量不多,机房面积也较小,例如医院、车站、博物馆、学校、工厂等环境。机房虽小,但仍然需要进行服务器性能监控以及机房动力环境监控,例如漏水、温湿度、烟感、电压…

ChatGPT企业版跟个人版有什么区别?

ChatGPT企业版(ChatGPT Enterprise)除了有和个人版GPT一样的功能外,企业版还可提供企业级的安全和隐私、以及数据分析功能。 订阅实用可以看下这篇文章: ChatGPT企业版的区别是,企业版允许客户输入公司的数据&#xf…

6.4.2转换文件

6.4.2转换文件 利用Swf2VideoConverter2可以很方便地将Flash动画(*.swf)转换为其它的视频格式。 1.单击“添加”按钮,在弹出的下拉菜单中选择“添加文件”,在弹出的“Open Swf Files(打开Swf文件)”窗口中选择swf文件(如:那些花…

拉索回归(Lasso Regression)的原理是什么?

拉索回归(Lasso Regression),全称Least Absolute Shrinkage and Selection Operator回归,是一种线性回归的改进方法,主要用于数据分析和特征选择。其核心原理是在传统的线性回归损失函数中加入了一个L1正则化项&#x…

Javaweb超详细实现模拟支付宝扫码支付

1.普通方式创建Javaweb项目 首先创建Java项目 2.创建好的项目添加web框架支持 如图选择确定 在项目结构中配置有关信息 右键创建classes文件夹与lib文件夹 如图 此处找到刚才的项目的classes路径设置 在依赖中设置lib路径的设置 找到刚才的lib路径 选择此选项 结束项目结构中模…

web蓝桥杯真题--10、灯的颜色变化

介绍 我们经常会看到各种颜色的灯光,本题我们将实现一个颜色会变化的灯的效果。 准备 开始答题前,需要先打开本题的项目代码文件夹,目录结构如下: ├── effect.gif ├── images │ ├── greenlight.svg │ ├── l…

初识SpringBoot

SpringBoot以约定大于配置的核心思想,默认帮我们进行了很多设置,简单来说就是SpringBoot其实不是什么新的框架,它默认配置了很多框架的使用方式,就像maven整合了所有的jar包,spring boot整合了所有的框架 。 创建的包一定要在项目主程序入口…

MATLAB - 计算机械臂关节扭矩以平衡末端力和力矩

系列文章目录 前言 产生力矩以平衡作用在平面机器人末端执行器体上的端点力。要使用各种方法计算关节力矩,请使用刚体树机器人模型的几何雅各比(geometricJacobian)和反动力学(inverseDynamics)对象函数。 一、初始化…

JavaScript的代码执行顺序

(1). js的执行顺序,先同步后异步 (2). 异步中任务队列的执行顺序: 先微任务microtask队列,再宏任务macrotask队列 注意,按顺序从上到下时,没有轮到执行的任务会进入相应…

PowerScale重磅升级,加速迈进AI时代

2024开年 给大伙报告一则好消息 Dell非结构化数据存储的扛把子 PowerScale迎来重大升级 第二代PowerScale全闪存系统 即将闪亮登场 此次升级主要涉及硬件、软件及与NVIDIA的合作关系三个方面,升级后的PowerScale有望成为第一个通过 NVIDIA DGX SuperPOD验证的以…

Linux———sort命令总结详解(狠狠爱住)

目录 sort命令: 命令参数及描述: 示例: 使用-b参数,忽略行首空白字符,按照第一列进行排序: -d 选项是 sort 命令中一个非常有用的选项,它可以按照字典顺序进行排序,同时忽略非字…

创业前先把刘强东这两句琢磨明白!不然大概率失败!2024最适合创业的行业!2024年普通人的创业机会在哪里

第一句,真正解决一个问题。 这句话表达了,你的项目一定是要建立在解决具体的问题上,而不是你觉得自己有个好点子,或者好产品就可以了。因为即使你的产品很好,服务很好,如果不能切实的解决某个问题&#xf…

渐开线齿轮计算软件开发Python

从0开始开发计算软件,欢迎大家加入 源代码仓库

【C++】std::string 转换成非const类型 char* 的三种方法记录

std::string 有两个方法:data() 和 c_str(),都是返回该字符串的const char类型,那如何转换成非const的char呢? 下面展示三种方法: 强转:char* char_test (char*)test.c_str();使用string的地址&#xff…

Android 查看 md5

网上看了一大批文章老实说 百分之80的都是垃圾 , 都没有说明白怎么看 keytool -list -v -keyst xxx.jks 在自己的项目中 , terminal 输入上面命令 跟本就没有用看不到 md5 很多的文章让你找 signingReport , 但是你查看 自己的目录可能压根就没有这个 自己直接用手敲就可以…