diff --git a/rohmu/object_storage/azure.py b/rohmu/object_storage/azure.py index f627926e..310e3e54 100644 --- a/rohmu/object_storage/azure.py +++ b/rohmu/object_storage/azure.py @@ -24,7 +24,6 @@ KEY_TYPE_OBJECT, KEY_TYPE_PREFIX, ProgressProportionCallbackType, - SourceStorageModelT, ) from rohmu.object_storage.config import ( # noqa: F401 AZURE_ENDPOINT_SUFFIXES as ENDPOINT_SUFFIXES, @@ -33,7 +32,8 @@ calculate_azure_max_block_size as calculate_max_block_size, ) from rohmu.typing import Metadata -from typing import Any, BinaryIO, Collection, Iterator, Optional, Tuple, Union +from typing import Any, BinaryIO, Iterator, Optional, Tuple, Union +from typing_extensions import Self import azure.common import enum @@ -182,7 +182,8 @@ def copy_file( def _copy_file_from_bucket( self, - source_bucket: AzureTransfer, + *, + source_bucket: Self, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, @@ -228,13 +229,6 @@ def _copy_file_from_bucket( f"Copying {repr(source_key)} to {repr(destination_key)} failed, unexpected status: {copy_props.status}" ) - def copy_files_from(self, *, source: BaseTransfer[SourceStorageModelT], keys: Collection[str]) -> None: - if isinstance(source, AzureTransfer): - for key in keys: - self._copy_file_from_bucket(source_bucket=source, source_key=key, destination_key=key, timeout=15) - else: - raise NotImplementedError - def get_metadata_for_key(self, key: str) -> Metadata: path = self.format_key_for_backend(key, remove_slash_prefix=True, trailing_slash=False) items = list(self._iter_key(path=path, with_metadata=True, deep=False)) diff --git a/rohmu/object_storage/base.py b/rohmu/object_storage/base.py index 09b3f082..95766929 100644 --- a/rohmu/object_storage/base.py +++ b/rohmu/object_storage/base.py @@ -56,6 +56,10 @@ class IterKeyItem(NamedTuple): IncrementalProgressCallbackType = Optional[Callable[[int], None]] +class ObjectTransferProgressCallback(Protocol): + def __call__(self, files_completed: int, total_files: int) -> None: ... + + @dataclass(frozen=True, unsafe_hash=True) class ConcurrentUpload: backend: str @@ -202,7 +206,31 @@ def copy_file( cannot be copied with this method. If no metadata is given copies the existing metadata.""" raise NotImplementedError - def copy_files_from(self, *, source: BaseTransfer[SourceStorageModelT], keys: Collection[str]) -> None: + def copy_files_from( + self, + *, + source: BaseTransfer[Any], + keys: Collection[str], + progress_fn: ObjectTransferProgressCallback | None = None, + ) -> None: + if isinstance(source, self.__class__): + total_files = len(keys) + for index, key in enumerate(keys): + self._copy_file_from_bucket(source_bucket=source, source_key=key, destination_key=key, timeout=15) + if progress_fn is not None: + progress_fn(index + 1, total_files) + else: + raise NotImplementedError + + def _copy_file_from_bucket( + self, + *, + source_bucket: Self, + source_key: str, + destination_key: str, + metadata: Optional[Metadata] = None, + timeout: float = 15.0, + ) -> None: raise NotImplementedError def format_key_for_backend(self, key: str, remove_slash_prefix: bool = False, trailing_slash: bool = False) -> str: diff --git a/rohmu/object_storage/google.py b/rohmu/object_storage/google.py index 3977b243..a60c58eb 100644 --- a/rohmu/object_storage/google.py +++ b/rohmu/object_storage/google.py @@ -38,7 +38,6 @@ KEY_TYPE_OBJECT, KEY_TYPE_PREFIX, ProgressProportionCallbackType, - SourceStorageModelT, ) from rohmu.object_storage.config import ( GOOGLE_DOWNLOAD_CHUNK_SIZE as DOWNLOAD_CHUNK_SIZE, @@ -52,7 +51,6 @@ BinaryIO, Callable, cast, - Collection, Iterable, Iterator, Optional, @@ -62,7 +60,7 @@ TypeVar, Union, ) -from typing_extensions import Protocol +from typing_extensions import Protocol, Self import codecs import dataclasses @@ -349,7 +347,13 @@ def copy_file( ) def _copy_file_from_bucket( - self, *, source_bucket: GoogleTransfer, source_key: str, destination_key: str, metadata: Optional[Metadata] = None + self, + *, + source_bucket: Self, + source_key: str, + destination_key: str, + metadata: Optional[Metadata] = None, + timeout: float = 15.0, ) -> None: source_object = source_bucket.format_key_for_backend(source_key) destination_object = self.format_key_for_backend(destination_key) @@ -374,13 +378,6 @@ def _copy_file_from_bucket( self.notifier.object_copied(key=destination_key, size=size, metadata=metadata) reporter.report(self.stats) - def copy_files_from(self, *, source: BaseTransfer[SourceStorageModelT], keys: Collection[str]) -> None: - if isinstance(source, GoogleTransfer): - for key in keys: - self._copy_file_from_bucket(source_bucket=source, source_key=key, destination_key=key) - else: - raise NotImplementedError - def get_metadata_for_key(self, key: str) -> Metadata: path = self.format_key_for_backend(key) with self._object_client(not_found=path) as clob: diff --git a/rohmu/object_storage/local.py b/rohmu/object_storage/local.py index b793c559..84775aea 100644 --- a/rohmu/object_storage/local.py +++ b/rohmu/object_storage/local.py @@ -18,12 +18,12 @@ KEY_TYPE_OBJECT, KEY_TYPE_PREFIX, ProgressProportionCallbackType, - SourceStorageModelT, ) from rohmu.object_storage.config import LOCAL_CHUNK_SIZE as CHUNK_SIZE, LocalObjectStorageConfig as Config from rohmu.typing import Metadata from rohmu.util import BinaryStreamsConcatenation, ProgressStream -from typing import Any, BinaryIO, Collection, Iterator, Optional, TextIO, Tuple, Union +from typing import Any, BinaryIO, Iterator, Optional, TextIO, Tuple, Union +from typing_extensions import Self import contextlib import datetime @@ -81,7 +81,13 @@ def copy_file( ) def _copy_file_from_bucket( - self, *, source_bucket: LocalTransfer, source_key: str, destination_key: str, metadata: Optional[Metadata] = None + self, + *, + source_bucket: Self, + source_key: str, + destination_key: str, + metadata: Optional[Metadata] = None, + timeout: float = 15.0, ) -> None: source_path = source_bucket.format_key_for_backend(source_key.strip("/")) destination_path = self.format_key_for_backend(destination_key.strip("/")) @@ -97,13 +103,6 @@ def _copy_file_from_bucket( self._save_metadata(destination_path, new_metadata) self.notifier.object_copied(key=destination_key, size=os.path.getsize(destination_path), metadata=metadata) - def copy_files_from(self, *, source: BaseTransfer[SourceStorageModelT], keys: Collection[str]) -> None: - if isinstance(source, LocalTransfer): - for key in keys: - self._copy_file_from_bucket(source_bucket=source, source_key=key, destination_key=key) - else: - raise NotImplementedError - def _get_metadata_for_key(self, key: str) -> Metadata: source_path = self.format_key_for_backend(key.strip("/")) if not os.path.exists(source_path): diff --git a/rohmu/object_storage/s3.py b/rohmu/object_storage/s3.py index 22e5e73e..091674a0 100644 --- a/rohmu/object_storage/s3.py +++ b/rohmu/object_storage/s3.py @@ -29,7 +29,6 @@ KEY_TYPE_OBJECT, KEY_TYPE_PREFIX, ProgressProportionCallbackType, - SourceStorageModelT, ) from rohmu.object_storage.config import ( # noqa: F401 calculate_s3_chunk_size as calculate_chunk_size, @@ -42,6 +41,7 @@ from rohmu.util import batched, ProgressStream from threading import RLock from typing import Any, BinaryIO, cast, Collection, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing_extensions import Self import botocore.client import botocore.config @@ -284,7 +284,13 @@ def copy_file( ) def _copy_file_from_bucket( - self, *, source_bucket: S3Transfer, source_key: str, destination_key: str, metadata: Optional[Metadata] = None + self, + *, + source_bucket: Self, + source_key: str, + destination_key: str, + metadata: Optional[Metadata] = None, + timeout: float = 15.0, ) -> None: source_path = ( source_bucket.bucket_name + "/" + source_bucket.format_key_for_backend(source_key, remove_slash_prefix=True) @@ -307,13 +313,6 @@ def _copy_file_from_bucket( else: raise StorageError(f"Copying {source_key!r} to {destination_key!r} failed: {ex!r}") from ex - def copy_files_from(self, *, source: BaseTransfer[SourceStorageModelT], keys: Collection[str]) -> None: - if isinstance(source, S3Transfer): - for key in keys: - self._copy_file_from_bucket(source_bucket=source, source_key=key, destination_key=key) - else: - raise NotImplementedError - def get_metadata_for_key(self, key: str) -> Metadata: path = self.format_key_for_backend(key, remove_slash_prefix=True) return self._metadata_for_key(path) diff --git a/test/object_storage/test_object_storage.py b/test/object_storage/test_object_storage.py index 646f47c0..6f5df2a3 100644 --- a/test/object_storage/test_object_storage.py +++ b/test/object_storage/test_object_storage.py @@ -4,6 +4,8 @@ from rohmu import errors from rohmu.object_storage.local import LocalTransfer from typing import Any +from unittest import mock +from unittest.mock import MagicMock import pytest @@ -69,18 +71,26 @@ def test_copy(transfer_type: str, request: Any) -> None: assert transfer.get_contents_to_string("dummy_copy_metadata") == (DUMMY_CONTENT, {"new_k": "new_v"}) -def test_copy_local_files_from(tmp_path: Path) -> None: +@pytest.mark.parametrize("with_progress_fn", [False, True]) +def test_copy_local_files_from(tmp_path: Path, with_progress_fn: bool) -> None: source = LocalTransfer(tmp_path / "source", prefix="s-prefix") destination = LocalTransfer(tmp_path / "destination", prefix="d-prefix") + mock_progress_fn = MagicMock(return_value=None) source.store_file_from_memory("some/a/key.ext", b"content_a", metadata={"info": "aaa"}) source.store_file_from_memory("some/b/key.ext", b"content_b", metadata={"info": "bbb"}) + source.store_file_from_memory("some/c/key.ext", b"content_c", metadata={"info": "ccc"}) destination.copy_files_from( source=source, - keys=["some/a/key.ext", "some/b/key.ext"], + keys=["some/a/key.ext", "some/b/key.ext", "some/c/key.ext"], + progress_fn=mock_progress_fn if with_progress_fn else None, ) + assert destination.get_contents_to_string("some/a/key.ext") == (b"content_a", {"info": "aaa", "Content-Length": "9"}) assert destination.get_contents_to_string("some/b/key.ext") == (b"content_b", {"info": "bbb", "Content-Length": "9"}) + assert destination.get_contents_to_string("some/c/key.ext") == (b"content_c", {"info": "ccc", "Content-Length": "9"}) + if with_progress_fn: + assert mock_progress_fn.call_args_list == [mock.call(1, 3), mock.call(2, 3), mock.call(3, 3)] @pytest.mark.parametrize("transfer_type", ["local_transfer"])