From 5dfccbccf88c1f33a5a3964f42b989567258260f Mon Sep 17 00:00:00 2001 From: Philipp Rouast Date: Fri, 15 Nov 2024 16:21:19 +1100 Subject: [PATCH] Improve handling of indices --- pyproject.toml | 2 +- vitallens/methods/vitallens.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d352902..6cffc35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "importlib_resources", "numpy", "onnxruntime", - "prpy[ffmpeg,numpy_min]>=0.2.15", + "prpy[ffmpeg,numpy_min]>=0.2.17", "python-dotenv", "pyyaml", "requests", diff --git a/vitallens/methods/vitallens.py b/vitallens/methods/vitallens.py index 0e0917c..41a4d27 100644 --- a/vitallens/methods/vitallens.py +++ b/vitallens/methods/vitallens.py @@ -118,8 +118,8 @@ def __call__( frames = inputs # Longer videos are split up with small overlaps n_splits = 1 if expected_ds_n <= API_MAX_FRAMES else math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1 - split_len = expected_ds_n if n_splits == 1 else math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits) - start_idxs = [i * (split_len - API_OVERLAP * expected_ds_factor) for i in range(n_splits)] + split_len = inputs_n if n_splits == 1 else math.ceil((inputs_n + (n_splits-1) * API_OVERLAP) / n_splits) + start_idxs = [i * (split_len - API_OVERLAP) for i in range(n_splits)] end_idxs = [min(start + split_len, inputs_n) for start in start_idxs] start_idxs = [max(0, end - split_len) for end in end_idxs] logging.info("Running inference for {} frames using {} request(s)...".format(expected_ds_n, n_splits)) @@ -243,7 +243,7 @@ def process_api_batch( assert self.input_buffer is not None payload["state"] = base64.b64encode(self.state.astype(np.float32).tobytes()).decode('utf-8') # Adjust idxs - idxs = idxs[3:] - 3 + idxs = idxs[(self.n_inputs-1):] - (self.n_inputs-1) # Ask API to process video response = requests.post(API_URL, headers=headers, json=payload) response_body = json.loads(response.text)