Skip to content

Commit

Permalink
Create TaskRunner for spawning ot-rcp (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
erjiaqing authored Dec 25, 2023
1 parent 910ef2b commit 16e8668
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 3 deletions.
92 changes: 92 additions & 0 deletions cirque/common/taskrunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import queue
import threading
import time

from cirque.common.cirquelog import CirqueLog

class Task:
def __init__(self, fn):
self.fn = fn
self.result = None
self.completed = None
self.exception = None
self.cv = threading.Condition()

def run(self):
try:
self.result = self.fn()
except Exception as ex:
self.exception = ex

with self.cv:
self.completed = True
self.cv.notify_all()

def wait(self):
with self.cv:
while not self.completed:
self.cv.wait()
if self.exception:
raise self.exception
return self.result

class _TaskRunner:
def __init__(self):
self.logger = CirqueLog.get_cirque_logger(self.__class__.__name__)
self.queue_cv = threading.Condition()
self.queue = queue.Queue()

def post_task(self, fn) -> Task:
task = Task(fn)
with self.queue_cv:
self.queue.put(task)
self.queue_cv.notify()
self.logger.info("Task sent to runner thread.")
return task

def start(self):
self.running = True
self.logger.info("Starting task runner.")
self.th = threading.Thread(target=lambda:self._run())
self.th.start()
self.logger.info("Task runner started.")


def stop(self):
self.logger.info("Stopping runner thread.")
with self.queue_cv:
self.running = False
self.queue_cv.notify()
self.th.join()

def _run(self):
self.logger.info("Task runner running.")
with self.queue_cv:
while self.running:
try:
task = self.queue.get_nowait()
self.logger.info("Handled task.")
taskStart = time.time()
task.run()
self.logger.info(f"Task handling duration: {time.time() - taskStart}s")
except queue.Empty:
self.logger.info(f"No task")
pass
self.queue_cv.wait()
self.logger.info("Task runner stopped.")

TaskRunner = _TaskRunner()
12 changes: 9 additions & 3 deletions cirque/connectivity/threadsimpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import shutil
from threading import Lock

from cirque.common.cirquelog import CirqueLog
from cirque.common.taskrunner import TaskRunner
from cirque.connectivity.socatpipepair import SocatPipePair


Expand All @@ -39,6 +41,7 @@ def __init__(self, node_id, petition_id=0, rcp=False):
self.radio_fd = None
self.radio_process = None
self.petition_id = petition_id
self.logger = CirqueLog.get_cirque_logger(self.__class__.__name__)
if rcp:
self.radio_command = 'ot-rcp'
else:
Expand All @@ -51,11 +54,14 @@ def open(self):
self.radio_fd = os.open(self.pipe_path_for_ncp, os.O_RDWR)
env = os.environ
env['PORT_OFFSET'] = str(self.petition_id * self.__THREAD_GROUP_SIZE)
self.radio_process = subprocess.Popen(
[self.radio_command, '{}'.format(self.node_id)],
command = [self.radio_command, '{}'.format(self.node_id)]
self.logger.info("-> Start virtual OpenThread Radio: command=%s, env=%s", command, env)
self.radio_process = TaskRunner.post_task(lambda:subprocess.Popen(
command,
env=env,
stdout=self.radio_fd,
stdin=self.radio_fd)
stdin=self.radio_fd,
stderr=subprocess.PIPE)).wait()

def close(self):
if self.radio_fd is not None:
Expand Down
3 changes: 3 additions & 0 deletions cirque/grpc/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import cirque.proto.service_pb2_grpc as service_pb2_grpc

from cirque.common.cirquelog import CirqueLog
from cirque.common.taskrunner import TaskRunner
from cirque.home.home import CirqueHome

logger = None
Expand Down Expand Up @@ -191,6 +192,7 @@ def __exit_handler():
global cirque_service
for home in cirque_service.homes.values():
home.destroy_home()
taskrunner.TaskRunner.stop()


class CirqueService(service_pb2_grpc.CirqueServiceServicer):
Expand Down Expand Up @@ -354,6 +356,7 @@ def serve(service_port=50051):
def main(service_port):
global logger
CirqueLog.setup_cirque_logger()
TaskRunner.start()
logger = CirqueLog.get_cirque_logger('grpc_service')
serve(service_port)

Expand Down
3 changes: 3 additions & 0 deletions cirque/restservice/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flask import Response

from cirque.common.cirquelog import CirqueLog
from cirque.common.taskrunner import TaskRunner
from cirque.home.home import CirqueHome

app = Flask(__name__)
Expand Down Expand Up @@ -119,8 +120,10 @@ def destroy_homes():


# becareful not to remove this part
atexit.register(lambda: TaskRunner.stop())
if service_mode:
atexit.register(destroy_homes)
TaskRunner.start()

if __name__ == '__main__':
app.run()

0 comments on commit 16e8668

Please sign in to comment.