Web service deployment deep learning model - Zhihu

Posted by urb on Wed, 27 May 2020 07:12:58 +0200

The purpose of this article is to introduce how to use Web services to deploy deep learning model quickly. Although TF has TFserving to deploy the model, it can't do anything for python (if you want to use it, you need to transform the torch model, which is a bit troublesome); therefore, this article introduces a method of using Web services to deploy deep learning (simple and effective, don't like to spray).

This paper takes a simple news classification model as an example, model: BERT; data source: Tsinghua news corpus (address:

THUCTC: an efficient Chinese text classification tool )There are 14 categories of Tsinghua news corpus: sports, entertainment, home furnishing, lottery, real estate, education, fashion, current affairs, constellation, game, society, technology, stock and finance. In order to quickly train the model, I randomly selected 1000 training sets and 200 verification sets in each category. The data preprocessing, model training and pb model saving codes are as follows: News classification model training github address . (it's not important, but more about it. There are detailed instructions on github. If you have any questions, please leave a message.)

In order to make the deployment of web Services simple, I construct a method class, which is convenient to load pb model, preprocess the incoming text and predict the model.

The model initialization code is as follows:

import bert_tokenization
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
import os

class ClassificationModel(object):
    def __init__(self):
        self.tokenizer = None
        self.sess = None
        self.is_train = None
        self.input_ids = None
        self.input_mask = None
        self.segment_ids = None
        self.predictions = None
        self.max_seq_length = None
        self.label_dict = ['Sports', 'entertainment', 'Home Furnishing', 'lottery', 'house property', 'education', 'fashion', 'Current affairs', 'constellation', 'game', 'Sociology', 'science and technology', 'shares', 'Finance and Economics']

Where, tokenizer is the word breaker; sessions is the session module of TF; is_train,input_ids,input_mask and segment_ids is the input of pb model; predictions is the output of pb model; max_seq_length is the maximum input length of the model; label_dict is the news category label.

Load the pb model code as follows:

def load_model(self, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    self.tokenizer = bert_tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
    sess_config = tf.ConfigProto(gpu_options=gpu_options)
    self.sess = tf.Session(config=sess_config)
    with gfile.FastGFile(model_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        self.sess.graph.as_default()
        tf.import_graph_def(graph_def, name="")

    self.sess.run(tf.global_variables_initializer())
    self.is_train = self.sess.graph.get_tensor_by_name("input/is_train:0")
    self.input_ids = self.sess.graph.get_tensor_by_name("input/input_ids:0")
    self.input_mask = self.sess.graph.get_tensor_by_name("input/input_mask:0")
    self.segment_ids = self.sess.graph.get_tensor_by_name("input/segment_ids:0")
    self.predictions = self.sess.graph.get_tensor_by_name("output_layer/predictions:0")
    self.max_seq_length = max_seq_length

Where, gpu_id is the serial number using GPU; vocab_file is the dictionary path used by the BERT model; gpu_memory_fraction is the proportion of GPU used; model_path is the path of pb model; max_seq_length is the maximum length of the BERT model.

The format code required to convert the incoming text into the model is as follows:

def convert_fearture(self, text):
    max_seq_length = self.max_seq_length
    max_length_context = max_seq_length - 2

    content_token = self.tokenizer.tokenize(text)
    if len(content_token) > max_length_context:
        content_token = content_token[:max_length_context]

    tokens = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in content_token:
        tokens.append(token)
        segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    input_ids = np.array(input_ids)
    input_mask = np.array(input_mask)
    segment_ids = np.array(segment_ids)
    return input_ids, input_mask, segment_ids

The forecast code is as follows:

def predict(self, text):
    input_ids_temp, input_mask_temp, segment_ids_temp = self.convert_fearture(text)
    feed = {self.is_train: False,
            self.input_ids: input_ids_temp.reshape(1, self.max_seq_length),
            self.input_mask: input_mask_temp.reshape(1, self.max_seq_length),
            self.segment_ids: segment_ids_temp.reshape(1, self.max_seq_length)}
    [label] = self.sess.run([self.predictions], feed)
    label_name = self.label_dict[label[0]]
    return label[0], label_name

Where, the input is a news text, and the output is the category number and the corresponding label name. For detailed and complete code, see github:

ClassificationModel.py Documents.

(highlight) the above are all about how to load the model easily and succinctly. Next, start to use the web service suspend model. Generally speaking, in fact, I build a web service through the flash framework to obtain external input; and use the attached model to predict; finally, I transfer the prediction results through the web service.

from gevent import monkey
monkey.patch_all()
from flask import Flask, request
from gevent import wsgi
import json
from ClassificationModel import ClassificationModel


def start_sever(http_id, port, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):
    model = ClassificationModel()
    model.load_model(gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length)
    print("load model ending!")
    app = Flask(__name__)

    @app.route('/')
    def index():
        return "This is News Classification Model Server"

    @app.route('/news-classification', methods=['Get', 'POST'])
    def response_request():
        if request.method == 'POST':
            text = request.form.get('text')
        else:
            text = request.args.get('text')
        label, label_name = model.predict(text)
        d = {"label": str(label), "label_name": label_name}
        print(d)
        return json.dumps(d, ensure_ascii=False)

    server = wsgi.WSGIServer((str(http_id), port), app)
    server.serve_forever()

Where, http_id is the address of the web service; port is the port number; gpu_id,vocab_file,gpu_memory_fraction,model_path and max_seq_length is the parameter needed to load the model described above. See the above for details.

The index function is used to check whether the web service is unblocked. As shown in Figure 1.

Figure 1

response_ The request function is the response function. There are two ways to request data, get and post. When using the get method to get the web input, the get command is request.args.get('text '); when using the post method to get web input, the get command is request.form.get('text').


When the web service is up, it can be called!!!

The browser call is shown in Figure 2.

Figure 2

Code is called as follows:

import requests

def http_test(text):
    url = 'http://127.0.0.1:5555/news-classification'
    raw_data = {'text': text}
    res = requests.post(url, raw_data)
    result = res.json()
    return result

if __name__ == "__main__":
    text = "Yao Ming in NBA Playing, very strong."
    result = http_test(text)
    print(result["label_name"])

The above is the whole content of the in-depth learning model through web service deployment. Please like it more~~~~~

Recommend some articles I wrote before:

Liu Cong NLP: Research on short text similarity algorithm

Liu Cong NLP: Reading Notes: Open Domain Search Q & A (ORQA)

Liu Cong NLP: paper reading notes: BiMPM contained in the text

Like the students, you can pay attention to the column, the author, and please praise more~~~~~~

Topics: github JSON Python Session