Skip to content

Commit

Permalink
Revert "Remove deprecated methods from API (#402)" (#404)
Browse files Browse the repository at this point in the history
This reverts commit 35c40a2.
  • Loading branch information
neekolas authored Sep 6, 2024
1 parent e386f82 commit dd0b6d5
Show file tree
Hide file tree
Showing 18 changed files with 3,989 additions and 1,725 deletions.
43 changes: 43 additions & 0 deletions pkg/mls/api/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,34 @@ func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) er
return nil
}

/*
*
DEPRECATED: Use UploadKeyPackage instead
*
*/
func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterInstallationRequest) (*mlsv1.RegisterInstallationResponse, error) {
if err := validateRegisterInstallationRequest(req); err != nil {
return nil, err
}

results, err := s.validationService.ValidateInboxIdKeyPackages(ctx, [][]byte{req.KeyPackage.KeyPackageTlsSerialized})
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err)
}

if len(results) != 1 {
return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results))
}

installationKey := results[0].InstallationKey
if err = s.store.CreateOrUpdateInstallation(ctx, installationKey, req.KeyPackage.KeyPackageTlsSerialized); err != nil {
return nil, err
}
return &mlsv1.RegisterInstallationResponse{
InstallationKey: installationKey,
}, nil
}

func (s *Service) FetchKeyPackages(ctx context.Context, req *mlsv1.FetchKeyPackagesRequest) (*mlsv1.FetchKeyPackagesResponse, error) {
ids := req.InstallationKeys
installations, err := s.store.FetchKeyPackages(ctx, ids)
Expand Down Expand Up @@ -163,6 +191,14 @@ func (s *Service) UploadKeyPackage(ctx context.Context, req *mlsv1.UploadKeyPack
return &emptypb.Empty{}, nil
}

func (s *Service) RevokeInstallation(ctx context.Context, req *mlsv1.RevokeInstallationRequest) (*emptypb.Empty, error) {
return nil, status.Error(codes.Unimplemented, "unimplemented")
}

func (s *Service) GetIdentityUpdates(ctx context.Context, req *mlsv1.GetIdentityUpdatesRequest) (res *mlsv1.GetIdentityUpdatesResponse, err error) {
return nil, status.Error(codes.Unimplemented, "unimplemented")
}

func (s *Service) SendGroupMessages(ctx context.Context, req *mlsv1.SendGroupMessagesRequest) (res *emptypb.Empty, err error) {
log := s.log.Named("send-group-messages")
if err = validateSendGroupMessagesRequest(req); err != nil {
Expand Down Expand Up @@ -516,6 +552,13 @@ func validateSendWelcomeMessagesRequest(req *mlsv1.SendWelcomeMessagesRequest) e
return nil
}

func validateRegisterInstallationRequest(req *mlsv1.RegisterInstallationRequest) error {
if req == nil || req.KeyPackage == nil {
return status.Error(codes.InvalidArgument, "no key package")
}
return nil
}

func validateUploadKeyPackageRequest(req *mlsv1.UploadKeyPackageRequest) error {
if req == nil || req.KeyPackage == nil {
return status.Error(codes.InvalidArgument, "no key package")
Expand Down
51 changes: 48 additions & 3 deletions pkg/mls/api/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"bytes"
"context"
"errors"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -74,6 +75,50 @@ func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mock
}
}

func TestRegisterInstallation(t *testing.T) {
ctx := context.Background()
svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx)
defer cleanup()

installationId := test.RandomBytes(32)
keyPackage := []byte("test")

mockValidateInboxIdKeyPackages(mlsValidationService, installationId, test.RandomInboxId())

res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{
KeyPackage: &mlsv1.KeyPackageUpload{
KeyPackageTlsSerialized: keyPackage,
},
IsInboxIdCredential: false,
})

require.NoError(t, err)
require.Equal(t, installationId, res.InstallationKey)

installation, err := queries.New(mlsDb.DB).GetInstallation(ctx, installationId)
require.NoError(t, err)

require.Equal(t, installationId, installation.ID)
require.Equal(t, []byte("test"), installation.KeyPackage)
}

func TestRegisterInstallationError(t *testing.T) {
ctx := context.Background()
svc, _, mlsValidationService, cleanup := newTestService(t, ctx)
defer cleanup()

mlsValidationService.EXPECT().ValidateInboxIdKeyPackages(mock.Anything, mock.Anything).Return(nil, errors.New("error validating"))

res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{
KeyPackage: &mlsv1.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
IsInboxIdCredential: false,
})
require.Error(t, err)
require.Nil(t, res)
}

func TestUploadKeyPackage(t *testing.T) {
ctx := context.Background()
svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx)
Expand All @@ -84,7 +129,7 @@ func TestUploadKeyPackage(t *testing.T) {

mockValidateInboxIdKeyPackages(mlsValidationService, installationId, inboxId)

res, err := svc.UploadKeyPackage(ctx, &mlsv1.UploadKeyPackageRequest{
res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{
KeyPackage: &mlsv1.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
Expand Down Expand Up @@ -117,7 +162,7 @@ func TestFetchKeyPackages(t *testing.T) {

mockCall := mockValidateInboxIdKeyPackages(mlsValidationService, installationId1, inboxId)

res, err := svc.UploadKeyPackage(ctx, &mlsv1.UploadKeyPackageRequest{
res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{
KeyPackage: &mlsv1.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
Expand All @@ -133,7 +178,7 @@ func TestFetchKeyPackages(t *testing.T) {

mockValidateInboxIdKeyPackages(mlsValidationService, installationId2, inboxId)

res, err = svc.UploadKeyPackage(ctx, &mlsv1.UploadKeyPackageRequest{
res, err = svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{
KeyPackage: &mlsv1.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test2"),
},
Expand Down
33 changes: 29 additions & 4 deletions pkg/mlsvalidate/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type IdentityInput struct {

type MLSValidationService interface {
ValidateInboxIdKeyPackages(ctx context.Context, keyPackages [][]byte) ([]InboxIdValidationResult, error)
ValidateV3KeyPackages(ctx context.Context, keyPackages [][]byte) ([]IdentityValidationResult, error)
ValidateGroupMessages(ctx context.Context, groupMessages []*mlsv1.GroupMessageInput) ([]GroupMessageValidationResult, error)
GetAssociationState(ctx context.Context, oldUpdates []*associations.IdentityUpdate, newUpdates []*associations.IdentityUpdate) (*AssociationStateResult, error)
}
Expand Down Expand Up @@ -108,15 +109,39 @@ func (s *MLSValidationServiceImpl) ValidateInboxIdKeyPackages(ctx context.Contex
return out, nil
}

func makeValidateKeyPackageRequest(keyPackageBytes [][]byte, isInboxIdCredential bool) *svc.ValidateInboxIdKeyPackagesRequest {
keyPackageRequests := make([]*svc.ValidateInboxIdKeyPackagesRequest_KeyPackage, len(keyPackageBytes))
func (s *MLSValidationServiceImpl) ValidateV3KeyPackages(ctx context.Context, keyPackages [][]byte) ([]IdentityValidationResult, error) {
req := makeValidateKeyPackageRequest(keyPackages, false)

response, err := s.grpcClient.ValidateKeyPackages(ctx, req)
if err != nil {
return nil, err
}

out := make([]IdentityValidationResult, len(response.Responses))
for i, response := range response.Responses {
if !response.IsOk {
return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage)
}
out[i] = IdentityValidationResult{
AccountAddress: response.AccountAddress,
InstallationKey: response.InstallationId,
CredentialIdentity: response.CredentialIdentityBytes,
Expiration: response.Expiration,
}
}

return out, nil
}

func makeValidateKeyPackageRequest(keyPackageBytes [][]byte, isInboxIdCredential bool) *svc.ValidateKeyPackagesRequest {
keyPackageRequests := make([]*svc.ValidateKeyPackagesRequest_KeyPackage, len(keyPackageBytes))
for i, keyPackage := range keyPackageBytes {
keyPackageRequests[i] = &svc.ValidateInboxIdKeyPackagesRequest_KeyPackage{
keyPackageRequests[i] = &svc.ValidateKeyPackagesRequest_KeyPackage{
KeyPackageBytesTlsSerialized: keyPackage,
IsInboxIdCredential: isInboxIdCredential,
}
}
return &svc.ValidateInboxIdKeyPackagesRequest{
return &svc.ValidateKeyPackagesRequest{
KeyPackages: keyPackageRequests,
}
}
Expand Down
39 changes: 38 additions & 1 deletion pkg/mlsvalidate/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,30 @@ func (m *MockedGRPCService) GetAssociationState(ctx context.Context, in *svc.Get
return nil, nil
}

func (m *MockedGRPCService) ValidateKeyPackages(ctx context.Context, req *svc.ValidateKeyPackagesRequest, opts ...grpc.CallOption) (*svc.ValidateKeyPackagesResponse, error) {
args := m.Called(ctx, req)

return args.Get(0).(*svc.ValidateKeyPackagesResponse), args.Error(1)
}

func (m *MockedGRPCService) ValidateGroupMessages(ctx context.Context, req *svc.ValidateGroupMessagesRequest, opts ...grpc.CallOption) (*svc.ValidateGroupMessagesResponse, error) {
args := m.Called(ctx, req)

return args.Get(0).(*svc.ValidateGroupMessagesResponse), args.Error(1)
}

func (m *MockedGRPCService) ValidateInboxIdKeyPackages(ctx context.Context, req *svc.ValidateInboxIdKeyPackagesRequest, opts ...grpc.CallOption) (*svc.ValidateInboxIdKeyPackagesResponse, error) {
func (m *MockedGRPCService) ValidateInboxIdKeyPackages(ctx context.Context, req *svc.ValidateKeyPackagesRequest, opts ...grpc.CallOption) (*svc.ValidateInboxIdKeyPackagesResponse, error) {
args := m.Called(ctx, req)

return args.Get(0).(*svc.ValidateInboxIdKeyPackagesResponse), args.Error(1)
}

func (m *MockedGRPCService) ValidateInboxIds(ctx context.Context, req *svc.ValidateInboxIdsRequest, opts ...grpc.CallOption) (*svc.ValidateInboxIdsResponse, error) {
args := m.Called(ctx, req)

return args.Get(0).(*svc.ValidateInboxIdsResponse), args.Error(1)
}

func getMockedService() (*MockedGRPCService, MLSValidationService) {
mockService := new(MockedGRPCService)
service := &MLSValidationServiceImpl{
Expand All @@ -39,6 +51,31 @@ func getMockedService() (*MockedGRPCService, MLSValidationService) {
return mockService, service
}

func TestValidateKeyPackages(t *testing.T) {
mockGrpc, service := getMockedService()

ctx := context.Background()

firstResponse := svc.ValidateKeyPackagesResponse_ValidationResponse{
IsOk: true,
AccountAddress: "0x123",
InstallationId: []byte("123"),
CredentialIdentityBytes: []byte("456"),
ErrorMessage: "",
}

mockGrpc.On("ValidateKeyPackages", ctx, mock.Anything).Return(&svc.ValidateKeyPackagesResponse{
Responses: []*svc.ValidateKeyPackagesResponse_ValidationResponse{&firstResponse},
}, nil)

res, err := service.ValidateV3KeyPackages(ctx, nil)
assert.NoError(t, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, "0x123", res[0].AccountAddress)
assert.Equal(t, []byte("123"), res[0].InstallationKey)
assert.Equal(t, []byte("456"), res[0].CredentialIdentity)
}

func TestValidateInboxIdKeyPackages(t *testing.T) {
mockGrpc, service := getMockedService()

Expand Down
61 changes: 60 additions & 1 deletion pkg/mocks/mock_MLSValidationService.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/mocks/mock_MlsApi_SubscribeGroupMessagesServer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/mocks/mock_MlsApi_SubscribeWelcomeMessagesServer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit dd0b6d5

Please sign in to comment.