TensorFlow TFRecords简介

news2024/11/27 22:19:55

TensorFlow TFRecords简介

这篇博客将介绍TensorFlow的TFRecords,提供有关TFRecords的所有信息的一应俱全的介绍。从如何构建基本TFRecords到用于训练 SRGAN 和 ESRGAN 模型的高级TFRecords的所有内容。包括什么是TFRecords,如何序列化,反序列化数据,以及如何使用TFRecords预处理和序列化像div2k这样的大型数据集,如何使用TFRecords及TensorFlow训练深度神经网络。

TFRecord格式的两个主要优点是,高效地存储数据集,并且与从磁盘读取原始数据相比,获得了更快的I/O速度。

当使用TPU训练深度神经网络时,TFRecords非常有用。可以查看SRGAN和ESRGAN教程,其中介绍了如何使用Tensor处理单元(TPUs ensor Processing Units)和图形处理单元(GPUs Graphics Processing Units )训练深度神经网络。

最好不使用tf.image.resize,坑太多

1. 效果图

可以看到原始数据和编码后数据相同,编码数据只是原始数据的字节字符串,TFRecord中的数据是序列化的二进制记录。

$ python single_tf_record.py

Original data: 12345
Encoded data: b'12345'
Data from the TFRecord: b'\x05\x00\x00\x00\x00\x00\x00\x00\xea\xb2\x04>12345z\x1c\xed\xe8'
Decoded data: 12345

从输出中可以明显看出,原始数据被序列化为一系列字节字符串,随后被反序列化为原始数据。

$ python serialization.py

Original Data: [1 2 3 4]
Encoded Data: b'\x08\x04\x12\x04\x12\x02\x08\x04"\x04\x01\x02\x03\x04'
Decoded Data: [1 2 3 4]

根据url下载网络图片,指定文件名,构建为TFRecord 数据,并序列化为二进制字符串保存到文件,然后读取在解析会照片和文件名,效果图如下:
在这里插入图片描述在这里插入图片描述

2. 原理

2.1 安装

pip install tensorflow==2.1.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# pip install tensorflow --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorflow-datasets

2.2 TFRecord是什么

TFRecord是用于存储二进制记录序列的自定义TensorFlow格式。TFRecords针对TensorFlow进行了高度优化,因此具有以下优势:

  • 高效的数据存储形式
  • 与其他类型的格式相比,读取速度更快

TFRecords最重要的用例之一是使用TPU训练模型。TPU功能强大,但需要远程存储与之交互的数据。在TPU上训练模型时,以TFRecord格式远程存储数据集,因为它可以有效地保存数据并更容易地加载数据。

2.3 什么是序列化二进制记录?

TFRecords存储一系列二进制记录。因此首先需要学习如何将数据转换为二进制表示。
TensorFlow有两个公共API,负责将数据编码和解码为二进制记录。这两个公共API来自tf.io.serialize_tensor 和 tf.io.parse_tensor

通过使用tf.train.Feature进行数据的序列化和反序列化,支持的类型如下:
在这里插入图片描述

2.4 DIV2K数据集

DIVerse 2K分辨率高质量图像

  • 1000张2K分辨率的图像分为:800张用于训练的图像、100张用于验证的图像和100张用于测试的图像
  • 对于每个挑战赛道(具有1.双三次或2.未知降级运算符),
  • 高分辨率图像:0001.png,0002.png,…,1000.png
  • 缩小的图像:YYYYx2.png表示缩小因子x2;其中YYYY是图像ID;
    YYYYx3.png,缩小因子x3;
    YYYYx4.png;缩小因子x4
  • DIV2K forder结构如下:
    DIV2K/–DIV2K数据集
    DIV2K/DIV2K_train_HR/–0001.png,0002.png,…,0800.png列车HR图像(提供给参与者)
    DIV2K/DIV2K_train_LR_bicubic/——使用Matlab调大小函数获得的具有默认设置的相应低分辨率图像(双三次插值)

3. 源代码

3.1 example_tf_record.py

# utils.py 从磁盘加载和保存图像到磁盘
# config.py 单个数据tfrecord示例的配置文件
# advance_config.py div2k数据集示例的配置文件
# single_tf_record.py 处理单个二进制记录并显示如何将其保存为TFRecord格式的脚本
# serialization.py 解释数据序列化重要性的脚本
# example_tf_record.py 保存和加载单个图片为TFRecord,如何从磁盘加载原始图像并以TFRecord格式对其进行序列化,以及如何加载序列化的TFRecord并对图像进行反序列化。
# create_tfrecords.py 生成高级TFRecords,保存和加载整个div2k数据集为TFRecords。将使用tfds(表示tensorflow_datasets,一组现成数据集)加载div2k数据集,对其进行预处理,然后将预处理的数据集序列化为TFRecords。
# DIV2K数据集:DIVerse 2K分辨率高质量图像
# 1000张2K分辨率的图像分为:800张用于训练的图像、100张用于验证的图像和100张用于测试的图像
# 对于每个挑战赛道(具有1.双三次或2.未知降级运算符),
# 高分辨率图像:0001.png,0002.png,…,1000.png
# 缩小的图像:YYYYx2.png表示缩小因子x2;其中YYYY是图像ID;
#           YYYYx3.png,缩小因子x3;
#           YYYYx4.png;缩小因子x4
# DIV2K forder结构如下:
# DIV2K/--DIV2K数据集
# DIV2K/DIV2K_train_HR/--0001.png,0002.png,…,0800.png列车HR图像(提供给参与者)
# DIV2K/DIV2K_train_LR_bicubic/——使用Matlab调大小函数获得的具有默认设置的相应低分辨率图像(双三次插值)

# USAGE
# python example_tf_record.py

import os

# 导入必要的包
import tensorflow as tf

from tfrecords_demo import config
from tfrecords_demo import utils

# 结构化的数据示例包括图片和图片名
# 从特定的url下载图像并将图像保存到磁盘。
imagePath = tf.keras.utils.get_file(
    config.IMAGE_FNAME,
    config.IMAGE_URL,
)

# 使用load_image函数从磁盘加载图像作为tf.Tensor
image = utils.load_image(pathToImage=imagePath)
class_name = config.IMAGE_CLASS

# 检查输出文件夹是否存在,不存在则创建
if not os.path.exists(config.OUTPUT_PATH):
    os.makedirs(config.OUTPUT_PATH)

# 保存缩放后的照片
utils.save_image(image=image, saveImagePath=config.RESIZED_IMAGE_PATH)

# 构建图片tf.train.Feature和类名tf.train.Feature
imageFeature = tf.train.Feature(
    bytes_list=tf.train.BytesList(value=[
        # 注意序列化图像的方法
        tf.io.serialize_tensor(image).numpy(),
    ])
)
classNameFeature = tf.train.Feature(
    bytes_list=tf.train.BytesList(value=[
        class_name.encode(),
    ])
)

# 包装图片和类名feature到一个feature字典中,并将其作为参数初始化一个类
features = tf.train.Features(feature={
    "image": imageFeature,
    "class_name": classNameFeature,
})
example = tf.train.Example(features=features)

# 序列化整个实例 使用SerializeToString函数直接序列化
serialized = example.SerializeToString()

# 将序列化实例写入 TFRecord
with tf.io.TFRecordWriter(config.TFRECORD_EXAMPLE_FNAME) as recordWriter:
    recordWriter.write(serialized)

# 构建feature模式和 TFRecord数据
featureSchema = {
    "image": tf.io.FixedLenFeature([], dtype=tf.string),
    "class_name": tf.io.FixedLenFeature([], dtype=tf.string),
}
# 读取数据构建TFRecord
dataset = tf.data.TFRecordDataset(config.TFRECORD_EXAMPLE_FNAME)

# 遍历数据
for element in dataset:
    # 获取序列化实例数据,并根据feature模式解析
    # 注意如何使用这里的特征示意图来解析示例。(序列化和反序列化时的数据类型是一样的)
    element = tf.io.parse_single_example(element, featureSchema)

    # 获取序列化后的类名和图像
    className = element["class_name"].numpy().decode()
    image = tf.io.parse_tensor(
        element["image"].numpy(),
        out_type=tf.dtypes.float32
    )

    # 使用图片名和图片保存反序列化后的图像
    utils.save_image(
        image=image,
        saveImagePath=config.DESERIALIZED_IMAGE_PATH,
        title=className
    )

3.2 create_tfrecords.py

# USAGE
# python create_tfrecords.py

# 导入必要的包
import os

import tensorflow as tf
import tensorflow_datasets as tfds

from tfrecords_demo import config

# 定义自动调频对象以优化过程
AUTO = tf.data.experimental.AUTOTUNE


def pre_process(element):
    # 获取低、高分辨率图像
    lrImage = element["lr"]
    hrImage = element["hr"]

    # 将低高分辨率图像从Tensor张量转换为序列化的张量TensorProto proto
    lrByte = tf.io.serialize_tensor(lrImage)
    hrByte = tf.io.serialize_tensor(hrImage)

    # 返回低、高分辨率proto对象
    return (lrByte, hrByte)


def create_dataset(dataDir, split, shardSize):
    print(config.DATASET, dataDir, shardSize)
    # 加载数据集,保存到磁盘,并处理
    ds = tfds.load(name="div2k", split=split, data_dir=dataDir,download=True)
    ds = (ds
          .map(pre_process, num_parallel_calls=AUTO)
          .batch(shardSize)
          )

    # 返回数据集TensorFlow dataset object
    return ds


def create_serialized_example(lrByte, hrByte):
    # 创建低、高分辨率图像字节list
    lrBytesList = tf.train.BytesList(value=[lrByte])
    hrBytesList = tf.train.BytesList(value=[hrByte])

    # 从字节list构建低、高分辨率推向feature
    lrFeature = tf.train.Feature(bytes_list=lrBytesList)
    hrFeature = tf.train.Feature(bytes_list=hrBytesList)

    # 构建低、高分辨率图像feature字典
    featureMap = {
        "lr": lrFeature,
        "hr": hrFeature,
    }

    # 构建一个features集合,构建features实例,序列化实例
    features = tf.train.Features(feature=featureMap)
    example = tf.train.Example(features=features)
    serializedExample = example.SerializeToString()

    # 返回序列化的实例
    return serializedExample


def prepare_tfrecords(dataset, outputDir, name, printEvery=50):
    # 检查输出路径是否存在
    if not os.path.exists(outputDir):
        os.makedirs(outputDir)

    # 遍历数据集,创建 TFRecords
    for (index, images) in enumerate(dataset):
        # 获取分片数,构建名称
        shardSize = images[0].numpy().shape[0]
        tfrecName = f"{index:02d}-{shardSize}.tfrec"
        filename = outputDir + f"/{name}-" + tfrecName

        # 写入 tfrecords
        with tf.io.TFRecordWriter(filename) as outFile:
            # write shard size serialized examples to each TFRecord
            for i in range(shardSize):
                serializedExample = create_serialized_example(
                    images[0].numpy()[i], images[1].numpy()[i])
                outFile.write(serializedExample)

            # 打印进度
            if index % printEvery == 0:
                print("[INFO] wrote file {} containing {} records..."
                      .format(filename, shardSize))


# ds = tfds.load('mnist', split='train', shuffle_files=True)
# ds = tfds.load('div2k', split='train[:5%]', shuffle_files=True)

# 创建div2k images的训练和验证数据集
print("[INFO] creating div2k training and testing dataset...")
trainDs = create_dataset(dataDir=config.DIV2K_PATH, split="train[:5%]",
                         shardSize=config.SHARD_SIZE)
testDs = create_dataset(dataDir=config.DIV2K_PATH, split="validation",
                        shardSize=config.SHARD_SIZE)

# 创建训练和测试 TFRecords,并写入磁盘
print("[INFO] preparing and writing div2k TFRecords to disk...")
prepare_tfrecords(dataset=trainDs, name="train",
                  outputDir=config.GPU_DIV2K_TFR_TRAIN_PATH)
prepare_tfrecords(dataset=testDs, name="test",
                  outputDir=config.GPU_DIV2K_TFR_TEST_PATH)

4. 报错及解决

  1. tf.data.experimental.AUTOTUNE
  2. tensorflow >=2.1.0
    在这里插入图片描述

参考

  • https://pyimagesearch.com/2022/08/08/introduction-to-tfrecords/
  • div2k数据集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/88415.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL 语句练习03

目录 一、建表 二、插入数据 三、查询 一、建表 这里先建好我们下面查询需要的表,方便后续查询。 建立如下学生表(命名格式“姓名拼音_三位学号_week5s”, 如LBJ_023_week5s)create table LYL_116_week5s(SNO varchar(4) primary key,SNA…

【Kubernetes】DashBoard部署

kubernetes,是一个全新的基于容器技术的分布式架构领先方案,是谷歌严格保密十几年的秘密武器----Borg系统的一个开源版本,于2014年9月发布第一个版本,2015年7月发布第一个正式版本。 kubernetes的本质是一组服务器集群&#xff0…

数字孪生智慧水务建设综述

随着新时期治水方针的逐步落实,水利现代化、智能化建设已全面开启,数字孪生等新技术的成熟,也为智慧水务体系的搭建提供了技术保障,新时代治水新思路正逐步得到落实。本文将重点对智慧水务的内涵及建设内容进行解读,力…

2022年航空与物流行业研究报告

第一章 行业概况 航空与物流行业是指以各种航空飞行器为运输工具,以空中运输的方式运载人员或货物的企业。航空公司是以各种航空飞行器为运输工具为乘客和货物提供民用航空服务的企业。航空公司使用的飞行器可以是他们自己拥有的,也可以是租来的&#x…

物联网通信技术原理-作业汇总(更新ing)

第一章 第二章 第三章 第四章 第五章 1. 移动通信中典型的多址接入方式有哪些?简要说明其工作原理2. 分集技术的基本原理是什么?简要说明空间、频率和时间分集、合并的异同。 1)分集技术的基本原理 通过多个信道(时间、频率或…

25.访客功能

访客功能 一、需求分析 用户在浏览我的主页时,需要记录访客数据,访客在一天内每个用户只记录一次。 首页展示最新5条访客记录 我的模块,分页展示所有的访客记录 二、数据库表 visitors(访客记录表) { “_id”: …

尚医通 (三十五) --------- 预约下单

目录一、预约下单前端1. 封装 api 请求2. 页面修改二、后端逻辑1. 需求分析2. 搭建 service-order 模块3. 添加订单基础类4. 封装 Feign 调用获取就诊人接口5. 封装 Feign 调用获取排班下单信息接口6. 实现下单接口7. 预约成功后处理逻辑① rabbit-util 模块封装② 封装短信接口…

C++ Reference: Standard C++ Library reference: Containers: map: map: cend

C官网参考链接&#xff1a;https://cplusplus.com/reference/map/map/cend/ 公有成员函数 <map> std::map::cend const_iterator cend() const noexcept;返回指向结束的const_iterator 返回一个指向容器结束后元素的const_iterator。 const_iterator是指向const内容的it…

正弦交流电物理量表征

前言 这一讲主要来表征正弦交流电的物理量 文章目录前言一、周期和频率二、最大值、有效值和平均值一、周期和频率 周期&#xff1a;正弦交流电每重复变化1次所需要的时间称为周期&#xff0c;用符号T表示&#xff0c;单位是秒&#xff08;s&#xff09;。 频率&#xff1a;正…

web前端期末大作业 绿色环境保护(4个页面) HTML5网站模板农业展示网站 html5网页制作代码 html5网页设计作业代码 html制作网页案例代码

&#x1f380; 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业…

体育馆场地预约管理系统/球馆管理系统

摘 要 随着体育馆规模的不断扩大&#xff0c;人流数量的急剧增加&#xff0c;有关体育馆的各种信息量也在不断成倍增长。面对庞大的信息量&#xff0c;就需要有体育馆场地预约管理系统来提高体育馆工作的效率。通过这样的系统&#xff0c;我们可以做到信息的规范管理和快速查询…

TCP/IP网络原理 【IP篇】

&#x1f389;&#x1f389;&#x1f389;写在前面&#xff1a; 博主主页&#xff1a;&#x1f339;&#x1f339;&#x1f339;戳一戳&#xff0c;欢迎大佬指点&#xff01; 目标梦想&#xff1a;进大厂&#xff0c;立志成为一个牛掰的Java程序猿&#xff0c;虽然现在还是一个…

聚观早报 | 马斯克丢掉世界首富宝座;加密货币FTX创始人被捕

今日要闻&#xff1a;马斯克丢掉世界首富宝座&#xff1b;加密货币FTX创始人被捕&#xff1b;美团推出高峰打车极速版&#xff1b;魔兽制作组正研发新功能&#xff1b;SpaceX出售公司内部股票马斯克丢掉世界首富宝座 12 月 13 日消息&#xff0c;据国外媒体报道&#xff0c;受特…

7-54 孤岛营救问题——状压bfs+三维标记

1944 年&#xff0c;特种兵麦克接到国防部的命令&#xff0c;要求立即赶赴太平洋上的一个孤岛&#xff0c;营救被敌军俘虏的大兵瑞恩。瑞恩被关押在一个迷宫里&#xff0c;迷宫地形复杂&#xff0c;但幸好麦克得到了迷宫的地形图。迷宫的外形是一个长方形&#xff0c; 其南北方…

二、小程序框架

目录 框架 一、响应的数据绑定 二、页面管理 三、基础组件 四、丰富的API 模块化 一、模块化 二、文件作用域 三、API 视图层 View 一、WXML 事件 什么是事件 事件的使用方式 使用 WXS 函数响应事件 事件详解 框架 小程序开发框架的目标是通过尽可能简单、高效…

万字长文详解 YOLOv1-v5 系列模型

一&#xff0c;YOLOv1二&#xff0c;YOLOv2三&#xff0c;YOLOv3四&#xff0c;YOLOv4五&#xff0c;YOLOv5参考资料 一&#xff0c;YOLOv1 YOLOv1 出自 2016 CVPR 论文 You Only Look Once:Unified, Real-Time Object Detection. YOLO 系列算法的核心思想是将输入的图像经过…

同时安装python3和Python2

一刚开始我很疑惑&#xff0c;Python为何要并行两个版本呢&#xff1f;今天我算知道了&#xff0c;原来是因为有的项目一直在用python2。虽然我已经安装了python3但是那些使用python2进行部署的项目我仍然无法使用&#xff0c;这就导致我要在电脑上同时安装python2和Python3了。…

【无标题】SIP网络广播音频模块

SIP2101V和SIP2103V网络音频模块是一款通用的独立SIP音频功能模块&#xff0c;可以轻松地嵌入到OEM产品中。该模块对来自网络的SIP协议及RTP音频流进行编解码。 该模块支持多种网络协议和音频编解码协议&#xff0c;可用于VoIP和IP寻呼以及高质量音乐流媒体播放等应用。同时&a…

如何将onnx转ncnn供移动端推理使用

ncnn是一个为手机端极致优化的高性能神经网络前向计算框架。基于 ncnn&#xff0c;开发者能够将深度学习算法轻松移植到手机端高效执行&#xff0c;开发出人工智能 APP&#xff0c;将 AI 带到你的指尖。 但是onnx直接转ncnn会存在很多问题&#xff0c;所以一般考虑都是先将onn…

Mysql 进阶(面向面试篇)InnoDB引擎(redo log undolog readview mvcc)

1.1 逻辑存储结构 1). 表空间 表空间是InnoDB存储引擎逻辑结构的最高层&#xff0c; 如果用户启用了参数 innodb_file_per_table(在8.0版本中默认开启) &#xff0c;则每张表都会有一个表空间&#xff08;xxx.ibd&#xff09;&#xff0c;一个mysql实例可以对应多个表空间&…