diff --git a/lib/galaxy/app.py b/lib/galaxy/app.py index 443119607f6c..110a8d9d8d70 100644 --- a/lib/galaxy/app.py +++ b/lib/galaxy/app.py @@ -262,6 +262,8 @@ def __init__(self, fsmon=False, **kwargs) -> None: self._register_singleton(GalaxyModelMapping, self.model) self._register_singleton(galaxy_scoped_session, self.model.context) self._register_singleton(install_model_scoped_session, self.install_model.context) + # Load quota management. + self.quota_agent = self._register_singleton(QuotaAgent, get_quota_agent(self.config, self.model)) def configure_fluent_log(self): if self.config.fluent_log: @@ -573,8 +575,6 @@ def __init__(self, configure_logging=True, use_converters=True, use_display_appl self.host_security_agent = galaxy.model.security.HostAgent( model=self.security_agent.model, permitted_actions=self.security_agent.permitted_actions ) - # Load quota management. - self.quota_agent = self._register_singleton(QuotaAgent, get_quota_agent(self.config, self.model)) # We need the datatype registry for running certain tasks that modify HDAs, and to build the registry we need # to setup the installed repositories ... this is not ideal diff --git a/lib/galaxy/managers/datasets.py b/lib/galaxy/managers/datasets.py index f97bbfbbbecd..9ef937de8b1f 100644 --- a/lib/galaxy/managers/datasets.py +++ b/lib/galaxy/managers/datasets.py @@ -60,6 +60,8 @@ def __init__(self, app: MinimalManagerApp): self.permissions = DatasetRBACPermissions(app) # needed for admin test self.user_manager = users.UserManager(app) + self.quota_agent = app.quota_agent + self.security_agent = app.model.security_agent def create(self, manage_roles=None, access_roles=None, flush=True, **kwargs): """ @@ -143,6 +145,36 @@ def has_access_permission(self, dataset, user): roles = user.all_roles_exploiting_cache() if user else [] return self.app.security_agent.can_access_dataset(roles, dataset) + def update_object_store_id(self, trans, dataset, object_store_id: str): + device_source_map = self.app.object_store.get_device_source_map() + old_object_store_id = dataset.object_store_id + new_object_store_id = object_store_id + if old_object_store_id == new_object_store_id: + return None + old_device_id = device_source_map.get_device_id(old_object_store_id) + new_device_id = device_source_map.get_device_id(new_object_store_id) + if old_device_id != new_device_id: + raise exceptions.RequestParameterInvalidException( + "Cannot swap object store IDs for object stores that don't share a device ID." + ) + + if not self.security_agent.can_change_object_store_id(trans.user, dataset): + # TODO: probably want separate exceptions for doesn't own the dataset and dataset + # has been shared. + raise exceptions.InsufficientPermissionsException("Cannot change dataset permissions...") + + quota_source_map = self.app.object_store.get_quota_source_map() + if quota_source_map: + old_label = quota_source_map.get_quota_source_label(old_object_store_id) + new_label = quota_source_map.get_quota_source_label(new_object_store_id) + if old_label != new_label: + self.quota_agent.relabel_quota_for_dataset(dataset, old_label, new_label) + sa_session = self.app.model.context + with transaction(sa_session): + dataset.object_store_id = new_object_store_id + sa_session.add(dataset) + sa_session.commit() + def compute_hash(self, request: ComputeDatasetHashTaskRequest): # For files in extra_files_path dataset = self.by_id(request.dataset_id) diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index 34af7c1d5e81..6e1ebf714be2 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -4024,6 +4024,16 @@ def quota_source_info(self): quota_source_map = self.object_store.get_quota_source_map() return quota_source_map.get_quota_source_info(object_store_id) + @property + def device_source_label(self): + return self.device_source_info.label + + @property + def device_source_info(self): + object_store_id = self.object_store_id + device_source_map = self.object_store.get_quota_source_map() + return device_source_map.get_device_source_info(object_store_id) + def set_file_name(self, filename): if not filename: self.external_filename = None diff --git a/lib/galaxy/model/security.py b/lib/galaxy/model/security.py index dc5ed7f32592..6be99b8320e8 100644 --- a/lib/galaxy/model/security.py +++ b/lib/galaxy/model/security.py @@ -15,6 +15,7 @@ select, ) from sqlalchemy.orm import joinedload +from sqlalchemy.sql import text import galaxy.model from galaxy.model import ( @@ -634,6 +635,24 @@ def can_modify_library_item(self, roles, item): def can_manage_library_item(self, roles, item): return self.allow_action(roles, self.permitted_actions.LIBRARY_MANAGE, item) + def can_change_object_store_id(self, user, dataset): + # prevent update if dataset shared with anyone but the current user + # private object stores would prevent this but if something has been + # kept private in a sharable object store still allow the swap + if dataset.library_associations: + return False + else: + query = text( + """ +SELECT COUNT(*) +FROM history +INNER JOIN + history_dataset_association on history_dataset_association.history_id = history.id +WHERE history.user_id != :user_id and history_dataset_association.dataset_id = :dataset_id +""" + ).bindparams(dataset_id=dataset.id, user_id=user.id) + return self.sa_session.scalars(query).first() == 0 + def get_item_actions(self, action, item): # item must be one of: Dataset, Library, LibraryFolder, LibraryDataset, LibraryDatasetDatasetAssociation # SM: Accessing item.actions emits a query to Library_Dataset_Permissions diff --git a/lib/galaxy/objectstore/__init__.py b/lib/galaxy/objectstore/__init__.py index d799ad674bbc..efa33353560c 100644 --- a/lib/galaxy/objectstore/__init__.py +++ b/lib/galaxy/objectstore/__init__.py @@ -63,7 +63,7 @@ DEFAULT_PRIVATE = False DEFAULT_QUOTA_SOURCE = None # Just track quota right on user object in Galaxy. DEFAULT_QUOTA_ENABLED = True # enable quota tracking in object stores by default - +DEFAULT_DEVICE_ID = None log = logging.getLogger(__name__) @@ -329,6 +329,10 @@ def to_dict(self) -> Dict[str, Any]: def get_quota_source_map(self): """Return QuotaSourceMap describing mapping of object store IDs to quota sources.""" + @abc.abstractmethod + def get_device_source_map(self) -> "DeviceSourceMap": + """Return DeviceSourceMap describing mapping of object store IDs to device sources.""" + class BaseObjectStore(ObjectStore): store_by: str @@ -491,6 +495,9 @@ def get_quota_source_map(self): # I'd rather keep this abstract... but register_singleton wants it to be instantiable... raise NotImplementedError() + def get_device_source_map(self): + return DeviceSourceMap() + class ConcreteObjectStore(BaseObjectStore): """Subclass of ObjectStore for stores that don't delegate (non-nested). @@ -501,6 +508,7 @@ class ConcreteObjectStore(BaseObjectStore): """ badges: List[StoredBadgeDict] + device_id: Optional[str] = None def __init__(self, config, config_dict=None, **kwargs): """ @@ -528,6 +536,7 @@ def __init__(self, config, config_dict=None, **kwargs): quota_config = config_dict.get("quota", {}) self.quota_source = quota_config.get("source", DEFAULT_QUOTA_SOURCE) self.quota_enabled = quota_config.get("enabled", DEFAULT_QUOTA_ENABLED) + self.device_id = config_dict.get("device", None) self.badges = read_badges(config_dict) def to_dict(self): @@ -541,6 +550,7 @@ def to_dict(self): "enabled": self.quota_enabled, } rval["badges"] = self._get_concrete_store_badges(None) + rval["device"] = self.device_id return rval def to_model(self, object_store_id: str) -> "ConcreteObjectStoreModel": @@ -551,6 +561,7 @@ def to_model(self, object_store_id: str) -> "ConcreteObjectStoreModel": description=self.description, quota=QuotaModel(source=self.quota_source, enabled=self.quota_enabled), badges=self._get_concrete_store_badges(None), + device=self.device_id, ) def _get_concrete_store_badges(self, obj) -> List[BadgeDict]: @@ -587,6 +598,9 @@ def get_quota_source_map(self): ) return quota_source_map + def get_device_source_map(self) -> "DeviceSourceMap": + return DeviceSourceMap(self.device_id) + class DiskObjectStore(ConcreteObjectStore): """ @@ -637,6 +651,8 @@ def parse_xml(clazz, config_xml): name = config_xml.attrib.get("name", None) if name is not None: config_dict["name"] = name + device = config_xml.attrib.get("device", None) + config_dict["device"] = device for e in config_xml: if e.tag == "quota": config_dict["quota"] = { @@ -1036,7 +1052,7 @@ def __init__(self, config, config_dict, fsmon=False): """ super().__init__(config, config_dict) self._quota_source_map = None - + self._device_source_map = None self.backends = {} self.weighted_backend_ids = [] self.original_weighted_backend_ids = [] @@ -1208,6 +1224,13 @@ def get_quota_source_map(self): self._quota_source_map = quota_source_map return self._quota_source_map + def get_device_source_map(self) -> "DeviceSourceMap": + if self._device_source_map is None: + device_source_map = DeviceSourceMap() + self._merge_device_source_map(device_source_map, self) + self._device_source_map = device_source_map + return self._device_source_map + @classmethod def _merge_quota_source_map(clz, quota_source_map, object_store): for backend_id, backend in object_store.backends.items(): @@ -1216,6 +1239,14 @@ def _merge_quota_source_map(clz, quota_source_map, object_store): else: quota_source_map.backends[backend_id] = backend.get_quota_source_map() + @classmethod + def _merge_device_source_map(clz, device_source_map: "DeviceSourceMap", object_store): + for backend_id, backend in object_store.backends.items(): + if isinstance(backend, DistributedObjectStore): + clz._merge_device_source_map(device_source_map, backend) + else: + device_source_map.backends[backend_id] = backend.get_device_source_map() + def __get_store_id_for(self, obj, **kwargs): if obj.object_store_id is not None: if obj.object_store_id in self.backends: @@ -1364,6 +1395,7 @@ class ConcreteObjectStoreModel(BaseModel): description: Optional[str] = None quota: QuotaModel badges: List[BadgeDict] + device: Optional[str] = None def type_to_object_store_class(store: str, fsmon: bool = False) -> Tuple[Type[BaseObjectStore], Dict[str, Any]]: @@ -1506,6 +1538,21 @@ class QuotaSourceInfo(NamedTuple): use: bool +class DeviceSourceMap: + def __init__(self, device_id=DEFAULT_DEVICE_ID): + self.default_device_id = device_id + self.backends = {} + + def get_device_id(self, object_store_id: str) -> Optional[str]: + if object_store_id in self.backends: + device_map = self.backends.get(object_store_id) + if device_map: + print(device_map) + return device_map.get_device_id(object_store_id) + + return self.default_device_id + + class QuotaSourceMap: def __init__(self, source=DEFAULT_QUOTA_SOURCE, enabled=DEFAULT_QUOTA_ENABLED): self.default_quota_source = source diff --git a/lib/galaxy/quota/__init__.py b/lib/galaxy/quota/__init__.py index 88254be3438f..43505c6ed2b5 100644 --- a/lib/galaxy/quota/__init__.py +++ b/lib/galaxy/quota/__init__.py @@ -3,6 +3,7 @@ from typing import Optional from sqlalchemy import select +from sqlalchemy.orm import object_session from sqlalchemy.sql import text import galaxy.util @@ -25,6 +26,13 @@ class QuotaAgent: # metaclass=abc.ABCMeta the quota in other apps (LDAP maybe?) or via configuration files. """ + def relabel_quota_for_dataset(self, dataset, from_label: Optional[str], to_label: Optional[str]): + """Update the quota source label for dataset and adjust relevant quotas. + + Subtract quota for labels from users using old label and quota for new label + for these users. + """ + # TODO: make abstractmethod after they work better with mypy def get_quota(self, user, quota_source_label=None) -> Optional[int]: """Return quota in bytes or None if no quota is set.""" @@ -81,6 +89,9 @@ def __init__(self): def get_quota(self, user, quota_source_label=None) -> Optional[int]: return None + def relabel_quota_for_dataset(self, dataset, from_label: Optional[str], to_label: Optional[str]): + return None + @property def default_quota(self): return None @@ -173,6 +184,97 @@ def get_quota(self, user, quota_source_label=None) -> Optional[int]: else: return None + def relabel_quota_for_dataset(self, dataset, from_label: Optional[str], to_label: Optional[str]): + adjust = dataset.get_total_size() + with_quota_affected_users = """WITH quota_affected_users AS +( + SELECT DISTINCT user_id + FROM history + INNER JOIN + history_dataset_association on history_dataset_association.history_id = history.id + INNER JOIN + dataset on history_dataset_association.dataset_id = dataset.id + WHERE + dataset_id = :dataset_id +)""" + engine = object_session(dataset).bind + + # Hack for older sqlite, would work on newer sqlite - 3.24.0 + for_sqlite = "sqlite" in engine.dialect.name + + if to_label == from_label: + return + if to_label is None: + to_statement = f""" +{with_quota_affected_users} +UPDATE galaxy_user +SET disk_usage = coalesce(disk_usage, 0) + :adjust +WHERE id in quota_affected_users +""" + else: + if for_sqlite: + to_statement = f""" +{with_quota_affected_users}, +new_quota_sources (user_id, disk_usage, quota_source_label) AS ( + SELECT user_id, :adjust as disk_usage, :to_label as quota_source_label + FROM quota_affected_users +) +INSERT OR REPLACE INTO user_quota_source_usage (id, user_id, quota_source_label, disk_usage) +SELECT old.id, new.user_id, new.quota_source_label, COALESCE(old.disk_usage + :adjust, :adjust) +FROM new_quota_sources as new LEFT JOIN user_quota_source_usage AS old ON new.user_id = old.user_id AND NEW.quota_source_label = old.quota_source_label""" + else: + to_statement = f""" +{with_quota_affected_users}, +new_quota_sources (user_id, disk_usage, quota_source_label) AS ( + SELECT user_id, :adjust as disk_usage, :to_label as quota_source_label + FROM quota_affected_users +) +INSERT INTO user_quota_source_usage(user_id, disk_usage, quota_source_label) +SELECT * FROM new_quota_sources +ON CONFLICT + ON constraint uqsu_unique_label_per_user + DO UPDATE SET disk_usage = user_quota_source_usage.disk_usage + :adjust +""" + + if from_label is None: + from_statement = f""" +{with_quota_affected_users} +UPDATE galaxy_user +SET disk_usage = coalesce(disk_usage - :adjust, 0) +WHERE id in quota_affected_users +""" + else: + if for_sqlite: + from_statement = f""" +{with_quota_affected_users}, +new_quota_sources (user_id, disk_usage, quota_source_label) AS ( + SELECT user_id, :adjust as disk_usage, :from_label as quota_source_label + FROM quota_affected_users +) +INSERT OR REPLACE INTO user_quota_source_usage (id, user_id, quota_source_label, disk_usage) +SELECT old.id, new.user_id, new.quota_source_label, COALESCE(old.disk_usage - :adjust, 0) +FROM new_quota_sources as new LEFT JOIN user_quota_source_usage AS old ON new.user_id = old.user_id AND NEW.quota_source_label = old.quota_source_label""" + else: + from_statement = f""" +{with_quota_affected_users}, +new_quota_sources (user_id, disk_usage, quota_source_label) AS ( + SELECT user_id, 0 as disk_usage, :from_label as quota_source_label + FROM quota_affected_users +) +INSERT INTO user_quota_source_usage(user_id, disk_usage, quota_source_label) +SELECT * FROM new_quota_sources +ON CONFLICT + ON constraint uqsu_unique_label_per_user + DO UPDATE SET disk_usage = user_quota_source_usage.disk_usage - :adjust +""" + + bind = {"dataset_id": dataset.id, "adjust": int(adjust), "to_label": to_label, "from_label": from_label} + engine = self.sa_session.get_bind() + with engine.connect() as conn: + conn.execute(text(from_statement), bind) + conn.execute(text(to_statement), bind) + return None + def _default_unregistered_quota(self, quota_source_label): return self._default_quota(self.model.DefaultQuotaAssociation.types.UNREGISTERED, quota_source_label) diff --git a/lib/galaxy/security/__init__.py b/lib/galaxy/security/__init__.py index 0867f7cc6cf7..ce023824457c 100644 --- a/lib/galaxy/security/__init__.py +++ b/lib/galaxy/security/__init__.py @@ -88,6 +88,9 @@ def can_add_library_item(self, roles, item): def can_modify_library_item(self, roles, item): raise Exception("Unimplemented Method") + def can_change_object_store_id(self, user, dataset): + raise Exception("Unimplemented Method") + def can_manage_library_item(self, roles, item): raise Exception("Unimplemented Method") diff --git a/lib/galaxy/webapps/galaxy/api/datasets.py b/lib/galaxy/webapps/galaxy/api/datasets.py index c735ef74ecb6..ad7a6bbf87b2 100644 --- a/lib/galaxy/webapps/galaxy/api/datasets.py +++ b/lib/galaxy/webapps/galaxy/api/datasets.py @@ -68,6 +68,7 @@ DeleteDatasetBatchPayload, DeleteDatasetBatchResult, RequestDataType, + UpdateObjectStoreIdPayload, ) log = logging.getLogger(__name__) @@ -484,3 +485,17 @@ def compute_hash( payload: ComputeDatasetHashPayload = Body(...), ) -> AsyncTaskResultSummary: return self.service.compute_hash(trans, dataset_id, payload, hda_ldda=hda_ldda) + + @router.put( + "/api/datasets/{dataset_id}/object_store_id", + summary="Update an object store ID for a dataset you own.", + operation_id="datasets__update_object_store_id", + ) + def update_object_store_id( + self, + dataset_id: HistoryDatasetIDPathParam, + trans=DependsOnTrans, + payload: UpdateObjectStoreIdPayload = Body(...), + ) -> bool: + self.service.update_object_store_id(trans, dataset_id, payload) + return True diff --git a/lib/galaxy/webapps/galaxy/services/datasets.py b/lib/galaxy/webapps/galaxy/services/datasets.py index f8208b088c96..ea8a0b6633d7 100644 --- a/lib/galaxy/webapps/galaxy/services/datasets.py +++ b/lib/galaxy/webapps/galaxy/services/datasets.py @@ -31,7 +31,10 @@ from galaxy.datatypes.dataproviders.exceptions import NoProviderAvailable from galaxy.managers.base import ModelSerializer from galaxy.managers.context import ProvidesHistoryContext -from galaxy.managers.datasets import DatasetAssociationManager +from galaxy.managers.datasets import ( + DatasetAssociationManager, + DatasetManager, +) from galaxy.managers.hdas import ( HDAManager, HDASerializer, @@ -247,6 +250,13 @@ class ComputeDatasetHashPayload(Model): model_config = ConfigDict(use_enum_values=True) +class UpdateObjectStoreIdPayload(Model): + object_store_id: str = Field( + ..., + description="Object store ID to update to, it must be an object store with the same device ID as the target dataset currently.", + ) + + class DatasetErrorMessage(Model): dataset: EncodedDatasetSourceId = Field( description="The encoded ID of the dataset and its source.", @@ -281,6 +291,7 @@ def __init__( history_contents_manager: HistoryContentsManager, history_contents_filters: HistoryContentsFilters, data_provider_registry: DataProviderRegistry, + dataset_manager: DatasetManager, ): super().__init__(security) self.history_manager = history_manager @@ -291,6 +302,7 @@ def __init__( self.history_contents_manager = history_contents_manager self.history_contents_filters = history_contents_filters self.data_provider_registry = data_provider_registry + self.dataset_manager = dataset_manager @property def serializer_by_type(self) -> Dict[str, ModelSerializer]: @@ -747,6 +759,11 @@ def get_structured_content( raise galaxy_exceptions.InternalServerError(f"Could not get content for dataset: {util.unicodify(e)}") return content, headers + def update_object_store_id(self, trans, dataset_id: DecodedDatabaseIdField, payload: UpdateObjectStoreIdPayload): + hda = self.hda_manager.get_accessible(dataset_id, trans.user) + dataset = hda.dataset + self.dataset_manager.update_object_store_id(trans, dataset, payload.object_store_id) + def _get_or_create_converted(self, trans, original: model.DatasetInstance, target_ext: str): try: original.get_converted_dataset(trans, target_ext) diff --git a/lib/galaxy_test/base/populators.py b/lib/galaxy_test/base/populators.py index 82490889e1f1..198a614f0891 100644 --- a/lib/galaxy_test/base/populators.py +++ b/lib/galaxy_test/base/populators.py @@ -1181,6 +1181,13 @@ def total_disk_usage(self) -> float: assert "total_disk_usage" in user_object return user_object["total_disk_usage"] + def update_object_store_id(self, dataset_id: str, object_store_id: str): + payload = {"object_store_id": object_store_id} + url = f"datasets/{dataset_id}/object_store_id" + update_response = self._put(url, payload, json=True) + update_response.raise_for_status() + return update_response + def create_role(self, user_ids: list, description: Optional[str] = None) -> dict: using_requirement("admin") payload = { diff --git a/test/integration/objectstore/test_changing_objectstore.py b/test/integration/objectstore/test_changing_objectstore.py new file mode 100644 index 000000000000..4cf9e53102c2 --- /dev/null +++ b/test/integration/objectstore/test_changing_objectstore.py @@ -0,0 +1,98 @@ +"""Integration tests for changing object stores.""" + +import string + +from galaxy_test.base import api_asserts +from ._base import BaseObjectStoreIntegrationTestCase + +DISTRIBUTED_OBJECT_STORE_CONFIG_TEMPLATE = string.Template( + """ + + + + + + + + + + + + + + + + + + + + + +""" +) + +TEST_INPUT_FILES_CONTENT = "1 2 3" + + +class TestChangingStoreObjectStoreIntegration(BaseObjectStoreIntegrationTestCase): + @classmethod + def handle_galaxy_config_kwds(cls, config): + config["new_user_dataset_access_role_default_private"] = True + cls._configure_object_store(DISTRIBUTED_OBJECT_STORE_CONFIG_TEMPLATE, config) + config["enable_quotas"] = True + + def test_valid_in_device_swap(self): + with self.dataset_populator.test_history() as history_id: + hda = self.dataset_populator.new_dataset(history_id, content=TEST_INPUT_FILES_CONTENT, wait=True) + + payload = { + "name": "quota_longer", + "description": "quota_longer desc", + "amount": "10 MB", + "operation": "=", + "default": "registered", + "quota_source_label": "longer_term", + } + self.dataset_populator.create_quota(payload) + + payload = { + "name": "quota_shorter", + "description": "quota_shorter desc", + "amount": "20 MB", + "operation": "=", + "default": "registered", + "quota_source_label": "shorter_term", + } + self.dataset_populator.create_quota(payload) + + quotas = self.dataset_populator.get_quotas() + assert len(quotas) == 2 + + usage = self.dataset_populator.get_usage_for("longer_term") + assert usage["total_disk_usage"] == 0 + usage = self.dataset_populator.get_usage_for("shorter_term") + assert usage["total_disk_usage"] == 0 + usage = self.dataset_populator.get_usage_for(None) + assert int(usage["total_disk_usage"]) == 6 + + self.dataset_populator.update_object_store_id(hda["id"], "temp_short") + usage = self.dataset_populator.get_usage_for("shorter_term") + assert int(usage["total_disk_usage"]) == 6 + usage = self.dataset_populator.get_usage_for(None) + assert int(usage["total_disk_usage"]) == 0 + + self.dataset_populator.update_object_store_id(hda["id"], "temp_long") + usage = self.dataset_populator.get_usage_for("shorter_term") + assert int(usage["total_disk_usage"]) == 0 + usage = self.dataset_populator.get_usage_for("longer_term") + assert int(usage["total_disk_usage"]) == 6 + usage = self.dataset_populator.get_usage_for(None) + assert int(usage["total_disk_usage"]) == 0 + + self.dataset_populator.update_object_store_id(hda["id"], "temp_short") + usage = self.dataset_populator.get_usage_for("shorter_term") + assert int(usage["total_disk_usage"]) == 6 + usage = self.dataset_populator.get_usage_for("longer_term") + assert int(usage["total_disk_usage"]) == 0 + usage = self.dataset_populator.get_usage_for(None) + assert int(usage["total_disk_usage"]) == 0 diff --git a/test/integration/objectstore/test_private_handling.py b/test/integration/objectstore/test_private_handling.py index f89d23d72cea..66ec63f0671d 100644 --- a/test/integration/objectstore/test_private_handling.py +++ b/test/integration/objectstore/test_private_handling.py @@ -1,4 +1,4 @@ -"""Integration tests for mixing store_by.""" +"""Integration tests for private object store handling.""" import string diff --git a/test/unit/data/test_quota.py b/test/unit/data/test_quota.py index 9d9c8573650c..5aea03dac771 100644 --- a/test/unit/data/test_quota.py +++ b/test/unit/data/test_quota.py @@ -284,10 +284,58 @@ def test_calculate_usage_default_storage_disabled(self): assert usages[1].quota_source_label == "alt_source" assert usages[1].total_disk_usage == 15 - def _refresh_user_and_assert_disk_usage_is(self, usage): + def test_update_usage_from_labeled_to_unlabeled(self): + model = self.model + quota_agent = DatabaseQuotaAgent(model) + u = self.u + + self._add_dataset(10) + alt_d = self._add_dataset(15, "alt_source_store") + self.model.session.flush() + assert quota_agent + + quota_source_map = QuotaSourceMap(None, True) + alt_source = QuotaSourceMap("alt_source", True) + quota_source_map.backends["alt_source_store"] = alt_source + + object_store = MockObjectStore(quota_source_map) + u.calculate_and_set_disk_usage(object_store) + self._refresh_user_and_assert_disk_usage_is(10) + quota_agent.relabel_quota_for_dataset(alt_d.dataset, "alt_source", None) + self._refresh_user_and_assert_disk_usage_is(25) + self._refresh_user_and_assert_disk_usage_is(0, "alt_source") + + def test_update_usage_from_unlabeled_to_labeled(self): + model = self.model + quota_agent = DatabaseQuotaAgent(model) + u = self.u + + d = self._add_dataset(10) + self._add_dataset(15, "alt_source_store") + self.model.session.flush() + assert quota_agent + + quota_source_map = QuotaSourceMap(None, True) + alt_source = QuotaSourceMap("alt_source", True) + quota_source_map.backends["alt_source_store"] = alt_source + + object_store = MockObjectStore(quota_source_map) + u.calculate_and_set_disk_usage(object_store) + self._refresh_user_and_assert_disk_usage_is(15, "alt_source") + quota_agent.relabel_quota_for_dataset(d.dataset, None, "alt_source") + self._refresh_user_and_assert_disk_usage_is(25, "alt_source") + self._refresh_user_and_assert_disk_usage_is(0, None) + + def _refresh_user_and_assert_disk_usage_is(self, usage, label=None): u = self.u self.model.context.refresh(u) - assert u.disk_usage == usage + if label is None: + assert u.disk_usage == usage + else: + usages = u.dictify_usage() + for u in usages: + if u.quota_source_label == label: + assert int(u.total_disk_usage) == int(usage) class TestQuota(BaseModelTestCase): diff --git a/test/unit/objectstore/test_objectstore.py b/test/unit/objectstore/test_objectstore.py index ae922e986ff4..4956ac970834 100644 --- a/test/unit/objectstore/test_objectstore.py +++ b/test/unit/objectstore/test_objectstore.py @@ -524,13 +524,13 @@ def test_badges_parsing_conflicts(): DISTRIBUTED_TEST_CONFIG = """ - + - + @@ -549,6 +549,7 @@ def test_badges_parsing_conflicts(): source: 1files type: disk weight: 2 + device: primary_disk files_dir: "${temp_directory}/files1" extra_dirs: - type: temp @@ -560,6 +561,7 @@ def test_badges_parsing_conflicts(): source: 2files type: disk weight: 1 + device: primary_disk files_dir: "${temp_directory}/files2" extra_dirs: - type: temp @@ -598,6 +600,12 @@ def test_distributed_store(): extra_dirs = as_dict["extra_dirs"] assert len(extra_dirs) == 2 + device_source_map = object_store.get_device_source_map() + assert device_source_map + print(device_source_map.backends) + assert device_source_map.get_device_id("files1") == "primary_disk" + assert device_source_map.get_device_id("files2") == "primary_disk" + def test_distributed_store_empty_cache_targets(): for config_str in [DISTRIBUTED_TEST_CONFIG, DISTRIBUTED_TEST_CONFIG_YAML]: