-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi.py
68 lines (52 loc) · 1.96 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from aic_summarization_rest_api.summarizer_factory import create_summarizer
import argparse
import json
import logging
from flask import Flask, Blueprint
from flask_restx import Api, Resource, fields
from time import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
api_v1 = Blueprint("api", __name__, url_prefix="/api/v1")
api = Api(
api_v1,
version="1.0",
title="AIC Summarization API",
description="A simple text summarization API",
)
ns = api.namespace("summarize", description="generate summaries")
parser = api.parser()
parser.add_argument(
"sources", type=str, required=True, action='append', help="Source texts to summarize", location="form"
)
apimodel = api.model('SumModel', {
'summaries': fields.List(fields.String, description="List of summaries, one per each source text"),
'duration_s': fields.Float(description="Summarization duration in seconds"),
})
@ns.route("/")
class Summarize(Resource):
"""Generate summaries"""
@api.marshal_with(apimodel, envelope='resource')
@api.doc(parser=parser)
def post(self):
"""Generate a summary"""
args = parser.parse_args()
st = time()
summaries = summarizer.summarize_batch(args["sources"])
duration_s = time()-st
ret = {"summaries": summaries, "duration_s": duration_s}
return ret, 200
if __name__ == "__main__":
cparser = argparse.ArgumentParser()
cparser.add_argument(
'cfgfile', help="JSON configuration file, such as cfg/mbart_headline.json")
cparser.add_argument('--device', default='cpu',
choices=["cpu", "cuda"], help="target device CPU or CUDA", type=str)
cargs = cparser.parse_args()
with open(cargs.cfgfile, "r") as f:
cfg = json.load(f)
logger.info(json.dumps(cfg, indent=4))
summarizer = create_summarizer(cfg, device=cargs.device)
app = Flask(__name__)
app.register_blueprint(api_v1)
app.run(debug=True)