推理模块
模型训练完成后,需要单独再写一个推理模块来供用户测试或者使用,该模块可以命名为test.py或者inference.py,导入训练好的模型文件和待测试的图像,输出该图像的分割结果。inference.py主体部分如代码11-7所示。
代码11-7 推理模块部分
# 导入相关库
import numpy as np
import torch
from PIL import Image
# 定义推理函数
def inference(model, test_img):
img = Image.open(test_img)
img = val_transform(img)
img = img.unsqueeze(0).to('cuda')
with torch.no_grad():
outputs = model(img)
preds = outputs.detach().max(dim=1)[1].cpu().numpy()
print(preds.shape)
pred = VOCSegmentation.decode_target(preds[0]).astype(np.uint8)
Image.fromarray(pred).save(os.path.join('s%_pred.png' % test_img.split('.')[0]))
上述代码仅展示推理模块的主体部分,完整代码可参考本书配套的对应章节代码文件。实际执行时,我们可以在命令行通过传入待测试图像和模型文件执行inference.py。测试示例如下:
python inference.py --data_root 2007_000676.jpg --model deeplabv3plus_resnet101
测试图像和模型预测结果示例如图11-5所示。
部署模块
虽然我们可以通过推理模块来测试模型效果,但推理毕竟不是面向用户级的使用体验。为了能够在常见的用户端使用我们的分割模型,还需要对模型进行工程化的部署(deployment)。根据分割模型的应用场景,一般最常见的部署场景是web端部署或者是基于C++的软件集成部署。web端部署一般基于Flask等后端部署框架来完成,形式上可以分为为REST API和web应用两种表现形式。
一个web服务简单而言就是用户从客户端发送一个HTTP请求,然后服务器收到请求后生成HTML文档作为响应返回给客户端的过程。当返回的内容需要在前端页面上呈现时,这个服务就是一个web端的应用;当返回内容不需要在前端页面体现,而是直接以JSON等数据结构给用户时,这个服务就是一个REST API。
Flask是一个基于Python的轻量级web应用框架,非常简洁和灵活,也便于初学者快速上手。简单几行代码即可快速定义一个web服务,如代码11-8所示。
# 导入flask相关模块
from flask import Flask, jsonify
# 创建应用
app = Flask(__name__)
# 定义预测路由
@app.route('/predict', methods=['POST'])
def predict():
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
本节我们将分别展示基于PASCAL VOC 2012训练的Deeplab v3+模型的REST API和web应用部署方式。
REST API部署
基于REST API部署相对较为简单,我们可以直接编写一个api.py的文件,将推理流程融入到Flask的预测路由函数中即可。在此之前需要先导入训练好的模型以及定义跟验证时同样的数据转换方法。基于上述策略可定义api.py如下:
代码11-9 REST API部署
# 导入相关库
import torch
from torchvision import transforms
from PIL import Image
import io
import numpy as np
from utils import ext_transforms_new as et
from datasets import VOCSegmentation
from flask import Flask, request, jsonify
import models
# 创建应用
app = Flask(__name__)
# 模型字典
model_map = {
'deeplabv3plus_resnet50': models.deeplabv3plus_resnet50,
'deeplabv3plus_resnet101': models.deeplabv3plus_resnet101,
}
# 创建模型
model = model_map['deeplabv3plus_resnet101'](num_classes=21,
output_stride=16)
# 导入模型
model.load_state_dict(torch.load('../checkpoints/deeplabv3plus_resnet101_voc.pth')['model_state'])
model.to('cuda')
model.eval()
print('model loaded.')
# 定义数据转换方法
transform = et.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# 定义模型预测路由
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
# 从请求中读取输入图像
image = request.files['image'].read()
image = Image.open(io.BytesIO(image))
# 图像变换
input_tensor = transform(image).unsqueeze(0).to('cuda')
# 模型预测
with torch.no_grad():
output = model(input_tensor)
preds = output.detach().max(dim=1)[1].cpu().numpy()
print(preds.shape)
# 对输出进行解码,转换为mask
preds = VOCSegmentation.decode_target(preds[0]).astype(np.uint8)
# 转换成list
result = preds.tolist()
return jsonify(result)
if __name__ == '__main__':
app.run(debug=True)
定义好app.py以后,直接在命令行启动该REST API服务:
python app.py
然后再单独启动一个Python终端,通过requests库发起post请求,传入一张待分割图像:
resp = requests.post("http://localhost:5000/predict", files={"image": open('./deployment/2007_000676.jpg', 'rb')})
这时候我们可以在服务端看到相关响应信息,如图11-6所示。状态码显示为200,说明请求成功,返回数据可以在requests返回对象中查看。
web端部署
REST API的部署方式更多的是方便开发者使用,对于普通用户可能不是那么友好。为了更加方便用户使用和更直观的展示模型效果,我们可以通过web端部署的方式,让用户上传图像作为输入,并将输入图像和分割结果直接在网页上显示。所以与API部署方式不同的是需要加上一个index.html的网页模板文件,将输入和分割结果在网页模板上进行渲染。同时原先的api.py文件也需要进行修改,修改后的文件可命名为app.py,主体部分如代码11-10所示。
代码11-10 web端应用app.py
# 创建应用
app = Flask(__name__)
# 定义上传和预测路由
@app.route('/', methods=['GET', 'POST'])
def upload_predict():
# POST请求后读取图像
if request.method == 'POST':
image_file = request.files['image']
if image_file:
image_location = os.path.join(
app.config['UPLOAD_FOLDER'],
image_file.filename
)
image_file.save(image_location)
# 图像变换
image = Image.open(image_location).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to('cuda')
# 模型预测
with torch.no_grad():
output = model(input_tensor)
preds = output.detach().max(dim=1)[1].cpu().numpy()
print(preds.shape)
# mask解码
preds = VOCSegmentation.decode_target(preds[0]).astype(np.uint8)
# 保存图像到指定路径
segmented_image = Image.fromarray(preds)
segmented_image_path = image_location.replace('.jpg', '_segmented.jpg')
segmented_image.save(segmented_image_path)
display_input_path = '../' + image_location
display_segmented_path = '../' + segmented_image_path
# 渲染结果到网页
return render_template('index.html', input_image=display_input_path, segmented_image=display_segmented_path)
return render_template('index.html')
代码11-10与api.py的主要区别在于读取图像部分是需要读取用户上传到指定目录下的图像,并且对输入图像和分割结果渲染呈现到网页端。index.html是网页HTML的模板文件,我们可以通过编辑该文件来实现自己想要的网页效果。
执行app.py文件启动web服务,然后打开服务运行地址:http:127.0.0.1:5000即可看到网页端效果,在网页点击“选择文件”上传输入图像,然后点击“Segment”执行模型分割,图11-7为web部署后的使用效果图。
总结
本章以PASCAL VOC 2012数据集和Deeplab v3+分割模型为例,给出了基于PyTorch的深度学习图像分割项目代码框架。一个相对完整的图像分割代码框架应包含:预处理模块、数据导入模块、模型模块、工具函数模块、配置模块、主函数模块、推理模块和部署模块。启中预处理、数据导入、模型、工具函数、配置和主函数模块均为模型训练阶段的工作模块,而推理和部署则属于模型训练完后的测试和使用阶段工作模块。
需要特别说明的是,本章的代码框架仅作为深度学习图像分割项目的一般性框架,具体使用时应根据项目的实际情况酌情参考。
后续全书内容和代码将在github上开源,请关注仓库:
https://github.com/luwill/Deep-Learning-Image-Segmentation
(本章完结,其余章节待续)