diff --git a/src/client/client_node.py b/src/client/client_node.py index 54eaa38..969d8d0 100755 --- a/src/client/client_node.py +++ b/src/client/client_node.py @@ -29,6 +29,7 @@ def run(self): updates = self.conn.updates() for v in updates.values(): self.pub_man.publish(v) + self.conn.stop() def create_subscriber(self, topic, msg_type, trusted): namespace, msg_name = msg_type.split("/") diff --git a/src/client/connection.py b/src/client/connection.py index 04e0189..cc24f7a 100644 --- a/src/client/connection.py +++ b/src/client/connection.py @@ -4,66 +4,11 @@ import json import copy import struct -from twisted.internet import reactor -from twisted.internet.protocol import ReconnectingClientFactory -from autobahn.twisted.websocket import WebSocketClientProtocol -from autobahn.twisted.websocket import WebSocketClientFactory import rospy - - -class MMClient(WebSocketClientProtocol): - - client = None - updates = dict() - acknowledged = True - timer = threading.Timer - - def onConnect(self, reponse): - MMClient.client = self - MMClient.acknowledged = True - MMClient.timer = threading.Timer - - def onMessage(self, payload, is_binary): - if not is_binary: - data = json.loads(payload) - MMClient.updates[data["topic"]] = data - else: - if len(payload) == 1: - MMClient.acknowledged = True - MMClient.timer.cancel() - else: - decompressed = zlib.decompress(payload) - size = struct.unpack('=I', decompressed[:4]) - frmt = "%ds" % size[0] - unpacked = struct.unpack('=I' + frmt, decompressed) - data = json.loads(unpacked[1]) - MMClient.updates[data["topic"]] = data - - def onClose(self, wasClean, code, reason): - rospy.logwarn("WebSocket connection closed: {0}".format(reason)) - - @staticmethod - def timeout(): - MMClient.acknowledged = True - - @staticmethod - def send_message(payload, is_binary): - if not MMClient.client is None: - # rospy.loginfo(MMClient.acknowledged) - if MMClient.acknowledged: - MMClient.acknowledged = False - MMClient.client.sendMessage(payload, is_binary) - MMClient.timer = threading.Timer(1, MMClient.timeout) - MMClient.timer.start() - - -class ClientFactory(WebSocketClientFactory, ReconnectingClientFactory): - def clientConnectionFailed(self, connector, reason): - print "Connection Failed {} -- {}".format(connector, reason) - - def clientConnectionLost(self, connector, reason): - print "Connection Failed {} -- {}".format(connector, reason) - +import tornado.web +import tornado.websocket +import tornado.httpserver +import tornado.ioloop class Connection(threading.Thread): def __init__(self, host, port, name): @@ -72,27 +17,70 @@ def __init__(self, host, port, name): self.port = port self.name = name self.url = "ws://{}:{}/{}".format(host, port, name) - self.factory = ClientFactory(self.url, debug=False) - self.daemon = True + self.ioloop = tornado.ioloop.IOLoop.current() + self.connection = None + self.values = dict() + self.acknowledged = True + self.timer = threading.Timer def run(self): - self.factory.protocol = MMClient - reactor.connectTCP(self.host, self.port, self.factory) - reactor.run(installSignalHandlers=0) + tornado.websocket.websocket_connect( + self.url, + self.ioloop, + callback = self.on_connected, + on_message_callback = self.on_message) + self.ioloop.start() def stop(self): - reactor.stop() + self.ioloop.stop() - def send_message(self, data): + def send_message_cb(self, data): payload = json.dumps(data) frmt = "%ds" % len(payload) binary = struct.pack(frmt, payload) binLen = len(binary) binary = struct.pack('=I' + frmt, binLen, payload) compressed = zlib.compress(binary) - return MMClient.send_message(compressed, True) + if not self.connection is None: + # rospy.loginfo(self.acknowledged) + if self.acknowledged: + self.acknowledged = False + self.connection.write_message(compressed, True) + self.timer = threading.Timer(1, self.timeout) + self.timer.start() + + def send_message(self, data): + self.ioloop.add_callback(self.send_message_cb, data) def updates(self): - payloads = copy.copy(MMClient.updates) - MMClient.updates = dict() + payloads = copy.copy(self.values) + self.values = dict() return payloads + + def on_connected(self, res): + try: + self.connection = res.result() + except Exception, e: + print "Failed to connect: {}".format(e) + tornado.websocket.websocket_connect( + self.url, + self.ioloop, + callback = self.on_connected, + on_message_callback = self.on_message) + + + def on_message(self, payload): + if len(payload) == 1: + self.acknowledged = True + self.timer.cancel() + else: + decompressed = zlib.decompress(payload) + size = struct.unpack('=I', decompressed[:4]) + frmt = "%ds" % size[0] + unpacked = struct.unpack('=I' + frmt, decompressed) + data = json.loads(unpacked[1]) + self.values[data["topic"]] = data + + def timeout(self): + self.acknowledged = True + diff --git a/src/server/server_node.py b/src/server/server_node.py index 68fc130..6f7cc51 100755 --- a/src/server/server_node.py +++ b/src/server/server_node.py @@ -2,24 +2,36 @@ import ws import rospy -from twisted.internet import reactor -from autobahn.twisted.websocket import WebSocketServerFactory +import signal +import tornado.web +import tornado.websocket +import tornado.httpserver +import tornado.ioloop NODE_NAME = "jammi_server" +settings = {'debug': True} +app = tornado.web.Application([ + (r'/(.*)', ws.MMServerProtocol), + ], **settings) + +def sig_handler(sig, frame): + tornado.ioloop.IOLoop.instance().add_callback(shutdown) + +def shutdown(): + tornado.ioloop.IOLoop.instance().stop() def run_server(host, port): url = "ws://{}:{}".format(host, port) - factory = WebSocketServerFactory(url, debug=True) - factory.protocol = ws.MMServerProtocol - reactor.listenTCP(port, factory) - while not rospy.is_shutdown(): - reactor.iterate() - + http_server = tornado.httpserver.HTTPServer(app) + http_server.listen(port) + tornado.ioloop.IOLoop.instance().start() if __name__ == "__main__": rospy.init_node(NODE_NAME, anonymous=False) host = rospy.get_param("~host", "localhost") port = rospy.get_param("~port", 9000) + signal.signal(signal.SIGTERM, sig_handler) + signal.signal(signal.SIGINT, sig_handler) run_server(host, port) diff --git a/src/server/ws.py b/src/server/ws.py index 87d18eb..e7257e8 100644 --- a/src/server/ws.py +++ b/src/server/ws.py @@ -6,44 +6,42 @@ import common import struct from std_msgs.msg import Float32 -from autobahn.twisted.websocket import WebSocketServerProtocol +import tornado.web +import tornado.websocket +import tornado.httpserver +import tornado.ioloop -class MMServerProtocol(WebSocketServerProtocol): +class MMServerProtocol(tornado.websocket.WebSocketHandler): + lat_pubs = dict() - def __init__(self): - self.lat_pubs = dict() - - def onConnect(self, request): - name = request.path[1:] + def open(self, name): common.add_client(name, self) self.name_of_client = name - self.lat_pubs[name] = rospy.Publisher("/jammi/" + name + "/latency", - Float32, queue_size=2) + MMServerProtocol.lat_pubs[name] = rospy.Publisher("/jammi/" + name + + "/latency", Float32, queue_size=2) + print "Connected to: {}".format(name) + - def onMessage(self, payload, is_binary): - if is_binary: - try: - received_time = time.time() - decompressed = zlib.decompress(payload) - size = struct.unpack('=I', decompressed[:4]) - frmt = "%ds" % size[0] - unpacked = struct.unpack('=I' + frmt, decompressed) - msg = json.loads(unpacked[1]) - acknowledge = struct.pack('=b', 0) - common.get_client(msg["from"]).sendMessage(acknowledge, True) - latency = Float32() - latency.data = received_time - msg["stamp"] - self.lat_pubs[msg["from"]].publish(latency) - if msg["to"][0] == "*": - for name in common.clients.keys(): - if name != msg["from"]: - common.get_client(name).sendMessage(payload, True) - else: - for name in msg["to"]: - common.get_client(name).sendMessage(payload, True) - except KeyError: - pass + def on_message(self, message): + received_time = time.time() + decompressed = zlib.decompress(message) + size = struct.unpack('=I', decompressed[:4]) + frmt = "%ds" % size[0] + unpacked = struct.unpack('=I' + frmt, decompressed) + msg = json.loads(unpacked[1]) + acknowledge = struct.pack('=b', 0) + self.write_message(acknowledge, True) + latency = Float32() + latency.data = received_time - msg["stamp"] + MMServerProtocol.lat_pubs[msg["from"]].publish(latency) + if msg["to"][0] == "*": + for name in common.clients.keys(): + if name != msg["from"]: + common.get_client(name).write_message(message, True) + else: + for name in msg["to"]: + common.get_client(name).write_message(message, True) - def onClose(self, was_clean, code, reason): + def on_close(self): common.remove_client(self.name_of_client)