-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
48 lines (35 loc) · 1.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Server file
Attributes:
app (fastapi.applications.FastAPI): Fast API app
"""
import os
import dotenv
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
dotenv.load_dotenv()
from core.classifiers import BOWSubclassPredictor, BERTSubclassPredictor
app = FastAPI()
class ClassificationRequest(BaseModel):
"""Class for defining input parameters data type"""
text: str
n: str
model: str
@app.post("/classify")
async def classify(item: ClassificationRequest):
"""Find relevant CPC technology subclasses for a given text snippet.
Returns:
list: Array of subclass codes, most relevant first.
"""
data = item
text = data.text
n_sub_classes = int(data.n)
if data.model == "BOWSubclassPredictor":
model = BOWSubclassPredictor()
else:
model = BERTSubclassPredictor()
sub_classes = model.predict_subclasses(text, n_sub_classes)
return sub_classes
if __name__ == "__main__":
port = int(os.environ["PORT"])
uvicorn.run(app, host="0.0.0.0", port=port)