Skip to content

Commit

Permalink
Merge branch 'record-client-versions'
Browse files Browse the repository at this point in the history
  • Loading branch information
warner committed Jun 24, 2018
2 parents 35774ad + 59f3ec9 commit 7d90055
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 58 deletions.
66 changes: 41 additions & 25 deletions src/wormhole_mailbox_server/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import unicode_literals
import os
import os, shutil
import sqlite3
import tempfile
from pkg_resources import resource_string
Expand All @@ -13,12 +13,21 @@ def get_schema(name, version):
"db-schemas/%s-v%d.sql" % (name, version))
return schema_bytes.decode("utf-8")

## def get_upgrader(new_version):
## schema_bytes = resource_string("wormhole_transit_relay",
## "db-schemas/upgrade-to-v%d.sql" % new_version)
## return schema_bytes.decode("utf-8")
def get_upgrader(name, new_version):
try:
FileNotFoundError # py3
except NameError:
FileNotFoundError = EnvironmentError # py2
try:
schema_bytes = resource_string("wormhole_mailbox_server",
"db-schemas/upgrade-%s-to-v%d.sql" %
(name, new_version))
except FileNotFoundError:
raise ValueError("no upgrader for %d" % new_version)
return schema_bytes.decode("utf-8")

TARGET_VERSION = 1
CHANNELDB_TARGET_VERSION = 1
USAGEDB_TARGET_VERSION = 2

def dict_factory(cursor, row):
d = {}
Expand Down Expand Up @@ -81,7 +90,7 @@ def _atomic_create_and_initialize_db(dbfile, name, target_version):
os.rename(temp_dbfile, dbfile)
return _open_db_connection(dbfile)

def _get_db(dbfile, name, target_version=TARGET_VERSION):
def _get_db(dbfile, name, target_version):
"""Open or create the given db file. The parent directory must exist.
Returns the db connection object, or raises DBError.
"""
Expand All @@ -95,31 +104,36 @@ def _get_db(dbfile, name, target_version=TARGET_VERSION):

version = db.execute("SELECT version FROM version").fetchone()["version"]

## while version < target_version:
## log.msg(" need to upgrade from %s to %s" % (version, target_version))
## try:
## upgrader = get_upgrader(version+1)
## except ValueError: # ResourceError??
## log.msg(" unable to upgrade %s to %s" % (version, version+1))
## raise DBError("Unable to upgrade %s to version %s, left at %s"
## % (dbfile, version+1, version))
## log.msg(" executing upgrader v%s->v%s" % (version, version+1))
## db.executescript(upgrader)
## db.commit()
## version = version+1
if version < target_version and dbfile != ":memory:":
backup_fn = "%s-backup-v%d" % (dbfile, version)
log.msg(" storing backup of v%d db in %s" % (version, backup_fn))
shutil.copy(dbfile, backup_fn)

while version < target_version:
log.msg(" need to upgrade from %s to %s" % (version, target_version))
try:
upgrader = get_upgrader(name, version+1)
except ValueError:
log.msg(" unable to upgrade %s to %s" % (version, version+1))
raise DBError("Unable to upgrade %s to version %s, left at %s"
% (dbfile, version+1, version))
log.msg(" executing upgrader v%s->v%s" % (version, version+1))
db.executescript(upgrader)
db.commit()
version = version+1

if version != target_version:
raise DBError("Unable to handle db version %s" % version)

return db

def create_or_upgrade_channel_db(dbfile):
return _get_db(dbfile, "channel")
return _get_db(dbfile, "channel", CHANNELDB_TARGET_VERSION)

def create_or_upgrade_usage_db(dbfile):
if dbfile is None:
return None
return _get_db(dbfile, "usage")
return _get_db(dbfile, "usage", USAGEDB_TARGET_VERSION)

class DBDoesntExist(Exception):
pass
Expand All @@ -140,21 +154,23 @@ def create_channel_db(dbfile):

if dbfile == ":memory:":
db = _open_db_connection(dbfile)
_initialize_db_schema(db, "channel", TARGET_VERSION)
_initialize_db_schema(db, "channel", CHANNELDB_TARGET_VERSION)
elif os.path.exists(dbfile):
raise DBAlreadyExists()
else:
db = _atomic_create_and_initialize_db(dbfile, "channel", TARGET_VERSION)
db = _atomic_create_and_initialize_db(dbfile, "channel",
CHANNELDB_TARGET_VERSION)
return db

def create_usage_db(dbfile):
if dbfile == ":memory:":
db = _open_db_connection(dbfile)
_initialize_db_schema(db, "usage", TARGET_VERSION)
_initialize_db_schema(db, "usage", USAGEDB_TARGET_VERSION)
elif os.path.exists(dbfile):
raise DBAlreadyExists()
else:
db = _atomic_create_and_initialize_db(dbfile, "usage", TARGET_VERSION)
db = _atomic_create_and_initialize_db(dbfile, "usage",
USAGEDB_TARGET_VERSION)
return db

def dump_db(db):
Expand Down
15 changes: 15 additions & 0 deletions src/wormhole_mailbox_server/db-schemas/upgrade-usage-to-v2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
CREATE TABLE `client_versions`
(
`app_id` VARCHAR,
`side` VARCHAR, -- for deduplication of reconnects
`connect_time` INTEGER, -- seconds since epoch, rounded to "blur time"
-- the client sends us a 'client_version' tuple of (implementation, version)
-- the Python client sends e.g. ("python", "0.11.0")
`implementation` VARCHAR,
`version` VARCHAR
);
CREATE INDEX `client_versions_time_idx` on `client_versions` (`connect_time`);
CREATE INDEX `client_versions_appid_time_idx` on `client_versions` (`app_id`, `connect_time`);

DELETE FROM `version`;
INSERT INTO `version` (`version`) VALUES (2);
2 changes: 1 addition & 1 deletion src/wormhole_mailbox_server/db-schemas/usage-v1.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
CREATE TABLE `version`
(
`version` INTEGER -- contains one row, set to 1
`version` INTEGER -- contains one row
);

CREATE TABLE `current`
Expand Down
61 changes: 61 additions & 0 deletions src/wormhole_mailbox_server/db-schemas/usage-v2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
CREATE TABLE `version`
(
`version` INTEGER -- contains one row
);

CREATE TABLE `current`
(
`rebooted` INTEGER, -- seconds since epoch of most recent reboot
`updated` INTEGER, -- when `current` was last updated
`blur_time` INTEGER, -- `started` is rounded to this, or None
`connections_websocket` INTEGER -- number of live clients via websocket
);

-- one row is created each time a nameplate is retired
CREATE TABLE `nameplates`
(
`app_id` VARCHAR,
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`total_time` INTEGER, -- seconds from open to last close/prune
`result` VARCHAR -- happy, lonely, pruney, crowded
-- nameplate moods:
-- "happy": two sides open and close
-- "lonely": one side opens and closes (no response from 2nd side)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
);
CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `started`);

-- one row is created each time a mailbox is retired
CREATE TABLE `mailboxes`
(
`app_id` VARCHAR,
`for_nameplate` BOOLEAN, -- allocated for a nameplate, not standalone
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`total_time` INTEGER, -- seconds from open to last close
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`result` VARCHAR -- happy, scary, lonely, errory, pruney
-- rendezvous moods:
-- "happy": both sides close with mood=happy
-- "scary": any side closes with mood=scary (bad MAC, probably wrong pw)
-- "lonely": any side closes with mood=lonely (no response from 2nd side)
-- "errory": any side closes with mood=errory (other errors)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
);
CREATE INDEX `mailboxes_idx` ON `mailboxes` (`app_id`, `started`);
CREATE INDEX `mailboxes_result_idx` ON `mailboxes` (`result`);

CREATE TABLE `client_versions`
(
`app_id` VARCHAR,
`side` VARCHAR, -- for deduplication of reconnects
`connect_time` INTEGER, -- seconds since epoch, rounded to "blur time"
-- the client sends us a 'client_version' tuple of (implementation, version)
-- the Python client sends e.g. ("python", "0.11.0")
`implementation` VARCHAR,
`version` VARCHAR
);
CREATE INDEX `client_versions_time_idx` on `client_versions` (`connect_time`);
CREATE INDEX `client_versions_appid_time_idx` on `client_versions` (`app_id`, `connect_time`);
13 changes: 13 additions & 0 deletions src/wormhole_mailbox_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,19 @@ def __init__(self, db, usage_db, blur_usage, log_requests, app_id,
self._mailboxes = {}
self._allow_list = allow_list

def log_client_version(self, server_rx, side, client_version):
if self._blur_usage:
server_rx = self._blur_usage * (server_rx // self._blur_usage)
implementation = client_version[0]
version = client_version[1]
self._usage_db.execute("INSERT INTO `client_versions`"
" (`app_id`, `side`, `connect_time`,"
" `implementation`, `version`)"
" VALUES(?,?,?,?,?)",
(self._app_id, side, server_rx,
implementation, version))
self._usage_db.commit()

def get_nameplate_ids(self):
if not self._allow_list:
return []
Expand Down
7 changes: 5 additions & 2 deletions src/wormhole_mailbox_server/server_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def onMessage(self, payload, isBinary):
if mtype == "ping":
return self.handle_ping(msg)
if mtype == "bind":
return self.handle_bind(msg)
return self.handle_bind(msg, server_rx)

if not self._app:
raise Error("must bind first")
Expand Down Expand Up @@ -161,7 +161,7 @@ def handle_ping(self, msg):
raise Error("ping requires 'ping'")
self.send("pong", pong=msg["ping"])

def handle_bind(self, msg):
def handle_bind(self, msg, server_rx):
if self._app or self._side:
raise Error("already bound")
if "appid" not in msg:
Expand All @@ -170,6 +170,9 @@ def handle_bind(self, msg):
raise Error("bind requires 'side'")
self._app = self.factory.server.get_app(msg["appid"])
self._side = msg["side"]
client_version = msg.get("client_version") # ("python", "0.xyz")
if client_version:
self._app.log_client_version(server_rx, self._side, client_version)


def handle_list(self):
Expand Down
72 changes: 43 additions & 29 deletions src/wormhole_mailbox_server/test/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,40 @@
from twisted.python import filepath
from twisted.trial import unittest
from .. import database
from ..database import _get_db, TARGET_VERSION, dump_db, DBError
from ..database import (CHANNELDB_TARGET_VERSION, USAGEDB_TARGET_VERSION,
_get_db, dump_db, DBError)

class Get(unittest.TestCase):
def test_create_default(self):
db_url = ":memory:"
db = _get_db(db_url, "channel")
db = _get_db(db_url, "channel", CHANNELDB_TARGET_VERSION)
rows = db.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], TARGET_VERSION)
self.assertEqual(rows[0]["version"], CHANNELDB_TARGET_VERSION)

def test_open_existing_file(self):
basedir = self.mktemp()
os.mkdir(basedir)
fn = os.path.join(basedir, "normal.db")
db = _get_db(fn, "channel")
db = _get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
rows = db.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], TARGET_VERSION)
db2 = _get_db(fn, "channel")
self.assertEqual(rows[0]["version"], CHANNELDB_TARGET_VERSION)
db2 = _get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
rows = db2.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], TARGET_VERSION)
self.assertEqual(rows[0]["version"], CHANNELDB_TARGET_VERSION)

def test_open_bad_version(self):
basedir = self.mktemp()
os.mkdir(basedir)
fn = os.path.join(basedir, "old.db")
db = _get_db(fn, "channel")
db = _get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
db.execute("UPDATE version SET version=999")
db.commit()

with self.assertRaises(DBError) as e:
_get_db(fn, "channel")
_get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
self.assertIn("Unable to handle db version 999", str(e.exception))

def test_open_corrupt(self):
Expand All @@ -45,55 +46,68 @@ def test_open_corrupt(self):
with open(fn, "wb") as f:
f.write(b"I am not a database")
with self.assertRaises(DBError) as e:
_get_db(fn, "channel")
_get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
self.assertIn("not a database", str(e.exception))

def test_failed_create_allows_subsequent_create(self):
patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema")
dbfile = filepath.FilePath(self.mktemp())
self.assertRaises(Exception, lambda: _get_db(dbfile.path))
patch.restore()
_get_db(dbfile.path, "channel")
_get_db(dbfile.path, "channel", CHANNELDB_TARGET_VERSION)

def OFF_test_upgrade(self): # disabled until we add a v2 schema
def test_upgrade(self):
basedir = self.mktemp()
os.mkdir(basedir)
fn = os.path.join(basedir, "upgrade.db")
self.assertNotEqual(TARGET_VERSION, 2)
self.assertNotEqual(USAGEDB_TARGET_VERSION, 1)

# create an old-version DB in a file
db = _get_db(fn, "channel", 2)
db = _get_db(fn, "usage", 1)
rows = db.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], 2)
self.assertEqual(rows[0]["version"], 1)
del db

# then upgrade the file to the latest version
dbA = _get_db(fn, "channel", TARGET_VERSION)
dbA = _get_db(fn, "usage", USAGEDB_TARGET_VERSION)
rows = dbA.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], TARGET_VERSION)
self.assertEqual(rows[0]["version"], USAGEDB_TARGET_VERSION)
dbA_text = dump_db(dbA)
del dbA

# make sure the upgrades got committed to disk
dbB = _get_db(fn, "channel", TARGET_VERSION)
dbB = _get_db(fn, "usage", USAGEDB_TARGET_VERSION)
dbB_text = dump_db(dbB)
del dbB
self.assertEqual(dbA_text, dbB_text)

# The upgraded schema should be equivalent to that of a new DB.
# However a text dump will differ because ALTER TABLE always appends
# the new column to the end of a table, whereas our schema puts it
# somewhere in the middle (wherever it fits naturally). Also ALTER
# TABLE doesn't include comments.
if False:
latest_db = _get_db(":memory:", "channel", TARGET_VERSION)
latest_text = dump_db(latest_db)
with open("up.sql","w") as f: f.write(dbA_text)
with open("new.sql","w") as f: f.write(latest_text)
# check with "diff -u _trial_temp/up.sql _trial_temp/new.sql"
self.assertEqual(dbA_text, latest_text)
latest_db = _get_db(":memory:", "usage", USAGEDB_TARGET_VERSION)
latest_text = dump_db(latest_db)
with open("up.sql","w") as f: f.write(dbA_text)
with open("new.sql","w") as f: f.write(latest_text)
# debug with "diff -u _trial_temp/up.sql _trial_temp/new.sql"
self.assertEqual(dbA_text, latest_text)

def test_upgrade_fails(self):
basedir = self.mktemp()
os.mkdir(basedir)
fn = os.path.join(basedir, "upgrade.db")
self.assertNotEqual(USAGEDB_TARGET_VERSION, 1)

# create an old-version DB in a file
db = _get_db(fn, "usage", 1)
rows = db.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], 1)
del db

# then upgrade the file to a too-new version, for which we have no
# upgrader
with self.assertRaises(DBError):
_get_db(fn, "usage", USAGEDB_TARGET_VERSION+1)

class CreateChannel(unittest.TestCase):
def test_memory(self):
Expand Down
Loading

0 comments on commit 7d90055

Please sign in to comment.