From 8ade093fe09788e9d7642b7dbe41ec0e74ea93e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 26 Nov 2024 13:47:10 -0500 Subject: [PATCH] File reading IO refactoring into backends (#1421) * File reading IO refactoring into backends * Fixes * fix for older python version --- lhotse/__init__.py | 11 +- lhotse/serialization.py | 374 +++++++++++++++++++++++++++----- test/test_missing_torchaudio.py | 1 + 3 files changed, 325 insertions(+), 61 deletions(-) diff --git a/lhotse/__init__.py b/lhotse/__init__.py index 82bdb809d..01b1ea2f7 100644 --- a/lhotse/__init__.py +++ b/lhotse/__init__.py @@ -19,7 +19,16 @@ from .lazy import dill_enabled, is_dill_enabled, set_dill_enabled from .manipulation import combine, split_parallelize_combine, to_manifest from .qa import fix_manifests, validate, validate_recordings_and_supervisions -from .serialization import load_manifest, load_manifest_lazy, store_manifest +from .serialization import ( + available_io_backends, + get_current_io_backend, + get_default_io_backend, + io_backend, + load_manifest, + load_manifest_lazy, + set_current_io_backend, + store_manifest, +) from .supervision import SupervisionSegment, SupervisionSet from .tools.env import add_tools_to_path as _add_tools_to_path from .utils import ( diff --git a/lhotse/serialization.py b/lhotse/serialization.py index 76822ae08..1cad33cb5 100644 --- a/lhotse/serialization.py +++ b/lhotse/serialization.py @@ -4,15 +4,16 @@ import sys import warnings from codecs import StreamReader, StreamWriter +from contextlib import contextmanager from functools import lru_cache from io import BytesIO, StringIO from pathlib import Path -from typing import Any, Dict, Generator, Iterable, Optional, Type, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union import yaml from packaging.version import parse as parse_version -from lhotse.utils import Pathlike, SmartOpen, is_module_available, is_valid_url +from lhotse.utils import Pathlike, Pipe, SmartOpen, is_module_available, is_valid_url from lhotse.workarounds import gzip_open_robust # TODO: figure out how to use some sort of typing stubs @@ -31,59 +32,22 @@ def open_best(path: Pathlike, mode: str = "r"): either stdin or stdout depending on the mode. The concept is similar to Kaldi's "generalized pipes", but uses WebDataset syntax. """ - strpath = str(path) - if strpath == "-": - if mode == "r": - return StdStreamWrapper(sys.stdin) - elif mode == "w": - return StdStreamWrapper(sys.stdout) - else: - raise ValueError( - f"Cannot open stream for '-' with mode other 'r' or 'w' (got: '{mode}')" - ) - if isinstance(path, (BytesIO, StringIO, StreamWriter, StreamReader)): return path - - if strpath.startswith("pipe:"): - return open_pipe(path[5:], mode) - - if strpath.startswith("ais://"): - return open_aistore(path, mode) - - if is_valid_url(strpath): - if is_aistore_available(): - return open_aistore(path, mode) - elif is_module_available("smart_open"): - return SmartOpen.open(path, mode) - else: + assert isinstance( + path, (str, Path) + ), f"Unexpected identifier type {type(path)} for object {path}. Expected str or pathlib.Path." + try: + return get_current_io_backend().open(path, mode) + except Exception: + if is_valid_url(path): raise ValueError( + f"Error trying to open what seems to be a URI: '{path}'\n" f"In order to open URLs/URIs please run 'pip install smart_open' " f"(if you're trying to use AIStore, either the Python SDK is not installed (pip install aistore) " f"or {AIS_ENDPOINT_ENVVAR} is not defined." ) - - if is_module_available("smart_open"): - return SmartOpen.open(path, mode) - - compressed = strpath.endswith(".gz") - if compressed: - if "t" not in mode and "b" not in mode: - # Opening as bytes not requested explicitly, use "t" to tell gzip to handle unicode. - mode = mode + "t" - return gzip_open_robust(path, mode) - - return open(path, mode) - - -def open_pipe(cmd: str, mode: str): - """ - Runs the command and redirects stdin/stdout depending on the mode. - Returns a file-like object that can be read from or written to. - """ - from lhotse.utils import Pipe - - return Pipe(cmd, mode=mode, shell=True, bufsize=8092) + raise AIS_ENDPOINT_ENVVAR = "AIS_ENDPOINT" @@ -113,18 +77,6 @@ def get_aistore_client(): return aistore.Client(endpoint_url), version -def open_aistore(uri: str, mode: str): - assert "r" in mode, "We only support reading from AIStore at this time." - client, version = get_aistore_client() - object = client.fetch_object_by_url(uri) - request = object.get() - if version >= parse_version("1.9.1"): - # AIStore SDK 1.9.1 supports ObjectFile for improved read fault resiliency - return request.as_file() - else: - return request.raw() - - def save_to_yaml(data: Any, path: Pathlike) -> None: with open_best(path, "w") as f: try: @@ -684,3 +636,305 @@ def __getattr__(self, item: str): if item == "close": return self.close return getattr(self.stream, item) + + +class IOBackend: + """ + Base class for IO backends supported by Lhotse. + An IO backend supports open() operations for reads and/or writes to file-like objects. + Deriving classes are auto-registered under their class name, and auto-discoverable + through functions: + + * :func:`~lhotse.serialization.available_io_backends` + + * :func:`~lhotse.serialization.get_current_io_backend` + + * :func:`~lhotse.serialization.set_current_io_backend` + + The default composite backend that tries to figure out the best solution + can be obtained via :func:`~lhotse.serialization.get_default_io_backend`. + + New IO backends are expected to define the following methods: + + * `open(identifier: str, mode: str)` which returns a file-like object. + Must be implemented. + + * `is_applicable(identifier: str) -> bool` returns `True` if a given + backend can be used for a given identifier. True by default. + + * `is_available(identifier: str) -> bool` Class method. Only define it when + the availability of the backend depends on some special actions, + such as installing an option dependency. + + * `handles_special_case(identifier: str) -> bool` defined ONLY when + a given IO Backend MUST be selected for a specific identifier. + For example, only :class:`~lhotse.serialization.PipeIOBackend` handles + piped commands like `"pipe:gunzip -c manifest.jsonl.gz"`. + """ + + KNOWN_BACKENDS = {} + + def __init_subclass__(cls, **kwargs): + if cls.__name__ not in IOBackend.KNOWN_BACKENDS: + IOBackend.KNOWN_BACKENDS[cls.__name__] = cls + super().__init_subclass__(**kwargs) + + @classmethod + def new(cls, name: str) -> "IOBackend": + if name not in cls.KNOWN_BACKENDS: + raise RuntimeError(f"Unknown IO backend name: {name}") + return cls.KNOWN_BACKENDS[name]() + + def open(self, identifier: Pathlike, mode: str): + raise NotImplementedError() + + @classmethod + def is_available(cls) -> bool: + return True + + def handles_special_case(self, identifier: Pathlike) -> bool: + return False + + def is_applicable(self, identifier: Pathlike) -> bool: + return True + + +class BuiltinIOBackend(IOBackend): + """Calls Python's built-in `open`.""" + + def open(self, identifier: Pathlike, mode: str): + return open(identifier, mode=mode) + + def is_applicable(self, identifier: Pathlike) -> bool: + return not is_valid_url(identifier) + + +class RedirectIOBackend(IOBackend): + """Opens a stream to stdin or stdout.""" + + def open(self, identifier: Pathlike, mode: str): + if mode == "r": + return StdStreamWrapper(sys.stdin) + elif mode == "w": + return StdStreamWrapper(sys.stdout) + raise ValueError( + f"Cannot open stream for '-' with mode other 'r' or 'w' (got: '{mode}')" + ) + + def handles_special_case(self, identifier: Pathlike) -> bool: + return identifier == "-" + + def is_applicable(self, identifier: Pathlike) -> bool: + return self.handles_special_case(identifier) + + +class PipeIOBackend(IOBackend): + """Executes the provided command / pipe and wraps it into a file-like object.""" + + def open(self, identifier: Pathlike, mode: str): + """ + Runs the command and redirects stdin/stdout depending on the mode. + Returns a file-like object that can be read from or written to. + """ + return Pipe(str(identifier)[5:], mode=mode, shell=True, bufsize=8092) + + def handles_special_case(self, identifier: Pathlike) -> bool: + return str(identifier).startswith("pipe:") + + def is_applicable(self, identifier: Pathlike) -> bool: + return self.handles_special_case(identifier) + + +class GzipIOBackend(IOBackend): + """Uses gzip.open to automatically (de)compress.""" + + def open(self, identifier: Pathlike, mode: str): + if "t" not in mode and "b" not in mode: + # Opening as bytes not requested explicitly, use "t" to tell gzip to handle unicode. + mode = mode + "t" + return gzip_open_robust(identifier, mode) + + def handles_special_case(self, identifier: Pathlike) -> bool: + identifier = str(identifier) + return identifier.endswith(".gz") and not is_valid_url(identifier) + + def is_applicable(self, identifier: Pathlike) -> bool: + return self.handles_special_case(identifier) + + +class SmartOpenIOBackend(IOBackend): + """Uses `smart_open` library (if installed) to auto-determine how to handle the URI.""" + + def open(self, identifier: Pathlike, mode: str): + return SmartOpen.open(identifier, mode) + + @classmethod + def is_available(cls) -> bool: + return is_module_available("smart_open") + + +class AIStoreIOBackend(IOBackend): + """ + Uses `aistore` client (if installed and enabled via AIS_ENDPOINT env var) + to download data from AIStore if the identifier is a URL/URI. + """ + + def open(self, identifier: str, mode: str): + assert "r" in mode, "We only support reading from AIStore at this time." + client, version = get_aistore_client() + object = client.fetch_object_by_url(identifier) + request = object.get() + if version >= parse_version("1.9.1"): + # AIStore SDK 1.9.1 supports ObjectFile for improved read fault resiliency + return request.as_file() + else: + return request.raw() + + @classmethod + def is_available(cls) -> bool: + return ( + is_module_available("aistore") + and AIS_ENDPOINT_ENVVAR in os.environ + and is_valid_url(os.environ[AIS_ENDPOINT_ENVVAR]) + ) + + def handles_special_case(self, identifier: Pathlike) -> bool: + return str(identifier).startswith("ais://") + + def is_applicable(self, identifier: Pathlike) -> bool: + return is_valid_url(identifier) + + +class CompositeIOBackend(IOBackend): + """ + Composes multiple IO backends together. + Uses `handles_special_case` and `is_applicable` of sub-backends to auto-detect + which backend to select. + + In case of `handles_special_case`, if multiple backends could have worked, + we'll use the first one in the list. + """ + + def __init__(self, backends: List[IOBackend]): + self.backends = backends + + def open(self, identifier: Pathlike, mode: str): + for b in self.backends: + if b.handles_special_case(identifier): + return b.open(identifier, mode) + + for b in self.backends: + if b.is_applicable(identifier): + return b.open(identifier, mode) + + raise RuntimeError( + f"Couldn't find a suitable IOBackend for input '{identifier}'" + ) + + def handles_special_case(self, identifier: Pathlike) -> bool: + return any(b.handles_special_case(identifier) for b in self.backends) + + def is_applicable(self, identifier: Pathlike) -> bool: + return any(b.is_applicable(identifier) for b in self.backends) + + +CURRENT_IO_BACKEND: Optional["IOBackend"] = None + + +def available_io_backends() -> List[str]: + """ + Return a list of names of available IO backends, including "default". + """ + return ["default"] + sorted( + b + for b in IOBackend.KNOWN_BACKENDS + if IOBackend.KNOWN_BACKENDS[b].is_available() + ) + + +@contextmanager +def io_backend(backend: Union["IOBackend", str]) -> Generator["IOBackend", None, None]: + """ + Context manager that sets Lhotse's IO backend to the specified value + and restores the previous IO backend at the end of its scope. + + Example:: + + >>> with io_backend("AIStoreIOBackend"): + ... cuts = CutSet.from_file(...) # forced open() via AIStore client + """ + previous = get_current_io_backend() + b = set_current_io_backend(backend) + yield b + set_current_io_backend(previous) + + +def get_current_io_backend() -> "IOBackend": + """ + Return the backend currently set by the user, or default. + """ + global CURRENT_IO_BACKEND + + # First check if the user has programmatically overridden the backend. + if CURRENT_IO_BACKEND is not None: + return CURRENT_IO_BACKEND + + # Then, check if the user has overridden the audio backend via an env var. + maybe_backend = os.environ.get("LHOTSE_IO_BACKEND") + if maybe_backend is not None: + return set_current_io_backend(maybe_backend) + + # Lastly, fall back to the default backend. + return set_current_io_backend("default") + + +def set_current_io_backend(backend: Union["IOBackend", str]) -> "IOBackend": + """ + Force Lhotse to use a specific IO backend to open every path/URL/URI, + overriding the default behaviour of "educated guessing". + + Example forcing Lhotse to use ``aistore`` library for every IO open() operation:: + + >>> set_current_io_backend(AIStoreIOBackend()) + """ + global CURRENT_IO_BACKEND + if backend == "default": + backend = get_default_io_backend() + elif isinstance(backend, str): + backend = IOBackend.new(backend) + else: + if isinstance(backend, type): + backend = backend() + assert isinstance( + backend, IOBackend + ), f"Expected str or IOBackend, got: {backend}" + CURRENT_IO_BACKEND = backend + return CURRENT_IO_BACKEND + + +@lru_cache(maxsize=1) +def get_default_io_backend() -> "IOBackend": + """ + Return a composite backend that auto-infers which internal backend can support reading + from a given identifier. + + It first looks for special cases that need very specific handling + (such as: stdin/stdout redirects, pipes) + and tries to match them against relevant IO backends. + """ + # Start with the special cases. + backends = [ + RedirectIOBackend(), + PipeIOBackend(), + ] + if AIStoreIOBackend.is_available(): + # Try AIStore before other generalist backends, + # but only if it's installed and enabled via AIS_ENDPOINT env var. + backends.append(AIStoreIOBackend()) + if SmartOpenIOBackend.is_available(): + backends.append(SmartOpenIOBackend()) + backends += [ + GzipIOBackend(), + BuiltinIOBackend(), + ] + return CompositeIOBackend(backends) diff --git a/test/test_missing_torchaudio.py b/test/test_missing_torchaudio.py index 12765d00b..5ba3d9e34 100644 --- a/test/test_missing_torchaudio.py +++ b/test/test_missing_torchaudio.py @@ -6,6 +6,7 @@ def is_torchaudio_available(): + return False return importlib.util.find_spec("torchaudio") is not None