Skip to content

Commit

Permalink
Add some docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 committed Nov 13, 2023
1 parent d76a6ef commit 240b3ac
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/diart/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,28 @@

@dataclass
class HyperParameter:
"""Represents a pipeline hyper-parameter that can be tuned by diart"""

name: Text
"""Name of the hyper-parameter (e.g. tau_active)"""
low: float
"""Lowest value that this parameter can take"""
high: float
"""Highest value that this parameter can take"""

@staticmethod
def from_name(name: Text) -> "HyperParameter":
"""Create a HyperParameter object given its name.
Parameters
----------
name: str
Name of the hyper-parameter
Returns
-------
HyperParameter
"""
if name == "tau_active":
return TauActive
if name == "rho_update":
Expand All @@ -32,24 +48,32 @@ def from_name(name: Text) -> "HyperParameter":


class PipelineConfig(ABC):
"""Configuration containing the required parameters to build and run a pipeline"""

@property
@abstractmethod
def duration(self) -> float:
"""The duration of an input audio chunk (in seconds)"""
pass

@property
@abstractmethod
def step(self) -> float:
"""The step between two consecutive input audio chunks (in seconds)"""
pass

@property
@abstractmethod
def latency(self) -> float:
"""The algorithmic latency of the pipeline (in seconds).
At time `t` of the audio stream, the pipeline will output predictions for time `t - latency`.
"""
pass

@property
@abstractmethod
def sample_rate(self) -> int:
"""The sample rate of the input audio stream"""
pass

def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
Expand All @@ -60,6 +84,8 @@ def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:


class Pipeline(ABC):
"""Represents a streaming audio pipeline"""

@staticmethod
@abstractmethod
def get_config_class() -> type:
Expand Down Expand Up @@ -92,4 +118,16 @@ def set_timestamp_shift(self, shift: float):
def __call__(
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
"""Runs the next steps of the pipeline given a list of consecutive audio chunks.
Parameters
----------
waveforms: Sequence[SlidingWindowFeature]
Consecutive chunk waveforms for the pipeline to ingest
Returns
-------
Sequence[Tuple[Any, SlidingWindowFeature]]
For each input waveform, a tuple containing the pipeline output and its respective audio
"""
pass

0 comments on commit 240b3ac

Please sign in to comment.