diff --git a/leap_data_management_utils/data_management_transforms.py b/leap_data_management_utils/data_management_transforms.py index 645d691..e2cce73 100644 --- a/leap_data_management_utils/data_management_transforms.py +++ b/leap_data_management_utils/data_management_transforms.py @@ -116,3 +116,43 @@ def _register_dataset_to_catalog(self, store: zarr.storage.FSStore) -> zarr.stor def expand(self, pcoll: beam.PCollection) -> beam.PCollection: return pcoll | beam.Map(self._register_dataset_to_catalog) + + +@dataclass +class Copy(beam.PTransform): + target: str + + def _copy(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore: + import os + + import gcsfs + import zarr + + # We do need the gs:// prefix? + # TODO: Determine this dynamically from zarr.storage.FSStore + source = f'gs://{os.path.normpath(store.path)}/' # FIXME more elegant. `.copytree` needs trailing slash + fs = gcsfs.GCSFileSystem() # FIXME: How can we generalize this? + fs.cp(source, self.target, recursive=True) + # return a new store with the new path that behaves exactly like the input + # to this stage (so we can slot this stage right before testing/logging stages) + return zarr.storage.FSStore(self.target) + + def expand(self, pcoll: beam.PCollection) -> beam.PCollection: + return pcoll | 'Copying Store' >> beam.Map(self._copy) + + +@dataclass +class InjectAttrs(beam.PTransform): + inject_attrs: dict + + def _update_zarr_attrs(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore: + # TODO: Can we get a warning here if the store does not exist? + attrs = zarr.open(store, mode='a').attrs + attrs.update(self.inject_attrs) + # ? Should we consolidate here? We are explicitly doing that later... + return store + + def expand( + self, pcoll: beam.PCollection[zarr.storage.FSStore] + ) -> beam.PCollection[zarr.storage.FSStore]: + return pcoll | 'Injecting Attributes' >> beam.Map(self._update_zarr_attrs)