Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fastapi scoring #1

Open
wants to merge 2 commits into
base: branch-1.22
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlflow/R/mlflow/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: mlflow
Title: Interface to 'MLflow'
Version: 1.21.1
Version: 1.22.0
Authors@R:
c(person(given = "Matei",
family = "Zaharia",
Expand Down
2 changes: 1 addition & 1 deletion mlflow/java/client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>org.mlflow</groupId>
<artifactId>mlflow-parent</artifactId>
<version>1.21.1-SNAPSHOT</version>
<version>1.22.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
4 changes: 2 additions & 2 deletions mlflow/java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.mlflow</groupId>
<artifactId>mlflow-parent</artifactId>
<version>1.21.1-SNAPSHOT</version>
<version>1.22.0</version>
<packaging>pom</packaging>
<name>MLflow Parent POM</name>
<url>http://mlflow.org</url>
Expand Down Expand Up @@ -59,7 +59,7 @@
</distributionManagement>

<properties>
<mlflow-version>1.21.1-SNAPSHOT</mlflow-version>
<mlflow-version>1.22.0</mlflow-version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<scala.version>2.11.12</scala.version>
Expand Down
2 changes: 1 addition & 1 deletion mlflow/java/scoring/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>org.mlflow</groupId>
<artifactId>mlflow-parent</artifactId>
<version>1.21.1-SNAPSHOT</version>
<version>1.22.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
4 changes: 2 additions & 2 deletions mlflow/java/spark/pom.xml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<artifactId>mlflow-spark</artifactId>
<version>1.21.1-SNAPSHOT</version>
<version>1.22.0</version>
<name>${project.artifactId}</name>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
Expand All @@ -15,7 +15,7 @@
<parent>
<groupId>org.mlflow</groupId>
<artifactId>mlflow-parent</artifactId>
<version>1.21.1-SNAPSHOT</version>
<version>1.22.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
93 changes: 54 additions & 39 deletions mlflow/pyfunc/scoring_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import pandas as pd
import sys
import traceback
from pydantic import BaseModel
from fastapi import FastAPI, APIRouter, Request, HTTPException, Response, Header, status
from typing import List, Optional, Dict
import uvicorn
import asyncio
import json

# NB: We need to be careful what we import form mlflow here. Scoring server is used from within
# model's conda environment. The version of mlflow doing the serving (outside) and the version of
Expand Down Expand Up @@ -65,13 +71,24 @@

CONTENT_TYPE_FORMAT_RECORDS_ORIENTED = "pandas-records"
CONTENT_TYPE_FORMAT_SPLIT_ORIENTED = "pandas-split"
CONTENT_TYPE_RAW_JSON = "raw-json"

FORMATS = [CONTENT_TYPE_FORMAT_RECORDS_ORIENTED, CONTENT_TYPE_FORMAT_SPLIT_ORIENTED]
FORMATS = [CONTENT_TYPE_FORMAT_RECORDS_ORIENTED, CONTENT_TYPE_FORMAT_SPLIT_ORIENTED, CONTENT_TYPE_RAW_JSON]

PREDICTIONS_WRAPPER_ATTR_NAME_ENV_KEY = "PREDICTIONS_WRAPPER_ATTR_NAME"

_logger = logging.getLogger(__name__)

class RequestData(BaseModel):
columns: List[str] = []
data: list = []

def is_valid(self):
return True

def get_dataframe(self):
df = pd.DataFrame(data = self.data, columns = self.columns)
return df

def infer_and_parse_json_input(json_input, schema: Schema = None):
"""
Expand Down Expand Up @@ -205,38 +222,38 @@ def _handle_serving_error(error_message, error_code, include_traceback=True):
e = MlflowException(message=error_message, error_code=error_code)
reraise(MlflowException, e)


def init(model: PyFuncModel):

"""
Initialize the server. Loads pyfunc model from the path.
"""
app = flask.Flask(__name__)
fast_app = FastAPI(title= __name__, version= "v1")
fast_app.include_router(APIRouter())
input_schema = model.metadata.get_input_schema()

@app.route("/ping", methods=["GET"])
@fast_app.get("/ping")
def ping(): # pylint: disable=unused-variable
"""
Determine if the container is working and healthy.
We declare it healthy if we can load the model successfully.
"""
health = model is not None
status = 200 if health else 404
return flask.Response(response="\n", status=status, mimetype="application/json")
if model is None:
raise HTTPException(status_code=404, detail="Model not loaded properly")
return {"message": "OK"}

@app.route("/invocations", methods=["POST"])
@catch_mlflow_exception
def transformation(): # pylint: disable=unused-variable
@fast_app.post("/invocations")
def transformation(request_data: RequestData, content_type: Optional[str] = Header(None)): # pylint: disable=unused-variable
"""
Do an inference on a single batch of data. In this sample server,
we take data as CSV or json, convert it to a Pandas DataFrame or Numpy,
generate predictions and convert them back to json.
"""
# data = _dataframe_from_json(request_data.json())

# Content-Type can include other attributes like CHARSET
# Content-type RFC: https://datatracker.ietf.org/doc/html/rfc2045#section-5.1
# TODO: Suport ";" in quoted parameter values
type_parts = flask.request.content_type.split(";")
type_parts = content_type.split(";")
type_parts = list(map(str.strip, type_parts))
mime_type = type_parts[0]
parameter_value_pairs = type_parts[1:]
Expand All @@ -247,27 +264,31 @@ def transformation(): # pylint: disable=unused-variable

charset = parameter_values.get("charset", "utf-8").lower()
if charset != "utf-8":
return flask.Response(
response="The scoring server only supports UTF-8",
status=415,
mimetype="text/plain",
return Response(
content="The scoring server only supports UTF-8",
status_code=415,
media_type="text/plain"
)

content_format = parameter_values.get("format")

# Convert from CSV to pandas
if mime_type == CONTENT_TYPE_CSV and not content_format:
data = flask.request.data.decode("utf-8")
data = request_data.json()
csv_input = StringIO(data)
data = parse_csv_input(csv_input=csv_input)
elif mime_type == CONTENT_TYPE_JSON and content_format == CONTENT_TYPE_RAW_JSON:
if len(request_data.data) != 0:
data = dict(zip(request_data.columns, request_data.data[0]))
else:
data = {}
elif mime_type == CONTENT_TYPE_JSON and not content_format:
json_str = flask.request.data.decode("utf-8")
data = infer_and_parse_json_input(json_str, input_schema)
data = infer_and_parse_json_input(request_data.json(), input_schema)
elif (
mime_type == CONTENT_TYPE_JSON and content_format == CONTENT_TYPE_FORMAT_SPLIT_ORIENTED
):
data = parse_json_input(
json_input=StringIO(flask.request.data.decode("utf-8")),
json_input=StringIO(request_data.json()),
orient="split",
schema=input_schema,
)
Expand All @@ -276,29 +297,25 @@ def transformation(): # pylint: disable=unused-variable
and content_format == CONTENT_TYPE_FORMAT_RECORDS_ORIENTED
):
data = parse_json_input(
json_input=StringIO(flask.request.data.decode("utf-8")),
json_input=StringIO(request_data.json()),
orient="records",
schema=input_schema,
)
elif mime_type == CONTENT_TYPE_JSON_SPLIT_NUMPY and not content_format:
data = parse_split_oriented_json_input_to_numpy(flask.request.data.decode("utf-8"))
data = parse_split_oriented_json_input_to_numpy(request_data.json())
else:
return flask.Response(
response=(
"This predictor only supports the following content types and formats:"
return Response(
content="This predictor only supports the following content types and formats:"
" Types: {supported_content_types}; Formats: {formats}."
" Got '{received_content_type}'.".format(
supported_content_types=CONTENT_TYPES,
formats=FORMATS,
received_content_type=flask.request.content_type,
)
),
status=415,
mimetype="text/plain",
received_content_type=content_type,
),
status_code=415,
media_type="text/plain"
)

# Do the prediction

try:
raw_predictions = model.predict(data)
except MlflowException as e:
Expand All @@ -314,11 +331,10 @@ def transformation(): # pylint: disable=unused-variable
),
error_code=BAD_REQUEST,
)
result = StringIO()
predictions_to_json(raw_predictions, result)
return flask.Response(response=result.getvalue(), status=200, mimetype="application/json")
predictions = _get_jsonable_obj(raw_predictions, pandas_orient="records")
return predictions

return app
return fast_app


def _predict(model_uri, input_path, output_path, content_type, json_format):
Expand All @@ -342,8 +358,8 @@ def _predict(model_uri, input_path, output_path, content_type, json_format):

def _serve(model_uri, port, host):
pyfunc_model = load_model(model_uri)
init(pyfunc_model).run(port=port, host=host)

fast_app = init(pyfunc_model)
uvicorn.run(fast_app, host=host, port=port, log_level="info")

def get_cmd(
model_uri: str, port: int = None, host: int = None, nworkers: int = None
Expand All @@ -362,8 +378,7 @@ def get_cmd(
args.append(f"-w {nworkers}")

command = (
f"gunicorn {' '.join(args)} ${{GUNICORN_CMD_ARGS}}"
" -- mlflow.pyfunc.scoring_server.wsgi:app"
"gunicorn mlflow.pyfunc.scoring_server.wsgi:app --worker-class uvicorn.workers.UvicornWorker"
)
else:
args = []
Expand Down
2 changes: 1 addition & 1 deletion mlflow/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re


VERSION = "1.21.1.dev0"
VERSION = "1.22.0"


def is_release_version():
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def package_files(directory):
"alembic<=1.4.1",
# Required
"docker>=4.0.0",
"fastapi",
"uvicorn",
"Flask",
"gunicorn; platform_system != 'Windows'",
"numpy",
Expand Down