拓展神经网络八股(入门级)

news2025/1/8 5:42:31

自制数据集

minst等数据集是别人打包好的,如果是本领域的数据集。自制数据集。

替换

把图片路径和标签文件输入到函数里,并返回输入特征和标签

只需要把图片灰度值数据拼接到特征列表,标签添加到标签列表,提取操作函数如下:

def generateds(path, txt):
    f = open(txt, 'r')
    contents = f.readlines() #读取所有行
    f.close()
    x, y_ = [], []
    for content in contents:
        value = content.split()
        img_path = path + value[0]#找到图片索引路径
        img = Image.open(img_path) #图片打开
        img = np.array(img.convert('L')) # 图片变为8位灰度的npy格式的数据集                    
        img = img / 255.
        x.append(img)
        y_.append(value[1])
        print('loading:' + content) # 打印状态提示
    x = np.array(x)
    y_ = np.array(y_)
    y_ = y_astype(np.int64)
    return x, y_

 完整代码

import tensorflow as tf
from PIL import Image
import numpy as np
import os

train_path = './fashion_image_label/fashion_train_jpg_60000/'
train_txt = './fashion_image_label/fashion_train_jpg_60000.txt'
x_train_savepath = './fashion_image_label/fashion_x_train.npy'
y_train_savepath = './fashion_image_label/fahion_y_train.npy'

test_path = './fashion_image_label/fashion_test_jpg_10000/'
test_txt = './fashion_image_label/fashion_test_jpg_10000.txt'
x_test_savepath = './fashion_image_label/fashion_x_test.npy'
y_test_savepath = './fashion_image_label/fashion_y_test.npy'


def generateds(path, txt):
    f = open(txt, 'r')
    contents = f.readlines()  # 按行读取
    f.close()
    x, y_ = [], []
    for content in contents:
        value = content.split()  # 以空格分开,存入数组
        img_path = path + value[0]
        img = Image.open(img_path)
        img = np.array(img.convert('L'))
        img = img / 255.
        x.append(img)
        y_.append(value[1])
        print('loading : ' + content)

    x = np.array(x)
    y_ = np.array(y_)
    y_ = y_.astype(np.int64)
    return x, y_


if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(
        x_test_savepath) and os.path.exists(y_test_savepath):
    print('-------------Load Datasets-----------------')
    x_train_save = np.load(x_train_savepath)
    y_train = np.load(y_train_savepath)
    x_test_save = np.load(x_test_savepath)
    y_test = np.load(y_test_savepath)
    x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))
    x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:
    print('-------------Generate Datasets-----------------')
    x_train, y_train = generateds(train_path, train_txt)
    x_test, y_test = generateds(test_path, test_txt)

    print('-------------Save Datasets-----------------')
    x_train_save = np.reshape(x_train, (len(x_train), -1))
    x_test_save = np.reshape(x_test, (len(x_test), -1))
    np.save(x_train_savepath, x_train_save)
    np.save(y_train_savepath, y_train)
    np.save(x_test_savepath, x_test_save)
    np.save(y_test_savepath, y_test)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

数据增强

如果数据量过少,模型见识不足。增加数据,提高泛化力。

用来应对因为拍照角度不同引起的图片变形

image_gen_train=tf,keras.preprocessing,image.ImageDataGenneratorP(...)

image_gen)train,fit(x_train)

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度,使数据和网络结构匹配

image_gen_train = ImageDataGenerator(
    rescale=1. / 1.,  # 如为图像,分母为255时,可归至0~1
    rotation_range=45,  # 随机45度旋转
    width_shift_range=.15,  # 宽度偏移
    height_shift_range=.15,  # 高度偏移
    horizontal_flip=True,  # 水平翻转
    zoom_range=0.5  # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test),
          validation_freq=1)
model.summary()

 因为是标准MINST数据集,因此在准确度上看不出来,需要在具体应用中才能体现

断点续训

实时保存最优模型

 保存模型参数可以使用tensorflow提供的ModelCheckpoint(filepath=checkpoint_save,

                              save_weight_only,sabe_best_only)

参数提取

获取各层网络最优参数,可以在各个平台实现应用

model.trainable_variables 返回模型中可训练参数

acc/loss可视化

查看训练效果

history=model.fit()

import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt

np.set_printoptions(threshold=np.inf)

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1) 画出第一列
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2) #画出第二列
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

应用程序

给图识物

给出一张图片,输出预测结果

1.复现模型 Sequential加载模型

2.加载参数 load_weights(model_save_path)

3.预测结果

我们需要对颜色取反,我们的训练图片是黑底白字

减少了背景噪声的影响

from PIL import Image
import numpy as np
import tensorflow as tf

type = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

model_save_path = './checkpoint/fashion.ckpt'
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
                                        
model.load_weights(model_save_path)

preNum = int(input("input the number of test pictures:"))
for i in range(preNum):
    image_path = input("the path of test picture:")
    img = Image.open(image_path)
    img=img.resize((28,28),Image.ANTIALIAS)
    img_arr = np.array(img.convert('L'))
    img_arr = 255 - img_arr  #每个像素点= 255 - 各自点当前灰度值
    img_arr = img_arr/255.0
    x_predict = img_arr[tf.newaxis,...]

    result = model.predict(x_predict)
    pred=tf.argmax(result, axis=1)
    print('\n')
    print(type[int(pred)])


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1912541.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32快速搭建项目框架

注:编写本博客的原因,学习期间基于复习之前知识点的需要,故撰写本教程,即是复习前面的知识点也是作为博客的补充 1.0 文件夹的创建 创建一个STM32项目为模版工程,问价夹下分别包含4个子文件夹,一个是Librar…

【初阶数据结构】1.算法复杂度

文章目录 1.数据结构前言1.1 数据结构1.2 算法1.3 如何学好数据结构和算法 2.算法效率2.1 复杂度的概念2.2 复杂度的重要性 3.时间复杂度3.1 大O的渐进表示法3.2 时间复杂度计算示例3.2.1 示例13.2.2 示例23.2.3 示例33.2.4 示例43.2.5 示例53.2.6 示例63.2.7 示例7 4.空间复杂…

阻尼振动的可视化 包括源码和推导

阻尼振动的可视化 包括源码和推导 flyfish 牛顿第二定律(加速度定律) 胡克定律(Hooke‘s Law) 阻尼振动是指在振动系统中,由于阻力或能量损耗导致振动幅度随时间减小的现象。 左边为无阻尼,右边为有阻尼…

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第57-agent机器人助理自动获取喵星人资讯

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第57-agent机器人助理自动获取喵星人资讯 使用dtns.network德塔世界(开源的智体世界引擎),策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。dtns.network是一款主要由JavaScript…

FastReport 指定sql 和修改 数据库连接地址的 工具类 :FastReportHelper

FastReport 指定sql 和修改 数据库连接地址的 工具类 :FastReportHelper 介绍核心代码:完整代码: 介绍 在FastReport中,经常会遇到需要给 sql 加条件的情况,或者给数据库地址做更换。 (废话不多说&#x…

Elasticsearch基础(四):Elasticsearch语法与案例介绍

文章目录 Elasticsearch语法与案例介绍 一、Restful API 二、查询语法 1、ES分词器 2、ES查询 2.1、match 2.2、match_phrase 2.3、multi_match 2.4、term 2.5、terms 2.6、fuzzy 2.7、range 2.8、bool Elasticsearch语法与案例介绍 一、Restful API Elastics…

Echarts实现github提交记录图

最近改个人博客&#xff0c;看了github的提交记录&#xff0c;是真觉得好看。可以移植到自己的博客上做文章统计 效果如下 代码如下 <!DOCTYPE html> <html lang"en" style"height: 100%"><head><meta charset"utf-8"> …

需求分析|泳道图 ProcessOn教学

文章目录 1.为什么使用泳道图2.具体例子一、如何绘制确定好泳道中枢的角色在中央基于事实来绘制过程不要纠结美观先画主干处理流程再画分支处理流程一个图表达不完&#xff0c;切分子流程过程数不超25 &#xff0c;A4纸的幅面处理过程过程用动词短语最后美化并加上序号酌情加上…

未羽研发测试管理平台

突然有一些觉悟&#xff0c;程序猿不能只会吭哧吭哧的低头做事&#xff0c;应该学会怎么去展示自己&#xff0c;怎么去宣传自己&#xff0c;怎么把自己想做的事表述清楚。 于是&#xff0c;这两天一直在整理自己的作品&#xff0c;也为接下来的找工作多做点准备。接下来…

2-29 基于matlab的CEEMD

基于matlab的CEEMD&#xff08;Complementary Ensemble Empirical Mode Decomposition&#xff0c;互补集合经验模态分解&#xff09;&#xff0c;先将数据精心ceemd分解&#xff0c;得到imf分量&#xff0c;然后通过相关系数帅选分量&#xff0c;在求出他们的样本熵的特征。用…

理解点对点协议:构建高效网络通信

在通信线路质量较差的年代&#xff0c;能够实现可靠传输的高级数据链路控制&#xff08;High-level Data Link Control, HDLC&#xff09;协议曾是比较流行的数据链路层协议。HDLC是一个较复杂的协议&#xff0c;实现了滑动窗口协议&#xff0c;并支持点对点和点对多点两种连接…

SpringBoot实现简单AI问答(百度千帆)

第一步&#xff1a;注册并登录百度智能云&#xff0c;创建应用并获取自己的APIKey与SecretKey&#xff0c;参考网址&#xff1a; 点击去百度智能云 第二步&#xff1a;引入千帆的pom依赖 <dependency><groupId>com.baidubce</groupId><artifactId>q…

我的FPGA

1.安装quartus 2.更新usb blaster驱动 3.新建工程 1.随便找一个文件夹&#xff0c;里面新建demo文件夹&#xff0c;表示一个个工程 在demo文件夹里面&#xff0c;新建src&#xff08;源码&#xff09;&#xff0c;prj&#xff08;项目&#xff09;&#xff0c;doc&#xff…

基于单片机的温控光控智能窗帘设计探讨

摘 要&#xff1a; 文章使用的核心原件是 AT89C52 单片机&#xff0c;以此为基础进行模块化的设计&#xff0c;在整个设计中通过加入光检测模块和温度检测模块&#xff0c;从而对室内的温度和光照强度进行检测&#xff0c;然后将检测得到的数据传输给单片机&#xff0c;单片机…

Mosh|内连接、外连接、左连接、右连接(未完)

下图取自菜鸟教程&#xff0c;侵权删&#xff5e; 一、内连接&#xff1a;Inner Joins 模版&#xff1a;SELECT * FROM A JOIN B ON 条件 含义&#xff1a;返回A与B的交集&#xff0c;列为AB列之和 练习&#xff1a;将order_items表和products表连接&#xff0c;返回产品id和…

成为编程大佬!!——数据结构与算法(1)——算法复杂度!!

前言&#xff1a;解决同一个程序问题可以通过多个算法解决&#xff0c;那么要怎样判断一个算法的优劣呢&#xff1f;&#x1f914; 算法复杂度 算法复杂度是对某个程序运行时的时空效率的粗略估算&#xff0c;常用来判断一个算法的好坏。 我们通过两个维度来看算法复杂度——…

记录docker部署好golang web项目后浏览器访问不到的问题

部署好项目&#xff0c;docker ps -a查看没有任何问题 端口映射成功&#xff0c;但是浏览器就是访问不到&#xff0c;排查后发现犯了个错&#xff0c;注意&#xff0c;项目配置文件中的端口&#xff1a; 其实也就是你项目中监听的端口&#xff1a; 必须和容器端口一致&#x…

Linux——多线程(四)

前言 这是之前基于阻塞队列的生产消费模型中Enqueue的代码 void Enqueue(const T &in) // 生产者用的接口{pthread_mutex_lock(&_mutex);while(IsFull())//判断队列是否已经满了{pthread_cond_wait(&_product_cond, &_mutex); //满的时候就在此情况下等待// 1.…

看影视学英语(假如第一季第一集)

in the hour也代表一小时吗&#xff1f;等同于in an hour&#xff1f;

Effective C++笔记之二十一:One Definition Rule(ODR)

ODR细节有点复杂&#xff0c;跨越各种情况。基本内容如下&#xff1a; ●普通&#xff08;非模板&#xff09;的noninline函数和成员函数、noninline全局变量、静态数据成员在整个程序中都应当只定义一次。 ●class类型&#xff08;包括structs和unions&#xff09;、模板&…