Skip to content

Commit

Permalink
Let each item have their own prompt queue
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelma committed Feb 21, 2024
1 parent 593a931 commit f00035f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
4 changes: 2 additions & 2 deletions spine_engine/server/engine_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def serve(self):
msg = f"Answering prompt failed. Worker for job_id:{request.request_id()} not found."
self.send_init_failed_reply(frontend, request.connection_id(), msg)
continue
item_name, accepted = request.data()
worker.answer_prompt(item_name, accepted)
prompter_id, answer = request.data()
worker.answer_prompt(prompter_id, answer)
continue
elif request.cmd() == "retrieve_project":
project_dir = project_dirs.get(request.request_id(), None) # Get project dir based on job_id
Expand Down
4 changes: 2 additions & 2 deletions spine_engine/server/remote_execution_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def stop_engine(self):
"""Stops DAG execution."""
self.engine.stop()

def answer_prompt(self, item_name, accepted):
def answer_prompt(self, prompter_id, answer):
"""Answers prompt."""
self.engine.answer_prompt(item_name, accepted)
self.engine.answer_prompt(prompter_id, answer)

def close(self):
"""Cleans up sockets after worker is finished."""
Expand Down
11 changes: 5 additions & 6 deletions spine_engine/spine_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,8 @@ def make_item(self, item_name, direction):
Note that this method is called multiple times for each item:
Once for the backward pipeline, and once for each filtered execution in the forward pipeline."""
item_dict = self._items[item_name]
if item_name not in self._prompt_queues:
self._prompt_queues[item_name] = mp.Queue()
prompt_queue = self._prompt_queues[item_name]
prompt_queue = mp.Queue()
self._prompt_queues[id(prompt_queue)] = prompt_queue
logger = QueueLogger(
self._queue, item_name, prompt_queue, self._answered_prompts, silent=direction is ED.BACKWARD
)
Expand Down Expand Up @@ -299,9 +298,9 @@ def _get_event_stream(self):
break
self._thread.join()

def answer_prompt(self, item_name, accepted):
"""Answers the prompt for the specified item, either accepting or rejecting it."""
self._prompt_queues[item_name].put(accepted)
def answer_prompt(self, prompter_id, answer):
"""Answers the prompt for the specified prompter id."""
self._prompt_queues[prompter_id].put(answer)

def wait(self):
"""Waits until engine execution has finished."""
Expand Down
2 changes: 1 addition & 1 deletion spine_engine/utils/queue_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def emit(self, prompt_data):
key = str(prompt_data)
if key not in self._answered_prompts:
self._answered_prompts[key] = self._PENDING
prompt = {"item_name": self._item_name, "data": prompt_data}
prompt = {"prompter_id": id(self._prompt_queue), "data": prompt_data}
self._queue.put(("prompt", prompt))
self._answered_prompts[key] = self._prompt_queue.get()
while self._answered_prompts[key] is self._PENDING:
Expand Down

0 comments on commit f00035f

Please sign in to comment.