Skip to content

Commit

Permalink
In multi-request inference for VITALLENS, use a separate ROI per request
Browse files Browse the repository at this point in the history
  • Loading branch information
prouast committed Jul 21, 2024
1 parent a2a4196 commit 0dc0bde
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 42 deletions.
15 changes: 14 additions & 1 deletion tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys
sys.path.append('../vitallens-python')

from vitallens.signal import windowed_mean, windowed_freq
from vitallens.signal import windowed_mean, windowed_freq, reassemble_from_windows

def test_windowed_mean():
x = np.asarray([0., 1., 2., 3., 4., 5., 6.])
Expand All @@ -47,4 +47,17 @@ def test_estimate_freq_periodogram(num, freq, window_size):
windowed_freq(x=y, window_size=window_size, overlap=window_size//2, f_s=len(x), f_range=(max(freq-2,1),freq+2), f_res=0.05),
np.full((num,), fill_value=freq),
rtol=1)

def test_reassemble_from_windows():
x = np.array([[[2.0, 4.0, 6.0, 8.0, 10.0], [7.0, 1.0, 10.0, 12.0, 18.0]],
[[2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0, 11.0]]], dtype=np.float32).transpose(1, 0, 2)
idxs = np.array([[1, 3, 5, 7, 9], [5, 6, 9, 11, 13]], dtype=np.int64)
out_x, out_idxs = reassemble_from_windows(x=x, idxs=idxs)
np.testing.assert_equal(
out_x,
np.asarray([[2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 18.0],
[2.0, 3.0, 4.0, 5.0, 6.0, 10.0, 11.0]]))
np.testing.assert_equal(
out_idxs,
np.asarray([1, 3, 5, 7, 9, 11, 13]))

111 changes: 71 additions & 40 deletions vitallens/methods/vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vitallens.methods.rppg_method import RPPGMethod
from vitallens.signal import detrend_lambda_for_hr_response, detrend_lambda_for_rr_response
from vitallens.signal import moving_average_size_for_hr_response, moving_average_size_for_rr_response
from vitallens.signal import reassemble_from_windows
from vitallens.utils import probe_video_inputs, parse_video_inputs

class VitalLensRPPGMethod(RPPGMethod):
Expand Down Expand Up @@ -72,49 +73,43 @@ def __call__(
live: The face live confidence. Shape (1, n_frames)
"""
inputs_shape, fps = probe_video_inputs(video=frames, fps=fps)
# Choose representative face detection
# TODO: For longer videos extract chunks from separate locations?
face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))]
roi = get_roi_from_det(
face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1]))
if np.any(np.logical_or(
(faces[:,2] - faces[:,0]) * 0.5 < np.maximum(0, faces[:,0] - roi[0]) + np.maximum(0, faces[:,2] - roi[2]),
(faces[:,3] - faces[:,1]) * 0.5 < np.maximum(0, faces[:,1] - roi[1]) + np.maximum(0, faces[:,3] - roi[3]))):
logging.warn("Large face movement detected")
# Parse the inputs
logging.debug("Preparing video for inference...")
frames_ds, fps, inputs_shape, ds_factor, _ = parse_video_inputs(
video=frames, fps=fps, target_size=self.input_size, roi=roi,
target_fps=override_fps_target if override_fps_target is not None else self.fps_target,
library='prpy', scale_algorithm='bilinear')
# Check the number of frames to be processed
ds_len = frames_ds.shape[0]
if ds_len <= API_MAX_FRAMES:
inputs_n = inputs_shape[0]
fps_target = override_fps_target if override_fps_target is not None else self.fps_target
expected_ds_factor = round(fps / fps_target)
expected_ds_n = math.ceil(inputs_n / expected_ds_factor)
if expected_ds_n <= API_MAX_FRAMES:
# API supports up to MAX_FRAMES at once - process all frames
sig_ds, conf_ds, live_ds = self.process_api(frames_ds)
sig_ds, conf_ds, live_ds, idxs = self.process_api_batch(
batch=1, n_batches=1, inputs=frames, inputs_shape=inputs_shape,
faces=faces, fps_target=fps_target, fps=fps)
else:
# Longer videos are split up with small overlaps
ds_len = frames_ds.shape[0]
n_splits = math.ceil((ds_len - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1
split_len = math.ceil((ds_len + (n_splits-1) * API_OVERLAP) / n_splits)
start_idxs = [i for i in range(0, ds_len - n_splits * API_OVERLAP, split_len - API_OVERLAP)]
end_idxs = [min(i + split_len, ds_len) for i in start_idxs]
logging.info("Running inference for {} frames using {} requests...".format(ds_len, n_splits))
n_splits = math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1
split_len = math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits)
# start_idxs = [i for i in range(0, expected_ds_len - n_splits * API_OVERLAP, split_len - API_OVERLAP)]
start_idxs = [i * (split_len - API_OVERLAP * expected_ds_factor) for i in range(n_splits)]
end_idxs = [min(start + split_len, inputs_n) for start in start_idxs]
start_idxs = [max(0, end - split_len) for end in end_idxs]
logging.info("Running inference for {} frames using {} requests...".format(expected_ds_n, n_splits))
# Process the splits in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(lambda i: self.process_api(frames_ds[start_idxs[i]:end_idxs[i]]), range(n_splits)))
results = list(executor.map(lambda i: self.process_api_batch(
batch=i, n_batches=n_splits, inputs=frames, inputs_shape=inputs_shape,
faces=faces, fps_target=fps_target, start=start_idxs[i], end=end_idxs[i],
fps=fps), range(n_splits)))
# Aggregate the results
sig_results, conf_results, live_results = zip(*results)
sig_ds = np.concatenate([sig_results[0]] + [[x[API_OVERLAP:] for x in e] for e in sig_results[1:]], axis=-1)
conf_ds = np.concatenate([conf_results[0]] + [[x[API_OVERLAP:] for x in e] for e in conf_results[1:]], axis=-1)
live_ds = np.concatenate([live_results[0]] + [x[API_OVERLAP:] for x in live_results[1:]], axis=-1)
sig_results, conf_results, live_results, idxs_results = zip(*results)
sig_ds, idxs = reassemble_from_windows(x=sig_results, idxs=idxs_results)
conf_ds, _ = reassemble_from_windows(x=conf_results, idxs=idxs_results)
live_ds = reassemble_from_windows(x=np.asarray(live_results)[:,np.newaxis], idxs=idxs_results)[0][0]
# Interpolate to original sampling rate (n_frames,)
sig = interpolate_cubic_spline(
x=np.arange(inputs_shape[0])[0::ds_factor], y=sig_ds, xs=np.arange(inputs_shape[0]), axis=1)
x=idxs, y=sig_ds, xs=np.arange(inputs_n), axis=1)
conf = interpolate_cubic_spline(
x=np.arange(inputs_shape[0])[0::ds_factor], y=conf_ds, xs=np.arange(inputs_shape[0]), axis=1)
x=idxs, y=conf_ds, xs=np.arange(inputs_n), axis=1)
live = interpolate_cubic_spline(
x=np.arange(inputs_shape[0])[0::ds_factor], y=live_ds, xs=np.arange(inputs_shape[0]), axis=0)
x=idxs, y=live_ds, xs=np.arange(inputs_n), axis=0)
# Filter (n_frames,)
sig = np.asarray([self.postprocess(p, fps, type=name) for p, name in zip(sig, ['ppg', 'resp'])])
# Estimate summary vitals
Expand Down Expand Up @@ -153,23 +148,58 @@ def __call__(
out_conf[name] = conf[1]
out_note[name] = 'Estimate of the respiratory waveform using VitalLens, along with frame-wise confidences between 0 and 1.'
return out_data, out_unit, out_conf, out_note, live
def process_api(
def process_api_batch(
self,
frames: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Process frames with the VitalLens API.
batch: int,
n_batches: int,
inputs: Tuple[np.ndarray, str],
inputs_shape: tuple,
faces: np.ndarray,
fps_target: float,
start: int = None,
end: int = None,
fps: float = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Process a batch of frames with the VitalLens API.
Args:
frames: The video frames. Shape (n_frames<=MAX_FRAMES, h==input_size, w==input_size, 3)
batch: The number of this batch.
n_batches: The total number of batches.
inputs: The video to analyze. Either a np.ndarray of shape (n_frames, h, w, 3)
with a sequence of frames in unscaled uint8 RGB format, or a path to a video file.
inputs_shape: The shape of the inputs.
faces: The face detection boxes as np.int64. Shape (n_frames, 4) in form (x0, y0, x1, y1)
fps_target: The target frame rate at which to run inference.
start: The index of first frame of the video to analyze in this batch.
end: The index of the last frame of the video to analyze in this batch.
fps: The frame rate of the input video. Required if type(video) == np.ndarray
Returns:
sig: Estimated signals. Shape (n_sig, n_frames)
conf: Estimation confidences. Shape (n_sig, n_frames)
live: Liveness estimation. Shape (n_frames,)
idxs: Indices in inputs that were processed. Shape (n_frames)
"""
assert frames.shape[0] <= API_MAX_FRAMES
logging.debug("Batch {}/{}...".format(batch, n_batches))
# Trim face detections to batch if necessary
if start is not None and end is not None:
faces = faces[start:end]
# Choose representative face detection
face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))]
roi = get_roi_from_det(
face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1]))
if np.any(np.logical_or(
(faces[:,2] - faces[:,0]) * 0.5 < np.maximum(0, faces[:,0] - roi[0]) + np.maximum(0, faces[:,2] - roi[2]),
(faces[:,3] - faces[:,1]) * 0.5 < np.maximum(0, faces[:,1] - roi[1]) + np.maximum(0, faces[:,3] - roi[3]))):
logging.warn("Large face movement detected")
# Parse the inputs
frames_ds, fps, inputs_shape, _, idxs = parse_video_inputs(
video=inputs, fps=fps, target_size=self.input_size, roi=roi, target_fps=fps_target,
trim=(start, end) if start is not None and end is not None else None,
library='prpy', scale_algorithm='bilinear')
assert frames_ds.shape[0] <= API_MAX_FRAMES
# Prepare API header and payload
headers = {"x-api-key": self.api_key}
payload = {"video": base64.b64encode(frames.tobytes()).decode('utf-8')}
payload = {"video": base64.b64encode(frames_ds.tobytes()).decode('utf-8')}
# Ask API to process video
response = requests.post(API_URL, headers=headers, json=payload)
response_body = json.loads(response.text)
Expand All @@ -194,7 +224,8 @@ def process_api(
np.asarray(response_body["vital_signs"]["respiratory_waveform"]["confidence"]),
], axis=0)
live_ds = np.asarray(response_body["face"]["confidence"])
return sig_ds, conf_ds, live_ds
idxs = np.asarray(idxs)
return sig_ds, conf_ds, live_ds, idxs
def postprocess(self, sig, fps, type='ppg', filter=True):
"""Apply filters to the estimated signal.
Args:
Expand Down
29 changes: 29 additions & 0 deletions vitallens/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,32 @@ def windowed_freq(
freq_vals.shape[0], n)
# Return
return freq_vals

def reassemble_from_windows(
x: np.ndarray,
idxs: np.ndarray
) -> np.ndarray:
"""Reassemble windowed data using corresponding idxs.
Args:
x: Data generated using a windowing operation. Shape (n_windows, n, window_size)
idxs: Indices of x in the original 1-d array. Shape (n_windows, window_size)
Returns:
out: Reassembled data. Shape (n, n_idxs)
idxs: Reassembled idxs. Shape (n_idxs)
"""
x = np.asarray(x)
idxs = np.asarray(idxs)
# Transpose x (n, n_windows, window_size)
x = np.transpose(x, (1, 0, 2))
# Adjust indices based on their window position
offset_idxs = idxs - np.arange(idxs.shape[0])[:, np.newaxis]
# Find strictly increasing indices using np.maximum.accumulate
flat_offset_idxs = offset_idxs.flatten()
max_so_far = np.maximum.accumulate(flat_offset_idxs.flatten())
mask = (flat_offset_idxs == max_so_far) # Mask to keep only strictly increasing indices
# Filter data based on mask and extract the final result values
result = x.reshape(x.shape[0], -1)[:,mask]
idxs = idxs.flatten()[mask]
return result, idxs
2 changes: 1 addition & 1 deletion vitallens/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def scan_batch(
Args:
batch: The number of this batch.
b_batches: The total number of batches.
n_batches: The total number of batches.
inputs: The video to analyze. Either a np.ndarray of shape (n_frames, h, w, 3)
with a sequence of frames in unscaled uint8 RGB format, or a path to a video file.
start: The index of first frame of the video to analyze in this batch.
Expand Down
1 change: 1 addition & 0 deletions vitallens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def parse_video_inputs(
else: ds_factor = max(round(fps / target_fps), 1)
target_idxs = None if ds_factor == 1 else list(range(video.shape[0])[0::ds_factor])
if trim is not None:
if target_idxs is None: target_idxs = range(video_shape_in[0])
target_idxs = [idx for idx in target_idxs if trim[0] <= idx < trim[1]]
if roi is not None or target_size is not None or target_idxs is not None:
if target_size is None and roi is not None: target_size = (int(roi[3]-roi[1]), int(roi[2]-roi[0]))
Expand Down

0 comments on commit 0dc0bde

Please sign in to comment.