From 0dc0bde019b89f123c42bb02839d5b9b7aa82f92 Mon Sep 17 00:00:00 2001 From: Philipp Rouast Date: Sun, 21 Jul 2024 19:16:48 +0200 Subject: [PATCH] In multi-request inference for VITALLENS, use a separate ROI per request --- tests/test_signal.py | 15 ++++- vitallens/methods/vitallens.py | 111 +++++++++++++++++++++------------ vitallens/signal.py | 29 +++++++++ vitallens/ssd.py | 2 +- vitallens/utils.py | 1 + 5 files changed, 116 insertions(+), 42 deletions(-) diff --git a/tests/test_signal.py b/tests/test_signal.py index 1282f4e..81f5ecc 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -24,7 +24,7 @@ import sys sys.path.append('../vitallens-python') -from vitallens.signal import windowed_mean, windowed_freq +from vitallens.signal import windowed_mean, windowed_freq, reassemble_from_windows def test_windowed_mean(): x = np.asarray([0., 1., 2., 3., 4., 5., 6.]) @@ -47,4 +47,17 @@ def test_estimate_freq_periodogram(num, freq, window_size): windowed_freq(x=y, window_size=window_size, overlap=window_size//2, f_s=len(x), f_range=(max(freq-2,1),freq+2), f_res=0.05), np.full((num,), fill_value=freq), rtol=1) + +def test_reassemble_from_windows(): + x = np.array([[[2.0, 4.0, 6.0, 8.0, 10.0], [7.0, 1.0, 10.0, 12.0, 18.0]], + [[2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0, 11.0]]], dtype=np.float32).transpose(1, 0, 2) + idxs = np.array([[1, 3, 5, 7, 9], [5, 6, 9, 11, 13]], dtype=np.int64) + out_x, out_idxs = reassemble_from_windows(x=x, idxs=idxs) + np.testing.assert_equal( + out_x, + np.asarray([[2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 18.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 10.0, 11.0]])) + np.testing.assert_equal( + out_idxs, + np.asarray([1, 3, 5, 7, 9, 11, 13])) \ No newline at end of file diff --git a/vitallens/methods/vitallens.py b/vitallens/methods/vitallens.py index db2c7af..ff98420 100644 --- a/vitallens/methods/vitallens.py +++ b/vitallens/methods/vitallens.py @@ -36,6 +36,7 @@ from vitallens.methods.rppg_method import RPPGMethod from vitallens.signal import detrend_lambda_for_hr_response, detrend_lambda_for_rr_response from vitallens.signal import moving_average_size_for_hr_response, moving_average_size_for_rr_response +from vitallens.signal import reassemble_from_windows from vitallens.utils import probe_video_inputs, parse_video_inputs class VitalLensRPPGMethod(RPPGMethod): @@ -72,49 +73,43 @@ def __call__( live: The face live confidence. Shape (1, n_frames) """ inputs_shape, fps = probe_video_inputs(video=frames, fps=fps) - # Choose representative face detection - # TODO: For longer videos extract chunks from separate locations? - face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))] - roi = get_roi_from_det( - face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1])) - if np.any(np.logical_or( - (faces[:,2] - faces[:,0]) * 0.5 < np.maximum(0, faces[:,0] - roi[0]) + np.maximum(0, faces[:,2] - roi[2]), - (faces[:,3] - faces[:,1]) * 0.5 < np.maximum(0, faces[:,1] - roi[1]) + np.maximum(0, faces[:,3] - roi[3]))): - logging.warn("Large face movement detected") - # Parse the inputs - logging.debug("Preparing video for inference...") - frames_ds, fps, inputs_shape, ds_factor, _ = parse_video_inputs( - video=frames, fps=fps, target_size=self.input_size, roi=roi, - target_fps=override_fps_target if override_fps_target is not None else self.fps_target, - library='prpy', scale_algorithm='bilinear') # Check the number of frames to be processed - ds_len = frames_ds.shape[0] - if ds_len <= API_MAX_FRAMES: + inputs_n = inputs_shape[0] + fps_target = override_fps_target if override_fps_target is not None else self.fps_target + expected_ds_factor = round(fps / fps_target) + expected_ds_n = math.ceil(inputs_n / expected_ds_factor) + if expected_ds_n <= API_MAX_FRAMES: # API supports up to MAX_FRAMES at once - process all frames - sig_ds, conf_ds, live_ds = self.process_api(frames_ds) + sig_ds, conf_ds, live_ds, idxs = self.process_api_batch( + batch=1, n_batches=1, inputs=frames, inputs_shape=inputs_shape, + faces=faces, fps_target=fps_target, fps=fps) else: # Longer videos are split up with small overlaps - ds_len = frames_ds.shape[0] - n_splits = math.ceil((ds_len - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1 - split_len = math.ceil((ds_len + (n_splits-1) * API_OVERLAP) / n_splits) - start_idxs = [i for i in range(0, ds_len - n_splits * API_OVERLAP, split_len - API_OVERLAP)] - end_idxs = [min(i + split_len, ds_len) for i in start_idxs] - logging.info("Running inference for {} frames using {} requests...".format(ds_len, n_splits)) + n_splits = math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1 + split_len = math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits) + # start_idxs = [i for i in range(0, expected_ds_len - n_splits * API_OVERLAP, split_len - API_OVERLAP)] + start_idxs = [i * (split_len - API_OVERLAP * expected_ds_factor) for i in range(n_splits)] + end_idxs = [min(start + split_len, inputs_n) for start in start_idxs] + start_idxs = [max(0, end - split_len) for end in end_idxs] + logging.info("Running inference for {} frames using {} requests...".format(expected_ds_n, n_splits)) # Process the splits in parallel with concurrent.futures.ThreadPoolExecutor() as executor: - results = list(executor.map(lambda i: self.process_api(frames_ds[start_idxs[i]:end_idxs[i]]), range(n_splits))) + results = list(executor.map(lambda i: self.process_api_batch( + batch=i, n_batches=n_splits, inputs=frames, inputs_shape=inputs_shape, + faces=faces, fps_target=fps_target, start=start_idxs[i], end=end_idxs[i], + fps=fps), range(n_splits))) # Aggregate the results - sig_results, conf_results, live_results = zip(*results) - sig_ds = np.concatenate([sig_results[0]] + [[x[API_OVERLAP:] for x in e] for e in sig_results[1:]], axis=-1) - conf_ds = np.concatenate([conf_results[0]] + [[x[API_OVERLAP:] for x in e] for e in conf_results[1:]], axis=-1) - live_ds = np.concatenate([live_results[0]] + [x[API_OVERLAP:] for x in live_results[1:]], axis=-1) + sig_results, conf_results, live_results, idxs_results = zip(*results) + sig_ds, idxs = reassemble_from_windows(x=sig_results, idxs=idxs_results) + conf_ds, _ = reassemble_from_windows(x=conf_results, idxs=idxs_results) + live_ds = reassemble_from_windows(x=np.asarray(live_results)[:,np.newaxis], idxs=idxs_results)[0][0] # Interpolate to original sampling rate (n_frames,) sig = interpolate_cubic_spline( - x=np.arange(inputs_shape[0])[0::ds_factor], y=sig_ds, xs=np.arange(inputs_shape[0]), axis=1) + x=idxs, y=sig_ds, xs=np.arange(inputs_n), axis=1) conf = interpolate_cubic_spline( - x=np.arange(inputs_shape[0])[0::ds_factor], y=conf_ds, xs=np.arange(inputs_shape[0]), axis=1) + x=idxs, y=conf_ds, xs=np.arange(inputs_n), axis=1) live = interpolate_cubic_spline( - x=np.arange(inputs_shape[0])[0::ds_factor], y=live_ds, xs=np.arange(inputs_shape[0]), axis=0) + x=idxs, y=live_ds, xs=np.arange(inputs_n), axis=0) # Filter (n_frames,) sig = np.asarray([self.postprocess(p, fps, type=name) for p, name in zip(sig, ['ppg', 'resp'])]) # Estimate summary vitals @@ -153,23 +148,58 @@ def __call__( out_conf[name] = conf[1] out_note[name] = 'Estimate of the respiratory waveform using VitalLens, along with frame-wise confidences between 0 and 1.' return out_data, out_unit, out_conf, out_note, live - def process_api( + def process_api_batch( self, - frames: np.ndarray, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Process frames with the VitalLens API. + batch: int, + n_batches: int, + inputs: Tuple[np.ndarray, str], + inputs_shape: tuple, + faces: np.ndarray, + fps_target: float, + start: int = None, + end: int = None, + fps: float = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Process a batch of frames with the VitalLens API. Args: - frames: The video frames. Shape (n_frames<=MAX_FRAMES, h==input_size, w==input_size, 3) + batch: The number of this batch. + n_batches: The total number of batches. + inputs: The video to analyze. Either a np.ndarray of shape (n_frames, h, w, 3) + with a sequence of frames in unscaled uint8 RGB format, or a path to a video file. + inputs_shape: The shape of the inputs. + faces: The face detection boxes as np.int64. Shape (n_frames, 4) in form (x0, y0, x1, y1) + fps_target: The target frame rate at which to run inference. + start: The index of first frame of the video to analyze in this batch. + end: The index of the last frame of the video to analyze in this batch. + fps: The frame rate of the input video. Required if type(video) == np.ndarray Returns: sig: Estimated signals. Shape (n_sig, n_frames) conf: Estimation confidences. Shape (n_sig, n_frames) live: Liveness estimation. Shape (n_frames,) + idxs: Indices in inputs that were processed. Shape (n_frames) """ - assert frames.shape[0] <= API_MAX_FRAMES + logging.debug("Batch {}/{}...".format(batch, n_batches)) + # Trim face detections to batch if necessary + if start is not None and end is not None: + faces = faces[start:end] + # Choose representative face detection + face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))] + roi = get_roi_from_det( + face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1])) + if np.any(np.logical_or( + (faces[:,2] - faces[:,0]) * 0.5 < np.maximum(0, faces[:,0] - roi[0]) + np.maximum(0, faces[:,2] - roi[2]), + (faces[:,3] - faces[:,1]) * 0.5 < np.maximum(0, faces[:,1] - roi[1]) + np.maximum(0, faces[:,3] - roi[3]))): + logging.warn("Large face movement detected") + # Parse the inputs + frames_ds, fps, inputs_shape, _, idxs = parse_video_inputs( + video=inputs, fps=fps, target_size=self.input_size, roi=roi, target_fps=fps_target, + trim=(start, end) if start is not None and end is not None else None, + library='prpy', scale_algorithm='bilinear') + assert frames_ds.shape[0] <= API_MAX_FRAMES # Prepare API header and payload headers = {"x-api-key": self.api_key} - payload = {"video": base64.b64encode(frames.tobytes()).decode('utf-8')} + payload = {"video": base64.b64encode(frames_ds.tobytes()).decode('utf-8')} # Ask API to process video response = requests.post(API_URL, headers=headers, json=payload) response_body = json.loads(response.text) @@ -194,7 +224,8 @@ def process_api( np.asarray(response_body["vital_signs"]["respiratory_waveform"]["confidence"]), ], axis=0) live_ds = np.asarray(response_body["face"]["confidence"]) - return sig_ds, conf_ds, live_ds + idxs = np.asarray(idxs) + return sig_ds, conf_ds, live_ds, idxs def postprocess(self, sig, fps, type='ppg', filter=True): """Apply filters to the estimated signal. Args: diff --git a/vitallens/signal.py b/vitallens/signal.py index c8924f3..9834790 100644 --- a/vitallens/signal.py +++ b/vitallens/signal.py @@ -122,3 +122,32 @@ def windowed_freq( freq_vals.shape[0], n) # Return return freq_vals + +def reassemble_from_windows( + x: np.ndarray, + idxs: np.ndarray + ) -> np.ndarray: + """Reassemble windowed data using corresponding idxs. + + Args: + x: Data generated using a windowing operation. Shape (n_windows, n, window_size) + idxs: Indices of x in the original 1-d array. Shape (n_windows, window_size) + + Returns: + out: Reassembled data. Shape (n, n_idxs) + idxs: Reassembled idxs. Shape (n_idxs) + """ + x = np.asarray(x) + idxs = np.asarray(idxs) + # Transpose x (n, n_windows, window_size) + x = np.transpose(x, (1, 0, 2)) + # Adjust indices based on their window position + offset_idxs = idxs - np.arange(idxs.shape[0])[:, np.newaxis] + # Find strictly increasing indices using np.maximum.accumulate + flat_offset_idxs = offset_idxs.flatten() + max_so_far = np.maximum.accumulate(flat_offset_idxs.flatten()) + mask = (flat_offset_idxs == max_so_far) # Mask to keep only strictly increasing indices + # Filter data based on mask and extract the final result values + result = x.reshape(x.shape[0], -1)[:,mask] + idxs = idxs.flatten()[mask] + return result, idxs diff --git a/vitallens/ssd.py b/vitallens/ssd.py index 86098c2..44f240c 100644 --- a/vitallens/ssd.py +++ b/vitallens/ssd.py @@ -280,7 +280,7 @@ def scan_batch( Args: batch: The number of this batch. - b_batches: The total number of batches. + n_batches: The total number of batches. inputs: The video to analyze. Either a np.ndarray of shape (n_frames, h, w, 3) with a sequence of frames in unscaled uint8 RGB format, or a path to a video file. start: The index of first frame of the video to analyze in this batch. diff --git a/vitallens/utils.py b/vitallens/utils.py index f8e4816..2c73343 100644 --- a/vitallens/utils.py +++ b/vitallens/utils.py @@ -158,6 +158,7 @@ def parse_video_inputs( else: ds_factor = max(round(fps / target_fps), 1) target_idxs = None if ds_factor == 1 else list(range(video.shape[0])[0::ds_factor]) if trim is not None: + if target_idxs is None: target_idxs = range(video_shape_in[0]) target_idxs = [idx for idx in target_idxs if trim[0] <= idx < trim[1]] if roi is not None or target_size is not None or target_idxs is not None: if target_size is None and roi is not None: target_size = (int(roi[3]-roi[1]), int(roi[2]-roi[0]))