diff --git a/src/wormhole_mailbox_server/server.py b/src/wormhole_mailbox_server/server.py index fa8ab20..690c7c2 100644 --- a/src/wormhole_mailbox_server/server.py +++ b/src/wormhole_mailbox_server/server.py @@ -554,7 +554,7 @@ def _shutdown(self): class Server(service.MultiService): def __init__(self, db, allow_list, welcome, - blur_usage, usage_db=None, log_file=None, permission_provider=None): + blur_usage, usage_db=None, log_file=None, permission_providers=None): service.MultiService.__init__(self) self._db = db self._allow_list = allow_list @@ -563,12 +563,17 @@ def __init__(self, db, allow_list, welcome, self._log_requests = blur_usage is None self._usage_db = usage_db self._log_file = log_file - if permission_provider is not None: - if not IPermission.implementedBy(permission_provider): + if permission_providers is not None: + for perm in permission_providers: + if not IPermission.implementedBy(perm): + raise ValueError( + "All permission_providers must be IPermission" + ) + if not permission_providers: raise ValueError( - "permission_provider must be IPermission" + "Need at least one permission provider" ) - self._permission_provider = permission_provider + self._permission_providers = permission_providers self._apps = {} def get_welcome(self): @@ -578,18 +583,19 @@ def get_welcome(self): """ return self._welcome - def create_permission_provider(self): + def instantiate_permission_providers(self): """ - An object that encapsulates how to grant permission. - In prinicipal (in the protocol) we could support many, we - currently only support two (and only one at a time): 'none' - and 'hashcash'. + A list of objects that encapsulate ways to grant permission. - The `none` one does nothing. - - :returns IPermission: a method of granting permission + :returns list[IPermission]: at least one (but possibly more) + methods of granting permission """ - return self._permission_provider() + # Since providers can have state, we need to instantiate them + # once for each connection + return [ + provider() + for provider in self._permission_providers + ] def get_log_requests(self): return self._log_requests @@ -705,7 +711,7 @@ def make_server(db, allow_list=True, advertise_version=None, signal_error=None, blur_usage=None, - permission_provider=None, + permission_providers=None, usage_db=None, log_file=None, welcome_motd=None, @@ -731,9 +737,11 @@ def make_server(db, allow_list=True, if signal_error: welcome["error"] = signal_error - if permission_provider is None: - permission_provider = create_permission_provider("none") + if permission_providers is None: + permission_providers = [ + create_permission_provider("none") + ] return Server(db, allow_list=allow_list, welcome=welcome, blur_usage=blur_usage, usage_db=usage_db, log_file=log_file, - permission_provider=permission_provider) + permission_providers=permission_providers) diff --git a/src/wormhole_mailbox_server/server_tap.py b/src/wormhole_mailbox_server/server_tap.py index da4ce7e..cd39338 100644 --- a/src/wormhole_mailbox_server/server_tap.py +++ b/src/wormhole_mailbox_server/server_tap.py @@ -31,7 +31,7 @@ class Options(usage.Options): ("advertise-version", None, None, "version to recommend to clients"), ("signal-error", None, None, "force all clients to fail with a message"), ("motd", None, None, "Send a Message of the Day in the welcome"), - ("permissions", None, "none", + ("permissions", None, None, "demand permissions ({})".format(", ".join(valid_permissions))), ] optFlags = [ @@ -42,6 +42,7 @@ def __init__(self): super(Options, self).__init__() self["websocket-protocol-options"] = [] self["allow-list"] = True + self["permissions"] = set() def opt_permissions(self, arg): if arg not in valid_permissions: @@ -50,7 +51,7 @@ def opt_permissions(self, arg): ", ".join(valid_permissions) ) ) - self["permissions"] = arg + self["permissions"].add(arg) def opt_disallow_list(self): self["allow-list"] = False @@ -94,13 +95,19 @@ def makeService(config, channel_db="relay.sqlite", reactor=reactor): if config["log-fd"] is not None else None) + # if the user specified any permissions at all, then we use only + # that set. Otherwise, we use the set ["none"] + permissions = [ + create_permission_provider(perm_name) + for perm_name in (config["permissions"] or ["none"]) + ] server = make_server( channel_db, allow_list=config["allow-list"], advertise_version=config["advertise-version"], signal_error=config["signal-error"], blur_usage=config["blur-usage"], - permission_provider=create_permission_provider(config.get("permissions", "none")), + permission_providers=permissions, usage_db=usage_db, log_file=log_file, welcome_motd=config["motd"], diff --git a/src/wormhole_mailbox_server/server_websocket.py b/src/wormhole_mailbox_server/server_websocket.py index 66ae8ab..cc315e6 100644 --- a/src/wormhole_mailbox_server/server_websocket.py +++ b/src/wormhole_mailbox_server/server_websocket.py @@ -110,14 +110,14 @@ def __init__(self): self._mailbox = None self._mailbox_id = None self._did_close = False - self._permission = None + self._permissions = None self._permission_passed = False def onConnect(self, request): rv = self.factory.server if rv.get_log_requests(): log.msg("ws client connecting: %s" % (request.peer,)) - self._permission = rv.create_permission_provider() + self._permissions = rv.instantiate_permission_providers() self._reactor = self.factory.reactor def _generate_welcome(self): @@ -130,10 +130,17 @@ def _generate_welcome(self): rv = self.factory.server static_welcome = rv.get_welcome() - if not isinstance(self._permission, NoPermission): + permissions = [ + perm.name + for perm in self._permissions + if not isinstance(perm, NoPermission) + ] + if permissions: welcome = { "permission-required": { - self._permission.name: self._permission.get_welcome_data(), + perm.name: perm.get_welcome_data() + for perm in self._permissions + if not isinstance(perm, NoPermission) } } welcome.update(static_welcome) @@ -189,7 +196,10 @@ def handle_ping(self, msg): def handle_bind(self, msg, server_rx): # if demanding permission, but no permission yet .. error - if not isinstance(self._permission, NoPermission) and not self._permission_passed: + # unless there's a "NoPermission" option, then we pass + if any(isinstance(p, NoPermission) for p in self._permissions): + self._permission_passed = True + if not self._permission_passed: raise Error("must submit-permission first") if self._app or self._side: