OpenVINO 2022.3实战三:POT API实现图像分类模型 INT8 量化

news2024/12/30 0:46:00

OpenVINO 2022.3实战三:POT API实现图像分类模型 INT8 量化

1 准备需要量化的模型

这里使用我其他项目里面,使用 hymenoptera 数据集训练好的 MobileNetV2 模型,加载pytorch模型,并转换为onnx。

import os
from pathlib import Path
import sys
import torch
import torch.nn as nn
import torchvision

from torchvision import transforms, datasets

import matplotlib.pyplot as plt
import numpy as np
from openvino.tools.pot.api import DataLoader, Metric
from openvino.tools.pot.engines.ie_engine import IEEngine
from openvino.tools.pot.graph import load_model, save_model
from openvino.tools.pot.graph.model_utils import compress_model_weights
from openvino.tools.pot.pipeline.initializer import create_pipeline
from openvino.runtime import Core
from torchvision import transforms
from subprocess import run
from SlimPytorch.quantization.mobilenet_v2 import MobileNetV2

# Set the data and model directories
DATA_DIR = '/home/liumin/data/hymenoptera/val'
MODEL_DIR = './weights'



def load_pretrain_model(model_dir):
    model = MobileNetV2('mobilenet_v2', classifier=True)
    num_ftrs = model.fc[1].in_features
    model.fc[1] = nn.Linear(num_ftrs, 2)
    model.load_state_dict(torch.load(model_dir, map_location='cpu'))
    return model

def load_val_data(data_dir):
    data_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image_dataset = datasets.ImageFolder(data_dir, data_transform)
    # dataload = torch.utils.data.DataLoader(image_dataset, batch_size=1, shuffle=False, num_workers=4)
    return image_dataset


model = load_pretrain_model(Path(MODEL_DIR) / 'mobilenet_v2_train.pt')
dataset = load_val_data(DATA_DIR)

model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

onnx_model_path = Path(MODEL_DIR) / 'mobilenet_v2.onnx'
ir_model_xml = onnx_model_path.with_suffix('.xml')
ir_model_bin = onnx_model_path.with_suffix('.bin')

torch.onnx.export(model, dummy_input, onnx_model_path)

运行模型优化器将ONNX转换为OpenVINO IR:

mo --compress_to_fp16 -m .\weights\mobilenet_v2.onnx  --output_dir .\weights\

2 定义数据加载和精度验证功能

这里注意 需要继承来自 openvino.tools.pot.api 的 DataLoader和 Metric 类

# Create a DataLoader.
class QDataLoader(DataLoader):

    def __init__(self, config):
        """
        Initialize config and dataset.
        :param config: created config with DATA_DIR path.
        """
        super().__init__(config)
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        """
        Return one sample of index, label and picture.
        :param index: index of the taken sample.
        """
        image, label = self.dataset[index]
        return (index, label), image.numpy()

    def load_data(self, dataset):
        """
        Load dataset in needed format.
        :param dataset:  downloaded dataset.
        """
        pictures, labels, indexes = [], [], []

        for idx, sample in enumerate(dataset):
            pictures.append(sample[0])
            labels.append(sample[1])
            indexes.append(idx)

        return indexes, pictures, labels


class Accuracy(Metric):

    # Required methods
    def __init__(self, top_k=1):
        super().__init__()
        self._top_k = top_k
        self._name = 'accuracy@top{}'.format(self._top_k)
        self._matches = []

    @property
    def value(self):
        """ Returns accuracy metric value for the last model output. """
        return {self._name: self._matches[-1]}

    @property
    def avg_value(self):
        """ Returns accuracy metric value for all model outputs. """
        return {self._name: np.ravel(self._matches).mean()}

    def update(self, output, target):
        """ Updates prediction matches.
        :param output: model output
        :param target: annotations
        """
        if len(output) > 1:
            raise Exception('The accuracy metric cannot be calculated '
                            'for a model with multiple outputs')
        if isinstance(target, dict):
            target = list(target.values())
        predictions = np.argsort(output[0], axis=1)[:, -self._top_k:]
        match = [float(t in predictions[i]) for i, t in enumerate(target)]

        self._matches.append(match)

    def reset(self):
        """ Resets collected matches """
        self._matches = []

    def get_attributes(self):
        """
        Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
        Required attributes: 'direction': 'higher-better' or 'higher-worse'
                             'type': metric type
        """
        return {self._name: {'direction': 'higher-better',
                             'type': 'accuracy'}}

3 运行优化流程

量化模型

model_config = {
    'model_name': 'mobilenet_v2',
    'model': ir_model_xml,
    'weights': ir_model_bin
}
engine_config = {'device': 'CPU'}
dataset_config = {
    'data_source': DATA_DIR
}
algorithms = [
    {
        'name': 'DefaultQuantization',
        'params': {
            'target_device': 'CPU',
            'preset': 'performance',
            'stat_subset_size': 300
        }
    }
]

# Steps 1-7: Model optimization
# Step 1: Load the model.
model = load_model(model_config)

# Step 2: Initialize the data loader.
data_loader = QDataLoader(dataset_config)

# Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric.
metric = Accuracy(top_k=1)

# Step 4: Initialize the engine for metric calculation and statistics collection.
engine = IEEngine(engine_config, data_loader, metric)

# Step 5: Create a pipeline of compression algorithms.
pipeline = create_pipeline(algorithms, engine)

# Step 6: Execute the pipeline.
compressed_model = pipeline.run(model)

# Step 7 (Optional): Compress model weights quantized precision
#                    in order to reduce the size of final .bin file.
compress_model_weights(compressed_model)

# Step 8: Save the compressed model to the desired path.
compressed_model_paths = save_model(model=compressed_model, save_path=MODEL_DIR, model_name="quantized_mobilenet_v2"
)
compressed_model_xml = compressed_model_paths[0]["model"]
compressed_model_bin = Path(compressed_model_paths[0]["model"]).with_suffix(".bin")

4 比较原始模型和量化模型的准确性

# Step 9: Compare accuracy of the original and quantized models.
metric_results = pipeline.evaluate(model)
if metric_results:
    for name, value in metric_results.items():
        print(f"Accuracy of the original model: {name}: {value}")

metric_results = pipeline.evaluate(compressed_model)
if metric_results:
    for name, value in metric_results.items():
        print(f"Accuracy of the optimized model: {name}: {value}")

输出:

Accuracy of the original model: accuracy@top1: 0.9215686274509803
Accuracy of the optimized model: accuracy@top1: 0.921568627450980

5 比较原始模型和量化模型的性能

使用OpenVINO中的Benchmark Tool(推理性能测量工具)测量FP16和INT8模型的推理性能

FP16:

benchmark_app -m .\weights\mobilenet_v2.xml -d CPU -api async

输出:

在这里插入图片描述

INT8:

benchmark_app -m .\weights\quantized_mobilenet_v2.xml -d CPU -api async

输出:

在这里插入图片描述

可以看出吞吐量增大了1.5倍

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/559109.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙Hi3861学习十八-DevEco Device Tool环境搭建

一、简介 在之前的文章中,我们是通过在windows下烧录,在ubuntu下编译的方式进行开发。今天我们同样是采用windowsubuntu混合环境进行开发。为什么要采用这种方式呢?因为就目前而言,大部分的开发板还不支持在Windows环境下进行编译…

典型的高可用设计(二):MySQL

一、高可用模式 MySQL数据库提供了数据库建的复制能力,做到了多个数据库同时拥有同一个数据副本,保证了数据的安全性,一台数据库服务器出现问题,其他数据库可以做到数据不丢失。MySQL的服务高可用设计也是以数据库复制能力为基础&…

云计算专业怎么样,大学应届生学的话难不难?

云计算专业学起来挺难的,一般人建议不要轻易尝试!!! 虽然IT行业一直以来发展前景、技术更新、新领域的开发或者新概念的提出等各方面都还不错,云计算也是当下非常火的一个就业方向,很多人也非常想进入云计…

IT系统方案大纲模版,以智慧工地系统为例

# 咖米智慧工地解决方案 ## 第1章 智慧工地系统概述 ### 1.1应用背景 ### 1.2需求分析 ### 1.3总体目标 ## 第2章 系统总体设计 ### 2.1设计理念 ### 2.2设计依据 ### 2.3设计架构 ### 2.4系统描述 ### 2.5系统特点 ## 第3章 详细设计 ### 3.1工地远程监控子系统 #### 3.1.1需求…

一分钟了解乐观锁、悲观锁、共享锁、排它锁、行锁、表锁以及使用场景

大家好,我是冰点,今天给大家带来,关于MySQL中的锁的使用。 我首先提个问题,大家知道什么是 乐观锁、悲观锁、共享锁,、排它锁、行锁、表锁,以及每种锁的使用场景吗? !! 背景:最近在各…

Unity 使用 VSCode 作为默认编辑器,解决没有代码提示,智能补全功能

文章目录 删除现有编辑器配置选择 VSCode 作为代码编辑器代码补全和智能提示 删除现有编辑器配置 首先打开你的项目文件夹,需要把这几个文件删掉,稍后重新生成~ 选择 VSCode 作为代码编辑器 打开 Edit - Preference: 选择 External Script…

【bsauce读论文】2023-SP-内核Use-After-Cleanup漏洞挖掘与利用

本文参考G.O.S.S.I.P 阅读推荐 2023-01-06 UACatcher做一些补充。 1. UAC漏洞介绍 UAC漏洞介绍:Use-After-Cleanup (UAC)漏洞类似UAF,本文主要检测Linux内核中UAC漏洞。UAC基本原理参见图Fig-1。首先,UAC漏洞和系统中…

众多行业适用的这款Lighthouse Apex Z便携粒子计数器有什么优势

Lighthouse Apex Z粒子计数器围绕易用性和可靠性进行构建。是建立在Lighthouse洁净室行业 40 多年的基于问题的学习基础上的解决方案。 采样设置 ApexZ易于使用的样品设置,可以匹配当前的sop,减少丢失位置或采样错误参数的风险。 用户管理 为了提高效…

ES6:var 、const、let的使用和区别

前言 本文主要介绍了ES6中var、const、let的使用和区别 基本介绍 let let声明变量 const const :声明常量const声明的常量可以修改,但不能重新赋值 如:以下代码是正确的: //引用数据类型 const info {name:Candy }; info.nameJune;而下面的代码是…

GPT-4国内有免费平替吗?

免费/平替永远是最贵的 就如同我们生活中买口红一样,总想找到平替,但永远比不上看中的那只! 但在寻找平替过程中 花出去的时间、金钱成本都是翻倍的。 那么GPT-4呢? GPT-4优于GPT-3.5闪光点,想必大家都十分清楚 不…

基于springboot自动排课系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SpringBoot 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:…

云计算安全

前言 什么是云计算? 云计算就是一种新兴的计算资源利用方式,云计算的服务商通过对硬件资源的虚拟化,将基础IT资源变成了可以自由调度的资源池,从而实现IT资源的按需分配,向客户提供按使用付费的云计算服务。用户可以…

帽子设计作品——蒸汽朋克的乌托邦,机械配件的幻想世界!

蒸汽朋克是由蒸汽steam和朋克punk两个词组成, 蒸汽代表着以蒸汽机作为动力的大型机械,而朋克则代表一种反抗、叛逆的精神。 蒸汽朋克的作品通常以蒸汽时代为背景,通过如新能源、新机械、新材料、新交通工具等新技术,使画面充满想…

基于OpenCV和PyQt5的跳远成果展示程序

基于OpenCV和PyQt5的跳远成果展示程序 近年来,体育运动越来越受到人们的关注,其中跳远是一项备受瞩目的运动项目。为了更好地展示运动员的跳远成果,本文将介绍一种基于OpenCV和PyQt5的跳远成果展示程序实现方法。 本文的跳远成果展示程序主…

基于SSM的校园办公管理系统的设计与实现(源码完整)

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据你想解决的问题,今天给…

【TES641】基于VU13P FPGA的4路FMC接口基带信号处理平台

板卡概述 TES641是一款基于Virtex UltraScale系列FPGA的高性能4路FMC接口基带信号处理平台,该平台采用1片Xilinx的Virtex UltraScale系列FPGA XCVU13P作为信号实时处理单元,该板卡具有4个FMC子卡接口(其中有2个为FMC接口)&#xf…

Sui如何进行独立审计

Sui及其生态项目的第三方审计对于网络的安全至关重要。 类似于Sui这样的L1区块链必须使用多重有效的措施,来确保项目保持尽可能高的安全级别。Sui流程中的一个关键环节就是使用第三方安全审计。了解Sui的安全状态及其维护方式对整个社区来说很重要,因此…

【JavaSE】Java基础语法(二)

文章目录 1. ⛄类型转换1.1 🪂🪂隐式转换1.2 🪂🪂强制转换 2. ⛄运算符2.1 🪂🪂算术运算符2.1.1 算术运算符2.1.2 字符的“”操作2.1.3 字符串的“”操作2.1.4 数值拆分 2.2 🪂🪂自增…

SQL注入 - Part 2

SQL注入 - Part 2 1.sql注入自动化工具--sqlmap配置环境变量/快捷方式一些sqlmap的常用语句前置SQL知识batch批量注入 2.sql注入靶场——sqlilabs3.布尔盲注4.基于时间的盲注5.基于报错的注入总结 1.sql注入自动化工具–sqlmap 配置环境变量/快捷方式 最终效果: …

数据高效转储,生产轻松支撑

在使用WINDOWS或智能手机的时候,经常会遇到存储空间不足的问题,鲜有人会打开文件管理系统自己逐个清理,不仅因为底层的系统文件繁多操作耗时,更有其操作专业度高、风险高的问题。这时我们往往会求助各种各样的清理大师&#xff0c…