Skip to content

Commit

Permalink
Workers count can be given at engine creation
Browse files Browse the repository at this point in the history
  • Loading branch information
sapetnioc committed Apr 26, 2024
1 parent 19cd8b1 commit dbb9b07
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 17 deletions.
5 changes: 4 additions & 1 deletion capsul/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,17 @@ def engines(self):
if field.name != "databases":
yield self.engine(field.name)

def engine(self, name="builtin", update_database=False):
def engine(self, name="builtin", workers_count=None, update_database=False):
"""Get a :class:`~capsul.engine.Engine` instance"""
from .engine import Engine

# get engine type from config
engine_config = getattr(self.config, name, None)
if engine_config is None:
raise ValueError(f'engine "{name}" is not configured.')
if workers_count is not None:
engine_config.start_workers = engine_config.start_workers.copy()
engine_config.start_workers["count"] = workers_count
return Engine(
name,
engine_config,
Expand Down
32 changes: 16 additions & 16 deletions capsul/database/populse_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"label": [str, {"index": True}],
"config": dict,
"workers": list[str],
"executions": list[dict],
"executions": list[str],
"persistent": bool,
"connections": int,
}
Expand Down Expand Up @@ -183,14 +183,14 @@ def worker_database_config(self, engine_id):
return self.config

def worker_started(self, engine_id):
with self.storage.data(write=True) as db:
with self.storage.data(write=True, exclusive=True) as db:
worker_id = str(uuid4())
workers = db.capsul_engine[engine_id].workers.get()
if workers is not None:
workers.append(worker_id)
db.capsul_engine[engine_id].workers = workers
return worker_id
raise ValueError(f"Invalid engine_id: {engine_id}")
return worker_id
raise ValueError(f"Invalid engine_id: {engine_id}")

def worker_ended(self, engine_id, worker_id):
with self.storage.data(write=True) as db:
Expand Down Expand Up @@ -395,9 +395,9 @@ def job_finished_json(
waiting_job["return_code"] = (
"Not started because de dependent job failed"
)
db.capsul_job[
engine_id, execution_id, waiting_id
].job = waiting_job
db.capsul_job[engine_id, execution_id, waiting_id].job = (
waiting_job
)
waiting.remove(waiting_id)
failed.append(waiting_id)
stack.update(waiting_job.get("waited_by", []))
Expand Down Expand Up @@ -427,9 +427,9 @@ def job_finished_json(

if not ongoing and not ready:
if failed:
db.capsul_execution[
engine_id, execution_id
].error = "Some jobs failed"
db.capsul_execution[engine_id, execution_id].error = (
"Some jobs failed"
)
db.capsul_execution[engine_id, execution_id].update(
{
"status": "finalization",
Expand Down Expand Up @@ -483,12 +483,12 @@ def set_job_output_parameters(
indices = job.get("parameters_index", {})
for name, value in output_parameters.items():
values[indices[name]] = value
db.capsul_job[
engine_id, execution_id, job_id
].job.output_parameters = output_parameters
db.capsul_execution[
engine_id, execution_id
].workflow_parameters_values = values
db.capsul_job[engine_id, execution_id, job_id].job.output_parameters = (
output_parameters
)
db.capsul_execution[engine_id, execution_id].workflow_parameters_values = (
values
)

def job_json(self, engine_id, execution_id, job_id):
if os.path.exists(self.path):
Expand Down
27 changes: 27 additions & 0 deletions capsul/test/test_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from capsul.api import Capsul
import time

def noop() -> None:
pass

def test_start_workers():
capsul = Capsul(database_path="")
noop_executable = capsul.executable(noop)
for wc in [3, 2, 1]:
with capsul.engine(workers_count=wc) as engine:
requested = engine.config.start_workers.get("count", 0)
assert requested == wc
noop_id = engine.start(noop_executable)
for i in range(100):
if engine.database.workers_count(engine.engine_id) == wc:
break
time.sleep(0.2)
else:
raise RuntimeError(f'expected {wc} workers to be created, got {engine.database.workers_count(engine.engine_id)}')
engine.dispose(noop_id)
for i in range(100):
if engine.database.workers_count(engine.engine_id) == 0:
break
time.sleep(0.2)
else:
raise RuntimeError(f'expected workers to be stopped; running workers = {engine.database.workers_count(engine.engine_id)}')

0 comments on commit dbb9b07

Please sign in to comment.