Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allowing to use zcollection without any dask cluster. #16

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 187 additions & 71 deletions zcollection/collection/__init__.py

Large diffs are not rendered by default.

234 changes: 188 additions & 46 deletions zcollection/collection/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,27 +294,47 @@ def _relative_path(self, path: str) -> str:
return pathlib.Path(path).relative_to(
self.partition_properties.dir).as_posix()

def _normalize_partitions(self,
partitions: Iterable[str]) -> Iterable[str]:
"""Normalize the provided list of partitions to include the full
partition's path.

Args:
partitions: The list of partitions to normalize.

Returns:
The list of partitions.
"""
return filter(
self.fs.exists,
map(
lambda partition: self.fs.sep.join(
(self.partition_properties.dir, partition)),
sorted(set(partitions))))

def partitions(
self,
*,
cache: Iterable[str] | None = None,
lock: bool = False,
filters: PartitionFilter = None,
indexer: Indexer | None = None,
selected_partitions: Iterable[str] | None = None,
relative: bool = False,
lock: bool = False,
) -> Iterator[str]:
"""List the partitions of the collection.

Args:
cache: The list of partitions to use. If None, the partitions are
listed.
lock: Whether to lock the collection or not to avoid listing
partitions while the collection is being modified.
filters: The predicate used to filter the partitions to load. If
the predicate is a string, it is a valid python expression to
filter the partitions, using the partitioning scheme as
variables. If the predicate is a function, it is a function that
takes the partition scheme as input and returns a boolean.
indexer: The indexer to apply.
selected_partitions: A list of partitions to load (using the
partition relative path).
relative: Whether to return the relative path.
lock: Whether to lock the collection or not to avoid listing
partitions while the collection is being modified.

Returns:
The list of partitions.
Expand All @@ -336,8 +356,9 @@ def partitions(

base_dir: str = self.partition_properties.dir
sep: str = self.fs.sep
if cache is not None:
partitions: Iterable[str] = cache
if selected_partitions is not None:
partitions: Iterable[str] = self._normalize_partitions(
partitions=selected_partitions)
else:
if lock:
with self.synchronizer:
Expand All @@ -347,6 +368,17 @@ def partitions(
partitions = self.partitioning.list_partitions(
self.fs, base_dir)

if indexer is not None:
# List of partitions existing in the indexer and partitions list
partitions = list(partitions)
partitions = [
p for p in list_partitions_from_indexer(
indexer=indexer,
partition_handler=self.partitioning,
base_dir=self.partition_properties.dir,
sep=self.fs.sep) if p in partitions
]

yield from (self._relative_path(item) if relative else item
for item in partitions
if (item != self._immutable and self._is_selected(
Expand Down Expand Up @@ -553,8 +585,11 @@ def load(
filters: PartitionFilter = None,
indexer: Indexer | None = None,
selected_variables: Iterable[str] | None = None,
selected_partitions: Iterable[str] | None = None,
distributed: bool = True,
) -> dataset.Dataset | None:
"""Load the selected partitions.
"""Load collection's data, respecting filters, indexer, and selected
partitions constraints.

Args:
delayed: Whether to load data in a dask array or not.
Expand All @@ -564,6 +599,9 @@ def load(
indexer: The indexer to apply.
selected_variables: A list of variables to retain from the
collection. If None, all variables are kept.
selected_partitions: A list of partitions to load (using the
partition relative path).
distributed: Whether to use dask or not. Default To True.

Returns:
The dataset containing the selected partitions, or None if no
Expand All @@ -582,46 +620,149 @@ def load(
... filters=lambda keys: keys["year"] == 2019 and
... keys["month"] == 3 and keys["day"] % 2 == 0)
"""
client: dask.distributed.Client = dask_utils.get_client()
arrays: list[dataset.Dataset]
# Delayed has to be True of dask is disabled
if not distributed:
delayed = False

if indexer is None:
selected_partitions = tuple(self.partitions(filters=filters))
if len(selected_partitions) == 0:
return None
arrays = self._load_partitions(
delayed=delayed,
filters=filters,
selected_variables=selected_variables,
selected_partitions=selected_partitions,
distributed=distributed)
else:
arrays = self._load_partitions_indexer(
indexer=indexer,
delayed=delayed,
filters=filters,
selected_variables=selected_variables,
selected_partitions=selected_partitions,
distributed=distributed)

if arrays is None:
return None

# No indexer, so the dataset is loaded directly for each
# selected partition.
array: dataset.Dataset = arrays.pop(0)
if arrays:
array = array.concat(arrays, self.partition_properties.dim)
if self._immutable:
array.merge(
storage.open_zarr_group(self._immutable,
self.fs,
delayed=delayed,
selected_variables=selected_variables))
array.fill_attrs(self.metadata)
return array

def _load_partitions(
self,
*,
delayed: bool = True,
filters: PartitionFilter = None,
selected_variables: Iterable[str] | None = None,
selected_partitions: Iterable[str] | None = None,
distributed: bool = True,
) -> list[dataset.Dataset] | None:
"""Load collection's partitions, respecting filters, and selected
partitions constraints.

Args:
delayed: Whether to load data in a dask array or not.
filters: The predicate used to filter the partitions to load. To
get more information on the predicate, see the documentation of
the :meth:`partitions` method.
selected_variables: A list of variables to retain from the
collection. If None, all variables are kept.
selected_partitions: A list of partitions to load (using the
partition relative path).
distributed: Whether to use dask or not. Default To True.

Returns:
The list of dataset for each partition, or None if no
partitions were selected.
"""
# No indexer, so the dataset is loaded directly for each
# selected partition.
selected_partitions = tuple(
self.partitions(filters=filters,
selected_partitions=selected_partitions))

if len(selected_partitions) == 0:
return None

if distributed:
client = dask_utils.get_client()
bag: dask.bag.core.Bag = dask.bag.core.from_sequence(
self.partitions(filters=filters),
selected_partitions,
npartitions=dask_utils.dask_workers(client, cores_only=True))
arrays = bag.map(storage.open_zarr_group,
delayed=delayed,
fs=self.fs,
selected_variables=selected_variables).compute()
else:
# We're going to reuse the indexer variable, so ensure it is
# an iterable not a generator.
indexer = tuple(indexer)

# Build the indexer arguments.
partitions = self.partitions(filters=filters,
cache=list_partitions_from_indexer(
indexer, self.partitioning,
self.partition_properties.dir,
self.fs.sep))
args = tuple(
build_indexer_args(self,
filters,
indexer,
partitions=partitions))
if len(args) == 0:
return None
arrays = [
storage.open_zarr_group(dirname=partition,
delayed=delayed,
fs=self.fs,
selected_variables=selected_variables)
for partition in selected_partitions
]

return arrays

def _load_partitions_indexer(
self,
*,
indexer: Indexer,
delayed: bool = True,
filters: PartitionFilter = None,
selected_variables: Iterable[str] | None = None,
selected_partitions: Iterable[str] | None = None,
distributed: bool = True,
) -> list[dataset.Dataset] | None:
"""Load collection's partitions, respecting filters, indexer, and
selected partitions constraints.

Args:
indexer: The indexer to apply.
delayed: Whether to load data in a dask array or not.
filters: The predicate used to filter the partitions to load. To
get more information on the predicate, see the documentation of
the :meth:`partitions` method.
selected_variables: A list of variables to retain from the
collection. If None, all variables are kept.
selected_partitions: A list of partitions to load (using the
partition relative path).
distributed: Whether to use dask or not. Default To True.

Returns:
The list of dataset for each partition, or None if no
partitions were selected.
"""
# We're going to reuse the indexer variable, so ensure it is
# an iterable not a generator.
indexer = tuple(indexer)

# Build the indexer arguments.
partitions = self.partitions(selected_partitions=selected_partitions,
filters=filters,
indexer=indexer)
args = tuple(
build_indexer_args(collection=self,
filters=filters,
indexer=indexer,
partitions=partitions))
if len(args) == 0:
return None

# Finally, load the selected partitions and apply the indexer.
if distributed:
client = dask_utils.get_client()
bag = dask.bag.core.from_sequence(
args,
npartitions=dask_utils.dask_workers(client, cores_only=True))

# Finally, load the selected partitions and apply the indexer.
arrays = list(
itertools.chain.from_iterable(
bag.map(
Expand All @@ -632,18 +773,19 @@ def load(
partition_properties=self.partition_properties,
selected_variables=selected_variables,
).compute()))
else:
arrays = list(
itertools.chain.from_iterable([
_load_and_apply_indexer(
args=a,
delayed=delayed,
fs=self.fs,
partition_handler=self.partitioning,
partition_properties=self.partition_properties,
selected_variables=selected_variables) for a in args
]))

array: dataset.Dataset = arrays.pop(0)
if arrays:
array = array.concat(arrays, self.partition_properties.dim)
if self._immutable:
array.merge(
storage.open_zarr_group(self._immutable,
self.fs,
delayed=delayed,
selected_variables=selected_variables))
array.fill_attrs(self.metadata)
return array
return arrays

def _bag_from_partitions(
self,
Expand Down
13 changes: 10 additions & 3 deletions zcollection/collection/detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def _insert(
fs: fsspec.AbstractFileSystem,
merge_callable: merging.MergeCallable | None,
partitioning_properties: PartitioningProperties,
distributed: bool = True,
**kwargs,
) -> None:
"""Insert or update a partition in the collection.
Expand All @@ -405,6 +406,7 @@ def _insert(
fs: The file system that the partition is stored on.
merge_callable: The merge callable.
partitioning_properties: The partitioning properties.
distributed: Whether to use dask or not. Default To True.
**kwargs: Additional keyword arguments to pass to the merge callable.
"""
partition: tuple[str, ...]
Expand All @@ -423,7 +425,8 @@ def _insert(
axis,
fs,
partitioning_properties.dim,
delayed=zds.delayed,
delayed=zds.delayed if distributed else False,
distributed=distributed,
merge_callable=merge_callable,
**kwargs)
return
Expand All @@ -434,7 +437,11 @@ def _insert(
zarr.storage.init_group(store=fs.get_mapper(dirname))

# The synchronization is done by the caller.
write_zarr_group(zds.isel(indexer), dirname, fs, sync.NoSync())
write_zarr_group(zds.isel(indexer),
dirname,
fs,
sync.NoSync(),
distributed=distributed)
except: # noqa: E722
# If the construction of the new dataset fails, the created
# partition is deleted, to guarantee the integrity of the
Expand All @@ -459,7 +466,7 @@ def _load_and_apply_indexer(
fs: The file system that the partition is stored on.
partition_handler: The partitioning handler.
partition_properties: The partitioning properties.
selected_variable: The selected variables to load.
selected_variables: The selected variables to load.

Returns:
The list of loaded datasets.
Expand Down
Loading
Loading