From 5b392c3c0a33b919dd4a6e6abab18a6e09ede836 Mon Sep 17 00:00:00 2001 From: Young Min Paik Date: Sat, 15 Jan 2022 10:37:56 +0900 Subject: [PATCH] add startup event in main.py --- starter/main.py | 22 ++++++++++++++++------ starter/test_main.py | 16 +++++++++++----- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/starter/main.py b/starter/main.py index f276e08..99d33dc 100644 --- a/starter/main.py +++ b/starter/main.py @@ -17,6 +17,7 @@ app = FastAPI() +MODEL_CONFIGS = {} class InferenceRequest(BaseModel): @@ -48,19 +49,28 @@ async def create_item(item: dict): return item +@app.on_event("startup") +async def startup_event(): + cwd_p = os.getcwd() + MODEL_CONFIGS['trained_model'] = \ + joblib.load(f"{cwd_p}/starter/model/model_trained.joblib") + MODEL_CONFIGS['encoder'] = \ + joblib.load(f"{cwd_p}/starter/model/encoder.joblib") + MODEL_CONFIGS['labels'] = \ + joblib.load(f"{cwd_p}/starter/model/lb.joblib") + + @app.post('/predict') async def get_prediction(request_data: InferenceRequest): - cwd_p = os.getcwd() - trained_model = joblib.load(f"{cwd_p}/starter/model/model_trained.joblib") - encoder = joblib.load(f"{cwd_p}/starter/model/encoder.joblib") - labels = joblib.load(f"{cwd_p}/starter/model/lb.joblib") request_dict = request_data.dict(by_alias=True) request_df = pd.DataFrame(request_dict, index=[0]) processed_data, _, _, _ = process_data( request_df, categorical_features=CAT_FEATURES, label=None, - training=False, encoder=encoder, lb=labels + training=False, encoder=MODEL_CONFIGS['encoder'], + lb=MODEL_CONFIGS['labels'] ) - preds = inference(trained_model, np.array(processed_data)) + preds = inference(MODEL_CONFIGS['trained_model'], + np.array(processed_data)) if preds[0]: pred_cat = '>50K' else: diff --git a/starter/test_main.py b/starter/test_main.py index 65ea701..0ae50eb 100644 --- a/starter/test_main.py +++ b/starter/test_main.py @@ -1,20 +1,26 @@ from fastapi.testclient import TestClient import logging +import pytest from main import app logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") logger = logging.getLogger() -client = TestClient(app) -def test_welcome(): +@pytest.fixture +def client(): + with TestClient(app) as clt: + yield clt + + +def test_welcome(client): req = client.get('/') assert req.status_code == 200, "Status code is not 200" assert req.json() == "Welcome, this API returns predictions on Salary", "Wrong json output" -def test_post(): +def test_post(client): sample_dict = { "age": 49, "workclass": "State-gov", @@ -24,7 +30,7 @@ def test_post(): assert response.json() == sample_dict -def test_get_prediction_negative(): +def test_get_prediction_negative(client): input_dict = { "age": 49, "workclass": "State-gov", @@ -47,7 +53,7 @@ def test_get_prediction_negative(): "Wrong json output" -def test_get_prediction_positive(): +def test_get_prediction_positive(client): input_dict = { "age": 41, "workclass": "Private",