Skip to content

Commit

Permalink
fix #119, panic gorilla ws when connection lost. (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
ElecTwix authored Jun 12, 2024
1 parent bc55e64 commit 8f4a698
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 23 deletions.
38 changes: 35 additions & 3 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ import (
"github.com/surrealdb/surrealdb.go/pkg/marshal"
)

// Default consts and vars for testing
const (
defaultURL = "ws://localhost:8000/rpc"
)

var currentURL = os.Getenv("SURREALDB_URL")

//

// TestDBSuite is a test s for the DB struct
type SurrealDBTestSuite struct {
suite.Suite
Expand Down Expand Up @@ -112,13 +121,18 @@ func (t testUser) String() (str string, err error) {
return
}

// openConnection opens a new connection to the database
func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB {
func (s *SurrealDBTestSuite) createTestDB() *surrealdb.DB {
url := os.Getenv("SURREALDB_URL")
if url == "" {
url = "ws://localhost:8000/rpc"
}
impl := s.connImplementations[s.name]
db := s.openConnection(url, impl)
return db
}

// openConnection opens a new connection to the database
func (s *SurrealDBTestSuite) openConnection(url string, impl conn.Connection) *surrealdb.DB {
require.NotNil(s.T(), impl)
db, err := surrealdb.New(url, impl)
s.Require().NoError(err)
Expand All @@ -127,7 +141,7 @@ func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB {

// SetupSuite is called before the s starts running
func (s *SurrealDBTestSuite) SetupSuite() {
db := s.openConnection()
db := s.createTestDB()
s.Require().NotNil(db)
s.db = db
_ = signin(s)
Expand Down Expand Up @@ -766,6 +780,24 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() {
})
}

func (s *SurrealDBTestSuite) TestConnectionBreak() {
ws := gorilla.Create()
var url string
if currentURL == "" {
url = defaultURL
} else {
url = currentURL
}

db := s.openConnection(url, ws)
// Close the connection hard from ws
ws.Conn.Close()

// Needs to be return error when the connection is closed or broken
_, err := db.Select("users")
s.Require().Error(err)
}

// assertContains performs an assertion on a list, asserting that at least one element matches a provided condition.
// All the matching elements are returned from this function, which can be used as a filter.
func assertContains[K any](s *SurrealDBTestSuite, input []K, matcher func(K) bool) []K {
Expand Down
61 changes: 41 additions & 20 deletions pkg/conn/gorilla/gorilla.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"reflect"
"strconv"
Expand Down Expand Up @@ -43,13 +44,14 @@ type WebSocket struct {
notificationChannels map[string]chan model.Notification
notificationChannelsLock sync.RWMutex

close chan int
closeChan chan int
closeError error
}

func Create() *WebSocket {
return &WebSocket{
Conn: nil,
close: make(chan int),
closeChan: make(chan int),
responseChannels: make(map[string]chan rpc.RPCResponse),
notificationChannels: make(map[string]chan model.Notification),
Timeout: DefaultTimeout * time.Second,
Expand All @@ -73,7 +75,7 @@ func (ws *WebSocket) Connect(url string) (conn.Connection, error) {
}
}

ws.initialize()
go ws.initialize()
return ws, nil
}

Expand Down Expand Up @@ -107,7 +109,7 @@ func (ws *WebSocket) SetCompression(compress bool) *WebSocket {
func (ws *WebSocket) Close() error {
ws.connLock.Lock()
defer ws.connLock.Unlock()
close(ws.close)
close(ws.closeChan)
err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, ""))
if err != nil {
return err
Expand Down Expand Up @@ -179,6 +181,12 @@ func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) {
}

func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) {
select {
case <-ws.closeChan:
return nil, ws.closeError
default:
}

id := rand.String(RequestIDLength)
request := &rpc.RPCRequest{
ID: id,
Expand Down Expand Up @@ -235,25 +243,38 @@ func (ws *WebSocket) write(v interface{}) error {
}

func (ws *WebSocket) initialize() {
go func() {
for {
select {
case <-ws.close:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
ws.logger.Error(err.Error())
continue
for {
select {
case <-ws.closeChan:
return
default:
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
shouldExit := ws.handleError(err)
if shouldExit {
return
}
go ws.handleResponse(res)
continue
}
go ws.handleResponse(res)
}
}()
}
}

func (ws *WebSocket) handleError(err error) bool {
if errors.Is(err, net.ErrClosed) {
ws.closeError = net.ErrClosed
return true
}
if gorilla.IsUnexpectedCloseError(err) {
ws.closeError = io.ErrClosedPipe
<-ws.closeChan
return true
}

ws.logger.Error(err.Error())
return false
}

func (ws *WebSocket) handleResponse(res rpc.RPCResponse) {
Expand Down

0 comments on commit 8f4a698

Please sign in to comment.