Skip to content

Commit

Permalink
Implement advanced transfer options for boto3.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed May 9, 2024
1 parent 8ff1ca6 commit 98c5e69
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 18 deletions.
91 changes: 73 additions & 18 deletions lib/galaxy/objectstore/s3_boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import logging
import os
from typing import (
Any,
Callable,
Dict,
TYPE_CHECKING,
TypedDict,
)
Expand All @@ -19,10 +22,13 @@
try:
# Imports are done this way to allow objectstore code to be used outside of Galaxy.
import boto3
from boto3.s3.transfer import TransferConfig
from botocore.client import ClientError
except ImportError:
boto3 = None # type: ignore[assignment,unused-ignore]
TransferConfig = None # type: ignore[assignment,unused-ignore,misc]

from galaxy.util import asbool
from ._caching_base import CachingConcreteObjectStore
from .caching import (
enable_cache_monitor,
Expand All @@ -35,22 +41,6 @@
)

log = logging.getLogger(__name__)
logging.getLogger("boto").setLevel(logging.INFO) # Otherwise boto is quite noisy


def download_directory(bucket, remote_folder, local_path):
# List objects in the specified S3 folder
objects = bucket.list(prefix=remote_folder)

for obj in objects:
remote_file_path = obj.key
local_file_path = os.path.join(local_path, os.path.relpath(remote_file_path, remote_folder))

# Create directories if they don't exist
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

# Download the file
obj.get_contents_to_filename(local_file_path)


def parse_config_xml(config_xml):
Expand All @@ -71,6 +61,28 @@ def parse_config_xml(config_xml):
region = cn_xml.get("region")
cache_dict = parse_caching_config_dict_from_xml(config_xml)

transfer_xml = config_xml.findall("transfer")
if not transfer_xml:
transfer_xml = {}
else:
transfer_xml = transfer_xml[0]
transfer_dict = {}
for prefix in ["", "upload_", "download_"]:
for key in [
"multipart_threshold",
"max_concurrency",
"multipart_chunksize",
"num_download_attempts",
"max_io_queue",
"io_chunksize",
"use_threads",
"max_bandwidth",
]:
full_key = f"{prefix}{key}"
value = transfer_xml.get(full_key)
if transfer_xml.get(full_key) is not None:
transfer_dict[full_key] = value

tag, attrs = "extra_dir", ("type", "path")
extra_dirs = config_xml.findall(tag)
if not extra_dirs:
Expand All @@ -91,6 +103,7 @@ def parse_config_xml(config_xml):
"endpoint_url": endpoint_url,
"region": region,
},
"transfer": transfer_dict,
"cache": cache_dict,
"extra_dirs": extra_dirs,
"private": CachingConcreteObjectStore.parse_private_from_config_xml(config_xml),
Expand Down Expand Up @@ -134,6 +147,26 @@ def __init__(self, config, config_dict):
bucket_dict = config_dict["bucket"]
connection_dict = config_dict.get("connection", {})
cache_dict = config_dict.get("cache") or {}
transfer_dict = config_dict.get("transfer", {})
typed_transfer_dict = {}
for prefix in ["", "upload_", "download_"]:
options: Dict[str, Callable[[Any], Any]] = {
"multipart_threshold": int,
"max_concurrency": int,
"multipart_chunksize": int,
"num_download_attempts": int,
"max_io_queue": int,
"io_chunksize": int,
"use_threads": asbool,
"max_bandwidth": int,
}
for key, key_type in options.items():
full_key = f"{prefix}{key}"
transfer_value = transfer_dict.get(full_key)
if transfer_value is not None:
typed_transfer_dict[full_key] = key_type(transfer_value)
self.transfer_dict = typed_transfer_dict

self.enable_cache_monitor, self.cache_monitor_interval = enable_cache_monitor(config, config_dict)

self.access_key = auth_dict.get("access_key")
Expand Down Expand Up @@ -226,6 +259,7 @@ def _config_to_dict(self):
"endpoint_url": self.endpoint_url,
"region": self.region,
},
"transfer": self.transfer_dict,
"cache": {
"size": self.cache_size,
"path": self.staging_path,
Expand Down Expand Up @@ -257,7 +291,8 @@ def _download(self, rel_path: str) -> bool:
log.debug("Pulling key '%s' into cache to %s", rel_path, local_destination)
if not self._caching_allowed(rel_path):
return False
self._client.download_file(self.bucket, rel_path, local_destination)
config = self._transfer_config("download")
self._client.download_file(self.bucket, rel_path, local_destination, Config=config)
return True
except ClientError:
log.exception("Failed to download file from S3")
Expand All @@ -273,7 +308,8 @@ def _push_string_to_path(self, rel_path: str, from_string: str) -> bool:

def _push_file_to_path(self, rel_path: str, source_file: str) -> bool:
try:
self._client.upload_file(source_file, self.bucket, rel_path)
config = self._transfer_config("upload")
self._client.upload_file(source_file, self.bucket, rel_path, Config=config)
return True
except ClientError:
log.exception("Trouble pushing to S3 '%s' from file '%s'", rel_path, source_file)
Expand Down Expand Up @@ -336,5 +372,24 @@ def _get_object_url(self, obj, **kwargs):
def _get_store_usage_percent(self, obj):
return 0.0

def _transfer_config(self, prefix: Literal["upload", "download"]) -> "TransferConfig":
config = {}
for key in [
"multipart_threshold",
"max_concurrency",
"multipart_chunksize",
"num_download_attempts",
"max_io_queue",
"io_chunksize",
"use_threads",
"max_bandwidth",
]:
specific_key = f"{prefix}_key"
if specific_key in self.transfer_dict:
config[key] = self.transfer_dict[specific_key]
elif key in self.transfer_dict:
config[key] = self.transfer_dict[key]
return TransferConfig(**config)

def shutdown(self):
self._shutdown_cache_monitor()
43 changes: 43 additions & 0 deletions test/unit/objectstore/test_objectstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,15 @@ def verify_caching_object_store_functionality(tmp_path, object_store, check_get_
reset_cache(object_store.cache_target)
assert not object_store.exists(to_delete_dataset)

# Test bigger file to force multi-process.
big_file_dataset = MockDataset(6)
size = 1024
path = tmp_path / "big_file.bytes"
with path.open("wb") as f:
import os
f.write(os.urandom(size))
object_store.update_from_file(big_file_dataset, file_name=hello_path, create=True)

# Test get_object_url returns a read-only URL
url = object_store.get_object_url(hello_world_dataset)
if check_get_url:
Expand Down Expand Up @@ -1576,6 +1585,40 @@ def test_real_aws_s3_store_boto3(tmp_path):
verify_caching_object_store_functionality(tmp_path, object_store)


AMAZON_BOTO3_S3_MULTITHREAD_TEMPLATE_TEST_CONFIG_YAML = """
type: boto3
store_by: uuid
auth:
access_key: ${access_key}
secret_key: ${secret_key}
bucket:
name: ${bucket}
transfer:
multipart_threshold: 10
extra_dirs:
- type: job_work
path: database/job_working_directory_azure
- type: temp
path: database/tmp_azure
"""


@skip_unless_environ("GALAXY_TEST_AWS_ACCESS_KEY")
@skip_unless_environ("GALAXY_TEST_AWS_SECRET_KEY")
@skip_unless_environ("GALAXY_TEST_AWS_BUCKET")
def test_real_aws_s3_store_boto3_multipart(tmp_path):
template_vars = {
"access_key": os.environ["GALAXY_TEST_AWS_ACCESS_KEY"],
"secret_key": os.environ["GALAXY_TEST_AWS_SECRET_KEY"],
"bucket": os.environ["GALAXY_TEST_AWS_BUCKET"],
}
with TestConfig(AMAZON_BOTO3_S3_MULTITHREAD_TEMPLATE_TEST_CONFIG_YAML, template_vars=template_vars) as (_, object_store):
verify_caching_object_store_functionality(tmp_path, object_store)


@skip_unless_environ("GALAXY_TEST_AWS_ACCESS_KEY")
@skip_unless_environ("GALAXY_TEST_AWS_SECRET_KEY")
def test_real_aws_s3_store_boto3_new_bucket(tmp_path):
Expand Down

0 comments on commit 98c5e69

Please sign in to comment.