Skip to content

Commit

Permalink
[WIP] device source stuff...
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Jan 17, 2024
1 parent 1e3051d commit 7780321
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 4 deletions.
10 changes: 10 additions & 0 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4024,6 +4024,16 @@ def quota_source_info(self):
quota_source_map = self.object_store.get_quota_source_map()
return quota_source_map.get_quota_source_info(object_store_id)

@property
def device_source_label(self):
return self.device_source_info.label

@property
def device_source_info(self):
object_store_id = self.object_store_id
device_source_map = self.object_store.get_quota_source_map()
return device_source_map.get_device_source_info(object_store_id)

def set_file_name(self, filename):
if not filename:
self.external_filename = None
Expand Down
40 changes: 38 additions & 2 deletions lib/galaxy/objectstore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
DEFAULT_PRIVATE = False
DEFAULT_QUOTA_SOURCE = None # Just track quota right on user object in Galaxy.
DEFAULT_QUOTA_ENABLED = True # enable quota tracking in object stores by default

DEFAULT_DEVICE_ID = None
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -329,6 +329,10 @@ def to_dict(self) -> Dict[str, Any]:
def get_quota_source_map(self):
"""Return QuotaSourceMap describing mapping of object store IDs to quota sources."""

@abc.abstractmethod
def get_device_source_map(self) -> "DeviceSourceMap":
"""Return DeviceSourceMap describing mapping of object store IDs to device sources."""


class BaseObjectStore(ObjectStore):
store_by: str
Expand Down Expand Up @@ -491,6 +495,9 @@ def get_quota_source_map(self):
# I'd rather keep this abstract... but register_singleton wants it to be instantiable...
raise NotImplementedError()

def get_device_source_map(self):
return DeviceSourceMap()


class ConcreteObjectStore(BaseObjectStore):
"""Subclass of ObjectStore for stores that don't delegate (non-nested).
Expand All @@ -501,6 +508,7 @@ class ConcreteObjectStore(BaseObjectStore):
"""

badges: List[StoredBadgeDict]
device_id: Optional[str] = None

def __init__(self, config, config_dict=None, **kwargs):
"""
Expand Down Expand Up @@ -528,6 +536,7 @@ def __init__(self, config, config_dict=None, **kwargs):
quota_config = config_dict.get("quota", {})
self.quota_source = quota_config.get("source", DEFAULT_QUOTA_SOURCE)
self.quota_enabled = quota_config.get("enabled", DEFAULT_QUOTA_ENABLED)
self.device_id = config_dict.get("device", None)
self.badges = read_badges(config_dict)

def to_dict(self):
Expand All @@ -541,6 +550,7 @@ def to_dict(self):
"enabled": self.quota_enabled,
}
rval["badges"] = self._get_concrete_store_badges(None)
rval["device"] = self.device_id
return rval

def to_model(self, object_store_id: str) -> "ConcreteObjectStoreModel":
Expand All @@ -551,6 +561,7 @@ def to_model(self, object_store_id: str) -> "ConcreteObjectStoreModel":
description=self.description,
quota=QuotaModel(source=self.quota_source, enabled=self.quota_enabled),
badges=self._get_concrete_store_badges(None),
device=self.device_id,
)

def _get_concrete_store_badges(self, obj) -> List[BadgeDict]:
Expand Down Expand Up @@ -587,6 +598,9 @@ def get_quota_source_map(self):
)
return quota_source_map

def get_device_source_map(self) -> "DeviceSourceMap":
return DeviceSourceMap(self.device_id)


class DiskObjectStore(ConcreteObjectStore):
"""
Expand Down Expand Up @@ -1036,7 +1050,7 @@ def __init__(self, config, config_dict, fsmon=False):
"""
super().__init__(config, config_dict)
self._quota_source_map = None

self._device_source_map = None
self.backends = {}
self.weighted_backend_ids = []
self.original_weighted_backend_ids = []
Expand Down Expand Up @@ -1208,6 +1222,13 @@ def get_quota_source_map(self):
self._quota_source_map = quota_source_map
return self._quota_source_map

def get_device_source_map(self) -> "DeviceSourceMap":
if self._device_source_map is None:
device_source_map = DeviceSourceMap()
self._merge_device_source_map(device_source_map, self)
self._device_source_map = device_source_map
return self._device_source_map

@classmethod
def _merge_quota_source_map(clz, quota_source_map, object_store):
for backend_id, backend in object_store.backends.items():
Expand All @@ -1216,6 +1237,14 @@ def _merge_quota_source_map(clz, quota_source_map, object_store):
else:
quota_source_map.backends[backend_id] = backend.get_quota_source_map()

@classmethod
def _merge_device_source_map(clz, device_source_map: "DeviceSourceMap", object_store):
for backend_id, backend in object_store.backends.items():
if isinstance(backend, DistributedObjectStore):
clz._merge_device_source_map(device_source_map, backend)
else:
device_source_map.backends[backend_id] = backend.get_device_source_map()

def __get_store_id_for(self, obj, **kwargs):
if obj.object_store_id is not None:
if obj.object_store_id in self.backends:
Expand Down Expand Up @@ -1364,6 +1393,7 @@ class ConcreteObjectStoreModel(BaseModel):
description: Optional[str] = None
quota: QuotaModel
badges: List[BadgeDict]
device: Optional[str] = None


def type_to_object_store_class(store: str, fsmon: bool = False) -> Tuple[Type[BaseObjectStore], Dict[str, Any]]:
Expand Down Expand Up @@ -1506,6 +1536,12 @@ class QuotaSourceInfo(NamedTuple):
use: bool


class DeviceSourceMap:
def __init__(self, device_id=DEFAULT_DEVICE_ID):
self.default_device_id = device_id
self.backends = {}


class QuotaSourceMap:
def __init__(self, source=DEFAULT_QUOTA_SOURCE, enabled=DEFAULT_QUOTA_ENABLED):
self.default_quota_source = source
Expand Down
103 changes: 103 additions & 0 deletions lib/galaxy/quota/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from sqlalchemy import select
from sqlalchemy.orm import object_session
from sqlalchemy.sql import text

import galaxy.util
Expand All @@ -25,6 +26,13 @@ class QuotaAgent: # metaclass=abc.ABCMeta
the quota in other apps (LDAP maybe?) or via configuration files.
"""

def relabel_quota_for_dataset(self, dataset, quota_source_label):
"""Update the quota source label for dataset and adjust relevant quotas.
Subtract quota for labels from users using old label and quota for new label
for these users.
"""

# TODO: make abstractmethod after they work better with mypy
def get_quota(self, user, quota_source_label=None) -> Optional[int]:
"""Return quota in bytes or None if no quota is set."""
Expand Down Expand Up @@ -81,6 +89,9 @@ def __init__(self):
def get_quota(self, user, quota_source_label=None) -> Optional[int]:
return None

def relabel_quota_for_dataset(self, dataset, quota_source_label):
return None

@property
def default_quota(self):
return None
Expand Down Expand Up @@ -173,6 +184,98 @@ def get_quota(self, user, quota_source_label=None) -> Optional[int]:
else:
return None

def relabel_quota_for_dataset(self, dataset, from_label, to_label):
adjust = dataset.get_total_size()
with_quota_affected_users = """WITH quota_affected_users AS
(
SELECT DISTINCT user_id
FROM history
INNER JOIN
history_dataset_association on history_dataset_association.history_id = history.id
INNER JOIN
dataset on history_dataset_association.dataset_id = dataset.id
WHERE
dataset_id = :dataset_id
)"""
engine = object_session(dataset).bind

# Hack for older sqlite, would work on newer sqlite - 3.24.0
for_sqlite = "sqlite" in engine.dialect.name

if to_label == from_label:
return
if to_label is None:
print("IN HERE WITH")
to_statement = f"""
{with_quota_affected_users}
UPDATE galaxy_user
SET disk_usage = coalesce(disk_usage, 0) + :adjust
WHERE id in quota_affected_users
"""
else:
if for_sqlite:
to_statement = f"""
{with_quota_affected_users},
new_quota_sources (user_id, disk_usage, quota_source_label) AS (
SELECT user_id, :adjust as disk_usage, :to_label as quota_source_label
FROM quota_affected_users
)
INSERT OR REPLACE INTO user_quota_source_usage (id, user_id, quota_source_label, disk_usage)
SELECT old.id, new.user_id, new.quota_source_label, COALESCE(old.disk_usage + :adjust, :adjust)
FROM new_quota_sources as new LEFT JOIN user_quota_source_usage AS old ON new.user_id = old.user_id AND NEW.quota_source_label = old.quota_source_label"""
else:
to_statement = f"""
{with_quota_affected_users},
new_quota_sources (user_id, disk_usage, quota_source_label) AS (
SELECT user_id, :adjust as disk_usage, :to_label as quota_source_label
FROM quota_affected_users
)
INSERT INTO user_quota_source_usage(user_id, disk_usage, quota_source_label)
SELECT * FROM new_quota_sources
ON CONFLICT
ON constraint uqsu_unique_label_per_user
DO UPDATE SET disk_usage = user_quota_source_usage.disk_usage + :adjust
"""

if from_label is None:
from_statement = f"""
{with_quota_affected_users}
UPDATE galaxy_user
SET disk_usage = coalesce(disk_usage - :adjust, 0)
WHERE id in quota_affected_users
"""
else:
if for_sqlite:
from_statement = f"""
{with_quota_affected_users},
new_quota_sources (user_id, disk_usage, quota_source_label) AS (
SELECT user_id, :adjust as disk_usage, :from_label as quota_source_label
FROM quota_affected_users
)
INSERT OR REPLACE INTO user_quota_source_usage (id, user_id, quota_source_label, disk_usage)
SELECT old.id, new.user_id, new.quota_source_label, COALESCE(old.disk_usage - :adjust, 0)
FROM new_quota_sources as new LEFT JOIN user_quota_source_usage AS old ON new.user_id = old.user_id AND NEW.quota_source_label = old.quota_source_label"""
else:
from_statement = f"""
{with_quota_affected_users},
new_quota_sources (user_id, disk_usage, quota_source_label) AS (
SELECT user_id, 0 as disk_usage, :from_label as quota_source_label
FROM quota_affected_users
)
INSERT INTO user_quota_source_usage(user_id, disk_usage, quota_source_label)
SELECT * FROM new_quota_sources
ON CONFLICT
ON constraint uqsu_unique_label_per_user
DO UPDATE SET disk_usage = user_quota_source_usage.disk_usage - :adjust
"""

bind = {"dataset_id": dataset.id, "adjust": int(adjust), "to_label": to_label, "from_label": from_label}
engine = self.sa_session.get_bind()
with engine.connect() as conn:
conn.execute(text(from_statement), bind)
conn.execute(text(to_statement), bind)
return None

def _default_unregistered_quota(self, quota_source_label):
return self._default_quota(self.model.DefaultQuotaAssociation.types.UNREGISTERED, quota_source_label)

Expand Down
50 changes: 48 additions & 2 deletions test/unit/data/test_quota.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,56 @@ def test_calculate_usage_default_storage_disabled(self):
assert usages[1].quota_source_label == "alt_source"
assert usages[1].total_disk_usage == 15

def _refresh_user_and_assert_disk_usage_is(self, usage):
def test_update_usage_from_labeled_to_unlabeled(self):
model = self.model
quota_agent = DatabaseQuotaAgent(model)
u = self.u

d = self._add_dataset(10)
alt_d = self._add_dataset(15, "alt_source_store")
self.model.session.flush()
assert quota_agent

quota_source_map = QuotaSourceMap(None, True)
alt_source = QuotaSourceMap("alt_source", True)
quota_source_map.backends["alt_source_store"] = alt_source

object_store = MockObjectStore(quota_source_map)
u.calculate_and_set_disk_usage(object_store)
self._refresh_user_and_assert_disk_usage_is(10)
quota_agent.relabel_quota_for_dataset(alt_d.dataset, "alt_source", None)
self._refresh_user_and_assert_disk_usage_is(25)

def test_update_usage_from_unlabeled_to_labeled(self):
model = self.model
quota_agent = DatabaseQuotaAgent(model)
u = self.u

d = self._add_dataset(10)
alt_d = self._add_dataset(15, "alt_source_store")
self.model.session.flush()
assert quota_agent

quota_source_map = QuotaSourceMap(None, True)
alt_source = QuotaSourceMap("alt_source", True)
quota_source_map.backends["alt_source_store"] = alt_source

object_store = MockObjectStore(quota_source_map)
u.calculate_and_set_disk_usage(object_store)
self._refresh_user_and_assert_disk_usage_is(15, "alt_source")
quota_agent.relabel_quota_for_dataset(d.dataset, None, "alt_source")
self._refresh_user_and_assert_disk_usage_is(25, "alt_source")

def _refresh_user_and_assert_disk_usage_is(self, usage, label=None):
u = self.u
self.model.context.refresh(u)
assert u.disk_usage == usage
if label is None:
assert u.disk_usage == usage
else:
usages = u.dictify_usage()
for u in usages:
if u.quota_source_label == label:
assert int(u.total_disk_usage) == int(usage)


class TestQuota(BaseModelTestCase):
Expand Down

0 comments on commit 7780321

Please sign in to comment.