Tensorflow2.0搭建网络八股扩展

news2024/9/24 3:27:22

目录

一、自制数据集

准备:txt和图片

制作函数

二、断点继训,存取模型

1.读取保存的模型

2.保存模型

3.正确使用

三、参数提取,把参数存入txt

参数提取

四、acc/loss可视化,查看效果

1.前提开启:获取history

2.history参数表

3.代码

五、应用程序,给图识物

代码实现

总结


一、自制数据集

准备:txt和图片

需要有图片的名称信息,并且有txt文件进行汇总,下面是txt信息

制作函数

//路径
train_path = './mnist_image_label/mnist_train_jpg_60000/'
train_txt = './mnist_image_label/mnist_train_jpg_60000.txt'
x_train_savepath = './mnist_image_label/mnist_x_train.npy'
y_train_savepath = './mnist_image_label/mnist_y_train.npy'

//函数
def generateds(path, txt):
    f = open(txt, 'r')  # 以只读形式打开txt文件
    contents = f.readlines()  # 读取文件中所有行
    f.close()  # 关闭txt文件
    x, y_ = [], []  # 建立空列表
    for content in contents:  # 逐行取出
        value = content.split()  # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表
        img_path = path + value[0]  # 拼出图片路径和文件名
        img = Image.open(img_path)  # 读入图片
        img = np.array(img.convert('L'))  # 图片变为8位宽灰度值的np.array格式
        img = img / 255.  # 数据归一化 (实现预处理)
        x.append(img)  # 归一化后的数据,贴到列表x
        y_.append(value[1])  # 标签贴到列表y_
        print('loading : ' + content)  # 打印状态提示

    x = np.array(x)  # 变为np.array格式
    y_ = np.array(y_)  # 变为np.array格式
    y_ = y_.astype(np.int64)  # 变为64位整型
    return x, y_  # 返回输入特征x,返回标签y_
//调用函数,并保存numpy格式的训练数据
    x_train, y_train = generateds(train_path, train_txt)
    x_train_save = np.reshape(x_train, (len(x_train), -1))
    np.save(x_train_savepath, x_train_save)
    np.save(y_train_savepath, y_train)

二、断点继训,存取模型

1.读取保存的模型

load_weights(路径文件名)

2.保存模型

tf.keras.callbacks.ModelCheckpoint(
filepath=路径文件名, 
save_weights_only=True/False, 
save_best_only=True/False)
history = model.fit( callbacks=[cp_callback] 

3.正确使用

# 保存训练的模型路径
checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    # 读取文件模型名字
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
# 回调函数,将训练好的模型返回history
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])

三、参数提取,把参数存入txt

提取可训练参数

model.trainable_variables 返回模型中可训练的参数

参数提取

print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

四、acc/loss可视化,查看效果

1.前提开启:获取history

history=model.fit(训练集数据, 训练集标签, batch_size= , epochs=, validation_split=用作测试数据的比例,validation_data=测试集, validation_freq=测试频率)

2.history参数表

训练集loss: loss

测试集loss: val_loss

训练集准确率: sparse_categorical_accuracy

测试集准确率: val_sparse_categorical_accuracy

3.代码

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

五、应用程序,给图识物

前向传播执行应用

predict(输入特征, batch_size=整数) 返回前向传播计算结果

代码实现

from PIL import Image
import numpy as np
import tensorflow as tf
# 训练好的模型
model_save_path = './checkpoint/mnist.ckpt'

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')])
# 读取模型
model.load_weights(model_save_path)
# 测试数据
preNum = int(input("input the number of test pictures:"))
# 将输入的数据修改成我们想要的数据
for i in range(preNum):
    image_path = input("the path of test picture:")
    img = Image.open(image_path)
    img = img.resize((28, 28), Image.ANTIALIAS)
    img_arr = np.array(img.convert('L'))

    for i in range(28):
        for j in range(28):
            if img_arr[i][j] < 200:
                img_arr[i][j] = 255
            else:
                img_arr[i][j] = 0

    img_arr = img_arr / 255.0
    x_predict = img_arr[tf.newaxis, ...]
    result = model.predict(x_predict)

    pred = tf.argmax(result, axis=1)

    print('\n')
    tf.print(pred)
















总结

  • 本文主要借鉴:mooc曹健老师的《人工智能实践:Tensorflow笔记》
  • 可以利用这几种方式将八股神经网络进行优化应用到我们想要的场景中
  • 可差分的神经网络体系方便我们使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/935790.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ubuntu学习(六)----文件编程实现cp指令

1 思路 Linux要想复制一份文件通常指令为&#xff1a; cp src.c des.c 其中src.c为源文件&#xff0c;des.c为目标文件。 要想通过文件编程实现cp效果&#xff0c;思路如下 1 首先打开源文件 src.c 2 读src到buf 3 创建des.c 4 将buf写入到des.c 5 close两个文件 2 实现 vi …

并发编程基础知识篇--线程的状态和基本操作

目录 创建线程的四种方式 线程的状态和生命周期 扩展知识 线程的调度 线程状态的基本操作 interrupted 实例 join 实例 sleep 实例 扩展小知识 yield 实例 扩展 创建线程的四种方式 创建线程的四种方式 继承Thread类实现Runnable接口使用Callable和Future创…

博客系统——前端部分

目录 一、博客页面介绍 二、实现博客列表页 1、先实现导航栏 2、页面主体 左侧区域的实现&#xff1a;​编辑 右侧页面的实现&#xff1a;​编辑 博客列表页代码汇总&#xff1a; 三、实现博客详情页 代码实现&#xff1a; 四、实现博客登录页​编辑 五、博客编辑页 …

【赋权算法】Python实现熵权法

在开始之前&#xff0c;我们先说一下信息熵的概念。 当一件事情发生&#xff0c;如果是意料之中&#xff0c;那么这个事情就并不能拿来当做茶余饭后的谈资&#xff0c;我们可以说这个事情并没有什么信息和价值。而当一件不可能发生的事情发生的时候&#xff0c;我们可能就会觉…

挖数据四周年庆典,壕礼不断,惊喜不停!

挖数据四周岁啦&#xff01;为了感谢广大用户们一路以来的支持与陪伴&#xff0c;我们特地准备了丰富的优惠活动&#xff0c;希望能够用最实际的行动来回馈您们的厚爱。四年的成长与蜕变&#xff0c;都是因为有您们的陪伴与鼓励&#xff0c;我们期待与您们一同分享这份喜悦与成…

Linux 基金会宣布正式进驻中国

在 LinuxCon 2017 &#xff08;北京&#xff09;即将召开前夕&#xff0c;我们Linux 中国会同 51CTO、开源中国对 Linux 基金会执行董事 Jim Zemlin 进行了一场远跨大洋的视频专访。 在这次专访中&#xff0c;Jim 先生回答了几个开源界和互联网领域关注的问题&#xff0c;并披…

PCI设备和PCI桥的配置空间(header_type0、header_type1)和配置命令(type0、type1)详解

1、PCI典型拓扑 2、type0和type1 名称含义Bus Number设备所在总线号Device Number设备分配到的设备号Function Number功能号&#xff0c;有的设备是支持多个功能的&#xff0c;最多8种功能Register Number要访问的寄存器地址 (1)type0和type1的区别&#xff1a;AD[1:0]是00代表…

几个nlp的小任务(生成式任务——语言模型(CLM与MLM))

@TOC 本章节需要用到的类库 微调任意Transformers模型(CLM因果语言模型、MLM遮蔽语言模型) CLM MLM 准备数据集 展示几个数据的结构

【AI底层逻辑】——篇章7(下):计算资源软件代码共享

续上篇... 目录 续上篇... 三、计算资源 1、第一阶段&#xff1a;数据大集中 2、第二阶段&#xff1a;资源云化 ①“云”的分类 ②虚拟化技术 ③边缘计算的普及 四、软件代码共享 总结 往期精彩&#xff1a; 三、计算资源 AlphaGo算法论文虽然已经发表&#xff0c;但…

华为OD七日集训第2期 - 按算法分类,由易到难,循序渐进,玩转OD(文末送书)

目录 一、适合人群二、本期训练时间三、如何参加四、7日集训第2期五、精心挑选21道高频100分经典题目&#xff0c;作为入门。第1天、逻辑分析第2天、字符串处理第3天、数据结构第4天、递归回溯第5天、二分查找第6天、深度优先搜索dfs算法第7天、动态规划 六、集训总结1、《代码…

rke安装k8s

1、修改集群中各物理机主机名hostname文件 # 查看 cat /etc/hostname # 命令修改 hostnamectl set-hostname k8s-master2、实现主机名与ip地址解析 # 查看cat /etc/hosts # 修改 vi /etc/hosts3、配置ip_forward过滤机制 # 修改 vi /etc/sysctl.conf net.ipv4.ip_forward1…

RT-Thread IO设备模型

IO设备模型 RTT提供了一套简单的I/O设备模型框架&#xff0c;它位于硬件和应用程序之间&#xff0c;共分成三层&#xff0c;从上到下分别是I/O设备管理层、设备驱动框架层、设备驱动层。 应用程序通过I/O设备管理接口获得正确的设备驱动&#xff0c;然后通过这个设备驱动与底层…

递归算法学习——全排列

目录 ​编辑 一&#xff0c;问题描述 1.例子&#xff1a; 题目接口&#xff1a; 二&#xff0c;问题分析和解决 1.问题分析 2.解题代码 一&#xff0c;问题描述 首先我们得来先看看全排列的问题描述。全排列问题的问题描述如下&#xff1a; 给定一个不含重复数字的数组 n…

DTC状态变化例子 4

例子1&#xff1a; 此示例概述了两个操作周期排放相关的 OBD DTC 中 DTC 状态位的操作。该图显示了两个操作周期排放相关的 OBD DTC 的处理。该处理也可应用于非排放相关的 OBD DTC&#xff0c;此处显示仅供一般参考。 0 接收到清除诊断信息 → DTC 状态字节初始化。 1, 2 相关…

基于类电磁机制算法优化的BP神经网络(预测应用) - 附代码

基于类电磁机制算法优化的BP神经网络&#xff08;预测应用&#xff09; - 附代码 文章目录 基于类电磁机制算法优化的BP神经网络&#xff08;预测应用&#xff09; - 附代码1.数据介绍2.类电磁机制优化BP神经网络2.1 BP神经网络参数设置2.2 类电磁机制算法应用 4.测试结果&…

RabbitMQ---订阅模型-Topic

订阅模型-Topic • Topic类型的Exchange与Direct相比&#xff0c;都是可以根据RoutingKey把消息路由到不同的队列。只不过Topic类型Exchange可以让队列在绑定Routing key 的时候使用通配符&#xff01; • Routingkey 一般都是有一个或多个单词组成&#xff0c;多个单词之间以…

【clojure】入门篇-01

一、环境的配置 1.java环境配置 clojureScript 需要java环境的配置需要下载jdk进行java环境变量配置 下载官网 java环境变量的配置教程 2.Leningen环境配置 1.下载.bat文件内容 2.配置环境变量 2.8.3及以上内容进行配置 lein教程 2.使用vscode vscode官网 下载插件 C…

SIP 协议路由规则详解

文章目录 SIP 路由关键字段SIP 路由图解 SIP 路由关键字段 SIP 协议实际上和 HTTP 类似&#xff0c;都是基于文本、可阅读的应用层协议&#xff0c;二者的不同之处在于 SIP 协议是有状态的。在 SIP 协议中&#xff0c;影响报文路由的相关字段如下表所示&#xff0c;总结起来如…

给微软.Net runtime运行时提交的几个Issues

前言 因为目前从事的CLRJIT,所以会遇到一些非常底层的问题&#xff0c;比如涉及到微软的公共运行时和即时编译器或者AOT编译器的编译异常等情况,这里分享下自己提的几个Issues。原文:微软.Net runtime运行时提交的几个Issues Issues 一.issues one 第一个System.Numerics.Vecto…

深度强化学习。介绍。深度 Q 网络 (DQN) 算法

马库斯布赫霍尔茨 一. 引言 深度强化学习的起源是纯粹的强化学习&#xff0c;其中问题通常被框定为马尔可夫决策过程&#xff08;MDP&#xff09;。MDP 由一组状态 S 和操作 A 组成。状态之间的转换使用转移概率 P、奖励 R 和贴现因子 gamma 执行。概率转换P&#xff08;系统动…