Skip to content

Commit

Permalink
Handle RPC errors (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheRangiCrew authored Oct 31, 2024
1 parent f8d60dd commit 8448815
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
12 changes: 12 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,15 @@ func (s *SurrealDBTestSuite) TestQueryRaw() {
fmt.Println(created)
fmt.Println(selected)
}

func (s *SurrealDBTestSuite) TestRPCError() {
s.Run("Test valid query", func() {
_, err := surrealdb.Query[[]testUser](s.db, "SELECT * FROM users", map[string]interface{}{})
s.Require().NoError(err)
})

s.Run("Test invalid query", func() {
_, err := surrealdb.Query[[]testUser](s.db, "SELEC * FROM users", map[string]interface{}{})
s.Require().Error(err)
})
}
30 changes: 30 additions & 0 deletions pkg/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ type BaseConnection struct {
responseChannels map[string]chan []byte
responseChannelsLock sync.RWMutex

errorChannels map[string]chan error
errorChannelsLock sync.RWMutex

notificationChannels map[string]chan Notification
notificationChannelsLock sync.RWMutex
}
Expand All @@ -60,6 +63,20 @@ func (bc *BaseConnection) createResponseChannel(id string) (chan []byte, error)
return ch, nil
}

func (bc *BaseConnection) createErrorChannel(id string) (chan error, error) {
bc.errorChannelsLock.Lock()
defer bc.errorChannelsLock.Unlock()

if _, ok := bc.errorChannels[id]; ok {
return nil, fmt.Errorf("%w: %v", constants.ErrIDInUse, id)
}

ch := make(chan error)
bc.errorChannels[id] = ch

return ch, nil
}

func (bc *BaseConnection) createNotificationChannel(liveQueryID string) (chan Notification, error) {
bc.notificationChannelsLock.Lock()
defer bc.notificationChannelsLock.Unlock()
Expand All @@ -80,13 +97,26 @@ func (bc *BaseConnection) removeResponseChannel(id string) {
delete(bc.responseChannels, id)
}

func (bc *BaseConnection) removeErrorChannel(id string) {
bc.errorChannelsLock.Lock()
defer bc.errorChannelsLock.Unlock()
delete(bc.errorChannels, id)
}

func (bc *BaseConnection) getResponseChannel(id string) (chan []byte, bool) {
bc.responseChannelsLock.RLock()
defer bc.responseChannelsLock.RUnlock()
ch, ok := bc.responseChannels[id]
return ch, ok
}

func (bc *BaseConnection) getErrorChannel(id string) (chan error, bool) {
bc.errorChannelsLock.RLock()
defer bc.errorChannelsLock.RUnlock()
ch, ok := bc.errorChannels[id]
return ch, ok
}

func (bc *BaseConnection) getLiveChannel(id string) (chan Notification, bool) {
bc.notificationChannelsLock.RLock()
defer bc.notificationChannelsLock.RUnlock()
Expand Down
22 changes: 22 additions & 0 deletions pkg/connection/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func NewWebSocketConnection(p NewConnectionParams) *WebSocketConnection {
unmarshaler: p.Unmarshaler,

responseChannels: make(map[string]chan []byte),
errorChannels: make(map[string]chan error),
notificationChannels: make(map[string]chan Notification),
},

Expand Down Expand Up @@ -159,7 +160,12 @@ func (ws *WebSocketConnection) Send(dest interface{}, method string, params ...i
if err != nil {
return err
}
errorChan, err := ws.createErrorChannel(id)
if err != nil {
return err
}
defer ws.removeResponseChannel(id)
defer ws.removeErrorChannel(id)

if err := ws.write(request); err != nil {
return err
Expand All @@ -177,6 +183,11 @@ func (ws *WebSocketConnection) Send(dest interface{}, method string, params ...i
return ws.unmarshaler.Unmarshal(resBytes, dest)
}
return nil
case resErr, open := <-errorChan:
if !open {
return errors.New("error channel closed")
}
return resErr
}
}

Expand Down Expand Up @@ -234,6 +245,17 @@ func (ws *WebSocketConnection) handleResponse(res []byte) {
if rpcRes.Error != nil {
err := fmt.Errorf("rpc request err %w", rpcRes.Error)
ws.logger.Error(err.Error())

errChan, ok := ws.getErrorChannel(fmt.Sprintf("%v", rpcRes.ID))
if !ok {
err := fmt.Errorf("unavailable ErrorChannel %+v", rpcRes.ID)
ws.logger.Error(err.Error())
return
}

defer close(errChan)
errChan <- rpcRes.Error

return
}

Expand Down

0 comments on commit 8448815

Please sign in to comment.