diff --git a/src/wormhole_mailbox_server/database.py b/src/wormhole_mailbox_server/database.py index 1db7d4e..fa99b75 100644 --- a/src/wormhole_mailbox_server/database.py +++ b/src/wormhole_mailbox_server/database.py @@ -14,9 +14,12 @@ def get_schema(name, version): return schema_bytes.decode("utf-8") def get_upgrader(name, new_version): - schema_bytes = resource_string("wormhole_mailbox_server", - "db-schemas/upgrade-%s-to-v%d.sql" % - (name, new_version)) + try: + schema_bytes = resource_string("wormhole_mailbox_server", + "db-schemas/upgrade-%s-to-v%d.sql" % + (name, new_version)) + except FileNotFoundError: + raise ValueError("no upgrader for %d" % new_version) return schema_bytes.decode("utf-8") CHANNELDB_TARGET_VERSION = 1 diff --git a/src/wormhole_mailbox_server/test/test_database.py b/src/wormhole_mailbox_server/test/test_database.py index 52adeeb..bfe9440 100644 --- a/src/wormhole_mailbox_server/test/test_database.py +++ b/src/wormhole_mailbox_server/test/test_database.py @@ -91,6 +91,24 @@ def test_upgrade(self): # debug with "diff -u _trial_temp/up.sql _trial_temp/new.sql" self.assertEqual(dbA_text, latest_text) + def test_upgrade_fails(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "upgrade.db") + self.assertNotEqual(USAGEDB_TARGET_VERSION, 1) + + # create an old-version DB in a file + db = _get_db(fn, "usage", 1) + rows = db.execute("SELECT * FROM version").fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["version"], 1) + del db + + # then upgrade the file to a too-new version, for which we have no + # upgrader + with self.assertRaises(DBError): + _get_db(fn, "usage", USAGEDB_TARGET_VERSION+1) + class CreateChannel(unittest.TestCase): def test_memory(self): db = database.create_channel_db(":memory:") diff --git a/src/wormhole_mailbox_server/test/test_web.py b/src/wormhole_mailbox_server/test/test_web.py index 6398091..a5dbc10 100644 --- a/src/wormhole_mailbox_server/test/test_web.py +++ b/src/wormhole_mailbox_server/test/test_web.py @@ -7,6 +7,7 @@ from twisted.internet.defer import inlineCallbacks, returnValue from ..web import make_web_server from ..server import SidedMessage +from ..database import create_or_upgrade_usage_db from .common import ServerBase, _Util from .ws_client import WSFactory @@ -90,8 +91,10 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): def setUp(self): self._lp = None self._clients = [] + self._usage_db = usage_db = create_or_upgrade_usage_db(":memory:") yield self._setup_relay(do_listen=True, - advertise_version="advertised.version") + advertise_version="advertised.version", + usage_db=usage_db) def tearDown(self): for c in self._clients: @@ -158,6 +161,36 @@ def test_bind(self): self.assertEqual(err["type"], "error") self.assertEqual(err["error"], "ping requires 'ping'") + @inlineCallbacks + def test_bind_with_client_version(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + + c1.send("bind", appid="appid", side="side", + client_version=("python", "1.2.3")) + yield c1.sync() + self.assertEqual(list(self._server._apps.keys()), ["appid"]) + v = self._usage_db.execute("SELECT * FROM `client_versions`").fetchall() + self.assertEqual(v[0]["app_id"], "appid") + self.assertEqual(v[0]["side"], "side") + self.assertEqual(v[0]["implementation"], "python") + self.assertEqual(v[0]["version"], "1.2.3") + + @inlineCallbacks + def test_bind_with_client_version_extra_junk(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + + c1.send("bind", appid="appid", side="side", + client_version=("python", "1.2.3", "extra ignore me")) + yield c1.sync() + self.assertEqual(list(self._server._apps.keys()), ["appid"]) + v = self._usage_db.execute("SELECT * FROM `client_versions`").fetchall() + self.assertEqual(v[0]["app_id"], "appid") + self.assertEqual(v[0]["side"], "side") + self.assertEqual(v[0]["implementation"], "python") + self.assertEqual(v[0]["version"], "1.2.3") + @inlineCallbacks def test_list(self): c1 = yield self.make_client()