Skip to content

Commit

Permalink
embed(): return dataframe rather than matrix
Browse files Browse the repository at this point in the history
the dataframe index is needed to associate embeddings with specific audio files and time ranges
  • Loading branch information
sammlapp committed Sep 6, 2024
1 parent a10df73 commit 51263bc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
13 changes: 11 additions & 2 deletions opensoundscape/ml/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,9 +2125,18 @@ def embed(
avgpool_intermediates=avgpool,
)

# put embeddings
embeddings = pd.DataFrame(
data=embeddings[0], index=dataloader.dataset.dataset.label_df.index
)

if return_preds:
return embeddings[0], preds
return embeddings[0]
# put predictions in a DataFrame with same index as embeddings
preds = pd.DataFrame(
data=preds, index=dataloader.dataset.dataset.label_df.index
)
return embeddings, preds
return embeddings

@property
def device(self):
Expand Down
1 change: 1 addition & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def test_embed(test_df):
embeddings = m.embed(samples=test_df, avgpool=True, progress_bar=False)
assert embeddings.shape[0] == 2
assert len(embeddings.shape) == 2
assert isinstance(embeddings, pd.DataFrame)
except Exception as e:
raise Exception(f"{arch} failed") from e

Expand Down

0 comments on commit 51263bc

Please sign in to comment.