From 3da60abee98808c07afb72230754958274c29922 Mon Sep 17 00:00:00 2001 From: meejah Date: Tue, 31 Jan 2023 18:59:10 -0700 Subject: [PATCH] refactor permissions --- src/wormhole_mailbox_server/permission.py | 126 ++++++++++++++++++ src/wormhole_mailbox_server/server.py | 116 ++-------------- src/wormhole_mailbox_server/server_tap.py | 23 ++-- .../server_websocket.py | 9 +- .../test/test_config.py | 12 ++ .../test/test_server.py | 15 ++- .../test/test_service.py | 8 +- src/wormhole_mailbox_server/test/test_web.py | 13 +- 8 files changed, 190 insertions(+), 132 deletions(-) create mode 100644 src/wormhole_mailbox_server/permission.py diff --git a/src/wormhole_mailbox_server/permission.py b/src/wormhole_mailbox_server/permission.py new file mode 100644 index 0000000..aae5472 --- /dev/null +++ b/src/wormhole_mailbox_server/permission.py @@ -0,0 +1,126 @@ +import os +import base64 +import hashlib +from zope.interface import ( + Interface, + Attribute, + implementer, +) + + +class IPermission(Interface): + """ + A server-side method of granting permission to a client. + """ + name = Attribute("name") + + def get_welcome_data(): + """ + return a dict of information to include under the name of this + Permission granter (under "permission-required" in the Welcome) + """ + + def verify_permission(submit_permission): + """ + return a bool indicating if the submit_permission data is a valid + permission (or not) + """ + + +def create_permission_provider(kind): + """ + returns a permissions-provider + """ + if kind == "none": + return NoPermission + elif kind == "hashcash": + return HashcashPermission + raise ValueError( + "Unknown permission provider '{}'".format(kind) + ) + + +@implementer(IPermission) +class NoPermission(object): + """ + A no-op permission provider used to grant any client access (the + default). + """ + name = "none" + + def get_welcome_data(self): + return {} + + def verify_permission(self, submit_permission): + return True + + +@implementer(IPermission) +class HashcashPermission(object): + """ + A permission provider that generates a random 'resource' string + and checks a proof-of-work from the client. + """ + name = "hashcash" + + def __init__(self, bits=20): + self._bits = bits + + def get_welcome_data(self): + """ + Generate the data to include under this method's key in the + `permission-required` value of the welcome message. + + Should be called at most once per connection. + """ + self._hashcash_resource = base64.b64encode(os.urandom(8)).decode("utf8") + return { + "bits": self._bits, + "resource": self._hashcash_resource, + } + + def verify_permission(self, perms): + """ + :returns bool: an indication of whether the provided permissions + reply from a client is valid + """ + # XXX THINK do we need this whole method to be constant-time? + # (basically impossible if it's not even syntactially valid?) + stamp = perms.get("stamp", "") + fields = stamp.split(":") + if len(fields) != 7: + return False + vers, claimed_bits, date, resource, ext, rand, counter = fields + vers = int(vers) + if vers != 1: + return False + if resource != self._hashcash_resource: + return False + + claimed_bits = int(claimed_bits) + if claimed_bits < self._bits: + return False + + h = hashlib.sha1() + h.update(stamp.encode("utf8")) + measured_hash = h.digest() + if leading_zero_bits(measured_hash) < claimed_bits: + return False + return True + + +def leading_zero_bits(bytestring): + """ + :returns int: the number of leading zeros in the given byte-string + """ + measured_bits = 0 + for byte in bytestring: + bit = 1 << 7 + while bit: + if byte & bit: + return measured_bits + else: + measured_bits += 1 + bit = bit >> 1 + + diff --git a/src/wormhole_mailbox_server/server.py b/src/wormhole_mailbox_server/server.py index 7aef321..8a7534b 100644 --- a/src/wormhole_mailbox_server/server.py +++ b/src/wormhole_mailbox_server/server.py @@ -4,6 +4,7 @@ from collections import namedtuple from twisted.python import log from twisted.application import service +from .permission import create_permission_provider def generate_mailbox_id(): return base64.b32encode(os.urandom(8)).lower().strip(b"=").decode("ascii") @@ -552,102 +553,9 @@ def _shutdown(self): channel._shutdown() -def leading_zero_bits(bytestring): - """ - :returns int: the number of leading zeros in the given byte-string - """ - measured_bits = 0 - for byte in bytestring: - bit = 1 << 7 - while bit: - if byte & bit: - return measured_bits - else: - measured_bits += 1 - bit = bit >> 1 - - -class NoPermission(object): - """ - A no-op permission provider used to grant any client access (the - default). - """ - name = "none" - - def get_welcome_data(self): - return {} - - def verify_permission(self, submit_permission): - return True - - def is_passed(self): - return True - - -class HashcashPermission(object): - """ - A permission provider that generates a random 'resource' string - and checks a proof-of-work from the client. - """ - name = "hashcash" - - def __init__(self, bits=20): - self._bits = bits - self._passed = False - - def get_welcome_data(self): - """ - Generate the data to include under this method's key in the - `permission-required` value of the welcome message. - - Should be called at most once per connection. - """ - self._hashcash_resource = base64.b64encode(os.urandom(8)).decode("utf8") - return { - "bits": self._bits, - "resource": self._hashcash_resource, - } - - def is_passed(self): - """ - :returns bool: True if verify_permission has been called successfully - """ - return self._passed - - def verify_permission(self, perms): - """ - :returns bool: an indication of whether the provided permissions - reply from a client is valid - """ - # XXX THINK do we need this whole method to be constant-time? - # (basically impossible if it's not even syntactially valid?) - stamp = perms.get("stamp", "") - fields = stamp.split(":") - if len(fields) != 7: - return False - vers, claimed_bits, date, resource, ext, rand, counter = fields - vers = int(vers) - if vers != 1: - return False - if resource != self._hashcash_resource: - return False - - claimed_bits = int(claimed_bits) - if claimed_bits < self._bits: - return False - - h = hashlib.sha1() - h.update(stamp.encode("utf8")) - measured_hash = h.digest() - if leading_zero_bits(measured_hash) < claimed_bits: - return False - self._passed = True - return True - - class Server(service.MultiService): def __init__(self, db, allow_list, welcome, - blur_usage, usage_db=None, log_file=None, permissions="none"): + blur_usage, usage_db=None, log_file=None, permission_provider=None): service.MultiService.__init__(self) self._db = db self._allow_list = allow_list @@ -656,8 +564,8 @@ def __init__(self, db, allow_list, welcome, self._log_requests = blur_usage is None self._usage_db = usage_db self._log_file = log_file - self._permissions = permissions - assert self._permissions in ("none", "hashcash") + self._permission_provider = permission_provider + # XXX assert interface instead assert self._permissions in ("none", "hashcash") self._apps = {} def get_welcome(self): @@ -678,14 +586,7 @@ def get_permission_method(self): :returns IPermissionGranter: a method of permission """ - if self._permissions == "none": - return NoPermission() - elif self._permissions == "hashcash": - return HashcashPermission() - else: - raise ValueError( - 'Unknown permission "{}"'.format(self._permissions) - ) + return self._permission_provider() def get_log_requests(self): return self._log_requests @@ -801,7 +702,7 @@ def make_server(db, allow_list=True, advertise_version=None, signal_error=None, blur_usage=None, - permissions="none", + permission_provider=None, usage_db=None, log_file=None, welcome_motd=None, @@ -827,6 +728,9 @@ def make_server(db, allow_list=True, if signal_error: welcome["error"] = signal_error + if permission_provider is None: + permission_provider = create_permission_provider("none") + return Server(db, allow_list=allow_list, welcome=welcome, blur_usage=blur_usage, usage_db=usage_db, log_file=log_file, - permissions=permissions) + permission_provider=permission_provider) diff --git a/src/wormhole_mailbox_server/server_tap.py b/src/wormhole_mailbox_server/server_tap.py index 5ce4e22..da4ce7e 100644 --- a/src/wormhole_mailbox_server/server_tap.py +++ b/src/wormhole_mailbox_server/server_tap.py @@ -9,6 +9,7 @@ from .server import make_server from .web import make_web_server from .database import create_or_upgrade_channel_db, create_or_upgrade_usage_db +from .permission import create_permission_provider LONGDESC = """This plugin sets up a 'Mailbox' server for magic-wormhole. This service forwards short messages between clients, to perform key exchange @@ -92,16 +93,18 @@ def makeService(config, channel_db="relay.sqlite", reactor=reactor): log_file = (os.fdopen(int(config["log-fd"]), "w") if config["log-fd"] is not None else None) - server = make_server(channel_db, - allow_list=config["allow-list"], - advertise_version=config["advertise-version"], - signal_error=config["signal-error"], - blur_usage=config["blur-usage"], - permissions=config["permissions"], - usage_db=usage_db, - log_file=log_file, - welcome_motd=config["motd"], - ) + + server = make_server( + channel_db, + allow_list=config["allow-list"], + advertise_version=config["advertise-version"], + signal_error=config["signal-error"], + blur_usage=config["blur-usage"], + permission_provider=create_permission_provider(config.get("permissions", "none")), + usage_db=usage_db, + log_file=log_file, + welcome_motd=config["motd"], + ) server.setServiceParent(parent) rebooted = time.time() def expire(): diff --git a/src/wormhole_mailbox_server/server_websocket.py b/src/wormhole_mailbox_server/server_websocket.py index c3acd1c..aca4ac6 100644 --- a/src/wormhole_mailbox_server/server_websocket.py +++ b/src/wormhole_mailbox_server/server_websocket.py @@ -3,7 +3,8 @@ from twisted.internet import reactor from twisted.python import log from autobahn.twisted import websocket -from .server import CrowdedError, ReclaimedError, SidedMessage, NoPermission +from .server import CrowdedError, ReclaimedError, SidedMessage +from .permission import NoPermission from .util import dict_to_bytes, bytes_to_dict # The WebSocket allows the client to send "commands" to the server, and the @@ -110,6 +111,7 @@ def __init__(self): self._mailbox_id = None self._did_close = False self._permission = None + self._permission_passed = False def onConnect(self, request): rv = self.factory.server @@ -187,7 +189,7 @@ def handle_ping(self, msg): def handle_bind(self, msg, server_rx): # if demanding permission, but no permission yet .. error - if self._permission is not None and not self._permission.is_passed(): + if not isinstance(self._permission, NoPermission) and not self._permission_passed: raise Error("must submit-permission first") if self._app or self._side: @@ -205,7 +207,8 @@ def handle_bind(self, msg, server_rx): def handle_submit_permissions(self, msg, server_rx): if msg.get("method", None) != self._permission.name: raise Error("need permission method '{}'".format(self._permission.name)) - if not self._permission.verify_permission(msg): + self._permission_passed = self._permission.verify_permission(msg) + if not self._permission_passed: raise Error("submit-permission failed") def handle_list(self): diff --git a/src/wormhole_mailbox_server/test/test_config.py b/src/wormhole_mailbox_server/test/test_config.py index d83e382..f4ea550 100644 --- a/src/wormhole_mailbox_server/test/test_config.py +++ b/src/wormhole_mailbox_server/test/test_config.py @@ -21,6 +21,7 @@ def test_defaults(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_advertise_version(self): @@ -38,6 +39,7 @@ def test_advertise_version(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_blur(self): @@ -55,6 +57,7 @@ def test_blur(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_channel_db(self): @@ -72,6 +75,7 @@ def test_channel_db(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_disallow_list(self): @@ -89,6 +93,7 @@ def test_disallow_list(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_log_fd(self): @@ -106,6 +111,7 @@ def test_log_fd(self): "log-fd": 5, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_port(self): @@ -123,6 +129,7 @@ def test_port(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) o = server_tap.Options() @@ -139,6 +146,7 @@ def test_port(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_signal_error(self): @@ -156,6 +164,7 @@ def test_signal_error(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_usage_db(self): @@ -173,6 +182,7 @@ def test_usage_db(self): "log-fd": None, "websocket-protocol-options": [], "permissions": "none", + "external-permission": {}, }) def test_websocket_protocol_option_1(self): @@ -190,6 +200,7 @@ def test_websocket_protocol_option_1(self): "log-fd": None, "websocket-protocol-options": [("foo", "bar")], "permissions": "none", + "external-permission": {}, }) def test_websocket_protocol_option_2(self): @@ -211,6 +222,7 @@ def test_websocket_protocol_option_2(self): ("baz", [1, "buz"]), ], "permissions": "none", + "external-permission": {}, }) def test_websocket_protocol_option_errors(self): diff --git a/src/wormhole_mailbox_server/test/test_server.py b/src/wormhole_mailbox_server/test/test_server.py index ef81ae1..a897c0d 100644 --- a/src/wormhole_mailbox_server/test/test_server.py +++ b/src/wormhole_mailbox_server/test/test_server.py @@ -13,8 +13,13 @@ from autobahn.twisted.websocket import WebSocketClientProtocol from .common import ServerBase, _Util from ..server import (make_server, Usage, - SidedMessage, CrowdedError, AppNamespace, - HashcashPermission, NoPermission) + SidedMessage, CrowdedError, AppNamespace) +from ..permission import ( + create_permission_provider, + NoPermission, + HashcashPermission, +) + from ..server_websocket import WebSocketServerFactory from ..util import bytes_to_dict, dict_to_bytes @@ -649,7 +654,7 @@ class Permissions(unittest.TestCase): def test_hashcash_permission(self): db = create_channel_db(":memory:") - s = make_server(db, permissions="hashcash") + s = make_server(db, permission_provider=create_permission_provider("hashcash")) self.assertIsInstance( s.get_permission_method(), HashcashPermission @@ -657,7 +662,7 @@ def test_hashcash_permission(self): def test_no_permission(self): db = create_channel_db(":memory:") - s = make_server(db, permissions="none") + s = make_server(db, permission_provider=create_permission_provider("none")) self.assertIsInstance( s.get_permission_method(), NoPermission @@ -693,7 +698,7 @@ def test_submit_success(self): self.addCleanup(pump.stop) def create_proto(): - server = make_server(create_channel_db(":memory:"), permissions="hashcash") + server = make_server(create_channel_db(":memory:"), permission_provider=create_permission_provider("hashcash")) factory = WebSocketServerFactory("ws://127.0.0.1:1", server) addr = IPv4Address("TCP", "127.0.0.1", "0") proto = factory.buildProtocol(addr) diff --git a/src/wormhole_mailbox_server/test/test_service.py b/src/wormhole_mailbox_server/test/test_service.py index 3f19cc6..86962a3 100644 --- a/src/wormhole_mailbox_server/test/test_service.py +++ b/src/wormhole_mailbox_server/test/test_service.py @@ -2,6 +2,10 @@ from twisted.trial import unittest import mock from twisted.application.service import MultiService +from ..permission import ( + NoPermission, + HashcashPermission, +) from .. import server_tap class Service(unittest.TestCase): @@ -24,7 +28,7 @@ def test_defaults(self): signal_error=None, welcome_motd=None, blur_usage=None, - permissions="none", + permission_provider=NoPermission, usage_db=udb, log_file=None)]) self.assertEqual(mws.mock_calls, [mock.call(r, True, [])]) @@ -52,6 +56,6 @@ def test_log_fd(self): signal_error=None, welcome_motd=None, blur_usage=None, - permissions="none", + permission_provider=NoPermission, usage_db=udb, log_file=fd)]) diff --git a/src/wormhole_mailbox_server/test/test_web.py b/src/wormhole_mailbox_server/test/test_web.py index c62a3e2..0ab068e 100644 --- a/src/wormhole_mailbox_server/test/test_web.py +++ b/src/wormhole_mailbox_server/test/test_web.py @@ -10,6 +10,7 @@ from ..web import make_web_server from ..server import SidedMessage from ..database import create_or_upgrade_usage_db +from ..permission import create_permission_provider from .common import ServerBase, _Util from .ws_client import WSFactory @@ -765,7 +766,7 @@ def make_client(self): @inlineCallbacks def test_hashcash(self): - yield self._setup_relay(do_listen=True, permissions="hashcash") + yield self._setup_relay(do_listen=True, permission_provider=create_permission_provider("hashcash")) c = yield self.make_client() welcome = yield c.next_non_ack() self.assertIn( @@ -779,7 +780,7 @@ def test_hashcash(self): @inlineCallbacks def test_hashcash_invalid_fields(self): - yield self._setup_relay(do_listen=True, permissions="hashcash") + yield self._setup_relay(do_listen=True, permission_provider=create_permission_provider("hashcash")) c = yield self.make_client() yield c.next_non_ack() yield c.send("submit-permissions", method="hashcash", stamp="wrong") @@ -789,7 +790,7 @@ def test_hashcash_invalid_fields(self): @inlineCallbacks def test_hashcash_wrong_version(self): - yield self._setup_relay(do_listen=True, permissions="hashcash") + yield self._setup_relay(do_listen=True, permission_provider=create_permission_provider("hashcash")) c = yield self.make_client() yield c.next_non_ack() yield c.send("submit-permissions", method="hashcash", stamp="0:2:*:*:*:*:*") @@ -799,7 +800,7 @@ def test_hashcash_wrong_version(self): @inlineCallbacks def test_hashcash_wrong_resource(self): - yield self._setup_relay(do_listen=True, permissions="hashcash") + yield self._setup_relay(do_listen=True, permission_provider=create_permission_provider("hashcash")) c = yield self.make_client() yield c.next_non_ack() yield c.send("submit-permissions", method="hashcash", stamp="1:2:date:resource:*:*:*") @@ -809,7 +810,7 @@ def test_hashcash_wrong_resource(self): @inlineCallbacks def test_hashcash_correct(self): - yield self._setup_relay(do_listen=True, permissions="hashcash") + yield self._setup_relay(do_listen=True, permission_provider=create_permission_provider("hashcash")) if not hasattr(shutil, "which") or not shutil.which("hashcash"): raise unittest.SkipTest("no 'hashcash' binary installed") @@ -841,7 +842,7 @@ def test_hashcash_correct(self): @inlineCallbacks def test_hashcash_wrong_bits(self): - yield self._setup_relay(do_listen=True, permissions="hashcash") + yield self._setup_relay(do_listen=True, permission_provider=create_permission_provider("hashcash")) if not hasattr(shutil, "which") or not shutil.which("hashcash"): raise unittest.SkipTest("no 'hashcash' binary installed")