【TensorFlow Hub】:有 100 个预训练模型等你用

news2024/7/6 19:04:52

要访问TensorFlow Hub,请单击此处 — https://www.tensorflow.org/hub

一、说明

        TensorFlow Hub是一个库,用于在TensorFlow中发布,发现和使用可重用模型。它提供了一种使用预训练模型执行各种任务(如图像分类、文本分析等)的简单方法。

        TensorFlow Hub提供了广泛的预训练模型,由TensorFlow和更广泛的机器学习社区的研究人员和工程师开发。

        以下是TensorFlow Hub中可用的模型类型的一些示例:

  1. 图像分类模型:这些模型在标记图像的大型数据集上进行训练,可以将图像分类为各种类别。TensorFlow Hub中一些流行的图像分类模型包括Inception,MobileNet,ResNet和VGG。
  2. 对象检测模型:这些模型可以检测和定位图像中的对象。TensorFlow Hub中一些流行的对象检测模型包括Faster R-CNN和YOLO。
  3. 自然语言处理 (NLP) 模型:这些模型可以分析文本并执行情绪分析、文本分类和语言翻译等任务。TensorFlow Hub中一些流行的NLP模型包括BERT和ALBERT。
  4. 语音识别模型:这些模型可以将语音转录为文本。TensorFlow Hub 中一些流行的语音识别模型包括 wav2vec2、spice 和 yamnet。
  5. 生成模型:这些模型可以根据输入数据生成新内容,例如图像或文本。TensorFlow Hub中一些流行的生成模型包括Progressive GAN和BigGAN。
  6. 迁移学习模型:这些模型在大型数据集上预先训练,可以使用较小的数据集针对特定任务进行微调。TensorFlow Hub 中的迁移学习模型可用于各种任务,例如图像分类、对象检测、分割和文本分析。

        TensorFlow Hub提供了一种方便的方式来访问这些预先训练的模型,并将其用于各种机器学习任务。这些模型有多种格式,例如 TensorFlow SavedModel、Keras 模型和 TensorFlow.js 模型,可以轻松将它们集成到您的机器学习管道中。

二、安装和使用

2.1 安装 TensorFlow 和 TensorFlow Hub:

        在使用 TensorFlow Hub 之前,您需要同时安装 TensorFlow 和 TensorFlow Hub。您可以在命令提示符或终端中使用 pip 安装它们:

pip install tensorflow
pip install tensorflow-hub 

        这将安装最新版本的TensorFlow和TensorFlow Hub。

三、从 TensorFlow Hub 加载预训练模型:

        您可以在TensorFlow Hub网站(https://tfhub.dev/)上浏览可用的模型。

        要加载预先训练的模型,您首先需要从TensorFlow Hub网站获取其URL。例如,如果要使用在 ImageNet 2K 数据集上预先训练的EfficientNet_v1_s模型进行图像分类,则可以使用以下 URL:

module_url = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/classification/2"

        接下来,您可以使用类和 URL 创建 Keras 层:hub.KerasLayer

import tensorflow as tf
import tensorflow_hub as hub

feature_extractor = hub.KerasLayer(module_url, input_shape=(384,384,3))

        在此示例中,我们将创建一个 Keras 层,用于从图像中提取特征。 input_shape 参数指定输入图像的形状,在本例中为 384x384 像素,具有 3 个颜色通道 (RGB)。

3.1 使用预先训练的模型执行任务 

        加载预训练模型后,可以使用它来执行特定任务。例如,如果您随机拍摄老虎的图像,则可以执行以下操作:

import numpy as np
import PIL.Image as Image
import matplotlib.pyplot as plt

image = Image.open("image.jpg").resize((384,384))
plt.imshow(image)
plt.show()
image_array = np.array(image) / 255.0
image_batch = np.expand_dims(image_array, axis=0)

features = feature_extractor(image_batch)

输出:

        在此示例中,我们从文件加载图像,将其大小调整为 384x384 像素,显示图像,将其转换为 NumPy 数组,并将其规范化为 0 到 1 之间的值。然后,我们向数组添加一个额外的维度来创建一批图像(因为预训练的模型需要一批图像作为输入)。最后,我们使用预先训练的EfficientNet_v2_s模型对图像进行分类。

labels_file = "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"

#download labels and creates a maps
downloaded_file = tf.keras.utils.get_file("labels.txt", origin=labels_file)

classes = []

with open(downloaded_file) as f:
  labels = f.readlines()
  classes = [l.strip() for l in labels]

        此代码段使用 TensorFlow 库从 Google Cloud Storage 存储桶下载包含 ImageNet 数据集标签的文本文件。文本文件的 URL 存储在 labels_file 变量中。tf.keras.utils.get_file() 然后使用该函数从 URL 指定的 labels_file 下载文件,并将其以名称“labels.txt”保存在本地。 get_file() 函数在本地缓存文件,因此如果再次请求相同的 URL,将使用本地缓存的文件,而不是再次下载文件。

        下载文件后,代码将使用 open() 函数打开它并读取其内容。 readlines() 函数返回一个字符串列表,每个字符串代表文件中的一行。然后,代码使用 strip() 函数删除每行中的所有前导或尾随空格,并将清理后的标签存储在名为“classes”的列表中。

        生成的类列表包含 ImageNet 数据集的 1001 个类(1000 个主类 + 1 个“背景”类)的标签。该列表可用于将神经网络的输出映射到相应的标签以进行显示或进一步分析。    

top_5 = tf.argsort(features, axis=-1, direction="DESCENDING")[0][:5].numpy()

features_array = np.array(features)
features_flatten = features_array.flatten()

for i, item in enumerate(top_5):
  class_index = item +1
  line = classes[class_index].upper()
  prob = round(features_flatten[item]*10,2)
  print(f'Predicted class : {line} : with probability : {prob}%')

输出:

Predicted class : TIGER : with probability : 90.79%
Predicted class : TIGER CAT : with probability : 77.36%
Predicted class : JAGUAR : with probability : 37.36%
Predicted class : LYNX : with probability : 29.64%
Predicted class : LEOPARD : with probability : 25.47%

        此代码片段使用 TensorFlow 库来预测输入图像的前 5 个最有可能的 ImageNet 类。

        变量特征包含输入图像的神经网络模型的输出。 tf.argsort() 函数用于按降序对输出值进行排序并返回排序值的索引。通过使用 [0][:5] 对结果张量进行切片来选择前 5 个索引,然后使用 .numpy() 方法将其转换为 NumPy 数组。

      

        features_array 变量是通过将特征张量转换为 NumPy 数组来创建的。 features_flatten 变量是通过将 features_array 展平为一维数组而创建的。

        然后使用 for 循环迭代前 5 个索引。对于每个索引,通过向索引添加 1(因为索引是从 0 开始的,但类标签是从 1 开始的)并将标签转换为大写,从类列表中检索相应的 ImageNet 类标签。

        最后,使用 print() 函数将预测的类标签及其相应的概率打印到控制台。输出将显示前 5 个最有可能的 ImageNet 类别中每个类别的预测类别及其对应的概率。

   

3.2 微调预训练模型(可选):

        如果预训练的模型不能完全满足您的要求,您可以通过添加新层或在新数据集上重新训练它来对其进行微调。下面是如何微调新分类任务的EfficientNet_v2_s模型的示例:

import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers

model_url = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/feature_vector/2"
model = hub.KerasLayer(model_url, trainable=True)

num_classes = 10 #just an example

fine_tuning_model = tf.keras.Sequential([
    model,
    layers.Dense(num_classes, activation='softmax')
])

fine_tuning_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.CategoricalAccuracy()]
)

epochs = 10
batch_size = 32

fine_tuning_model.fit(
    train_data,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=val_data
)

test_loss, test_acc = fine_tuning_model.evaluate(test_data)
print('Test accuracy:', test_acc)

        在此示例中,我们通过添加新的密集层进行分类来创建新模型。然后,我们使用优化器、损失函数和指标编译模型。最后,我们在指定数量的 epoch 的新数据集上训练模型。

四、结论

        这些是安装和使用 TensorFlow Hub 的基本步骤。TensorFlow Hub 提供了更多功能,例如缓存、版本控制等,因此请务必查看文档以获取更多信息。

资料引用:

A.瓦斯瓦尼,N.沙泽尔,N.帕尔马,J.乌什科雷特,L.琼斯,A.戈麦斯,{.凯撒  阿琼·萨卡尔

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1063500.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GPU(国内外发展,概念参数(CUDA,Tensor Core等),类别,如何选型,NPU,TPU)

目录 前言 1.国内外GPU发展简述 2.GPU概念参数和选择标准 2.1 CUDA 2.2 Tensor Core 2.3 显存容量和显存位宽 2.4 精度 2.5 如何选择GPU 3.常见GPU类别和价格 3.1 GPU类别 3.2 GPU价格(部分) 3.3 GPU云服务器收费标准(以阿里云为例&a…

机器学习基础之《回归与聚类算法(1)—线性回归》

一、线性回归的原理 1、线性回归应用场景 如何判定一个问题是回归问题的,目标值是连续型的数据的时候 房价预测 销售额度预测 贷款额度预测、利用线性回归以及系数分析因子 2、线性回归定义 线性回归(Linear regression)是利用回归方程(函数)对一个或多个自变量(…

【多线程进阶】synchronized 原理

文章目录 前言1. 基本锁策略2. 加锁工作过程2.1 偏向锁2.2 轻量级锁2.3 重量级锁 3. 其他的优化操作3.1 锁消除3.2 锁粗化 总结 前言 在前面章节中, 提到了多线程中的锁策略, 那么我们 Java 中的锁 synchronized 背后都采取了哪些锁策略呢? 又是如何进行工作的呢? 本节我们就…

第三课 哈希表、集合、映射

文章目录 第三课 哈希表、集合、映射lc1.两数之和--简单题目描述代码展示 lc30.串联所有单词的子串--困难题目描述代码展示 lc49.字母异位分组--中等题目描述代码展示 lc874.模拟行走机器人--中等题目描述代码展示 lc146.LRU缓存--中等题目描述相关补充思路讲解代码展示图示理解…

提升您的 Go 应用性能的 6 种方法

优化您的 Go 应用程序 1. 如果您的应用程序在 Kubernetes 中运行,请自动设置 GOMAXPROCS 以匹配 Linux 容器的 CPU 配额 Go 调度器 可以具有与运行设备的核心数量一样多的线程。由于我们的应用程序在 Kubernetes 环境中的节点上运行,当我们的 Go 应用程…

美国各流域边界下载,并利用arcgis提取与处理

一、边界数据的下载 一般使用最普遍的流域边界数据是从HydroSHEDS官网下载: HydroBASINS代表一系列矢量多边形图层,以全球尺度呈现次级流域边界。该产品的目标是提供一种无缝的全球覆盖,其中包含了不同尺度(从数十到数百万平方千米&#xf…

Docker 配置基础优化

Author:rab 为什么要优化? 你有没有发现,Docker 作为线上环境使用时,Docker 日志驱动程序的日志、存储驱动数据都比较大(尤其是在你容器需要增删比较频繁的时候),动不动就好几百 G 的大小&…

P3-Python学习当中的两大法宝函数

P3-Python学习当中的两大法宝函数 实战操作 打开pycharm,在命令行当中先检测是否是在envs当中的pytorch环境里面,或者导入torch包是否成功 dir(torch)//展示torch以下的分隔的工具包证明torch目录以下有cuda包 dir(torch.cuda.is_available())//可以展示…

Springboot学生成绩管理系统idea开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 springboot 学生成绩管理系统是一套完善的信息系统,结合springboot框架和bootstrap完成本系统,对理解JSP java编程开发语言有帮助系统采用springboot框架(MVC模式开发),系统 具有完整的源代码和数据库&…

golang gin——controller 模型绑定与参数校验

controller 模型绑定与参数校验 gin框架提供了多种方法可以将请求体的内容绑定到对应struct上,并且提供了一些预置的参数校验 绑定方法 根据数据源和类型的不同,gin提供了不同的绑定方法 Bind, shouldBind: 从form表单中去绑定对象BindJSON, shouldB…

【MVC】C# MVC基础知识点、原理以及容器和管道

给自己一个目标,然后坚持一段时间,总会有收获和感悟! 国庆假期马上结束,闲暇时间,重温一遍C#关于MVC的技术,控制器、视图、模型,知识点和原理,小伙伴们还记得吗 目录 一、MVC知识点1…

纸质书籍OCR方案大揭秘,快来看看有哪些神奇的黑科技

随着数字化时代的来临,纸质书籍逐渐被电子书所替代。在将纸质书籍转换为电子格式的过程中,扫描电子书目录并进行文字识别(OCR,Optical Character Recognition)成为了一项重要的工作。OCR技术能够将纸质书籍中的文字内容…

如何使用 Overleaf 编写 LaTeX 文档

如何使用 Overleaf 编写 LaTeX 文档 😇博主简介:我是一名正在攻读研究生学位的人工智能专业学生,我可以为计算机、人工智能相关本科生和研究生提供排忧解惑的服务。如果您有任何问题或困惑,欢迎随时来交流哦!&#x1f…

Firefly-LLaMA2-Chinese - 开源中文LLaMA2大模型

文章目录 关于模型列表 & 数据列表训练细节增量预训练 & 指令微调数据格式 & 数据处理逻辑增量预训练指令微调模型推理权重合并模型推理部署关于 github : https://github.com/yangjianxin1/Firefly-LLaMA2-Chinese本项目与Firefly一脉相承,专注于低资源增量预训练…

模糊搜索利器:Python的thefuzz模块详解

文章目录 thefuzz模块简介thefuzz模块的参数和方法使用thefuzz实现模糊搜索在Python中,thefuzz模块是一个用于实现模糊搜索的强大工具。它可以帮助我们在处理字符串时,快速找到相似的匹配项。本文将详细介绍thefuzz模块的功能和用法,并结合代码示例演示如何实现模糊搜索。 t…

有自动交易股票的软件么,怎么实现全自动交易?

随着技术的发展,我们经常会在看到一些关于自动交易股票软件的宣传。那么,这些软件是否真的存在?如何实现全自动交易呢? 股票量化程序化自动交易接口 一、自动交易股票软件存在吗? 答案是有,部分券商已经对…

Python数据容器——集合的相关操作

作者:Insist-- 个人主页:insist--个人主页 本文专栏:Python专栏 专栏介绍:本专栏为免费专栏,并且会持续更新python基础知识,欢迎各位订阅关注。 目录 一、理解集合 1. 集合是什么? 2. 为什么…

typora + picgo + 对象存储 OSS

文章目录 一、安装软件二、使用阿里云 oss 存储图片三、picgo 设置四、typora 设置自动上传 一、安装软件 Typora1.3.8 (安装即破解) picgo 2.3.0 安装 阿里云盘(软件安装包): https://www.aliyundrive.com/s/saQoS…

Windows10实用的12个快捷组合键

Windows10实用的12个快捷组合键 1、网页多标签切换 CTRL TAB 2、恢复不小心关闭的标签页 CTRLSHIFT T 3、新建标签页 CTRL T 4、高亮选择地址栏 ALT D 5、打开设置 WIN I 6、打开任务管理器 CTRLSHIFT ESC 7、打开文件资源管理器 WIN E 8、黑屏或屏幕卡顿无响应&#x…

python实验(超详细)

目录 实验一 python编程基础实验二 python序列、字符串处理实验三 函数及python类的定义与使用实验四 python综合应用 实验一 python编程基础 在交互式环境中打印“Hello world”字符串。记录操作过程。 略 创建脚本helloworld.py,在命令符提示环境中执行程序&…