TensorFlow手动搭建神经网络实现鸢尾花分类

news2025/2/6 23:03:19

步骤

  • 准备数据

  • 搭建网络

    • 定义神经网络中所有可训练参数
  • 参数优化

    • 嵌套循环迭代,with结构更新参数,显示当前loss
  • 测试效果

    • 计算当前参数前向传播后的准确率,显示当前acc
  • acc/loss可视化

这里使用一个最简单的网络实现鸢尾花分类

在这里插入图片描述

完整代码

import tensorflow as tf
from sklearn.datasets import load_iris
import numpy as np
from matplotlib import pyplot as plt

# import data
src = load_iris()
x_data = src.data
y_data = src.target

# shuffle data
np.random.seed(0)
np.random.shuffle(x_data)
np.random.seed(0)
np.random.shuffle(y_data)
tf.random.set_seed(0)

# split train and test data
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]

# 数据类型转换
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)

# 构建数据集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

# 定义神经网络的可训练参数

w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))

# 定义超参数

lr = 0.1
train_loss_result = []
test_acc = []
epoch = 500
loss_all = 0

# 循环嵌套迭代

for epoch in range(epoch):
    for step, (x_train, y_train) in enumerate(train_db):  # batch级别迭代
        with tf.GradientTape() as tape:
            y = tf.matmul(x_train, w1) + b1
            y = tf.nn.softmax(y)
            y_ = tf.one_hot(y_train, depth=3)
            loss = tf.reduce_mean(tf.square(y_ - y))
            loss_all += loss.numpy()
        #     计算loss对各个参数的梯度
        grads = tape.gradient(loss, [w1, b1])

        # 实现梯度更新
        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1])
    # 每个epoch,打印loss信息
    print("Epoch {},loss:{}".format(epoch, loss_all / 4))
    train_loss_result.append(loss_all / 4)
    loss_all = 0

    # 测试部分
    total_correct, total_number = 0, 0
    for x_test, y_test in test_db:
        y = tf.matmul(x_test, w1) + b1
        y=tf.nn.softmax(y)
        pred=tf.argmax(y,axis=1)
        pred=tf.cast(pred,dtype=y_test.dtype)

        correct=tf.cast(tf.equal(pred,y_test),dtype=tf.int32)
        correct=tf.reduce_sum(correct)
        total_correct+=int(correct)
        total_number+=x_test.shape[0]
    acc=total_correct/total_number
    test_acc.append(acc)
    print("Test_acc:",acc)
    print("-------------")

plt.title('Loss Function Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(train_loss_result,label='$Loss$')
plt.legend()
plt.show()

plt.title('Acc Curve')
plt.xlabel('Epoch')
plt.ylabel('Acc')
plt.plot(test_acc,label='$Accuracy$')
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/102807.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一文搞定 Postman 接口自动化测试

本文适合已经掌握 Postman 基本用法的读者,即对接口相关概念有一定了解、已经会使用 Postman 进行模拟请求等基本操作。 工作环境与版本: Window 7(64位) Postman (Chrome App v5.5.3) P.S. 不同版本页面…

Vue中在组件中单独使用this

目录 🔽 全局注册 🔽 局部注册 🔽 组件使用总结 🔽 全局注册 1、Vue.prototype 在多个地方都需要使用但不想污染全局作用域的情况下,这样定义,在每个 Vue 实例中都可用。$ 表示这是一个在 Vue 所有实…

Allegro如何查看PCB进度百分比操作指导

Allegro如何查看PCB进度百分比操作指导 Allegro支持实时查看PCB进度百分比,让设计者实时了解设计进度,具体操作如下 选择Display-StatusUnrouted connections这里就是就剩下未完成的百分比,如果是0,代表已经完成除了可以在这里快捷的查看,也可以通过报表实现,选择Tools-r…

InstructPix2Pix: 随口修图

InstructPix2Pix Learning to Follow Image Editing Instructions是一篇非常有意思的文章,有意思说的是效果,要做出论文的效果过程并没那么顺利。首先需要微调GPT3模型,这个花钱花力气,在之前的文章里已经提过,可以参考…

RedisSon分布式锁 源码解析,在 java 中使用 redis + lua 做秒杀

1. RedisSon 分布式锁 <dependency><groupId>org.redisson</groupId><artifactId>redisson-spring-boot-starter</artifactId><version>3.17.0</version> </dependency>spring:profiles:active: devredis:cluster:nodes: 192…

Mac OSX 安装 MongoDB

1&#xff0c;简介 MongoDB是由C语言编写&#xff0c;开源而且基于分布式文件存储的介于关系数据库和非关系数据库之间的产品&#xff1b;在高负载的情况下&#xff0c;通过添加更多节点保证服务器性能&#xff1b;旨在为WEB应用提供可扩展的高性能数据存储解决方案&#xff1…

Copy-Paste

在2D视觉目标检测领域&#xff0c;由相似目标之间的重叠引起的拥挤是普遍存在的挑战。 文章地址&#xff1a;https://arxiv.org/pdf/2211.12110.pdf 研究者首先强调了拥挤问题的两个主要影响&#xff1a;1&#xff09;IoU置信度相关干扰&#xff08;ICD&#xff09;和2&#…

桥接模式

文章目录桥接模式1.桥接模式的本质2.何时选用桥接模式3.优缺点4.桥接模式的结构5.实现模拟消息发送MVC在桥接模式的体现桥接模式 桥接模式实质就是分离抽象和实现&#xff0c;抽象部分有多种&#xff0c;实现部分有多种&#xff0c;耦合在一起很难扩展&#xff0c;将其分离开来…

excel如何排序?两个方法汇总

排序是Excel中最常用的功能之一&#xff0c;也是数据分类和汇总操作的重要前提。excel如何排序&#xff1f;本文介绍如何给Excel里面的数据进行排序&#xff0c;方法很简单。感兴趣的朋友&#xff0c;赶紧来看看吧&#xff01; 操作环境&#xff1a; 演示机型&#xff1a;Dell …

PostgreSQL 导入 SLS,从业务到监控数据

日志服务SLS数据导入简介 日志服务SLS是云原生观测和分析平台&#xff0c;为Log、Metric、Trace等数据提供大规模、低成本、实时的平台化服务。日志服务是提供一站式数据采集、加工、查询与分析、可视化、告警、消费与投递等功能。全面提升在研发、运维、运营、安全等场景的数…

web常见的攻击方式有哪些,以及如何进行防御?

一、是什么 Web攻击&#xff08;WebAttack&#xff09;是针对用户上网行为或网站服务器等设备进行攻击的行为 如植入恶意代码&#xff0c;修改网站权限&#xff0c;获取网站用户隐私信息等等 Web应用程序的安全性是任何基于Web业务的重要组成部分 确保Web应用程序安全十分重…

python中的模块与包详解

目录 一.什么是模块 二.模块的导入 1.import 模块名 2.from 模块名 import 功能名 3.from 模块名 import * 4.as定义别名 模块导入总结 三.自定义模块 制作自定义模块 用pycharm演示 测试模块_ _main_ _变量的作用 演示 ‘_ _all_ _’变量 自定义模块小结 四.python中的包…

Flink集成Seatunnel

安装包下载 相关包的下载地址 Apache SeaTunnel | Apache SeaTunnel Apache Flink: Downloads 解压&#xff08;注意下载scala_2.11&#xff09; tar -zxvf flink-1.13.6-bin-scala_2.11.tgz -C ../module/ Yarn模式部署 环境准备 sudo vi /etc/profile.d/my_env.sh 修…

中国清洁清洗行业等级资质

中国商业企业管理协会清洁服务商专业委员会——“中清委”&#xff08;以下简称评定单位&#xff09;承担组织等级清洁清洗服务机构评定工作。 申请资料 (1)专业清洁清洗服务机构等级评定申请表&#xff08;附录B&#xff09;&#xff1b; (2)法人代表资格证明&#xff1…

小林Coding阅读笔记:操作系统篇之硬件结构,伪共享问题及CPU的任务执行

前言 参考/导流&#xff1a; 小林coding - 2.5 CPU 是如何执行任务的&#xff1f;学习意义 底层基础知识&#xff0c;了解CPU执行过程&#xff0c;让上层编码有效并发问题处理、思考理解调度策略、思想借鉴分析 相关说明 该篇博文是个人阅读的重要梳理&#xff0c;仅做简单参…

Transformer实现以及Pytorch源码解读(一)-数据输入篇

目标 以词性标注任务为例子&#xff0c;实现Transformer&#xff0c;并分析实现Pytorch的源码解读。 数据准备 所选的数据为nltk数据工具中的treebank数据集。treebank数据集的样子如以下两幅图所示&#xff1a; 该数据集中解释变量为若干句完整的句子&#xff1a; 被解释变…

Docker-DockerFile制定镜像

什么是DockerFile&#xff1f; DockerFile是一个用来编写Docker镜像的文本文件&#xff0c;文本内容包含了一条条构建镜像所需要的指令和说明。DockerFile就想要一个脚本文件一样。把我们想要执行的操作放到文本文件里&#xff0c;一键执行。这样我们就可以复用这个DockerFile…

读论文:Learning to Compare: Relation Network for Few-Shot Learning

Abstract 我们提出了一个概念上简单、灵活且通用的少镜头学习框架&#xff0c;其中分类器必须学习识别每个只给出少量示例的新类。我们的方法称为关系网络(RN)&#xff0c;从头到尾进行训练。在元学习过程中&#xff0c;它学习学习一个深度距离度量来比较插曲中的少量图像&…

RNA-seq 详细教程:时间点分析(14)

学习内容 了解如何使用 DESeq2 进行时间的分析LRT 使用 LRT 进行 Time course 分析尽管基因表达的静态测量很受欢迎&#xff0c;但生物过程的时程捕获对于反映其动态性质至关重要&#xff0c;特别是当模式复杂且不仅仅是上升或下降时。在处理此类数据时&#xff0c;似然比检验 …

doris入门后遇到的几个问题总结

文章目录1. Access denied for user anonymnull (using password: NO)2. timeout when waiting for send fragments RPC. Wait(sec): 5, host: xxx(ip)3. Failed to initialize JNI: Failed to find the library libjvm.so.4. 从mysql库导出的json文件大于100M时报错5. csv格式…