Skip to content

通过带有 Flask 的 REST API 在 Python 中部署 PyTorch

原文: https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html

注意

单击此处的下载完整的示例代码

作者Avinash Sajjanshetty

在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开用于模型推理的 REST API。 特别是,我们将部署预训练的 DenseNet 121 模型来检测图像。

小贴士

此处使用的所有代码均以 MIT 许可发布,可在 Github 上找到。

这是在生产中部署 PyTorch 模型的系列教程中的第一篇。 到目前为止,以这种方式使用 Flask 是开始为 PyTorch 模型提供服务的最简单方法,但不适用于具有高性能要求的用例。 为了那个原因:

API 定义

我们将首先定义 API 端点,请求和响应类型。 我们的 API 端点将位于/predict,该端点通过包含图片的file参数接受 HTTP POST 请求。 响应将是包含预测的 JSON 响应:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

依存关系

通过运行以下命令来安装所需的依赖项:

$ pip install Flask==1.0.3 torchvision-0.3.0

简单的 Web 服务器

以下是一个简单的网络服务器,摘自 Flask 的文档

from flask import Flask
app = Flask(__name__)

@app.route('/')
def hello():
    return 'Hello World!'

将以上代码段保存在名为app.py的文件中,您现在可以通过输入以下内容来运行 Flask 开发服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

当您在网络浏览器中访问http://localhost:5000/时,将看到Hello World!文字

我们将对上面的代码段进行一些更改,以使其适合我们的 API 定义。 首先,我们将方法重命名为predict。 我们将端点路径更新为/predict。 由于图像文件将通过 HTTP POST 请求发送,因此我们将对其进行更新,使其也仅接受 POST 请求:

@app.route('/predict', methods=['POST'])
def predict():
    return 'Hello World!'

我们还将更改响应类型,以使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。 更新后的app.py文件现在为:

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在下一部分中,我们将重点介绍编写推理代码。 这将涉及两部分,第一部分是准备图像,以便可以将其馈送到 DenseNet;第二部分,我们将编写代码以从模型中获取实际的预测。

准备图像

DenseNet 模型要求图像为尺寸为 224 x 224 的 3 通道 RGB 图像。我们还将使用所需的均值和标准偏差值对图像张量进行归一化。 您可以在上阅读有关它的更多信息。

我们将使用torchvision库中的transforms并建立一个转换管道,该转换管道可根据需要转换图像。 您可以在上阅读有关变换的更多信息。

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。 要测试上述方法,请以字节模式读取图像文件(首先将 <cite>../_static/img/sample_file.jpeg</cite> 替换为计算机上文件的实际路径),然后查看是否获得张量 背部:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

出:

tensor([[[[ 0.4508,  0.4166,  0.3994,  ..., -1.3473, -1.3302, -1.3473],
          [ 0.5364,  0.4851,  0.4508,  ..., -1.2959, -1.3130, -1.3302],
          [ 0.7077,  0.6392,  0.6049,  ..., -1.2959, -1.3302, -1.3644],
          ...,
          [ 1.3755,  1.3927,  1.4098,  ...,  1.1700,  1.3584,  1.6667],
          [ 1.8893,  1.7694,  1.4440,  ...,  1.2899,  1.4783,  1.5468],
          [ 1.6324,  1.8379,  1.8379,  ...,  1.4783,  1.7352,  1.4612]],

         [[ 0.5728,  0.5378,  0.5203,  ..., -1.3704, -1.3529, -1.3529],
          [ 0.6604,  0.6078,  0.5728,  ..., -1.3004, -1.3179, -1.3354],
          [ 0.8529,  0.7654,  0.7304,  ..., -1.3004, -1.3354, -1.3704],
          ...,
          [ 1.4657,  1.4657,  1.4832,  ...,  1.3256,  1.5357,  1.8508],
          [ 2.0084,  1.8683,  1.5182,  ...,  1.4657,  1.6583,  1.7283],
          [ 1.7458,  1.9384,  1.9209,  ...,  1.6583,  1.9209,  1.6408]],

         [[ 0.7228,  0.6879,  0.6531,  ..., -1.6476, -1.6302, -1.6476],
          [ 0.8099,  0.7576,  0.7228,  ..., -1.6476, -1.6476, -1.6650],
          [ 1.0017,  0.9145,  0.8797,  ..., -1.6476, -1.6650, -1.6999],
          ...,
          [ 1.6291,  1.6291,  1.6465,  ...,  1.6291,  1.8208,  2.1346],
          [ 2.1868,  2.0300,  1.6814,  ...,  1.7685,  1.9428,  2.0125],
          [ 1.9254,  2.0997,  2.0823,  ...,  1.9428,  2.2043,  1.9080]]]])

预测

现在将使用预训练的 DenseNet 121 模型来预测图像类别。 我们将使用torchvision库中的一个,加载模型并进行推断。 在此示例中,我们将使用预训练的模型,但您可以对自己的模型使用相同的方法。 在此教程中查看有关加载模型的更多信息。

from torchvision import models

# Make sure to pass `pretrained` as `True` to use the pretrained weights:
model = models.densenet121(pretrained=True)
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

张量y_hat将包含预测的类 ID 的索引。 但是,我们需要一个人类可读的类名。 为此,我们需要一个类 ID 来进行名称映射。 将这个文件下载为imagenet_class_index.json,并记住它的保存位置(或者,如果您按照本教程中的确切步骤操作,请将其保存在 <cite>tutorials / _static</cite> 中)。 该文件包含 ImageNet 类 ID 到 ImageNet 类名称的映射。 我们将加载此 JSON 文件并获取预测索引的类名称。

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

在使用imagenet_class_index字典之前,首先我们将张量值转换为字符串值,因为imagenet_class_index字典中的键是字符串。 我们将测试上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

Out:

['n02124075', 'Egyptian_cat']

您应该得到如下响应:

['n02124075', 'Egyptian_cat']

数组中的第一项是 ImageNet 类 ID,第二项是人类可读的名称。

Note

您是否注意到model变量不属于get_prediction方法? 还是为什么模型是全局变量? 就内存和计算而言,加载模型可能是一项昂贵的操作。 如果我们在get_prediction方法中加载模型,则每次调用该方法时都会不必要地加载该模型。 由于我们正在构建一个 Web 服务器,因此每秒可能有成千上万的请求,因此我们不应该浪费时间为每个推断重复加载模型。 因此,我们仅将模型加载到内存中一次。 在生产系统中,必须有效利用计算以能够大规模处理请求,因此通常应在处理请求之前加载模型。

将模型集成到我们的 API 服务器中

在最后一部分中,我们将模型添加到 Flask API 服务器中。 由于我们的 API 服务器应该获取图像文件,因此我们将更新predict方法以从请求中读取文件:

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

app.py文件现在完成。 以下是完整版本; 将路径替换为保存文件的路径,它应运行:

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

if __name__ == '__main__':
    app.run()

让我们测试一下我们的网络服务器! 跑:

$ FLASK_ENV=development FLASK_APP=app.py flask run

我们可以使用请求库将 POST 请求发送到我们的应用:

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

现在打印 <cite>resp.json()</cite>将显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

下一步

我们编写的服务器非常琐碎,可能无法完成生产应用程序所需的一切。 因此,您可以采取一些措施来改善它:

  • 端点/predict假定请求中始终会有一个图像文件。 并非所有请求都适用。 我们的用户可能发送带有其他参数的图像,或者根本不发送任何图像。
  • 用户也可以发送非图像类型的文件。 由于我们没有处理错误,因此这将破坏我们的服务器。 添加显式的错误处理路径将引发异常,这将使我们能够更好地处理错误的输入
  • 即使模型可以识别大量类别的图像,也可能无法识别所有图像。 增强实现以处理模型无法识别图像中的任何情况的情况。
  • 我们在开发模式下运行 Flask 服务器,该服务器不适合在生产中进行部署。 您可以查看本教程的,以在生产环境中部署 Flask 服务器。
  • 您还可以通过创建一个带有表单的页面来添加 UI,该表单可以拍摄图像并显示预测。 查看类似项目的演示及其源代码
  • 在本教程中,我们仅展示了如何构建可以一次返回单个图像预测的服务。 我们可以修改我们的服务,以便能够一次返回多个图像的预测。 此外,服务流媒体库会自动将对服务的请求排队,并将请求采样到微型批次中,这些微型批次可输入模型中。 您可以查看本教程
  • 最后,我们鼓励您在页面顶部查看链接到的有关部署 PyTorch 模型的其他教程。

脚本的总运行时间:(0 分钟 1.971 秒)

Download Python source code: flask_rest_api_tutorial.py Download Jupyter notebook: flask_rest_api_tutorial.ipynb

由狮身人面像画廊生成的画廊



回到顶部