From 51263bca097b91d83cadf1e28ca9997bac6a3f1f Mon Sep 17 00:00:00 2001 From: sammlapp Date: Fri, 6 Sep 2024 12:55:13 -0400 Subject: [PATCH] embed(): return dataframe rather than matrix the dataframe index is needed to associate embeddings with specific audio files and time ranges --- opensoundscape/ml/cnn.py | 13 +++++++++++-- tests/test_cnn.py | 1 + 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/opensoundscape/ml/cnn.py b/opensoundscape/ml/cnn.py index edc0db38..7d545ca0 100644 --- a/opensoundscape/ml/cnn.py +++ b/opensoundscape/ml/cnn.py @@ -2125,9 +2125,18 @@ def embed( avgpool_intermediates=avgpool, ) + # put embeddings + embeddings = pd.DataFrame( + data=embeddings[0], index=dataloader.dataset.dataset.label_df.index + ) + if return_preds: - return embeddings[0], preds - return embeddings[0] + # put predictions in a DataFrame with same index as embeddings + preds = pd.DataFrame( + data=preds, index=dataloader.dataset.dataset.label_df.index + ) + return embeddings, preds + return embeddings @property def device(self): diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 97eec68b..644f1a6c 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -831,6 +831,7 @@ def test_embed(test_df): embeddings = m.embed(samples=test_df, avgpool=True, progress_bar=False) assert embeddings.shape[0] == 2 assert len(embeddings.shape) == 2 + assert isinstance(embeddings, pd.DataFrame) except Exception as e: raise Exception(f"{arch} failed") from e