Skip to content

Commit

Permalink
add tests for find_overlapping_idxs_in_clip_df
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfh committed Jun 7, 2024
1 parent d00281a commit f2543ec
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
33 changes: 9 additions & 24 deletions opensoundscape/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,38 +873,23 @@ def find_overlapping_idxs_in_clip_df(
clip_df = clip_df.loc[clip_df.index.get_level_values(1) < annotation_end]
# and all rows that end before the annotation starts. End is level 2 of multi-index
clip_df = clip_df.loc[clip_df.index.get_level_values(2) > annotation_start]
# don't calculate overlaps if there are no overlapping rows
if clip_df.empty:
return None
# now for each row, calculate the overlap
clip_df["overlap"] = clip_df.apply(
lambda row: overlap(
[annotation_start, annotation_end],
[
row.name[1],
row.name[2],
], # row.name is the multi-index. So row.name[1] is the start_time and row.name[2] is the end_time
),
axis=1,
)
# now for each time-window, calculate the overlaps
clip_df["overlap"] = [
overlap([annotation_start, annotation_end], [row[1], row[2]])
for row in clip_df.index
]

# discard annotations that do not overlap with the time window
clip_df = clip_df[clip_df["overlap"] > 0]

# calculate the fraction of each annotation that overlaps with this time window
clip_df["overlap_fraction"] = clip_df.apply(
lambda row: overlap_fraction(
[annotation_start, annotation_end], [row.name[1], row.name[2]]
),
axis=1,
)
clip_df["overlap_fraction"] = [
overlap_fraction([annotation_start, annotation_end], [row[1], row[2]])
for row in clip_df.index
]

if min_label_overlap is not None:
clip_df = clip_df[clip_df["overlap"] >= min_label_overlap]
if min_label_fraction is not None:
clip_df = clip_df[clip_df["overlap_fraction"] >= min_label_fraction]

# return the indices of the overlapping rows
if clip_df.empty:
return None
return clip_df.index
22 changes: 22 additions & 0 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,25 @@ def test_from_raven_files_one_audio_file(raven_file):
ba = BoxedAnnotations.from_raven_files(Path(raven_file), Path("path1"))
assert str(ba.audio_files[0]) == "path1"
assert len(ba.audio_files) == 1


def test_find_overlapping_idxs_in_clip_df(boxed_annotations):
clip_df = generate_clip_times_df(5, clip_duration=1.0, clip_overlap=0)
# make it a multi-index, with the first level being the audio file, second being start, third being end time
clip_df["audio_file"] = "audio_file.wav"
clip_df = clip_df.set_index(["audio_file", "start_time", "end_time"])
# annotation overlaps with 1 time-window
idxs = annotations.find_overlapping_idxs_in_clip_df(
0, 1, clip_df, min_label_overlap=0.25
)
assert len(idxs) == 1
# annotation overlaps with 2 time-windows
idxs = annotations.find_overlapping_idxs_in_clip_df(
0, 1.3, clip_df, min_label_overlap=0.25
)
assert len(idxs) == 2
# annotation-overlaps with no time-windows
idxs = annotations.find_overlapping_idxs_in_clip_df(
1000, 1001, clip_df, min_label_overlap=0.25
)
assert len(idxs) == 0

0 comments on commit f2543ec

Please sign in to comment.