Skip to content

Commit

Permalink
Add more lazy imports throughout the code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638912770
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed May 31, 2024
1 parent c74827d commit 465d709
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 62 deletions.
2 changes: 2 additions & 0 deletions tensorflow_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
# pylint: enable=line-too-long
# pylint: disable=g-import-not-at-top,g-bad-import-order,wrong-import-position,unused-import

from __future__ import annotations

from absl import logging
from etils import epy as _epy

Expand Down
53 changes: 30 additions & 23 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@
from __future__ import annotations

import abc
from collections.abc import Iterable
import dataclasses
import json
import os
import posixpath
import tempfile
import time
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Optional

from absl import logging
from etils import epath
from etils import epy
from tensorflow_datasets.core import constants
from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core import lazy_imports_lib
Expand All @@ -52,16 +54,21 @@
from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.features import top_level_feature
from tensorflow_datasets.core.proto import dataset_info_pb2
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import gcs_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf

from google.protobuf import json_format
with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import gcs_utils

from google.protobuf import json_format
# pylint: enable=g-import-not-at-top


# TODO(b/109648354): Remove the "pytype: disable" comment.
Nest = Union[Tuple["Nest", ...], Dict[str, "Nest"], str] # pytype: disable=not-supported-yet
SupervisedKeysType = Union[Tuple[Nest, Nest], Tuple[Nest, Nest, Nest]]
Nest = tuple["Nest", ...] | dict[str, "Nest"] | str # pytype: disable=not-supported-yet
SupervisedKeysType = tuple[Nest, Nest] | tuple[Nest, Nest, Nest]


def dataset_info_path(dataset_info_dir: epath.PathLike) -> epath.Path:
Expand Down Expand Up @@ -108,7 +115,7 @@ class DatasetIdentity:
config_name: str | None = None
config_description: str | None = None
config_tags: list[str] | None = None
release_notes: Dict[str, str] | None = None
release_notes: dict[str, str] | None = None

@classmethod
def from_builder(cls, builder) -> "DatasetIdentity":
Expand Down Expand Up @@ -176,16 +183,16 @@ def __init__(
# LINT.IfChange(dataset_info_args)
self,
*,
builder: Union[DatasetIdentity, Any],
description: Optional[str] = None,
builder: DatasetIdentity | Any,
description: str | None = None,
features: Optional[feature_lib.FeatureConnector] = None,
supervised_keys: Optional[SupervisedKeysType] = None,
disable_shuffling: bool = False,
homepage: Optional[str] = None,
citation: Optional[str] = None,
metadata: Optional[Metadata] = None,
license: Optional[str] = None, # pylint: disable=redefined-builtin
redistribution_info: Optional[Dict[str, str]] = None,
homepage: str | None = None,
citation: str | None = None,
metadata: Metadata | None = None,
license: str | None = None, # pylint: disable=redefined-builtin
redistribution_info: Optional[dict[str, str]] = None,
split_dict: Optional[splits_lib.SplitDict] = None,
# LINT.ThenChange(:setstate)
):
Expand Down Expand Up @@ -347,7 +354,7 @@ def config_description(self) -> str | None:
return self._identity.config_description

@property
def config_tags(self) -> List[str] | None:
def config_tags(self) -> list[str] | None:
return self._identity.config_tags

@property
Expand All @@ -368,7 +375,7 @@ def version(self):
return self._identity.version

@property
def release_notes(self) -> Optional[Dict[str, str]]:
def release_notes(self) -> dict[str, str] | None:
return self._identity.release_notes

@property
Expand Down Expand Up @@ -412,7 +419,7 @@ def features(self):
return self._features

@property
def metadata(self) -> Optional[Metadata]:
def metadata(self) -> Metadata | None:
return self._metadata

@property
Expand All @@ -431,14 +438,14 @@ def module_name(self) -> str:
return self._identity.module_name

@property
def file_format(self) -> Optional[file_adapters.FileFormat]:
def file_format(self) -> file_adapters.FileFormat | None:
if not self.as_proto.file_format:
return None
return file_adapters.FileFormat(self.as_proto.file_format)

def set_file_format(
self,
file_format: Union[None, str, file_adapters.FileFormat],
file_format: None | str | file_adapters.FileFormat,
override: bool = False,
) -> None:
"""Internal function to define the file format.
Expand Down Expand Up @@ -716,8 +723,8 @@ def read_from_directory(self, dataset_info_dir: epath.PathLike) -> None:

def add_file_data_source_access(
self,
path: Union[epath.PathLike, Iterable[epath.PathLike]],
url: Optional[str] = None,
path: epath.PathLike | Iterable[epath.PathLike],
url: str | None = None,
) -> None:
"""Records that the given query was used to generate this dataset.
Expand All @@ -743,7 +750,7 @@ def add_file_data_source_access(
def add_url_access(
self,
url: str,
checksum: Optional[str] = None,
checksum: str | None = None,
) -> None:
"""Records the URL used to generate this dataset."""
self._info_proto.data_source_accesses.append(
Expand All @@ -768,7 +775,7 @@ def add_sql_data_source_access(
def add_tfds_data_source_access(
self,
dataset_reference: naming.DatasetReference,
url: Optional[str] = None,
url: str | None = None,
) -> None:
"""Records that the given query was used to generate this dataset.
Expand Down
11 changes: 9 additions & 2 deletions tensorflow_datasets/core/dataset_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@

"""Logic related to reading datasets metadata from config files."""

from __future__ import annotations

import dataclasses
import functools

from etils import epath
from etils import etree
from etils import epy
from tensorflow_datasets.core import constants
from tensorflow_datasets.core.utils import resource_utils

with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from etils import etree
from tensorflow_datasets.core.utils import resource_utils
# pylint: enable=g-import-not-at-top


CITATIONS_FILENAME = "CITATIONS.bib"
Expand Down
19 changes: 10 additions & 9 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,21 @@
from absl import logging
from etils import epath
from etils import epy
import promise
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.download import checksums
from tensorflow_datasets.core.download import extractor
from tensorflow_datasets.core.download import kaggle
from tensorflow_datasets.core.download import resource as resource_lib
from tensorflow_datasets.core.download import util
from tensorflow_datasets.core.utils import shard_utils
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tree

with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
import promise

from tensorflow_datasets.core import utils
from tensorflow_datasets.core.download import checksums
from tensorflow_datasets.core.download import downloader
from tensorflow_datasets.core.download import extractor
from tensorflow_datasets.core.download import kaggle
from tensorflow_datasets.core.download import resource as resource_lib
from tensorflow_datasets.core.download import util
from tensorflow_datasets.core.utils import shard_utils
from tensorflow_datasets.core.utils import type_utils
# pylint: enable=g-import-not-at-top

# pylint: disable=logging-fstring-interpolation
Expand Down
24 changes: 16 additions & 8 deletions tensorflow_datasets/core/logging/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@

"""This module defines the methods a logger implementation should define."""

from __future__ import annotations

from typing import Any, Dict, Optional, Union

from etils import epath
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import download as download_lib
from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.logging import call_metadata
from tensorflow_datasets.core.utils import read_config as read_config_lib
from tensorflow_datasets.core.utils import type_utils
from etils import epy

with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from etils import epath
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import download as download_lib
from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.logging import call_metadata
from tensorflow_datasets.core.utils import read_config as read_config_lib
from tensorflow_datasets.core.utils import type_utils
# pylint: enable=g-import-not-at-top


TreeDict = type_utils.TreeDict

Expand Down
12 changes: 7 additions & 5 deletions tensorflow_datasets/core/logging/call_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

"""To associate metadata with TFDS calls."""

from __future__ import annotations

import enum
import threading
import time
from typing import Dict, Optional, Tuple


# Maps thread_id to "Session ID", if any.
_THREAD_TO_SESSIONID: Dict[int, int] = {}
_THREAD_TO_SESSIONID: dict[int, int] = {}

_NEXT_SESSION_ID = 1
_NEXT_SESSION_ID_LOCK = threading.Lock()
Expand All @@ -33,7 +35,7 @@ class Status(enum.Enum):
ERROR = 2


def _get_session_id(thread_id: int) -> Tuple[int, bool]:
def _get_session_id(thread_id: int) -> tuple[int, bool]:
"""Returns (session_id, direct_call) tuple."""
session_id = _THREAD_TO_SESSIONID.get(thread_id, None)
if session_id:
Expand All @@ -55,8 +57,8 @@ class CallMetadata:
"""

# The start and end times of the event (microseconds since Epoch).
start_time_micros: Optional[int]
end_time_micros: Optional[int]
start_time_micros: int | None
end_time_micros: int | None

# The status (success or error) of the call.
status: Status
Expand Down
35 changes: 21 additions & 14 deletions tensorflow_datasets/core/read_only_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,34 @@

"""Load Datasets without reading dataset generation code."""

from __future__ import annotations

import functools
import os
import typing
from typing import Any, List, Optional, Type

from etils import epath
from etils import etree
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_info
from tensorflow_datasets.core import logging as tfds_logging
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import registered
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.proto import dataset_info_pb2
from tensorflow_datasets.core.utils import error_utils
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import version as version_lib
from etils import epy
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf

with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from etils import epath
from etils import etree
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_info
from tensorflow_datasets.core import logging as tfds_logging
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import registered
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.proto import dataset_info_pb2
from tensorflow_datasets.core.utils import error_utils
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import version as version_lib
# pylint: enable=g-import-not-at-top


class ReadOnlyBuilder(
dataset_builder.FileReaderBuilder, skip_registration=True
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_datasets/core/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Library of helper functions to handle dealing with files."""

from __future__ import annotations

import collections
from collections.abc import Iterator, Sequence
import functools
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_datasets/core/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Some python utils function and classes."""

from __future__ import annotations

import base64
from collections.abc import Iterator, Sequence
import contextlib
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_datasets/core/utils/tqdm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Wrapper around tqdm."""

from __future__ import annotations

import contextlib
import os

Expand Down
2 changes: 2 additions & 0 deletions tensorflow_datasets/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Public API of tfds, without the registered dataset."""

from __future__ import annotations

from etils import epy
from tensorflow_datasets import core
from tensorflow_datasets import typing
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/rlds/envlogger_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union

import numpy as np
import tree
from tensorflow_datasets.core.utils.lazy_imports_utils import tree


def _get_episode_metadata(episode: Sequence[Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 465d709

Please sign in to comment.