-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
54 lines (44 loc) · 1.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
Diamond-Mate-Backend
Alf-arv, 2021
"""
import sys
import os
import json
from flask import request
from flask import Flask
from model_training import train_regression_estimator
from inference import do_inference, do_batch_inference
app = Flask(__name__)
@app.route('/single_inference', methods = ['GET'])
def single_inference():
if(request.method == 'GET'):
res = do_inference(model_path='model', data={
"Shape": request.args.get('Shape'),
"Carat": request.args.get('Carat'),
"Color": request.args.get('Color'),
"Clarity": request.args.get('Clarity'),
"Cut": request.args.get('Cut')
})
return json.dumps(str({"price_prediction":res}))
@app.route('/batch_inference', methods = ['GET'])
def batch_inference():
if(request.method == 'GET'):
try:
batch_length = len(request.args.get('batch'))
except:
batch_length = 0
if batch_length > 0:
res = do_batch_inference(model_path='model', data={
"batch": request.args.get('batch')
})
return json.dumps(str(res))
else:
return json.dumps(str({"error": -1}))
@app.route('/train_model', methods = ['POST'])
def train_model():
if(request.method == 'POST'):
stat = train_regression_estimator(os.path.join('data', 'database.csv'), 'model')
return json.dumps(str({"success": stat}))
if __name__ == '__main__':
app.run(threaded=False)