深度学习(三)

news2025/1/11 8:42:15
5.Functional API 搭建神经网络模型
5.1利用Functional API编写宽深神经网络模型进行手写数字识别
import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Dense, concatenate

from tensorflow.keras.models import Model



iris = load_iris()



x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=23)

X_train, X_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=12)



print(X_valid.shape)

print(X_train.shape)



inputs = Input(shape=X_train.shape[1:])

hidden1 = Dense(300, activation="relu")(inputs)

hidden2 = Dense(100, activation="relu")(hidden1)

concat = concatenate([inputs, hidden2])

output = Dense(10, activation="softmax")(concat)

model_wide_deep = Model(inputs=inputs, outputs=output)

iris = load_iris():加载iris数据集,这是一个常用的多类分类数据集,包含了150个样本,每个样本有4个特征,属于3个不同的类别。

x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=23):将iris数据集分割为训练集和测试集。test_size=0.2表示测试集的大小为原始数据的20%,random_state=23是一个随机种子,确保分割的可重复性。

X_train, X_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=12):进一步将训练集分割为训练集和验证集。同样,test_size=0.2表示验证集的大小为分割后训练数据的20%,random_state=12确保分割的可重复性。

print(X_valid.shape):打印验证集的特征数据的形状。

print(X_train.shape):打印新的训练集的特征数据的形状。

inputs = Input(shape=X_train.shape[1:]):定义模型的输入层,shape=X_train.shape[1:]指定输入的形状,由于X_train是一个二维数组,shape[1:]表示除了第一维(样本数量)之外的所有维度。

hidden1 = Dense(300, activation="relu")(inputs):定义第一个隐藏层,它有300个神经元,并使用ReLU激活函数。

hidden2 = Dense(100, activation="relu")(hidden1):定义第二个隐藏层,它有100个神经元,并使用ReLU激活函数。

concat = concatenate([inputs, hidden2]):将输入层和第二个隐藏层的输出拼接起来,形成更宽的网络。

output = Dense(10, activation="softmax")(concat):定义输出层,它有10个神经元(对应于3个类别和一个额外的神经元,这是常见的做法),并使用softmax激活函数输出概率分布。

model_wide_deep = Model(inputs=inputs, outputs=output):创建一个Keras模型,将输入层和输出层连接起来。

使用scikit-learn库中的load_iris函数来加载iris数据集,然后使用train_test_split函数将数据集分割为训练集和测试集,以及进一步的训练集和验证集。接着,它定义了一个宽深网络(wide and deep network)模型,其中包含了输入层、两个隐藏层和一个输出层。

model_wide_deep.summary()

model_wide_deep.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])

h = model_wide_deep.fit(X_train, y_train, batch_size=32, epochs=30,validation_data=(X_valid, y_valid))

# 绘图

pd.DataFrame(h.history).plot(figsize=(8,5))

plt.grid(True)

plt.gca().set_ylim(0, 1)

plt.show()

# 使用 model_wide_deep 评估测试集

test_loss, test_accuracy = model_wide_deep.evaluate(x_test, y_test, batch_size=32)



print(f"Test Loss: {test_loss}")

print(f"Test Accuracy: {test_accuracy}")

6.SubClassing API 搭建神经网络模型

以前馈全连接神经网络手写数字识别为例

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Dense, concatenate

from tensorflow.keras.models import Model

from tensorflow.keras import backend as K



# 加载数据集

iris = load_iris()

X = iris.data

y = iris.target



# 分割数据集

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=23)

X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=12)



# 打印验证集和训练集的形状

print(X_valid.shape)

print(X_train.shape)



# 定义 Model_sub_fn 类

class Model_sub_fn(Model):

    def __init__(self, units_1, units_2, units_out, activation="relu"):

        super(Model_sub_fn, self).__init__()

        self.hidden1 = Dense(units_1, activation=activation)

        self.hidden2 = Dense(units_2, activation=activation)

        self.main_output = Dense(units_out, activation="softmax")



    def call(self, inputs):

        x = self.hidden1(inputs)

        x = self.hidden2(x)

        return self.main_output(x)

定义了一个名为Model_sub_fn的类,该类继承自tensorflow.keras.Model。这个类用于创建一个简单的神经网络模型,它包含两个隐藏层和一个输出层。

class Model_sub_fn(Model)定义一个名为Model_sub_fn的类,它继承自tensorflow.keras.Model。这意味着Model_sub_fn类可以访问和继承Model类的所有属性和方法。

def __init__(self, units_1, units_2, units_out, activation="relu"):定义类的构造函数__init__,它接受四个参数:units_1(第一个隐藏层的神经元数量)、units_2(第二个隐藏层的神经元数量)、units_out(输出层的神经元数量)和activation(激活函数类型,默认为ReLU)。

super(Model_sub_fn, self).__init__():调用父类的构造函数,这是继承自Model类的标准做法。

self.hidden1 = Dense(units_1, activation=activation):定义第一个隐藏层,它有units_1个神经元,并使用activation作为激活函数。

self.hidden2 = Dense(units_2, activation=activation):定义第二个隐藏层,它有units_2个神经元,并使用activation作为激活函数。

self.main_output = Dense(units_out, activation="softmax"):定义输出层,它有units_out个神经元,并使用softmax作为激活函数。

def call(self, inputs):定义call方法,这是所有Keras模型必须定义的方法,它用于前向传播。在这个方法中,输入数据通过两个隐藏层,最后通过输出层。

x = self.hidden1(inputs):将输入数据通过第一个隐藏层。

x = self.hidden2(x):将第一个隐藏层的输出通过第二个隐藏层。

return self.main_output(x):将第二个隐藏层的输出通过输出层,并返回结果。

model_sub_fn = Model_sub_fn(units_1=64, units_2=32, units_out=3)



# 创建 Model_sub_fn 实例

model_sub_fn = Model_sub_fn(300, 100, 3, activation="relu")  # 假设输出层有3个单元,因为Iris数据集有3个类别



# 编译模型

model_sub_fn.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])



# 训练模型

history = model_sub_fn.fit(X_train, y_train, batch_size=32, epochs=30, validation_data=(X_valid, y_valid))

model_sub_fn.summary()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1797687.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蒋耀锴:剑桥毕业,硅谷敲代码8年,回国创业做低代码

这是《开发者说》的第10期,本期我们邀请的开发者是蒋耀锴,曾在意大利、英国、美国读书,专长高性能计算、高容错分布式系统与软件工程,毕业后在硅谷的Medallia写了8年代码,19年回国创业做低代码,喜欢半夜里一…

常见Rabbitmq面试题及答案总结

1、 什么是 rabbitmq 釆用AMQP高级消息队列协议的一种消息队列技术撮大的特点就是消费并不需要 确保提供方存在,实现了服务之间的高度解耦 2、 为什么要使rabbitmq (1) 在分布式系统下具备异步,削峰,负载均衡等一系列高级功能&…

汽车MCU虚拟化--对中断虚拟化的思考(2)

目录 1.引入 2.TC4xx如何实现中断虚拟化 3.小结 1.引入 其实不管内核怎么变,针对中断虚拟化无非就是上面两种,要么透传给VM,要么由Hypervisor统一分发。汽车MCU虚拟化--对中断虚拟化的思考(1)-CSDN博客 那么,作为车规MCU龙头…

手把手制作Vue3+Flask全栈项目 全栈开发之路实战篇 问卷网站(五)数据处理

全栈开发一条龙——前端篇 第一篇:框架确定、ide设置与项目创建 第二篇:介绍项目文件意义、组件结构与导入以及setup的引入。 第三篇:setup语法,设置响应式数据。 第四篇:数据绑定、计算属性和watch监视 第五篇 : 组件…

Java学习中,如何理解注解的概念及常用注解的使用方法

一、简介 Java注解(Annotation)是一种元数据,提供了一种将数据与程序元素(类、方法、字段等)关联的方法。注解本身不改变程序的执行逻辑,但可以通过工具或框架进行处理,从而影响编译、运行时的…

Suryxin’s ACM退役记

序 我的记忆力很差,经历过的很多事情都已经记不太清了,其中有很多美好回忆也已经消散,我很惋惜没能留存一些照片和声音或是文字供我怀念,这就像《泰坦尼克号》一样,露丝和杰克感人肺腑的爱情故事,最后也仅…

东航携手抖音生活服务开启机票首播,推出国内、国际超值机票次卡

在民航暑运旺季到来之际,越来越多的用户选择提前做好旅行规划,囤下高性价比的出游商品。6月6日18点,中国东方航空(以下简称“东航”)将在抖音开启首次机票直播,推荐多款超值机票次卡及空中Wi-Fi等特色产品&…

SpringBoot发邮件服务如何配置?怎么使用?

SpringBoot发邮件需要的参数?邮件发送性能如何优化? 在SpringBoot项目中配置发邮件服务是一个常见的需求,它允许我们通过应用程序发送通知、验证邮件或其他类型的邮件。AokSend将详细介绍如何在SpringBoot中配置发邮件服务。 SpringBoot发邮…

nginx和proxy_protocol协议

目录 1. 引言2. HTTP server的配置3. Stream server的配置3.1 作为proxy_protocol的前端服务器3.2 作为proxy_protocol的后端服务器1. 引言 proxy_protocol 是haproxy开发的一种用于在代理服务器和后端服务器之间传递客户端连接信息的协议。使用 proxy_protocol 的主要优势是能…

QT获取最小化,最大化,关闭窗口事件

QT获取最小化,最大化,关闭窗口事件 主程序头文件: 实现: changeEvent,状态改变事件 closeEvent触发点击窗口关闭按钮事件 其代码它参考: /*重写该函数*/ void MainWindow::changeEvent(QEvent *event) {…

蚓链数字化营销生态的影响力分享!

​家人们,今天来给大家分享一些关于数字化平台生态化对数字营销影响的具体案例。 比如某电商平台,通过生态化的建设,实现了精准的推荐算法。根据用户的浏览历史和购买行为,为他们推荐最符合需求的商品,大大提高了购买…

JeeSite 快速开发平台 Vue3 前端版介绍

JeeSite 快速开发平台 Vue3 前端版介绍: 它构建于 Vue3、Vite、Ant-Design-Vue、TypeScript 以及 Vue Vben Admin 等最前沿的技术栈之上,能助力初学者迅速上手并顺利融入团队开发进程。涵盖的模块包括组织机构、角色用户、菜单授权、数据权限、系统参数…

小程序开发平台——超级万能DIY商城小程序源码系统 前后端分离 带完整的安装代码包以及搭建教程

系统概述 超级万能 DIY 商城小程序源码系统是一款集前端和后端分离的强大工具,为开发者提供了一站式的解决方案。它不仅具备完整的安装代码包,还附带详细的搭建教程,让即使是没有丰富技术经验的开发者也能轻松上手,快速构建自己的…

Vue3实战笔记(56)—实战:DefineModel的使用方法细节

文章目录 前言一、实战DefineModel二、思考原理总结 前言 今天写个小例子&#xff0c;实战DefineModel的使用方法细节 一、实战DefineModel 上文官方说的挺清楚&#xff0c;实战验证一下&#xff0c;新建DefineModel.vue&#xff08;这是儿子&#xff09;&#xff1a; <te…

珠海鸿瑞毛利率持续下滑:核心产品销量大降,偿债能力偏弱

《港湾商业观察》黄懿 日前&#xff0c;珠海市鸿瑞信息技术股份有限公司&#xff08;下称“珠海鸿瑞”&#xff09;收到了北京证券交易所发出的第三轮审核问询函。 此前&#xff0c;2020年11月&#xff0c;珠海鸿瑞曾向深交所报送上市申请。IPO申请文件获受理后&#xff0c;珠…

MySQL8 全文索引

文章目录 创建索引使用索引总结 创建索引 之前未尝试过使用MySQL8的全文索引&#xff0c;今天试一试看看什么效果&#xff0c;否则跟不上时代了都。   创建索引非常简单&#xff0c;写句SQL就行。 create table goods(id integer primary key auto_increment,name varchar(2…

知识图谱的应用---智能制造

文章目录 智能制造典型应用 智能制造 随着云计算、大数据、人工智能技术的快速发展&#xff0c;越来越多的新技术正在应用于传统工业领域&#xff0c;并在帮助企业实现产业转型、技术升级及效益提升方面起到了关键作用。目前在提升良品率方面&#xff0c;知识图谱通过深度计算所…

Selenium时间等待_显示等待

特点&#xff1a; 针对具体元素进行时间等待 可以自定义等待时长和间隔时间 按照设定的时间&#xff0c;不断定位元素&#xff0c;定位到了直接执行下一步操作 如在设定时间内没定位到元素&#xff0c;则报错&#xff08;TimeOutException&#xff09; 显示等待概念&#x…

【Python报错】已解决NameError: name ‘secrets‘ is not defined

解决Python报错&#xff1a;NameError: name ‘secrets’ is not defined 在使用Python进行安全编程时&#xff0c;我们经常需要使用secrets模块来生成安全的随机数。然而&#xff0c;如果你在尝试使用这个模块时遇到了NameError: name secrets is not defined的错误&#xff0…

【机器学习】机器学习与智能交通在智慧城市中的融合应用与性能优化新探索

文章目录 引言机器学习与智能交通的基本概念机器学习概述监督学习无监督学习强化学习 智能交通概述交通流量预测交通拥堵管理智能信号控制智能停车管理 机器学习与智能交通的融合应用实时交通数据分析数据预处理特征工程 交通流量预测与优化模型训练模型评估 智能信号控制与优化…