diff --git a/vitallens/methods/chrom.py b/vitallens/methods/chrom.py index 308b26e..0d0d265 100644 --- a/vitallens/methods/chrom.py +++ b/vitallens/methods/chrom.py @@ -28,16 +28,22 @@ from vitallens.signal import detrend_lambda_for_hr_response class CHROMRPPGMethod(SimpleRPPGMethod): + """The CHROM algorithm by De Haan and Jeanne (2013)""" def __init__( self, config: dict ): + """Initialize the `CHROMRPPGMethod` + + Args: + config: The configuration dict + """ super(CHROMRPPGMethod, self).__init__(config=config) def algorithm( self, rgb: np.ndarray, fps: float - ): + ) -> np.ndarray: """Use CHROM algorithm to estimate pulse from rgb signal. Args: @@ -77,7 +83,7 @@ def pulse_filter( self, sig: np.ndarray, fps: float - ): + ) -> np.ndarray: """Apply filters to the estimated pulse signal. Args: diff --git a/vitallens/methods/g.py b/vitallens/methods/g.py index 35242f5..74913a3 100644 --- a/vitallens/methods/g.py +++ b/vitallens/methods/g.py @@ -26,10 +26,16 @@ from vitallens.signal import moving_average_size_for_hr_response class GRPPGMethod(SimpleRPPGMethod): + """The G algorithm by Verkruysse (2008)""" def __init__( self, config: dict ): + """Initialize the `GRPPGMethod` + + Args: + config: The configuration dict + """ super(GRPPGMethod, self).__init__(config=config) def algorithm( self, diff --git a/vitallens/methods/pos.py b/vitallens/methods/pos.py index 14ed236..38f0bfe 100644 --- a/vitallens/methods/pos.py +++ b/vitallens/methods/pos.py @@ -28,10 +28,16 @@ from vitallens.signal import moving_average_size_for_hr_response class POSRPPGMethod(SimpleRPPGMethod): + """The POS algorithm by Wang et al. (2017)""" def __init__( self, config: dict ): + """Initialize the `POSRPPGMethod` + + Args: + config: The configuration dict + """ super(POSRPPGMethod, self).__init__(config=config) def algorithm( self, diff --git a/vitallens/methods/rppg_method.py b/vitallens/methods/rppg_method.py index b8d66cb..9552a8e 100644 --- a/vitallens/methods/rppg_method.py +++ b/vitallens/methods/rppg_method.py @@ -19,13 +19,21 @@ # SOFTWARE. import abc +import numpy as np class RPPGMethod(metaclass=abc.ABCMeta): - def __init__(self, config): + """Abstract superclass for rPPG methods""" + def __init__(self, config: dict): + """Initialize the `RPPGMethod` + + Args: + config: The configuration dict + """ self.fps_target = config['fps_target'] self.est_window_length = config['est_window_length'] self.est_window_overlap = config['est_window_overlap'] self.est_window_flexible = self.est_window_length == 0 @abc.abstractmethod - def __call__(self, video, fps, mode): + def __call__(self, frames, faces, fps, override_fps_target, override_global_parse): + """Run inference. Abstract method to be implemented in subclasses.""" pass diff --git a/vitallens/methods/simple_rppg_method.py b/vitallens/methods/simple_rppg_method.py index cb086fe..cc264ac 100644 --- a/vitallens/methods/simple_rppg_method.py +++ b/vitallens/methods/simple_rppg_method.py @@ -31,10 +31,16 @@ from vitallens.utils import parse_video_inputs, merge_faces class SimpleRPPGMethod(RPPGMethod): + """A simple rPPG method using a handcrafted algorithm based on RGB signal trace""" def __init__( self, config: dict ): + """Initialize the `SimpleRPPGMethod` + + Args: + config: The configuration dict + """ super(SimpleRPPGMethod, self).__init__(config=config) self.model = config['model'] self.roi_method = config['roi_method'] @@ -45,12 +51,14 @@ def algorithm( rgb: np.ndarray, fps: float ): + """The algorithm. Abstract method to be implemented by subclasses.""" pass @abc.abstractmethod def pulse_filter(self, sig: np.ndarray, fps: float ) -> np.ndarray: + """The post-processing filter to be applied to estimated pulse signal. Abstract method to be implemented by subclasses.""" pass def __call__( self, @@ -70,11 +78,12 @@ def __call__( override_fps_target: Override the method's default inference fps (optional). override_global_parse: Has no effect here. Returns: - data: A dictionary with the values of the estimated vital signs. - unit: A dictionary with the units of the estimated vital signs. - conf: A dictionary with the confidences of the estimated vital signs. - note: A dictionary with notes on the estimated vital signs. - live: Dummy live confidence estimation (set to always 1). Shape (n_frames,) + Tuple of + - data: A dictionary with the values of the estimated vital signs. + - unit: A dictionary with the units of the estimated vital signs. + - conf: A dictionary with the confidences of the estimated vital signs. + - note: A dictionary with notes on the estimated vital signs. + - live: Dummy live confidence estimation (set to always 1). Shape (n_frames,) """ # Compute temporal union of ROIs u_roi = merge_faces(faces) diff --git a/vitallens/methods/vitallens.py b/vitallens/methods/vitallens.py index c896af9..6c9f7f3 100644 --- a/vitallens/methods/vitallens.py +++ b/vitallens/methods/vitallens.py @@ -42,11 +42,18 @@ from vitallens.utils import probe_video_inputs, parse_video_inputs, check_faces_in_roi class VitalLensRPPGMethod(RPPGMethod): + """RPPG method using the VitalLens API for inference""" def __init__( self, config: dict, api_key: str ): + """Initialize the `VitalLensRPPGMethod` + + Args: + config: The configuration dict + api_key: The API key + """ super(VitalLensRPPGMethod, self).__init__(config=config) self.api_key = api_key self.input_size = config['input_size'] @@ -71,11 +78,12 @@ def __call__( override_global_parse: If True, always use global parse. If False, don't use global parse. If None, choose based on video. Returns: - out_data: The estimated data/value for each signal. - out_unit: The estimation unit for each signal. - out_conf: The estimation confidence for each signal. - out_note: An explanatory note for each signal. - live: The face live confidence. Shape (1, n_frames) + Tuple of + - out_data: The estimated data/value for each signal. + - out_unit: The estimation unit for each signal. + - out_conf: The estimation confidence for each signal. + - out_note: An explanatory note for each signal. + - live: The face live confidence. Shape (1, n_frames) """ inputs_shape, fps, video_issues = probe_video_inputs(video=frames, fps=fps) video_fits_in_memory = enough_memory_for_ndarray( @@ -189,10 +197,11 @@ def process_api_batch( fps: The frame rate of the input video. Required if type(video) == np.ndarray global_parse: Flag that indicates whether video has already been parsed. Returns: - sig: Estimated signals. Shape (n_sig, n_frames) - conf: Estimation confidences. Shape (n_sig, n_frames) - live: Liveness estimation. Shape (n_frames,) - idxs: Indices in inputs that were processed. Shape (n_frames) + Tuple of + - sig: Estimated signals. Shape (n_sig, n_frames) + - conf: Estimation confidences. Shape (n_sig, n_frames) + - live: Liveness estimation. Shape (n_frames,) + - idxs: Indices in inputs that were processed. Shape (n_frames) """ logging.debug("Batch {}/{}...".format(batch, n_batches)) # Trim face detections to batch if necessary @@ -256,7 +265,13 @@ def process_api_batch( live_ds = np.asarray(response_body["face"]["confidence"]) idxs = np.asarray(idxs) return sig_ds, conf_ds, live_ds, idxs - def postprocess(self, sig, fps, type='ppg', filter=True): + def postprocess( + self, + sig: np.ndarray, + fps: float, + type: str = 'ppg', + filter: bool = True + ) -> np.ndarray: """Apply filters to the estimated signal. Args: sig: The estimated signal. Shape (n_frames,) diff --git a/vitallens/signal.py b/vitallens/signal.py index 340ce52..f11526a 100644 --- a/vitallens/signal.py +++ b/vitallens/signal.py @@ -28,22 +28,50 @@ def moving_average_size_for_hr_response( f_s: Union[float, int] - ): + ) -> int: + """Get the moving average window size for a signal with HR information sampled at a given frequency + + Args: + f_s: The sampling frequency + Returns: + The moving average size in number of signal vals + """ return moving_average_size_for_response(f_s, CALC_HR_MAX / SECONDS_PER_MINUTE) def moving_average_size_for_rr_response( f_s: Union[float, int] - ): + ) -> int: + """Get the moving average window size for a signal with RR information sampled at a given frequency + + Args: + f_s: The sampling frequency + Returns: + The moving average size in number of signal vals + """ return moving_average_size_for_response(f_s, CALC_RR_MAX / SECONDS_PER_MINUTE) def detrend_lambda_for_hr_response( f_s: Union[float, int] - ): + ) -> int: + """Get the detrending lambda parameter for a signal with HR information sampled at a given frequency + + Args: + f_s: The sampling frequency + Returns: + The lambda parameter + """ return int(0.1614*np.power(f_s, 1.9804)) def detrend_lambda_for_rr_response( f_s: Union[float, int] - ): + ) -> int: + """Get the detrending lambda parameter for a signal with RR information sampled at a given frequency + + Args: + f_s: The sampling frequency + Returns: + The lambda parameter + """ return int(4.4248*np.power(f_s, 2.1253)) def windowed_mean( @@ -126,16 +154,16 @@ def windowed_freq( def reassemble_from_windows( x: np.ndarray, idxs: np.ndarray - ) -> np.ndarray: + ) -> Tuple[np.ndarray, np.ndarray]: """Reassemble windowed data using corresponding idxs. Args: x: Data generated using a windowing operation. Shape (n_windows, n, window_size) idxs: Indices of x in the original 1-d array. Shape (n_windows, window_size) - Returns: - out: Reassembled data. Shape (n, n_idxs) - idxs: Reassembled idxs. Shape (n_idxs) + Tuple of + - out: Reassembled data. Shape (n, n_idxs) + - idxs: Reassembled idxs. Shape (n_idxs,) """ x = np.asarray(x) idxs = np.asarray(idxs) diff --git a/vitallens/ssd.py b/vitallens/ssd.py index b70c1cb..82f98fc 100644 --- a/vitallens/ssd.py +++ b/vitallens/ssd.py @@ -53,7 +53,9 @@ def nms( iou_threshold: Threshold wrt iou for amount of box overlap. Scalar. score_threshold: Threshold wrt score for removing boxes. Scalar. Returns: - idxs: The selected indices padded with zero. Shape (n_batch, max_output_size) + Tuple of + - idxs: The selected indices padded with zero. Shape (n_batch, max_output_size) + - num_valid: Number of valid elements per batch. Shape (n_batch,) """ n_batch = boxes.shape[0] # Split up box coordinates @@ -108,8 +110,9 @@ def enforce_temporal_consistency( info: Detection info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4) n_frames: Number of frames in the original input. Returns: - boxes: Processed boxes in point form [0, 1], shape (n_frames, n_faces, 4) - info: Processed info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4) + Tuple of + - boxes: Processed boxes in point form [0, 1], shape (n_frames, n_faces, 4) + - info: Processed info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4) """ # Make sure that enough frames are present if n_frames == 1: @@ -162,8 +165,9 @@ def interpolate_unscanned_frames( info: Detection info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (n_frames, n_faces, 5) n_frames: Number of frames in the original input. Returns: - boxes: Processed boxes in point form [0, 1], shape (orig_n_frames, n_faces, 4) - info: Processed info: idx, scanned, scan_found_face, confidence. Shape (orig_n_frames, n_faces, 4) + Tuple of + - boxes: Processed boxes in point form [0, 1], shape (orig_n_frames, n_faces, 4) + - info: Processed info: idx, scanned, scan_found_face, confidence. Shape (orig_n_frames, n_faces, 4) """ _, n_faces, _ = info.shape # Add rows corresponding to unscanned frames @@ -220,8 +224,9 @@ def __call__( inputs_shape: The shape of the input video as (n_frames, h, w, 3) fps: Sampling frequency of the input video. Returns: - boxes: Detected face boxes in relative flat point form (n_frames, n_faces, 4) - info: Tuple (idx, scanned, scan_found_face, interp_valid, confidence) (n_frames, n_faces, 5) + Tuple of + - boxes: Detected face boxes in relative flat point form (n_frames, n_faces, 4) + - info: Tuple (idx, scanned, scan_found_face, interp_valid, confidence) (n_frames, n_faces, 5) """ # Determine number of batches n_frames = inputs_shape[0] @@ -275,7 +280,7 @@ def scan_batch( start: int, end: int, fps: float = None, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, list]: """Parse video and run inference for one batch. Args: @@ -287,9 +292,10 @@ def scan_batch( end: The index of the last frame of the video to analyze in this batch. fps: Sampling frequency of the input video. Required if type(video) == np.ndarray. Returns: - boxes: Scanned boxes in flat point form (n_frames, n_boxes, 4) - classes: Detection scores for boxes (n_frames, n_boxes, 2) - idxs: Indices of the scanned frames from the original video + Tuple of + - boxes: Scanned boxes in flat point form (n_frames, n_boxes, 4) + - classes: Detection scores for boxes (n_frames, n_boxes, 2) + - idxs: Indices of the scanned frames from the original video """ logging.debug("Batch {}/{}...".format(batch, n_batches)) # Parse the inputs diff --git a/vitallens/utils.py b/vitallens/utils.py index e68af5a..d7ce0c4 100644 --- a/vitallens/utils.py +++ b/vitallens/utils.py @@ -61,7 +61,7 @@ def download_file(url: str, dest: str): def probe_video_inputs( video: Union[np.ndarray, str], fps: float = None - ) -> Tuple[tuple, float]: + ) -> Tuple[tuple, float, bool]: """Check the video inputs and probe to extract metadata. Args: @@ -70,9 +70,10 @@ def probe_video_inputs( video file. fps: Sampling frequency of the input video. Required if type(video)==np.ndarray. Returns: - video_shape: The shape of the input video as (n_frames, h, w, 3) - fps: Sampling frequency of the input video. - issues: True if a possible issue with the video has been detected. + Tuple of + - video_shape: The shape of the input video as (n_frames, h, w, 3) + - fps: Sampling frequency of the input video. + - issues: True if a possible issue with the video has been detected. """ # Check that fps is correct type if not (fps is None or isinstance(fps, (int, float))): @@ -111,7 +112,7 @@ def parse_video_inputs( scale_algorithm: str = 'bilinear', trim: tuple = None, dim_deltas: tuple = (1, 1, 1) - ) -> Tuple[np.ndarray, float, tuple, int]: + ) -> Tuple[np.ndarray, float, tuple, int, list]: """Parse video inputs into required shape. Args: @@ -125,12 +126,13 @@ def parse_video_inputs( trim: Frame numbers for temporal trimming (start, end) (optional). dim_deltas: Maximum acceptable deviation from expected video (n, h, w) dims. Returns: - parsed: Parsed inputs as `np.ndarray` with type uint8. Shape (n, h, w, c) - if target_size provided, h = target_size[0] and w = target_size[1]. - fps_in: Frame rate of original inputs - shape_in: Shape of original inputs in form (n, h, w, c) - ds_factor: Temporal downsampling factor applied - idxs: The frame indices returned from original video + Tuple of + - parsed: Parsed inputs as `np.ndarray` with type uint8. Shape (n, h, w, c) + if target_size provided, h = target_size[0] and w = target_size[1]. + - fps_in: Frame rate of original inputs + - shape_in: Shape of original inputs in form (n, h, w, c) + - ds_factor: Temporal downsampling factor applied + - idxs: The frame indices returned from original video """ # Check if input is array or file name if isinstance(video, str):