相关代码链接见文末
1.所需基本环境配置
首先,我们需要一个预先训练好的模型以及相应的配置。接下来,为了实际应用这个模型,我们必须搭建一个功能强大的服务器。这台服务器的核心任务是加载我们的模型,并能够接收用户上传的图片。一旦图片被接收,服务器将使用加载的模型进行预测,并迅速将预测结果返回给用户。这样,整个预测流程就能在服务器上高效、准确地完成。
2.模型加载与数据预处理
在run_pytorch_server.py中定义了模型加载和数据集预处理模块,流程如下:
(1)首先,初始化Flask app
app = flask.Flask(__name__)
model = None
use_gpu = False
(2)加载标签信息,标签信息为字典信息,将预测结果对应到实际的类别
# 加载标签信息
with open('imagenet_class.txt', 'r') as f:
idx2label = eval(f.read())
(3)加载模型,这里加载的resnet50模型
# 加载模型进来
def load_model():
"""Load the pre-trained model, you can use your model just as easily.
"""
global model
model = resnet50(pretrained=True)
model.eval()
if use_gpu:
model.cuda()
(4)数据预处理模块
数据预处理包括对图像进行resize,转化为tensor,对图像进行标准化。
# 数据预处理
def prepare_image(image, target_size):
"""Do image preprocessing before prediction on any data.
:param image: original image
:param target_size: target image size
:return:
preprocessed image
"""
# pytorch输入的是RGB格式
if image.mode != 'RGB':
image = image.convert("RGB")
# 将图像resize特定大小并转化为tensor格式
# Resize the input image nad preprocess it.
image = T.Resize(target_size)(image)
image = T.ToTensor()(image)
# 标准化
# Convert to Torch.Tensor and normalize. mean与std
image = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
# Add batch_size axis.
image = image[None]
if use_gpu:
image = image.cuda()
return Variable(image, volatile=True) #不需要求导
(5)开启服务
然后是开启服务,实现数据输入、数据预处理、模型预测、返回整个标签的整个流程。
# 开启服务
@app.route("/predict", methods=["POST"])
def predict():
# Initialize the data dictionary that will be returned from the view.
data = {"success": False}
# Ensure an image was properly uploaded to our endpoint.
if flask.request.method == 'POST':
if flask.request.files.get("image"):
# Read the image in PIL format
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image)) # 读取数据,二进制数据
# Preprocess the image and prepare it for classification. 数据预处理
image = prepare_image(image, target_size=(224, 224))
# Classify the input image and then initialize the list of predictions to return to the client.
preds = F.softmax(model(image), dim=1) # 预测概率
results = torch.topk(preds.cpu().data, k=3, dim=1) # 返回概率最高的前k个
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
data['predictions'] = list()
# 返回最终的标签
# Loop over the results and add them to the list of returned predictions
for prob, label in zip(results[0][0], results[1][0]):
label_name = idx2label[label]
r = {"label": label_name, "probability": float(prob)}
data['predictions'].append(r)
# Indicate that the request was a success.
data["success"] = True
# Return the data dictionary as a JSON response.
return flask.jsonify(data)
3.预测效果展示
首先,使用命令行,启动服务
simple_request.py定义了post请求及返回结果,执行simple_request.py时,需要指定文件路径。
返回结果
访问记录
链接:https://pan.baidu.com/s/12nhoFcZWLD1_ticGprawUg?pwd=iujk
提取码:iujk