Skip to content

Commit

Permalink
Remove inbox validation for uploading key packages (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas authored Aug 14, 2024
1 parent 24971f4 commit f7a54a0
Show file tree
Hide file tree
Showing 16 changed files with 122 additions and 203 deletions.
2 changes: 1 addition & 1 deletion dev/docker/env
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
set -e

function docker_compose() {
docker-compose -f dev/docker/docker-compose.yml -p xmtpd "$@"
docker compose -f dev/docker/docker-compose.yml -p xmtpd "$@"
}
2 changes: 1 addition & 1 deletion dev/e2e/docker/env
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
set -e

function docker_compose() {
docker-compose -f dev/e2e/docker/docker-compose.yml -p xmtpd-e2e "$@"
docker compose -f dev/e2e/docker/docker-compose.yml -p xmtpd-e2e "$@"
}
2 changes: 1 addition & 1 deletion dev/lint
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ if [[ $(gofmt -l .) ]]; then
echo "gofmt errors, run 'gofmt -w .' and commit"
fi

golangci-lint --config dev/.golangci.yaml run ./... --deadline=5m
golangci-lint --config dev/.golangci.yaml run ./...

protolint .
13 changes: 7 additions & 6 deletions pkg/api/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,17 @@ func (wa *WalletAuthorizer) requiresAuthorization(req interface{}) bool {
func (wa *WalletAuthorizer) getWallet(ctx context.Context) (types.WalletAddr, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Errorf(codes.Unauthenticated, "metadata is not provided")
return "", status.Error(codes.Unauthenticated, "metadata is not provided")
}

values := md.Get(authorizationMetadataKey)
if len(values) == 0 {
return "", status.Errorf(codes.Unauthenticated, "authorization token is not provided")
return "", status.Error(codes.Unauthenticated, "authorization token is not provided")
}

words := strings.SplitN(values[0], " ", 2)
if len(words) != 2 {
return "", status.Errorf(codes.Unauthenticated, "invalid authorization header")
return "", status.Error(codes.Unauthenticated, "invalid authorization header")
}
if scheme := strings.TrimSpace(words[0]); scheme != "Bearer" {
return "", status.Errorf(codes.Unauthenticated, "unrecognized authorization scheme %s", scheme)
Expand All @@ -127,14 +127,14 @@ func (wa *WalletAuthorizer) authorize(ctx context.Context, req interface{}, wall
if pub, isPublish := req.(*messagev1.PublishRequest); isPublish {
for _, env := range pub.Envelopes {
if !wa.privilegedAddresses[wallet] && !allowedToPublish(env.ContentTopic, wallet) {
return status.Errorf(codes.PermissionDenied, "publishing to restricted topic")
return status.Error(codes.PermissionDenied, "publishing to restricted topic")
}
}
}
if wa.AllowLists {
if wa.AllowLister.IsDenyListed(wallet.String()) {
wa.Log.Debug("wallet deny listed", logging.WalletAddress(wallet.String()))
return status.Errorf(codes.PermissionDenied, ErrDenyListed.Error())
return status.Error(codes.PermissionDenied, ErrDenyListed.Error())
}
}
return nil
Expand Down Expand Up @@ -185,7 +185,8 @@ func (wa *WalletAuthorizer) applyLimits(ctx context.Context, fullMethod string,
logging.String("method", method),
logging.String("limit", string(limitType)),
logging.Int("cost", cost))
return status.Errorf(codes.ResourceExhausted, err.Error())

return status.Error(codes.ResourceExhausted, err.Error())
}

const (
Expand Down
10 changes: 5 additions & 5 deletions pkg/api/message/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot
log.Debug("received message")

if len(env.ContentTopic) > MaxContentTopicNameSize {
return nil, status.Errorf(codes.InvalidArgument, "topic length too big")
return nil, status.Error(codes.InvalidArgument, "topic length too big")
}

if len(env.Message) > MaxMessageSize {
return nil, status.Errorf(codes.InvalidArgument, "message too big")
return nil, status.Error(codes.InvalidArgument, "message too big")
}

if !topic.IsEphemeral(env.ContentTopic) {
_, err := s.store.InsertMessage(env)
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
return nil, status.Error(codes.Internal, err.Error())
}
}

Expand All @@ -150,7 +150,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot
Payload: env.Message,
})
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
return nil, status.Error(codes.Internal, err.Error())
}

metrics.EmitPublishedEnvelope(ctx, log, env)
Expand Down Expand Up @@ -393,7 +393,7 @@ func (s *Service) BatchQuery(ctx context.Context, req *proto.BatchQueryRequest)
// We execute the query using the existing Query API
resp, err := s.Query(ctx, query)
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
return nil, status.Error(codes.Internal, err.Error())
}
responses = append(responses, resp)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SET statement_timeout = 0;

--bun:split
ALTER TABLE installations
ADD COLUMN inbox_id BYTEA NOT NULL,
ADD COLUMN expiration BIGINT NOT NULL;

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SET statement_timeout = 0;

--bun:split
ALTER TABLE installations
DROP COLUMN IF EXISTS inbox_id,
DROP COLUMN IF EXISTS expiration;

34 changes: 19 additions & 15 deletions pkg/mls/api/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) er
return nil
}

/*
*
DEPRECATED: Use UploadKeyPackage instead
*
*/
func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterInstallationRequest) (*mlsv1.RegisterInstallationResponse, error) {
if err := validateRegisterInstallationRequest(req); err != nil {
return nil, err
Expand All @@ -126,9 +131,9 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterI
if len(results) != 1 {
return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results))
}

installationKey := results[0].InstallationKey
credential := results[0].Credential
if err = s.store.CreateInstallation(ctx, installationKey, credential.InboxId, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil {
if err = s.store.CreateOrUpdateInstallation(ctx, installationKey, req.KeyPackage.KeyPackageTlsSerialized); err != nil {
return nil, err
}
return &mlsv1.RegisterInstallationResponse{
Expand All @@ -152,7 +157,7 @@ func (s *Service) FetchKeyPackages(ctx context.Context, req *mlsv1.FetchKeyPacka

idx, ok := keyPackageMap[string(installation.ID)]
if !ok {
return nil, status.Errorf(codes.Internal, "could not find key package for installation")
return nil, status.Error(codes.Internal, "could not find key package for installation")
}

resPackages[idx] = &mlsv1.FetchKeyPackagesResponse_KeyPackage{
Expand All @@ -178,21 +183,20 @@ func (s *Service) UploadKeyPackage(ctx context.Context, req *mlsv1.UploadKeyPack
}

installationId := validationResults[0].InstallationKey
expiration := validationResults[0].Expiration

if err = s.store.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil {
if err = s.store.CreateOrUpdateInstallation(ctx, installationId, keyPackageBytes); err != nil {
return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err)
}

return &emptypb.Empty{}, nil
}

func (s *Service) RevokeInstallation(ctx context.Context, req *mlsv1.RevokeInstallationRequest) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
return nil, status.Error(codes.Unimplemented, "unimplemented")
}

func (s *Service) GetIdentityUpdates(ctx context.Context, req *mlsv1.GetIdentityUpdatesRequest) (res *mlsv1.GetIdentityUpdatesResponse, err error) {
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
return nil, status.Error(codes.Unimplemented, "unimplemented")
}

func (s *Service) SendGroupMessages(ctx context.Context, req *mlsv1.SendGroupMessagesRequest) (res *emptypb.Empty, err error) {
Expand Down Expand Up @@ -521,11 +525,11 @@ func buildNatsSubjectForWelcomeMessages(installationId []byte) string {

func validateSendGroupMessagesRequest(req *mlsv1.SendGroupMessagesRequest) error {
if req == nil || len(req.Messages) == 0 {
return status.Errorf(codes.InvalidArgument, "no group messages to send")
return status.Error(codes.InvalidArgument, "no group messages to send")
}
for _, input := range req.Messages {
if input == nil || input.GetV1() == nil {
return status.Errorf(codes.InvalidArgument, "invalid group message")
return status.Error(codes.InvalidArgument, "invalid group message")
}
}
return nil
Expand All @@ -537,37 +541,37 @@ func validateSendWelcomeMessagesRequest(req *mlsv1.SendWelcomeMessagesRequest) e
}
for _, input := range req.Messages {
if input == nil || input.GetV1() == nil {
return status.Errorf(codes.InvalidArgument, "invalid welcome message")
return status.Error(codes.InvalidArgument, "invalid welcome message")
}

v1 := input.GetV1()
if len(v1.Data) == 0 || len(v1.InstallationKey) == 0 || len(v1.HpkePublicKey) == 0 {
return status.Errorf(codes.InvalidArgument, "invalid welcome message")
return status.Error(codes.InvalidArgument, "invalid welcome message")
}
}
return nil
}

func validateRegisterInstallationRequest(req *mlsv1.RegisterInstallationRequest) error {
if req == nil || req.KeyPackage == nil {
return status.Errorf(codes.InvalidArgument, "no key package")
return status.Error(codes.InvalidArgument, "no key package")
}
return nil
}

func validateUploadKeyPackageRequest(req *mlsv1.UploadKeyPackageRequest) error {
if req == nil || req.KeyPackage == nil {
return status.Errorf(codes.InvalidArgument, "no key package")
return status.Error(codes.InvalidArgument, "no key package")
}
return nil
}

func requireReadyToSend(groupId string, message []byte) error {
if len(groupId) == 0 {
return status.Errorf(codes.InvalidArgument, "group id is empty")
return status.Error(codes.InvalidArgument, "group id is empty")
}
if len(message) == 0 {
return status.Errorf(codes.InvalidArgument, "message is empty")
return status.Error(codes.InvalidArgument, "message is empty")
}
return nil
}
Expand Down
9 changes: 5 additions & 4 deletions pkg/mls/api/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ func TestRegisterInstallation(t *testing.T) {
defer cleanup()

installationId := test.RandomBytes(32)
inboxId := test.RandomInboxId()
keyPackage := []byte("test")

mockValidateInboxIdKeyPackages(mlsValidationService, installationId, inboxId)
mockValidateInboxIdKeyPackages(mlsValidationService, installationId, test.RandomInboxId())

res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{
KeyPackage: &mlsv1.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
KeyPackageTlsSerialized: keyPackage,
},
IsInboxIdCredential: false,
})
Expand All @@ -98,7 +98,8 @@ func TestRegisterInstallation(t *testing.T) {
installation, err := queries.New(mlsDb.DB).GetInstallation(ctx, installationId)
require.NoError(t, err)

require.Equal(t, inboxId, installation.InboxID)
require.Equal(t, installationId, installation.ID)
require.Equal(t, []byte("test"), installation.KeyPackage)
}

func TestRegisterInstallationError(t *testing.T) {
Expand Down
23 changes: 7 additions & 16 deletions pkg/mls/store/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,24 @@ WHERE (address, inbox_id, association_sequence_id) =(
address,
inbox_id);

-- name: CreateInstallation :exec
INSERT INTO installations(id, created_at, updated_at, inbox_id, key_package, expiration)
VALUES (@id, @created_at, @updated_at, decode(@inbox_id, 'hex'), @key_package, @expiration);
-- name: CreateOrUpdateInstallation :exec
INSERT INTO installations(id, created_at, updated_at, key_package)
VALUES (@id, @created_at, @updated_at, @key_package)
ON CONFLICT (id)
DO UPDATE SET
key_package = @key_package, updated_at = @updated_at;

-- name: GetInstallation :one
SELECT
id,
created_at,
updated_at,
encode(inbox_id, 'hex') AS inbox_id,
key_package,
expiration
key_package
FROM
installations
WHERE
id = $1;

-- name: UpdateKeyPackage :execrows
UPDATE
installations
SET
key_package = @key_package,
updated_at = @updated_at,
expiration = @expiration
WHERE
id = @id;

-- name: FetchKeyPackages :many
SELECT
id,
Expand Down
2 changes: 0 additions & 2 deletions pkg/mls/store/queries/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit f7a54a0

Please sign in to comment.