Skip to content

Commit

Permalink
Merge pull request #218 from openforcefield/alchemiscale-fah
Browse files Browse the repository at this point in the history
Changes needed to support execution via `alchemiscale-fah`
  • Loading branch information
dotsdl authored Jul 15, 2024
2 parents 9386907 + 208eeb3 commit e1b408b
Show file tree
Hide file tree
Showing 23 changed files with 421 additions and 171 deletions.
3 changes: 2 additions & 1 deletion alchemiscale/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def get_settings_override():
def synchronous(config_file):
from alchemiscale.models import Scope
from alchemiscale.compute.service import SynchronousComputeService
from alchemiscale.compute.settings import ComputeServiceSettings

params = yaml.safe_load(config_file)

Expand All @@ -373,7 +374,7 @@ def synchronous(config_file):
Scope.from_str(scope) for scope in params_init["scopes"]
]

service = SynchronousComputeService(**params_init)
service = SynchronousComputeService(ComputeServiceSettings(**params_init))

# add signal handling
for signame in {"SIGHUP", "SIGINT", "SIGTERM"}:
Expand Down
81 changes: 81 additions & 0 deletions alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import json
from datetime import datetime, timedelta
import random

from fastapi import FastAPI, APIRouter, Body, Depends
from fastapi.middleware.gzip import GZipMiddleware
Expand All @@ -23,6 +24,7 @@
get_cred_entity,
validate_scopes,
validate_scopes_query,
minimize_scope_space,
_check_store_connectivity,
gufe_to_json,
GzipRoute,
Expand Down Expand Up @@ -177,6 +179,7 @@ def claim_taskhub_tasks(
*,
compute_service_id: str = Body(),
count: int = Body(),
protocols: Optional[List[str]] = Body(None, embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
Expand All @@ -187,13 +190,91 @@ def claim_taskhub_tasks(
taskhub=taskhub_scoped_key,
compute_service_id=ComputeServiceID(compute_service_id),
count=count,
protocols=protocols,
)

return [str(t) if t is not None else None for t in tasks]


@router.post("/claim")
def claim_tasks(
scopes: List[Scope] = Body(),
compute_service_id: str = Body(),
count: int = Body(),
protocols: Optional[List[str]] = Body(None, embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
# intersect query scopes with accessible scopes in the token
scopes_reduced = minimize_scope_space(scopes)
query_scopes = []
for scope in scopes_reduced:
query_scopes.extend(validate_scopes_query(scope, token))

taskhubs = dict()
# query each scope for available taskhubs
# loop might be more removable in the future with a Union like operator on scopes
for single_query_scope in set(query_scopes):
taskhubs.update(n4js.query_taskhubs(scope=single_query_scope, return_gufe=True))

# list of tasks to return
tasks = []

if len(taskhubs) == 0:
return []

# claim tasks from taskhubs based on weight; keep going till we hit our
# total desired task count, or we run out of taskhubs to draw from
while len(tasks) < count and len(taskhubs) > 0:
weights = [th.weight for th in taskhubs.values()]

if sum(weights) == 0:
break

# based on weights, choose taskhub to draw from
taskhub: ScopedKey = random.choices(list(taskhubs.keys()), weights=weights)[0]

# claim tasks from the taskhub
claimed_tasks = n4js.claim_taskhub_tasks(
taskhub,
compute_service_id=ComputeServiceID(compute_service_id),
count=(count - len(tasks)),
protocols=protocols,
)

# gather up claimed tasks, if present
for t in claimed_tasks:
if t is not None:
tasks.append(t)

# remove this taskhub from the options available; repeat
taskhubs.pop(taskhub)

return [str(t) for t in tasks] + [None] * (count - len(tasks))


@router.get("/tasks/{task_scoped_key}/transformation")
def get_task_transformation(
task_scoped_key,
*,
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
sk = ScopedKey.from_str(task_scoped_key)
validate_scopes(sk.scope, token)

transformation: ScopedKey

transformation, _ = n4js.get_task_transformation(
task=task_scoped_key,
return_gufe=False,
)

return str(transformation)


@router.get("/tasks/{task_scoped_key}/transformation/gufe")
def retrieve_task_transformation(
task_scoped_key,
*,
n4js: Neo4jStore = Depends(get_n4js_depends),
Expand Down
49 changes: 40 additions & 9 deletions alchemiscale/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ class AlchemiscaleComputeClient(AlchemiscaleBaseClient):
_exception = AlchemiscaleComputeClientError

def register(self, compute_service_id: ComputeServiceID):
res = self._post_resource(f"computeservice/{compute_service_id}/register", {})
res = self._post_resource(f"/computeservice/{compute_service_id}/register", {})
return ComputeServiceID(res)

def deregister(self, compute_service_id: ComputeServiceID):
res = self._post_resource(f"computeservice/{compute_service_id}/deregister", {})
res = self._post_resource(
f"/computeservice/{compute_service_id}/deregister", {}
)
return ComputeServiceID(res)

def heartbeat(self, compute_service_id: ComputeServiceID):
res = self._post_resource(f"computeservice/{compute_service_id}/heartbeat", {})
res = self._post_resource(f"/computeservice/{compute_service_id}/heartbeat", {})
return ComputeServiceID(res)

def list_scopes(self) -> List[Scope]:
Expand Down Expand Up @@ -71,19 +73,48 @@ def query_taskhubs(
return taskhubs

def claim_taskhub_tasks(
self, taskhub: ScopedKey, compute_service_id: ComputeServiceID, count: int = 1
self,
taskhub: ScopedKey,
compute_service_id: ComputeServiceID,
count: int = 1,
protocols: Optional[List[str]] = None,
) -> Task:
"""Claim a `Task` from the specified `TaskHub`"""
data = dict(compute_service_id=str(compute_service_id), count=count)
tasks = self._post_resource(f"taskhubs/{taskhub}/claim", data)
data = dict(
compute_service_id=str(compute_service_id), count=count, protocols=protocols
)
tasks = self._post_resource(f"/taskhubs/{taskhub}/claim", data)

return [ScopedKey.from_str(t) if t is not None else None for t in tasks]

def claim_tasks(
self,
scopes: List[Scope],
compute_service_id: ComputeServiceID,
count: int = 1,
protocols: Optional[List[str]] = None,
):
"""Claim Tasks from TaskHubs within a list of Scopes."""
data = dict(
scopes=[scope.dict() for scope in scopes],
compute_service_id=str(compute_service_id),
count=count,
protocols=protocols,
)
tasks = self._post_resource("/claim", data)

return [ScopedKey.from_str(t) if t is not None else None for t in tasks]

def get_task_transformation(
def get_task_transformation(self, task: ScopedKey) -> ScopedKey:
"""Get the Transformation associated with the given Task."""
transformation = self._get_resource(f"/tasks/{task}/transformation")
return ScopedKey.from_str(transformation)

def retrieve_task_transformation(
self, task: ScopedKey
) -> Tuple[Transformation, Optional[ProtocolDAGResult]]:
transformation, protocoldagresult = self._get_resource(
f"tasks/{task}/transformation"
f"/tasks/{task}/transformation/gufe"
)

return (
Expand All @@ -104,6 +135,6 @@ def set_task_result(
compute_service_id=str(compute_service_id),
)

pdr_sk = self._post_resource(f"tasks/{task}/results", data)
pdr_sk = self._post_resource(f"/tasks/{task}/results", data)

return ScopedKey.from_dict(pdr_sk)
Loading

0 comments on commit e1b408b

Please sign in to comment.