引言
随着深度学习的发展,图像分类已成为一项基础的技术,被广泛应用于各种场景之中。本文将介绍如何使用Flask框架和PyTorch库来构建一个简单的图像分类Web服务。通过这个服务,用户可以通过HTTP POST请求上传花朵图片,然后由后端的深度学习模型对其进行分类,并返回分类结果。
环境搭建
首先,确保安装了以下Python库:
- Flask:用于构建Web应用。
- PyTorch:用于加载和运行深度学习模型。
- torchvision:用于图像处理和加载预训练模型。
- PIL:用于图像处理。
1. 初始化Flask应用
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
# 初始化Flask app
app = flask.Flask(__name__)# 创建一个新的Flask应用程序实例
# __name__参数通常被传递给FasK应用程序来定位应用程序的根路径,这样FlasK就可以知道在哪里找到模板、静态文件等。
# 总体来说app = flask.Flask(__name_)是FLaSK应用程序的起点。它初始化了一个新的FLaSK应用程序实例。为后续添加路由、配置等莫定
2. 加载模型
为了方便,我们将预训练好的ResNet18模型,保存在一个名为best.pth
的检查点文件中。我们将加载这个模型,并准备好用于推理。
def load_model():
"""Load the pre-trained model, you can use your model just as easily."""
global model
# 加载resnet18网络。ResNet(残差网络)是一种深度学习架构,设计用于解决深层神经网络中的梯度消失问题。
model = models.resnet18()
# num_ftrs 被赋值为模型全连接层(fc)的输入特征数量。
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 类别数自己根据自己任务来
# print(model)
#导入最优模型
#这行代码实际上是加载了一个预先训练好的模型的权重。
# torch.load('best.pth') 会加载保存在 best.pth 文件中的模型检查点,
# 通常这个检查点包含模型的状态字典(state dict),即模型所有层的权重和偏置。
# model.load_state_dict(checkpoint['state_dict']) 会将加载的状态字典应用到我们的模型上,使模型具有之前训练时学到的参数。
checkpoint = torch.load('best.pth')
model.load_state_dict(checkpoint['state_dict'])
# 将模型指定为测试格式
model.eval()
# 是否使用gpu
if use_gpu:
model.cuda()
3. 预处理图像
为了使图像符合模型的要求,我们需要对其进行预处理,包括调整大小、转换为张量以及标准化。
def prepare_image(image, target_size):
# 检查输入图像的颜色模式是否为 RGB。如果不是,则将其转换为 RGB 模式。
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
# 使用 transforms.Resize 对象将图像调整为目标尺寸 target_size。
image = transforms.Resize(target_size)(image)
# 使用 transforms.ToTensor() 将图像转换为 PyTorch 的 Tensor 类型。
image = transforms.ToTensor()(image)
# Convert to Torch, Tensor and normalize. mean与std
# 对图像张量进行标准化处理。
# 标准化的参数 [0.485, 0.456, 0.406] 是均值,代表每个颜色通道(红、绿、蓝)的平均值;
# [0.229, 0.224, 0.225] 是标准差,代表每个颜色通道的标准差。
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
# Add batch_size axis 增加一个维度,用于按batch测试本次这里一次测试一张
image = image[None]
if use_gpu:
image = image.cuda() # return torch.tensor(image
return torch.tensor(image)
4. 设置路由和处理请求
使用Flask设置路由,并处理POST请求中的图像数据。
# 定义了一个名为 predict 的视图函数,并通过装饰器 @app.route 绑定了路由 /predict,允许该路由接收 HTTP POST 请求。
@app.route("/predict", methods=["POST"])
def predict():
# 做一个标志,刚开始无图像传入时为false,传入图像时为true
data = {"success": False}
if flask.request.method == 'POST': # 检查请求的方法是否为 POST
if flask.request.files.get("image"): # 判断是否为图像
image = flask.request.files["image"].read() # 将收到的图像进行读取,内容为二进制
image = Image.open(io.BytesIO(image)) # 将这个二进制字符串转换为一个 PIL 图像对象。
# 利用上面的预处理函数将读入的图像进行预处理
image = prepare_image(image, target_size=(224, 224))
# 将预处理后的图像输入到模型中,并得到一个未归一化的输出向量。
# 使用 F.softmax 函数将这个输出向量转换为概率分布,这表示模型对于每个类别的预测概率。
preds = F.softmax(model(image), dim=1) # 得到各个类别的概率
# cpu().data 确保结果在 CPU 上,并且不包含梯度信息。dim=1 表示沿着列方向查找最大值。
results = torch.topk(preds.cpu().data, k=3, dim=1) # 概率最大的前3个结果# torch.topk用于返回输入张量中每行最大的k个元素及其对应的索引
# 将结果从 PyTorch 张量转换为 NumPy 数组,以便更容易地处理。results[0] 包含了概率值,而 results[1] 包含了类别索引。
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
# 将data字典增加一个key,value,其中value为ist格式
data['predictions'] = list()
for probability, label in zip(results[0][0], results[1][0]):
# Label name =idx2labellstr(label)]
r = {"label": str(label), "probability": float(probability)}
# 将预测结果添加至data字典
data['predictions'].append(r)
# Indicate that the reguest was a success.
data["success"] = True
return flask.jsonify(data) # 将最后结果以json格式文件传出,并返回给客户端。
5. 启动服务
最后,在主入口处启动Flask服务,并加载模型。
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
load_model() #加载模型
app.run(host='192.168.24.45', port=5012) #启动服务器,IP地址,端口
我们点击运行即可启动服务器,保持程序运行客户端即可通过ip地址和端口访问
接口客户端实现
在上一部分中,我们完成了基于Flask和PyTorch的图像分类Web服务的搭建。接下来,我们将继续探讨如何编写客户端代码来与该服务进行交互。通过编写一个简单的Python脚本来发送HTTP请求,我们可以测试我们的Web服务是否正常工作。
客户端代码实现
为了测试我们的图像分类服务,我们需要编写一段代码来模拟客户端的行为。这段代码将负责向服务端发送包含图像的POST请求,并接收返回的分类结果。
import requests
flask_url = 'http://192.168.24.45:5012/predict'
# 定义一个名为 predict_result 的函数,该函数接受一个参数 image_path,表示要发送给 Flask 应用的图像文件的路径。
def predict_result(image_path):
# 使用 open 函数以二进制模式 ('rb') 打开图像文件,并读取其内容。
image = open(image_path, 'rb').read()
# 将图像内容包装到一个字典 payload 中,键为 'image',值为图像的二进制内容。
payload = {'image': image}
# 使用 requests.post 方法发送一个 POST 请求到 Flask 应用,其中 files 参数用于上传文件。
# files=payload 表示将 payload 字典中的内容作为文件上传。
r = requests.post(flask_url, files=payload).json() # .json() 方法将响应内容解析为 Python 字典形式,方便后续处理。
if r['success']: # 检查响应中的 success 键是否为 True。如果为 True,则意味着请求成功,并且会打印出预测结果。
for (i, result) in enumerate(r['predictions']): print(
'{}.预测类别为{}:的概率:{}'.format(i + 1, result['label'], result['probability']))
print('OK') # 预测结果存储在 r['predictions'] 列表中,每个预测结果都是一个字典,包含类别标签 ("label") 和概率 ("probability")。
else: # 失败打印
print('Request failed')
if __name__ == '__main__':
predict_result('../data/6/image_07162.jpg')
预测图像
本次实验随机采用一张花的图片上传到到服务端
预测结果
客户端访问记录
当我们通过客户端访问服务端时,可通过后台查看访问记录
总结
通过以上步骤,我们构建了一个简单的图像分类Web服务。用户可以通过发送POST请求并将图像作为附件上传,然后服务端会对图像进行分类,并返回最有可能的三个类别及其概率。这种服务可以用于各种场合,如在线图像识别、产品分类等。
希望这篇文章能帮助你了解如何使用Flask和PyTorch快速搭建一个图像分类的服务,并激发你在实际项目中的应用。