From d4662f83c637931f9c1fc64248e085fc928c6a9c Mon Sep 17 00:00:00 2001 From: Thomas Zilio Date: Wed, 20 Nov 2024 14:55:45 +0100 Subject: [PATCH 1/2] feat: Allowing to use zcollection without any dask cluster. --- zcollection/collection/__init__.py | 250 ++++++++++----- zcollection/collection/abc.py | 85 +++-- zcollection/collection/detail.py | 13 +- .../collection/tests/test_collection.py | 213 +++++++++---- zcollection/convenience/view.py | 19 +- zcollection/merging/__init__.py | 20 +- zcollection/merging/tests/test_merging.py | 21 +- zcollection/storage.py | 4 +- zcollection/view/__init__.py | 297 ++++++++++++------ zcollection/view/detail.py | 52 ++- zcollection/view/tests/test_view.py | 70 +++-- 11 files changed, 711 insertions(+), 333 deletions(-) diff --git a/zcollection/collection/__init__.py b/zcollection/collection/__init__.py index 8c09807..0a50e3e 100644 --- a/zcollection/collection/__init__.py +++ b/zcollection/collection/__init__.py @@ -38,6 +38,7 @@ from .abc import PartitionFilter, ReadOnlyCollection from .callable_objects import UpdateCallable, WrappedPartitionCallable from .detail import ( + PartitionSlice, _insert, _try_infer_callable, _wrap_update_func, @@ -306,12 +307,13 @@ def insert( merge_callable: merging.MergeCallable | None = None, npartitions: int | None = None, validate: bool = False, + distributed: bool = True, **kwargs, ) -> Iterable[str]: """Insert a dataset into the collection. Args: - ds: The dataset to insert. It can be either an xarray.Dataset or a + ds: The dataset to insert. It can be either a xarray.Dataset or a dataset.Dataset object. merge_callable: A function to use to merge the existing data set already stored in partitions with the new partitioned data. If @@ -319,6 +321,9 @@ def insert( partitioned data. npartitions: The maximum number of partitions to process in parallel. By default, partitions are processed one by one. + validate: Whether to validate dataset metadata before insertion + or not. + distributed: Whether to use dask or not. Default To True. kwargs: Additional keyword arguments passed to the merge callable. .. note:: @@ -331,9 +336,6 @@ def insert( partition insertion will happen sequentially and changing this parameter will have no effect. - validate: Whether to validate dataset metadata before insertion - or not. - Returns: A list of the inserted partitions. @@ -372,7 +374,6 @@ def insert( "Provided dataset's metadata do not match the collection's ones" ) ds = ds.set_for_insertion(self.metadata) - client: dask.distributed.Client = dask_utils.get_client() # If the dataset contains variables that should not be partitioned. if self._immutable is not None: @@ -391,11 +392,33 @@ def insert( partitions = tuple( self.partitioning.split_dataset(ds, self.partition_properties.dim)) + if distributed: + self._insert_distributed(ds=ds, + partitions=partitions, + npartitions=npartitions, + merge_callable=merge_callable, + **kwargs) + else: + self._insert_sequential(ds=ds, + partitions=partitions, + merge_callable=merge_callable, + **kwargs) + + return (fs_utils.join_path(*((self.partition_properties.dir, ) + item)) + for item, _ in partitions) + + def _insert_distributed(self, ds: xarray.Dataset | dataset.Dataset, + partitions: tuple[PartitionSlice, + ...], npartitions: int | None, + merge_callable: merging.MergeCallable | None, + **kwargs): + """Insert a dataset into the collection using dask.""" if npartitions is not None: if npartitions < 1: raise ValueError('The number of partitions must be positive') npartitions = len(partitions) // npartitions + 1 + client: dask.distributed.Client = dask_utils.get_client() scattered_ds: Any = client.scatter(ds) for sequence in dask_utils.split_sequence(partitions, npartitions): futures: list[dask.distributed.Future] = [ @@ -411,8 +434,21 @@ def insert( ] storage.execute_transaction(client, self.synchronizer, futures) - return (fs_utils.join_path(*((self.partition_properties.dir, ) + item)) - for item, _ in partitions) + def _insert_sequential(self, ds: xarray.Dataset | dataset.Dataset, + partitions: tuple[PartitionSlice, ...], + merge_callable: merging.MergeCallable | None, + **kwargs): + """Insert a dataset into the collection without using dask.""" + ds = ds.compute() + for partition in partitions: + _insert(args=partition, + axis=self.axis, + zds=ds, + fs=self.fs, + merge_callable=merge_callable, + partitioning_properties=self.partition_properties, + distributed=False, + **kwargs) # pylint: disable=method-hidden def drop_partitions( @@ -420,6 +456,7 @@ def drop_partitions( *, filters: PartitionFilter = None, timedelta: datetime.timedelta | None = None, + distributed: bool = True, ) -> Iterable[str]: # pylint: disable=method-hidden """Drop the selected partitions. @@ -430,6 +467,7 @@ def drop_partitions( the :meth:`partitions` method. timedelta: Select the partitions created before the specified time delta relative to the current time. + distributed: Whether to use dask or not. Default To True. Returns: A list of the dropped partitions. @@ -440,7 +478,6 @@ def drop_partitions( ... timedelta=datetime.timedelta(days=30)) """ now: datetime.datetime = datetime.datetime.now() - client: dask.distributed.Client = dask_utils.get_client() folders = list(self.partitions(filters=filters, lock=True)) # No partition selected, nothing to do. @@ -463,9 +500,14 @@ def is_created_before(path: str, now: datetime.datetime, folder, now, timedelta), folders)) - storage.execute_transaction( - client, self.synchronizer, - client.map(self.fs.rm, folders, recursive=True)) + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + storage.execute_transaction( + client, self.synchronizer, + client.map(self.fs.rm, folders, recursive=True)) + else: + for folder in folders: + self.fs.rm(path=folder, recursive=True) def invalidate_cache(path) -> None: """Invalidate the cache.""" @@ -488,6 +530,7 @@ def update( selected_variables: list[str] | None = None, trim: bool = True, variables: Sequence[str] | None = None, + distributed: bool = True, **kwargs, ) -> None: # pylint: disable=method-hidden @@ -517,7 +560,8 @@ def update( the variables are inferred by calling the function on the first partition. In this case, it is important to ensure that the function can be called twice on the same partition without - side-effects. Default is None. + side effects. Default is None. + distributed: Whether to use dask or not. Default To True. **kwargs: The keyword arguments to pass to the function. Raises: @@ -534,6 +578,10 @@ def update( if not callable(func): raise TypeError('func must be a callable') + # Delayed has to be True of dask is disabled + if not distributed: + delayed = False + variables = variables or _infer_callable( self, func, filters, delayed, selected_variables, *args, **kwargs) @@ -575,24 +623,30 @@ def update( trim=trim, **kwargs) - client: dask.distributed.Client = dask_utils.get_client() + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + + batches: Iterator[Sequence[str]] = dask_utils.split_sequence( + selected_partitions, npartitions + or dask_utils.dask_workers(client, cores_only=True)) + storage.execute_transaction( + client, self.synchronizer, + client.map(local_func, tuple(batches), key=func.__name__)) + else: + local_func(selected_partitions) - batches: Iterator[Sequence[str]] = dask_utils.split_sequence( - selected_partitions, npartitions - or dask_utils.dask_workers(client, cores_only=True)) - storage.execute_transaction( - client, self.synchronizer, - client.map(local_func, tuple(batches), key=func.__name__)) tuple(map(self.fs.invalidate_cache, selected_partitions)) def drop_variable( self, variable: str, + distributed: bool = True, ) -> None: """Delete the variable from the collection. Args: variable: The variable to delete. + distributed: Whether to use dask or not. Default To True. Raises: ValueError: If the variable doesn't exist in the collection or is @@ -618,23 +672,34 @@ def drop_variable( raise ValueError( f'The variable {variable!r} is part of the immutable ' 'dataset.') - client: dask.distributed.Client = dask_utils.get_client() - bag: dask.bag.core.Bag = self._bag_from_partitions(lock=True) - awaitables: list[ - dask.distributed.Future] = dask.distributed.futures_of( - bag.map(storage.del_zarr_array, variable, self.fs).persist()) - storage.execute_transaction(client, self.synchronizer, awaitables) + + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + bag: dask.bag.core.Bag = self._bag_from_partitions(lock=True) + awaitables: list[ + dask.distributed.Future] = dask.distributed.futures_of( + bag.map(storage.del_zarr_array, variable, + self.fs).persist()) + storage.execute_transaction(client, self.synchronizer, awaitables) + else: + for partition in self.partitions(lock=True): + storage.del_zarr_array(dirname=partition, + name=variable, + fs=self.fs) + del self.metadata.variables[variable] self._write_config() def add_variable( self, variable: meta.Variable | dataset.Variable, + distributed: bool = True, ) -> None: """Add a variable to the collection. Args: variable: The variable to add. + distributed: Whether to use dask or not. Default To True. Raises: ValueError: if the variable is already part of the collection, it @@ -643,6 +708,7 @@ def add_variable( Example: >>> import zcollection + >>> import numpy >>> collection = zcollection.open_collection( ... "my_collection", mode="w") >>> new_variable = meta.Variable( @@ -667,8 +733,6 @@ def add_variable( # from the collection metadata. variable = variable.set_for_insertion() - client: dask.distributed.Client = dask_utils.get_client() - template: meta.Variable = self.metadata.search_same_dimensions_as( variable) chunks: dict[str, int] = { @@ -676,17 +740,26 @@ def add_variable( for dim in self.metadata.chunks } try: - bag: dask.bag.core.Bag = self._bag_from_partitions(lock=True) - futures: list[ - dask.distributed.Future] = dask.distributed.futures_of( - bag.map(storage.add_zarr_array, - variable, - template.name, - self.fs, - chunks=chunks).persist()) - storage.execute_transaction(client, self.synchronizer, futures) + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + bag: dask.bag.core.Bag = self._bag_from_partitions(lock=True) + futures: list[ + dask.distributed.Future] = dask.distributed.futures_of( + bag.map(storage.add_zarr_array, + variable, + template.name, + self.fs, + chunks=chunks).persist()) + storage.execute_transaction(client, self.synchronizer, futures) + else: + for partition in self.partitions(lock=True): + storage.add_zarr_array(dirname=partition, + variable=variable, + template=template.name, + fs=self.fs, + chunks=chunks) except Exception: - self.drop_variable(variable.name) + self.drop_variable(variable.name, distributed=distributed) raise def copy( @@ -698,6 +771,7 @@ def copy( mode: Literal['r', 'w'] = 'w', npartitions: int | None = None, synchronizer: sync.Sync | None = None, + distributed: bool = True, ) -> Collection: """Copy the collection to a new location. @@ -711,6 +785,7 @@ def copy( is number of cores. synchronizer: The synchronizer used to synchronize the collection copied. Default is None. + distributed: Whether to use dask or not. Default To True. Returns: The new collection. @@ -724,28 +799,44 @@ def copy( _LOGGER.info('Copying of the collection to %r', target) if filesystem is None: filesystem = fs_utils.get_fs(target) - client: dask.distributed.Client = dask_utils.get_client() - npartitions = npartitions or dask_utils.dask_workers(client, - cores_only=True) - - # Sequence of (source, target) to copy split in npartitions - args = tuple( - dask_utils.split_sequence( - [(item, - fs_utils.join_path( - target, - os.path.relpath(item, self.partition_properties.dir))) - for item in self.partitions(filters=filters)], npartitions)) - # Copy the selected partitions - partial = functools.partial(fs_utils.copy_tree, - fs_source=self.fs, - fs_target=filesystem) - - def worker_task(args: Sequence[tuple[str, str]]) -> None: - """Function call on each worker to copy the partitions.""" - tuple(map(lambda arg: partial(*arg), args)) - - client.gather(client.map(worker_task, args)) + + partitions = self.partitions(filters=filters) + + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + npartitions = npartitions or dask_utils.dask_workers( + client, cores_only=True) + + # Sequence of (source, target) to copy split in npartitions + args = tuple( + dask_utils.split_sequence([ + (item, + fs_utils.join_path( + target, + os.path.relpath(item, self.partition_properties.dir))) + for item in partitions + ], npartitions)) + # Copy the selected partitions + partial = functools.partial(fs_utils.copy_tree, + fs_source=self.fs, + fs_target=filesystem) + + def worker_task(args: Sequence[tuple[str, str]]) -> None: + """Function call on each worker to copy the partitions.""" + tuple(map(lambda arg: partial(*arg), args)) + + client.gather(client.map(worker_task, args)) + else: + for source_path in partitions: + target_path = fs_utils.join_path( + target, + os.path.relpath(source_path, + self.partition_properties.dir)) + fs_utils.copy_tree(source=source_path, + target=target_path, + fs_source=self.fs, + fs_target=filesystem) + # Then the remaining files in the root directory (config, metadata, # etc.) fs_utils.copy_files([ @@ -753,6 +844,7 @@ def worker_task(args: Sequence[tuple[str, str]]) -> None: for item in self.fs.listdir(self.partition_properties.dir, detail=True) if item['type'] == 'file' ], target, self.fs, filesystem) + return Collection.from_config(target, mode=mode, filesystem=filesystem, @@ -760,6 +852,7 @@ def worker_task(args: Sequence[tuple[str, str]]) -> None: def validate_partitions(self, filters: PartitionFilter | None = None, + distributed: bool = True, fix: bool = False) -> list[str]: """Validates partitions in the collection by checking if they exist and are readable. If `fix` is True, invalid partitions will be removed from @@ -770,6 +863,7 @@ def validate_partitions(self, validate. By default, all partitions are validated. fix: Whether to fix invalid partitions by removing them from the collection. + distributed: Whether to use dask or not. Default To True. Returns: A list of invalid partitions. @@ -777,19 +871,33 @@ def validate_partitions(self, partitions = tuple(self.partitions(filters=filters)) if not partitions: return [] - client: dask.distributed.Client = dask_utils.get_client() - futures: list[dask.distributed.Future] = client.map( - _check_partition, - partitions, - fs=self.fs, - partitioning_strategy=self.partitioning) + invalid_partitions: list[str] = [] - for item in dask.distributed.as_completed(futures): - partition, valid = item.result() # type: ignore - if not valid: - warnings.warn(f'Invalid partition: {partition}', - category=RuntimeWarning) - invalid_partitions.append(partition) + + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + futures: list[dask.distributed.Future] = client.map( + _check_partition, + partitions, + fs=self.fs, + partitioning_strategy=self.partitioning) + + for item in dask.distributed.as_completed(futures): + partition, valid = item.result() # type: ignore + if not valid: + warnings.warn(f'Invalid partition: {partition}', + category=RuntimeWarning) + invalid_partitions.append(partition) + else: + for partition in partitions: + partition, valid = _check_partition( + partition, + fs=self.fs, + partitioning_strategy=self.partitioning) + if not valid: + warnings.warn(f'Invalid partition: {partition}', + category=RuntimeWarning) + invalid_partitions.append(partition) if fix and invalid_partitions: for item in invalid_partitions: diff --git a/zcollection/collection/abc.py b/zcollection/collection/abc.py index afeab60..08dcce1 100644 --- a/zcollection/collection/abc.py +++ b/zcollection/collection/abc.py @@ -553,6 +553,7 @@ def load( filters: PartitionFilter = None, indexer: Indexer | None = None, selected_variables: Iterable[str] | None = None, + distributed: bool = True, ) -> dataset.Dataset | None: """Load the selected partitions. @@ -564,6 +565,7 @@ def load( indexer: The indexer to apply. selected_variables: A list of variables to retain from the collection. If None, all variables are kept. + distributed: Whether to use dask or not. Default To True. Returns: The dataset containing the selected partitions, or None if no @@ -582,22 +584,42 @@ def load( ... filters=lambda keys: keys["year"] == 2019 and ... keys["month"] == 3 and keys["day"] % 2 == 0) """ - client: dask.distributed.Client = dask_utils.get_client() + # Delayed has to be True of dask is disabled + if not distributed: + delayed = False + arrays: list[dataset.Dataset] + client: dask.distributed.Client + if indexer is None: + # No indexer, so the dataset is loaded directly for each + # selected partition. selected_partitions = tuple(self.partitions(filters=filters)) if len(selected_partitions) == 0: return None - # No indexer, so the dataset is loaded directly for each - # selected partition. - bag: dask.bag.core.Bag = dask.bag.core.from_sequence( - self.partitions(filters=filters), - npartitions=dask_utils.dask_workers(client, cores_only=True)) - arrays = bag.map(storage.open_zarr_group, - delayed=delayed, - fs=self.fs, - selected_variables=selected_variables).compute() + partitions = self.partitions(filters=filters) + + if distributed: + client = dask_utils.get_client() + bag: dask.bag.core.Bag = dask.bag.core.from_sequence( + partitions, + npartitions=dask_utils.dask_workers(client, + cores_only=True)) + arrays = bag.map( + storage.open_zarr_group, + delayed=delayed, + fs=self.fs, + selected_variables=selected_variables).compute() + else: + arrays = [ + storage.open_zarr_group( + dirname=partition, + delayed=delayed, + fs=self.fs, + selected_variables=selected_variables) + for partition in partitions + ] else: # We're going to reuse the indexer variable, so ensure it is # an iterable not a generator. @@ -617,21 +639,36 @@ def load( if len(args) == 0: return None - bag = dask.bag.core.from_sequence( - args, - npartitions=dask_utils.dask_workers(client, cores_only=True)) - # Finally, load the selected partitions and apply the indexer. - arrays = list( - itertools.chain.from_iterable( - bag.map( - _load_and_apply_indexer, - delayed=delayed, - fs=self.fs, - partition_handler=self.partitioning, - partition_properties=self.partition_properties, - selected_variables=selected_variables, - ).compute())) + if distributed: + client = dask_utils.get_client() + bag = dask.bag.core.from_sequence( + args, + npartitions=dask_utils.dask_workers(client, + cores_only=True)) + + arrays = list( + itertools.chain.from_iterable( + bag.map( + _load_and_apply_indexer, + delayed=delayed, + fs=self.fs, + partition_handler=self.partitioning, + partition_properties=self.partition_properties, + selected_variables=selected_variables, + ).compute())) + else: + arrays = list( + itertools.chain.from_iterable([ + _load_and_apply_indexer( + args=a, + delayed=delayed, + fs=self.fs, + partition_handler=self.partitioning, + partition_properties=self.partition_properties, + selected_variables=selected_variables) + for a in args + ])) array: dataset.Dataset = arrays.pop(0) if arrays: diff --git a/zcollection/collection/detail.py b/zcollection/collection/detail.py index 7ce0daf..c0df7d4 100644 --- a/zcollection/collection/detail.py +++ b/zcollection/collection/detail.py @@ -394,6 +394,7 @@ def _insert( fs: fsspec.AbstractFileSystem, merge_callable: merging.MergeCallable | None, partitioning_properties: PartitioningProperties, + distributed: bool = True, **kwargs, ) -> None: """Insert or update a partition in the collection. @@ -405,6 +406,7 @@ def _insert( fs: The file system that the partition is stored on. merge_callable: The merge callable. partitioning_properties: The partitioning properties. + distributed: Whether to use dask or not. Default To True. **kwargs: Additional keyword arguments to pass to the merge callable. """ partition: tuple[str, ...] @@ -423,7 +425,8 @@ def _insert( axis, fs, partitioning_properties.dim, - delayed=zds.delayed, + delayed=zds.delayed if distributed else False, + distributed=distributed, merge_callable=merge_callable, **kwargs) return @@ -434,7 +437,11 @@ def _insert( zarr.storage.init_group(store=fs.get_mapper(dirname)) # The synchronization is done by the caller. - write_zarr_group(zds.isel(indexer), dirname, fs, sync.NoSync()) + write_zarr_group(zds.isel(indexer), + dirname, + fs, + sync.NoSync(), + distributed=distributed) except: # noqa: E722 # If the construction of the new dataset fails, the created # partition is deleted, to guarantee the integrity of the @@ -459,7 +466,7 @@ def _load_and_apply_indexer( fs: The file system that the partition is stored on. partition_handler: The partitioning handler. partition_properties: The partitioning properties. - selected_variable: The selected variables to load. + selected_variables: The selected variables to load. Returns: The list of loaded datasets. diff --git a/zcollection/collection/tests/test_collection.py b/zcollection/collection/tests/test_collection.py index 1b60111..215d965 100644 --- a/zcollection/collection/tests/test_collection.py +++ b/zcollection/collection/tests/test_collection.py @@ -99,9 +99,11 @@ def test_collection_creation( # pylint: disable=too-many-statements @pytest.mark.parametrize('fs', ['local_fs', 's3_fs']) @pytest.mark.parametrize('arrays_type', ['dask_arrays', 'numpy_arrays']) +@pytest.mark.parametrize('distributed', [False, True]) def test_insert( dask_client, # pylint: disable=redefined-outer-name,unused-argument arrays_type, + distributed, fs, request, tmpdir, @@ -122,21 +124,22 @@ def test_insert( numpy.random.shuffle(indices) for idx in indices: zcollection.insert(datasets[idx], - merge_callable=merging.merge_time_series) + merge_callable=merging.merge_time_series, + distributed=distributed) - data = zcollection.load(delayed=delayed) + data = zcollection.load(delayed=delayed, distributed=distributed) assert data is not None values = data.variables['time'].values assert numpy.all(values == numpy.arange(START_DATE, END_DATE, DELTA)) # Adding same datasets once more (should not change anything) for idx in indices[:5]: - zcollection.insert(datasets[idx]) + zcollection.insert(datasets[idx], distributed=distributed) assert list(zcollection.partitions()) == sorted( list(zcollection.partitions())) - data = zcollection.load(delayed=delayed) + data = zcollection.load(delayed=delayed, distributed=distributed) assert data is not None values = data.variables['time'].values assert numpy.all(values == numpy.arange(START_DATE, END_DATE, DELTA)) @@ -149,22 +152,29 @@ def test_insert( numpy.all(values == numpy.vstack((numpy.arange(values.shape[0]), ) * values.shape[1]).T) - data = zcollection.load(delayed=delayed, filters='year == 2020') + data = zcollection.load(delayed=delayed, + filters='year == 2020', + distributed=distributed) assert data is None - data = zcollection.load(delayed=delayed, filters='year == 2000') + data = zcollection.load(delayed=delayed, + filters='year == 2000', + distributed=distributed) assert data is not None assert data.variables['time'].shape[0] == 61 data = zcollection.load(delayed=delayed, - filters='year == 2000 and month == 4') + filters='year == 2000 and month == 4', + distributed=distributed) assert data is not None dates = data.variables['time'].values assert numpy.all( dates.astype('datetime64[M]') == numpy.datetime64('2000-04-01')) data = zcollection.load( - delayed=delayed, filters='year == 2000 and month == 4 and day == 15') + delayed=delayed, + filters='year == 2000 and month == 4 and day == 15', + distributed=distributed) assert data is not None dates = data.variables['time'].values assert numpy.all( @@ -172,12 +182,14 @@ def test_insert( data = zcollection.load( delayed=delayed, - filters='year == 2000 and month == 4 and day in range(5, 25)') + filters='year == 2000 and month == 4 and day in range(5, 25)', + distributed=distributed) assert data is not None data = zcollection.load(delayed=delayed, filters=lambda keys: datetime.date(2000, 4, 5) <= datetime.date(keys['year'], keys['month'], keys[ - 'day']) <= datetime.date(2000, 4, 24)) + 'day']) <= datetime.date(2000, 4, 24), + distributed=distributed) assert data is not None dates = data.variables['time'].values.astype('datetime64[D]') assert dates.min() == numpy.datetime64('2000-04-06') @@ -190,16 +202,22 @@ def test_insert( zcollection = convenience.open_collection(str(tested_fs.collection), mode='r', filesystem=tested_fs.fs) - zds = zcollection.load(delayed=delayed, selected_variables=['var1']) + zds = zcollection.load(delayed=delayed, + selected_variables=['var1'], + distributed=distributed) assert zds is not None assert 'var1' in zds.variables assert 'var2' not in zds.variables - zds = zcollection.load(delayed=delayed, selected_variables=[]) + zds = zcollection.load(delayed=delayed, + selected_variables=[], + distributed=distributed) assert zds is not None assert len(zds.variables) == 0 - zds = zcollection.load(delayed=delayed, selected_variables=['varX']) + zds = zcollection.load(delayed=delayed, + selected_variables=['varX'], + distributed=distributed) assert zds is not None assert len(zds.variables) == 0 @@ -208,10 +226,12 @@ def test_insert( @pytest.mark.parametrize('fs,create_test_data', FILE_SYSTEM_DATASET) @pytest.mark.parametrize('arrays_type', ['dask_arrays', 'numpy_arrays']) +@pytest.mark.parametrize('distributed', [False, True]) def test_update( dask_client, # pylint: disable=redefined-outer-name,unused-argument fs, arrays_type, + distributed, create_test_data, request, ) -> None: @@ -224,15 +244,16 @@ def test_update( partitioning.Date(('time', ), 'D'), str(tested_fs.collection), filesystem=tested_fs.fs) - zcollection.insert(zds) + zcollection.insert(zds, distributed=distributed) def update(zds: dataset.Dataset, shift: int = 3): """Update function used for this test.""" return {'var2': zds.variables['var1'].values * -1 + shift} - zcollection.update(update, delayed=delayed) # type: ignore + zcollection.update(update, delayed=delayed, + distributed=distributed) # type: ignore - data = zcollection.load() + data = zcollection.load(distributed=distributed) assert data is not None assert numpy.allclose(data.variables['var2'].values, data.variables['var1'].values * -1 + 3, @@ -241,11 +262,12 @@ def update(zds: dataset.Dataset, shift: int = 3): zcollection.update( update, # type: ignore delayed=delayed, + distributed=distributed, variables=('var2', ), depth=1, shift=5) - data = zcollection.load(delayed=delayed) + data = zcollection.load(delayed=delayed, distributed=distributed) assert data is not None assert numpy.allclose(data.variables['var2'].values, data.variables['var1'].values * -1 + 5, @@ -256,11 +278,12 @@ def update(zds: dataset.Dataset, shift: int = 3): zcollection.update( update, # type: ignore delayed=delayed, + distributed=distributed, selected_variables=['var1'], depth=1, shift=5) - data = zcollection.load(delayed=delayed) + data = zcollection.load(delayed=delayed, distributed=distributed) assert data is not None assert numpy.allclose(data.variables['var2'].values, data.variables['var1'].values * -1 + 5, @@ -279,10 +302,11 @@ def update_with_info(zds: dataset.Dataset, partition_info=None, shift=3): zcollection.update( update_with_info, # type: ignore delayed=delayed, + distributed=distributed, depth=1, shift=10) - data = zcollection.load(delayed=delayed) + data = zcollection.load(delayed=delayed, distributed=distributed) assert data is not None assert numpy.allclose(data.variables['var2'].values, data.variables['var1'].values * -1 + 10, @@ -298,10 +322,11 @@ def update_and_trim(zds: dataset.Dataset, partition_info=None): zcollection.update( update_and_trim, # type: ignore delayed=delayed, + distributed=distributed, trim=False, depth=1) - data = zcollection.load(delayed=delayed) + data = zcollection.load(delayed=delayed, distributed=distributed) assert data is not None assert numpy.allclose(data.variables['var2'].values, data.variables['var1'].values * -1, @@ -313,13 +338,16 @@ def invalid_var_name(zds: dataset.Dataset): return {'var99': zds.variables['var1'].values * -1 + 3} with pytest.raises(ValueError): - zcollection.update(invalid_var_name) # type: ignore + zcollection.update(invalid_var_name, + distributed=distributed) # type: ignore @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_drop_partitions( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, ) -> None: """Test the dropping of a dataset.""" @@ -331,18 +359,21 @@ def test_drop_partitions( item.split(zcollection.fs.sep)[-2] for item in all_partitions ] - zcollection.drop_partitions(filters='year == 2000 and month==1') + zcollection.drop_partitions(filters='year == 2000 and month==1', + distributed=distributed) partitions = list(zcollection.partitions()) assert 'month=01' not in [ item.split(zcollection.fs.sep)[-2] for item in partitions ] npartitions = len(partitions) - zcollection.drop_partitions(timedelta=datetime.timedelta(days=1)) + zcollection.drop_partitions(timedelta=datetime.timedelta(days=1), + distributed=distributed) partitions = list(zcollection.partitions()) assert len(partitions) == npartitions - zcollection.drop_partitions(timedelta=datetime.timedelta(0)) + zcollection.drop_partitions(timedelta=datetime.timedelta(0), + distributed=distributed) partitions = list(zcollection.partitions()) assert len(partitions) == 0 @@ -350,13 +381,15 @@ def test_drop_partitions( mode='r', filesystem=tested_fs.fs) with pytest.raises(io.UnsupportedOperation): - zcollection.drop_partitions() + zcollection.drop_partitions(distributed=distributed) @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_drop_variable( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, ) -> None: """Test the dropping of a variable.""" @@ -364,13 +397,13 @@ def test_drop_variable( zcollection = create_test_collection(tested_fs, delayed=False) with pytest.raises(ValueError): - zcollection.drop_variable('time') - zcollection.drop_variable('var1') + zcollection.drop_variable('time', distributed=distributed) + zcollection.drop_variable('var1', distributed=distributed) with pytest.raises(ValueError): - zcollection.drop_variable('var1') + zcollection.drop_variable('var1', distributed=distributed) - zds = zcollection.load(delayed=False) + zds = zcollection.load(delayed=False, distributed=distributed) assert zds is not None assert 'var1' not in zds.variables @@ -378,13 +411,15 @@ def test_drop_variable( mode='r', filesystem=tested_fs.fs) with pytest.raises(io.UnsupportedOperation): - zcollection.drop_partitions() + zcollection.drop_partitions(distributed=distributed) @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_add_variable( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, ) -> None: """Test the adding of a variable.""" @@ -396,21 +431,21 @@ def test_add_variable( dtype=numpy.dtype('float64'), dimensions=('time', )) with pytest.raises(ValueError): - zcollection.add_variable(new) + zcollection.add_variable(new, distributed=distributed) # Variable doesn't use the partitioning dimension. new = meta.Variable(name='x', dtype=numpy.dtype('float64'), dimensions=('x', )) with pytest.raises(ValueError): - zcollection.add_variable(new) + zcollection.add_variable(new, distributed=distributed) # Variable doesn't use the dataset dimension. new = meta.Variable(name='x', dtype=numpy.dtype('float64'), dimensions=('time', 'x')) with pytest.raises(ValueError): - zcollection.add_variable(new) + zcollection.add_variable(new, distributed=distributed) new = meta.Variable( name='var3', @@ -419,7 +454,7 @@ def test_add_variable( fill_value=32267, attrs=(dataset.Attribute(name='attr', value=4), ), ) - zcollection.add_variable(new) + zcollection.add_variable(new, distributed=distributed) assert new.name in zcollection.metadata.variables @@ -430,7 +465,7 @@ def test_add_variable( assert new.name in zcollection.metadata.variables - zds = zcollection.load(delayed=False) + zds = zcollection.load(delayed=False, distributed=distributed) assert zds is not None values = zds.variables['var3'].values assert isinstance(values, numpy.ma.MaskedArray) @@ -439,10 +474,12 @@ def test_add_variable( @pytest.mark.parametrize('fs,create_test_data', FILE_SYSTEM_DATASET) @pytest.mark.parametrize('arrays_type', ['dask_arrays', 'numpy_arrays']) +@pytest.mark.parametrize('distributed', [False, True]) def test_add_update( dask_client, # pylint: disable=redefined-outer-name,unused-argument fs, arrays_type, + distributed, create_test_data, request, ) -> None: @@ -455,7 +492,7 @@ def test_add_update( partitioning.Date(('time', ), 'D'), str(tested_fs.collection), filesystem=tested_fs.fs) - zcollection.insert(zds) + zcollection.insert(zds, distributed=distributed) new1 = meta.Variable(name='var3', dtype=numpy.dtype('float64'), @@ -470,10 +507,10 @@ def test_add_update( fill_value=32267, attrs=(dataset.Attribute(name='attr', value=4), ), ) - zcollection.add_variable(new1) - zcollection.add_variable(new2) + zcollection.add_variable(new1, distributed=distributed) + zcollection.add_variable(new2, distributed=distributed) - data = zcollection.load(delayed=delayed) + data = zcollection.load(delayed=delayed, distributed=distributed) assert data is not None def update_1(zds, varname): @@ -487,9 +524,9 @@ def update_2(zds, varname): zcollection.update(update_1, new1.name, delayed=delayed) # type: ignore zcollection.update(update_2, new2.name, delayed=delayed) # type: ignore - if delayed is False: + if not (delayed and distributed): # If the dataset is not delayed, we need to reload it. - data = zcollection.load(delayed=False) + data = zcollection.load(delayed=False, distributed=distributed) assert data is not None assert numpy.allclose(data.variables[new1.name].values, @@ -532,9 +569,11 @@ def test_fillvalue( @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_degraded_tests( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, ) -> None: """Test the degraded functionality.""" @@ -546,13 +585,15 @@ def test_degraded_tests( fake_ds.variables['var3'].name = 'var3' with pytest.raises(ValueError): - zcollection.insert(fake_ds) + zcollection.insert(fake_ds, distributed=distributed) @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_insert_with_missing_variable( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, ) -> None: """Test of the insertion of a dataset in which a variable is missing. @@ -568,11 +609,13 @@ def test_insert_with_missing_variable( partition_handler=partitioning.Date(('time', ), 'M'), partition_base_dir=str(tested_fs.collection), filesystem=tested_fs.fs) - zcollection.insert(zds, merge_callable=merging.merge_time_series) + zcollection.insert(zds, + merge_callable=merging.merge_time_series, + distributed=distributed) zds = next(create_test_dataset_with_fillvalue()) zds.drops_vars('var1') - zcollection.insert(zds) + zcollection.insert(zds, distributed=distributed) data = zcollection.load() assert data is not None @@ -587,10 +630,12 @@ def test_insert_with_missing_variable( @pytest.mark.parametrize('fs', ['local_fs', 's3_fs']) @pytest.mark.parametrize('arrays_type', ['dask_arrays', 'numpy_arrays']) +@pytest.mark.parametrize('distributed', [False, True]) def test_insert_failed( dask_client, # pylint: disable=redefined-outer-name,unused-argument fs, arrays_type, + distributed, request, ) -> None: """Test the insertion of a dataset in which the insertion failed.""" @@ -613,18 +658,20 @@ def test_insert_failed( zcollection.fs.touch(one_directory) with pytest.raises((OSError, ValueError)): - zcollection.insert(zds) + zcollection.insert(zds, distributed=distributed) # Because the insert failed, the partition that was supposed to be created # was deleted. assert not zcollection.fs.exists(one_directory) - zcollection.insert(zds) + zcollection.insert(zds, distributed=distributed) @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_insert_validation( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, ) -> None: """Test the insertion of a dataset with metadata validation.""" @@ -637,21 +684,22 @@ def test_insert_validation( partition_handler=partitioning.Date(('time', ), 'M'), partition_base_dir=str(tested_fs.collection), filesystem=tested_fs.fs) - zcollection.insert(zds) + zcollection.insert(zds, distributed=distributed) zds = next(create_test_dataset_with_fillvalue()) # Inserting a dataset containing valid attributes zcollection.insert(zds, merge_callable=merging.merge_time_series, - tolerance=numpy.timedelta64(1, 'm')) + tolerance=numpy.timedelta64(1, 'm'), + distributed=distributed) # Inserting a dataset containing an invalid attributes zds = next(create_test_dataset_with_fillvalue()) zds.attrs = (meta.Attribute('invalid', 1), ) with pytest.raises(ValueError): - zcollection.insert(zds, validate=True) + zcollection.insert(zds, validate=True, distributed=distributed) # Inserting a dataset containing variables with invalid attributes zds = next(create_test_dataset_with_fillvalue()) @@ -660,7 +708,7 @@ def test_insert_validation( var.attrs = (meta.Attribute('invalid', 1), ) with pytest.raises(ValueError): - zcollection.insert(zds, validate=True) + zcollection.insert(zds, validate=True, distributed=distributed) @pytest.mark.parametrize('fs', ['local_fs', 's3_fs']) @@ -689,10 +737,12 @@ def test_map_partition( @pytest.mark.parametrize('fs', ['local_fs', 's3_fs']) @pytest.mark.parametrize('arrays_type', ['dask_arrays', 'numpy_arrays']) +@pytest.mark.parametrize('distributed', [False, True]) def test_indexer( dask_client, # pylint: disable=redefined-outer-name,unused-argument fs, arrays_type, + distributed, request, ) -> None: """Test the update of a dataset.""" @@ -703,9 +753,11 @@ def test_indexer( indexers = zcollection.map( lambda x: slice(0, x.dimensions['num_lines']) # type: ignore ).compute() - zds1 = zcollection.load(indexer=indexers, delayed=delayed) + zds1 = zcollection.load(indexer=indexers, + delayed=delayed, + distributed=distributed) assert zds1 is not None - zds2 = zcollection.load(delayed=delayed) + zds2 = zcollection.load(delayed=delayed, distributed=distributed) assert zds2 is not None assert numpy.allclose(zds1.variables['var1'].values, @@ -768,9 +820,11 @@ def func(zds: dataset.Dataset, partition_info: tuple[str, slice]): @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_insert_immutable( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, ) -> None: """Test the insertion of a dataset with variables that are immutable @@ -815,10 +869,11 @@ def test_insert_immutable( filesystem=tested_fs.fs) assert zcollection.immutable assert not tested_fs.fs.exists(zcollection._immutable) - zcollection.insert(zds_reference) + + zcollection.insert(zds_reference, distributed=distributed) assert tested_fs.fs.exists(zcollection._immutable) - zds = zcollection.load(delayed=False) + zds = zcollection.load(delayed=False, distributed=distributed) assert zds is not None assert numpy.all( @@ -834,8 +889,12 @@ def update(zds: dataset.Dataset, varname: str) -> dict[str, numpy.ndarray]: """Update function used for this test.""" return {varname: zds.variables['grid'].values * -1} - zcollection.update(update, delayed=False, varname='grid') # type: ignore - zds = zcollection.load(delayed=False) + zcollection.update( + update, # type: ignore + delayed=False, + distributed=distributed, + varname='grid') + zds = zcollection.load(delayed=False, distributed=distributed) assert zds is not None assert numpy.all(zds.variables['grid'].values == @@ -853,9 +912,12 @@ def update(zds: dataset.Dataset, varname: str) -> dict[str, numpy.ndarray]: dimensions=('time', 'lon', 'lat'), attrs=(meta.Attribute('units', 'm'), ), ) - zcollection.add_variable(new_variable) - zcollection.update(update, varname='new_var') # type: ignore - zds = zcollection.load() + zcollection.add_variable(new_variable, distributed=distributed) + zcollection.update( + update, # type: ignore + distributed=distributed, + varname='new_var') + zds = zcollection.load(distributed=distributed) assert zds is not None assert numpy.all(zds.variables['new_var'].values == zds_reference.variables['grid'].values) @@ -867,13 +929,15 @@ def update(zds: dataset.Dataset, varname: str) -> dict[str, numpy.ndarray]: attrs=(meta.Attribute('units', 'm'), ), ) with pytest.raises(ValueError): - zcollection.add_variable(new_variable) + zcollection.add_variable(new_variable, distributed=distributed) @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_copy_collection( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, tmpdir) -> None: """Test the dropping of a dataset.""" @@ -881,10 +945,12 @@ def test_copy_collection( zcollection = create_test_collection(tested_fs) target = str(tmpdir / 'copy') - zcopy = zcollection.copy(target, filesystem=fsspec.filesystem('file')) + zcopy = zcollection.copy(target, + filesystem=fsspec.filesystem('file'), + distributed=distributed) - ds_before_copy = zcollection.load() - ds_after_copy = zcopy.load() + ds_before_copy = zcollection.load(distributed=distributed) + ds_after_copy = zcopy.load(distributed=distributed) assert ds_before_copy is not None assert ds_after_copy is not None @@ -1010,8 +1076,10 @@ def new_shape( _ = zds['time'].values +@pytest.mark.parametrize('distributed', [False, True]) def test_invalid_partitions( dask_client, # pylint: disable=redefined-outer-name,unused-argument + distributed, tmpdir) -> None: fs = fsspec.filesystem('file') datasets = list(create_test_dataset()) @@ -1031,22 +1099,25 @@ def test_invalid_partitions( with fs.open(var2, 'wb') as file: file.write(b'invalid') with pytest.raises(ValueError): - _ = zcollection.load(delayed=False) + _ = zcollection.load(delayed=False, distributed=distributed) with pytest.warns(RuntimeWarning, match='Invalid partition'): - invalid_partitions = zcollection.validate_partitions() + invalid_partitions = zcollection.validate_partitions( + distributed=distributed) assert len(invalid_partitions) == 2 assert sorted(invalid_partitions) == sorted(partitions[ix] for ix in choices) with pytest.warns(RuntimeWarning, match='Invalid partition'): - zcollection.validate_partitions(fix=True) + zcollection.validate_partitions(fix=True, distributed=distributed) assert zcollection.load() is not None # pylint: disable=too-many-statements @pytest.mark.parametrize('fs', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_insert_with_chunks( dask_client, # pylint: disable=redefined-outer-name,unused-argument fs, + distributed, request, tmpdir, ) -> None: @@ -1067,7 +1138,11 @@ def test_insert_with_chunks( str(tmpdir / 'lock.lck'))) # First insertion - zcollection.insert(datasets[0], merge_callable=merging.merge_time_series) + zcollection.insert(datasets[0], + merge_callable=merging.merge_time_series, + distributed=distributed) + + # Not setting distributed to False when loading otherwise we won't have any chunk data = zcollection.load() assert data is not None @@ -1075,7 +1150,9 @@ def test_insert_with_chunks( assert data.variables['var2'].data.chunksize[1] == chunk_size # Insertion with merge - zcollection.insert(datasets[1], merge_callable=merging.merge_time_series) + zcollection.insert(datasets[1], + merge_callable=merging.merge_time_series, + distributed=distributed) data = zcollection.load() assert data is not None diff --git a/zcollection/convenience/view.py b/zcollection/convenience/view.py index 87dde30..3c54e35 100644 --- a/zcollection/convenience/view.py +++ b/zcollection/convenience/view.py @@ -13,14 +13,13 @@ from .. import collection, fs_utils, sync, view -def create_view( - path: str, - view_ref: view.ViewReference, - *, - filesystem: fsspec.AbstractFileSystem | str | None = None, - filters: collection.PartitionFilter = None, - synchronizer: sync.Sync | None = None, -) -> view.View: +def create_view(path: str, + view_ref: view.ViewReference, + *, + filesystem: fsspec.AbstractFileSystem | str | None = None, + filters: collection.PartitionFilter = None, + synchronizer: sync.Sync | None = None, + distributed: bool = True) -> view.View: """Create a new view. Args: @@ -30,6 +29,7 @@ def create_view( filters: The filters used to select the partitions of the reference view. If not provided, all partitions are selected. synchronizer: The synchronizer used to synchronize the view. + distributed: Whether to use dask or not. Default To True. Example: >>> view_ref = ViewReference( @@ -50,7 +50,8 @@ def create_view( ds=None, filesystem=filesystem, filters=filters, - synchronizer=synchronizer) + synchronizer=synchronizer, + distributed=distributed) def open_view( diff --git a/zcollection/merging/__init__.py b/zcollection/merging/__init__.py index 7a39e3e..f6690d3 100644 --- a/zcollection/merging/__init__.py +++ b/zcollection/merging/__init__.py @@ -98,6 +98,7 @@ def _update_fs( zds: dataset.Dataset, fs: fsspec.AbstractFileSystem, *, + distributed: bool = True, synchronizer: sync.Sync | None = None, ) -> None: """Updates a dataset stored in a partition. @@ -106,6 +107,7 @@ def _update_fs( dirname: The name of the partition. zds: The dataset to update. fs: The file system that the partition is stored on. + distributed: Whether to use dask or not. Default To True. synchronizer: The instance handling access to critical resources. """ # Building a temporary directory to store the new data. The name of the @@ -122,7 +124,11 @@ def _update_fs( # Writing new data. try: # The synchronization is done by the caller. - storage.write_zarr_group(zds, temp, fs, synchronizer or sync.NoSync()) + storage.write_zarr_group(zds=zds, + dirname=temp, + fs=fs, + synchronizer=synchronizer or sync.NoSync(), + distributed=distributed) except Exception: # The "write_zarr_group" method throws the exception if all scheduled # tasks are finished. So here we can delete the temporary directory. @@ -141,6 +147,7 @@ def perform( partitioning_dim: str, *, delayed: bool = True, + distributed: bool = True, merge_callable: MergeCallable | None, synchronizer: sync.Sync | None = None, **kwargs, @@ -155,6 +162,7 @@ def perform( partitioning_dim: The partitioning dimension. delayed: If True, the existing dataset is loaded lazily. Defaults to True. + distributed: Whether to use dask or not. Default To True. merge_callable: The merge callable. If None, the inserted dataset overwrites the existing dataset stored in the partition. Defaults to None. @@ -166,10 +174,16 @@ def perform( if merge_callable is None: zds = ds_inserted else: - ds = storage.open_zarr_group(dirname, fs, delayed=delayed) + ds = storage.open_zarr_group(dirname, + fs, + delayed=delayed if distributed else False) # Read dataset does not contain insertion properties. # This properties might be loss in the merge_callable depending on which # dataset is used. ds.copy_properties(ds=ds_inserted) zds = merge_callable(ds, ds_inserted, axis, partitioning_dim, **kwargs) - _update_fs(dirname, zds, fs, synchronizer=synchronizer) + _update_fs(dirname, + zds, + fs, + distributed=distributed, + synchronizer=synchronizer) diff --git a/zcollection/merging/tests/test_merging.py b/zcollection/merging/tests/test_merging.py index a8dc080..6ab4ec6 100644 --- a/zcollection/merging/tests/test_merging.py +++ b/zcollection/merging/tests/test_merging.py @@ -45,28 +45,27 @@ def test_update_fs( """Test the _update_fs function.""" generator = data.create_test_dataset(delayed=False) zds = next(generator) + zds_sc = dask_client.scatter(zds) partition_folder = local_fs.root.joinpath('variable=1') zattrs = str(partition_folder.joinpath('.zattrs')) - future = dask_client.submit(_update_fs, str(partition_folder), - dask_client.scatter(zds), local_fs.fs) + future = dask_client.submit(_update_fs, str(partition_folder), zds_sc, + local_fs.fs) dask_client.gather(future) assert local_fs.exists(zattrs) local_fs.fs.rm(str(partition_folder), recursive=True) assert not local_fs.exists(zattrs) - seen_exception = False - try: + + with pytest.raises(MyError): future = dask_client.submit(_update_fs, str(partition_folder), - dask_client.scatter(zds), + zds_sc, local_fs.fs, synchronizer=ThrowError()) dask_client.gather(future) - except MyError: - seen_exception = True - assert seen_exception + assert not local_fs.exists(zattrs) @@ -83,13 +82,13 @@ def test_perform( zds = next(generator) path = str(local_fs.root.joinpath('variable=1')) + zds_sc = dask_client.scatter(zds) - future = dask_client.submit(_update_fs, path, dask_client.scatter(zds), - local_fs.fs) + future = dask_client.submit(_update_fs, path, zds_sc, local_fs.fs) dask_client.gather(future) future = dask_client.submit(perform, - dask_client.scatter(zds), + zds_sc, path, 'time', local_fs.fs, diff --git a/zcollection/storage.py b/zcollection/storage.py index 9442198..2b7d921 100644 --- a/zcollection/storage.py +++ b/zcollection/storage.py @@ -258,8 +258,8 @@ def write_zarr_group( futures: list[dask.distributed.Future] = client.map( write_zarr_variable, iterables, - block_size_limit=zds.block_size_limit, chunks=zds.chunks, + block_size_limit=zds.block_size_limit, dirname=dirname, fs=fs, ) @@ -407,7 +407,7 @@ def add_zarr_array( template: str, fs: fsspec.AbstractFileSystem, *, - chunks: dict[str, int | str] | None = None, + chunks: dict[str, int] | None = None, ) -> None: """Add a variable to a Zarr dataset. diff --git a/zcollection/view/__init__.py b/zcollection/view/__init__.py index 097b56d..c337e3e 100644 --- a/zcollection/view/__init__.py +++ b/zcollection/view/__init__.py @@ -62,6 +62,7 @@ class View: filters: The filters used to select the partitions of the reference. If not provided, all partitions are selected. synchronizer: The synchronizer used to synchronize the view. + distributed: Whether to use dask or not. Default To True. Note: Normally, you should not call this constructor directly. Instead, use @@ -71,16 +72,15 @@ class View: #: Configuration filename of the view. CONFIG: ClassVar[str] = '.view' - def __init__( - self, - base_dir: str, - view_ref: ViewReference, - *, - ds: meta.Dataset | None, - filesystem: fsspec.AbstractFileSystem | str | None = None, - filters: collection.PartitionFilter = None, - synchronizer: sync.Sync | None = None, - ) -> None: + def __init__(self, + base_dir: str, + view_ref: ViewReference, + *, + ds: meta.Dataset | None, + filesystem: fsspec.AbstractFileSystem | str | None = None, + filters: collection.PartitionFilter = None, + synchronizer: sync.Sync | None = None, + distributed: bool = True) -> None: #: The file system used to access the view (default local file system). self.fs: fsspec.AbstractFileSystem = fs_utils.get_fs(filesystem) #: Path to the directory where the view is stored. @@ -100,11 +100,13 @@ def __init__( _LOGGER.info('Creating view %s', self) self.fs.makedirs(self.base_dir) self._write_config() - self._init_partitions(filters) + self._init_partitions(filters, distributed=distributed) else: _LOGGER.info('Opening view %s', self) - def _init_partitions(self, filters: collection.PartitionFilter) -> None: + def _init_partitions(self, + filters: collection.PartitionFilter, + distributed: bool = True) -> None: """Initialize the partitions of the view.""" _LOGGER.info('Populating view %s', self) args = tuple( @@ -116,16 +118,24 @@ def _init_partitions(self, filters: collection.PartitionFilter) -> None: args = tuple(filter(lambda item: not self.fs.exists(item), args)) _LOGGER.info('%d partitions selected from %s', len(args), self.view_ref) - client: dask.distributed.Client = dask_utils.get_client() - storage.execute_transaction( - client, self.synchronizer, - client.map( - _write_checksum, - tuple(args), - base_dir=self.base_dir, - view_ref=self.view_ref, - fs=self.fs, - )) + + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + storage.execute_transaction( + client, self.synchronizer, + client.map( + _write_checksum, + tuple(args), + base_dir=self.base_dir, + view_ref=self.view_ref, + fs=self.fs, + )) + else: + for arg in args: + _write_checksum(arg, + base_dir=self.base_dir, + view_ref=self.view_ref, + fs=self.fs) def __str__(self) -> str: return (f'{self.__class__.__name__}' @@ -229,14 +239,14 @@ def variables( return dataset.get_dataset_variable_properties(self.metadata, selected_variables) - def add_variable( - self, - variable: meta.Variable | dataset.Variable, - ) -> None: + def add_variable(self, + variable: meta.Variable | dataset.Variable, + distributed: bool = True) -> None: """Add a variable to the view. Args: variable: The variable to add + distributed: Whether to use dask or not. Default To True. Raises: ValueError: If the variable already exists @@ -258,7 +268,7 @@ def add_variable( if (variable.name in self.view_ref.metadata.variables or variable.name in self.metadata.variables): raise ValueError(f'Variable {variable.name} already exists') - client: dask.distributed.Client = dask_utils.get_client() + self.metadata.add_variable(variable) template: meta.Variable = \ self.view_ref.metadata.search_same_dimensions_as(variable) @@ -281,24 +291,41 @@ def add_variable( # from the collection metadata. variable = variable.set_for_insertion() - try: - storage.execute_transaction( - client, self.synchronizer, - client.map(_create_zarr_array, - tuple(args), - base_dir=self.base_dir, - fs=self.fs, - template=template.name, - variable=variable)) - except Exception: - storage.execute_transaction( - client, self.synchronizer, - client.map(_drop_zarr_zarr, - tuple(self.partitions()), - fs=self.fs, - variable=variable.name, - ignore_errors=True)) - raise + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + try: + storage.execute_transaction( + client, self.synchronizer, + client.map(_create_zarr_array, + tuple(args), + base_dir=self.base_dir, + fs=self.fs, + template=template.name, + variable=variable)) + except Exception: + storage.execute_transaction( + client, self.synchronizer, + client.map(_drop_zarr_zarr, + tuple(self.partitions()), + fs=self.fs, + variable=variable.name, + ignore_errors=True)) + raise + else: + try: + for arg in args: + _create_zarr_array(arg, + base_dir=self.base_dir, + fs=self.fs, + template=template.name, + variable=variable) + except Exception: + for partition in self.partitions(): + _drop_zarr_zarr(partition, + fs=self.fs, + variable=variable.name, + ignore_errors=True) + raise self._write_config() # pylint: enable=duplicate-code @@ -306,11 +333,13 @@ def add_variable( def drop_variable( self, varname: str, + distributed: bool = True, ) -> None: """Drop a variable from the view. Args: varname: The name of the variable to drop. + distributed: Whether to use dask or not. Default To True. Raise: ValueError: If the variable does not exist or if the variable @@ -322,17 +351,21 @@ def drop_variable( _LOGGER.info('Dropping variable %r', varname) _assert_variable_handled(self.view_ref.metadata, self.metadata, varname) - client: dask.distributed.Client = dask_utils.get_client() variable: meta.Variable = self.metadata.variables.pop(varname) self._write_config() - storage.execute_transaction( - client, self.synchronizer, - client.map(_drop_zarr_zarr, - tuple(self.partitions()), - fs=self.fs, - variable=variable.name)) + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + storage.execute_transaction( + client, self.synchronizer, + client.map(_drop_zarr_zarr, + tuple(self.partitions()), + fs=self.fs, + variable=variable.name)) + else: + for partition in self.partitions(): + _drop_zarr_zarr(partition, fs=self.fs, variable=variable.name) def load( self, @@ -341,6 +374,7 @@ def load( filters: collection.PartitionFilter = None, indexer: collection.abc.Indexer | None = None, selected_variables: Iterable[str] | None = None, + distributed: bool = True, ) -> dataset.Dataset | None: """Load the view. @@ -353,6 +387,7 @@ def load( indexer: The indexer to apply. selected_variables: A list of variables to retain from the view. If None, all variables are loaded. + distributed: Whether to use dask or not. Default To True. Returns: The dataset. @@ -363,6 +398,12 @@ def load( >>> view.load(filters=lambda x: x["time"] == "2020-01-01") """ _assert_have_variables(self.metadata) + # Delayed has to be True of dask is disabled + if not distributed: + delayed = False + + datasets: list[tuple[dataset.Dataset, str] | None] + if indexer is not None: arguments = tuple( collection.abc.build_indexer_args( @@ -376,26 +417,42 @@ def load( arguments = tuple((self.view_ref.partitioning.parse(item), []) for item in self.partitions(filters=filters)) - client: dask.distributed.Client = dask_utils.get_client() - futures: list[dask.distributed.Future] = client.map( - _load_one_dataset, - arguments, - base_dir=self.base_dir, - delayed=delayed, - fs=self.fs, - selected_variables=self.view_ref.metadata.select_variables( - selected_variables), - view_ref=client.scatter(self.view_ref), - variables=self.metadata.select_variables(selected_variables)) + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + futures: list[dask.distributed.Future] = client.map( + _load_one_dataset, + arguments, + base_dir=self.base_dir, + delayed=delayed, + fs=self.fs, + selected_variables=self.view_ref.metadata.select_variables( + selected_variables), + view_ref=client.scatter(self.view_ref), + variables=self.metadata.select_variables(selected_variables)) + datasets = client.gather(futures) + else: + datasets = [ + _load_one_dataset( + arg, + base_dir=self.base_dir, + delayed=delayed, + fs=self.fs, + selected_variables=self.view_ref.metadata.select_variables( + selected_variables), + view_ref=self.view_ref, + variables=self.metadata.select_variables( + selected_variables)) for arg in arguments + ] # The load function returns the path to the partitions and the loaded # datasets. Only the loaded datasets are retrieved here and filter None # values corresponding to empty partitions. arrays: list[dataset.Dataset] = list( map( - lambda item: item[0], # type: ignore[arg-type] + lambda item: item[0], # type: ignore[index] filter(lambda item: item is not None, - client.gather(futures)))) # type: ignore[arg-type] + datasets))) # type: ignore[arg-type] + if arrays: array: dataset.Dataset = arrays.pop(0) if arrays: @@ -418,6 +475,7 @@ def update( npartitions: int | None = None, selected_variables: Iterable[str] | None = None, trim: bool = True, + distributed: bool = True, **kwargs, ) -> None: """Update a variable stored int the view. @@ -446,6 +504,7 @@ def update( trim: Whether to trim ``depth`` items from each partition after calling ``func``. Set it to ``False`` if your function does this for you. + distributed: Whether to use dask or not. Default To True. args: The positional arguments to pass to the function. kwargs: The keyword arguments to pass to the function. @@ -461,13 +520,18 @@ def update( ... dataset: zcollection.dataset.Dataset, ... ) -> Dict[str, numpy.ndarray]: ... return dict( - ... temperature_kelvin=dataset["temperature"].values + 273, - ... 15) - >>> view.update(update_temperature) + ... temperature_kelvin=dataset["temperature"].values + + ... 273.15) + >>> view.update(temp_celsius_to_kelvin) """ _assert_have_variables(self.metadata) - client: dask.distributed.Client = dask_utils.get_client() + client: dask.distributed.Client | None + + if distributed: + client = dask_utils.get_client() + else: + client = None datasets_list = tuple( _load_datasets_list(client=client, @@ -525,15 +589,18 @@ def update( **kwargs, ) - batchs: Iterator[Sequence[Any]] = dask_utils.split_sequence( - datasets_list, npartitions - or dask_utils.dask_workers(client, cores_only=True)) - awaitables: list[dask.distributed.Future] = client.map( - wrap_function, - tuple(batchs), - key=func.__name__, - base_dir=self.base_dir) - storage.execute_transaction(client, self.synchronizer, awaitables) + if distributed: + batches: Iterator[Sequence[Any]] = dask_utils.split_sequence( + datasets_list, npartitions + or dask_utils.dask_workers(client, cores_only=True)) + awaitables: list[dask.distributed.Future] = client.map( + wrap_function, + tuple(batches), + key=func.__name__, + base_dir=self.base_dir) + storage.execute_transaction(client, self.synchronizer, awaitables) + else: + wrap_function(datasets_list, self.base_dir) # pylint: disable=duplicate-code # false positive, no code duplication @@ -731,31 +798,47 @@ def _wrap( npartitions=npartitions) return bag.map(_wrap, func, datasets_list, depth, *args, **kwargs) - def is_synced(self) -> bool: + def is_synced(self, distributed: bool = True) -> bool: """Check if the view is synchronized with the underlying collection. + Args: + distributed: Whether to use dask or not. Default To True. + Returns: True if the view is synchronized, False otherwise. """ partitions = tuple(self.view_ref.partitions(relative=True)) - client: dask.distributed.Client = dask_utils.get_client() - unsynchronized_partition = storage.execute_transaction( - client, self.synchronizer, - client.map(_sync, - partitions, - base_dir=self.base_dir, - fs=self.fs, - view_ref=self.view_ref, - metadata=self.metadata, - dry_run=True)) + + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + unsynchronized_partition = storage.execute_transaction( + client, self.synchronizer, + client.map(_sync, + partitions, + base_dir=self.base_dir, + fs=self.fs, + view_ref=self.view_ref, + metadata=self.metadata, + dry_run=True)) + else: + unsynchronized_partition = [ + _sync(partition, + base_dir=self.base_dir, + fs=self.fs, + view_ref=self.view_ref, + metadata=self.metadata, + dry_run=True) for partition in partitions + ] + return len( tuple( filter(lambda item: item is not None, unsynchronized_partition))) == 0 def sync( - self, - filters: collection.PartitionFilter = None + self, + filters: collection.PartitionFilter = None, + distributed: bool = True ) -> collection.abc.PartitionFilterCallback: """Synchronize the view with the underlying collection. @@ -773,6 +856,7 @@ def sync( selects the new partitions. Existing partitions are not removed, even if they are not selected by the predicate. + distributed: Whether to use dask or not. Default To True. Returns: A function that can be used as a predicate to get the partitions @@ -784,20 +868,31 @@ def sync( if filters is not None: self.filters = filters self._write_config() - self._init_partitions(filters) + self._init_partitions(filters, distributed=distributed) partitions = tuple(self.view_ref.partitions(relative=True)) _LOGGER.info('%d partitions to synchronize', len(partitions)) - client: dask.distributed.Client = dask_utils.get_client() - synchronized_partition: list[str | None] = storage.execute_transaction( - client, self.synchronizer, - client.map(_sync, - partitions, - base_dir=self.base_dir, - fs=self.fs, - view_ref=self.view_ref, - metadata=self.metadata)) + if distributed: + client: dask.distributed.Client = dask_utils.get_client() + synchronized_partition: list[str + | None] = storage.execute_transaction( + client, self.synchronizer, + client.map( + _sync, + partitions, + base_dir=self.base_dir, + fs=self.fs, + view_ref=self.view_ref, + metadata=self.metadata)) + else: + synchronized_partition = [ + _sync(partition, + base_dir=self.base_dir, + fs=self.fs, + view_ref=self.view_ref, + metadata=self.metadata) for partition in partitions + ] partition_ids = tuple( dict(self.view_ref.partitioning.parse( diff --git a/zcollection/view/detail.py b/zcollection/view/detail.py index f6aea02..50fc0ef 100644 --- a/zcollection/view/detail.py +++ b/zcollection/view/detail.py @@ -127,7 +127,6 @@ def _drop_zarr_zarr(partition: str, Args: partition: The partition that contains the array to drop. - base_dir: Base directory for the Zarr array. fs: The filesystem used to delete the Zarr array. variable: The name of the variable to drop. ignore_errors: If True, ignore errors when dropping the array. @@ -151,7 +150,7 @@ def _load_one_dataset( fs: fsspec.AbstractFileSystem, selected_variables: Iterable[str] | None, view_ref: collection.Collection, - variables: Sequence[str], + variables: Iterable[str], ) -> tuple[dataset.Dataset, str] | None: """Load a dataset from a partition stored in the reference collection and merge it with the variables defined in this view. @@ -245,7 +244,7 @@ def _assert_variable_handled(reference: meta.Dataset, view: meta.Dataset, def _load_datasets_list( *, - client: dask.distributed.Client, + client: dask.distributed.Client | None, base_dir: str, delayed: bool, fs: fsspec.AbstractFileSystem, @@ -257,7 +256,8 @@ def _load_datasets_list( """Load datasets from a list of partitions. Args: - client: The client used to load the datasets. + client: The client used to load the datasets (or None to + avoid dask usage). base_dir: Base directory of the view. delayed: If True, load the dataset lazily. fs: The file system used to access the variables in the view. @@ -271,19 +271,37 @@ def _load_datasets_list( """ arguments: tuple[tuple[tuple[tuple[str, int], ...], list], ...] = tuple( (view_ref.partitioning.parse(item), []) for item in partitions) - futures: list[dask.distributed.Future] = client.map( - _load_one_dataset, - arguments, - base_dir=base_dir, - delayed=delayed, - fs=fs, - selected_variables=view_ref.metadata.select_variables( - keep_variables=selected_variables), - view_ref=client.scatter(view_ref), - variables=metadata.select_variables(selected_variables)) - - return filter(lambda item: item is not None, - client.gather(futures)) # type: ignore[arg-type] + + datasets: list[tuple[dataset.Dataset, str] | None] + + if client is not None: + futures: list[dask.distributed.Future] = client.map( + _load_one_dataset, + arguments, + base_dir=base_dir, + delayed=delayed, + fs=fs, + selected_variables=view_ref.metadata.select_variables( + keep_variables=selected_variables), + view_ref=client.scatter(view_ref), + variables=metadata.select_variables(selected_variables)) + datasets = client.gather(futures) + else: + datasets = [ + _load_one_dataset( + arg, + base_dir=base_dir, + delayed=False, + fs=fs, + selected_variables=view_ref.metadata.select_variables( + keep_variables=selected_variables), + view_ref=view_ref, + variables=metadata.select_variables(selected_variables)) + for arg in arguments + ] + return filter( + lambda item: item is not None, # type: ignore[arg-type] + datasets) def _assert_have_variables(metadata: meta.Dataset) -> None: diff --git a/zcollection/view/tests/test_view.py b/zcollection/view/tests/test_view.py index 365249c..0083891 100644 --- a/zcollection/view/tests/test_view.py +++ b/zcollection/view/tests/test_view.py @@ -30,10 +30,12 @@ @pytest.mark.parametrize('fs', ['local_fs', 's3_fs']) @pytest.mark.parametrize('arrays_type', ['dask_arrays', 'numpy_arrays']) +@pytest.mark.parametrize('distributed', [False, True]) def test_view( dask_client, # pylint: disable=redefined-outer-name,unused-argument fs, arrays_type, + distributed, request, ): """Test the creation of a view.""" @@ -45,13 +47,14 @@ def test_view( view.ViewReference( str(tested_fs.collection), tested_fs.fs), - filesystem=tested_fs.fs) + filesystem=tested_fs.fs, + distributed=distributed) assert isinstance(instance, view.View) assert isinstance(str(instance), str) # No variable recorded, so no data can be loaded with pytest.raises(ValueError): - instance.load(delayed=delayed) + instance.load(delayed=delayed, distributed=distributed) var = meta.Variable( name='var2', @@ -61,17 +64,17 @@ def test_view( ) with pytest.raises(ValueError): - instance.add_variable(var) + instance.add_variable(var, distributed=distributed) var.name = 'var3' - instance.add_variable(var) + instance.add_variable(var, distributed=distributed) with pytest.raises(ValueError): - instance.add_variable(var) + instance.add_variable(var, distributed=distributed) instance = convenience.open_view(str(tested_fs.view), filesystem=tested_fs.fs) - zds = instance.load(delayed=delayed) + zds = instance.load(delayed=delayed, distributed=distributed) assert zds is not None assert set(zds['time'].values.astype('datetime64[D]')) == { numpy.datetime64('2000-01-01'), @@ -83,7 +86,9 @@ def test_view( } # Loading a variable existing only in the view. - zds = instance.load(delayed=delayed, selected_variables=('var3', )) + zds = instance.load(delayed=delayed, + selected_variables=('var3', ), + distributed=distributed) assert zds is not None assert tuple(zds.variables) == ('var3', ) assert 'var3' in zds.metadata().variables.keys() @@ -91,8 +96,10 @@ def test_view( # The metadata of the reference collection is not modified. assert 'var3' not in instance.view_ref.metadata.variables.keys() - # Loading a non existing variable. - zds = instance.load(delayed=delayed, selected_variables=('var55', )) + # Loading a non-existing variable. + zds = instance.load(delayed=delayed, + selected_variables=('var55', ), + distributed=distributed) assert zds is not None assert len(zds.variables) == 0 @@ -105,7 +112,7 @@ def test_view( assert len(tuple(instance.partitions())) == 5 assert len(tuple(instance.view_ref.partitions())) == 6 - zds = instance.load(delayed=delayed) + zds = instance.load(delayed=delayed, distributed=distributed) assert zds is not None assert set(zds['time'].values.astype('datetime64[D]')) == { numpy.datetime64('2000-01-01'), @@ -117,39 +124,51 @@ def test_view( # Create a variable with the unsynchronized view var.name = 'var4' - instance.add_variable(var) + instance.add_variable(var, distributed=distributed) - zds = instance.load(delayed=delayed) + zds = instance.load(delayed=delayed, distributed=distributed) assert zds is not None def update(zds, varname): """Update function used for this test.""" return {varname: zds.variables['var1'].values * 0 + 5} - instance.update(update, 'var3', delayed=delayed) # type: ignore + instance.update( + update, # type: ignore + 'var3', + delayed=delayed, + distributed=distributed) with pytest.raises(ValueError): - instance.update(update, 'varX') # type: ignore + instance.update( + update, # type: ignore + 'varX', + distributed=distributed) with pytest.raises(ValueError): - instance.update(update, 'var2') # type: ignore + instance.update( + update, # type: ignore + 'var2', + distributed=distributed) - zds = instance.load(delayed=delayed) + zds = instance.load(delayed=delayed, distributed=distributed) assert zds is not None numpy.all(zds.variables['var3'].values == 5) indexers = instance.map( lambda x: slice(0, x.dimensions['num_lines']) # type: ignore ).compute() - ds1 = instance.load(delayed=delayed, indexer=indexers) + ds1 = instance.load(delayed=delayed, + indexer=indexers, + distributed=distributed) assert ds1 is not None - ds2 = instance.load(delayed=delayed) + ds2 = instance.load(delayed=delayed, distributed=distributed) assert ds2 is not None assert numpy.allclose(ds1.variables['var1'].values, ds2.variables['var1'].values) - instance.drop_variable('var3') + instance.drop_variable('var3', distributed=distributed) assert tuple( str(pathlib.Path(item)) @@ -268,9 +287,11 @@ def test_view_checksum( @pytest.mark.filterwarnings('ignore:.*cannot be serialized.*') @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +@pytest.mark.parametrize('distributed', [False, True]) def test_view_sync( dask_client, # pylint: disable=redefined-outer-name,unused-argument arg, + distributed, request, ): """Test the synchronization of a view.""" @@ -280,11 +301,12 @@ def test_view_sync( view.ViewReference( str(tested_fs.collection), tested_fs.fs), - filesystem=tested_fs.fs) + filesystem=tested_fs.fs, + distributed=distributed) var = meta.Variable(name='var3', dtype=numpy.float64, dimensions=('num_lines', 'num_pixels')) - instance.add_variable(var) + instance.add_variable(var, distributed=distributed) del instance zcollection = convenience.open_collection(str(tested_fs.collection), @@ -305,7 +327,7 @@ def test_view_sync( instance = convenience.open_view(str(tested_fs.view), filesystem=tested_fs.fs) assert instance is not None - assert instance.is_synced() is False - instance.sync(filters=lambda keys: True) - zds = instance.load() + assert instance.is_synced(distributed=distributed) is False + instance.sync(filters=lambda keys: True, distributed=distributed) + zds = instance.load(distributed=distributed) assert zds is not None From 7fbef1834c32b9240e50aee3e6739129a757d5f7 Mon Sep 17 00:00:00 2001 From: Thomas Zilio Date: Sat, 7 Dec 2024 18:02:21 +0100 Subject: [PATCH 2/2] feat: Adding selected_partitions parameter to collection.load() and view.load() functions. refactor: collection.partitions() and view.partitions() now handle indexer and selected_partitions parameters. --- zcollection/collection/__init__.py | 26 +- zcollection/collection/abc.py | 281 ++++++++++++------ .../collection/tests/test_collection.py | 62 +++- zcollection/indexing/abc.py | 4 +- zcollection/view/__init__.py | 33 +- zcollection/view/tests/test_view.py | 12 + 6 files changed, 307 insertions(+), 111 deletions(-) diff --git a/zcollection/collection/__init__.py b/zcollection/collection/__init__.py index 0a50e3e..0d9b985 100644 --- a/zcollection/collection/__init__.py +++ b/zcollection/collection/__init__.py @@ -35,7 +35,7 @@ storage, sync, ) -from .abc import PartitionFilter, ReadOnlyCollection +from .abc import Indexer, PartitionFilter, ReadOnlyCollection from .callable_objects import UpdateCallable, WrappedPartitionCallable from .detail import ( PartitionSlice, @@ -45,6 +45,12 @@ _wrap_update_func_with_overlap, ) +__all__ = ('dask_utils', 'dataset', 'fs_utils', 'merging', 'meta', + 'partitioning', 'storage', 'sync', 'Indexer', 'PartitionFilter', + 'ReadOnlyCollection', 'UpdateCallable', 'WrappedPartitionCallable', + 'PartitionSlice', '_insert', '_try_infer_callable', + '_wrap_update_func', '_wrap_update_func_with_overlap') + #: Module logger. _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -874,6 +880,14 @@ def validate_partitions(self, invalid_partitions: list[str] = [] + def _validity_check(_partition, _valid): + """Check partition validity and add it to invalid partitions if not + valid.""" + if not _valid: + warnings.warn(f'Invalid partition: {_partition}', + category=RuntimeWarning) + invalid_partitions.append(_partition) + if distributed: client: dask.distributed.Client = dask_utils.get_client() futures: list[dask.distributed.Future] = client.map( @@ -884,20 +898,14 @@ def validate_partitions(self, for item in dask.distributed.as_completed(futures): partition, valid = item.result() # type: ignore - if not valid: - warnings.warn(f'Invalid partition: {partition}', - category=RuntimeWarning) - invalid_partitions.append(partition) + _validity_check(_partition=partition, _valid=valid) else: for partition in partitions: partition, valid = _check_partition( partition, fs=self.fs, partitioning_strategy=self.partitioning) - if not valid: - warnings.warn(f'Invalid partition: {partition}', - category=RuntimeWarning) - invalid_partitions.append(partition) + _validity_check(_partition=partition, _valid=valid) if fix and invalid_partitions: for item in invalid_partitions: diff --git a/zcollection/collection/abc.py b/zcollection/collection/abc.py index 08dcce1..c9f565d 100644 --- a/zcollection/collection/abc.py +++ b/zcollection/collection/abc.py @@ -294,27 +294,47 @@ def _relative_path(self, path: str) -> str: return pathlib.Path(path).relative_to( self.partition_properties.dir).as_posix() + def _normalize_partitions(self, + partitions: Iterable[str]) -> Iterable[str]: + """Normalize the provided list of partitions to include the full + partition's path. + + Args: + partitions: The list of partitions to normalize. + + Returns: + The list of partitions. + """ + return filter( + self.fs.exists, + map( + lambda partition: self.fs.sep.join( + (self.partition_properties.dir, partition)), + sorted(set(partitions)))) + def partitions( self, *, - cache: Iterable[str] | None = None, - lock: bool = False, filters: PartitionFilter = None, + indexer: Indexer | None = None, + selected_partitions: Iterable[str] | None = None, relative: bool = False, + lock: bool = False, ) -> Iterator[str]: """List the partitions of the collection. Args: - cache: The list of partitions to use. If None, the partitions are - listed. - lock: Whether to lock the collection or not to avoid listing - partitions while the collection is being modified. filters: The predicate used to filter the partitions to load. If the predicate is a string, it is a valid python expression to filter the partitions, using the partitioning scheme as variables. If the predicate is a function, it is a function that takes the partition scheme as input and returns a boolean. + indexer: The indexer to apply. + selected_partitions: A list of partitions to load (using the + partition relative path). relative: Whether to return the relative path. + lock: Whether to lock the collection or not to avoid listing + partitions while the collection is being modified. Returns: The list of partitions. @@ -336,8 +356,9 @@ def partitions( base_dir: str = self.partition_properties.dir sep: str = self.fs.sep - if cache is not None: - partitions: Iterable[str] = cache + if selected_partitions is not None: + partitions: Iterable[str] = self._normalize_partitions( + partitions=selected_partitions) else: if lock: with self.synchronizer: @@ -347,6 +368,17 @@ def partitions( partitions = self.partitioning.list_partitions( self.fs, base_dir) + if indexer is not None: + # List of partitions existing in the indexer and partitions list + partitions = list(partitions) + partitions = [ + p for p in list_partitions_from_indexer( + indexer=indexer, + partition_handler=self.partitioning, + base_dir=self.partition_properties.dir, + sep=self.fs.sep) if p in partitions + ] + yield from (self._relative_path(item) if relative else item for item in partitions if (item != self._immutable and self._is_selected( @@ -553,9 +585,11 @@ def load( filters: PartitionFilter = None, indexer: Indexer | None = None, selected_variables: Iterable[str] | None = None, + selected_partitions: Iterable[str] | None = None, distributed: bool = True, ) -> dataset.Dataset | None: - """Load the selected partitions. + """Load collection's data, respecting filters, indexer, and selected + partitions constraints. Args: delayed: Whether to load data in a dask array or not. @@ -565,6 +599,8 @@ def load( indexer: The indexer to apply. selected_variables: A list of variables to retain from the collection. If None, all variables are kept. + selected_partitions: A list of partitions to load (using the + partition relative path). distributed: Whether to use dask or not. Default To True. Returns: @@ -588,87 +624,24 @@ def load( if not distributed: delayed = False - arrays: list[dataset.Dataset] - client: dask.distributed.Client - if indexer is None: - # No indexer, so the dataset is loaded directly for each - # selected partition. - selected_partitions = tuple(self.partitions(filters=filters)) - if len(selected_partitions) == 0: - return None - - partitions = self.partitions(filters=filters) - - if distributed: - client = dask_utils.get_client() - bag: dask.bag.core.Bag = dask.bag.core.from_sequence( - partitions, - npartitions=dask_utils.dask_workers(client, - cores_only=True)) - arrays = bag.map( - storage.open_zarr_group, - delayed=delayed, - fs=self.fs, - selected_variables=selected_variables).compute() - else: - arrays = [ - storage.open_zarr_group( - dirname=partition, - delayed=delayed, - fs=self.fs, - selected_variables=selected_variables) - for partition in partitions - ] + arrays = self._load_partitions( + delayed=delayed, + filters=filters, + selected_variables=selected_variables, + selected_partitions=selected_partitions, + distributed=distributed) else: - # We're going to reuse the indexer variable, so ensure it is - # an iterable not a generator. - indexer = tuple(indexer) - - # Build the indexer arguments. - partitions = self.partitions(filters=filters, - cache=list_partitions_from_indexer( - indexer, self.partitioning, - self.partition_properties.dir, - self.fs.sep)) - args = tuple( - build_indexer_args(self, - filters, - indexer, - partitions=partitions)) - if len(args) == 0: - return None - - # Finally, load the selected partitions and apply the indexer. - if distributed: - client = dask_utils.get_client() - bag = dask.bag.core.from_sequence( - args, - npartitions=dask_utils.dask_workers(client, - cores_only=True)) - - arrays = list( - itertools.chain.from_iterable( - bag.map( - _load_and_apply_indexer, - delayed=delayed, - fs=self.fs, - partition_handler=self.partitioning, - partition_properties=self.partition_properties, - selected_variables=selected_variables, - ).compute())) - else: - arrays = list( - itertools.chain.from_iterable([ - _load_and_apply_indexer( - args=a, - delayed=delayed, - fs=self.fs, - partition_handler=self.partitioning, - partition_properties=self.partition_properties, - selected_variables=selected_variables) - for a in args - ])) + arrays = self._load_partitions_indexer( + indexer=indexer, + delayed=delayed, + filters=filters, + selected_variables=selected_variables, + selected_partitions=selected_partitions, + distributed=distributed) + + if arrays is None: + return None array: dataset.Dataset = arrays.pop(0) if arrays: @@ -682,6 +655,138 @@ def load( array.fill_attrs(self.metadata) return array + def _load_partitions( + self, + *, + delayed: bool = True, + filters: PartitionFilter = None, + selected_variables: Iterable[str] | None = None, + selected_partitions: Iterable[str] | None = None, + distributed: bool = True, + ) -> list[dataset.Dataset] | None: + """Load collection's partitions, respecting filters, and selected + partitions constraints. + + Args: + delayed: Whether to load data in a dask array or not. + filters: The predicate used to filter the partitions to load. To + get more information on the predicate, see the documentation of + the :meth:`partitions` method. + selected_variables: A list of variables to retain from the + collection. If None, all variables are kept. + selected_partitions: A list of partitions to load (using the + partition relative path). + distributed: Whether to use dask or not. Default To True. + + Returns: + The list of dataset for each partition, or None if no + partitions were selected. + """ + # No indexer, so the dataset is loaded directly for each + # selected partition. + selected_partitions = tuple( + self.partitions(filters=filters, + selected_partitions=selected_partitions)) + + if len(selected_partitions) == 0: + return None + + if distributed: + client = dask_utils.get_client() + bag: dask.bag.core.Bag = dask.bag.core.from_sequence( + selected_partitions, + npartitions=dask_utils.dask_workers(client, cores_only=True)) + arrays = bag.map(storage.open_zarr_group, + delayed=delayed, + fs=self.fs, + selected_variables=selected_variables).compute() + else: + arrays = [ + storage.open_zarr_group(dirname=partition, + delayed=delayed, + fs=self.fs, + selected_variables=selected_variables) + for partition in selected_partitions + ] + + return arrays + + def _load_partitions_indexer( + self, + *, + indexer: Indexer, + delayed: bool = True, + filters: PartitionFilter = None, + selected_variables: Iterable[str] | None = None, + selected_partitions: Iterable[str] | None = None, + distributed: bool = True, + ) -> list[dataset.Dataset] | None: + """Load collection's partitions, respecting filters, indexer, and + selected partitions constraints. + + Args: + indexer: The indexer to apply. + delayed: Whether to load data in a dask array or not. + filters: The predicate used to filter the partitions to load. To + get more information on the predicate, see the documentation of + the :meth:`partitions` method. + selected_variables: A list of variables to retain from the + collection. If None, all variables are kept. + selected_partitions: A list of partitions to load (using the + partition relative path). + distributed: Whether to use dask or not. Default To True. + + Returns: + The list of dataset for each partition, or None if no + partitions were selected. + """ + # We're going to reuse the indexer variable, so ensure it is + # an iterable not a generator. + indexer = tuple(indexer) + + # Build the indexer arguments. + partitions = self.partitions(selected_partitions=selected_partitions, + filters=filters, + indexer=indexer) + args = tuple( + build_indexer_args(collection=self, + filters=filters, + indexer=indexer, + partitions=partitions)) + if len(args) == 0: + return None + + # Finally, load the selected partitions and apply the indexer. + if distributed: + client = dask_utils.get_client() + bag = dask.bag.core.from_sequence( + args, + npartitions=dask_utils.dask_workers(client, cores_only=True)) + + arrays = list( + itertools.chain.from_iterable( + bag.map( + _load_and_apply_indexer, + delayed=delayed, + fs=self.fs, + partition_handler=self.partitioning, + partition_properties=self.partition_properties, + selected_variables=selected_variables, + ).compute())) + else: + arrays = list( + itertools.chain.from_iterable([ + _load_and_apply_indexer( + args=a, + delayed=delayed, + fs=self.fs, + partition_handler=self.partitioning, + partition_properties=self.partition_properties, + selected_variables=selected_variables) for a in args + ])) + + return arrays + def _bag_from_partitions( self, filters: PartitionFilter | None = None, diff --git a/zcollection/collection/tests/test_collection.py b/zcollection/collection/tests/test_collection.py index 215d965..f9f4f6b 100644 --- a/zcollection/collection/tests/test_collection.py +++ b/zcollection/collection/tests/test_collection.py @@ -342,6 +342,62 @@ def invalid_var_name(zds: dataset.Dataset): distributed=distributed) # type: ignore +@pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) +def test_list_partitions( + dask_client, # pylint: disable=redefined-outer-name,unused-argument + arg, + request, +) -> None: + """Test the dropping of a dataset.""" + tested_fs = request.getfixturevalue(arg) + zcollection = create_test_collection(tested_fs, delayed=False) + + all_partitions = list(zcollection.partitions()) + assert len(all_partitions) == 6 + + full_path = lambda partition: zcollection.fs.sep.join( + (zcollection.partition_properties.dir, partition)) + + selected_partitions = ['year=2000/month=01/day=01'] + partitions = list( + zcollection.partitions(selected_partitions=selected_partitions)) + assert partitions == list(map(full_path, selected_partitions)) + + selected_partitions = [ + 'year=2000/month=01/day=01', 'year=2000/month=01/day=01' + ] + partitions = list( + zcollection.partitions(selected_partitions=selected_partitions)) + assert partitions == list(map(full_path, selected_partitions[:1])) + + selected_partitions = ['year=2000/month=01/day=02'] + partitions = list( + zcollection.partitions(selected_partitions=selected_partitions)) + assert not partitions + + selected_partitions = [ + 'year=2000/month=01/day=01', 'year=2000/month=01/day=02', + 'year=2000/month=01/day=07' + ] + partitions = list( + zcollection.partitions(selected_partitions=selected_partitions)) + assert partitions == list( + map(full_path, [selected_partitions[0], selected_partitions[2]])) + + indexer = zcollection.map( + lambda x: slice(0, x.dimensions['num_lines']) # type: ignore + ).compute()[3:] + + selected_partitions = [ + 'year=2000/month=01/day=01', 'year=2000/month=01/day=02', + 'year=2000/month=01/day=13' + ] + partitions = list( + zcollection.partitions(indexer=indexer, + selected_partitions=selected_partitions)) + assert partitions == list(map(full_path, selected_partitions[-1:])) + + @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) @pytest.mark.parametrize('distributed', [False, True]) def test_drop_partitions( @@ -900,11 +956,11 @@ def update(zds: dataset.Dataset, varname: str) -> dict[str, numpy.ndarray]: assert numpy.all(zds.variables['grid'].values == zds_reference.variables['grid'].values * -1) assert numpy.all( - zds.variables['time'].values == zds.variables['time'].values) + zds.variables['time'].values == zds_reference.variables['time'].values) assert numpy.all( - zds.variables['lon'].values == zds.variables['lon'].values) + zds.variables['lon'].values == zds_reference.variables['lon'].values) assert numpy.all( - zds.variables['lat'].values == zds.variables['lat'].values) + zds.variables['lat'].values == zds_reference.variables['lat'].values) new_variable = meta.Variable( 'new_var', diff --git a/zcollection/indexing/abc.py b/zcollection/indexing/abc.py index 1044c0d..0c4757d 100644 --- a/zcollection/indexing/abc.py +++ b/zcollection/indexing/abc.py @@ -356,7 +356,7 @@ def _read(self) -> pyarrow.Table: return self._table def _table_2_indexer(self, table: pyarrow.Table, - only_partition_keys: bool) -> collection.abc.Indexer: + only_partition_keys: bool) -> collection.Indexer: """Convert a table to an indexer. Args: @@ -409,7 +409,7 @@ def query( logical_op: str | None = None, mask: pyarrow.ChunkedArray | None = None, only_partition_keys: bool = True, - ) -> collection.abc.Indexer: + ) -> collection.Indexer: """Query the index. Args: diff --git a/zcollection/view/__init__.py b/zcollection/view/__init__.py index c337e3e..860c19b 100644 --- a/zcollection/view/__init__.py +++ b/zcollection/view/__init__.py @@ -209,19 +209,29 @@ def from_config( def partitions( self, filters: collection.PartitionFilter = None, + indexer: collection.Indexer | None = None, + selected_partitions: Iterable[str] | None = None, ) -> Iterator[str]: """Returns the list of partitions in the view. Args: filters: The partition filters. + indexer: The indexer to apply. + selected_partitions: A list of partitions to load (using the + partition relative path). Returns: The list of partitions. """ return filter( self.fs.exists, - map(lambda item: fs_utils.join_path(self.base_dir, item), - self.view_ref.partitions(filters=filters, relative=True))) + map( + lambda item: fs_utils.join_path(self.base_dir, item), + self.view_ref.partitions( + filters=filters, + indexer=indexer, + selected_partitions=selected_partitions, + relative=True))) def variables( self, @@ -372,8 +382,9 @@ def load( *, delayed: bool = True, filters: collection.PartitionFilter = None, - indexer: collection.abc.Indexer | None = None, + indexer: collection.Indexer | None = None, selected_variables: Iterable[str] | None = None, + selected_partitions: Iterable[str] | None = None, distributed: bool = True, ) -> dataset.Dataset | None: """Load the view. @@ -387,6 +398,8 @@ def load( indexer: The indexer to apply. selected_variables: A list of variables to retain from the view. If None, all variables are loaded. + selected_partitions: A list of partitions to load (using the + partition relative path). distributed: Whether to use dask or not. Default To True. Returns: @@ -403,19 +416,21 @@ def load( delayed = False datasets: list[tuple[dataset.Dataset, str] | None] + partitions = self.partitions(selected_partitions=selected_partitions, + filters=filters, + indexer=indexer) if indexer is not None: arguments = tuple( - collection.abc.build_indexer_args( - self.view_ref, - filters, - indexer, - partitions=self.partitions())) + collection.abc.build_indexer_args(collection=self.view_ref, + filters=filters, + indexer=indexer, + partitions=partitions)) if len(arguments) == 0: return None else: arguments = tuple((self.view_ref.partitioning.parse(item), []) - for item in self.partitions(filters=filters)) + for item in partitions) if distributed: client: dask.distributed.Client = dask_utils.get_client() diff --git a/zcollection/view/tests/test_view.py b/zcollection/view/tests/test_view.py index 0083891..737a442 100644 --- a/zcollection/view/tests/test_view.py +++ b/zcollection/view/tests/test_view.py @@ -112,6 +112,18 @@ def test_view( assert len(tuple(instance.partitions())) == 5 assert len(tuple(instance.view_ref.partitions())) == 6 + selected_partitions = [ + 'year=2000/month=01/day=01', 'year=2000/month=01/day=07', + 'year=2000/month=01/day=13' + ] + assert len( + tuple( + instance.partitions(selected_partitions=selected_partitions))) == 2 + assert len( + tuple( + instance.view_ref.partitions( + selected_partitions=selected_partitions))) == 3 + zds = instance.load(delayed=delayed, distributed=distributed) assert zds is not None assert set(zds['time'].values.astype('datetime64[D]')) == {