Skip to content

Commit

Permalink
inputs_shape and fps args now required for running face detection
Browse files Browse the repository at this point in the history
  • Loading branch information
prouast committed Jul 20, 2024
1 parent 55e747b commit b79b4bc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ def test_video_fps():
fps, *_ = probe_video(TEST_VIDEO_PATH)
return fps

@pytest.fixture(scope='session')
def test_video_shape():
_, n, w, h, _, _, _ = probe_video(TEST_VIDEO_PATH)
return (n, h, w, 3)

@pytest.fixture(scope='session')
def test_video_faces(request):
det = FaceDetector(
max_faces=1, fs=1.0, iou_threshold=0.45, score_threshold=0.9)
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
test_video_fps = request.getfixturevalue('test_video_fps')
boxes, _ = det(test_video_ndarray, fps=test_video_fps)
boxes, _ = det(test_video_ndarray,
inputs_shape=test_video_ndarray.shape,
fps=test_video_fps)
boxes = (boxes * [test_video_ndarray.shape[2], test_video_ndarray.shape[1], test_video_ndarray.shape[2], test_video_ndarray.shape[1]]).astype(int)
return boxes[:,0].astype(np.int64)

Expand Down
11 changes: 9 additions & 2 deletions tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# SOFTWARE.

import numpy as np
from prpy.ffmpeg.probe import probe_video
import pytest

import sys
Expand Down Expand Up @@ -138,11 +139,17 @@ def test_FaceDetector(request, file):
max_faces=2, fs=1.0, iou_threshold=0.45, score_threshold=0.9)
if file:
test_video_path = request.getfixturevalue('test_video_path')
boxes, info = det(test_video_path)
test_video_shape = request.getfixturevalue('test_video_shape')
test_video_fps = request.getfixturevalue('test_video_fps')
boxes, info = det(inputs=test_video_path,
inputs_shape=test_video_shape,
fps=test_video_fps)
else:
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
test_video_fps = request.getfixturevalue('test_video_fps')
boxes, info = det(test_video_ndarray, fps=test_video_fps)
boxes, info = det(inputs=test_video_ndarray,
inputs_shape=test_video_ndarray.shape,
fps=test_video_fps)
assert boxes.shape == (360, 1, 4)
assert info.shape == (360, 1, 5)
np.testing.assert_allclose(boxes[0,0],
Expand Down
4 changes: 2 additions & 2 deletions vitallens/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,15 @@ def __call__(
self,
inputs: Tuple[np.ndarray, str],
inputs_shape: Tuple[tuple, float],
fps: float = None
fps: float
) -> Tuple[np.ndarray, np.ndarray]:
"""Run inference.
Args:
inputs: The video to analyze. Either a np.ndarray of shape (n_frames, h, w, 3)
with a sequence of frames in unscaled uint8 RGB format, or a path to a video file.
inputs_shape: The shape of the input video as (n_frames, h, w, 3)
fps: Sampling frequency of the input video. Required if type(video) == np.ndarray.
fps: Sampling frequency of the input video.
Returns:
boxes: Detected face boxes in relative flat point form (n_frames, n_faces, 4)
info: Tuple (idx, scanned, scan_found_face, interp_valid, confidence) (n_frames, n_faces, 5)
Expand Down

0 comments on commit b79b4bc

Please sign in to comment.