Skip to content

Commit

Permalink
refactor StreamingHandler to use LazyModel for resource mgmt
Browse files Browse the repository at this point in the history
  • Loading branch information
janaab11 committed Jan 1, 2025
1 parent 1975f75 commit d4380c4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 57 deletions.
23 changes: 11 additions & 12 deletions src/diart/console/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diart import argdoc
from diart import models as m
from diart import utils
from diart.handler import StreamingInferenceConfig, StreamingInferenceHandler
from diart.handler import StreamingHandlerConfig, StreamingHandler


def run():
Expand Down Expand Up @@ -94,24 +94,23 @@ def run():
args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)

# Resolve pipeline
# Resolve pipeline configuration
pipeline_class = utils.get_pipeline_class(args.pipeline)
config = pipeline_class.get_config_class()(**vars(args))
pipeline = pipeline_class(config)
pipeline_config = pipeline_class.get_config_class()(**vars(args))

# Create inference configuration
inference_config = StreamingInferenceConfig(
pipeline=pipeline,
# Create handler configuration for inference
config = StreamingHandlerConfig(
pipeline_class=pipeline_class,
pipeline_config=pipeline_config,
batch_size=1,
do_profile=False,
do_plot=False,
show_progress=False,
)

# Initialize handler with new configuration
handler = StreamingInferenceHandler(
inference_config=inference_config,
sample_rate=config.sample_rate,
# Initialize handler
handler = StreamingHandler(
config=config,
host=args.host,
port=args.port,
)
Expand All @@ -120,4 +119,4 @@ def run():


if __name__ == "__main__":
run()
run()
67 changes: 22 additions & 45 deletions src/diart/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,15 @@


@dataclass
class WebSocketAudioSourceConfig:
"""Configuration for WebSocket audio source.
Parameters
----------
uri : str
WebSocket URI for the audio source
sample_rate : int
Audio sample rate in Hz
"""

uri: str
sample_rate: int = 16000


@dataclass
class StreamingInferenceConfig:
class StreamingHandlerConfig:
"""Configuration for streaming inference.
Parameters
----------
pipeline : blocks.Pipeline
Diarization pipeline configuration
pipeline_class : type
Pipeline class
pipeline_config : blocks.PipelineConfig
Pipeline configuration
batch_size : int
Number of inputs to process at once
do_profile : bool
Expand All @@ -54,7 +40,8 @@ class StreamingInferenceConfig:
Custom progress bar implementation
"""

pipeline: blocks.Pipeline
pipeline_class: type
pipeline_config: blocks.PipelineConfig
batch_size: int = 1
do_profile: bool = True
do_plot: bool = False
Expand All @@ -70,18 +57,16 @@ class ClientState:
inference: StreamingInference


class StreamingInferenceHandler:
class StreamingHandler:
"""Handles real-time speaker diarization inference for multiple audio sources over WebSocket.
This handler manages WebSocket connections from multiple clients, processing
audio streams and performing speaker diarization in real-time.
Parameters
----------
inference_config : StreamingInferenceConfig
config : StreamingHandlerConfig
Streaming inference configuration
sample_rate : int, optional
Audio sample rate in Hz, by default 16000
host : str, optional
WebSocket server host, by default "127.0.0.1"
port : int, optional
Expand All @@ -94,15 +79,13 @@ class StreamingInferenceHandler:

def __init__(
self,
inference_config: StreamingInferenceConfig,
sample_rate: int = 16000,
config: StreamingHandlerConfig,
host: Text = "127.0.0.1",
port: int = 7007,
key: Optional[Union[Text, Path]] = None,
certificate: Optional[Union[Text, Path]] = None,
):
self.inference_config = inference_config
self.sample_rate = sample_rate
self.config = config
self.host = host
self.port = port

Expand Down Expand Up @@ -135,26 +118,21 @@ def _create_client_state(self, client_id: Text) -> ClientState:
"""
# Create a new pipeline instance with the same config
# This ensures each client has its own state while sharing model weights
pipeline = self.inference_config.pipeline.__class__(
self.inference_config.pipeline.config
)

audio_config = WebSocketAudioSourceConfig(
uri=f"{self.uri}:{client_id}", sample_rate=self.sample_rate
)
pipeline = self.config.pipeline_class(self.config.pipeline_config)

audio_source = src.WebSocketAudioSource(
uri=audio_config.uri, sample_rate=audio_config.sample_rate
uri=f"{self.uri}:{client_id}",
sample_rate=self.config.pipeline_config.sample_rate,
)

inference = StreamingInference(
pipeline=pipeline,
source=audio_source,
batch_size=self.inference_config.batch_size,
do_profile=self.inference_config.do_profile,
do_plot=self.inference_config.do_plot,
show_progress=self.inference_config.show_progress,
progress_bar=self.inference_config.progress_bar,
batch_size=self.config.batch_size,
do_profile=self.config.do_profile,
do_plot=self.config.do_plot,
show_progress=self.config.show_progress,
progress_bar=self.config.progress_bar,
)

return ClientState(audio_source=audio_source, inference=inference)
Expand All @@ -174,16 +152,15 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:

if client_id not in self._clients:
try:
client_state = self._create_client_state(client_id)
self._clients[client_id] = client_state
self._clients[client_id] = self._create_client_state(client_id)

# Setup RTTM response hook
client_state.inference.attach_hooks(
self._clients[client_id].inference.attach_hooks(
lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm())
)

# Start inference
client_state.inference()
self._clients[client_id].inference()
logger.info(f"Started inference for client: {client_id}")

# Send ready notification to client
Expand Down

0 comments on commit d4380c4

Please sign in to comment.