Skip to content

Commit

Permalink
Merge pull request #65 from matrix-org/kegan/rpc-orphan
Browse files Browse the repository at this point in the history
Add keep-alive mechanism for the RPC server
  • Loading branch information
kegsay authored May 22, 2024
2 parents c602cd3 + f332973 commit 6b55d02
Showing 1 changed file with 63 additions and 10 deletions.
73 changes: 63 additions & 10 deletions internal/deploy/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,42 @@ package deploy
import (
"fmt"
"log"
"os"
"sync"
"time"

"github.com/matrix-org/complement-crypto/internal/api"
"github.com/matrix-org/complement-crypto/internal/api/langs"
)

const InactivityThreshold = 30 * time.Second

// RPCServer exposes the api.Client interface over the wire, consumed via net/rpc.
// Args and return params must be encodable with encoding/gob.
// All functions on this struct must meet the form:
//
// func (t *T) MethodName(argType T1, replyType *T2) error
type RPCServer struct {
contextID string // test|user|device
bindings api.LanguageBindings
activeClient api.Client
stopSyncing func()
waiters map[int]*RPCServerWaiter
nextWaiterID int
waitersMu *sync.Mutex
contextID string // test|user|device
bindings api.LanguageBindings
activeClient api.Client
stopSyncing func()
waiters map[int]*RPCServerWaiter
nextWaiterID int
waitersMu *sync.Mutex
lastCmdRecv time.Time
lastCmdRecvMu *sync.Mutex
}

func NewRPCServer() *RPCServer {
return &RPCServer{
waiters: make(map[int]*RPCServerWaiter),
waitersMu: &sync.Mutex{},
srv := &RPCServer{
waiters: make(map[int]*RPCServerWaiter),
waitersMu: &sync.Mutex{},
lastCmdRecv: time.Now(),
lastCmdRecvMu: &sync.Mutex{},
}
go srv.checkKeepAlive()
return srv
}

type RPCClientCreationOpts struct {
Expand All @@ -38,8 +47,30 @@ type RPCClientCreationOpts struct {
ContextID string
}

// When the RPC server is run locally, we want to make sure we don't persist as an orphan process
// if the test suite crashes. We do this by checking that we have seen an RPC command within
// InactivityThreshold duration.
func (s *RPCServer) checkKeepAlive() {
ticker := time.NewTicker(time.Second)
for range ticker.C {
s.lastCmdRecvMu.Lock()
if time.Since(s.lastCmdRecv) > InactivityThreshold {
fmt.Printf("terminating RPC server due to inactivity (%v)\n", InactivityThreshold)
os.Exit(0)
}
s.lastCmdRecvMu.Unlock()
}
}

func (s *RPCServer) keepAlive() {
s.lastCmdRecvMu.Lock()
defer s.lastCmdRecvMu.Unlock()
s.lastCmdRecv = time.Now()
}

// MustCreateClient creates a given client and returns it to the caller, else returns an error.
func (s *RPCServer) MustCreateClient(opts RPCClientCreationOpts, void *int) error {
defer s.keepAlive()
fmt.Printf("RPCServer: Received MustCreateClient: %+v\n", opts)
if s.activeClient != nil {
return fmt.Errorf("RPC: MustCreateClient: already have an activeClient")
Expand All @@ -56,27 +87,32 @@ func (s *RPCServer) MustCreateClient(opts RPCClientCreationOpts, void *int) erro
}

func (s *RPCServer) Close(testName string, void *int) error {
defer s.keepAlive()
s.activeClient.Close(&api.MockT{TestName: testName})
// write logs
s.bindings.PostTestRun(s.contextID)
return nil
}

func (s *RPCServer) DeletePersistentStorage(testName string, void *int) error {
defer s.keepAlive()
s.activeClient.DeletePersistentStorage(&api.MockT{TestName: testName})
return nil
}

func (s *RPCServer) Login(opts api.ClientCreationOpts, void *int) error {
defer s.keepAlive()
return s.activeClient.Login(&api.MockT{}, opts)
}

func (s *RPCServer) MustStartSyncing(testName string, void *int) error {
defer s.keepAlive()
s.stopSyncing = s.activeClient.MustStartSyncing(&api.MockT{TestName: testName})
return nil
}

func (s *RPCServer) StartSyncing(testName string, void *int) error {
defer s.keepAlive()
stopSyncing, err := s.activeClient.StartSyncing(&api.MockT{TestName: testName})
if err != nil {
return fmt.Errorf("%s RPCServer.StartSyncing: %v", testName, err)
Expand All @@ -86,6 +122,7 @@ func (s *RPCServer) StartSyncing(testName string, void *int) error {
}

func (s *RPCServer) StopSyncing(testName string, void *int) error {
defer s.keepAlive()
if s.stopSyncing == nil {
return fmt.Errorf("%s RPCServer.StopSyncing: cannot stop syncing as StartSyncing wasn't called", testName)
}
Expand All @@ -95,6 +132,7 @@ func (s *RPCServer) StopSyncing(testName string, void *int) error {
}

func (s *RPCServer) IsRoomEncrypted(roomID string, isEncrypted *bool) error {
defer s.keepAlive()
var err error
*isEncrypted, err = s.activeClient.IsRoomEncrypted(&api.MockT{}, roomID)
return err
Expand All @@ -107,11 +145,13 @@ type RPCSendMessage struct {
}

func (s *RPCServer) SendMessage(msg RPCSendMessage, eventID *string) error {
defer s.keepAlive()
*eventID = s.activeClient.SendMessage(&api.MockT{TestName: msg.TestName}, msg.RoomID, msg.Text)
return nil
}

func (s *RPCServer) TrySendMessage(msg RPCSendMessage, eventID *string) error {
defer s.keepAlive()
var err error
*eventID, err = s.activeClient.TrySendMessage(&api.MockT{TestName: msg.TestName}, msg.RoomID, msg.Text)
if err != nil {
Expand All @@ -126,6 +166,7 @@ type RPCWaitUntilEvent struct {
}

func (s *RPCServer) WaitUntilEventInRoom(input RPCWaitUntilEvent, waiterID *int) error {
defer s.keepAlive()
waiter := s.activeClient.WaitUntilEventInRoom(&api.MockT{TestName: input.TestName}, input.RoomID, func(e api.Event) bool {
s.waitersMu.Lock()
defer s.waitersMu.Unlock()
Expand Down Expand Up @@ -162,6 +203,7 @@ type RPCWait struct {
// WaiterStart is the RPC equivalent to Waiter.Waitf. It begins accumulating events for the RPC client to check.
// Clients need to call WaiterPoll to get these new events.
func (s *RPCServer) WaiterStart(input RPCWait, void *int) error {
defer s.keepAlive()
s.waitersMu.Lock()
w := s.waiters[input.WaiterID]
if w == nil {
Expand All @@ -185,6 +227,7 @@ func (s *RPCServer) WaiterStart(input RPCWait, void *int) error {
}

func (s *RPCServer) WaiterPoll(waiterID int, eventsToCheck *[]api.Event) error {
defer s.keepAlive()
fmt.Println("Acquiring lock")
s.waitersMu.Lock()
defer s.waitersMu.Unlock()
Expand Down Expand Up @@ -213,6 +256,7 @@ type RPCBackpaginate struct {
}

func (s *RPCServer) MustBackpaginate(input RPCBackpaginate, void *int) error {
defer s.keepAlive()
s.activeClient.MustBackpaginate(&api.MockT{TestName: input.TestName}, input.RoomID, input.Count)
return nil
}
Expand All @@ -225,12 +269,14 @@ type RPCGetEvent struct {

// MustGetEvent will return the client's view of this event, or fail the test if the event cannot be found.
func (s *RPCServer) MustGetEvent(input RPCGetEvent, output *api.Event) error {
defer s.keepAlive()
*output = s.activeClient.MustGetEvent(&api.MockT{TestName: input.TestName}, input.RoomID, input.EventID)
return nil
}

// MustBackupKeys will backup E2EE keys, else fail the test.
func (s *RPCServer) MustBackupKeys(testName string, recoveryKey *string) error {
defer s.keepAlive()
*recoveryKey = s.activeClient.MustBackupKeys(&api.MockT{TestName: testName})
return nil
}
Expand All @@ -241,6 +287,7 @@ type RPCGetNotification struct {
}

func (s *RPCServer) GetNotification(input RPCGetNotification, output *api.Notification) (err error) {
defer s.keepAlive()
var n *api.Notification
n, err = s.activeClient.GetNotification(&api.MockT{}, input.RoomID, input.EventID)
if err == nil {
Expand All @@ -251,29 +298,35 @@ func (s *RPCServer) GetNotification(input RPCGetNotification, output *api.Notifi

// MustLoadBackup will recover E2EE keys from the latest backup, else fail the test.
func (s *RPCServer) MustLoadBackup(recoveryKey string, void *int) error {
defer s.keepAlive()
s.activeClient.MustLoadBackup(&api.MockT{}, recoveryKey)
return nil
}

func (s *RPCServer) LoadBackup(recoveryKey string, void *int) error {
defer s.keepAlive()
return s.activeClient.LoadBackup(&api.MockT{}, recoveryKey)
}

func (s *RPCServer) Logf(input string, void *int) error {
defer s.keepAlive()
log.Println(input)
s.activeClient.Logf(&api.MockT{}, input)
return nil
}

func (s *RPCServer) UserID(void int, userID *string) error {
defer s.keepAlive()
*userID = s.activeClient.UserID()
return nil
}
func (s *RPCServer) Type(void int, clientType *api.ClientTypeLang) error {
defer s.keepAlive()
*clientType = s.activeClient.Type()
return nil
}
func (s *RPCServer) Opts(void int, opts *api.ClientCreationOpts) error {
defer s.keepAlive()
*opts = s.activeClient.Opts()
return nil
}
Expand Down

0 comments on commit 6b55d02

Please sign in to comment.