From f00035ff22b66ea389bf36450161c6a96d906412 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 21 Feb 2024 10:29:03 +0100 Subject: [PATCH] Let each item have their own prompt queue --- spine_engine/server/engine_server.py | 4 ++-- spine_engine/server/remote_execution_service.py | 4 ++-- spine_engine/spine_engine.py | 11 +++++------ spine_engine/utils/queue_logger.py | 2 +- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/spine_engine/server/engine_server.py b/spine_engine/server/engine_server.py index 19209def..f44658f9 100644 --- a/spine_engine/server/engine_server.py +++ b/spine_engine/server/engine_server.py @@ -162,8 +162,8 @@ def serve(self): msg = f"Answering prompt failed. Worker for job_id:{request.request_id()} not found." self.send_init_failed_reply(frontend, request.connection_id(), msg) continue - item_name, accepted = request.data() - worker.answer_prompt(item_name, accepted) + prompter_id, answer = request.data() + worker.answer_prompt(prompter_id, answer) continue elif request.cmd() == "retrieve_project": project_dir = project_dirs.get(request.request_id(), None) # Get project dir based on job_id diff --git a/spine_engine/server/remote_execution_service.py b/spine_engine/server/remote_execution_service.py index ac3af694..89bc71e6 100644 --- a/spine_engine/server/remote_execution_service.py +++ b/spine_engine/server/remote_execution_service.py @@ -153,9 +153,9 @@ def stop_engine(self): """Stops DAG execution.""" self.engine.stop() - def answer_prompt(self, item_name, accepted): + def answer_prompt(self, prompter_id, answer): """Answers prompt.""" - self.engine.answer_prompt(item_name, accepted) + self.engine.answer_prompt(prompter_id, answer) def close(self): """Cleans up sockets after worker is finished.""" diff --git a/spine_engine/spine_engine.py b/spine_engine/spine_engine.py index dce3ea41..57748fd3 100644 --- a/spine_engine/spine_engine.py +++ b/spine_engine/spine_engine.py @@ -259,9 +259,8 @@ def make_item(self, item_name, direction): Note that this method is called multiple times for each item: Once for the backward pipeline, and once for each filtered execution in the forward pipeline.""" item_dict = self._items[item_name] - if item_name not in self._prompt_queues: - self._prompt_queues[item_name] = mp.Queue() - prompt_queue = self._prompt_queues[item_name] + prompt_queue = mp.Queue() + self._prompt_queues[id(prompt_queue)] = prompt_queue logger = QueueLogger( self._queue, item_name, prompt_queue, self._answered_prompts, silent=direction is ED.BACKWARD ) @@ -299,9 +298,9 @@ def _get_event_stream(self): break self._thread.join() - def answer_prompt(self, item_name, accepted): - """Answers the prompt for the specified item, either accepting or rejecting it.""" - self._prompt_queues[item_name].put(accepted) + def answer_prompt(self, prompter_id, answer): + """Answers the prompt for the specified prompter id.""" + self._prompt_queues[prompter_id].put(answer) def wait(self): """Waits until engine execution has finished.""" diff --git a/spine_engine/utils/queue_logger.py b/spine_engine/utils/queue_logger.py index 55d7f3bc..97ddc0cb 100644 --- a/spine_engine/utils/queue_logger.py +++ b/spine_engine/utils/queue_logger.py @@ -83,7 +83,7 @@ def emit(self, prompt_data): key = str(prompt_data) if key not in self._answered_prompts: self._answered_prompts[key] = self._PENDING - prompt = {"item_name": self._item_name, "data": prompt_data} + prompt = {"prompter_id": id(self._prompt_queue), "data": prompt_data} self._queue.put(("prompt", prompt)) self._answered_prompts[key] = self._prompt_queue.get() while self._answered_prompts[key] is self._PENDING: