From c3d96087aaf475f8580854efaf934b644eb62f91 Mon Sep 17 00:00:00 2001 From: Philipp Rouast Date: Tue, 23 Jul 2024 14:45:20 +0200 Subject: [PATCH] Either parse video globally or not depending on circumstances; Add robustness for videos with issues which result in unexpected number of frames; Allow overwrite of global parse setting; Return more informative errors --- pyproject.toml | 2 +- tests/conftest.py | 2 +- tests/test_ssd.py | 1 - tests/test_utils.py | 5 +- tests/test_vitallens.py | 11 ++- vitallens/client.py | 9 +- vitallens/constants.py | 3 + vitallens/methods/simple_rppg_method.py | 4 +- vitallens/methods/vitallens.py | 107 +++++++++++++++--------- vitallens/utils.py | 32 +++---- 10 files changed, 111 insertions(+), 65 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ff2e71..9fba0ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "importlib_resources", "numpy", "onnxruntime", - "prpy[ffmpeg,numpy_min]>=0.2.8", + "prpy[ffmpeg,numpy_min]==0.2.10", "python-dotenv", "pyyaml", "requests", diff --git a/tests/conftest.py b/tests/conftest.py index cee2760..766da65 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,7 +52,7 @@ def test_video_fps(): @pytest.fixture(scope='session') def test_video_shape(): - _, n, w, h, _, _, _ = probe_video(TEST_VIDEO_PATH) + _, n, w, h, *_ = probe_video(TEST_VIDEO_PATH) return (n, h, w, 3) @pytest.fixture(scope='session') diff --git a/tests/test_ssd.py b/tests/test_ssd.py index bccfc47..c5d1468 100644 --- a/tests/test_ssd.py +++ b/tests/test_ssd.py @@ -19,7 +19,6 @@ # SOFTWARE. import numpy as np -from prpy.ffmpeg.probe import probe_video import pytest import sys diff --git a/tests/test_utils.py b/tests/test_utils.py index 2bd5db0..01974e1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -37,13 +37,14 @@ def test_load_config(method): def test_probe_video_inputs(request, file): if file: test_video_path = request.getfixturevalue('test_video_path') - video_shape, fps = probe_video_inputs(test_video_path) + video_shape, fps, i = probe_video_inputs(test_video_path) else: test_video_ndarray = request.getfixturevalue('test_video_ndarray') test_video_fps = request.getfixturevalue('test_video_fps') - video_shape, fps = probe_video_inputs(test_video_ndarray, fps=test_video_fps) + video_shape, fps, i = probe_video_inputs(test_video_ndarray, fps=test_video_fps) assert video_shape == (360, 480, 768, 3) assert fps == 30 + assert i == False def test_probe_video_inputs_no_file(): with pytest.raises(Exception): diff --git a/tests/test_vitallens.py b/tests/test_vitallens.py index 33dce8c..36d6f43 100644 --- a/tests/test_vitallens.py +++ b/tests/test_vitallens.py @@ -94,8 +94,9 @@ def create_mock_api_response( @pytest.mark.parametrize("file", [True, False]) @pytest.mark.parametrize("override_fps_target", [None, 15, 10]) @pytest.mark.parametrize("long", [False, True]) +@pytest.mark.parametrize("override_global_parse", [False, True]) @patch('requests.post', side_effect=create_mock_api_response) -def test_VitalLensRPPGMethod_mock(mock_post, request, file, override_fps_target, long): +def test_VitalLensRPPGMethod_mock(mock_post, request, file, long, override_fps_target, override_global_parse): if long and file: pytest.skip("Skip because parameter combination does not work") config = load_config("vitallens.yaml") @@ -108,15 +109,17 @@ def test_VitalLensRPPGMethod_mock(mock_post, request, file, override_fps_target, if file: data, unit, conf, note, live = method( frames=test_video_path, faces=test_video_faces, - override_fps_target=override_fps_target) - else: + override_fps_target=override_fps_target, + override_global_parse=override_global_parse) + else: if long: n_repeats = (API_MAX_FRAMES * 3) // test_video_ndarray.shape[0] + 1 test_video_ndarray = np.repeat(test_video_ndarray, repeats=n_repeats, axis=0) test_video_faces = np.repeat(test_video_faces, repeats=n_repeats, axis=0) data, unit, conf, note, live = method( frames=test_video_ndarray, faces=test_video_faces, - fps=test_video_fps, override_fps_target=override_fps_target) + fps=test_video_fps, override_fps_target=override_fps_target, + override_global_parse=override_global_parse) assert all(key in data for key in method.signals) assert all(key in unit for key in method.signals) assert all(key in conf for key in method.signals) diff --git a/vitallens/client.py b/vitallens/client.py index 207f0f9..ab1746f 100644 --- a/vitallens/client.py +++ b/vitallens/client.py @@ -107,6 +107,7 @@ def __call__( faces: Union[np.ndarray, list] = None, fps: float = None, override_fps_target: float = None, + override_global_parse: bool = None, export_filename: str = None ) -> list: """Run rPPG inference. @@ -124,6 +125,8 @@ def __call__( fps: Sampling frequency of the input video. Required if type(video) == np.ndarray. override_fps_target: Target fps at which rPPG inference should be run (optional). If not provided, will use default of the selected method. + override_global_parse: If True, always use global parse. If False, don't use global parse. + If None, choose based on video. export_filename: Filename for json export if applicable. Returns: result: Analysis results as a list of faces in the following format: @@ -169,7 +172,7 @@ def __call__( ] """ # Probe inputs - inputs_shape, fps = probe_video_inputs(video=video, fps=fps) + inputs_shape, fps, _ = probe_video_inputs(video=video, fps=fps) # TODO: Optimize performance of simple rPPG methods for long videos # Warning if using long video target_fps = override_fps_target if override_fps_target is not None else self.rppg.fps_target @@ -194,7 +197,9 @@ def __call__( for face in faces: # Run selected rPPG method data, unit, conf, note, live = self.rppg( - frames=video, faces=face, fps=fps, override_fps_target=override_fps_target) + frames=video, faces=face, fps=fps, + override_fps_target=override_fps_target, + override_global_parse=override_global_parse) # Parse face results face_result = {'face': { 'coordinates': face, diff --git a/vitallens/constants.py b/vitallens/constants.py index ec3fe38..5dfff40 100644 --- a/vitallens/constants.py +++ b/vitallens/constants.py @@ -38,5 +38,8 @@ if 'API_URL' in os.environ: API_URL = os.getenv('API_URL') +# Video error message +VIDEO_PARSE_ERROR = "Unable to parse input video. There may be an issue with the video file." + # Disclaimer message DISCLAIMER = "The provided values are estimates and should be interpreted according to the provided confidence levels ranging from 0 to 1. The VitalLens API is not a medical device and its estimates are not intended for any medical purposes." diff --git a/vitallens/methods/simple_rppg_method.py b/vitallens/methods/simple_rppg_method.py index f3cf8f9..a763410 100644 --- a/vitallens/methods/simple_rppg_method.py +++ b/vitallens/methods/simple_rppg_method.py @@ -57,7 +57,8 @@ def __call__( frames: Union[np.ndarray, str], faces: np.ndarray, fps: float, - override_fps_target: float = None + override_fps_target: float = None, + override_global_parse: float = None, ) -> Tuple[dict, dict, dict, dict, np.ndarray]: """Estimate pulse signal from video frames using the subclass algorithm. @@ -66,6 +67,7 @@ def __call__( faces: The face detection boxes as np.int64. Shape (n_frames, 4) in form (x0, y0, x1, y1) fps: The rate at which video was sampled. override_fps_target: Override the method's default inference fps (optional). + override_global_parse: Has no effect here. Returns: data: A dictionary with the values of the estimated vital signs. unit: A dictionary with the units of the estimated vital signs. diff --git a/vitallens/methods/vitallens.py b/vitallens/methods/vitallens.py index d4035d3..a85ca54 100644 --- a/vitallens/methods/vitallens.py +++ b/vitallens/methods/vitallens.py @@ -26,6 +26,7 @@ from prpy.numpy.face import get_roi_from_det from prpy.numpy.signal import detrend, moving_average, standardize from prpy.numpy.signal import interpolate_cubic_spline, estimate_freq +from prpy.numpy.utils import enough_memory_for_ndarray import json import logging import requests @@ -38,7 +39,7 @@ from vitallens.signal import detrend_lambda_for_hr_response, detrend_lambda_for_rr_response from vitallens.signal import moving_average_size_for_hr_response, moving_average_size_for_rr_response from vitallens.signal import reassemble_from_windows -from vitallens.utils import probe_video_inputs, parse_video_inputs +from vitallens.utils import probe_video_inputs, parse_video_inputs, check_faces_in_roi class VitalLensRPPGMethod(RPPGMethod): def __init__( @@ -56,7 +57,8 @@ def __call__( frames: Union[np.ndarray, str], faces: np.ndarray, fps: float = None, - override_fps_target: float = None + override_fps_target: float = None, + override_global_parse: bool = None ) -> Tuple[dict, dict, dict, dict, np.ndarray]: """Estimate vitals from video frames using the VitalLens API. @@ -66,6 +68,8 @@ def __call__( faces: The face detection boxes as np.int64. Shape (n_frames, 4) in form (x0, y0, x1, y1) fps: The rate at which video was sampled. override_fps_target: Override the method's default inference fps (optional). + override_global_parse: If True, always use global parse. If False, don't use global parse. + If None, choose based on video. Returns: out_data: The estimated data/value for each signal. out_unit: The estimation unit for each signal. @@ -73,37 +77,45 @@ def __call__( out_note: An explanatory note for each signal. live: The face live confidence. Shape (1, n_frames) """ - inputs_shape, fps = probe_video_inputs(video=frames, fps=fps) + inputs_shape, fps, video_issues = probe_video_inputs(video=frames, fps=fps) + video_fits_in_memory = enough_memory_for_ndarray( + shape=(inputs_shape[0], self.input_size, self.input_size, 3), dtype=np.uint8) # Check the number of frames to be processed inputs_n = inputs_shape[0] fps_target = override_fps_target if override_fps_target is not None else self.fps_target expected_ds_factor = round(fps / fps_target) expected_ds_n = math.ceil(inputs_n / expected_ds_factor) - if expected_ds_n <= API_MAX_FRAMES: - # API supports up to MAX_FRAMES at once - process all frames - sig_ds, conf_ds, live_ds, idxs = self.process_api_batch( - batch=1, n_batches=1, inputs=frames, inputs_shape=inputs_shape, - faces=faces, fps_target=fps_target, fps=fps) - else: - # Longer videos are split up with small overlaps - n_splits = math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1 - split_len = math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits) - # start_idxs = [i for i in range(0, expected_ds_len - n_splits * API_OVERLAP, split_len - API_OVERLAP)] - start_idxs = [i * (split_len - API_OVERLAP * expected_ds_factor) for i in range(n_splits)] - end_idxs = [min(start + split_len, inputs_n) for start in start_idxs] - start_idxs = [max(0, end - split_len) for end in end_idxs] - logging.info("Running inference for {} frames using {} requests...".format(expected_ds_n, n_splits)) - # Process the splits in parallel - with concurrent.futures.ThreadPoolExecutor() as executor: - results = list(executor.map(lambda i: self.process_api_batch( - batch=i, n_batches=n_splits, inputs=frames, inputs_shape=inputs_shape, - faces=faces, fps_target=fps_target, start=start_idxs[i], end=end_idxs[i], - fps=fps), range(n_splits))) - # Aggregate the results - sig_results, conf_results, live_results, idxs_results = zip(*results) - sig_ds, idxs = reassemble_from_windows(x=sig_results, idxs=idxs_results) - conf_ds, _ = reassemble_from_windows(x=conf_results, idxs=idxs_results) - live_ds = reassemble_from_windows(x=np.asarray(live_results)[:,np.newaxis], idxs=idxs_results)[0][0] + # Check if we can parse the video globally + global_face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))] + global_roi = get_roi_from_det( + global_face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1])) + global_faces_in_roi = check_faces_in_roi(faces=faces, roi=global_roi) + global_parse = isinstance(frames, str) and video_fits_in_memory and (video_issues or global_faces_in_roi) + if override_global_parse is not None: global_parse = override_global_parse + if global_parse: + # Parse entire video for inference globally + frames, _, _, _, idxs = parse_video_inputs( + video=frames, fps=fps, target_size=self.input_size, roi=global_roi, target_fps=fps_target, + library='prpy', scale_algorithm='bilinear', dim_deltas=(API_OVERLAP, 0, 0)) + # Longer videos are split up with small overlaps + n_splits = 1 if expected_ds_n <= API_MAX_FRAMES else math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1 + split_len = expected_ds_n if n_splits == 1 else math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits) + start_idxs = [i * (split_len - API_OVERLAP * expected_ds_factor) for i in range(n_splits)] + end_idxs = [min(start + split_len, inputs_n) for start in start_idxs] + start_idxs = [max(0, end - split_len) for end in end_idxs] + logging.info("Running inference for {} frames using {} request(s)...".format(expected_ds_n, n_splits)) + # Process the splits in parallel + with concurrent.futures.ThreadPoolExecutor() as executor: + results = list(executor.map(lambda i: self.process_api_batch( + batch=i, n_batches=n_splits, inputs=frames, inputs_shape=inputs_shape, + faces=faces, fps_target=fps_target, fps=fps, global_parse=global_parse, + start=None if n_splits == 1 else start_idxs[i], + end=None if n_splits == 1 else end_idxs[i]), range(n_splits))) + # Aggregate the results + sig_results, conf_results, live_results, idxs_results = zip(*results) + sig_ds, idxs = reassemble_from_windows(x=sig_results, idxs=idxs_results) + conf_ds, _ = reassemble_from_windows(x=conf_results, idxs=idxs_results) + live_ds = reassemble_from_windows(x=np.asarray(live_results)[:,np.newaxis], idxs=idxs_results)[0][0] # Interpolate to original sampling rate (n_frames,) sig = interpolate_cubic_spline( x=idxs, y=sig_ds, xs=np.arange(inputs_n), axis=1) @@ -159,7 +171,8 @@ def process_api_batch( fps_target: float, start: int = None, end: int = None, - fps: float = None + fps: float = None, + global_parse: bool = False ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Process a batch of frames with the VitalLens API. @@ -168,12 +181,13 @@ def process_api_batch( n_batches: The total number of batches. inputs: The video to analyze. Either a np.ndarray of shape (n_frames, h, w, 3) with a sequence of frames in unscaled uint8 RGB format, or a path to a video file. - inputs_shape: The shape of the inputs. + inputs_shape: The original shape of the inputs. faces: The face detection boxes as np.int64. Shape (n_frames, 4) in form (x0, y0, x1, y1) fps_target: The target frame rate at which to run inference. start: The index of first frame of the video to analyze in this batch. end: The index of the last frame of the video to analyze in this batch. fps: The frame rate of the input video. Required if type(video) == np.ndarray + global_parse: Flag that indicates whether video has already been parsed. Returns: sig: Estimated signals. Shape (n_sig, n_frames) conf: Estimation confidences. Shape (n_sig, n_frames) @@ -188,16 +202,31 @@ def process_api_batch( face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))] roi = get_roi_from_det( face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1])) - if np.any(np.logical_or( - (faces[:,2] - faces[:,0]) * 0.5 < np.maximum(0, faces[:,0] - roi[0]) + np.maximum(0, faces[:,2] - roi[2]), - (faces[:,3] - faces[:,1]) * 0.5 < np.maximum(0, faces[:,1] - roi[1]) + np.maximum(0, faces[:,3] - roi[3]))): + if not check_faces_in_roi(faces=faces, roi=roi): logging.warning("Large face movement detected") - # Parse the inputs - frames_ds, fps, inputs_shape, _, idxs = parse_video_inputs( - video=inputs, fps=fps, target_size=self.input_size, roi=roi, target_fps=fps_target, - trim=(start, end) if start is not None and end is not None else None, - library='prpy', scale_algorithm='bilinear') - assert frames_ds.shape[0] <= API_MAX_FRAMES + if global_parse: + # Inputs have already been parsed globally. + assert isinstance(inputs, np.ndarray) + frames_ds = inputs + ds_factor = math.ceil(inputs_shape[0] / frames_ds.shape[0]) + # Trim frames to batch if necessary + if start is not None and end is not None: + start_ds = start // ds_factor + end_ds = math.ceil((end-start)/ds_factor) + start_ds + frames_ds = frames_ds[start_ds:end_ds] + idxs = list(range(start, end, ds_factor)) + else: + idxs = list(range(0, inputs_shape[0], ds_factor)) + else: + # Inputs have not been parsed globally. Parse the inputs + frames_ds, _, _, ds_factor, idxs = parse_video_inputs( + video=inputs, fps=fps, target_size=self.input_size, roi=roi, target_fps=fps_target, + trim=(start, end) if start is not None and end is not None else None, + library='prpy', scale_algorithm='bilinear', dim_deltas=(API_OVERLAP, 0, 0)) + # Make sure we have the correct number of frames + expected_n = math.ceil(((end-start) if start is not None and end is not None else inputs_shape[0]) / ds_factor) + if frames_ds.shape[0] != expected_n or len(idxs) != expected_n: + raise ValueError("Unexpected number of frames returned. Try to set `override_global_parse` to `True` or `False`.") # Prepare API header and payload headers = {"x-api-key": self.api_key} payload = {"video": base64.b64encode(frames_ds.tobytes()).decode('utf-8')} diff --git a/vitallens/utils.py b/vitallens/utils.py index 30f9880..7818828 100644 --- a/vitallens/utils.py +++ b/vitallens/utils.py @@ -30,7 +30,7 @@ import urllib.request import yaml -from vitallens.constants import API_MIN_FRAMES +from vitallens.constants import API_MIN_FRAMES, VIDEO_PARSE_ERROR def load_config(filename: str) -> dict: """Load a yaml config file. @@ -72,6 +72,7 @@ def probe_video_inputs( Returns: video_shape: The shape of the input video as (n_frames, h, w, 3) fps: Sampling frequency of the input video. + issues: True if a possible issue with the video has been detected. """ # Check that fps is correct type if not (fps is None or isinstance(fps, (int, float))): @@ -80,11 +81,11 @@ def probe_video_inputs( if isinstance(video, str): if os.path.isfile(video): try: - fps_, n, w_, h_, _, _, r = probe_video(video) + fps_, n, w_, h_, _, _, r, i = probe_video(video) if fps is None: fps = fps_ if abs(r) == 90: h = w_; w = h_ else: h = h_; w = w_ - return (n, h, w, 3), fps + return (n, h, w, 3), fps, i except Exception as e: raise ValueError("Problem probing video at {}: {}".format(video, e)) else: @@ -96,7 +97,7 @@ def probe_video_inputs( raise ValueError("video.dtype should be uint8, but got {}".format(video.dtype)) if len(video.shape) != 4 or video.shape[0] < API_MIN_FRAMES or video.shape[3] != 3: raise ValueError("video should have shape (n_frames [>= {}], h, w, 3), but found {}".format(API_MIN_FRAMES, video.shape)) - return video.shape, fps + return video.shape, fps, False else: raise ValueError("Invalid video {}, type {}".format(video, type(input))) @@ -135,15 +136,18 @@ def parse_video_inputs( if isinstance(video, str): if os.path.isfile(video): try: - fps_, n, w_, h_, _, _, r = probe_video(video) + fps_, n, w_, h_, _, _, r, i = probe_video(video) if fps is None: fps = fps_ if roi is not None: roi = (int(roi[0]), int(roi[1]), int(roi[2]-roi[0]), int(roi[3]-roi[1])) if isinstance(target_size, tuple): target_size = (target_size[1], target_size[0]) if abs(r) == 90: h = w_; w = h_ else: h = h_; w = w_ - video, ds_factor = read_video_from_path( - path=video, target_fps=target_fps, crop=roi, scale=target_size, trim=trim, - pix_fmt='rgb24', dim_deltas=dim_deltas, scale_algorithm=scale_algorithm) + try: + video, ds_factor = read_video_from_path( + path=video, target_fps=target_fps, crop=roi, scale=target_size, trim=trim, + pix_fmt='rgb24', dim_deltas=dim_deltas, scale_algorithm=scale_algorithm) + except: + ValueError(VIDEO_PARSE_ERROR) expected_n = math.ceil(((trim[1]-trim[0]) if trim is not None else n) / ds_factor) if video.shape[0] < expected_n: logging.warning("Less frames received than expected (delta = {}) - this may indicate an issue with the video file. Padding to avoid issues.".format(video.shape[0]-expected_n)) @@ -155,7 +159,7 @@ def parse_video_inputs( end_idx = min(n, trim[1]) if trim is not None else n idxs = list(range(start_idx, end_idx, ds_factor)) if video.shape[0] != expected_n or len(idxs) != expected_n: - raise ValueError("Unable to parse input video. Possible issue with video file.") + raise ValueError(VIDEO_PARSE_ERROR) return video, fps, (n, h, w, 3), ds_factor, idxs except Exception as e: raise ValueError("Problem reading video from {}: {}".format(video, e)) @@ -248,7 +252,7 @@ def check_faces( def check_faces_in_roi( faces: np.ndarray, - roi: np.ndarray, + roi: Union[np.ndarray, tuple, list], percentage_required_inside_roi: tuple = (0.5, 0.5) ) -> bool: """Check whether all faces are sufficiently inside the ROI. @@ -264,10 +268,10 @@ def check_faces_in_roi( faces_w = faces[:,2] - faces[:,0] faces_h = faces[:,3] - faces[:,1] faces_inside_roi = np.logical_and( - np.logical_and(faces[:,2] - roi[0] > percentage_required_inside_roi * faces_w, - roi[2] - faces[:,0] > percentage_required_inside_roi * faces_w), - np.logical_and(faces[:,3] - roi[1] > percentage_required_inside_roi * faces_h, - roi[3] - faces[:,1] > percentage_required_inside_roi * faces_h)) + np.logical_and(faces[:,2] - roi[0] > percentage_required_inside_roi[0] * faces_w, + roi[2] - faces[:,0] > percentage_required_inside_roi[0] * faces_w), + np.logical_and(faces[:,3] - roi[1] > percentage_required_inside_roi[1] * faces_h, + roi[3] - faces[:,1] > percentage_required_inside_roi[1] * faces_h)) facess_inside_roi = np.all(faces_inside_roi) return facess_inside_roi