Skip to content

Commit

Permalink
Added UNWIND to cancel tasks query
Browse files Browse the repository at this point in the history
Added the UNWIND clause to the cancel tasks query. Additionally, I
expanded the tests to explicitly test for returned `None`.
  • Loading branch information
ianmkenney committed Sep 3, 2024
1 parent 79ec027 commit 7c34d9c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
32 changes: 16 additions & 16 deletions alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,24 +1585,24 @@ def cancel_tasks(
none at all.
"""
canceled_sks = []
with self.transaction() as tx:
for t in tasks:
q = f"""
// get our task hub, as well as the task :ACTIONS relationship we want to remove
MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}})
DELETE ar
RETURN task
"""
_task = tx.run(q).to_eager_result()
query = """
UNWIND $task_scoped_keys AS task_scoped_key
MATCH (:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: task_scoped_key})
DELETE ar
RETURN task._scoped_key as task_scoped_key
"""
results = self.execute_query(
query,
task_scoped_keys=list(map(str, tasks)),
taskhub_scoped_key=str(taskhub),
)

if _task.records:
sk = _task.records[0].data()["task"]["_scoped_key"]
canceled_sks.append(ScopedKey.from_str(sk))
else:
canceled_sks.append(None)
returned_keys = {record["task_scoped_key"] for record in results.records}
filtered_tasks = [
task if str(task) in returned_keys else None for task in tasks
]

return canceled_sks
return filtered_tasks

def get_taskhub_tasks(
self, taskhub: ScopedKey, return_gufe=False
Expand Down
21 changes: 13 additions & 8 deletions alchemiscale/tests/integration/storage/test_statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,17 +1214,22 @@ def test_cancel_task(self, n4js, network_tyk2, scope_test):
canceled = n4js.cancel_tasks(task_sks[1:3], taskhub_sk)

# check that the hub has the contents we expect
q = f"""MATCH (tq:TaskHub {{_scoped_key: '{taskhub_sk}'}})-[:ACTIONS]->(task:Task)
return task
"""
q = """
MATCH (:TaskHub {_scoped_key: $taskhub_scoped_key})-[:ACTIONS]->(task:Task)
RETURN task._scoped_key AS task_scoped_key
"""

tasks = n4js.execute_query(q)
tasks = [record["task"] for record in tasks.records]
tasks = n4js.execute_query(q, taskhub_scoped_key=str(taskhub_sk))
tasks = [
ScopedKey.from_str(record["task_scoped_key"]) for record in tasks.records
]

assert len(tasks) == 8
assert set([ScopedKey.from_str(t["_scoped_key"]) for t in tasks]) == set(
actioned
) - set(canceled)
assert set(tasks) == set(actioned) - set(canceled)

# cancel the remaining tasks and check for Nones
canceled = n4js.cancel_tasks(task_sks, taskhub_sk)
assert canceled == [task_sks[0]] + [None, None] + task_sks[3:]

def test_get_taskhub_tasks(self, n4js, network_tyk2, scope_test):
an = network_tyk2
Expand Down

0 comments on commit 7c34d9c

Please sign in to comment.