Hugging Face应用——图像识别

news2024/11/18 1:30:13

在这里插入图片描述
利用人工智能解决音频、视觉和语言问题。音频分类、图像分类、物体检测、问答、总结、文本分类、翻译等均有大量模型进行参考。

Eg1: 图像识别

图像分类是为整个图像分配标签或类别的任务。每张图像预计只有一个类别。图像分类模型将图像作为输入并返回有关图像所属类别的预测

在这里插入图片描述

借助该transformers库,可以使用image-classification管道来推断图像分类模型。在不提供模型ID时,默认使用google/vit-base-patch16-224进行初始化pipeline。 调用pipeline管道时,只需要指定路径、http链接或PIL(Python Imaging Library)中加载的图标;还可以提供一个top_k参数来确定应返回多少结果

在这里插入图片描述

使用Transformer微调ViT

如何像对句子标记一样对图像进行标记,以便将其传递到Transformer模型进行训练。

  1. 将图像分割成子图像块的网格
  2. 使用线性投影嵌入每个补丁
  3. 每个嵌入的补丁都会成为一个令牌,嵌入补丁的结果序列就是传递给模型的序列

在这里插入图片描述

  • MLP:Multilayer Perceptron 前向结构的人工神经网络——多层感知器
  • Embedded Patches 嵌入补丁

如何使用datasets下载和处理的图像分类数据集通过微调预训练的ViT transformer

  1. 首先安装软件包

    pip install datasets transformers Pillow
    
  2. 加载数据集

    使用beans数据集,是健康和非健康豆叶的图片集合

    from datasets import load_dataset
    
    ds = load_dataset('beans')
    //DatasetDict({
    //    train: Dataset({
    //        features: ['image_file_path', 'image', 'labels'],
    //        num_rows: 1034
    //    })
    //    validation: Dataset({
    //        features: ['image_file_path', 'image', 'labels'],
    //       num_rows: 133
    //    })
    //    test: Dataset({
    //        features: ['image_file_path', 'image', 'labels'],
    //        num_rows: 128
    //    })
    //})
    

    每个数据集中每个示例都有3个特征:

    • image: PIL图像
    • image_file_path: str加载的图像文件的路径image
    • labels:一个datasets.CLassLabel特征,是标签的整数表示
    {
      'image': <PIL.JpegImagePlugin ...>,
      'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
      'labels': 1
    }
    
    ex = ds['train'][400]
    image = ex['image']
    


    由于'labels'该数据集的特征是 datasets.features.ClassLabel,我们可以使用它来查找本示例的标签 ID 的相应名称

labels = ds['train'].features['labels']

// ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)

使用int2str 函数来打印示例的类标签

labels.int2str(ex['labels'])

// 'bean_rust'

上面图片叶子感染了“豆锈病”,是一种豆科植物的严重疾病

编写一个函数显示每个类的示例网格:

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)

在这里插入图片描述

  • 角叶斑:有不规则的棕色斑块
  • 豆锈病:有圆形棕色斑点,周围有白黄色环
  • 健康:……看起来很健康

加载ViT特征提取器

现在知道图像是什么样子,并且更好地理解我们要解决的问题。让我们看看如何为我们的模型准备这些图像!

当训练 ViT 模型时,特定的转换将应用于输入到其中的图像。对图像使用错误的转换,模型将无法理解它所看到的内容!🖼➡️🔢

为了确保我们应用正确的转换,我们将使用ViTFeatureExtractor与我们计划使用的预训练模型一起保存的配置进行初始化。在我们的例子中,我们将使用google/vit-base-patch16-224-in21k模型,因此让我们从 Hugging Face Hub 加载其特征提取器。

from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

//ViTFeatureExtractor {
//  "do_normalize": true,
//  "do_resize": true,
//  "feature_extractor_type": "ViTFeatureExtractor",
//  "image_mean": [
//    0.5,
//    0.5,
//    0.5
//  ],
//  "image_std": [
//    0.5,
//    0.5,
//    0.5
//  ],
//  "resample": 2,
//  "size": 224
//}

要处理图像,只需将其传递给特征提取器的调用函数即可。这将返回一个包含 的字典pixel values,它是要传递给模型的数字表示形式

默认情况下,您会得到一个 NumPy 数组,但如果添加参数return_tensors='pt',您将得到torch张量…张量的形状是(1, 3, 224, 224)

feature_extractor(image, return_tensors='pt')

//{
//  'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
//}
  1. 处理数据集

    编写一个函数,读取图像并将其转换为输入——>处理数据集中的单个示例

    def process_example(example):
        inputs = feature_extractor(example['image'], return_tensors='pt')
        inputs['labels'] = example['labels']
        return inputs
      
    process_example(ds['train'][0])
    
    //{
    //  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
    //  'labels': 0
    //}
    

    ⚠️: 虽然可以ds.map立即调用方法并将其应用于每个示例,但在大数据集时会出现性能问题。故在每次转换发生于索引示例时应用于示例

    ds = load_dataset('beans')
    
    def transform(example_batch):
        # Take a list of PIL images and turn them to pixel values
        inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    
        # Don't forget to include the labels!
        inputs['labels'] = example_batch['labels']
        return inputs
    

    也可以直接将其应用到数据集

    prepared_ds = ds.with_transform(transform)
    

    每次从数据集中获取示例时,都会实时应用转换

    prepared_ds['train'][0:2]
    
    //{
    //  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
    //  'labels': [0, 0]
    //}
    

    训练

    数据已处理完毕,就可以开始设置训练管道了。

    这篇博文使用了 🤗 的 Trainer,但这需要我们首先做一些事情:

    • 定义一个整理函数。
    • 定义评估指标。在训练期间,应评估模型的预测准确性。应该使用compute_metrics相应地定义一个函数。
    • 加载预训练的检查点。需要加载预训练的检查点并正确配置它以进行训练。
    • 定义训练配置。

    对模型进行微调后,我们将在评估数据上正确评估它,并验证它确实学会了正确分类图像。

    1. 定义我们的数据整理器

      批次以字典列表的形式出现,因此可以将它们解压+堆叠到批次张量中。

      由于collate_fn将返回一个批处理字典,因此可以**unpack稍后将输入输入到模型中

      import torch
      
      def collate_fn(batch):
          return {
              'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
              'labels': torch.tensor([x['labels'] for x in batch])
          }
      
    2. 定义评估指标

      数据集的准确性度量可以轻松用于将预测与标签进行比较。 下面可以看到如何在 Trainer 将使用的compute_metrics 函数中使用它

      import numpy as np
      from datasets import load_metric
      
      metric = load_metric("accuracy")
      def compute_metrics(p):
          return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
      

      加载预训练的模型。我们将添加num_labelsinit,以便模型创建具有正确数量的单元的分类头。

      (我们还将包括id2labellabel2id映射,以便在 Hub 小部件中具有人类可读的标签。。如果您选择这样做push_to_hub)

      from transformers import ViTForImageClassification
      
      labels = ds['train'].features['labels'].names
      
      model = ViTForImageClassification.from_pretrained(
          model_name_or_path,
          num_labels=len(labels),
          id2label={str(i): c for i, c in enumerate(labels)},
          label2id={c: str(i) for i, c in enumerate(labels)}
      )
      

      在此之前需要的最后一件事是通过定义 来设置训练配置TrainingArguments

      其中大多数都是不言自明的,但这里非常重要的是remove_unused_columns=False。这将删除模型调用函数未使用的所有功能。默认情况下,这是True因为通常最好删除未使用的特征列,从而更容易将输入解包到模型的调用函数中。但是,在我们的例子中,我们需要未使用的特征(特别是“图像”)来创建“像素值”。

      from transformers import TrainingArguments
      
      training_args = TrainingArguments(
        output_dir="./vit-base-beans",
        per_device_train_batch_size=16,
        evaluation_strategy="steps",
        num_train_epochs=4,
        fp16=True,
        save_steps=100,
        eval_steps=100,
        logging_steps=10,
        learning_rate=2e-4,
        save_total_limit=2,
        remove_unused_columns=False,
        push_to_hub=False,
        report_to='tensorboard',
        load_best_model_at_end=True,
      )
      

      接下来便可以开始训练

      from transformers import Trainer
      
      trainer = Trainer(
          model=model,
          args=training_args,
          data_collator=collate_fn,
          compute_metrics=compute_metrics,
          train_dataset=prepared_ds["train"],
          eval_dataset=prepared_ds["validation"],
          tokenizer=feature_extractor,
      )
      

      Train 🚀

      train_results = trainer.train()
      trainer.save_model()
      trainer.log_metrics("train", train_results.metrics)
      trainer.save_metrics("train", train_results.metrics)
      trainer.save_state()
      

      Evaluate 📊

      metrics = trainer.evaluate(prepared_ds['validation'])
      trainer.log_metrics("eval", metrics)
      trainer.save_metrics("eval", metrics)
      

      结果如下:

      ***** eval metrics *****
        epoch                   =        4.0
        eval_accuracy           =      0.985
        eval_loss               =     0.0637
        eval_runtime            = 0:00:02.13
        eval_samples_per_second =     62.356
        eval_steps_per_second   =       7.97
      

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/727130.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OPPO手机上怎么设置阴历或阳历生日提醒?

有不少手机用户现在使用的都是OPPO这个品牌的手机&#xff0c;并且绝大多数用户都表示OPPO手机是比较好用的&#xff0c;不过也有一部分用户在使用手机的过程中遇到了一些问题&#xff0c;例如不知道在OPPO手机上怎么设置阴历或阳历生日提醒&#xff0c;这应该怎么办呢&#xf…

基于matlab开发和评估停车场场景中的视觉定位算法(附源码)

一、前言 本示例展示了如何使用虚幻引擎模拟环境中的合成图像数据开发视觉定位系统。 获取基本事实以评估定位算法在不同条件下的性能是一项具有挑战性的任务。与使用高精度惯性导航系统或差分GPS等更昂贵的方法相比&#xff0c;不同场景下的虚拟仿真是一种经济高效的方法来获…

数字化时代,到底如何认识商业智能BI?

数字化时代&#xff0c;商业智能BI对于企业的落地应用有着巨大价值&#xff0c;逐渐成为了现代企业信息化、数字化转型中的基础建设。 我曾经看到有人在讨论过商业智能BI的部署对于企业是否有实际意义&#xff0c;现在市场的数据已经证明商业智能BI在商业世界中&#xff0c;在…

使用Docker安装RabbitMQ并实现入门案例“Hello World”

RabbitMQ官方文档&#xff1a;RabbitMQ Tutorials — RabbitMQ 一、RabbitMQ安装&#xff08;Linux下&#xff09; 你可以选择原始的方式安装配置&#xff0c;也可以使用docker进行安装&#xff0c;方便快捷&#xff01; 1. 安装docker 没有docker的先安装一下docker&#x…

谷歌和edge浏览器升级到94及以上版本后反复提示安装pageoffice客户端

原因&#xff1a;Chrome开发团队以网络安全为由&#xff0c;强推ssl证书&#xff0c;希望所有部署在公网的网站&#xff0c;全部改用https访问&#xff0c;所以最新的谷歌和edge升级到94版本后对公网上的http请求下的非同域的http请求进行了拦截&#xff0c;于是就出现了目前遇…

一分钟告诉你国内和国外的ai绘画软件哪个好

前几天&#xff0c;我在一次聚会上偶然听到朋友们谈论起创作ai绘画的问题&#xff0c;大家都很热衷于用国内的ai绘画软件来生成自己喜欢的艺术作品&#xff0c;但又不知道国内和国外的ai绘画软件哪个好。正当我们陷入无尽的思考中时&#xff0c;其中一位朋友突然站出来说&#…

【计算机网络】1.5——计算机网络的体系结构

计算机网络的体系结构 概述 计算机网络的体系结构是计算机网络及其构建所应完成功能的精确定义 考题 不属于网络体系结构所描述的内容的是 A、网络的层次 B、每层使用的协议 C、协议的内部实现细节 D、每层必须完成的功能 这些功能的「实现细节」&#xff0c;是遵守这种体系…

SPEC CPU 2017 Ubuntu 20.04 LTS cpu2017-1_0_5.iso 安装、测试 单核成绩 笔记

环境 $ gcc -v Using built-in specs. COLLECT_GCCgcc COLLECT_LTO_WRAPPER/usr/lib/gcc/x86_64-linux-gnu/11/lto-wrapper OFFLOAD_TARGET_NAMESnvptx-none:amdgcn-amdhsa OFFLOAD_TARGET_DEFAULT1 Target: x86_64-linux-gnu Configured with: ../src/configure -v --with-pk…

vue3中的computed和watch

一、computed 1. vue2和vue3中计算属性用法对比 Vue2中的计算属性 Vue2中的计算属性是通过在Vue实例的computed选项中定义函数来创建的。计算属性会根据依赖的响应式属性进行缓存&#xff0c;只有当依赖的属性发生变化时&#xff0c;计算属性才会重新求值。 举个例子&#x…

【环境配置】Conda报错 requests.exceptions.HTTPError

问题&#xff1a; conda 创建新的虚拟环境时报错 Collecting package metadata (current_repodata.json): done Solving environment: done# >>>>>>>>>>>>>>>>>>>>>> ERROR REPORT <<<<<<…

OpenCVForUnity(二)基本图像容器Mat

这里写目录标题 前言Mat指针引用说明存储的方式如何创建一个Mat对像 前言 今天继续学习OpenCV的基本单位Mat. 学计算机的同学都知道在计算机中,你所看到的一切其都是数据的呈现.期最底层的本质皆是0和1的构成的.当然图片,视频等等也不例外.我们用相机,扫描仪核磁共振成像等方式…

OpenAI深夜放大招,GPT4 API全面开放并弃用一系列旧模型

GPT-4 API 现已向所有付费 OpenAI API 客户开放。GPT-3.5 Turbo、DALLE 和 Whisper API 现已普遍可用&#xff0c;我们宣布了一些旧型号的弃用计划&#xff0c;这些型号将于 2024 年初退役。 ✅ GPT4 API面向付费用户开放&#xff0c;不需要再额外申请,并且具有8K上下文&#…

bash文件输入到txt文件中

bash test_bct.sh >> test.txt结果如下

WeeChat 4.0.0 正式发布

导读WeeChat (Wee Enhanced Environment for Chat) 是一款自由开源的轻量级 IRC 客户端&#xff0c;具有高度的可定制特性&#xff0c;并且可以通过脚本进行扩展。 WeeChat 支持大多数的平台和操作系统&#xff0c;例如 Linux、BSD、macOS、Debian GNU/Hurd、HP-UX、Solaris、…

全国产化适配低代码平台,政企数字化的不二选择

编者按&#xff1a;在国家政策及战略方向的指导下&#xff0c;信创产业已成为奠定中国未来发展的重要数字基础&#xff0c;而国产化则可以解决核心技术关键被“卡脖子”的问题。另一方面&#xff0c;低代码平台能够为企业加速交付业务应用&#xff0c;降低运营成本&#xff0c;…

插入排序(思路+代码)

变量&#xff1a; index &#xff1a;代表待插入数的前一个数的下标&#xff0c;依次往回找&#xff0c;找到找到结果。 indexvalue&#xff1a;代表待插入元素的值&#xff0c;找到位置之后往index1的位置插入元素 代码&#xff1a; import java.util.Arrays;public class …

【库表操作】

一、数据库Market中创建表customers 1、创建数据库 #创建数据库 mysql> create database Market; mysql> use Market;2、创建数据表 #创建数据表 mysql> create table customers(-> c_num int(11) primary key auto_increment,-> c_name varchar(50),-> c_…

iOS-配置Universal Links通用链接

1、开启Associated Domains服务 登录苹果开发者网站&#xff0c;在Certificates, Identifiers & Profiles页面左侧选择Identifiers&#xff0c;右侧选择对应的App ID&#xff0c;点击进入配置详情页&#xff0c;开启Associated Domains服务&#xff1b; 2、更新Profile文件…

【动手学习深度学习--逐行代码解析合集】09权重衰减

【动手学习深度学习】逐行代码解析合集 09权重衰减 视频链接&#xff1a;动手学习深度学习–权重衰减 课程主页&#xff1a;https://courses.d2l.ai/zh-v2/ 教材&#xff1a;https://zh-v2.d2l.ai/ 0、准备工作 import matplotlib # 注意这个也要import一次 import matplotli…

Wordpress的mysql迁库遇到问题

在我们迁移库的时候经常会出现如下问题&#xff1a; 5.7日期默认0000-00-00 00:00:00 设置错误。 MySQL默认设置中不支持日期datetime格式下的0000-00-00 00:00:00。 解决方法如下&#xff1a; select sql_mode 来查看对应内容 ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO…