深度学习 之 模型部署 使用Flask和PyTorch构建图像分类Web服务

news2024/10/24 7:55:42

引言

随着深度学习的发展,图像分类已成为一项基础的技术,被广泛应用于各种场景之中。本文将介绍如何使用Flask框架和PyTorch库来构建一个简单的图像分类Web服务。通过这个服务,用户可以通过HTTP POST请求上传花朵图片,然后由后端的深度学习模型对其进行分类,并返回分类结果。

环境搭建

首先,确保安装了以下Python库:

  • Flask:用于构建Web应用。
  • PyTorch:用于加载和运行深度学习模型。
  • torchvision:用于图像处理和加载预训练模型。
  • PIL:用于图像处理。

1. 初始化Flask应用

import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models

# 初始化Flask app
app = flask.Flask(__name__)# 创建一个新的Flask应用程序实例
# __name__参数通常被传递给FasK应用程序来定位应用程序的根路径,这样FlasK就可以知道在哪里找到模板、静态文件等。
# 总体来说app = flask.Flask(__name_)是FLaSK应用程序的起点。它初始化了一个新的FLaSK应用程序实例。为后续添加路由、配置等莫定

2. 加载模型

为了方便,我们将预训练好的ResNet18模型,保存在一个名为best.pth的检查点文件中。我们将加载这个模型,并准备好用于推理。

def load_model():
    """Load the pre-trained model, you can use your model just as easily."""
    global model
    # 加载resnet18网络。ResNet(残差网络)是一种深度学习架构,设计用于解决深层神经网络中的梯度消失问题。
    model = models.resnet18()
    # num_ftrs 被赋值为模型全连接层(fc)的输入特征数量。
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 类别数自己根据自己任务来

    # print(model)
    #导入最优模型
    #这行代码实际上是加载了一个预先训练好的模型的权重。
    # torch.load('best.pth') 会加载保存在 best.pth 文件中的模型检查点,
    # 通常这个检查点包含模型的状态字典(state dict),即模型所有层的权重和偏置。
    # model.load_state_dict(checkpoint['state_dict']) 会将加载的状态字典应用到我们的模型上,使模型具有之前训练时学到的参数。
    checkpoint = torch.load('best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    # 将模型指定为测试格式
    model.eval()
    # 是否使用gpu
    if use_gpu:
        model.cuda()

3. 预处理图像

为了使图像符合模型的要求,我们需要对其进行预处理,包括调整大小、转换为张量以及标准化。

def prepare_image(image, target_size):
    # 检查输入图像的颜色模式是否为 RGB。如果不是,则将其转换为 RGB 模式。
    if image.mode != 'RGB':
        image = image.convert('RGB')
    # Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
    # 使用 transforms.Resize 对象将图像调整为目标尺寸 target_size。
    image = transforms.Resize(target_size)(image)
    # 使用 transforms.ToTensor() 将图像转换为 PyTorch 的 Tensor 类型。
    image = transforms.ToTensor()(image)

    # Convert to Torch, Tensor and normalize. mean与std
    # 对图像张量进行标准化处理。
    # 标准化的参数 [0.485, 0.456, 0.406] 是均值,代表每个颜色通道(红、绿、蓝)的平均值;
    # [0.229, 0.224, 0.225] 是标准差,代表每个颜色通道的标准差。
    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
    # Add batch_size axis 增加一个维度,用于按batch测试本次这里一次测试一张
    image = image[None]
    if use_gpu:
        image = image.cuda()  # return torch.tensor(image
    return torch.tensor(image)

4. 设置路由和处理请求

使用Flask设置路由,并处理POST请求中的图像数据。

# 定义了一个名为 predict 的视图函数,并通过装饰器 @app.route 绑定了路由 /predict,允许该路由接收 HTTP POST 请求。
@app.route("/predict", methods=["POST"])
def predict():
    # 做一个标志,刚开始无图像传入时为false,传入图像时为true
    data = {"success": False}
    if flask.request.method == 'POST':  # 检查请求的方法是否为 POST
        if flask.request.files.get("image"):  # 判断是否为图像
            image = flask.request.files["image"].read()  # 将收到的图像进行读取,内容为二进制
            image = Image.open(io.BytesIO(image)) # 将这个二进制字符串转换为一个 PIL 图像对象。

            # 利用上面的预处理函数将读入的图像进行预处理
            image = prepare_image(image, target_size=(224, 224))

            # 将预处理后的图像输入到模型中,并得到一个未归一化的输出向量。
            # 使用 F.softmax 函数将这个输出向量转换为概率分布,这表示模型对于每个类别的预测概率。
            preds = F.softmax(model(image), dim=1)  # 得到各个类别的概率
            # cpu().data 确保结果在 CPU 上,并且不包含梯度信息。dim=1 表示沿着列方向查找最大值。
            results = torch.topk(preds.cpu().data, k=3, dim=1)  # 概率最大的前3个结果# torch.topk用于返回输入张量中每行最大的k个元素及其对应的索引
            # 将结果从 PyTorch 张量转换为 NumPy 数组,以便更容易地处理。results[0] 包含了概率值,而 results[1] 包含了类别索引。
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())
            # 将data字典增加一个key,value,其中value为ist格式
            data['predictions'] = list()
            for probability, label in zip(results[0][0], results[1][0]):
                # Label name =idx2labellstr(label)]
                r = {"label": str(label), "probability": float(probability)}
                # 将预测结果添加至data字典
                data['predictions'].append(r)
    # Indicate that the reguest was a success.
    data["success"] = True


    return flask.jsonify(data)  # 将最后结果以json格式文件传出,并返回给客户端。

5. 启动服务

最后,在主入口处启动Flask服务,并加载模型。

if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    load_model() #加载模型
    app.run(host='192.168.24.45', port=5012) #启动服务器,IP地址,端口

我们点击运行即可启动服务器,保持程序运行客户端即可通过ip地址和端口访问

接口客户端实现

在上一部分中,我们完成了基于Flask和PyTorch的图像分类Web服务的搭建。接下来,我们将继续探讨如何编写客户端代码来与该服务进行交互。通过编写一个简单的Python脚本来发送HTTP请求,我们可以测试我们的Web服务是否正常工作。

客户端代码实现

为了测试我们的图像分类服务,我们需要编写一段代码来模拟客户端的行为。这段代码将负责向服务端发送包含图像的POST请求,并接收返回的分类结果。

import requests

flask_url = 'http://192.168.24.45:5012/predict'

# 定义一个名为 predict_result 的函数,该函数接受一个参数 image_path,表示要发送给 Flask 应用的图像文件的路径。
def predict_result(image_path):
    # 使用 open 函数以二进制模式 ('rb') 打开图像文件,并读取其内容。
    image = open(image_path, 'rb').read()
    # 将图像内容包装到一个字典 payload 中,键为 'image',值为图像的二进制内容。
    payload = {'image': image}
    # 使用 requests.post 方法发送一个 POST 请求到 Flask 应用,其中 files 参数用于上传文件。
    # files=payload 表示将 payload 字典中的内容作为文件上传。
    r = requests.post(flask_url, files=payload).json()  # .json() 方法将响应内容解析为 Python 字典形式,方便后续处理。
    if r['success']:  # 检查响应中的 success 键是否为 True。如果为 True,则意味着请求成功,并且会打印出预测结果。
        for (i, result) in enumerate(r['predictions']): print(
            '{}.预测类别为{}:的概率:{}'.format(i + 1, result['label'], result['probability']))
        print('OK')  # 预测结果存储在 r['predictions'] 列表中,每个预测结果都是一个字典,包含类别标签 ("label") 和概率 ("probability")。
    else:  # 失败打印
        print('Request failed')
if __name__ == '__main__':
    predict_result('../data/6/image_07162.jpg')

预测图像

本次实验随机采用一张花的图片上传到到服务端


预测结果

客户端访问记录

当我们通过客户端访问服务端时,可通过后台查看访问记录

总结

通过以上步骤,我们构建了一个简单的图像分类Web服务。用户可以通过发送POST请求并将图像作为附件上传,然后服务端会对图像进行分类,并返回最有可能的三个类别及其概率。这种服务可以用于各种场合,如在线图像识别、产品分类等。

希望这篇文章能帮助你了解如何使用Flask和PyTorch快速搭建一个图像分类的服务,并激发你在实际项目中的应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2222224.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Qt】QTableView添加下拉框过滤条件

实现通过带复选框的下拉框来为表格添加过滤条件 带复选框的下拉框 .h文件 #pragma once #include <QCheckBox> #include <QComboBox> #include <QEvent> #include <QLineEdit> #include <QListWidget>class TableComboBox : public QComboBox …

前端构建工具vite的优势

1. 极速冷启动 Vite 使用原生 ES 模块 (ESM) 在开发环境下进行工作。相比于传统构建工具需要打包所有的文件&#xff0c;Vite 只在浏览器请求模块时动态加载所需的文件。无打包冷启动&#xff1a;无需预先打包&#xff0c;项目启动非常快&#xff0c;尤其对于大型项目效果更明…

顺序表(一)(数据结构)

一. 线性表 线性表&#xff08;linear list&#xff09;是n个具有相同特性的数据元素的有限序列 。 线性表是一种在实际中广泛使用的数据结构&#xff0c;常见的线性表&#xff1a;顺序表、链表、栈、队列、字符串... 线性表在逻辑上是线性结构&#xff0c;是人为想象出来的数…

HCIP--1

同一区域内的OSPF路由器拥有一致的 LSDB, 在区域内&#xff0c;OSPF 采用 SPF算法计算路由一个区域太多路由器&#xff0c;硬件资源跟不上&#xff0c;所以多划分区域 OSPF 路由计算原理 1. 区域内路由计算 LSA 在OSPF中&#xff0c;每个路由器生成 LSA&#xff0c;用于告诉…

【部署篇】RabbitMq-03集群模式部署

一、准备主机 准备3台主机用于rabbitmq部署&#xff0c;文章中是在centos7上安装部署rabbitmq3.8通过文章中介绍的方式可以同样在centos8、centos9上部署&#xff0c;只需下载对应的版本进行相同的操作。 主机IP角色说明192.168.128.31种子节点192.168.128.32普通节点192.16…

Matlab学习01-矩阵

目录 一&#xff0c;矩阵的创建 1&#xff0c;直接输入法创建矩阵 2&#xff0c;利用M文件创建矩阵 3&#xff0c;利用其它文本编辑器创建矩阵 二&#xff0c;矩阵的拼接 1&#xff0c;基本拼接 1&#xff09; 水平方向的拼接 2&#xff09;垂直方向的拼接 3&#xf…

Linux系统基础-进程间通信(5)_模拟实现命名管道和共享内存

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 Linux系统基础-进程间通信(5)_模拟实现命名管道和共享内存 收录于专栏[Linux学习] 本专栏旨在分享学习Linux的一点学习笔记&#xff0c;欢迎大家在评论区交流讨…

LeetCode--删除并获得点数--动态规划

一、题目解析 二、算法原理 根据题意&#xff0c;在选择了元素 x 后&#xff0c;该元素以及所有等于 x−1 或 x1 的元素会从数组中删去。若还有多个值为 x 的元素&#xff0c;由于所有等于 x−1 或 x1 的元素已经被删除&#xff0c;我们可以直接删除 x 并获得其点数。因此若选…

Win10+MinGW13.1.0编译Qt5.15.15

安装windows SDK、python、ruby、cmake、Perl[可选]安装MySQL解压qt-everywhere-opensource-src-5.15.15.zip&#xff08;注&#xff1a;不要使用qt-everywhere-opensource-src-5.15.15.tar.xz&#xff09;修改源代码 E:\qt-everywhere-src-5.15.15\qtbase\src\3rdparty\angle\…

hive数据库,表操作

1.创建; create database if not exists myhive; use myhive; 2.查看: 查看数据库详细信息:desc database myhive; 默认数据库的存放路径是 HDFS 的&#xff1a; /user/hive/warehouse 内 补充:创建数据库并指定 hdfs 存储位置:create database myhive2 location /myhive2 3.…

<项目代码>YOLOv8路面垃圾识别<目标检测>

YOLOv8是一种单阶段&#xff08;one-stage&#xff09;检测算法&#xff0c;它将目标检测问题转化为一个回归问题&#xff0c;能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法&#xff08;如Faster R-CNN&#xff09;&#xff0c;YOLOv8具有更高的…

Codeforces Round 881 (Div. 3)(A~F1题解)

这场比赛可能是因为比较老吧&#xff0c;我感觉很轻松&#xff0c;就是第五个卡了一下&#xff0c;看错题了&#xff0c;原本应该是严格大于的&#xff0c;从而导致代码一直出现bug&#xff0c;在1小时20分钟才解决 A. Sasha and Array Coloring 题意&#xff1a;就是说给你n个…

提权 | Windows系统

文章目录 cmd提权meterpreter提权getsystemsteal_tokenmigrate 令牌窃取(MS16-075)烂土豆提权步骤烂土豆提权原理 sc命令提权抓本地密码提权其他工具pr工具 内核提权WindowsVulScan cmd提权 前言&#xff1a;我们getshell一个用windows部署的网站后&#xff0c;通过蚁剑或者其…

ESP32 S3 语音识别 语音唤醒程序流程

ESP32 S3 语音识别 语音唤醒程序流程 参考例程首先进行esp_periph_set_init 初始化之后执行setup_player&#xff0c;之后执行start_recorder&#xff0c;识别的主处理voice_read_task 参考例程 D:\Espressif\esp-adf\examples\speech_recognition\wwe\ 首先进行esp_periph_se…

零知识学习WLAN漫游二、无线漫游介绍(2)

接前一篇文章&#xff1a;零知识学习WLAN漫游一、无线漫游介绍&#xff08;1&#xff09; 本文内容参考&#xff1a; WLAN漫游简介_漫游主动性-CSDN博客 无线漫游_百度百科 无线漫游简述-CSDN博客 特此致谢&#xff01; 一、WLAN漫游简介 3. 漫游协议和快速漫游协议 802.…

算法的学习笔记—数字在排序数组中出现的次数(牛客JZ53)

&#x1f600;前言 在编程中&#xff0c;查找有序数组中特定元素的出现次数是一个常见的问题。本文将详细讲解这个问题的解决方案&#xff0c;并通过二分查找法优化效率。 &#x1f3e0;个人主页&#xff1a;尘觉主页 文章目录 &#x1f970;数字在排序数组中出现的次数&#x…

九、pico+Unity交互开发——触碰抓取

一、VR交互的类型 Hover&#xff08;悬停&#xff09; 定义&#xff1a;发起交互的对象停留在可交互对象的交互区域。例如&#xff0c;当手触摸到物品表面&#xff08;可交互区域&#xff09;时&#xff0c;视为触发了Hover。 Grab&#xff08;抓取&#xff09; 概念&#xff…

深入浅出:深度学习模型部署全流程详解

博主简介&#xff1a;努力学习的22级计算机科学与技术本科生一枚&#x1f338;博主主页&#xff1a; Yaoyao2024往期回顾&#xff1a; 【论文精读】PSAD&#xff1a;小样本部件分割揭示工业异常检测的合成逻辑每日一言&#x1f33c;: 生活要有所期待&#xff0c; 否则就如同罩在…

【国潮来袭】华为原生鸿蒙 HarmonyOS NEXT(5.0)正式发布:鸿蒙诞生以来最大升级,碰一碰、小艺圈选重磅上线

在昨日晚间的原生鸿蒙之夜暨华为全场景新品发布会上&#xff0c;华为原生鸿蒙 HarmonyOS NEXT&#xff08;5.0&#xff09;正式发布。 华为官方透露&#xff0c;截至目前&#xff0c;鸿蒙操作系统在中国市场份额占据 Top2 的领先地位&#xff0c;拥有超过 1.1 亿 的代码行和 6…

想让前后端交互更轻松?alovajs了解一下?

作为一个前端开发者&#xff0c;我最近发现了一个超赞的请求库 alovajs&#xff0c;它真的让我眼前一亮&#xff01;说实话&#xff0c;我感觉自己找到了前端开发的新大陆。大家知道&#xff0c;在前端开发中&#xff0c;处理 Client-Server 交互一直是个老大难的问题&#xff…