【tensorflow框架神经网络实现鸢尾花分类】

news2024/12/26 12:04:53

文章目录

  • 1、数据获取
  • 2、数据集构建
  • 3、模型的训练验证
  • 可视化训练过程

1、数据获取

  • 从sklearn中获取鸢尾花数据,并合并处理
from sklearn.datasets import load_iris
import pandas as pd

x_data = load_iris().data
y_data = load_iris().target

x_data = pd.DataFrame(x_data, columns=['花萼长度','花萼宽度','花瓣长度','花瓣宽度'])
pd.set_option('display.unicode.east_asian_width', True)

x_data['类别'] = y_data
x_data

在这里插入图片描述

2、数据集构建

  • 数据集构建包括:
    • 数据读取
    • 数据打乱
    • 数据划分
    • 小批量迭代器生成
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris

# 1、从sklearn包中datasets读取数据集
x_data = load_iris().data
y_data = load_iris().target

# 2、数据打乱
np.random.seed(1)   # 使用相同的seed,使输入特征/标签一一对应
np.random.shuffle(x_data)
np.random.seed(1)
np.random.shuffle(y_data)
tf.random.set_seed(1)

# 3、训练集、测试集划分
x_train, x_test = x_data[:-30], x_data[-30:]
y_train, y_test = y_data[:-30], y_data[-30:]

# 4、小批量数据
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
train_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

3、模型的训练验证

# 定义超参数,预设变量
lr = 0.1
loss_all = 0
Epoch = 500
train_loss_list = []
test_acc = []

# 定义神经网络的可训练参数
w = tf.Variable(tf.random.truncated_normal([4,3], stddev=0.1, seed=1))
b = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))

# 循环迭代,训练参数
for epoch in range(Epoch):
    for step, (x_, y_) in enumerate(train_db):
        with tf.GradientTape() as tape:
            x_ = tf.cast(x_, tf.float32)
            y_pre = tf.matmul(x_, w) + b
            y_pre = tf.nn.softmax(y_pre)

            y_lab = tf.one_hot(y_, depth=3)
            loss = tf.reduce_mean(tf.square(y_lab - y_pre))
            loss_all += loss.numpy()

        grads = tape.gradient(loss, [w,b])
        w.assign_sub(lr * grads[0])
        b.assign_sub(lr * grads[1])
    print(f'Epoch: {epoch}, loss: {loss_all/4}')
    train_loss_list.append(loss_all/4)
    loss_all = 0

    # 测试部分
    total_correct, total_number = 0, 0
    for x_,y_ in test_db:
        x_ = tf.cast(x_, tf.float32)
        y_pre = tf.matmul(x_, w) + b
        y_pre = tf.nn.softmax(y_pre)
        y_p = tf.argmax(y_pre, axis=1)
        y_p = tf.cast(y_p, dtype=y_.dtype)

        correct = tf.cast(tf.equal(y_p, y_), dtype=tf.int32)
        correct = tf.reduce_sum(correct)
        total_correct += int(correct) 
        total_number += x_.shape[0]
    acc = total_correct / total_number
    test_acc.append(acc)
    print("Test_acc:", acc)
    print("-"*30)

在这里插入图片描述

可视化训练过程

# 绘制测试Acc曲线和训练loss曲线
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(train_loss_list,'b-')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')

ax1 = ax.twinx()
ax1.plot(test_acc,'r-')
ax1.set_ylabel('Acc')

ax1.spines['left'].set_color('blue')
ax1.spines['right'].set_color('red')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1553366.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flask学习(六):蓝图(Blueprint)

蓝图(Blueprint):将各个业务进行区分,然后每一个业务单元可以独立维护,Blueprint可以单独具有自己的模板、静态文件或者其它的通用操作方法,它并不是必须要实现应用的视图和函数的。 Demo目录结构&#xf…

八大技术趋势案例(人工智能物联网)

科技巨变,未来已来,八大技术趋势引领数字化时代。信息技术的迅猛发展,深刻改变了我们的生活、工作和生产方式。人工智能、物联网、云计算、大数据、虚拟现实、增强现实、区块链、量子计算等新兴技术在各行各业得到广泛应用,为各个领域带来了新的活力和变革。 为了更好地了解…

利用Java代码混淆技术提升应用程序抗逆向工程能力

摘要 本文探讨了代码混淆在保护Java代码安全性和知识产权方面的重要意义。通过混淆技术,可以有效防止代码被反编译、逆向工程或恶意篡改,提高代码的安全性。常见的Java代码混淆工具如IPAGuard、Allatori、DashO、Zelix KlassMaster和yGuard等&#xff0…

Python人工智能:气象数据可视化的新工具

Python是功能强大、免费、开源,实现面向对象的编程语言,在数据处理、科学计算、数学建模、数据挖掘和数据可视化方面具备优异的性能,这些优势使得Python在气象、海洋、地理、气候、水文和生态等地学领域的科研和工程项目中得到广泛应用。可以…

物联网实战--入门篇之(一)物联网概述

目录 一、前言 二、知识梳理 三、项目体验 四、项目分解 一、前言 近几年很多学校开设了物联网专业,但是确却地讲,物联网属于一个领域,包含了很多的专业或者说技能树,例如计算机、电子设计、传感器、单片机、网…

葵花卫星影像应用场景及数据获取

一、卫星参数 葵花卫星是由中国航天科技集团公司研制的一颗光学遥感卫星,代号CAS-03。该卫星于2016年11月9日成功发射,位于地球同步轨道,轨道高度约为35786公里,倾角为0。卫星设计寿命为5年,搭载了高分辨率光学相机和多…

Oracle存数字精度问题number、binary_double、binary_float类型

--表1 score是number(10,5)类型 create table TEST1 (score number(10,5) ); --表2 score是binary_double类型 create table TEST2 (score binary_double ); --表3 score是binary_float类型 create table TEST3 (score binary_float );实验一:分别往三张表插入 小数…

抖音视频关键词无水印下载软件|手机网页视频批量提取工具

全新视频关键词无水印下载软件,助您快速获取所需视频! 随着时代的发展,视频内容已成为人们获取信息和娱乐的重要途径。为了方便用户获取所需视频,推出了一款功能强大的视频关键词无水印下载软件。该软件主要功能包括关键词批量提取…

【话题】AI大模型学习:理论、技术与应用探索

大家好,我是全栈小5,欢迎阅读小5的系列文章,这是《话题》系列文章 目录 背景1. AI大模型学习的基础理论1.1 机器学习1.2 深度学习 2. AI大模型学习的技术要点2.1 模型结构设计2.2 算法优化2.3 大规模数据处理 3. AI大模型学习的应用场景3.1 自…

网络爬虫框架Scrapy的入门使用

Scrapy的入门使用 Scrapy概述引擎(Engine)调度器(Scheduler)下载器(Downloader)SpiderItem Pipeline 基本使用安装scrapy创建项目定义Item数据模型对象创建爬虫(Spider)管道pipeline来保存数据启动爬虫 其他…

Netty核心原理剖析与RPC实践6-10

Netty核心原理剖析与RPC实践6-10 06-粘包拆包问题:如何获取一个完整的网络包 本节课开始我们将学习 Netty 通信过程中的编解码技术。编解码技术这是实现网络通信的基础,让我们可以定义任何满足业务需求的应用层协议。在网络编程中,我们经常…

高风险IP来自哪里:探讨IP地址来源及其风险性质

在网络安全领域,高风险IP地址是指那些可能涉及恶意活动或网络攻击的IP地址。了解这些高风险IP地址的来源可以帮助网络管理员更好地识别和应对潜在的安全威胁。本文将探讨高风险IP地址的来源及其风险性质,并提供一些有效的应对措施。 风险IP查询&#xf…

Sourcetree如何解决冲突和重置

解决冲突:找到冲突的文件然后点恢复(其实是丢弃的意思) 重置回某个分支节点:

HTTP——Cookie

HTTP——Cookie 什么是Cookie通过Cookie访问网站 我们之前了解了HTTP协议,如果还有小伙伴还不清楚HTTP协议,可以点击这里: https://blog.csdn.net/qq_67693066/article/details/136895597 我们今天来稍微了解一下HTTP里面一个很小的部分&…

Redis中的LRU算法分析

LRU算法 概述 Redis作为缓存使用时,一些场景下要考虑内容的空间消耗问题。Redis会删除过期键以释放空间,过期键的删除策略 有两种: 1.惰性删除:每次从键空间中获取键时,都检查取得的键是否过期,如果过期的话,就删除…

Adobe最近推出了Firefly AI的结构参考以及面向品牌的GenStudio

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Rust编程(四)PackageCrateModule

这一部分的中文教程/文档都很混乱,翻译也五花八门,所以我建议直接看英文官方文档,对于一些名词不要进行翻译,翻译只会让事情更混乱,本篇从实战和实际需求出发,讲解几个名称的关系。 Module & Crate & Package & Workspace 英文中的意思: Cargo:货物 Crate:…

Apache HBase(二)

目录 一、Apache HBase 1、HBase Shell操作 1.1、DDL创建修改表格 1、创建命名空间和表格 2、查看表格 3、修改表 4、删除表 1.2、DML写入读取数据 1、写入数据 2、读取数据 3、删除数据 2、大数据软件启动 一、Apache HBase 1、HBase Shell操作 先启动HBase。再…

【已修复】iPhone13 Pro 长焦相机水印(黑斑)修复 洗水印

iPhone13 Pro 长焦相机水印(黑斑)修复 洗水印 问题描述 iPhone13 Pro 后摄3倍相机有黑色斑点(水印),如图所示, 后摄相机布局如图所示, 修复过程 拆机过程有风险,没有把握最好不要…

【算法刷题 | 二叉树 05】3.28(左叶子之和、找树 左下角的值)

文章目录 11.左叶子之和11.1问题11.2解法一:递归11.2.1递归思路11.2.2代码实现 11.3解法二:栈11.3.1栈思想11.3.2代码实现 12.找树左下角的值12.1问题12.2解法一:层序遍历 11.左叶子之和 11.1问题 给定二叉树的根节点 root ,返回…