forked from swatipb/oceans-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserving.py
151 lines (121 loc) · 4.77 KB
/
serving.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Inference server. Currently works with TF OD API."""
from concurrent import futures
import time
from absl import app
from absl import flags
from absl import logging
import grpc
import numpy as np
import service_pb2
import service_pb2_grpc
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_string('model_path', None, 'Path to inference SavedModel.')
flags.DEFINE_string('model_signature', 'serving_default',
'Signature of the model to run.')
flags.DEFINE_float('detection_threshold', 0.4,
'Detection confidence threshold to return.')
flags.DEFINE_integer('batch_size', 4,
'Batch size that the model expects.')
_OUTPUT_NAMES = (
'detection_anchor_indices',
'detection_boxes',
'detection_classes',
'detection_multiclass_scores',
'detection_scores',
'num_detections',
'raw_detection_boxes',
'raw_detection_scores',
)
_MAX_MESSAGE_LENGTH = 100 * 1024 * 1024 # 100 MiB
_IMAGE_TYPE = tf.uint8
class Detector(service_pb2_grpc.Detector):
"""Detector service. gets serizliaed tensors and returns detecion results."""
def __init__(self, model):
super().__init__()
self._model = model
try:
serving_fn = model.signatures[FLAGS.model_signature]
except KeyError:
raise KeyError(f'Model does not have signautre {FLAGS.model_signature}. '
f'Available signatures: {list(model.signatures)}')
@tf.function(
input_signature=[tf.TensorSpec((None, None, None, 3), _IMAGE_TYPE)])
def model_fn(data):
return serving_fn(data)
self._model_fn = model_fn
logging.info('Warming up..')
# Warm-up.
# TODO: Read the input size from the model.
for i in range(10):
self._model_fn(tf.zeros((4, 736, 1280, 3), dtype=tf.uint8))
def Inference(self, request, context):
images = tf.io.parse_tensor(request.data, _IMAGE_TYPE)
start = time.time()
logging.info(f'Inference request with tensor shape: {images.shape}')
if images.shape[0] < FLAGS.batch_size:
images = tf.pad(images, [[0, FLAGS.batch_size - images.shape[0]], [0, 0], [0, 0], [0, 0]])
detections = self._model_fn(images)
result = service_pb2.InferenceReply()
img_h, img_w = images.shape[1:3]
num_detections = detections['num_detections'].numpy().astype(np.int32)
detection_boxes = detections['detection_boxes'].numpy()
detection_classes = detections['detection_classes'].numpy().astype(np.int32)
detection_scores = detections['detection_scores'].numpy()
# Temporarily disable padding compensation because the current model will
# handle this.
#
# model_y_padding = (
# (request.original_image_width - request.original_image_height) / 2 /
# request.original_image_width)
model_y_padding = 0
# print('Detected:', num_detections)
for file_idx, file_path in enumerate(request.file_paths):
scores = detection_scores[file_idx]
valid_indices = detection_scores[file_idx, :] >= FLAGS.detection_threshold
scores = scores[valid_indices]
classes = detection_classes[file_idx, valid_indices]
bbox = detection_boxes[file_idx, valid_indices, :]
for i, pos in enumerate(bbox):
box_x1 = pos[1]
box_y1 = (pos[0] - model_y_padding) / (1 - 2 * model_y_padding)
box_x2 = pos[3]
box_y2 = (pos[2] - model_y_padding) / (1 - 2 * model_y_padding)
detection = service_pb2.BoundingBox(
file_path=file_path,
class_id=classes[i],
score=scores[i],
left=box_x1 * img_w,
top=box_y1 * img_h,
width=(box_x2 - box_x1) * img_w,
height=(box_y2 - box_y1) * img_h,
)
result.detections.append(detection)
inference_ms = int(1000 * (time.time() - start))
logging.info(f'Inference request done in {inference_ms}ms')
return result
def serve():
"""Starts gRPC service."""
# These are not available in TF 2.5.
# options = tf.saved_model.LoadOptions(
# allow_partial_checkpoint=True, experimental_skip_checkpoint=True)
start = time.time()
model = tf.saved_model.load(FLAGS.model_path)
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
options=[
('grpc.max_send_message_length', _MAX_MESSAGE_LENGTH),
('grpc.max_receive_message_length', _MAX_MESSAGE_LENGTH),
])
service_pb2_grpc.add_DetectorServicer_to_server(Detector(model), server)
logging.info('Model loading done in %.2fs. Inference server is ready.',
time.time() - start)
server.add_insecure_port('[::]:50051')
server.start()
server.wait_for_termination()
def main(unused_argv):
tf.config.optimizer.set_jit(True)
tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})
serve()
if __name__ == '__main__':
app.run(main)