From 07791edc165a1f00f95549c8bdcaef0f55c854a0 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:14:39 -0700 Subject: [PATCH] Validation service scaffold (#320) * Scaffold MLS server * Update go.mod * Fix missing argument * Add unsaved file * Lint * Working end-to-end * Lint * Add new push action * Address review comments * Change method casing * Change casing of server options * Change casing of validation options * Remove unused function * Remove double pointer * Make private again --- .github/workflows/push-mls.yml | 34 ++++ cmd/xmtpd/main.go | 7 + go.mod | 19 +- go.sum | 58 +++--- pkg/api/config.go | 12 +- pkg/api/message/v3/service.go | 90 +++++++-- pkg/api/server.go | 6 +- .../mls/20231023050806_init-schema.down.sql | 8 + .../mls/20231023050806_init-schema.up.sql | 44 +++++ pkg/migrations/mls/migrations.go | 18 ++ pkg/mlsstore/models.go | 24 +++ pkg/mlsstore/store.go | 157 +++++++++++++++- pkg/mlsstore/store_test.go | 171 ++++++++++++++++++ pkg/mlsvalidate/config.go | 5 + pkg/mlsvalidate/service.go | 107 +++++++++++ pkg/mlsvalidate/service_test.go | 63 +++++++ pkg/server/options.go | 19 +- pkg/server/server.go | 51 +++++- pkg/testing/store.go | 15 ++ 19 files changed, 831 insertions(+), 77 deletions(-) create mode 100644 .github/workflows/push-mls.yml create mode 100644 pkg/migrations/mls/20231023050806_init-schema.down.sql create mode 100644 pkg/migrations/mls/20231023050806_init-schema.up.sql create mode 100644 pkg/migrations/mls/migrations.go create mode 100644 pkg/mlsstore/models.go create mode 100644 pkg/mlsstore/store_test.go create mode 100644 pkg/mlsvalidate/config.go create mode 100644 pkg/mlsvalidate/service.go create mode 100644 pkg/mlsvalidate/service_test.go diff --git a/.github/workflows/push-mls.yml b/.github/workflows/push-mls.yml new file mode 100644 index 00000000..8ad7f1e5 --- /dev/null +++ b/.github/workflows/push-mls.yml @@ -0,0 +1,34 @@ +name: Deploy Nodes +on: + push: + branches: + - mls +jobs: + deploy: + concurrency: main + runs-on: ubuntu-latest + steps: + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + + - name: Login to Docker Hub + uses: docker/login-action@v1 + with: + username: xmtpeng + password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} + + - name: Git Checkout + uses: actions/checkout@v3 + + - uses: actions/setup-go@v3 + with: + go-version-file: go.mod + + - name: Push + id: push + run: | + export DOCKER_IMAGE_TAG=mls + IMAGE_TO_DEPLOY=xmtp/node-go@$(dev/docker/build) + echo Successfully pushed $IMAGE_TO_DEPLOY + echo "docker_image=${IMAGE_TO_DEPLOY}" >> $GITHUB_OUTPUT diff --git a/cmd/xmtpd/main.go b/cmd/xmtpd/main.go index cd948fff..c26364a5 100644 --- a/cmd/xmtpd/main.go +++ b/cmd/xmtpd/main.go @@ -92,6 +92,13 @@ func main() { return } + if options.CreateMlsMigration != "" && options.MLSStore.DbConnectionString != "" { + if err := server.CreateMlsMigration(options.CreateMlsMigration, options.MLSStore.DbConnectionString, options.WaitForDB, options.MLSStore.ReadTimeout, options.MLSStore.WriteTimeout, options.Store.MaxOpenConns); err != nil { + log.Fatal("creating authz db migration", zap.Error(err)) + } + return + } + if options.Tracing.Enable { log.Info("starting tracer") tracing.Start(Commit, utils.Logger()) diff --git a/go.mod b/go.mod index 026e3958..e86b30f8 100644 --- a/go.mod +++ b/go.mod @@ -25,12 +25,12 @@ require ( github.com/prometheus/client_golang v1.14.0 github.com/stretchr/testify v1.8.4 github.com/swaggest/swgui v1.6.2 - github.com/uptrace/bun v1.1.3 - github.com/uptrace/bun/dialect/pgdialect v1.1.3 - github.com/uptrace/bun/driver/pgdriver v1.1.3 + github.com/uptrace/bun v1.1.16 + github.com/uptrace/bun/dialect/pgdialect v1.1.16 + github.com/uptrace/bun/driver/pgdriver v1.1.16 github.com/waku-org/go-waku v0.8.0 github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 - github.com/xmtp/proto/v3 v3.30.0 + github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4 github.com/yoheimuta/protolint v0.39.0 go.uber.org/zap v1.24.0 golang.org/x/sync v0.3.0 @@ -108,7 +108,7 @@ require ( github.com/libp2p/go-yamux/v4 v4.0.1 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect - github.com/mattn/go-colorable v0.1.11 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/miekg/dns v1.1.55 // indirect @@ -148,6 +148,7 @@ require ( github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect github.com/shurcooL/httpgzip v0.0.0-20190720172056-320755c1c1b0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a // indirect github.com/tinylib/msgp v1.1.2 // indirect github.com/tklauser/go-sysconf v0.3.5 // indirect @@ -167,12 +168,12 @@ require ( go.uber.org/dig v1.17.0 // indirect go.uber.org/fx v1.20.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.12.0 // indirect + golang.org/x/crypto v0.13.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/time v0.0.0-20220922220347-f3bd1da661af // indirect golang.org/x/tools v0.12.1-0.20230818130535-1517d1a3ba60 // indirect golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect @@ -181,7 +182,7 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.2.1 // indirect - mellium.im/sasl v0.2.1 // indirect + mellium.im/sasl v0.3.1 // indirect ) // From node-go diff --git a/go.sum b/go.sum index 165565fe..506a1663 100644 --- a/go.sum +++ b/go.sum @@ -94,6 +94,7 @@ github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20191024131854-af6fa24be0db/go.mod h1:VTxUBvSJ3s3eHAg65PNgrsn5BtqCRPdmyXh6rAfdxN0= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= @@ -298,6 +299,7 @@ github.com/go-ldap/ldap/v3 v3.1.3/go.mod h1:3rbOH3jRS2u6jg2rJnKAMLE/xQyCKIveG2Sa github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E= @@ -633,16 +635,19 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jsternberg/zap-logfmt v1.0.0/go.mod h1:uvPs/4X51zdkcm5jXl5SYoN+4RK21K8mysFmDaM/h+o= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/jwilder/encoding v0.0.0-20170811194829-b4e1701a28ef/go.mod h1:Ct9fl0F6iIOGgxJ5npU/IUOhOhqlVrGjyIZc8/MagT0= github.com/karalabe/usb v0.0.2/go.mod h1:Od972xHfMJowv7NGVDiWVxk2zxnWgjLlJzE+F4F7AGU= @@ -733,8 +738,9 @@ github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVc github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-ieproxy v0.0.0-20190610004146-91bb50d98149/go.mod h1:31jz6HNzdxOmlERGGEc4v/dMssOfmp2p5bT/okiKFFc= github.com/mattn/go-ieproxy v0.0.0-20190702010315-6dee0af9227d/go.mod h1:31jz6HNzdxOmlERGGEc4v/dMssOfmp2p5bT/okiKFFc= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= @@ -747,6 +753,7 @@ github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcME github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= @@ -803,6 +810,7 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180320133207-05fbef0ca5da/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/mr-tron/base58 v1.1.2/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= @@ -834,6 +842,7 @@ github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/n github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/naoina/go-stringutil v0.1.0/go.mod h1:XJ2SJL9jCtBh+P9q5btrd/Ylo8XwT/h1USek5+NqSA0= github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416/go.mod h1:NBIhNtsFMo3G2szEBne+bO4gS192HuIYRqfvOWb4i1E= @@ -965,6 +974,7 @@ github.com/retailnext/hllpp v1.0.1-0.20180308014038-101a6d2f8b52/go.mod h1:RDpi1 github.com/rjeczalik/notify v0.9.1/go.mod h1:rKwnCoCGeuQnwBtTSPL9Dad03Vh2n40ePRrjvIXnJho= github.com/rjeczalik/notify v0.9.3 h1:6rJAzHTGKXGj76sbRgDiDcYj/HniypXmSJo1SWakZeY= github.com/rjeczalik/notify v0.9.3/go.mod h1:gF3zSOrafR9DQEWSE8TjfI9NkooDxbyT4UgRGKZA0lc= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= @@ -1037,7 +1047,9 @@ github.com/status-im/keycard-go v0.0.0-20190316090335-8537d3370df4/go.mod h1:RZL github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= @@ -1046,7 +1058,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/swaggest/swgui v1.6.2 h1:DR+ioYt11YrXMaEmLcgaOEFSZ/8QW30uYYE/Ck41cPA= @@ -1086,12 +1100,12 @@ github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef h1:wHSqTBrZ github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef/go.mod h1:sJ5fKU0s6JVwZjjcUEX2zFOnvq0ASQ2K9Zr6cf67kNs= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/uptrace/bun v1.1.3 h1:v62tsUyKjVCR5q7J49uckM6CVVTqMO26aV73F3G6RFk= -github.com/uptrace/bun v1.1.3/go.mod h1:aQbKvxs7/n9MMef/b8lYOh5Rwlo4Jd5A31E4HlYNqSc= -github.com/uptrace/bun/dialect/pgdialect v1.1.3 h1:EMRCC98YKSpo/EXyujsr+5v0PKYkRE0rwxJKKEcrOuE= -github.com/uptrace/bun/dialect/pgdialect v1.1.3/go.mod h1:2GJogfkVHmCKxt6N88vRbJNSUV5wfPym/rp6N25dShc= -github.com/uptrace/bun/driver/pgdriver v1.1.3 h1:WWxEfGnJQCXgODtjU37E+XWEVvCGwvs2fRgCYFqmKAY= -github.com/uptrace/bun/driver/pgdriver v1.1.3/go.mod h1:D7tTNXLIR9udcf/Dm9W+x1qvY+GDCkYVIRLgQyMElCY= +github.com/uptrace/bun v1.1.16 h1:cn9cgEMFwcyYRsQLfxCRMUxyK1WaHwOVrR3TvzEFZ/A= +github.com/uptrace/bun v1.1.16/go.mod h1:7HnsMRRvpLFUcquJxp22JO8PsWKpFQO/gNXqqsuGWg8= +github.com/uptrace/bun/dialect/pgdialect v1.1.16 h1:eUPZ+YCJ69BA+W1X1ZmpOJSkv1oYtinr0zCXf7zCo5g= +github.com/uptrace/bun/dialect/pgdialect v1.1.16/go.mod h1:KQjfx/r6JM0OXfbv0rFrxAbdkPD7idK8VitnjIV9fZI= +github.com/uptrace/bun/driver/pgdriver v1.1.16 h1:b/NiSXk6Ldw7KLfMLbOqIkm4odHd7QiNOCPLqPFJjK4= +github.com/uptrace/bun/driver/pgdriver v1.1.16/go.mod h1:Rmfbc+7lx1z/umjMyAxkOHK81LgnGj71XC5YpA6k1vU= github.com/urfave/cli v1.22.2 h1:gsqYFH8bb9ekPA12kRo0hfjngWQjkJPlN9R0N78BoUo= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= @@ -1140,16 +1154,10 @@ github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0 github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg= github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 h1:wzUffJGCTBGXIDyNU+1UBu1fn2Nzo+OQzM1pLrheh58= github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3/go.mod h1:bJREWk+NDnZYjgLQdAi8SUWuq/5pkMme4GqiffEhUF4= -github.com/xmtp/proto/v3 v3.27.0 h1:G70006UEffkCmWvp9G/7Dywosj1sLm9StR5HWEb891U= -github.com/xmtp/proto/v3 v3.27.0/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= -github.com/xmtp/proto/v3 v3.29.1-0.20231019020501-b49bc6ffb5eb h1:q2lR64lGFehm8m0FtcdRDMeH8MlkMyU4sz235+Ufq9E= -github.com/xmtp/proto/v3 v3.29.1-0.20231019020501-b49bc6ffb5eb/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= -github.com/xmtp/proto/v3 v3.29.1-0.20231019022514-1cc4b0d5a51a h1:kDEPyzhqQO9YdRAfvl21ysitvzWjdu4Ai8YCvHwqqbY= -github.com/xmtp/proto/v3 v3.29.1-0.20231019022514-1cc4b0d5a51a/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= -github.com/xmtp/proto/v3 v3.29.1-0.20231019163152-2a17d00f45f4 h1:Mxnc833msN9gX8DJyELd+E7oUJNHlhbIsZlTd88kg5M= -github.com/xmtp/proto/v3 v3.29.1-0.20231019163152-2a17d00f45f4/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= -github.com/xmtp/proto/v3 v3.30.0 h1:x6LGCWpO2HTQNhUiTXfE0l+u2HSL3Z35p41xhgy6hlw= -github.com/xmtp/proto/v3 v3.30.0/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= +github.com/xmtp/proto/v3 v3.29.1-0.20231019225839-328520e94f34 h1:rR10cJ5RTlw7OWdw5fnNF3WpByRybvGC3xmOmELd7JY= +github.com/xmtp/proto/v3 v3.29.1-0.20231019225839-328520e94f34/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= +github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4 h1:Qc2ed8NrlosJnPMNxVriugcFB21d4V90HKZdO83yV2M= +github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/yoheimuta/go-protoparser/v4 v4.6.0 h1:uvz1e9/5Ihsm4Ku8AJeDImTpirKmIxubZdSn0QJNdnw= github.com/yoheimuta/go-protoparser/v4 v4.6.0/go.mod h1:AHNNnSWnb0UoL4QgHPiOAg2BniQceFscPI5X/BZNHl8= @@ -1239,8 +1247,8 @@ golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= -golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1439,10 +1447,11 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1456,8 +1465,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1705,8 +1714,9 @@ k8s.io/kube-openapi v0.0.0-20191107075043-30be4d16710a/go.mod h1:1TqjTSzOxsLGIKf k8s.io/utils v0.0.0-20191114184206-e782cd3c129f/go.mod h1:sZAwmy6armz5eXlNoLmJcl4F1QuKu7sr+mFQ0byX7Ew= lukechampine.com/blake3 v1.2.1 h1:YuqqRuaqsGV71BV/nm9xlI0MKUv4QC54jQnBChWbGnI= lukechampine.com/blake3 v1.2.1/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k= -mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w= mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ= +mellium.im/sasl v0.3.1 h1:wE0LW6g7U83vhvxjC1IY8DnXM+EU095yeo8XClvCdfo= +mellium.im/sasl v0.3.1/go.mod h1:xm59PUYpZHhgQ9ZqoJ5QaCqzWMi8IeS49dhp6plPCzw= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= diff --git a/pkg/api/config.go b/pkg/api/config.go index f61d8ff4..723679bf 100644 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -8,6 +8,7 @@ import ( wakunode "github.com/waku-org/go-waku/waku/v2/node" "github.com/xmtp/xmtp-node-go/pkg/authz" "github.com/xmtp/xmtp-node-go/pkg/mlsstore" + "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" "github.com/xmtp/xmtp-node-go/pkg/ratelimiter" "github.com/xmtp/xmtp-node-go/pkg/store" "go.uber.org/zap" @@ -31,11 +32,12 @@ type Options struct { type Config struct { Options - AllowLister authz.WalletAllowLister - Waku *wakunode.WakuNode - Log *zap.Logger - Store *store.Store - MlsStore mlsstore.MlsStore + AllowLister authz.WalletAllowLister + Waku *wakunode.WakuNode + Log *zap.Logger + Store *store.Store + MLSStore mlsstore.MlsStore + MLSValidator mlsvalidate.MLSValidationService } // AuthnOptions bundle command line options associated with the authn package. diff --git a/pkg/api/message/v3/service.go b/pkg/api/message/v3/service.go index 54df462b..9ea61de4 100644 --- a/pkg/api/message/v3/service.go +++ b/pkg/api/message/v3/service.go @@ -6,6 +6,7 @@ import ( wakunode "github.com/waku-org/go-waku/waku/v2/node" proto "github.com/xmtp/proto/v3/go/message_api/v3" "github.com/xmtp/xmtp-node-go/pkg/mlsstore" + "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" "github.com/xmtp/xmtp-node-go/pkg/store" "go.uber.org/zap" "google.golang.org/grpc/codes" @@ -16,21 +17,23 @@ import ( type Service struct { proto.UnimplementedMlsApiServer - log *zap.Logger - waku *wakunode.WakuNode - messageStore *store.Store - mlsStore mlsstore.MlsStore + log *zap.Logger + waku *wakunode.WakuNode + messageStore *store.Store + mlsStore mlsstore.MlsStore + validationService mlsvalidate.MLSValidationService ctx context.Context ctxCancel func() } -func NewService(node *wakunode.WakuNode, logger *zap.Logger, messageStore *store.Store, mlsStore mlsstore.MlsStore) (s *Service, err error) { +func NewService(node *wakunode.WakuNode, logger *zap.Logger, messageStore *store.Store, mlsStore mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService) (s *Service, err error) { s = &Service{ - log: logger.Named("message/v3"), - waku: node, - messageStore: messageStore, - mlsStore: mlsStore, + log: logger.Named("message/v3"), + waku: node, + messageStore: messageStore, + mlsStore: mlsStore, + validationService: validationService, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) @@ -45,11 +48,54 @@ func (s *Service) Close() { } func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterInstallationRequest) (*proto.RegisterInstallationResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") + results, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{req.LastResortKeyPackage.KeyPackageTlsSerialized}) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) + } + if len(results) != 1 { + return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results)) + } + + installationId := results[0].InstallationId + walletAddress := results[0].WalletAddress + + err = s.mlsStore.CreateInstallation(ctx, installationId, walletAddress, req.LastResortKeyPackage.KeyPackageTlsSerialized) + if err != nil { + return nil, err + } + + return &proto.RegisterInstallationResponse{ + InstallationId: installationId, + }, nil } func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyPackagesRequest) (*proto.ConsumeKeyPackagesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") + ids := req.InstallationIds + keyPackages, err := s.mlsStore.ConsumeKeyPackages(ctx, ids) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to consume key packages: %s", err) + } + keyPackageMap := make(map[string]int) + for idx, id := range ids { + keyPackageMap[id] = idx + } + + resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages)) + for _, keyPackage := range keyPackages { + + idx, ok := keyPackageMap[keyPackage.InstallationId] + if !ok { + return nil, status.Errorf(codes.Internal, "could not find key package for installation") + } + + resPackages[idx] = &proto.ConsumeKeyPackagesResponse_KeyPackage{ + KeyPackageTlsSerialized: keyPackage.Data, + } + } + + return &proto.ConsumeKeyPackagesResponse{ + KeyPackages: resPackages, + }, nil } func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupRequest) (*emptypb.Empty, error) { @@ -61,7 +107,27 @@ func (s *Service) PublishWelcomes(ctx context.Context, req *proto.PublishWelcome } func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPackagesRequest) (*emptypb.Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") + keyPackageBytes := make([][]byte, len(req.KeyPackages)) + for i, keyPackage := range req.KeyPackages { + keyPackageBytes[i] = keyPackage.KeyPackageTlsSerialized + } + validationResults, err := s.validationService.ValidateKeyPackages(ctx, keyPackageBytes) + if err != nil { + // TODO: Differentiate between validation errors and internal errors + return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) + } + + keyPackageModels := make([]*mlsstore.KeyPackage, len(validationResults)) + for i, validationResult := range validationResults { + kp := mlsstore.NewKeyPackage(validationResult.InstallationId, keyPackageBytes[i], false) + keyPackageModels[i] = kp + } + err = s.mlsStore.InsertKeyPackages(ctx, keyPackageModels) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) + } + + return &emptypb.Empty{}, nil } func (s *Service) RevokeInstallation(ctx context.Context, req *proto.RevokeInstallationRequest) (*emptypb.Empty, error) { diff --git a/pkg/api/server.go b/pkg/api/server.go index e7f1b643..9648c3a1 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -130,8 +130,8 @@ func (s *Server) startGRPC() error { proto.RegisterMessageApiServer(grpcServer, s.messagev1) // Enable the MLS server if a store is provided - if s.Config.MlsStore != nil && s.Config.EnableMls { - s.messagev3, err = messagev3.NewService(s.Waku, s.Log, s.Store, s.Config.MlsStore) + if s.Config.MLSStore != nil && s.Config.MLSValidator != nil && s.Config.EnableMls { + s.messagev3, err = messagev3.NewService(s.Waku, s.Log, s.Store, s.Config.MLSStore, s.Config.MLSValidator) if err != nil { return errors.Wrap(err, "creating mls service") } @@ -183,7 +183,7 @@ func (s *Server) startHTTP() error { return errors.Wrap(err, "registering message handler") } - if s.Config.MlsStore != nil && s.Config.EnableMls { + if s.Config.MLSStore != nil && s.Config.EnableMls { err = v3Proto.RegisterMlsApiHandler(s.ctx, gwmux, conn) if err != nil { return errors.Wrap(err, "registering mls handler") diff --git a/pkg/migrations/mls/20231023050806_init-schema.down.sql b/pkg/migrations/mls/20231023050806_init-schema.down.sql new file mode 100644 index 00000000..b1077270 --- /dev/null +++ b/pkg/migrations/mls/20231023050806_init-schema.down.sql @@ -0,0 +1,8 @@ +SET + statement_timeout = 0; + +--bun:split +DROP TABLE IF EXISTS installations; + +--bun:split +DROP TABLE IF EXISTS key_packages; \ No newline at end of file diff --git a/pkg/migrations/mls/20231023050806_init-schema.up.sql b/pkg/migrations/mls/20231023050806_init-schema.up.sql new file mode 100644 index 00000000..b6aa90d5 --- /dev/null +++ b/pkg/migrations/mls/20231023050806_init-schema.up.sql @@ -0,0 +1,44 @@ +SET + statement_timeout = 0; + +--bun:split +CREATE TABLE installations ( + id TEXT PRIMARY KEY, + wallet_address TEXT NOT NULL, + created_at BIGINT NOT NULL, + revoked_at BIGINT +); + +--bun:split +CREATE TABLE key_packages ( + id TEXT PRIMARY KEY, + installation_id TEXT NOT NULL, + created_at BIGINT NOT NULL, + consumed_at BIGINT, + not_consumed BOOLEAN DEFAULT TRUE NOT NULL, + is_last_resort BOOLEAN NOT NULL, + data BYTEA NOT NULL, + -- Add a foreign key constraint to ensure key packages cannot be added for unregistered installations + CONSTRAINT fk_installation_id FOREIGN KEY (installation_id) REFERENCES installations (id) +); + +--bun:split +CREATE INDEX idx_installations_wallet_address ON installations(wallet_address); + +--bun:split +CREATE INDEX idx_installations_created_at ON installations(created_at); + +--bun:split +CREATE INDEX idx_installations_revoked_at ON installations(revoked_at); + +--bun:split +-- Adding indexes for the key_packages table +CREATE INDEX idx_key_packages_installation_id_not_is_last_resort_created_at ON key_packages( + installation_id, + not_consumed, + is_last_resort, + created_at +); + +--bun:split +CREATE INDEX idx_key_packages_is_last_resort_id ON key_packages(is_last_resort, id); \ No newline at end of file diff --git a/pkg/migrations/mls/migrations.go b/pkg/migrations/mls/migrations.go new file mode 100644 index 00000000..860e0445 --- /dev/null +++ b/pkg/migrations/mls/migrations.go @@ -0,0 +1,18 @@ +package mls + +import ( + "embed" + + "github.com/uptrace/bun/migrate" +) + +var Migrations = migrate.NewMigrations() + +//go:embed *.sql +var sqlMigrations embed.FS + +func init() { + if err := Migrations.Discover(sqlMigrations); err != nil { + panic(err) + } +} diff --git a/pkg/mlsstore/models.go b/pkg/mlsstore/models.go new file mode 100644 index 00000000..533c595d --- /dev/null +++ b/pkg/mlsstore/models.go @@ -0,0 +1,24 @@ +package mlsstore + +import "github.com/uptrace/bun" + +type Installation struct { + bun.BaseModel `bun:"table:installations"` + + ID string `bun:",pk"` + WalletAddress string `bun:"wallet_address,notnull"` + CreatedAt int64 `bun:"created_at,notnull"` + RevokedAt *int64 `bun:"revoked_at"` +} + +type KeyPackage struct { + bun.BaseModel `bun:"table:key_packages"` + + ID string `bun:",pk"` // ID is the hash of the data field + InstallationId string `bun:"installation_id,notnull"` + CreatedAt int64 `bun:"created_at,notnull"` + ConsumedAt *int64 `bun:"consumed_at"` + NotConsumed bool `bun:"not_consumed,default:true"` + IsLastResort bool `bun:"is_last_resort,notnull"` + Data []byte `bun:"data,notnull,type:bytea"` +} diff --git a/pkg/mlsstore/store.go b/pkg/mlsstore/store.go index 4dcfd277..f528b78d 100644 --- a/pkg/mlsstore/store.go +++ b/pkg/mlsstore/store.go @@ -2,27 +2,172 @@ package mlsstore import ( "context" + "crypto/sha256" + "database/sql" + "encoding/hex" + "errors" + "time" "github.com/uptrace/bun" + "github.com/uptrace/bun/migrate" + "github.com/xmtp/xmtp-node-go/pkg/migrations/messages" "go.uber.org/zap" ) type Store struct { - ctx context.Context - cancel context.CancelFunc + config Config log *zap.Logger db *bun.DB } type MlsStore interface { + CreateInstallation(ctx context.Context, installationId string, walletAddress string, lastResortKeyPackage []byte) error + InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error + ConsumeKeyPackages(ctx context.Context, installationIds []string) ([]*KeyPackage, error) } -func New(config Config) (*Store, error) { +func New(ctx context.Context, config Config) (*Store, error) { s := &Store{ - log: config.Log.Named("mlsstore"), - db: config.DB, + log: config.Log.Named("mlsstore"), + db: config.DB, + config: config, + } + + if err := s.migrate(ctx); err != nil { + return nil, err } - s.ctx, s.cancel = context.WithCancel(context.Background()) return s, nil } + +func (s *Store) Close() { + if s.db != nil { + s.db.Close() + } +} + +// Creates the installation and last resort key package +func (s *Store) CreateInstallation(ctx context.Context, installationId string, walletAddress string, lastResortKeyPackage []byte) error { + createdAt := nowNs() + + installation := Installation{ + ID: installationId, + WalletAddress: walletAddress, + CreatedAt: createdAt, + } + + keyPackage := NewKeyPackage(installationId, lastResortKeyPackage, true) + + return s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert(). + Model(&installation). + Ignore(). + Exec(ctx) + + if err != nil { + return err + } + + _, err = tx.NewInsert(). + Model(keyPackage). + Ignore(). + Exec(ctx) + + if err != nil { + return err + } + + return nil + }) +} + +// Insert a batch of key packages, ignoring any that may already exist +func (s *Store) InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error { + _, err := s.db.NewInsert().Model(&keyPackages).Ignore().Exec(ctx) + return err +} + +func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []string) ([]*KeyPackage, error) { + keyPackages := make([]*KeyPackage, 0) + err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + err := tx.NewRaw(` + SELECT DISTINCT ON(installation_id) * FROM key_packages + WHERE "installation_id" IN (?) + AND not_consumed = TRUE + ORDER BY installation_id ASC, is_last_resort ASC, created_at ASC + `, + bun.In(installationIds)). + Scan(ctx, &keyPackages) + + if err != nil { + return err + } + + if len(keyPackages) < len(installationIds) { + return errors.New("key packages not found") + } + + _, err = tx.NewUpdate(). + Table("key_packages"). + Set("consumed_at = ?", nowNs()). + Set("not_consumed = FALSE"). + Where("is_last_resort = FALSE"). + Where("id IN (?)", bun.In(extractIds(keyPackages))). + Exec(ctx) + + return err + }) + + if err != nil { + return nil, err + } + + return keyPackages, nil +} + +func NewKeyPackage(installationId string, data []byte, isLastResort bool) *KeyPackage { + return &KeyPackage{ + ID: buildKeyPackageId(data), + InstallationId: installationId, + CreatedAt: nowNs(), + IsLastResort: isLastResort, + NotConsumed: true, + Data: data, + } +} + +func extractIds(keyPackages []*KeyPackage) []string { + out := make([]string, len(keyPackages)) + for i, keyPackage := range keyPackages { + out[i] = keyPackage.ID + } + return out +} + +func (s *Store) migrate(ctx context.Context) error { + migrator := migrate.NewMigrator(s.db, messages.Migrations) + err := migrator.Init(ctx) + if err != nil { + return err + } + + group, err := migrator.Migrate(ctx) + if err != nil { + return err + } + + if group.IsZero() { + s.log.Info("No new migrations to run") + } + + return nil +} + +func nowNs() int64 { + return time.Now().UTC().UnixNano() +} + +func buildKeyPackageId(keyPackageData []byte) string { + digest := sha256.Sum256(keyPackageData) + return hex.EncodeToString(digest[:]) +} diff --git a/pkg/mlsstore/store_test.go b/pkg/mlsstore/store_test.go new file mode 100644 index 00000000..026b3da9 --- /dev/null +++ b/pkg/mlsstore/store_test.go @@ -0,0 +1,171 @@ +package mlsstore + +import ( + "context" + "crypto/rand" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + test "github.com/xmtp/xmtp-node-go/pkg/testing" +) + +func NewTestStore(t *testing.T) (*Store, func()) { + log := test.NewLog(t) + db, _, dbCleanup := test.NewMlsDB(t) + ctx := context.Background() + c := Config{ + Log: log, + DB: db, + } + + store, err := New(ctx, c) + require.NoError(t, err) + + return store, dbCleanup +} + +func randomBytes(n int) []byte { + b := make([]byte, n) + _, _ = rand.Reader.Read(b) + return b +} + +func randomString(n int) string { + return fmt.Sprintf("%x", randomBytes(n)) +} + +func TestCreateInstallation(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + installationId := randomString(32) + walletAddress := randomString(32) + + err := store.CreateInstallation(ctx, installationId, walletAddress, randomBytes(32)) + require.NoError(t, err) + + installationFromDb := &Installation{} + require.NoError(t, store.db.NewSelect().Model(installationFromDb).Where("id = ?", installationId).Scan(ctx)) + require.Equal(t, walletAddress, installationFromDb.WalletAddress) + + keyPackageFromDB := &KeyPackage{} + require.NoError(t, store.db.NewSelect().Model(keyPackageFromDB).Where("installation_id = ?", installationId).Scan(ctx)) + require.Equal(t, installationId, keyPackageFromDB.InstallationId) +} + +func TestCreateInstallationIdempotent(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + installationId := randomString(32) + walletAddress := randomString(32) + keyPackage := randomBytes(32) + + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + require.NoError(t, err) + err = store.CreateInstallation(ctx, installationId, walletAddress, randomBytes(32)) + require.NoError(t, err) + + keyPackageFromDb := &KeyPackage{} + require.NoError(t, store.db.NewSelect().Model(keyPackageFromDb).Where("installation_id = ?", installationId).Scan(ctx)) + require.Equal(t, keyPackage, keyPackageFromDb.Data) +} + +func TestInsertKeyPackages(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + installationId := randomString(32) + walletAddress := randomString(32) + keyPackage := randomBytes(32) + + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + require.NoError(t, err) + + keyPackage2 := randomBytes(32) + err = store.InsertKeyPackages(ctx, []*KeyPackage{{ + ID: buildKeyPackageId(keyPackage2), + InstallationId: installationId, + CreatedAt: nowNs(), + IsLastResort: false, + Data: keyPackage2, + }}) + require.NoError(t, err) + + keyPackagesFromDb := []*KeyPackage{} + require.NoError(t, store.db.NewSelect().Model(&keyPackagesFromDb).Where("installation_id = ?", installationId).Scan(ctx)) + require.Len(t, keyPackagesFromDb, 2) + + hasLastResort := false + hasRegular := false + for _, keyPackageFromDb := range keyPackagesFromDb { + require.Equal(t, installationId, keyPackageFromDb.InstallationId) + if keyPackageFromDb.IsLastResort { + hasLastResort = true + } + if !keyPackageFromDb.IsLastResort { + hasRegular = true + require.Equal(t, keyPackage2, keyPackageFromDb.Data) + } + } + + require.True(t, hasLastResort) + require.True(t, hasRegular) +} + +func TestConsumeLastResortKeyPackage(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + installationId := randomString(32) + walletAddress := randomString(32) + keyPackage := randomBytes(32) + + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + require.NoError(t, err) + + consumeResult, err := store.ConsumeKeyPackages(ctx, []string{installationId}) + require.NoError(t, err) + require.Len(t, consumeResult, 1) + require.Equal(t, keyPackage, consumeResult[0].Data) + require.Equal(t, installationId, consumeResult[0].InstallationId) +} + +func TestConsumeMultipleKeyPackages(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + installationId := randomString(32) + walletAddress := randomString(32) + keyPackage := randomBytes(32) + + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + require.NoError(t, err) + + keyPackage2 := randomBytes(32) + require.NoError(t, store.InsertKeyPackages(ctx, []*KeyPackage{{ + ID: buildKeyPackageId(keyPackage2), + InstallationId: installationId, + CreatedAt: nowNs(), + IsLastResort: false, + Data: keyPackage2, + }})) + + consumeResult, err := store.ConsumeKeyPackages(ctx, []string{installationId}) + require.NoError(t, err) + require.Len(t, consumeResult, 1) + require.Equal(t, keyPackage2, consumeResult[0].Data) + require.Equal(t, installationId, consumeResult[0].InstallationId) + + consumeResult, err = store.ConsumeKeyPackages(ctx, []string{installationId}) + require.NoError(t, err) + require.Len(t, consumeResult, 1) + // Now we are out of regular key packages. Expect to consume the last resort + require.Equal(t, keyPackage, consumeResult[0].Data) +} diff --git a/pkg/mlsvalidate/config.go b/pkg/mlsvalidate/config.go new file mode 100644 index 00000000..dd644eab --- /dev/null +++ b/pkg/mlsvalidate/config.go @@ -0,0 +1,5 @@ +package mlsvalidate + +type MLSValidationOptions struct { + GRPCAddress string `long:"grpc-address" description:"Address for the GRPC validation service"` +} diff --git a/pkg/mlsvalidate/service.go b/pkg/mlsvalidate/service.go new file mode 100644 index 00000000..36756ee2 --- /dev/null +++ b/pkg/mlsvalidate/service.go @@ -0,0 +1,107 @@ +package mlsvalidate + +import ( + "context" + "fmt" + + svc "github.com/xmtp/proto/v3/go/mls_validation/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type IdentityValidationResult struct { + WalletAddress string + InstallationId string +} + +type GroupMessageValidationResult struct { + GroupId string +} + +type IdentityInput struct { + SigningPublicKey []byte + Identity []byte +} + +type MLSValidationService interface { + ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]IdentityValidationResult, error) + ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]GroupMessageValidationResult, error) +} + +type MLSValidationServiceImpl struct { + grpcClient svc.ValidationApiClient +} + +func NewMlsValidationService(ctx context.Context, options MLSValidationOptions) (*MLSValidationServiceImpl, error) { + conn, err := grpc.DialContext(ctx, options.GRPCAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + return &MLSValidationServiceImpl{ + grpcClient: svc.NewValidationApiClient(conn), + }, nil +} + +func (s *MLSValidationServiceImpl) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]IdentityValidationResult, error) { + req := makeValidateKeyPackageRequest(keyPackages) + response, err := s.grpcClient.ValidateKeyPackages(ctx, req) + if err != nil { + return nil, err + } + out := make([]IdentityValidationResult, len(response.Responses)) + for i, response := range response.Responses { + if !response.IsOk { + return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage) + } + out[i] = IdentityValidationResult{ + WalletAddress: response.WalletAddress, + InstallationId: response.InstallationId, + } + } + return out, nil +} + +func makeValidateKeyPackageRequest(keyPackageBytes [][]byte) *svc.ValidateKeyPackagesRequest { + keyPackageRequests := make([]*svc.ValidateKeyPackagesRequest_KeyPackage, len(keyPackageBytes)) + for i, keyPackage := range keyPackageBytes { + keyPackageRequests[i] = &svc.ValidateKeyPackagesRequest_KeyPackage{ + KeyPackageBytesTlsSerialized: keyPackage, + } + } + return &svc.ValidateKeyPackagesRequest{ + KeyPackages: keyPackageRequests, + } +} + +func (s *MLSValidationServiceImpl) ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]GroupMessageValidationResult, error) { + req := makeValidateGroupMessagesRequest(groupMessages) + + response, err := s.grpcClient.ValidateGroupMessages(ctx, req) + if err != nil { + return nil, err + } + + out := make([]GroupMessageValidationResult, len(response.Responses)) + for i, response := range response.Responses { + if !response.IsOk { + return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage) + } + out[i] = GroupMessageValidationResult{ + GroupId: response.GroupId, + } + } + + return out, nil +} + +func makeValidateGroupMessagesRequest(groupMessages [][]byte) *svc.ValidateGroupMessagesRequest { + groupMessageRequests := make([]*svc.ValidateGroupMessagesRequest_GroupMessage, len(groupMessages)) + for i, groupMessage := range groupMessages { + groupMessageRequests[i] = &svc.ValidateGroupMessagesRequest_GroupMessage{ + GroupMessageBytesTlsSerialized: groupMessage, + } + } + return &svc.ValidateGroupMessagesRequest{ + GroupMessages: groupMessageRequests, + } +} diff --git a/pkg/mlsvalidate/service_test.go b/pkg/mlsvalidate/service_test.go new file mode 100644 index 00000000..2fb2cca3 --- /dev/null +++ b/pkg/mlsvalidate/service_test.go @@ -0,0 +1,63 @@ +package mlsvalidate + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + svc "github.com/xmtp/proto/v3/go/mls_validation/v1" + "google.golang.org/grpc" +) + +type MockedGRPCService struct { + mock.Mock +} + +func (m *MockedGRPCService) ValidateKeyPackages(ctx context.Context, req *svc.ValidateKeyPackagesRequest, opts ...grpc.CallOption) (*svc.ValidateKeyPackagesResponse, error) { + args := m.Called(ctx, req) + + return args.Get(0).(*svc.ValidateKeyPackagesResponse), args.Error(1) +} + +func (m *MockedGRPCService) ValidateGroupMessages(ctx context.Context, req *svc.ValidateGroupMessagesRequest, opts ...grpc.CallOption) (*svc.ValidateGroupMessagesResponse, error) { + args := m.Called(ctx, req) + + return args.Get(0).(*svc.ValidateGroupMessagesResponse), args.Error(1) +} + +func getMockedService() (*MockedGRPCService, MLSValidationService) { + mockService := new(MockedGRPCService) + service := &MLSValidationServiceImpl{ + grpcClient: mockService, + } + + return mockService, service +} + +func TestValidateKeyPackages(t *testing.T) { + mockGrpc, service := getMockedService() + + ctx := context.Background() + + firstResponse := svc.ValidateKeyPackagesResponse_ValidationResponse{ + IsOk: true, + WalletAddress: "0x123", + InstallationId: "123", + ErrorMessage: "", + } + + mockGrpc.On("ValidateKeyPackages", ctx, mock.Anything).Return(&svc.ValidateKeyPackagesResponse{ + Responses: []*svc.ValidateKeyPackagesResponse_ValidationResponse{&firstResponse}, + }, nil) + + res, err := service.ValidateKeyPackages(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, "0x123", res[0].WalletAddress) + assert.Equal(t, "123", res[0].InstallationId) +} + +func TestValidateKeyPackagesError(t *testing.T) { + +} diff --git a/pkg/server/options.go b/pkg/server/options.go index 40041f4d..e21e3efb 100644 --- a/pkg/server/options.go +++ b/pkg/server/options.go @@ -5,6 +5,7 @@ import ( "github.com/xmtp/xmtp-node-go/pkg/api" "github.com/xmtp/xmtp-node-go/pkg/mlsstore" + "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" "github.com/xmtp/xmtp-node-go/pkg/store" ) @@ -62,17 +63,19 @@ type Options struct { LogEncoding string `long:"log-encoding" description:"Log encoding format. Either console or json" choice:"console" choice:"json" default:"console"` CreateMessageMigration string `long:"create-message-migration" default:"" description:"Create a migration. Must provide a name"` CreateAuthzMigration string `long:"create-authz-migration" default:"" description:"Create a migration for the auth db. Must provide a name"` + CreateMlsMigration string `long:"create-mls-migration" default:"" description:"Create a migration for the mls db. Must provide a name"` WaitForDB time.Duration `long:"wait-for-db" description:"wait for DB on start, up to specified duration"` Version bool `long:"version" description:"Output binary version and exit"` GoProfiling bool `long:"go-profiling" description:"Enable Go profiling"` MetricsPeriod time.Duration `long:"metrics-period" description:"Polling period for server status metrics" default:"30s"` - API api.Options `group:"API Options" namespace:"api"` - Authz AuthzOptions `group:"Authz Options"` - Relay RelayOptions `group:"Relay Options"` - Store store.Options `group:"Store Options" namespace:"store"` - Metrics MetricsOptions `group:"Metrics Options"` - Tracing TracingOptions `group:"DD APM Tracing Options"` - Profiling ProfilingOptions `group:"DD APM Profiling Options" namespace:"profiling"` - MlsStore mlsstore.StoreOptions `group:"MLS Options" namespace:"mlsstore"` + API api.Options `group:"API Options" namespace:"api"` + Authz AuthzOptions `group:"Authz Options"` + Relay RelayOptions `group:"Relay Options"` + Store store.Options `group:"Store Options" namespace:"store"` + Metrics MetricsOptions `group:"Metrics Options"` + Tracing TracingOptions `group:"DD APM Tracing Options"` + Profiling ProfilingOptions `group:"DD APM Profiling Options" namespace:"profiling"` + MLSStore mlsstore.StoreOptions `group:"MLS Options" namespace:"mls-store"` + MlsValidation mlsvalidate.MLSValidationOptions `group:"MLS Validation Options" namespace:"mls-validation"` } diff --git a/pkg/server/server.go b/pkg/server/server.go index adf2ca53..8887fb82 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -37,7 +37,9 @@ import ( "github.com/xmtp/xmtp-node-go/pkg/metrics" authzmigrations "github.com/xmtp/xmtp-node-go/pkg/migrations/authz" messagemigrations "github.com/xmtp/xmtp-node-go/pkg/migrations/messages" + mlsmigrations "github.com/xmtp/xmtp-node-go/pkg/migrations/mls" "github.com/xmtp/xmtp-node-go/pkg/mlsstore" + "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" xmtpstore "github.com/xmtp/xmtp-node-go/pkg/store" "github.com/xmtp/xmtp-node-go/pkg/tracing" "go.uber.org/zap" @@ -59,6 +61,7 @@ type Server struct { allowLister authz.WalletAllowLister authenticator *authn.XmtpAuthentication grpc *api.Server + MLSStore *mlsstore.Store } // Create a new Server @@ -226,32 +229,43 @@ func New(ctx context.Context, log *zap.Logger, options Options) (*Server, error) } s.log.With(logging.MultiAddrs("listen", maddrs...)).Info("got server") - var mlsStore mlsstore.MlsStore + var MLSStore mlsstore.MlsStore - if options.MlsStore.DbConnectionString != "" { - mlsDb, err := createBunDB(options.MlsStore.DbConnectionString, options.WaitForDB, options.MlsStore.ReadTimeout, options.MlsStore.WriteTimeout, options.MlsStore.MaxOpenConns) + if options.MLSStore.DbConnectionString != "" { + mlsDb, err := createBunDB(options.MLSStore.DbConnectionString, options.WaitForDB, options.MLSStore.ReadTimeout, options.MLSStore.WriteTimeout, options.MLSStore.MaxOpenConns) if err != nil { return nil, errors.Wrap(err, "creating mls db") } - mlsStore, err = mlsstore.New(mlsstore.Config{ + s.MLSStore, err = mlsstore.New(s.ctx, mlsstore.Config{ Log: s.log, DB: mlsDb, }) + if err != nil { return nil, errors.Wrap(err, "creating mls store") } } + var MLSValidator mlsvalidate.MLSValidationService + if options.MlsValidation.GRPCAddress != "" { + MLSValidator, err = mlsvalidate.NewMlsValidationService(ctx, options.MlsValidation) + if err != nil { + return nil, errors.Wrap(err, "creating mls validation service") + } + + } + // Initialize gRPC server. s.grpc, err = api.New( &api.Config{ - Options: options.API, - Log: s.log.Named("api"), - Waku: s.wakuNode, - Store: s.store, - MlsStore: mlsStore, - AllowLister: s.allowLister, + Options: options.API, + Log: s.log.Named("`api"), + Waku: s.wakuNode, + Store: s.store, + MLSStore: MLSStore, + AllowLister: s.allowLister, + MLSValidator: MLSValidator, }, ) if err != nil { @@ -292,6 +306,9 @@ func (s *Server) Shutdown() { if s.store != nil { s.store.Close() } + if s.MLSStore != nil { + s.MLSStore.Close() + } // Close metrics server. if s.metricsServer != nil { @@ -510,6 +527,20 @@ func CreateAuthzMigration(migrationName, dbConnectionString string, waitForDb, r return err } +func CreateMlsMigration(migrationName, dbConnectionString string, waitForDb, readTimeout, writeTimeout time.Duration, maxOpenConns int) error { + db, err := createBunDB(dbConnectionString, waitForDb, readTimeout, writeTimeout, maxOpenConns) + if err != nil { + return err + } + migrator := migrate.NewMigrator(db, mlsmigrations.Migrations) + files, err := migrator.CreateSQLMigrations(context.Background(), migrationName) + for _, mf := range files { + fmt.Printf("created authz migration %s (%s)\n", mf.Name, mf.Path) + } + + return err +} + func createBunDB(dsn string, waitForDB, readTimeout, writeTimeout time.Duration, maxOpenConns int) (*bun.DB, error) { db, err := createDB(dsn, waitForDB, readTimeout, writeTimeout, maxOpenConns) if err != nil { diff --git a/pkg/testing/store.go b/pkg/testing/store.go index 35806f45..f9a2266c 100644 --- a/pkg/testing/store.go +++ b/pkg/testing/store.go @@ -11,6 +11,7 @@ import ( "github.com/uptrace/bun/driver/pgdriver" "github.com/uptrace/bun/migrate" "github.com/xmtp/xmtp-node-go/pkg/migrations/authz" + "github.com/xmtp/xmtp-node-go/pkg/migrations/mls" ) const ( @@ -48,3 +49,17 @@ func NewAuthzDB(t *testing.T) (*bun.DB, string, func()) { return bunDB, dsn, cleanup } + +func NewMlsDB(t *testing.T) (*bun.DB, string, func()) { + db, dsn, cleanup := NewDB(t) + bunDB := bun.NewDB(db, pgdialect.New()) + + ctx := context.Background() + migrator := migrate.NewMigrator(bunDB, mls.Migrations) + err := migrator.Init(ctx) + require.NoError(t, err) + _, err = migrator.Migrate(ctx) + require.NoError(t, err) + + return bunDB, dsn, cleanup +}