Skip to content

Commit

Permalink
add startup event in main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ympaik87 committed Jan 15, 2022
1 parent a33a8fe commit 5b392c3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
22 changes: 16 additions & 6 deletions starter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


app = FastAPI()
MODEL_CONFIGS = {}


class InferenceRequest(BaseModel):
Expand Down Expand Up @@ -48,19 +49,28 @@ async def create_item(item: dict):
return item


@app.on_event("startup")
async def startup_event():
cwd_p = os.getcwd()
MODEL_CONFIGS['trained_model'] = \
joblib.load(f"{cwd_p}/starter/model/model_trained.joblib")
MODEL_CONFIGS['encoder'] = \
joblib.load(f"{cwd_p}/starter/model/encoder.joblib")
MODEL_CONFIGS['labels'] = \
joblib.load(f"{cwd_p}/starter/model/lb.joblib")


@app.post('/predict')
async def get_prediction(request_data: InferenceRequest):
cwd_p = os.getcwd()
trained_model = joblib.load(f"{cwd_p}/starter/model/model_trained.joblib")
encoder = joblib.load(f"{cwd_p}/starter/model/encoder.joblib")
labels = joblib.load(f"{cwd_p}/starter/model/lb.joblib")
request_dict = request_data.dict(by_alias=True)
request_df = pd.DataFrame(request_dict, index=[0])
processed_data, _, _, _ = process_data(
request_df, categorical_features=CAT_FEATURES, label=None,
training=False, encoder=encoder, lb=labels
training=False, encoder=MODEL_CONFIGS['encoder'],
lb=MODEL_CONFIGS['labels']
)
preds = inference(trained_model, np.array(processed_data))
preds = inference(MODEL_CONFIGS['trained_model'],
np.array(processed_data))
if preds[0]:
pred_cat = '>50K'
else:
Expand Down
16 changes: 11 additions & 5 deletions starter/test_main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
from fastapi.testclient import TestClient
import logging
import pytest
from main import app


logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
client = TestClient(app)


def test_welcome():
@pytest.fixture
def client():
with TestClient(app) as clt:
yield clt


def test_welcome(client):
req = client.get('/')
assert req.status_code == 200, "Status code is not 200"
assert req.json() == "Welcome, this API returns predictions on Salary", "Wrong json output"


def test_post():
def test_post(client):
sample_dict = {
"age": 49,
"workclass": "State-gov",
Expand All @@ -24,7 +30,7 @@ def test_post():
assert response.json() == sample_dict


def test_get_prediction_negative():
def test_get_prediction_negative(client):
input_dict = {
"age": 49,
"workclass": "State-gov",
Expand All @@ -47,7 +53,7 @@ def test_get_prediction_negative():
"Wrong json output"


def test_get_prediction_positive():
def test_get_prediction_positive(client):
input_dict = {
"age": 41,
"workclass": "Private",
Expand Down

0 comments on commit 5b392c3

Please sign in to comment.