Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing of core classes #96

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __pycache__/
MANIFEST
.Python
env/
venv/
bin/
build/
develop-eggs/
Expand Down Expand Up @@ -61,4 +62,4 @@ doc/.ipynb_checkpoints
# PyCharm
.idea/

.mypy_cache/
.mypy_cache/
58 changes: 45 additions & 13 deletions pyannote/core/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
from collections import defaultdict
from typing import (
Hashable,
Literal,
Optional,
Dict,
Union,
Expand All @@ -122,6 +123,8 @@
Iterator,
Text,
TYPE_CHECKING,
NamedTuple,
overload,
)

import numpy as np
Expand All @@ -139,7 +142,17 @@
from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode

if TYPE_CHECKING:
import pandas as pd
import pandas as pd # type: ignore


class SegmentTrack(NamedTuple):
segment: Segment
track: TrackName

class SegmentTrackLabel(NamedTuple):
segment: Segment
track: TrackName
label: Label


class Annotation:
Expand Down Expand Up @@ -187,7 +200,7 @@ def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None):
self._labelNeedsUpdate: Dict[Label, bool] = {}

# timeline meant to store all annotated segments
self._timeline: Timeline = None
self._timeline: Optional[Timeline] = None
self._timelineNeedsUpdate: bool = True

@property
Expand Down Expand Up @@ -259,9 +272,16 @@ def itersegments(self):
"""
return iter(self._tracks)

@overload
def itertracks(self, yield_label: Literal[False] = ...) -> Iterator[SegmentTrack]: ...
@overload
def itertracks(self, yield_label: Literal[True]) -> Iterator[SegmentTrackLabel]: ...
@overload
def itertracks(self, yield_label: bool) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]: ...

def itertracks(
self, yield_label: bool = False
) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]:
) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]:
"""Iterate over tracks (in chronological order)

Parameters
Expand All @@ -286,9 +306,9 @@ def itertracks(
tracks.items(), key=lambda tl: (str(tl[0]), str(tl[1]))
):
if yield_label:
yield segment, track, lbl
yield SegmentTrackLabel(segment, track, lbl)
else:
yield segment, track
yield SegmentTrack(segment, track)

def _updateTimeline(self):
self._timeline = Timeline(segments=self._tracks, uri=self.uri)
Expand Down Expand Up @@ -317,9 +337,14 @@ def get_timeline(self, copy: bool = True) -> Timeline:
"""
if self._timelineNeedsUpdate:
self._updateTimeline()

timeline_ = self._timeline
if timeline_ is None:
timeline_ = Timeline(uri=self.uri)

if copy:
return self._timeline.copy()
return self._timeline
return timeline_.copy()
return timeline_

def __eq__(self, other: "Annotation"):
"""Equality
Expand Down Expand Up @@ -556,6 +581,9 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation
else:
raise NotImplementedError("unsupported mode: '%s'" % mode)

else:
raise TypeError("unsupported support type: '%s'" % type(support))

def extrude(
self, removed: Support, mode: CropMode = "intersection"
) -> "Annotation":
Expand Down Expand Up @@ -1178,7 +1206,7 @@ def argmax(self, support: Optional[Support] = None) -> Optional[Label]:
key=lambda x: x[1],
)[0]

def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
def rename_tracks(self, generator: Union[LabelGenerator, Iterable[str], Iterable[int]] = "string") -> "Annotation":
"""Rename all tracks

Parameters
Expand Down Expand Up @@ -1215,13 +1243,17 @@ def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
renamed = self.__class__(uri=self.uri, modality=self.modality)

if generator == "string":
generator = string_generator()
generator_ = string_generator()
elif generator == "int":
generator = int_generator()
generator_ = int_generator()
elif isinstance(generator, Iterable):
generator_ = iter(generator)
else:
raise ValueError("generator must be 'string', 'int', or iterable")

# TODO speed things up by working directly with annotation internals
for s, _, label in self.itertracks(yield_label=True):
renamed[s, next(generator)] = label
renamed[s, next(generator_)] = label
return renamed

def rename_labels(
Expand Down Expand Up @@ -1439,11 +1471,11 @@ def discretize(
duration: Optional[float] = None,
):
"""Discretize

Parameters
----------
support : Segment, optional
Part of annotation to discretize.
Part of annotation to discretize.
Defaults to annotation full extent.
resolution : float or SlidingWindow, optional
Defaults to 10ms frames.
Expand Down
8 changes: 4 additions & 4 deletions pyannote/core/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
"""
import numbers
import warnings
from typing import Tuple, Optional, Union, Iterator, List, Text
from typing import Tuple, Optional, Union, Iterator, List

import numpy as np

from pyannote.core.utils.types import Alignment
from pyannote.core.utils.types import Alignment, Label
from .segment import Segment
from .segment import SlidingWindow
from .timeline import Timeline
Expand All @@ -58,7 +58,7 @@ class SlidingWindowFeature(np.lib.mixins.NDArrayOperatorsMixin):
"""

def __init__(
self, data: np.ndarray, sliding_window: SlidingWindow, labels: List[Text] = None
self, data: np.ndarray, sliding_window: SlidingWindow, labels: Optional[List[Label]] = None
):
self.sliding_window: SlidingWindow = sliding_window
self.data = data
Expand Down Expand Up @@ -106,7 +106,7 @@ def __next__(self) -> Tuple[Segment, np.ndarray]:
self.__i += 1
try:
return self.sliding_window[self.__i], self.data[self.__i]
except IndexError as e:
except IndexError:
raise StopIteration()

def next(self):
Expand Down
Loading