diff --git a/lib/galaxy/app.py b/lib/galaxy/app.py
index 443119607f6c..110a8d9d8d70 100644
--- a/lib/galaxy/app.py
+++ b/lib/galaxy/app.py
@@ -262,6 +262,8 @@ def __init__(self, fsmon=False, **kwargs) -> None:
self._register_singleton(GalaxyModelMapping, self.model)
self._register_singleton(galaxy_scoped_session, self.model.context)
self._register_singleton(install_model_scoped_session, self.install_model.context)
+ # Load quota management.
+ self.quota_agent = self._register_singleton(QuotaAgent, get_quota_agent(self.config, self.model))
def configure_fluent_log(self):
if self.config.fluent_log:
@@ -573,8 +575,6 @@ def __init__(self, configure_logging=True, use_converters=True, use_display_appl
self.host_security_agent = galaxy.model.security.HostAgent(
model=self.security_agent.model, permitted_actions=self.security_agent.permitted_actions
)
- # Load quota management.
- self.quota_agent = self._register_singleton(QuotaAgent, get_quota_agent(self.config, self.model))
# We need the datatype registry for running certain tasks that modify HDAs, and to build the registry we need
# to setup the installed repositories ... this is not ideal
diff --git a/lib/galaxy/managers/datasets.py b/lib/galaxy/managers/datasets.py
index f97bbfbbbecd..9ef937de8b1f 100644
--- a/lib/galaxy/managers/datasets.py
+++ b/lib/galaxy/managers/datasets.py
@@ -60,6 +60,8 @@ def __init__(self, app: MinimalManagerApp):
self.permissions = DatasetRBACPermissions(app)
# needed for admin test
self.user_manager = users.UserManager(app)
+ self.quota_agent = app.quota_agent
+ self.security_agent = app.model.security_agent
def create(self, manage_roles=None, access_roles=None, flush=True, **kwargs):
"""
@@ -143,6 +145,36 @@ def has_access_permission(self, dataset, user):
roles = user.all_roles_exploiting_cache() if user else []
return self.app.security_agent.can_access_dataset(roles, dataset)
+ def update_object_store_id(self, trans, dataset, object_store_id: str):
+ device_source_map = self.app.object_store.get_device_source_map()
+ old_object_store_id = dataset.object_store_id
+ new_object_store_id = object_store_id
+ if old_object_store_id == new_object_store_id:
+ return None
+ old_device_id = device_source_map.get_device_id(old_object_store_id)
+ new_device_id = device_source_map.get_device_id(new_object_store_id)
+ if old_device_id != new_device_id:
+ raise exceptions.RequestParameterInvalidException(
+ "Cannot swap object store IDs for object stores that don't share a device ID."
+ )
+
+ if not self.security_agent.can_change_object_store_id(trans.user, dataset):
+ # TODO: probably want separate exceptions for doesn't own the dataset and dataset
+ # has been shared.
+ raise exceptions.InsufficientPermissionsException("Cannot change dataset permissions...")
+
+ quota_source_map = self.app.object_store.get_quota_source_map()
+ if quota_source_map:
+ old_label = quota_source_map.get_quota_source_label(old_object_store_id)
+ new_label = quota_source_map.get_quota_source_label(new_object_store_id)
+ if old_label != new_label:
+ self.quota_agent.relabel_quota_for_dataset(dataset, old_label, new_label)
+ sa_session = self.app.model.context
+ with transaction(sa_session):
+ dataset.object_store_id = new_object_store_id
+ sa_session.add(dataset)
+ sa_session.commit()
+
def compute_hash(self, request: ComputeDatasetHashTaskRequest):
# For files in extra_files_path
dataset = self.by_id(request.dataset_id)
diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py
index 34af7c1d5e81..6e1ebf714be2 100644
--- a/lib/galaxy/model/__init__.py
+++ b/lib/galaxy/model/__init__.py
@@ -4024,6 +4024,16 @@ def quota_source_info(self):
quota_source_map = self.object_store.get_quota_source_map()
return quota_source_map.get_quota_source_info(object_store_id)
+ @property
+ def device_source_label(self):
+ return self.device_source_info.label
+
+ @property
+ def device_source_info(self):
+ object_store_id = self.object_store_id
+ device_source_map = self.object_store.get_quota_source_map()
+ return device_source_map.get_device_source_info(object_store_id)
+
def set_file_name(self, filename):
if not filename:
self.external_filename = None
diff --git a/lib/galaxy/model/security.py b/lib/galaxy/model/security.py
index dc5ed7f32592..6be99b8320e8 100644
--- a/lib/galaxy/model/security.py
+++ b/lib/galaxy/model/security.py
@@ -15,6 +15,7 @@
select,
)
from sqlalchemy.orm import joinedload
+from sqlalchemy.sql import text
import galaxy.model
from galaxy.model import (
@@ -634,6 +635,24 @@ def can_modify_library_item(self, roles, item):
def can_manage_library_item(self, roles, item):
return self.allow_action(roles, self.permitted_actions.LIBRARY_MANAGE, item)
+ def can_change_object_store_id(self, user, dataset):
+ # prevent update if dataset shared with anyone but the current user
+ # private object stores would prevent this but if something has been
+ # kept private in a sharable object store still allow the swap
+ if dataset.library_associations:
+ return False
+ else:
+ query = text(
+ """
+SELECT COUNT(*)
+FROM history
+INNER JOIN
+ history_dataset_association on history_dataset_association.history_id = history.id
+WHERE history.user_id != :user_id and history_dataset_association.dataset_id = :dataset_id
+"""
+ ).bindparams(dataset_id=dataset.id, user_id=user.id)
+ return self.sa_session.scalars(query).first() == 0
+
def get_item_actions(self, action, item):
# item must be one of: Dataset, Library, LibraryFolder, LibraryDataset, LibraryDatasetDatasetAssociation
# SM: Accessing item.actions emits a query to Library_Dataset_Permissions
diff --git a/lib/galaxy/objectstore/__init__.py b/lib/galaxy/objectstore/__init__.py
index d799ad674bbc..efa33353560c 100644
--- a/lib/galaxy/objectstore/__init__.py
+++ b/lib/galaxy/objectstore/__init__.py
@@ -63,7 +63,7 @@
DEFAULT_PRIVATE = False
DEFAULT_QUOTA_SOURCE = None # Just track quota right on user object in Galaxy.
DEFAULT_QUOTA_ENABLED = True # enable quota tracking in object stores by default
-
+DEFAULT_DEVICE_ID = None
log = logging.getLogger(__name__)
@@ -329,6 +329,10 @@ def to_dict(self) -> Dict[str, Any]:
def get_quota_source_map(self):
"""Return QuotaSourceMap describing mapping of object store IDs to quota sources."""
+ @abc.abstractmethod
+ def get_device_source_map(self) -> "DeviceSourceMap":
+ """Return DeviceSourceMap describing mapping of object store IDs to device sources."""
+
class BaseObjectStore(ObjectStore):
store_by: str
@@ -491,6 +495,9 @@ def get_quota_source_map(self):
# I'd rather keep this abstract... but register_singleton wants it to be instantiable...
raise NotImplementedError()
+ def get_device_source_map(self):
+ return DeviceSourceMap()
+
class ConcreteObjectStore(BaseObjectStore):
"""Subclass of ObjectStore for stores that don't delegate (non-nested).
@@ -501,6 +508,7 @@ class ConcreteObjectStore(BaseObjectStore):
"""
badges: List[StoredBadgeDict]
+ device_id: Optional[str] = None
def __init__(self, config, config_dict=None, **kwargs):
"""
@@ -528,6 +536,7 @@ def __init__(self, config, config_dict=None, **kwargs):
quota_config = config_dict.get("quota", {})
self.quota_source = quota_config.get("source", DEFAULT_QUOTA_SOURCE)
self.quota_enabled = quota_config.get("enabled", DEFAULT_QUOTA_ENABLED)
+ self.device_id = config_dict.get("device", None)
self.badges = read_badges(config_dict)
def to_dict(self):
@@ -541,6 +550,7 @@ def to_dict(self):
"enabled": self.quota_enabled,
}
rval["badges"] = self._get_concrete_store_badges(None)
+ rval["device"] = self.device_id
return rval
def to_model(self, object_store_id: str) -> "ConcreteObjectStoreModel":
@@ -551,6 +561,7 @@ def to_model(self, object_store_id: str) -> "ConcreteObjectStoreModel":
description=self.description,
quota=QuotaModel(source=self.quota_source, enabled=self.quota_enabled),
badges=self._get_concrete_store_badges(None),
+ device=self.device_id,
)
def _get_concrete_store_badges(self, obj) -> List[BadgeDict]:
@@ -587,6 +598,9 @@ def get_quota_source_map(self):
)
return quota_source_map
+ def get_device_source_map(self) -> "DeviceSourceMap":
+ return DeviceSourceMap(self.device_id)
+
class DiskObjectStore(ConcreteObjectStore):
"""
@@ -637,6 +651,8 @@ def parse_xml(clazz, config_xml):
name = config_xml.attrib.get("name", None)
if name is not None:
config_dict["name"] = name
+ device = config_xml.attrib.get("device", None)
+ config_dict["device"] = device
for e in config_xml:
if e.tag == "quota":
config_dict["quota"] = {
@@ -1036,7 +1052,7 @@ def __init__(self, config, config_dict, fsmon=False):
"""
super().__init__(config, config_dict)
self._quota_source_map = None
-
+ self._device_source_map = None
self.backends = {}
self.weighted_backend_ids = []
self.original_weighted_backend_ids = []
@@ -1208,6 +1224,13 @@ def get_quota_source_map(self):
self._quota_source_map = quota_source_map
return self._quota_source_map
+ def get_device_source_map(self) -> "DeviceSourceMap":
+ if self._device_source_map is None:
+ device_source_map = DeviceSourceMap()
+ self._merge_device_source_map(device_source_map, self)
+ self._device_source_map = device_source_map
+ return self._device_source_map
+
@classmethod
def _merge_quota_source_map(clz, quota_source_map, object_store):
for backend_id, backend in object_store.backends.items():
@@ -1216,6 +1239,14 @@ def _merge_quota_source_map(clz, quota_source_map, object_store):
else:
quota_source_map.backends[backend_id] = backend.get_quota_source_map()
+ @classmethod
+ def _merge_device_source_map(clz, device_source_map: "DeviceSourceMap", object_store):
+ for backend_id, backend in object_store.backends.items():
+ if isinstance(backend, DistributedObjectStore):
+ clz._merge_device_source_map(device_source_map, backend)
+ else:
+ device_source_map.backends[backend_id] = backend.get_device_source_map()
+
def __get_store_id_for(self, obj, **kwargs):
if obj.object_store_id is not None:
if obj.object_store_id in self.backends:
@@ -1364,6 +1395,7 @@ class ConcreteObjectStoreModel(BaseModel):
description: Optional[str] = None
quota: QuotaModel
badges: List[BadgeDict]
+ device: Optional[str] = None
def type_to_object_store_class(store: str, fsmon: bool = False) -> Tuple[Type[BaseObjectStore], Dict[str, Any]]:
@@ -1506,6 +1538,21 @@ class QuotaSourceInfo(NamedTuple):
use: bool
+class DeviceSourceMap:
+ def __init__(self, device_id=DEFAULT_DEVICE_ID):
+ self.default_device_id = device_id
+ self.backends = {}
+
+ def get_device_id(self, object_store_id: str) -> Optional[str]:
+ if object_store_id in self.backends:
+ device_map = self.backends.get(object_store_id)
+ if device_map:
+ print(device_map)
+ return device_map.get_device_id(object_store_id)
+
+ return self.default_device_id
+
+
class QuotaSourceMap:
def __init__(self, source=DEFAULT_QUOTA_SOURCE, enabled=DEFAULT_QUOTA_ENABLED):
self.default_quota_source = source
diff --git a/lib/galaxy/quota/__init__.py b/lib/galaxy/quota/__init__.py
index 88254be3438f..43505c6ed2b5 100644
--- a/lib/galaxy/quota/__init__.py
+++ b/lib/galaxy/quota/__init__.py
@@ -3,6 +3,7 @@
from typing import Optional
from sqlalchemy import select
+from sqlalchemy.orm import object_session
from sqlalchemy.sql import text
import galaxy.util
@@ -25,6 +26,13 @@ class QuotaAgent: # metaclass=abc.ABCMeta
the quota in other apps (LDAP maybe?) or via configuration files.
"""
+ def relabel_quota_for_dataset(self, dataset, from_label: Optional[str], to_label: Optional[str]):
+ """Update the quota source label for dataset and adjust relevant quotas.
+
+ Subtract quota for labels from users using old label and quota for new label
+ for these users.
+ """
+
# TODO: make abstractmethod after they work better with mypy
def get_quota(self, user, quota_source_label=None) -> Optional[int]:
"""Return quota in bytes or None if no quota is set."""
@@ -81,6 +89,9 @@ def __init__(self):
def get_quota(self, user, quota_source_label=None) -> Optional[int]:
return None
+ def relabel_quota_for_dataset(self, dataset, from_label: Optional[str], to_label: Optional[str]):
+ return None
+
@property
def default_quota(self):
return None
@@ -173,6 +184,97 @@ def get_quota(self, user, quota_source_label=None) -> Optional[int]:
else:
return None
+ def relabel_quota_for_dataset(self, dataset, from_label: Optional[str], to_label: Optional[str]):
+ adjust = dataset.get_total_size()
+ with_quota_affected_users = """WITH quota_affected_users AS
+(
+ SELECT DISTINCT user_id
+ FROM history
+ INNER JOIN
+ history_dataset_association on history_dataset_association.history_id = history.id
+ INNER JOIN
+ dataset on history_dataset_association.dataset_id = dataset.id
+ WHERE
+ dataset_id = :dataset_id
+)"""
+ engine = object_session(dataset).bind
+
+ # Hack for older sqlite, would work on newer sqlite - 3.24.0
+ for_sqlite = "sqlite" in engine.dialect.name
+
+ if to_label == from_label:
+ return
+ if to_label is None:
+ to_statement = f"""
+{with_quota_affected_users}
+UPDATE galaxy_user
+SET disk_usage = coalesce(disk_usage, 0) + :adjust
+WHERE id in quota_affected_users
+"""
+ else:
+ if for_sqlite:
+ to_statement = f"""
+{with_quota_affected_users},
+new_quota_sources (user_id, disk_usage, quota_source_label) AS (
+ SELECT user_id, :adjust as disk_usage, :to_label as quota_source_label
+ FROM quota_affected_users
+)
+INSERT OR REPLACE INTO user_quota_source_usage (id, user_id, quota_source_label, disk_usage)
+SELECT old.id, new.user_id, new.quota_source_label, COALESCE(old.disk_usage + :adjust, :adjust)
+FROM new_quota_sources as new LEFT JOIN user_quota_source_usage AS old ON new.user_id = old.user_id AND NEW.quota_source_label = old.quota_source_label"""
+ else:
+ to_statement = f"""
+{with_quota_affected_users},
+new_quota_sources (user_id, disk_usage, quota_source_label) AS (
+ SELECT user_id, :adjust as disk_usage, :to_label as quota_source_label
+ FROM quota_affected_users
+)
+INSERT INTO user_quota_source_usage(user_id, disk_usage, quota_source_label)
+SELECT * FROM new_quota_sources
+ON CONFLICT
+ ON constraint uqsu_unique_label_per_user
+ DO UPDATE SET disk_usage = user_quota_source_usage.disk_usage + :adjust
+"""
+
+ if from_label is None:
+ from_statement = f"""
+{with_quota_affected_users}
+UPDATE galaxy_user
+SET disk_usage = coalesce(disk_usage - :adjust, 0)
+WHERE id in quota_affected_users
+"""
+ else:
+ if for_sqlite:
+ from_statement = f"""
+{with_quota_affected_users},
+new_quota_sources (user_id, disk_usage, quota_source_label) AS (
+ SELECT user_id, :adjust as disk_usage, :from_label as quota_source_label
+ FROM quota_affected_users
+)
+INSERT OR REPLACE INTO user_quota_source_usage (id, user_id, quota_source_label, disk_usage)
+SELECT old.id, new.user_id, new.quota_source_label, COALESCE(old.disk_usage - :adjust, 0)
+FROM new_quota_sources as new LEFT JOIN user_quota_source_usage AS old ON new.user_id = old.user_id AND NEW.quota_source_label = old.quota_source_label"""
+ else:
+ from_statement = f"""
+{with_quota_affected_users},
+new_quota_sources (user_id, disk_usage, quota_source_label) AS (
+ SELECT user_id, 0 as disk_usage, :from_label as quota_source_label
+ FROM quota_affected_users
+)
+INSERT INTO user_quota_source_usage(user_id, disk_usage, quota_source_label)
+SELECT * FROM new_quota_sources
+ON CONFLICT
+ ON constraint uqsu_unique_label_per_user
+ DO UPDATE SET disk_usage = user_quota_source_usage.disk_usage - :adjust
+"""
+
+ bind = {"dataset_id": dataset.id, "adjust": int(adjust), "to_label": to_label, "from_label": from_label}
+ engine = self.sa_session.get_bind()
+ with engine.connect() as conn:
+ conn.execute(text(from_statement), bind)
+ conn.execute(text(to_statement), bind)
+ return None
+
def _default_unregistered_quota(self, quota_source_label):
return self._default_quota(self.model.DefaultQuotaAssociation.types.UNREGISTERED, quota_source_label)
diff --git a/lib/galaxy/security/__init__.py b/lib/galaxy/security/__init__.py
index 0867f7cc6cf7..ce023824457c 100644
--- a/lib/galaxy/security/__init__.py
+++ b/lib/galaxy/security/__init__.py
@@ -88,6 +88,9 @@ def can_add_library_item(self, roles, item):
def can_modify_library_item(self, roles, item):
raise Exception("Unimplemented Method")
+ def can_change_object_store_id(self, user, dataset):
+ raise Exception("Unimplemented Method")
+
def can_manage_library_item(self, roles, item):
raise Exception("Unimplemented Method")
diff --git a/lib/galaxy/webapps/galaxy/api/datasets.py b/lib/galaxy/webapps/galaxy/api/datasets.py
index c735ef74ecb6..ad7a6bbf87b2 100644
--- a/lib/galaxy/webapps/galaxy/api/datasets.py
+++ b/lib/galaxy/webapps/galaxy/api/datasets.py
@@ -68,6 +68,7 @@
DeleteDatasetBatchPayload,
DeleteDatasetBatchResult,
RequestDataType,
+ UpdateObjectStoreIdPayload,
)
log = logging.getLogger(__name__)
@@ -484,3 +485,17 @@ def compute_hash(
payload: ComputeDatasetHashPayload = Body(...),
) -> AsyncTaskResultSummary:
return self.service.compute_hash(trans, dataset_id, payload, hda_ldda=hda_ldda)
+
+ @router.put(
+ "/api/datasets/{dataset_id}/object_store_id",
+ summary="Update an object store ID for a dataset you own.",
+ operation_id="datasets__update_object_store_id",
+ )
+ def update_object_store_id(
+ self,
+ dataset_id: HistoryDatasetIDPathParam,
+ trans=DependsOnTrans,
+ payload: UpdateObjectStoreIdPayload = Body(...),
+ ) -> bool:
+ self.service.update_object_store_id(trans, dataset_id, payload)
+ return True
diff --git a/lib/galaxy/webapps/galaxy/services/datasets.py b/lib/galaxy/webapps/galaxy/services/datasets.py
index f8208b088c96..ea8a0b6633d7 100644
--- a/lib/galaxy/webapps/galaxy/services/datasets.py
+++ b/lib/galaxy/webapps/galaxy/services/datasets.py
@@ -31,7 +31,10 @@
from galaxy.datatypes.dataproviders.exceptions import NoProviderAvailable
from galaxy.managers.base import ModelSerializer
from galaxy.managers.context import ProvidesHistoryContext
-from galaxy.managers.datasets import DatasetAssociationManager
+from galaxy.managers.datasets import (
+ DatasetAssociationManager,
+ DatasetManager,
+)
from galaxy.managers.hdas import (
HDAManager,
HDASerializer,
@@ -247,6 +250,13 @@ class ComputeDatasetHashPayload(Model):
model_config = ConfigDict(use_enum_values=True)
+class UpdateObjectStoreIdPayload(Model):
+ object_store_id: str = Field(
+ ...,
+ description="Object store ID to update to, it must be an object store with the same device ID as the target dataset currently.",
+ )
+
+
class DatasetErrorMessage(Model):
dataset: EncodedDatasetSourceId = Field(
description="The encoded ID of the dataset and its source.",
@@ -281,6 +291,7 @@ def __init__(
history_contents_manager: HistoryContentsManager,
history_contents_filters: HistoryContentsFilters,
data_provider_registry: DataProviderRegistry,
+ dataset_manager: DatasetManager,
):
super().__init__(security)
self.history_manager = history_manager
@@ -291,6 +302,7 @@ def __init__(
self.history_contents_manager = history_contents_manager
self.history_contents_filters = history_contents_filters
self.data_provider_registry = data_provider_registry
+ self.dataset_manager = dataset_manager
@property
def serializer_by_type(self) -> Dict[str, ModelSerializer]:
@@ -747,6 +759,11 @@ def get_structured_content(
raise galaxy_exceptions.InternalServerError(f"Could not get content for dataset: {util.unicodify(e)}")
return content, headers
+ def update_object_store_id(self, trans, dataset_id: DecodedDatabaseIdField, payload: UpdateObjectStoreIdPayload):
+ hda = self.hda_manager.get_accessible(dataset_id, trans.user)
+ dataset = hda.dataset
+ self.dataset_manager.update_object_store_id(trans, dataset, payload.object_store_id)
+
def _get_or_create_converted(self, trans, original: model.DatasetInstance, target_ext: str):
try:
original.get_converted_dataset(trans, target_ext)
diff --git a/lib/galaxy_test/base/populators.py b/lib/galaxy_test/base/populators.py
index 82490889e1f1..198a614f0891 100644
--- a/lib/galaxy_test/base/populators.py
+++ b/lib/galaxy_test/base/populators.py
@@ -1181,6 +1181,13 @@ def total_disk_usage(self) -> float:
assert "total_disk_usage" in user_object
return user_object["total_disk_usage"]
+ def update_object_store_id(self, dataset_id: str, object_store_id: str):
+ payload = {"object_store_id": object_store_id}
+ url = f"datasets/{dataset_id}/object_store_id"
+ update_response = self._put(url, payload, json=True)
+ update_response.raise_for_status()
+ return update_response
+
def create_role(self, user_ids: list, description: Optional[str] = None) -> dict:
using_requirement("admin")
payload = {
diff --git a/test/integration/objectstore/test_changing_objectstore.py b/test/integration/objectstore/test_changing_objectstore.py
new file mode 100644
index 000000000000..4cf9e53102c2
--- /dev/null
+++ b/test/integration/objectstore/test_changing_objectstore.py
@@ -0,0 +1,98 @@
+"""Integration tests for changing object stores."""
+
+import string
+
+from galaxy_test.base import api_asserts
+from ._base import BaseObjectStoreIntegrationTestCase
+
+DISTRIBUTED_OBJECT_STORE_CONFIG_TEMPLATE = string.Template(
+ """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+)
+
+TEST_INPUT_FILES_CONTENT = "1 2 3"
+
+
+class TestChangingStoreObjectStoreIntegration(BaseObjectStoreIntegrationTestCase):
+ @classmethod
+ def handle_galaxy_config_kwds(cls, config):
+ config["new_user_dataset_access_role_default_private"] = True
+ cls._configure_object_store(DISTRIBUTED_OBJECT_STORE_CONFIG_TEMPLATE, config)
+ config["enable_quotas"] = True
+
+ def test_valid_in_device_swap(self):
+ with self.dataset_populator.test_history() as history_id:
+ hda = self.dataset_populator.new_dataset(history_id, content=TEST_INPUT_FILES_CONTENT, wait=True)
+
+ payload = {
+ "name": "quota_longer",
+ "description": "quota_longer desc",
+ "amount": "10 MB",
+ "operation": "=",
+ "default": "registered",
+ "quota_source_label": "longer_term",
+ }
+ self.dataset_populator.create_quota(payload)
+
+ payload = {
+ "name": "quota_shorter",
+ "description": "quota_shorter desc",
+ "amount": "20 MB",
+ "operation": "=",
+ "default": "registered",
+ "quota_source_label": "shorter_term",
+ }
+ self.dataset_populator.create_quota(payload)
+
+ quotas = self.dataset_populator.get_quotas()
+ assert len(quotas) == 2
+
+ usage = self.dataset_populator.get_usage_for("longer_term")
+ assert usage["total_disk_usage"] == 0
+ usage = self.dataset_populator.get_usage_for("shorter_term")
+ assert usage["total_disk_usage"] == 0
+ usage = self.dataset_populator.get_usage_for(None)
+ assert int(usage["total_disk_usage"]) == 6
+
+ self.dataset_populator.update_object_store_id(hda["id"], "temp_short")
+ usage = self.dataset_populator.get_usage_for("shorter_term")
+ assert int(usage["total_disk_usage"]) == 6
+ usage = self.dataset_populator.get_usage_for(None)
+ assert int(usage["total_disk_usage"]) == 0
+
+ self.dataset_populator.update_object_store_id(hda["id"], "temp_long")
+ usage = self.dataset_populator.get_usage_for("shorter_term")
+ assert int(usage["total_disk_usage"]) == 0
+ usage = self.dataset_populator.get_usage_for("longer_term")
+ assert int(usage["total_disk_usage"]) == 6
+ usage = self.dataset_populator.get_usage_for(None)
+ assert int(usage["total_disk_usage"]) == 0
+
+ self.dataset_populator.update_object_store_id(hda["id"], "temp_short")
+ usage = self.dataset_populator.get_usage_for("shorter_term")
+ assert int(usage["total_disk_usage"]) == 6
+ usage = self.dataset_populator.get_usage_for("longer_term")
+ assert int(usage["total_disk_usage"]) == 0
+ usage = self.dataset_populator.get_usage_for(None)
+ assert int(usage["total_disk_usage"]) == 0
diff --git a/test/integration/objectstore/test_private_handling.py b/test/integration/objectstore/test_private_handling.py
index f89d23d72cea..66ec63f0671d 100644
--- a/test/integration/objectstore/test_private_handling.py
+++ b/test/integration/objectstore/test_private_handling.py
@@ -1,4 +1,4 @@
-"""Integration tests for mixing store_by."""
+"""Integration tests for private object store handling."""
import string
diff --git a/test/unit/data/test_quota.py b/test/unit/data/test_quota.py
index 9d9c8573650c..5aea03dac771 100644
--- a/test/unit/data/test_quota.py
+++ b/test/unit/data/test_quota.py
@@ -284,10 +284,58 @@ def test_calculate_usage_default_storage_disabled(self):
assert usages[1].quota_source_label == "alt_source"
assert usages[1].total_disk_usage == 15
- def _refresh_user_and_assert_disk_usage_is(self, usage):
+ def test_update_usage_from_labeled_to_unlabeled(self):
+ model = self.model
+ quota_agent = DatabaseQuotaAgent(model)
+ u = self.u
+
+ self._add_dataset(10)
+ alt_d = self._add_dataset(15, "alt_source_store")
+ self.model.session.flush()
+ assert quota_agent
+
+ quota_source_map = QuotaSourceMap(None, True)
+ alt_source = QuotaSourceMap("alt_source", True)
+ quota_source_map.backends["alt_source_store"] = alt_source
+
+ object_store = MockObjectStore(quota_source_map)
+ u.calculate_and_set_disk_usage(object_store)
+ self._refresh_user_and_assert_disk_usage_is(10)
+ quota_agent.relabel_quota_for_dataset(alt_d.dataset, "alt_source", None)
+ self._refresh_user_and_assert_disk_usage_is(25)
+ self._refresh_user_and_assert_disk_usage_is(0, "alt_source")
+
+ def test_update_usage_from_unlabeled_to_labeled(self):
+ model = self.model
+ quota_agent = DatabaseQuotaAgent(model)
+ u = self.u
+
+ d = self._add_dataset(10)
+ self._add_dataset(15, "alt_source_store")
+ self.model.session.flush()
+ assert quota_agent
+
+ quota_source_map = QuotaSourceMap(None, True)
+ alt_source = QuotaSourceMap("alt_source", True)
+ quota_source_map.backends["alt_source_store"] = alt_source
+
+ object_store = MockObjectStore(quota_source_map)
+ u.calculate_and_set_disk_usage(object_store)
+ self._refresh_user_and_assert_disk_usage_is(15, "alt_source")
+ quota_agent.relabel_quota_for_dataset(d.dataset, None, "alt_source")
+ self._refresh_user_and_assert_disk_usage_is(25, "alt_source")
+ self._refresh_user_and_assert_disk_usage_is(0, None)
+
+ def _refresh_user_and_assert_disk_usage_is(self, usage, label=None):
u = self.u
self.model.context.refresh(u)
- assert u.disk_usage == usage
+ if label is None:
+ assert u.disk_usage == usage
+ else:
+ usages = u.dictify_usage()
+ for u in usages:
+ if u.quota_source_label == label:
+ assert int(u.total_disk_usage) == int(usage)
class TestQuota(BaseModelTestCase):
diff --git a/test/unit/objectstore/test_objectstore.py b/test/unit/objectstore/test_objectstore.py
index ae922e986ff4..4956ac970834 100644
--- a/test/unit/objectstore/test_objectstore.py
+++ b/test/unit/objectstore/test_objectstore.py
@@ -524,13 +524,13 @@ def test_badges_parsing_conflicts():
DISTRIBUTED_TEST_CONFIG = """
-
+
-
+
@@ -549,6 +549,7 @@ def test_badges_parsing_conflicts():
source: 1files
type: disk
weight: 2
+ device: primary_disk
files_dir: "${temp_directory}/files1"
extra_dirs:
- type: temp
@@ -560,6 +561,7 @@ def test_badges_parsing_conflicts():
source: 2files
type: disk
weight: 1
+ device: primary_disk
files_dir: "${temp_directory}/files2"
extra_dirs:
- type: temp
@@ -598,6 +600,12 @@ def test_distributed_store():
extra_dirs = as_dict["extra_dirs"]
assert len(extra_dirs) == 2
+ device_source_map = object_store.get_device_source_map()
+ assert device_source_map
+ print(device_source_map.backends)
+ assert device_source_map.get_device_id("files1") == "primary_disk"
+ assert device_source_map.get_device_id("files2") == "primary_disk"
+
def test_distributed_store_empty_cache_targets():
for config_str in [DISTRIBUTED_TEST_CONFIG, DISTRIBUTED_TEST_CONFIG_YAML]: