From 24971f48cfde5dfef3b87b182ce2747a5c0d24fa Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Fri, 31 May 2024 15:10:57 -0700 Subject: [PATCH] Get Identity DB ready for prod (#394) * Deal with transaction serialization issues * Revert schema change * Use pg_advisory_xact_lock * Check return value of lock * Add random sleep * Add initial migration * Remove unused line * Update code to use new schema * Remove unused file * Fix failing tests * Remove unused functions * Move hex encoding to the db --- .mockery.yaml | 18 + dev/generate | 4 +- dev/up | 2 +- go.mod | 2 +- go.sum | 5 +- pkg/identity/api/v1/identity_service_test.go | 14 +- .../mls/20231023050806_init-schema.down.sql | 5 - .../mls/20231023050806_init-schema.up.sql | 23 -- .../mls/20240109001927_add-messages.up.sql | 32 -- .../mls/20240122230601_add-hpke-key.down.sql | 5 - .../mls/20240122230601_add-hpke-key.up.sql | 5 - .../mls/20240411200242_init-identity.down.sql | 8 - .../mls/20240411200242_init-identity.up.sql | 24 -- .../20240425021053_add-inbox-filters.up.sql | 8 - ...wn.sql => 20240528181822_wipe-db.down.sql} | 6 +- .../mls/20240528181822_wipe-db.up.sql | 8 + ...ql => 20240528181851_init-schema.down.sql} | 6 +- .../mls/20240528181851_init-schema.up.sql | 71 ++++ pkg/mls/api/v1/mock.gen.go | 257 ------------- pkg/mls/api/v1/service.go | 144 ++----- pkg/mls/api/v1/service_test.go | 213 ++++------- pkg/mls/store/queries.sql | 57 ++- pkg/mls/store/queries/models.go | 20 +- pkg/mls/store/queries/queries.sql.go | 189 ++++------ pkg/mls/store/store.go | 71 +--- pkg/mls/store/store_test.go | 157 ++------ pkg/mlsvalidate/mocks/mock.gen.go | 142 ------- pkg/mocks/mock_MLSValidationService.go | 278 ++++++++++++++ ...ock_MlsApi_SubscribeGroupMessagesServer.go | 350 ++++++++++++++++++ ...k_MlsApi_SubscribeWelcomeMessagesServer.go | 350 ++++++++++++++++++ pkg/testing/random.go | 8 + pkg/utils/hex.go | 19 + tools.go | 1 - 33 files changed, 1378 insertions(+), 1124 deletions(-) create mode 100644 .mockery.yaml delete mode 100644 pkg/migrations/mls/20231023050806_init-schema.down.sql delete mode 100644 pkg/migrations/mls/20231023050806_init-schema.up.sql delete mode 100644 pkg/migrations/mls/20240109001927_add-messages.up.sql delete mode 100644 pkg/migrations/mls/20240122230601_add-hpke-key.down.sql delete mode 100644 pkg/migrations/mls/20240122230601_add-hpke-key.up.sql delete mode 100644 pkg/migrations/mls/20240411200242_init-identity.down.sql delete mode 100644 pkg/migrations/mls/20240411200242_init-identity.up.sql delete mode 100644 pkg/migrations/mls/20240425021053_add-inbox-filters.up.sql rename pkg/migrations/mls/{20240109001927_add-messages.down.sql => 20240528181822_wipe-db.down.sql} (56%) create mode 100644 pkg/migrations/mls/20240528181822_wipe-db.up.sql rename pkg/migrations/mls/{20240425021053_add-inbox-filters.down.sql => 20240528181851_init-schema.down.sql} (54%) create mode 100644 pkg/migrations/mls/20240528181851_init-schema.up.sql delete mode 100644 pkg/mls/api/v1/mock.gen.go delete mode 100644 pkg/mlsvalidate/mocks/mock.gen.go create mode 100644 pkg/mocks/mock_MLSValidationService.go create mode 100644 pkg/mocks/mock_MlsApi_SubscribeGroupMessagesServer.go create mode 100644 pkg/mocks/mock_MlsApi_SubscribeWelcomeMessagesServer.go create mode 100644 pkg/utils/hex.go diff --git a/.mockery.yaml b/.mockery.yaml new file mode 100644 index 00000000..c21bd77d --- /dev/null +++ b/.mockery.yaml @@ -0,0 +1,18 @@ +with-expecter: true +packages: + github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1: + config: + dir: ./pkg/mocks + outpkg: mocks + interfaces: + MlsApi_SubscribeGroupMessagesServer: + config: + MlsApi_SubscribeWelcomeMessagesServer: + config: + github.com/xmtp/xmtp-node-go/pkg/mlsvalidate: + config: + dir: ./pkg/mocks + outpkg: mocks + interfaces: + MLSValidationService: + config: diff --git a/dev/generate b/dev/generate index 1896e37c..e82af341 100755 --- a/dev/generate +++ b/dev/generate @@ -3,8 +3,8 @@ set -e go generate ./... -mockgen -package api github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1 MlsApi_SubscribeGroupMessagesServer,MlsApi_SubscribeWelcomeMessagesServer > pkg/mls/api/v1/mock.gen.go -mockgen -package mocks -source ./pkg/mlsvalidate/service.go MLSValidationService > pkg/mlsvalidate/mocks/mock.gen.go +# Generate mocks +mockery rm -rf pkg/proto/**/*.pb.go pkg/proto/**/*.pb.gw.go pkg/proto/**/*.swagger.json if ! buf generate https://github.com/xmtp/proto.git#branch=main,subdir=proto; then echo "Failed to generate protobuf definitions" diff --git a/dev/up b/dev/up index 8247ba0a..15f07274 100755 --- a/dev/up +++ b/dev/up @@ -8,7 +8,7 @@ if ! which golangci-lint &>/dev/null; then brew install golangci-lint; fi if ! which shellcheck &>/dev/null; then brew install shellcheck; fi if ! which protoc &>/dev/null; then brew install protobuf; fi if ! which protoc-gen-go &>/dev/null; then go install google.golang.org/protobuf/cmd/protoc-gen-go@latest; fi -if ! which mockgen &>/dev/null || [ `mockgen --version` != "v0.4.0" ]; then go install go.uber.org/mock/mockgen@v0.4.0; fi +if ! which mockery &>/dev/null; then brew install mockery; fi if ! which protolint &>/dev/null; then go install github.com/yoheimuta/protolint/cmd/protolint@latest; fi dev/generate diff --git a/go.mod b/go.mod index a8e2bc9b..843b10fa 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,6 @@ require ( github.com/waku-org/go-waku v0.8.0 github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 github.com/yoheimuta/protolint v0.39.0 - go.uber.org/mock v0.4.0 go.uber.org/zap v1.24.0 golang.org/x/sync v0.3.0 google.golang.org/genproto v0.0.0-20230223222841-637eb2293923 @@ -42,6 +41,7 @@ require ( ) require ( + github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect diff --git a/go.sum b/go.sum index 16c75288..21ad6545 100644 --- a/go.sum +++ b/go.sum @@ -141,7 +141,8 @@ github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQ github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= github.com/btcsuite/btcd/btcec/v2 v2.2.1 h1:xP60mv8fvp+0khmrN0zTdPC3cNm24rfeE6lh2R/Yv3E= github.com/btcsuite/btcd/btcec/v2 v2.2.1/go.mod h1:9/CSmJxmuvqzX9Wh2fXMWToLOHhPd11lSPuIupwTkI8= -github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 h1:q0rUy8C/TYNBQS1+CGKw68tLOFYSNEs0TFnxxnS9+4U= +github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 h1:59Kx4K6lzOW5w6nFlA0v5+lk/6sjybR934QNHSJZPTQ= +github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg= @@ -1191,8 +1192,6 @@ go.uber.org/fx v1.20.0 h1:ZMC/pnRvhsthOZh9MZjMq5U8Or3mA9zBSPaLnzs3ihQ= go.uber.org/fx v1.20.0/go.mod h1:qCUj0btiR3/JnanEr1TYEePfSw6o/4qYJscgvzQ5Ub0= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= diff --git a/pkg/identity/api/v1/identity_service_test.go b/pkg/identity/api/v1/identity_service_test.go index 84318c45..4afe4d37 100644 --- a/pkg/identity/api/v1/identity_service_test.go +++ b/pkg/identity/api/v1/identity_service_test.go @@ -152,7 +152,7 @@ func TestPublishedUpdatesCanBeRead(t *testing.T) { svc, _, cleanup := newTestService(t, ctx) defer cleanup() - inbox_id := "test_inbox" + inbox_id := test.RandomInboxId() address := "test_address" _, err := svc.PublishIdentityUpdate(ctx, publishIdentityUpdateRequest(inbox_id, makeCreateInbox(address))) @@ -173,7 +173,7 @@ func TestPublishedUpdatesAreInOrder(t *testing.T) { svc, _, cleanup := newTestService(t, ctx) defer cleanup() - inbox_id := "test_inbox" + inbox_id := test.RandomInboxId() address := "test_address" _, err := svc.PublishIdentityUpdate(ctx, publishIdentityUpdateRequest(inbox_id, makeCreateInbox(address))) @@ -210,10 +210,10 @@ func TestQueryMultipleInboxes(t *testing.T) { svc, _, cleanup := newTestService(t, ctx) defer cleanup() - first_inbox_id := "test_inbox" - second_inbox_id := "second_inbox" - first_address := "test_address" - second_address := "test_address" + first_inbox_id := test.RandomInboxId() + second_inbox_id := test.RandomInboxId() + first_address := test.RandomInboxId() + second_address := test.RandomInboxId() _, err := svc.PublishIdentityUpdate(ctx, publishIdentityUpdateRequest(first_inbox_id, makeCreateInbox(first_address))) require.NoError(t, err) @@ -233,7 +233,7 @@ func TestInboxSizeLimit(t *testing.T) { svc, _, cleanup := newTestService(t, ctx) defer cleanup() - inbox_id := "test_inbox" + inbox_id := test.RandomInboxId() address := "test_address" _, err := svc.PublishIdentityUpdate(ctx, publishIdentityUpdateRequest(inbox_id, makeCreateInbox(address))) diff --git a/pkg/migrations/mls/20231023050806_init-schema.down.sql b/pkg/migrations/mls/20231023050806_init-schema.down.sql deleted file mode 100644 index 77c9a8e7..00000000 --- a/pkg/migrations/mls/20231023050806_init-schema.down.sql +++ /dev/null @@ -1,5 +0,0 @@ -SET statement_timeout = 0; - ---bun:split -DROP TABLE IF EXISTS installations; - diff --git a/pkg/migrations/mls/20231023050806_init-schema.up.sql b/pkg/migrations/mls/20231023050806_init-schema.up.sql deleted file mode 100644 index b57561fc..00000000 --- a/pkg/migrations/mls/20231023050806_init-schema.up.sql +++ /dev/null @@ -1,23 +0,0 @@ -SET statement_timeout = 0; - ---bun:split -CREATE TABLE installations( - id BYTEA PRIMARY KEY, - wallet_address TEXT NOT NULL, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - credential_identity BYTEA NOT NULL, - revoked_at BIGINT, - key_package BYTEA NOT NULL, - expiration BIGINT NOT NULL -); - ---bun:split -CREATE INDEX idx_installations_wallet_address ON installations(wallet_address); - ---bun:split -CREATE INDEX idx_installations_created_at ON installations(created_at); - ---bun:split -CREATE INDEX idx_installations_revoked_at ON installations(revoked_at); - diff --git a/pkg/migrations/mls/20240109001927_add-messages.up.sql b/pkg/migrations/mls/20240109001927_add-messages.up.sql deleted file mode 100644 index aa76a2be..00000000 --- a/pkg/migrations/mls/20240109001927_add-messages.up.sql +++ /dev/null @@ -1,32 +0,0 @@ -SET statement_timeout = 0; - ---bun:split -CREATE TABLE group_messages( - id BIGSERIAL PRIMARY KEY, - created_at TIMESTAMP NOT NULL DEFAULT NOW(), - group_id BYTEA NOT NULL, - data BYTEA NOT NULL, - group_id_data_hash BYTEA NOT NULL -); - ---bun:split -CREATE INDEX idx_group_messages_group_id_created_at ON group_messages(group_id, created_at); - ---bun:split -CREATE UNIQUE INDEX idx_group_messages_group_id_data_hash ON group_messages(group_id_data_hash); - ---bun:split -CREATE TABLE welcome_messages( - id BIGSERIAL PRIMARY KEY, - created_at TIMESTAMP NOT NULL DEFAULT NOW(), - installation_key BYTEA NOT NULL, - data BYTEA NOT NULL, - installation_key_data_hash BYTEA NOT NULL -); - ---bun:split -CREATE INDEX idx_welcome_messages_installation_key_created_at ON welcome_messages(installation_key, created_at); - ---bun:split -CREATE UNIQUE INDEX idx_welcome_messages_group_key_data_hash ON welcome_messages(installation_key_data_hash); - diff --git a/pkg/migrations/mls/20240122230601_add-hpke-key.down.sql b/pkg/migrations/mls/20240122230601_add-hpke-key.down.sql deleted file mode 100644 index b51fe300..00000000 --- a/pkg/migrations/mls/20240122230601_add-hpke-key.down.sql +++ /dev/null @@ -1,5 +0,0 @@ -SET statement_timeout = 0; - -ALTER TABLE welcome_messages - DROP COLUMN IF EXISTS hpke_public_key BYTEA; - diff --git a/pkg/migrations/mls/20240122230601_add-hpke-key.up.sql b/pkg/migrations/mls/20240122230601_add-hpke-key.up.sql deleted file mode 100644 index f38bf60d..00000000 --- a/pkg/migrations/mls/20240122230601_add-hpke-key.up.sql +++ /dev/null @@ -1,5 +0,0 @@ -SET statement_timeout = 0; - -ALTER TABLE welcome_messages - ADD COLUMN hpke_public_key BYTEA; - diff --git a/pkg/migrations/mls/20240411200242_init-identity.down.sql b/pkg/migrations/mls/20240411200242_init-identity.down.sql deleted file mode 100644 index 29753e66..00000000 --- a/pkg/migrations/mls/20240411200242_init-identity.down.sql +++ /dev/null @@ -1,8 +0,0 @@ -SET statement_timeout = 0; - ---bun:split -DROP TABLE IF EXISTS inbox_log; - ---bun:split -DROP TABLE IF EXISTS address_log; - diff --git a/pkg/migrations/mls/20240411200242_init-identity.up.sql b/pkg/migrations/mls/20240411200242_init-identity.up.sql deleted file mode 100644 index cb42e660..00000000 --- a/pkg/migrations/mls/20240411200242_init-identity.up.sql +++ /dev/null @@ -1,24 +0,0 @@ -SET statement_timeout = 0; - ---bun:split -CREATE TABLE inbox_log( - sequence_id BIGSERIAL PRIMARY KEY, - inbox_id TEXT NOT NULL, - server_timestamp_ns BIGINT NOT NULL, - identity_update_proto BYTEA NOT NULL -); - ---bun:split -CREATE INDEX idx_inbox_log_inbox_id ON inbox_log(inbox_id); - ---bun:split -CREATE TABLE address_log( - address TEXT NOT NULL, - inbox_id TEXT NOT NULL, - association_sequence_id BIGINT, - revocation_sequence_id BIGINT -); - ---bun:split -CREATE INDEX idx_address_log_address_inbox_id ON address_log(address, inbox_id); - diff --git a/pkg/migrations/mls/20240425021053_add-inbox-filters.up.sql b/pkg/migrations/mls/20240425021053_add-inbox-filters.up.sql deleted file mode 100644 index bee29307..00000000 --- a/pkg/migrations/mls/20240425021053_add-inbox-filters.up.sql +++ /dev/null @@ -1,8 +0,0 @@ -SET statement_timeout = 0; - ---bun:split -CREATE TYPE inbox_filter AS ( - inbox_id TEXT, - sequence_id BIGINT -); - diff --git a/pkg/migrations/mls/20240109001927_add-messages.down.sql b/pkg/migrations/mls/20240528181822_wipe-db.down.sql similarity index 56% rename from pkg/migrations/mls/20240109001927_add-messages.down.sql rename to pkg/migrations/mls/20240528181822_wipe-db.down.sql index 3c0d999e..87d60f3e 100644 --- a/pkg/migrations/mls/20240109001927_add-messages.down.sql +++ b/pkg/migrations/mls/20240528181822_wipe-db.down.sql @@ -1,5 +1,9 @@ SET statement_timeout = 0; --bun:split -DROP TABLE IF EXISTS messages; +SELECT 1 + +--bun:split + +SELECT 2 diff --git a/pkg/migrations/mls/20240528181822_wipe-db.up.sql b/pkg/migrations/mls/20240528181822_wipe-db.up.sql new file mode 100644 index 00000000..f58fdc8a --- /dev/null +++ b/pkg/migrations/mls/20240528181822_wipe-db.up.sql @@ -0,0 +1,8 @@ +SET statement_timeout = 0; + +--bun:split +DROP TABLE IF EXISTS installations, group_messages, welcome_messages, inbox_log, address_log; + +--bun:split +DROP TYPE IF EXISTS inbox_filter; + diff --git a/pkg/migrations/mls/20240425021053_add-inbox-filters.down.sql b/pkg/migrations/mls/20240528181851_init-schema.down.sql similarity index 54% rename from pkg/migrations/mls/20240425021053_add-inbox-filters.down.sql rename to pkg/migrations/mls/20240528181851_init-schema.down.sql index 6a628122..87d60f3e 100644 --- a/pkg/migrations/mls/20240425021053_add-inbox-filters.down.sql +++ b/pkg/migrations/mls/20240528181851_init-schema.down.sql @@ -1,5 +1,9 @@ SET statement_timeout = 0; --bun:split -DROP TYPE IF EXISTS inbox_filter; +SELECT 1 + +--bun:split + +SELECT 2 diff --git a/pkg/migrations/mls/20240528181851_init-schema.up.sql b/pkg/migrations/mls/20240528181851_init-schema.up.sql new file mode 100644 index 00000000..104d8526 --- /dev/null +++ b/pkg/migrations/mls/20240528181851_init-schema.up.sql @@ -0,0 +1,71 @@ +SET statement_timeout = 0; + +--bun:split +CREATE TABLE installations( + id BYTEA PRIMARY KEY, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL, + inbox_id BYTEA NOT NULL, + key_package BYTEA NOT NULL, + expiration BIGINT NOT NULL +); + +--bun:split +CREATE TABLE group_messages( + id BIGSERIAL PRIMARY KEY, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + group_id BYTEA NOT NULL, + data BYTEA NOT NULL, + group_id_data_hash BYTEA NOT NULL +); + +--bun:split +CREATE INDEX idx_group_messages_group_id_id ON group_messages(group_id, id); + +--bun:split +CREATE UNIQUE INDEX idx_group_messages_group_id_data_hash ON group_messages(group_id_data_hash); + +--bun:split +CREATE TABLE welcome_messages( + id BIGSERIAL PRIMARY KEY, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + installation_key BYTEA NOT NULL, + data BYTEA NOT NULL, + hpke_public_key BYTEA NOT NULL, + installation_key_data_hash BYTEA NOT NULL +); + +--bun:split +CREATE INDEX idx_welcome_messages_installation_key_id ON welcome_messages(installation_key, id); + +--bun:split +CREATE UNIQUE INDEX idx_welcome_messages_group_key_data_hash ON welcome_messages(installation_key_data_hash); + +--bun:split +CREATE TABLE inbox_log( + sequence_id BIGSERIAL PRIMARY KEY, + inbox_id BYTEA NOT NULL, + server_timestamp_ns BIGINT NOT NULL, + identity_update_proto BYTEA NOT NULL +); + +--bun:split +CREATE INDEX idx_inbox_log_inbox_id_sequence_id ON inbox_log(inbox_id, sequence_id); + +--bun:split +CREATE TABLE address_log( + address TEXT NOT NULL, + inbox_id BYTEA NOT NULL, + association_sequence_id BIGINT, + revocation_sequence_id BIGINT +); + +--bun:split +CREATE INDEX idx_address_log_address_inbox_id ON address_log(address, inbox_id); + +--bun:split +CREATE TYPE inbox_filter AS ( + inbox_id TEXT, -- Because this is serialized as JSON, we can't use a BYTEA type + sequence_id BIGINT +); + diff --git a/pkg/mls/api/v1/mock.gen.go b/pkg/mls/api/v1/mock.gen.go deleted file mode 100644 index 3c23a6ef..00000000 --- a/pkg/mls/api/v1/mock.gen.go +++ /dev/null @@ -1,257 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1 (interfaces: MlsApi_SubscribeGroupMessagesServer,MlsApi_SubscribeWelcomeMessagesServer) -// -// Generated by this command: -// -// mockgen -package api github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1 MlsApi_SubscribeGroupMessagesServer,MlsApi_SubscribeWelcomeMessagesServer -// - -// Package api is a generated GoMock package. -package api - -import ( - context "context" - reflect "reflect" - - apiv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" - gomock "go.uber.org/mock/gomock" - metadata "google.golang.org/grpc/metadata" -) - -// MockMlsApi_SubscribeGroupMessagesServer is a mock of MlsApi_SubscribeGroupMessagesServer interface. -type MockMlsApi_SubscribeGroupMessagesServer struct { - ctrl *gomock.Controller - recorder *MockMlsApi_SubscribeGroupMessagesServerMockRecorder -} - -// MockMlsApi_SubscribeGroupMessagesServerMockRecorder is the mock recorder for MockMlsApi_SubscribeGroupMessagesServer. -type MockMlsApi_SubscribeGroupMessagesServerMockRecorder struct { - mock *MockMlsApi_SubscribeGroupMessagesServer -} - -// NewMockMlsApi_SubscribeGroupMessagesServer creates a new mock instance. -func NewMockMlsApi_SubscribeGroupMessagesServer(ctrl *gomock.Controller) *MockMlsApi_SubscribeGroupMessagesServer { - mock := &MockMlsApi_SubscribeGroupMessagesServer{ctrl: ctrl} - mock.recorder = &MockMlsApi_SubscribeGroupMessagesServerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMlsApi_SubscribeGroupMessagesServer) EXPECT() *MockMlsApi_SubscribeGroupMessagesServerMockRecorder { - return m.recorder -} - -// Context mocks base method. -func (m *MockMlsApi_SubscribeGroupMessagesServer) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).Context)) -} - -// RecvMsg mocks base method. -func (m *MockMlsApi_SubscribeGroupMessagesServer) RecvMsg(arg0 any) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RecvMsg", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RecvMsg indicates an expected call of RecvMsg. -func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) RecvMsg(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).RecvMsg), arg0) -} - -// Send mocks base method. -func (m *MockMlsApi_SubscribeGroupMessagesServer) Send(arg0 *apiv1.GroupMessage) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Send", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Send indicates an expected call of Send. -func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) Send(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).Send), arg0) -} - -// SendHeader mocks base method. -func (m *MockMlsApi_SubscribeGroupMessagesServer) SendHeader(arg0 metadata.MD) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendHeader", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendHeader indicates an expected call of SendHeader. -func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) SendHeader(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeader", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).SendHeader), arg0) -} - -// SendMsg mocks base method. -func (m *MockMlsApi_SubscribeGroupMessagesServer) SendMsg(arg0 any) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMsg", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendMsg indicates an expected call of SendMsg. -func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) SendMsg(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).SendMsg), arg0) -} - -// SetHeader mocks base method. -func (m *MockMlsApi_SubscribeGroupMessagesServer) SetHeader(arg0 metadata.MD) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetHeader", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetHeader indicates an expected call of SetHeader. -func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) SetHeader(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeader", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).SetHeader), arg0) -} - -// SetTrailer mocks base method. -func (m *MockMlsApi_SubscribeGroupMessagesServer) SetTrailer(arg0 metadata.MD) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetTrailer", arg0) -} - -// SetTrailer indicates an expected call of SetTrailer. -func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) SetTrailer(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrailer", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).SetTrailer), arg0) -} - -// MockMlsApi_SubscribeWelcomeMessagesServer is a mock of MlsApi_SubscribeWelcomeMessagesServer interface. -type MockMlsApi_SubscribeWelcomeMessagesServer struct { - ctrl *gomock.Controller - recorder *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder -} - -// MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder is the mock recorder for MockMlsApi_SubscribeWelcomeMessagesServer. -type MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder struct { - mock *MockMlsApi_SubscribeWelcomeMessagesServer -} - -// NewMockMlsApi_SubscribeWelcomeMessagesServer creates a new mock instance. -func NewMockMlsApi_SubscribeWelcomeMessagesServer(ctrl *gomock.Controller) *MockMlsApi_SubscribeWelcomeMessagesServer { - mock := &MockMlsApi_SubscribeWelcomeMessagesServer{ctrl: ctrl} - mock.recorder = &MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMlsApi_SubscribeWelcomeMessagesServer) EXPECT() *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder { - return m.recorder -} - -// Context mocks base method. -func (m *MockMlsApi_SubscribeWelcomeMessagesServer) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).Context)) -} - -// RecvMsg mocks base method. -func (m *MockMlsApi_SubscribeWelcomeMessagesServer) RecvMsg(arg0 any) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RecvMsg", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RecvMsg indicates an expected call of RecvMsg. -func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) RecvMsg(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).RecvMsg), arg0) -} - -// Send mocks base method. -func (m *MockMlsApi_SubscribeWelcomeMessagesServer) Send(arg0 *apiv1.WelcomeMessage) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Send", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Send indicates an expected call of Send. -func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) Send(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).Send), arg0) -} - -// SendHeader mocks base method. -func (m *MockMlsApi_SubscribeWelcomeMessagesServer) SendHeader(arg0 metadata.MD) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendHeader", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendHeader indicates an expected call of SendHeader. -func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) SendHeader(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeader", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).SendHeader), arg0) -} - -// SendMsg mocks base method. -func (m *MockMlsApi_SubscribeWelcomeMessagesServer) SendMsg(arg0 any) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMsg", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendMsg indicates an expected call of SendMsg. -func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) SendMsg(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).SendMsg), arg0) -} - -// SetHeader mocks base method. -func (m *MockMlsApi_SubscribeWelcomeMessagesServer) SetHeader(arg0 metadata.MD) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetHeader", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetHeader indicates an expected call of SetHeader. -func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) SetHeader(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeader", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).SetHeader), arg0) -} - -// SetTrailer mocks base method. -func (m *MockMlsApi_SubscribeWelcomeMessagesServer) SetTrailer(arg0 metadata.MD) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetTrailer", arg0) -} - -// SetTrailer indicates an expected call of SetTrailer. -func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) SetTrailer(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrailer", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).SetTrailer), arg0) -} diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index ecc0755e..56c18c67 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -118,45 +118,22 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterI return nil, err } - if req.IsInboxIdCredential { - results, err := s.validationService.ValidateInboxIdKeyPackages(ctx, [][]byte{req.KeyPackage.KeyPackageTlsSerialized}) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) - } - - if len(results) != 1 { - return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results)) - } - installationKey := results[0].InstallationKey - credential := results[0].Credential - if err = s.store.CreateInstallation(ctx, installationKey, "", []byte(credential.InboxId), req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil { - return nil, err - } - return &mlsv1.RegisterInstallationResponse{ - InstallationKey: installationKey, - }, nil - } else { - results, err := s.validationService.ValidateV3KeyPackages(ctx, [][]byte{req.KeyPackage.KeyPackageTlsSerialized}) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) - } - - if len(results) != 1 { - return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results)) - } - - installationId := results[0].InstallationKey - accountAddress := results[0].AccountAddress - credentialIdentity := results[0].CredentialIdentity - - if err = s.store.CreateInstallation(ctx, installationId, accountAddress, credentialIdentity, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil { - return nil, err - } + results, err := s.validationService.ValidateInboxIdKeyPackages(ctx, [][]byte{req.KeyPackage.KeyPackageTlsSerialized}) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) + } - return &mlsv1.RegisterInstallationResponse{ - InstallationKey: installationId, - }, nil + if len(results) != 1 { + return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results)) + } + installationKey := results[0].InstallationKey + credential := results[0].Credential + if err = s.store.CreateInstallation(ctx, installationKey, credential.InboxId, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil { + return nil, err } + return &mlsv1.RegisterInstallationResponse{ + InstallationKey: installationKey, + }, nil } func (s *Service) FetchKeyPackages(ctx context.Context, req *mlsv1.FetchKeyPackagesRequest) (*mlsv1.FetchKeyPackagesResponse, error) { @@ -195,35 +172,19 @@ func (s *Service) UploadKeyPackage(ctx context.Context, req *mlsv1.UploadKeyPack // Extract the key packages from the request keyPackageBytes := req.KeyPackage.KeyPackageTlsSerialized - if req.IsInboxIdCredential { - validationResults, err := s.validationService.ValidateInboxIdKeyPackages(ctx, [][]byte{keyPackageBytes}) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) - } - - installationId := validationResults[0].InstallationKey - expiration := validationResults[0].Expiration - - if err = s.store.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil { - return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) - } + validationResults, err := s.validationService.ValidateInboxIdKeyPackages(ctx, [][]byte{keyPackageBytes}) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) + } - return &emptypb.Empty{}, nil - } else { - validationResults, err := s.validationService.ValidateV3KeyPackages(ctx, [][]byte{keyPackageBytes}) - if err != nil { - // TODO: Differentiate between validation errors and internal errors - return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) - } - installationId := validationResults[0].InstallationKey - expiration := validationResults[0].Expiration + installationId := validationResults[0].InstallationKey + expiration := validationResults[0].Expiration - if err = s.store.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil { - return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) - } - - return &emptypb.Empty{}, nil + if err = s.store.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil { + return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) } + + return &emptypb.Empty{}, nil } func (s *Service) RevokeInstallation(ctx context.Context, req *mlsv1.RevokeInstallationRequest) (*emptypb.Empty, error) { @@ -231,32 +192,7 @@ func (s *Service) RevokeInstallation(ctx context.Context, req *mlsv1.RevokeInsta } func (s *Service) GetIdentityUpdates(ctx context.Context, req *mlsv1.GetIdentityUpdatesRequest) (res *mlsv1.GetIdentityUpdatesResponse, err error) { - if err = validateGetIdentityUpdatesRequest(req); err != nil { - return nil, err - } - - accountAddresses := req.AccountAddresses - updates, err := s.store.GetIdentityUpdates(ctx, req.AccountAddresses, int64(req.StartTimeNs)) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get identity updates: %s", err) - } - - resUpdates := make([]*mlsv1.GetIdentityUpdatesResponse_WalletUpdates, len(accountAddresses)) - for i, accountAddress := range accountAddresses { - walletUpdates := updates[accountAddress] - - resUpdates[i] = &mlsv1.GetIdentityUpdatesResponse_WalletUpdates{ - Updates: []*mlsv1.GetIdentityUpdatesResponse_Update{}, - } - - for _, walletUpdate := range walletUpdates { - resUpdates[i].Updates = append(resUpdates[i].Updates, buildIdentityUpdate(walletUpdate)) - } - } - - return &mlsv1.GetIdentityUpdatesResponse{ - Updates: resUpdates, - }, nil + return nil, status.Errorf(codes.Unimplemented, "unimplemented") } func (s *Service) SendGroupMessages(ctx context.Context, req *mlsv1.SendGroupMessagesRequest) (res *emptypb.Empty, err error) { @@ -583,29 +519,6 @@ func buildNatsSubjectForWelcomeMessages(installationId []byte) string { return envelopes.BuildNatsSubject(contentTopic) } -func buildIdentityUpdate(update mlsstore.IdentityUpdate) *mlsv1.GetIdentityUpdatesResponse_Update { - base := mlsv1.GetIdentityUpdatesResponse_Update{ - TimestampNs: update.TimestampNs, - } - switch update.Kind { - case mlsstore.Create: - base.Kind = &mlsv1.GetIdentityUpdatesResponse_Update_NewInstallation{ - NewInstallation: &mlsv1.GetIdentityUpdatesResponse_NewInstallationUpdate{ - InstallationKey: update.InstallationKey, - CredentialIdentity: update.CredentialIdentity, - }, - } - case mlsstore.Revoke: - base.Kind = &mlsv1.GetIdentityUpdatesResponse_Update_RevokedInstallation{ - RevokedInstallation: &mlsv1.GetIdentityUpdatesResponse_RevokedInstallationUpdate{ - InstallationKey: update.InstallationKey, - }, - } - } - - return &base -} - func validateSendGroupMessagesRequest(req *mlsv1.SendGroupMessagesRequest) error { if req == nil || len(req.Messages) == 0 { return status.Errorf(codes.InvalidArgument, "no group messages to send") @@ -649,13 +562,6 @@ func validateUploadKeyPackageRequest(req *mlsv1.UploadKeyPackageRequest) error { return nil } -func validateGetIdentityUpdatesRequest(req *mlsv1.GetIdentityUpdatesRequest) error { - if req == nil || len(req.AccountAddresses) == 0 { - return status.Errorf(codes.InvalidArgument, "no wallet addresses to get updates for") - } - return nil -} - func requireReadyToSend(groupId string, message []byte) error { if len(groupId) == 0 { return status.Errorf(codes.InvalidArgument, "group id is empty") diff --git a/pkg/mls/api/v1/service_test.go b/pkg/mls/api/v1/service_test.go index 416c730e..479bdd78 100644 --- a/pkg/mls/api/v1/service_test.go +++ b/pkg/mls/api/v1/service_test.go @@ -16,67 +16,36 @@ import ( mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store" "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" - "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations" + "github.com/xmtp/xmtp-node-go/pkg/mocks" + "github.com/xmtp/xmtp-node-go/pkg/proto/identity" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" test "github.com/xmtp/xmtp-node-go/pkg/testing" "github.com/xmtp/xmtp-node-go/pkg/topic" - "go.uber.org/mock/gomock" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" ) -type mockedMLSValidationService struct { - mock.Mock -} - -func (m *mockedMLSValidationService) GetAssociationState(ctx context.Context, oldUpdates []*associations.IdentityUpdate, newUpdates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) { - return nil, nil -} - -func (m *mockedMLSValidationService) ValidateV3KeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) { - args := m.Called(ctx, keyPackages) - - response := args.Get(0) - if response == nil { - return nil, args.Error(1) - } - - return response.([]mlsvalidate.IdentityValidationResult), args.Error(1) -} - -func (m *mockedMLSValidationService) ValidateInboxIdKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.InboxIdValidationResult, error) { - return nil, nil -} - -func (m *mockedMLSValidationService) ValidateGroupMessages(ctx context.Context, groupMessages []*mlsv1.GroupMessageInput) ([]mlsvalidate.GroupMessageValidationResult, error) { - args := m.Called(ctx, groupMessages) - - return args.Get(0).([]mlsvalidate.GroupMessageValidationResult), args.Error(1) -} - -func newMockedValidationService() *mockedMLSValidationService { - return new(mockedMLSValidationService) -} - -func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId []byte, accountAddress string) *mock.Call { - return m.On("ValidateV3KeyPackages", mock.Anything, mock.Anything).Return([]mlsvalidate.IdentityValidationResult{ +func mockValidateInboxIdKeyPackages(m *mocks.MockMLSValidationService, installationId []byte, inboxId string) *mocks.MockMLSValidationService_ValidateInboxIdKeyPackages_Call { + return m.EXPECT().ValidateInboxIdKeyPackages(mock.Anything, mock.Anything).Return([]mlsvalidate.InboxIdValidationResult{ { - InstallationKey: installationId, - AccountAddress: accountAddress, - CredentialIdentity: []byte("test"), - Expiration: 0, + InstallationKey: installationId, + Credential: &identity.MlsCredential{ + InboxId: inboxId, + }, + Expiration: 0, }, }, nil) } -func (m *mockedMLSValidationService) mockValidateGroupMessages(groupId []byte) *mock.Call { - return m.On("ValidateGroupMessages", mock.Anything, mock.Anything).Return([]mlsvalidate.GroupMessageValidationResult{ +func mockValidateGroupMessages(m *mocks.MockMLSValidationService, groupId []byte) *mocks.MockMLSValidationService_ValidateGroupMessages_Call { + return m.EXPECT().ValidateGroupMessages(mock.Anything, mock.Anything).Return([]mlsvalidate.GroupMessageValidationResult{ { GroupId: fmt.Sprintf("%x", groupId), }, }, nil) } -func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mockedMLSValidationService, func()) { +func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mocks.MockMLSValidationService, func()) { log := test.NewLog(t) db, _, mlsDbCleanup := test.NewMLSDB(t) store, err := mlsstore.New(ctx, mlsstore.Config{ @@ -84,7 +53,7 @@ func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mock DB: db, }) require.NoError(t, err) - mlsValidationService := newMockedValidationService() + mockMlsValidation := mocks.NewMockMLSValidationService(t) natsServer, err := server.NewServer(&server.Options{ Port: server.RANDOM_PORT, }) @@ -94,12 +63,12 @@ func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mock t.Fail() } - svc, err := NewService(log, store, mlsValidationService, natsServer, func(ctx context.Context, wm *wakupb.WakuMessage) error { + svc, err := NewService(log, store, mockMlsValidation, natsServer, func(ctx context.Context, wm *wakupb.WakuMessage) error { return nil }) require.NoError(t, err) - return svc, db, mlsValidationService, func() { + return svc, db, mockMlsValidation, func() { svc.Close() natsServer.Shutdown() mlsDbCleanup() @@ -112,9 +81,9 @@ func TestRegisterInstallation(t *testing.T) { defer cleanup() installationId := test.RandomBytes(32) - accountAddress := test.RandomString(32) + inboxId := test.RandomInboxId() - mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) + mockValidateInboxIdKeyPackages(mlsValidationService, installationId, inboxId) res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ KeyPackage: &mlsv1.KeyPackageUpload{ @@ -129,7 +98,7 @@ func TestRegisterInstallation(t *testing.T) { installation, err := queries.New(mlsDb.DB).GetInstallation(ctx, installationId) require.NoError(t, err) - require.Equal(t, accountAddress, installation.WalletAddress) + require.Equal(t, inboxId, installation.InboxID) } func TestRegisterInstallationError(t *testing.T) { @@ -137,7 +106,7 @@ func TestRegisterInstallationError(t *testing.T) { svc, _, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() - mlsValidationService.On("ValidateV3KeyPackages", ctx, mock.Anything).Return(nil, errors.New("error validating")) + mlsValidationService.EXPECT().ValidateInboxIdKeyPackages(mock.Anything, mock.Anything).Return(nil, errors.New("error validating")) res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ KeyPackage: &mlsv1.KeyPackageUpload{ @@ -155,15 +124,15 @@ func TestUploadKeyPackage(t *testing.T) { defer cleanup() installationId := test.RandomBytes(32) - accountAddress := test.RandomString(32) + inboxId := test.RandomInboxId() - mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) + mockValidateInboxIdKeyPackages(mlsValidationService, installationId, inboxId) res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, - IsInboxIdCredential: false, + IsInboxIdCredential: true, }) require.NoError(t, err) require.NotNil(t, res) @@ -172,14 +141,14 @@ func TestUploadKeyPackage(t *testing.T) { KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test2"), }, - IsInboxIdCredential: false, + IsInboxIdCredential: true, }) require.NoError(t, err) require.NotNil(t, uploadRes) installation, err := queries.New(mlsDb.DB).GetInstallation(ctx, installationId) require.NoError(t, err) - require.Equal(t, accountAddress, installation.WalletAddress) + require.Equal(t, []byte("test2"), installation.KeyPackage) } func TestFetchKeyPackages(t *testing.T) { @@ -188,9 +157,9 @@ func TestFetchKeyPackages(t *testing.T) { defer cleanup() installationId1 := test.RandomBytes(32) - accountAddress1 := test.RandomString(32) + inboxId := test.RandomInboxId() - mockCall := mlsValidationService.mockValidateKeyPackages(installationId1, accountAddress1) + mockCall := mockValidateInboxIdKeyPackages(mlsValidationService, installationId1, inboxId) res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ KeyPackage: &mlsv1.KeyPackageUpload{ @@ -203,10 +172,10 @@ func TestFetchKeyPackages(t *testing.T) { // Add a second key package installationId2 := test.RandomBytes(32) - accountAddress2 := test.RandomString(32) // Unset the original mock so we can set a new one mockCall.Unset() - mlsValidationService.mockValidateKeyPackages(installationId2, accountAddress2) + + mockValidateInboxIdKeyPackages(mlsValidationService, installationId2, inboxId) res, err = svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ KeyPackage: &mlsv1.KeyPackageUpload{ @@ -258,7 +227,7 @@ func TestSendGroupMessages(t *testing.T) { groupId := []byte(test.RandomString(32)) - mlsValidationService.mockValidateGroupMessages(groupId) + mockValidateGroupMessages(mlsValidationService, groupId) _, err := svc.SendGroupMessages(ctx, &mlsv1.SendGroupMessagesRequest{ Messages: []*mlsv1.GroupMessageInput{ @@ -313,57 +282,6 @@ func TestSendWelcomeMessages(t *testing.T) { require.NotEmpty(t, resp.Messages[0].GetV1().CreatedNs) } -func TestGetIdentityUpdates(t *testing.T) { - ctx := context.Background() - svc, _, mlsValidationService, cleanup := newTestService(t, ctx) - defer cleanup() - - installationId := test.RandomBytes(32) - accountAddress := test.RandomString(32) - - mockCall := mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) - - _, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ - KeyPackage: &mlsv1.KeyPackageUpload{ - KeyPackageTlsSerialized: []byte("test"), - }, - IsInboxIdCredential: false, - }) - require.NoError(t, err) - - identityUpdates, err := svc.GetIdentityUpdates(ctx, &mlsv1.GetIdentityUpdatesRequest{ - AccountAddresses: []string{accountAddress}, - }) - require.NoError(t, err) - require.NotNil(t, identityUpdates) - require.Len(t, identityUpdates.Updates, 1) - require.Equal(t, identityUpdates.Updates[0].Updates[0].GetNewInstallation().InstallationKey, installationId) - require.Equal(t, identityUpdates.Updates[0].Updates[0].GetNewInstallation().CredentialIdentity, []byte("test")) - - for _, walletUpdate := range identityUpdates.Updates { - for _, update := range walletUpdate.Updates { - require.Equal(t, installationId, update.GetNewInstallation().InstallationKey) - } - } - - mockCall.Unset() - mlsValidationService.mockValidateKeyPackages(test.RandomBytes(32), accountAddress) - _, err = svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ - KeyPackage: &mlsv1.KeyPackageUpload{ - KeyPackageTlsSerialized: []byte("test"), - }, - IsInboxIdCredential: false, - }) - require.NoError(t, err) - - identityUpdates, err = svc.GetIdentityUpdates(ctx, &mlsv1.GetIdentityUpdatesRequest{ - AccountAddresses: []string{accountAddress}, - }) - require.NoError(t, err) - require.Len(t, identityUpdates.Updates, 1) - require.Len(t, identityUpdates.Updates[0].Updates, 2) -} - func TestSubscribeGroupMessages_WithoutCursor(t *testing.T) { ctx := context.Background() svc, _, _, cleanup := newTestService(t, ctx) @@ -385,11 +303,10 @@ func TestSubscribeGroupMessages_WithoutCursor(t *testing.T) { } } - ctrl := gomock.NewController(t) - stream := NewMockMlsApi_SubscribeGroupMessagesServer(ctrl) - stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) + stream := mocks.NewMockMlsApi_SubscribeGroupMessagesServer(t) + stream.EXPECT().SendHeader(metadata.New(map[string]string{"subscribed": "true"})).Return(nil) for _, msg := range msgs { - stream.EXPECT().Send(newGroupMessageEqualsMatcher(msg)).Return(nil).Times(1) + stream.EXPECT().Send(mock.MatchedBy(newGroupMessageEqualsMatcher(msg).Matches)).Return(nil).Times(1) } stream.EXPECT().Context().Return(ctx) @@ -417,7 +334,7 @@ func TestSubscribeGroupMessages_WithoutCursor(t *testing.T) { require.NoError(t, err) } - require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) + assertExpectationsWithTimeout(t, &stream.Mock, 5*time.Second, 100*time.Millisecond) } func TestSubscribeGroupMessages_WithCursor(t *testing.T) { @@ -428,7 +345,7 @@ func TestSubscribeGroupMessages_WithCursor(t *testing.T) { groupId := []byte(test.RandomString(32)) // Initial message before stream starts. - mlsValidationService.mockValidateGroupMessages(groupId) + mockValidateGroupMessages(mlsValidationService, groupId) initialMsgs := []*mlsv1.GroupMessageInput{ { Version: &mlsv1.GroupMessageInput_V1_{ @@ -474,22 +391,20 @@ func TestSubscribeGroupMessages_WithCursor(t *testing.T) { } } - // Set up expectations of streaming the 11 messages from cursor. - ctrl := gomock.NewController(t) - stream := NewMockMlsApi_SubscribeGroupMessagesServer(ctrl) - stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) - stream.EXPECT().Send(newGroupMessageIdAndDataEqualsMatcher(&mlsv1.GroupMessage{ + stream := mocks.NewMockMlsApi_SubscribeGroupMessagesServer(t) + stream.EXPECT().SendHeader(metadata.New(map[string]string{"subscribed": "true"})).Return(nil) + stream.EXPECT().Send(mock.MatchedBy(newGroupMessageIdAndDataEqualsMatcher(&mlsv1.GroupMessage{ Version: &mlsv1.GroupMessage_V1_{ V1: &mlsv1.GroupMessage_V1{ Id: 3, Data: []byte("data3"), }, }, - })).Return(nil).Times(1) + }).Matches)).Return(nil).Times(1) for _, msg := range msgs { - stream.EXPECT().Send(newGroupMessageEqualsMatcher(msg)).Return(nil).Times(1) + stream.EXPECT().Send(mock.MatchedBy(newGroupMessageEqualsMatcher(msg).Matches)).Return(nil).Times(1) } - stream.EXPECT().Context().Return(ctx).AnyTimes() + stream.EXPECT().Context().Return(ctx) go func() { err := svc.SubscribeGroupMessages(&mlsv1.SubscribeGroupMessagesRequest{ @@ -518,7 +433,7 @@ func TestSubscribeGroupMessages_WithCursor(t *testing.T) { } // Expectations should eventually be satisfied. - require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) + assertExpectationsWithTimeout(t, &stream.Mock, 5*time.Second, 100*time.Millisecond) } func TestSubscribeWelcomeMessages_WithoutCursor(t *testing.T) { @@ -543,11 +458,10 @@ func TestSubscribeWelcomeMessages_WithoutCursor(t *testing.T) { } } - ctrl := gomock.NewController(t) - stream := NewMockMlsApi_SubscribeWelcomeMessagesServer(ctrl) - stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) + stream := mocks.NewMockMlsApi_SubscribeWelcomeMessagesServer(t) + stream.EXPECT().SendHeader(metadata.New(map[string]string{"subscribed": "true"})).Return(nil) for _, msg := range msgs { - stream.EXPECT().Send(newWelcomeMessageEqualsMatcher(msg)).Return(nil).Times(1) + stream.EXPECT().Send(mock.MatchedBy(newWelcomeMessageEqualsMatcher(msg).Matches)).Return(nil).Times(1) } stream.EXPECT().Context().Return(ctx) @@ -575,7 +489,7 @@ func TestSubscribeWelcomeMessages_WithoutCursor(t *testing.T) { require.NoError(t, err) } - require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) + assertExpectationsWithTimeout(t, &stream.Mock, 5*time.Second, 100*time.Millisecond) } func TestSubscribeWelcomeMessages_WithCursor(t *testing.T) { @@ -640,10 +554,9 @@ func TestSubscribeWelcomeMessages_WithCursor(t *testing.T) { } // Set up expectations of streaming the 11 messages from cursor. - ctrl := gomock.NewController(t) - stream := NewMockMlsApi_SubscribeWelcomeMessagesServer(ctrl) - stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) - stream.EXPECT().Send(newWelcomeMessageEqualsMatcherWithoutTimestamp(&mlsv1.WelcomeMessage{ + stream := mocks.NewMockMlsApi_SubscribeWelcomeMessagesServer(t) + stream.EXPECT().SendHeader(metadata.New(map[string]string{"subscribed": "true"})).Return(nil) + stream.EXPECT().Send(mock.MatchedBy(newWelcomeMessageEqualsMatcherWithoutTimestamp(&mlsv1.WelcomeMessage{ Version: &mlsv1.WelcomeMessage_V1_{ V1: &mlsv1.WelcomeMessage_V1{ Id: 3, @@ -652,11 +565,11 @@ func TestSubscribeWelcomeMessages_WithCursor(t *testing.T) { Data: []byte("data3"), }, }, - })).Return(nil).Times(1) + }).Matches)).Return(nil).Times(1) for _, msg := range msgs { - stream.EXPECT().Send(newWelcomeMessageEqualsMatcher(msg)).Return(nil).Times(1) + stream.EXPECT().Send(mock.MatchedBy(newWelcomeMessageEqualsMatcher(msg).Matches)).Return(nil).Times(1) } - stream.EXPECT().Context().Return(ctx).AnyTimes() + stream.EXPECT().Context().Return(ctx) go func() { err := svc.SubscribeWelcomeMessages(&mlsv1.SubscribeWelcomeMessagesRequest{ @@ -685,7 +598,7 @@ func TestSubscribeWelcomeMessages_WithCursor(t *testing.T) { } // Expectations should eventually be satisfied. - require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) + assertExpectationsWithTimeout(t, &stream.Mock, 5*time.Second, 100*time.Millisecond) } type groupMessageEqualsMatcher struct { @@ -755,3 +668,23 @@ func (m *welcomeMessageEqualsMatcherWithoutTimestamp) Matches(obj interface{}) b func (m *welcomeMessageEqualsMatcherWithoutTimestamp) String() string { return m.obj.String() } + +func assertExpectationsWithTimeout(t *testing.T, mockObj *mock.Mock, timeout, interval time.Duration) { + timeoutChan := time.After(timeout) + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-timeoutChan: + if !mockObj.AssertExpectations(t) { + t.Error("Expectations were not met within the timeout period") + } + return + case <-ticker.C: + if mockObj.AssertExpectations(t) { + return + } + } + } +} diff --git a/pkg/mls/store/queries.sql b/pkg/mls/store/queries.sql index 13420c6c..561c69ae 100644 --- a/pkg/mls/store/queries.sql +++ b/pkg/mls/store/queries.sql @@ -4,17 +4,22 @@ SELECT -- name: GetAllInboxLogs :many SELECT - * + sequence_id, + encode(inbox_id, 'hex') AS inbox_id, + identity_update_proto FROM inbox_log WHERE - inbox_id = $1 + inbox_id = decode(@inbox_id, 'hex') ORDER BY sequence_id ASC; -- name: GetInboxLogFiltered :many SELECT - a.* + a.sequence_id, + encode(a.inbox_id, 'hex') AS inbox_id, + a.identity_update_proto, + a.server_timestamp_ns FROM inbox_log AS a JOIN ( @@ -22,7 +27,7 @@ FROM * FROM json_populate_recordset(NULL::inbox_filter, @filters) AS b(inbox_id, - sequence_id)) AS b ON b.inbox_id = a.inbox_id + sequence_id)) AS b ON decode(b.inbox_id, 'hex') = a.inbox_id::BYTEA AND a.sequence_id > b.sequence_id ORDER BY a.sequence_id ASC; @@ -30,7 +35,7 @@ FROM -- name: GetAddressLogs :many SELECT a.address, - a.inbox_id, + encode(a.inbox_id, 'hex') AS inbox_id, a.association_sequence_id FROM address_log a @@ -49,13 +54,13 @@ FROM -- name: InsertAddressLog :one INSERT INTO address_log(address, inbox_id, association_sequence_id, revocation_sequence_id) - VALUES ($1, $2, $3, $4) + VALUES (@address, decode(@inbox_id, 'hex'), @association_sequence_id, @revocation_sequence_id) RETURNING *; -- name: InsertInboxLog :one INSERT INTO inbox_log(inbox_id, server_timestamp_ns, identity_update_proto) - VALUES ($1, $2, $3) + VALUES (decode(@inbox_id, 'hex'), @server_timestamp_ns, @identity_update_proto) RETURNING sequence_id; @@ -63,7 +68,7 @@ RETURNING UPDATE address_log SET - revocation_sequence_id = $1 + revocation_sequence_id = @revocation_sequence_id WHERE (address, inbox_id, association_sequence_id) =( SELECT address, @@ -72,19 +77,24 @@ WHERE (address, inbox_id, association_sequence_id) =( FROM address_log AS a WHERE - a.address = $2 - AND a.inbox_id = $3 + a.address = @address + AND a.inbox_id = decode(@inbox_id, 'hex') GROUP BY address, inbox_id); -- name: CreateInstallation :exec -INSERT INTO installations(id, wallet_address, created_at, updated_at, credential_identity, key_package, expiration) - VALUES ($1, $2, $3, $3, $4, $5, $6); +INSERT INTO installations(id, created_at, updated_at, inbox_id, key_package, expiration) + VALUES (@id, @created_at, @updated_at, decode(@inbox_id, 'hex'), @key_package, @expiration); -- name: GetInstallation :one SELECT - * + id, + created_at, + updated_at, + encode(inbox_id, 'hex') AS inbox_id, + key_package, + expiration FROM installations WHERE @@ -109,27 +119,6 @@ FROM WHERE id = ANY (@installation_ids::BYTEA[]); --- name: GetIdentityUpdates :many -SELECT - * -FROM - installations -WHERE - wallet_address = ANY (@wallet_addresses::TEXT[]) - AND (created_at > @start_time - OR revoked_at > @start_time) -ORDER BY - created_at ASC; - --- name: RevokeInstallation :exec -UPDATE - installations -SET - revoked_at = @revoked_at -WHERE - id = @installation_id - AND revoked_at IS NULL; - -- name: InsertGroupMessage :one INSERT INTO group_messages(group_id, data, group_id_data_hash) VALUES ($1, $2, $3) diff --git a/pkg/mls/store/queries/models.go b/pkg/mls/store/queries/models.go index 711ac5b8..462220c6 100644 --- a/pkg/mls/store/queries/models.go +++ b/pkg/mls/store/queries/models.go @@ -11,7 +11,7 @@ import ( type AddressLog struct { Address string - InboxID string + InboxID []byte AssociationSequenceID sql.NullInt64 RevocationSequenceID sql.NullInt64 } @@ -26,20 +26,18 @@ type GroupMessage struct { type InboxLog struct { SequenceID int64 - InboxID string + InboxID []byte ServerTimestampNs int64 IdentityUpdateProto []byte } type Installation struct { - ID []byte - WalletAddress string - CreatedAt int64 - UpdatedAt int64 - CredentialIdentity []byte - RevokedAt sql.NullInt64 - KeyPackage []byte - Expiration int64 + ID []byte + CreatedAt int64 + UpdatedAt int64 + InboxID []byte + KeyPackage []byte + Expiration int64 } type WelcomeMessage struct { @@ -47,6 +45,6 @@ type WelcomeMessage struct { CreatedAt time.Time InstallationKey []byte Data []byte - InstallationKeyDataHash []byte HpkePublicKey []byte + InstallationKeyDataHash []byte } diff --git a/pkg/mls/store/queries/queries.sql.go b/pkg/mls/store/queries/queries.sql.go index 2543dd85..33428a08 100644 --- a/pkg/mls/store/queries/queries.sql.go +++ b/pkg/mls/store/queries/queries.sql.go @@ -14,25 +14,25 @@ import ( ) const createInstallation = `-- name: CreateInstallation :exec -INSERT INTO installations(id, wallet_address, created_at, updated_at, credential_identity, key_package, expiration) - VALUES ($1, $2, $3, $3, $4, $5, $6) +INSERT INTO installations(id, created_at, updated_at, inbox_id, key_package, expiration) + VALUES ($1, $2, $3, decode($4, 'hex'), $5, $6) ` type CreateInstallationParams struct { - ID []byte - WalletAddress string - CreatedAt int64 - CredentialIdentity []byte - KeyPackage []byte - Expiration int64 + ID []byte + CreatedAt int64 + UpdatedAt int64 + InboxID string + KeyPackage []byte + Expiration int64 } func (q *Queries) CreateInstallation(ctx context.Context, arg CreateInstallationParams) error { _, err := q.db.ExecContext(ctx, createInstallation, arg.ID, - arg.WalletAddress, arg.CreatedAt, - arg.CredentialIdentity, + arg.UpdatedAt, + arg.InboxID, arg.KeyPackage, arg.Expiration, ) @@ -80,7 +80,7 @@ func (q *Queries) FetchKeyPackages(ctx context.Context, installationIds [][]byte const getAddressLogs = `-- name: GetAddressLogs :many SELECT a.address, - a.inbox_id, + encode(a.inbox_id, 'hex') AS inbox_id, a.association_sequence_id FROM address_log a @@ -167,30 +167,33 @@ func (q *Queries) GetAllGroupMessages(ctx context.Context) ([]GroupMessage, erro const getAllInboxLogs = `-- name: GetAllInboxLogs :many SELECT - sequence_id, inbox_id, server_timestamp_ns, identity_update_proto + sequence_id, + encode(inbox_id, 'hex') AS inbox_id, + identity_update_proto FROM inbox_log WHERE - inbox_id = $1 + inbox_id = decode($1, 'hex') ORDER BY sequence_id ASC ` -func (q *Queries) GetAllInboxLogs(ctx context.Context, inboxID string) ([]InboxLog, error) { +type GetAllInboxLogsRow struct { + SequenceID int64 + InboxID string + IdentityUpdateProto []byte +} + +func (q *Queries) GetAllInboxLogs(ctx context.Context, inboxID string) ([]GetAllInboxLogsRow, error) { rows, err := q.db.QueryContext(ctx, getAllInboxLogs, inboxID) if err != nil { return nil, err } defer rows.Close() - var items []InboxLog + var items []GetAllInboxLogsRow for rows.Next() { - var i InboxLog - if err := rows.Scan( - &i.SequenceID, - &i.InboxID, - &i.ServerTimestampNs, - &i.IdentityUpdateProto, - ); err != nil { + var i GetAllInboxLogsRow + if err := rows.Scan(&i.SequenceID, &i.InboxID, &i.IdentityUpdateProto); err != nil { return nil, err } items = append(items, i) @@ -206,7 +209,7 @@ func (q *Queries) GetAllInboxLogs(ctx context.Context, inboxID string) ([]InboxL const getAllWelcomeMessages = `-- name: GetAllWelcomeMessages :many SELECT - id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key + id, created_at, installation_key, data, hpke_public_key, installation_key_data_hash FROM welcome_messages ORDER BY @@ -227,58 +230,8 @@ func (q *Queries) GetAllWelcomeMessages(ctx context.Context) ([]WelcomeMessage, &i.CreatedAt, &i.InstallationKey, &i.Data, - &i.InstallationKeyDataHash, &i.HpkePublicKey, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const getIdentityUpdates = `-- name: GetIdentityUpdates :many -SELECT - id, wallet_address, created_at, updated_at, credential_identity, revoked_at, key_package, expiration -FROM - installations -WHERE - wallet_address = ANY ($1::TEXT[]) - AND (created_at > $2 - OR revoked_at > $2) -ORDER BY - created_at ASC -` - -type GetIdentityUpdatesParams struct { - WalletAddresses []string - StartTime int64 -} - -func (q *Queries) GetIdentityUpdates(ctx context.Context, arg GetIdentityUpdatesParams) ([]Installation, error) { - rows, err := q.db.QueryContext(ctx, getIdentityUpdates, pq.Array(arg.WalletAddresses), arg.StartTime) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Installation - for rows.Next() { - var i Installation - if err := rows.Scan( - &i.ID, - &i.WalletAddress, - &i.CreatedAt, - &i.UpdatedAt, - &i.CredentialIdentity, - &i.RevokedAt, - &i.KeyPackage, - &i.Expiration, + &i.InstallationKeyDataHash, ); err != nil { return nil, err } @@ -295,7 +248,10 @@ func (q *Queries) GetIdentityUpdates(ctx context.Context, arg GetIdentityUpdates const getInboxLogFiltered = `-- name: GetInboxLogFiltered :many SELECT - a.sequence_id, a.inbox_id, a.server_timestamp_ns, a.identity_update_proto + a.sequence_id, + encode(a.inbox_id, 'hex') AS inbox_id, + a.identity_update_proto, + a.server_timestamp_ns FROM inbox_log AS a JOIN ( @@ -303,26 +259,33 @@ FROM inbox_id, sequence_id FROM json_populate_recordset(NULL::inbox_filter, $1) AS b(inbox_id, - sequence_id)) AS b ON b.inbox_id = a.inbox_id + sequence_id)) AS b ON decode(b.inbox_id, 'hex') = a.inbox_id::BYTEA AND a.sequence_id > b.sequence_id ORDER BY a.sequence_id ASC ` -func (q *Queries) GetInboxLogFiltered(ctx context.Context, filters json.RawMessage) ([]InboxLog, error) { +type GetInboxLogFilteredRow struct { + SequenceID int64 + InboxID string + IdentityUpdateProto []byte + ServerTimestampNs int64 +} + +func (q *Queries) GetInboxLogFiltered(ctx context.Context, filters json.RawMessage) ([]GetInboxLogFilteredRow, error) { rows, err := q.db.QueryContext(ctx, getInboxLogFiltered, filters) if err != nil { return nil, err } defer rows.Close() - var items []InboxLog + var items []GetInboxLogFilteredRow for rows.Next() { - var i InboxLog + var i GetInboxLogFilteredRow if err := rows.Scan( &i.SequenceID, &i.InboxID, - &i.ServerTimestampNs, &i.IdentityUpdateProto, + &i.ServerTimestampNs, ); err != nil { return nil, err } @@ -339,23 +302,35 @@ func (q *Queries) GetInboxLogFiltered(ctx context.Context, filters json.RawMessa const getInstallation = `-- name: GetInstallation :one SELECT - id, wallet_address, created_at, updated_at, credential_identity, revoked_at, key_package, expiration + id, + created_at, + updated_at, + encode(inbox_id, 'hex') AS inbox_id, + key_package, + expiration FROM installations WHERE id = $1 ` -func (q *Queries) GetInstallation(ctx context.Context, id []byte) (Installation, error) { +type GetInstallationRow struct { + ID []byte + CreatedAt int64 + UpdatedAt int64 + InboxID string + KeyPackage []byte + Expiration int64 +} + +func (q *Queries) GetInstallation(ctx context.Context, id []byte) (GetInstallationRow, error) { row := q.db.QueryRowContext(ctx, getInstallation, id) - var i Installation + var i GetInstallationRow err := row.Scan( &i.ID, - &i.WalletAddress, &i.CreatedAt, &i.UpdatedAt, - &i.CredentialIdentity, - &i.RevokedAt, + &i.InboxID, &i.KeyPackage, &i.Expiration, ) @@ -364,7 +339,7 @@ func (q *Queries) GetInstallation(ctx context.Context, id []byte) (Installation, const insertAddressLog = `-- name: InsertAddressLog :one INSERT INTO address_log(address, inbox_id, association_sequence_id, revocation_sequence_id) - VALUES ($1, $2, $3, $4) + VALUES ($1, decode($2, 'hex'), $3, $4) RETURNING address, inbox_id, association_sequence_id, revocation_sequence_id ` @@ -421,7 +396,7 @@ func (q *Queries) InsertGroupMessage(ctx context.Context, arg InsertGroupMessage const insertInboxLog = `-- name: InsertInboxLog :one INSERT INTO inbox_log(inbox_id, server_timestamp_ns, identity_update_proto) - VALUES ($1, $2, $3) + VALUES (decode($1, 'hex'), $2, $3) RETURNING sequence_id ` @@ -443,7 +418,7 @@ const insertWelcomeMessage = `-- name: InsertWelcomeMessage :one INSERT INTO welcome_messages(installation_key, data, installation_key_data_hash, hpke_public_key) VALUES ($1, $2, $3, $4) RETURNING - id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key + id, created_at, installation_key, data, hpke_public_key, installation_key_data_hash ` type InsertWelcomeMessageParams struct { @@ -466,8 +441,8 @@ func (q *Queries) InsertWelcomeMessage(ctx context.Context, arg InsertWelcomeMes &i.CreatedAt, &i.InstallationKey, &i.Data, - &i.InstallationKeyDataHash, &i.HpkePublicKey, + &i.InstallationKeyDataHash, ) return i, err } @@ -632,7 +607,7 @@ func (q *Queries) QueryGroupMessagesWithCursorDesc(ctx context.Context, arg Quer const queryWelcomeMessages = `-- name: QueryWelcomeMessages :many SELECT - id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key + id, created_at, installation_key, data, hpke_public_key, installation_key_data_hash FROM welcome_messages WHERE @@ -667,8 +642,8 @@ func (q *Queries) QueryWelcomeMessages(ctx context.Context, arg QueryWelcomeMess &i.CreatedAt, &i.InstallationKey, &i.Data, - &i.InstallationKeyDataHash, &i.HpkePublicKey, + &i.InstallationKeyDataHash, ); err != nil { return nil, err } @@ -685,7 +660,7 @@ func (q *Queries) QueryWelcomeMessages(ctx context.Context, arg QueryWelcomeMess const queryWelcomeMessagesWithCursorAsc = `-- name: QueryWelcomeMessagesWithCursorAsc :many SELECT - id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key + id, created_at, installation_key, data, hpke_public_key, installation_key_data_hash FROM welcome_messages WHERE @@ -716,8 +691,8 @@ func (q *Queries) QueryWelcomeMessagesWithCursorAsc(ctx context.Context, arg Que &i.CreatedAt, &i.InstallationKey, &i.Data, - &i.InstallationKeyDataHash, &i.HpkePublicKey, + &i.InstallationKeyDataHash, ); err != nil { return nil, err } @@ -734,7 +709,7 @@ func (q *Queries) QueryWelcomeMessagesWithCursorAsc(ctx context.Context, arg Que const queryWelcomeMessagesWithCursorDesc = `-- name: QueryWelcomeMessagesWithCursorDesc :many SELECT - id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key + id, created_at, installation_key, data, hpke_public_key, installation_key_data_hash FROM welcome_messages WHERE @@ -765,8 +740,8 @@ func (q *Queries) QueryWelcomeMessagesWithCursorDesc(ctx context.Context, arg Qu &i.CreatedAt, &i.InstallationKey, &i.Data, - &i.InstallationKeyDataHash, &i.HpkePublicKey, + &i.InstallationKeyDataHash, ); err != nil { return nil, err } @@ -795,7 +770,7 @@ WHERE (address, inbox_id, association_sequence_id) =( address_log AS a WHERE a.address = $2 - AND a.inbox_id = $3 + AND a.inbox_id = decode($3, 'hex') GROUP BY address, inbox_id) @@ -812,26 +787,6 @@ func (q *Queries) RevokeAddressFromLog(ctx context.Context, arg RevokeAddressFro return err } -const revokeInstallation = `-- name: RevokeInstallation :exec -UPDATE - installations -SET - revoked_at = $1 -WHERE - id = $2 - AND revoked_at IS NULL -` - -type RevokeInstallationParams struct { - RevokedAt sql.NullInt64 - InstallationID []byte -} - -func (q *Queries) RevokeInstallation(ctx context.Context, arg RevokeInstallationParams) error { - _, err := q.db.ExecContext(ctx, revokeInstallation, arg.RevokedAt, arg.InstallationID) - return err -} - const updateKeyPackage = `-- name: UpdateKeyPackage :execrows UPDATE installations diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index 2d1bb6dd..0aa62af4 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "database/sql" "errors" - "sort" "strings" "time" @@ -40,10 +39,9 @@ type IdentityStore interface { type MlsStore interface { IdentityStore - CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, credentialIdentity, keyPackage []byte, expiration uint64) error + CreateInstallation(ctx context.Context, installationId []byte, inboxId string, keyPackage []byte, expiration uint64) error UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]queries.FetchKeyPackagesRow, error) - GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*queries.GroupMessage, error) InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte, hpkePublicKey []byte) (*queries.WelcomeMessage, error) QueryGroupMessagesV1(ctx context.Context, query *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) @@ -162,7 +160,7 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish if address, ok := new_member.Kind.(*associations.MemberIdentifier_Address); ok { _, err = txQueries.InsertAddressLog(ctx, queries.InsertAddressLogParams{ Address: address.Address, - InboxID: state.AssociationState.InboxId, + InboxID: inboxId, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: sequence_id}, RevocationSequenceID: sql.NullInt64{Valid: false}, }) @@ -177,7 +175,7 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish if address, ok := removed_member.Kind.(*associations.MemberIdentifier_Address); ok { err = txQueries.RevokeAddressFromLog(ctx, queries.RevokeAddressFromLogParams{ Address: address.Address, - InboxID: state.AssociationState.InboxId, + InboxID: inboxId, RevocationSequenceID: sql.NullInt64{Valid: true, Int64: sequence_id}, }) if err != nil { @@ -200,7 +198,7 @@ func (s *Store) GetInboxLogs(ctx context.Context, batched_req *identity.GetIdent filters := make(queries.InboxLogFilterList, len(reqs)) for i, req := range reqs { filters[i] = queries.InboxLogFilter{ - InboxId: req.InboxId, + InboxId: req.InboxId, // InboxLogFilters take inbox_id as text and decode it inside Postgres, since the filters are JSON SequenceId: int64(req.SequenceId), } } @@ -215,9 +213,10 @@ func (s *Store) GetInboxLogs(ctx context.Context, batched_req *identity.GetIdent } // Organize the results by inbox ID - resultMap := make(map[string][]queries.InboxLog) + resultMap := make(map[string][]queries.GetInboxLogFilteredRow) for _, result := range results { - resultMap[result.InboxID] = append(resultMap[result.InboxID], result) + inboxId := result.InboxID + resultMap[inboxId] = append(resultMap[inboxId], result) } resps := make([]*identity.GetIdentityUpdatesResponse_Response, len(reqs)) @@ -247,16 +246,15 @@ func (s *Store) GetInboxLogs(ctx context.Context, batched_req *identity.GetIdent } // Creates the installation and last resort key package -func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, credentialIdentity, keyPackage []byte, expiration uint64) error { +func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, inboxId string, keyPackage []byte, expiration uint64) error { createdAt := nowNs() return s.queries.CreateInstallation(ctx, queries.CreateInstallationParams{ - ID: installationId, - WalletAddress: walletAddress, - CreatedAt: createdAt, - CredentialIdentity: credentialIdentity, - KeyPackage: keyPackage, - Expiration: int64(expiration), + ID: installationId, + CreatedAt: createdAt, + InboxID: inboxId, + KeyPackage: keyPackage, + Expiration: int64(expiration), }) } @@ -284,49 +282,6 @@ func (s *Store) FetchKeyPackages(ctx context.Context, installationIds [][]byte) return s.queries.FetchKeyPackages(ctx, installationIds) } -func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) { - updated, err := s.queries.GetIdentityUpdates(ctx, queries.GetIdentityUpdatesParams{ - WalletAddresses: walletAddresses, - StartTime: startTimeNs, - }) - if err != nil { - return nil, err - } - - // The returned list is only partially sorted - out := make(map[string]IdentityUpdateList) - for _, installation := range updated { - if installation.CreatedAt > startTimeNs { - out[installation.WalletAddress] = append(out[installation.WalletAddress], IdentityUpdate{ - Kind: Create, - InstallationKey: installation.ID, - CredentialIdentity: installation.CredentialIdentity, - TimestampNs: uint64(installation.CreatedAt), - }) - } - if installation.RevokedAt.Valid && installation.RevokedAt.Int64 > startTimeNs { - out[installation.WalletAddress] = append(out[installation.WalletAddress], IdentityUpdate{ - Kind: Revoke, - InstallationKey: installation.ID, - TimestampNs: uint64(installation.RevokedAt.Int64), - }) - } - } - // Sort the updates by timestamp now that the full list is assembled - for _, updates := range out { - sort.Sort(updates) - } - - return out, nil -} - -func (s *Store) RevokeInstallation(ctx context.Context, installationId []byte) error { - return s.queries.RevokeInstallation(ctx, queries.RevokeInstallationParams{ - RevokedAt: sql.NullInt64{Valid: true, Int64: nowNs()}, - InstallationID: installationId, - }) -} - func (s *Store) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*queries.GroupMessage, error) { dataHash := sha256.Sum256(append(groupId, data...)) message, err := s.queries.InsertGroupMessage(ctx, queries.InsertGroupMessageParams{ diff --git a/pkg/mls/store/store_test.go b/pkg/mls/store/store_test.go index 65109fbc..1ff9f73d 100644 --- a/pkg/mls/store/store_test.go +++ b/pkg/mls/store/store_test.go @@ -5,20 +5,21 @@ import ( "database/sql" "errors" "fmt" - "sort" "sync" "testing" "time" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" - "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate/mocks" + "github.com/xmtp/xmtp-node-go/pkg/mocks" identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" test "github.com/xmtp/xmtp-node-go/pkg/testing" - "go.uber.org/mock/gomock" + testutils "github.com/xmtp/xmtp-node-go/pkg/testing" + "go.uber.org/zap" ) func NewTestStore(t *testing.T) (*Store, func()) { @@ -44,14 +45,13 @@ func TestPublishIdentityUpdateParallel(t *testing.T) { // Create a mapping of inboxes to addresses inboxes := make(map[string]string) for i := 0; i < 50; i++ { - inboxes[fmt.Sprintf("inbox_%d", i)] = fmt.Sprintf("address_%d", i) + inboxes[testutils.RandomInboxId()] = fmt.Sprintf("address_%d", i) } - mockController := gomock.NewController(t) - mockMlsValidation := mocks.NewMockMLSValidationService(mockController) + mockMlsValidation := mocks.NewMockMLSValidationService(t) // For each inbox_id in the map, return an AssociationStateDiff that adds the corresponding address - mockMlsValidation.EXPECT().GetAssociationState(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ any, _ any, updates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) { + mockMlsValidation.EXPECT().GetAssociationState(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(_ context.Context, _ []*associations.IdentityUpdate, updates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) { inboxId := updates[0].InboxId address, ok := inboxes[inboxId] @@ -71,7 +71,7 @@ func TestPublishIdentityUpdateParallel(t *testing.T) { }}, }, }, nil - }).AnyTimes() + }) var wg sync.WaitGroup for inboxId := range inboxes { @@ -95,13 +95,18 @@ func TestInboxIds(t *testing.T) { defer cleanup() ctx := context.Background() - _, err := store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 1}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + inbox1 := testutils.RandomInboxId() + inbox2 := testutils.RandomInboxId() + correctInbox := testutils.RandomInboxId() + correctInbox2 := testutils.RandomInboxId() + + _, err := store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: inbox1, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 1}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) - _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 2}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: inbox1, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 2}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) - _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 3}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: inbox1, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 3}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) - _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "correct", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 4}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: correctInbox, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 4}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) reqs := make([]*identity.GetInboxIdsRequest_Request, 0) @@ -113,21 +118,21 @@ func TestInboxIds(t *testing.T) { } resp, _ := store.GetInboxIds(context.Background(), req) - require.Equal(t, "correct", *resp.Responses[0].InboxId) + require.Equal(t, correctInbox, *resp.Responses[0].InboxId) - _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "correct_inbox2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 5}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: correctInbox2, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 5}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) resp, _ = store.GetInboxIds(context.Background(), req) - require.Equal(t, "correct_inbox2", *resp.Responses[0].InboxId) + require.Equal(t, correctInbox2, *resp.Responses[0].InboxId) reqs = append(reqs, &identity.GetInboxIdsRequest_Request{Address: "address2"}) req = &identity.GetInboxIdsRequest{ Requests: reqs, } - _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address2", InboxID: "inbox2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 8}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address2", InboxID: inbox2, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 8}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) resp, _ = store.GetInboxIds(context.Background(), req) - require.Equal(t, "inbox2", *resp.Responses[1].InboxId) + require.Equal(t, inbox2, *resp.Responses[1].InboxId) } func TestMultipleInboxIds(t *testing.T) { @@ -135,9 +140,12 @@ func TestMultipleInboxIds(t *testing.T) { defer cleanup() ctx := context.Background() - _, err := store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address_1", InboxID: "inbox_1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 1}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + inbox1 := testutils.RandomInboxId() + inbox2 := testutils.RandomInboxId() + + _, err := store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address_1", InboxID: inbox1, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 1}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) - _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address_2", InboxID: "inbox_2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 2}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address_2", InboxID: inbox2, AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 2}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) reqs := make([]*identity.GetInboxIdsRequest_Request, 0) @@ -151,8 +159,8 @@ func TestMultipleInboxIds(t *testing.T) { Requests: reqs, } resp, _ := store.GetInboxIds(context.Background(), req) - require.Equal(t, "inbox_1", *resp.Responses[0].InboxId) - require.Equal(t, "inbox_2", *resp.Responses[1].InboxId) + require.Equal(t, inbox1, *resp.Responses[0].InboxId) + require.Equal(t, inbox2, *resp.Responses[1].InboxId) } func TestCreateInstallation(t *testing.T) { @@ -161,14 +169,14 @@ func TestCreateInstallation(t *testing.T) { ctx := context.Background() installationId := test.RandomBytes(32) - walletAddress := test.RandomString(32) + inboxId := test.RandomInboxId() - err := store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32), 0) + err := store.CreateInstallation(ctx, installationId, inboxId, test.RandomBytes(32), 0) require.NoError(t, err) installationFromDb, err := store.queries.GetInstallation(ctx, installationId) require.NoError(t, err) - require.Equal(t, walletAddress, installationFromDb.WalletAddress) + require.Equal(t, installationFromDb.ID, installationId) } func TestUpdateKeyPackage(t *testing.T) { @@ -177,10 +185,10 @@ func TestUpdateKeyPackage(t *testing.T) { ctx := context.Background() installationId := test.RandomBytes(32) - walletAddress := test.RandomString(32) + inboxId := test.RandomInboxId() keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage, 0) + err := store.CreateInstallation(ctx, installationId, inboxId, keyPackage, 0) require.NoError(t, err) keyPackage2 := test.RandomBytes(32) @@ -200,10 +208,10 @@ func TestConsumeLastResortKeyPackage(t *testing.T) { ctx := context.Background() installationId := test.RandomBytes(32) - walletAddress := test.RandomString(32) keyPackage := test.RandomBytes(32) + inboxId := test.RandomInboxId() - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage, 0) + err := store.CreateInstallation(ctx, installationId, inboxId, keyPackage, 0) require.NoError(t, err) fetchResult, err := store.FetchKeyPackages(ctx, [][]byte{installationId}) @@ -213,96 +221,6 @@ func TestConsumeLastResortKeyPackage(t *testing.T) { require.Equal(t, installationId, fetchResult[0].ID) } -func TestGetIdentityUpdates(t *testing.T) { - store, cleanup := NewTestStore(t) - defer cleanup() - - ctx := context.Background() - walletAddress := test.RandomString(32) - - installationId1 := test.RandomBytes(32) - keyPackage1 := test.RandomBytes(32) - - err := store.CreateInstallation(ctx, installationId1, walletAddress, keyPackage1, keyPackage1, 0) - require.NoError(t, err) - - installationId2 := test.RandomBytes(32) - keyPackage2 := test.RandomBytes(32) - - err = store.CreateInstallation(ctx, installationId2, walletAddress, keyPackage2, keyPackage2, 0) - require.NoError(t, err) - - identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress}, 0) - require.NoError(t, err) - require.Len(t, identityUpdates[walletAddress], 2) - require.Equal(t, identityUpdates[walletAddress][0].InstallationKey, installationId1) - require.Equal(t, identityUpdates[walletAddress][0].Kind, Create) - require.Equal(t, identityUpdates[walletAddress][1].InstallationKey, installationId2) - - // Make sure that date filtering works - identityUpdates, err = store.GetIdentityUpdates(ctx, []string{walletAddress}, nowNs()+1000000) - require.NoError(t, err) - require.Len(t, identityUpdates[walletAddress], 0) -} - -func TestGetIdentityUpdatesMultipleWallets(t *testing.T) { - store, cleanup := NewTestStore(t) - defer cleanup() - - ctx := context.Background() - walletAddress1 := test.RandomString(32) - installationId1 := test.RandomBytes(32) - keyPackage1 := test.RandomBytes(32) - - err := store.CreateInstallation(ctx, installationId1, walletAddress1, keyPackage1, keyPackage1, 0) - require.NoError(t, err) - - walletAddress2 := test.RandomString(32) - installationId2 := test.RandomBytes(32) - keyPackage2 := test.RandomBytes(32) - - err = store.CreateInstallation(ctx, installationId2, walletAddress2, keyPackage2, keyPackage2, 0) - require.NoError(t, err) - - identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress1, walletAddress2}, 0) - require.NoError(t, err) - require.Len(t, identityUpdates[walletAddress1], 1) - require.Len(t, identityUpdates[walletAddress2], 1) -} - -func TestGetIdentityUpdatesNoResult(t *testing.T) { - store, cleanup := NewTestStore(t) - defer cleanup() - - ctx := context.Background() - walletAddress := test.RandomString(32) - - identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress}, 0) - require.NoError(t, err) - require.Len(t, identityUpdates[walletAddress], 0) -} - -func TestIdentityUpdateSort(t *testing.T) { - updates := IdentityUpdateList([]IdentityUpdate{ - { - Kind: Create, - TimestampNs: 2, - }, - { - Kind: Create, - TimestampNs: 3, - }, - { - Kind: Create, - TimestampNs: 1, - }, - }) - sort.Sort(updates) - require.Equal(t, updates[0].TimestampNs, uint64(1)) - require.Equal(t, updates[1].TimestampNs, uint64(2)) - require.Equal(t, updates[2].TimestampNs, uint64(3)) -} - func TestInsertGroupMessage_Single(t *testing.T) { started := time.Now().UTC().Add(-time.Minute) store, cleanup := NewTestStore(t) @@ -313,7 +231,8 @@ func TestInsertGroupMessage_Single(t *testing.T) { require.NoError(t, err) require.NotNil(t, msg) require.Equal(t, int64(1), msg.ID) - require.True(t, msg.CreatedAt.Before(time.Now().UTC()) && msg.CreatedAt.After(started)) + store.log.Info("Created at", zap.Time("created_at", msg.CreatedAt)) + require.True(t, msg.CreatedAt.Before(time.Now().UTC().Add(1*time.Minute)) && msg.CreatedAt.After(started)) require.Equal(t, []byte("group"), msg.GroupID) require.Equal(t, []byte("data"), msg.Data) diff --git a/pkg/mlsvalidate/mocks/mock.gen.go b/pkg/mlsvalidate/mocks/mock.gen.go deleted file mode 100644 index fef93f1f..00000000 --- a/pkg/mlsvalidate/mocks/mock.gen.go +++ /dev/null @@ -1,142 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./pkg/mlsvalidate/service.go -// -// Generated by this command: -// -// mockgen -package mocks -source ./pkg/mlsvalidate/service.go MLSValidationService -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - mlsvalidate "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" - apiv1 "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" - associations "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations" - apiv10 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" - gomock "go.uber.org/mock/gomock" -) - -// MockMLSValidationService is a mock of MLSValidationService interface. -type MockMLSValidationService struct { - ctrl *gomock.Controller - recorder *MockMLSValidationServiceMockRecorder -} - -// MockMLSValidationServiceMockRecorder is the mock recorder for MockMLSValidationService. -type MockMLSValidationServiceMockRecorder struct { - mock *MockMLSValidationService -} - -// NewMockMLSValidationService creates a new mock instance. -func NewMockMLSValidationService(ctrl *gomock.Controller) *MockMLSValidationService { - mock := &MockMLSValidationService{ctrl: ctrl} - mock.recorder = &MockMLSValidationServiceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMLSValidationService) EXPECT() *MockMLSValidationServiceMockRecorder { - return m.recorder -} - -// GetAssociationState mocks base method. -func (m *MockMLSValidationService) GetAssociationState(ctx context.Context, oldUpdates, newUpdates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAssociationState", ctx, oldUpdates, newUpdates) - ret0, _ := ret[0].(*mlsvalidate.AssociationStateResult) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetAssociationState indicates an expected call of GetAssociationState. -func (mr *MockMLSValidationServiceMockRecorder) GetAssociationState(ctx, oldUpdates, newUpdates any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAssociationState", reflect.TypeOf((*MockMLSValidationService)(nil).GetAssociationState), ctx, oldUpdates, newUpdates) -} - -// ValidateGroupMessages mocks base method. -func (m *MockMLSValidationService) ValidateGroupMessages(ctx context.Context, groupMessages []*apiv10.GroupMessageInput) ([]mlsvalidate.GroupMessageValidationResult, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ValidateGroupMessages", ctx, groupMessages) - ret0, _ := ret[0].([]mlsvalidate.GroupMessageValidationResult) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ValidateGroupMessages indicates an expected call of ValidateGroupMessages. -func (mr *MockMLSValidationServiceMockRecorder) ValidateGroupMessages(ctx, groupMessages any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateGroupMessages", reflect.TypeOf((*MockMLSValidationService)(nil).ValidateGroupMessages), ctx, groupMessages) -} - -// ValidateInboxIdKeyPackages mocks base method. -func (m *MockMLSValidationService) ValidateInboxIdKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.InboxIdValidationResult, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ValidateInboxIdKeyPackages", ctx, keyPackages) - ret0, _ := ret[0].([]mlsvalidate.InboxIdValidationResult) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ValidateInboxIdKeyPackages indicates an expected call of ValidateInboxIdKeyPackages. -func (mr *MockMLSValidationServiceMockRecorder) ValidateInboxIdKeyPackages(ctx, keyPackages any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateInboxIdKeyPackages", reflect.TypeOf((*MockMLSValidationService)(nil).ValidateInboxIdKeyPackages), ctx, keyPackages) -} - -// ValidateV3KeyPackages mocks base method. -func (m *MockMLSValidationService) ValidateV3KeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ValidateV3KeyPackages", ctx, keyPackages) - ret0, _ := ret[0].([]mlsvalidate.IdentityValidationResult) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ValidateV3KeyPackages indicates an expected call of ValidateV3KeyPackages. -func (mr *MockMLSValidationServiceMockRecorder) ValidateV3KeyPackages(ctx, keyPackages any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateV3KeyPackages", reflect.TypeOf((*MockMLSValidationService)(nil).ValidateV3KeyPackages), ctx, keyPackages) -} - -// MockIdentityStore is a mock of IdentityStore interface. -type MockIdentityStore struct { - ctrl *gomock.Controller - recorder *MockIdentityStoreMockRecorder -} - -// MockIdentityStoreMockRecorder is the mock recorder for MockIdentityStore. -type MockIdentityStoreMockRecorder struct { - mock *MockIdentityStore -} - -// NewMockIdentityStore creates a new mock instance. -func NewMockIdentityStore(ctrl *gomock.Controller) *MockIdentityStore { - mock := &MockIdentityStore{ctrl: ctrl} - mock.recorder = &MockIdentityStoreMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockIdentityStore) EXPECT() *MockIdentityStoreMockRecorder { - return m.recorder -} - -// GetInboxLogs mocks base method. -func (m *MockIdentityStore) GetInboxLogs(ctx context.Context, req *apiv1.GetIdentityUpdatesRequest) (*apiv1.GetIdentityUpdatesResponse, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInboxLogs", ctx, req) - ret0, _ := ret[0].(*apiv1.GetIdentityUpdatesResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetInboxLogs indicates an expected call of GetInboxLogs. -func (mr *MockIdentityStoreMockRecorder) GetInboxLogs(ctx, req any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInboxLogs", reflect.TypeOf((*MockIdentityStore)(nil).GetInboxLogs), ctx, req) -} diff --git a/pkg/mocks/mock_MLSValidationService.go b/pkg/mocks/mock_MLSValidationService.go new file mode 100644 index 00000000..14cca455 --- /dev/null +++ b/pkg/mocks/mock_MLSValidationService.go @@ -0,0 +1,278 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + associations "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations" + apiv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" + + context "context" + + mlsvalidate "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" + + mock "github.com/stretchr/testify/mock" +) + +// MockMLSValidationService is an autogenerated mock type for the MLSValidationService type +type MockMLSValidationService struct { + mock.Mock +} + +type MockMLSValidationService_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMLSValidationService) EXPECT() *MockMLSValidationService_Expecter { + return &MockMLSValidationService_Expecter{mock: &_m.Mock} +} + +// GetAssociationState provides a mock function with given fields: ctx, oldUpdates, newUpdates +func (_m *MockMLSValidationService) GetAssociationState(ctx context.Context, oldUpdates []*associations.IdentityUpdate, newUpdates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) { + ret := _m.Called(ctx, oldUpdates, newUpdates) + + if len(ret) == 0 { + panic("no return value specified for GetAssociationState") + } + + var r0 *mlsvalidate.AssociationStateResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []*associations.IdentityUpdate, []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error)); ok { + return rf(ctx, oldUpdates, newUpdates) + } + if rf, ok := ret.Get(0).(func(context.Context, []*associations.IdentityUpdate, []*associations.IdentityUpdate) *mlsvalidate.AssociationStateResult); ok { + r0 = rf(ctx, oldUpdates, newUpdates) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*mlsvalidate.AssociationStateResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []*associations.IdentityUpdate, []*associations.IdentityUpdate) error); ok { + r1 = rf(ctx, oldUpdates, newUpdates) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMLSValidationService_GetAssociationState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAssociationState' +type MockMLSValidationService_GetAssociationState_Call struct { + *mock.Call +} + +// GetAssociationState is a helper method to define mock.On call +// - ctx context.Context +// - oldUpdates []*associations.IdentityUpdate +// - newUpdates []*associations.IdentityUpdate +func (_e *MockMLSValidationService_Expecter) GetAssociationState(ctx interface{}, oldUpdates interface{}, newUpdates interface{}) *MockMLSValidationService_GetAssociationState_Call { + return &MockMLSValidationService_GetAssociationState_Call{Call: _e.mock.On("GetAssociationState", ctx, oldUpdates, newUpdates)} +} + +func (_c *MockMLSValidationService_GetAssociationState_Call) Run(run func(ctx context.Context, oldUpdates []*associations.IdentityUpdate, newUpdates []*associations.IdentityUpdate)) *MockMLSValidationService_GetAssociationState_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]*associations.IdentityUpdate), args[2].([]*associations.IdentityUpdate)) + }) + return _c +} + +func (_c *MockMLSValidationService_GetAssociationState_Call) Return(_a0 *mlsvalidate.AssociationStateResult, _a1 error) *MockMLSValidationService_GetAssociationState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMLSValidationService_GetAssociationState_Call) RunAndReturn(run func(context.Context, []*associations.IdentityUpdate, []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error)) *MockMLSValidationService_GetAssociationState_Call { + _c.Call.Return(run) + return _c +} + +// ValidateGroupMessages provides a mock function with given fields: ctx, groupMessages +func (_m *MockMLSValidationService) ValidateGroupMessages(ctx context.Context, groupMessages []*apiv1.GroupMessageInput) ([]mlsvalidate.GroupMessageValidationResult, error) { + ret := _m.Called(ctx, groupMessages) + + if len(ret) == 0 { + panic("no return value specified for ValidateGroupMessages") + } + + var r0 []mlsvalidate.GroupMessageValidationResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []*apiv1.GroupMessageInput) ([]mlsvalidate.GroupMessageValidationResult, error)); ok { + return rf(ctx, groupMessages) + } + if rf, ok := ret.Get(0).(func(context.Context, []*apiv1.GroupMessageInput) []mlsvalidate.GroupMessageValidationResult); ok { + r0 = rf(ctx, groupMessages) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]mlsvalidate.GroupMessageValidationResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []*apiv1.GroupMessageInput) error); ok { + r1 = rf(ctx, groupMessages) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMLSValidationService_ValidateGroupMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ValidateGroupMessages' +type MockMLSValidationService_ValidateGroupMessages_Call struct { + *mock.Call +} + +// ValidateGroupMessages is a helper method to define mock.On call +// - ctx context.Context +// - groupMessages []*apiv1.GroupMessageInput +func (_e *MockMLSValidationService_Expecter) ValidateGroupMessages(ctx interface{}, groupMessages interface{}) *MockMLSValidationService_ValidateGroupMessages_Call { + return &MockMLSValidationService_ValidateGroupMessages_Call{Call: _e.mock.On("ValidateGroupMessages", ctx, groupMessages)} +} + +func (_c *MockMLSValidationService_ValidateGroupMessages_Call) Run(run func(ctx context.Context, groupMessages []*apiv1.GroupMessageInput)) *MockMLSValidationService_ValidateGroupMessages_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]*apiv1.GroupMessageInput)) + }) + return _c +} + +func (_c *MockMLSValidationService_ValidateGroupMessages_Call) Return(_a0 []mlsvalidate.GroupMessageValidationResult, _a1 error) *MockMLSValidationService_ValidateGroupMessages_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMLSValidationService_ValidateGroupMessages_Call) RunAndReturn(run func(context.Context, []*apiv1.GroupMessageInput) ([]mlsvalidate.GroupMessageValidationResult, error)) *MockMLSValidationService_ValidateGroupMessages_Call { + _c.Call.Return(run) + return _c +} + +// ValidateInboxIdKeyPackages provides a mock function with given fields: ctx, keyPackages +func (_m *MockMLSValidationService) ValidateInboxIdKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.InboxIdValidationResult, error) { + ret := _m.Called(ctx, keyPackages) + + if len(ret) == 0 { + panic("no return value specified for ValidateInboxIdKeyPackages") + } + + var r0 []mlsvalidate.InboxIdValidationResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, [][]byte) ([]mlsvalidate.InboxIdValidationResult, error)); ok { + return rf(ctx, keyPackages) + } + if rf, ok := ret.Get(0).(func(context.Context, [][]byte) []mlsvalidate.InboxIdValidationResult); ok { + r0 = rf(ctx, keyPackages) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]mlsvalidate.InboxIdValidationResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, [][]byte) error); ok { + r1 = rf(ctx, keyPackages) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMLSValidationService_ValidateInboxIdKeyPackages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ValidateInboxIdKeyPackages' +type MockMLSValidationService_ValidateInboxIdKeyPackages_Call struct { + *mock.Call +} + +// ValidateInboxIdKeyPackages is a helper method to define mock.On call +// - ctx context.Context +// - keyPackages [][]byte +func (_e *MockMLSValidationService_Expecter) ValidateInboxIdKeyPackages(ctx interface{}, keyPackages interface{}) *MockMLSValidationService_ValidateInboxIdKeyPackages_Call { + return &MockMLSValidationService_ValidateInboxIdKeyPackages_Call{Call: _e.mock.On("ValidateInboxIdKeyPackages", ctx, keyPackages)} +} + +func (_c *MockMLSValidationService_ValidateInboxIdKeyPackages_Call) Run(run func(ctx context.Context, keyPackages [][]byte)) *MockMLSValidationService_ValidateInboxIdKeyPackages_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([][]byte)) + }) + return _c +} + +func (_c *MockMLSValidationService_ValidateInboxIdKeyPackages_Call) Return(_a0 []mlsvalidate.InboxIdValidationResult, _a1 error) *MockMLSValidationService_ValidateInboxIdKeyPackages_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMLSValidationService_ValidateInboxIdKeyPackages_Call) RunAndReturn(run func(context.Context, [][]byte) ([]mlsvalidate.InboxIdValidationResult, error)) *MockMLSValidationService_ValidateInboxIdKeyPackages_Call { + _c.Call.Return(run) + return _c +} + +// ValidateV3KeyPackages provides a mock function with given fields: ctx, keyPackages +func (_m *MockMLSValidationService) ValidateV3KeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) { + ret := _m.Called(ctx, keyPackages) + + if len(ret) == 0 { + panic("no return value specified for ValidateV3KeyPackages") + } + + var r0 []mlsvalidate.IdentityValidationResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, [][]byte) ([]mlsvalidate.IdentityValidationResult, error)); ok { + return rf(ctx, keyPackages) + } + if rf, ok := ret.Get(0).(func(context.Context, [][]byte) []mlsvalidate.IdentityValidationResult); ok { + r0 = rf(ctx, keyPackages) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]mlsvalidate.IdentityValidationResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, [][]byte) error); ok { + r1 = rf(ctx, keyPackages) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMLSValidationService_ValidateV3KeyPackages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ValidateV3KeyPackages' +type MockMLSValidationService_ValidateV3KeyPackages_Call struct { + *mock.Call +} + +// ValidateV3KeyPackages is a helper method to define mock.On call +// - ctx context.Context +// - keyPackages [][]byte +func (_e *MockMLSValidationService_Expecter) ValidateV3KeyPackages(ctx interface{}, keyPackages interface{}) *MockMLSValidationService_ValidateV3KeyPackages_Call { + return &MockMLSValidationService_ValidateV3KeyPackages_Call{Call: _e.mock.On("ValidateV3KeyPackages", ctx, keyPackages)} +} + +func (_c *MockMLSValidationService_ValidateV3KeyPackages_Call) Run(run func(ctx context.Context, keyPackages [][]byte)) *MockMLSValidationService_ValidateV3KeyPackages_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([][]byte)) + }) + return _c +} + +func (_c *MockMLSValidationService_ValidateV3KeyPackages_Call) Return(_a0 []mlsvalidate.IdentityValidationResult, _a1 error) *MockMLSValidationService_ValidateV3KeyPackages_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMLSValidationService_ValidateV3KeyPackages_Call) RunAndReturn(run func(context.Context, [][]byte) ([]mlsvalidate.IdentityValidationResult, error)) *MockMLSValidationService_ValidateV3KeyPackages_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMLSValidationService creates a new instance of MockMLSValidationService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMLSValidationService(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMLSValidationService { + mock := &MockMLSValidationService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/mocks/mock_MlsApi_SubscribeGroupMessagesServer.go b/pkg/mocks/mock_MlsApi_SubscribeGroupMessagesServer.go new file mode 100644 index 00000000..02d2a168 --- /dev/null +++ b/pkg/mocks/mock_MlsApi_SubscribeGroupMessagesServer.go @@ -0,0 +1,350 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + apiv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" + + metadata "google.golang.org/grpc/metadata" + + mock "github.com/stretchr/testify/mock" +) + +// MockMlsApi_SubscribeGroupMessagesServer is an autogenerated mock type for the MlsApi_SubscribeGroupMessagesServer type +type MockMlsApi_SubscribeGroupMessagesServer struct { + mock.Mock +} + +type MockMlsApi_SubscribeGroupMessagesServer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMlsApi_SubscribeGroupMessagesServer) EXPECT() *MockMlsApi_SubscribeGroupMessagesServer_Expecter { + return &MockMlsApi_SubscribeGroupMessagesServer_Expecter{mock: &_m.Mock} +} + +// Context provides a mock function with given fields: +func (_m *MockMlsApi_SubscribeGroupMessagesServer) Context() context.Context { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Context") + } + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// MockMlsApi_SubscribeGroupMessagesServer_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context' +type MockMlsApi_SubscribeGroupMessagesServer_Context_Call struct { + *mock.Call +} + +// Context is a helper method to define mock.On call +func (_e *MockMlsApi_SubscribeGroupMessagesServer_Expecter) Context() *MockMlsApi_SubscribeGroupMessagesServer_Context_Call { + return &MockMlsApi_SubscribeGroupMessagesServer_Context_Call{Call: _e.mock.On("Context")} +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_Context_Call) Run(run func()) *MockMlsApi_SubscribeGroupMessagesServer_Context_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_Context_Call) Return(_a0 context.Context) *MockMlsApi_SubscribeGroupMessagesServer_Context_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_Context_Call) RunAndReturn(run func() context.Context) *MockMlsApi_SubscribeGroupMessagesServer_Context_Call { + _c.Call.Return(run) + return _c +} + +// RecvMsg provides a mock function with given fields: m +func (_m *MockMlsApi_SubscribeGroupMessagesServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + if len(ret) == 0 { + panic("no return value specified for RecvMsg") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg' +type MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call struct { + *mock.Call +} + +// RecvMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockMlsApi_SubscribeGroupMessagesServer_Expecter) RecvMsg(m interface{}) *MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call { + return &MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)} +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call) Run(run func(m interface{})) *MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call) Return(_a0 error) *MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *MockMlsApi_SubscribeGroupMessagesServer_RecvMsg_Call { + _c.Call.Return(run) + return _c +} + +// Send provides a mock function with given fields: _a0 +func (_m *MockMlsApi_SubscribeGroupMessagesServer) Send(_a0 *apiv1.GroupMessage) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Send") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*apiv1.GroupMessage) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeGroupMessagesServer_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type MockMlsApi_SubscribeGroupMessagesServer_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - _a0 *apiv1.GroupMessage +func (_e *MockMlsApi_SubscribeGroupMessagesServer_Expecter) Send(_a0 interface{}) *MockMlsApi_SubscribeGroupMessagesServer_Send_Call { + return &MockMlsApi_SubscribeGroupMessagesServer_Send_Call{Call: _e.mock.On("Send", _a0)} +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_Send_Call) Run(run func(_a0 *apiv1.GroupMessage)) *MockMlsApi_SubscribeGroupMessagesServer_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*apiv1.GroupMessage)) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_Send_Call) Return(_a0 error) *MockMlsApi_SubscribeGroupMessagesServer_Send_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_Send_Call) RunAndReturn(run func(*apiv1.GroupMessage) error) *MockMlsApi_SubscribeGroupMessagesServer_Send_Call { + _c.Call.Return(run) + return _c +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *MockMlsApi_SubscribeGroupMessagesServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendHeader") + } + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendHeader' +type MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call struct { + *mock.Call +} + +// SendHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockMlsApi_SubscribeGroupMessagesServer_Expecter) SendHeader(_a0 interface{}) *MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call { + return &MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call{Call: _e.mock.On("SendHeader", _a0)} +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call) Run(run func(_a0 metadata.MD)) *MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call) Return(_a0 error) *MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockMlsApi_SubscribeGroupMessagesServer_SendHeader_Call { + _c.Call.Return(run) + return _c +} + +// SendMsg provides a mock function with given fields: m +func (_m *MockMlsApi_SubscribeGroupMessagesServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + if len(ret) == 0 { + panic("no return value specified for SendMsg") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg' +type MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call struct { + *mock.Call +} + +// SendMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockMlsApi_SubscribeGroupMessagesServer_Expecter) SendMsg(m interface{}) *MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call { + return &MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call{Call: _e.mock.On("SendMsg", m)} +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call) Run(run func(m interface{})) *MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call) Return(_a0 error) *MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call) RunAndReturn(run func(interface{}) error) *MockMlsApi_SubscribeGroupMessagesServer_SendMsg_Call { + _c.Call.Return(run) + return _c +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *MockMlsApi_SubscribeGroupMessagesServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SetHeader") + } + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeader' +type MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call struct { + *mock.Call +} + +// SetHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockMlsApi_SubscribeGroupMessagesServer_Expecter) SetHeader(_a0 interface{}) *MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call { + return &MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call{Call: _e.mock.On("SetHeader", _a0)} +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call) Run(run func(_a0 metadata.MD)) *MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call) Return(_a0 error) *MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockMlsApi_SubscribeGroupMessagesServer_SetHeader_Call { + _c.Call.Return(run) + return _c +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *MockMlsApi_SubscribeGroupMessagesServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTrailer' +type MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call struct { + *mock.Call +} + +// SetTrailer is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockMlsApi_SubscribeGroupMessagesServer_Expecter) SetTrailer(_a0 interface{}) *MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call { + return &MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call{Call: _e.mock.On("SetTrailer", _a0)} +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call) Run(run func(_a0 metadata.MD)) *MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call) Return() *MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call { + _c.Call.Return() + return _c +} + +func (_c *MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call) RunAndReturn(run func(metadata.MD)) *MockMlsApi_SubscribeGroupMessagesServer_SetTrailer_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMlsApi_SubscribeGroupMessagesServer creates a new instance of MockMlsApi_SubscribeGroupMessagesServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMlsApi_SubscribeGroupMessagesServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMlsApi_SubscribeGroupMessagesServer { + mock := &MockMlsApi_SubscribeGroupMessagesServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/mocks/mock_MlsApi_SubscribeWelcomeMessagesServer.go b/pkg/mocks/mock_MlsApi_SubscribeWelcomeMessagesServer.go new file mode 100644 index 00000000..bfc77b8a --- /dev/null +++ b/pkg/mocks/mock_MlsApi_SubscribeWelcomeMessagesServer.go @@ -0,0 +1,350 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + apiv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" + + metadata "google.golang.org/grpc/metadata" + + mock "github.com/stretchr/testify/mock" +) + +// MockMlsApi_SubscribeWelcomeMessagesServer is an autogenerated mock type for the MlsApi_SubscribeWelcomeMessagesServer type +type MockMlsApi_SubscribeWelcomeMessagesServer struct { + mock.Mock +} + +type MockMlsApi_SubscribeWelcomeMessagesServer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMlsApi_SubscribeWelcomeMessagesServer) EXPECT() *MockMlsApi_SubscribeWelcomeMessagesServer_Expecter { + return &MockMlsApi_SubscribeWelcomeMessagesServer_Expecter{mock: &_m.Mock} +} + +// Context provides a mock function with given fields: +func (_m *MockMlsApi_SubscribeWelcomeMessagesServer) Context() context.Context { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Context") + } + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context' +type MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call struct { + *mock.Call +} + +// Context is a helper method to define mock.On call +func (_e *MockMlsApi_SubscribeWelcomeMessagesServer_Expecter) Context() *MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call { + return &MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call{Call: _e.mock.On("Context")} +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call) Run(run func()) *MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call) Return(_a0 context.Context) *MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call) RunAndReturn(run func() context.Context) *MockMlsApi_SubscribeWelcomeMessagesServer_Context_Call { + _c.Call.Return(run) + return _c +} + +// RecvMsg provides a mock function with given fields: m +func (_m *MockMlsApi_SubscribeWelcomeMessagesServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + if len(ret) == 0 { + panic("no return value specified for RecvMsg") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg' +type MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call struct { + *mock.Call +} + +// RecvMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockMlsApi_SubscribeWelcomeMessagesServer_Expecter) RecvMsg(m interface{}) *MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call { + return &MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)} +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call) Run(run func(m interface{})) *MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call) Return(_a0 error) *MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *MockMlsApi_SubscribeWelcomeMessagesServer_RecvMsg_Call { + _c.Call.Return(run) + return _c +} + +// Send provides a mock function with given fields: _a0 +func (_m *MockMlsApi_SubscribeWelcomeMessagesServer) Send(_a0 *apiv1.WelcomeMessage) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Send") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*apiv1.WelcomeMessage) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - _a0 *apiv1.WelcomeMessage +func (_e *MockMlsApi_SubscribeWelcomeMessagesServer_Expecter) Send(_a0 interface{}) *MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call { + return &MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call{Call: _e.mock.On("Send", _a0)} +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call) Run(run func(_a0 *apiv1.WelcomeMessage)) *MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*apiv1.WelcomeMessage)) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call) Return(_a0 error) *MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call) RunAndReturn(run func(*apiv1.WelcomeMessage) error) *MockMlsApi_SubscribeWelcomeMessagesServer_Send_Call { + _c.Call.Return(run) + return _c +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *MockMlsApi_SubscribeWelcomeMessagesServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendHeader") + } + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendHeader' +type MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call struct { + *mock.Call +} + +// SendHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockMlsApi_SubscribeWelcomeMessagesServer_Expecter) SendHeader(_a0 interface{}) *MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call { + return &MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call{Call: _e.mock.On("SendHeader", _a0)} +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call) Run(run func(_a0 metadata.MD)) *MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call) Return(_a0 error) *MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockMlsApi_SubscribeWelcomeMessagesServer_SendHeader_Call { + _c.Call.Return(run) + return _c +} + +// SendMsg provides a mock function with given fields: m +func (_m *MockMlsApi_SubscribeWelcomeMessagesServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + if len(ret) == 0 { + panic("no return value specified for SendMsg") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg' +type MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call struct { + *mock.Call +} + +// SendMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockMlsApi_SubscribeWelcomeMessagesServer_Expecter) SendMsg(m interface{}) *MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call { + return &MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call{Call: _e.mock.On("SendMsg", m)} +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call) Run(run func(m interface{})) *MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call) Return(_a0 error) *MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call) RunAndReturn(run func(interface{}) error) *MockMlsApi_SubscribeWelcomeMessagesServer_SendMsg_Call { + _c.Call.Return(run) + return _c +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *MockMlsApi_SubscribeWelcomeMessagesServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SetHeader") + } + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeader' +type MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call struct { + *mock.Call +} + +// SetHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockMlsApi_SubscribeWelcomeMessagesServer_Expecter) SetHeader(_a0 interface{}) *MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call { + return &MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call{Call: _e.mock.On("SetHeader", _a0)} +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call) Run(run func(_a0 metadata.MD)) *MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call) Return(_a0 error) *MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockMlsApi_SubscribeWelcomeMessagesServer_SetHeader_Call { + _c.Call.Return(run) + return _c +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *MockMlsApi_SubscribeWelcomeMessagesServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTrailer' +type MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call struct { + *mock.Call +} + +// SetTrailer is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockMlsApi_SubscribeWelcomeMessagesServer_Expecter) SetTrailer(_a0 interface{}) *MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call { + return &MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call{Call: _e.mock.On("SetTrailer", _a0)} +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call) Run(run func(_a0 metadata.MD)) *MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call) Return() *MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call { + _c.Call.Return() + return _c +} + +func (_c *MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call) RunAndReturn(run func(metadata.MD)) *MockMlsApi_SubscribeWelcomeMessagesServer_SetTrailer_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMlsApi_SubscribeWelcomeMessagesServer creates a new instance of MockMlsApi_SubscribeWelcomeMessagesServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMlsApi_SubscribeWelcomeMessagesServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMlsApi_SubscribeWelcomeMessagesServer { + mock := &MockMlsApi_SubscribeWelcomeMessagesServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/testing/random.go b/pkg/testing/random.go index 508331db..7f71a047 100644 --- a/pkg/testing/random.go +++ b/pkg/testing/random.go @@ -4,6 +4,8 @@ import ( cryptoRand "crypto/rand" "math/rand" "strings" + + "github.com/xmtp/xmtp-node-go/pkg/utils" ) var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") @@ -25,3 +27,9 @@ func RandomBytes(n int) []byte { _, _ = cryptoRand.Read(b) return b } + +func RandomInboxId() string { + bytes := RandomBytes(32) + + return utils.HexEncode(bytes) +} diff --git a/pkg/utils/hex.go b/pkg/utils/hex.go new file mode 100644 index 00000000..02fc8f47 --- /dev/null +++ b/pkg/utils/hex.go @@ -0,0 +1,19 @@ +package utils + +import "encoding/hex" + +func HexEncode(data []byte) string { + return hex.EncodeToString(data) +} + +func HexDecode(s string) ([]byte, error) { + return hex.DecodeString(s) +} + +func AssertHexDecode(s string) []byte { + data, err := HexDecode(s) + if err != nil { + panic(err) + } + return data +} diff --git a/tools.go b/tools.go index 5a194c14..6df5e5e9 100644 --- a/tools.go +++ b/tools.go @@ -5,6 +5,5 @@ package tools import ( _ "github.com/yoheimuta/protolint/cmd/protolint" - _ "go.uber.org/mock/mockgen" _ "google.golang.org/protobuf/cmd/protoc-gen-go" )