diff --git a/pkg/gql/README.md b/pkg/gql/README.md index a9343909a..28a797ac2 100644 --- a/pkg/gql/README.md +++ b/pkg/gql/README.md @@ -12,8 +12,9 @@ GraphQL package is here to provide a read-only access to any persistent/non-pers ### API Endpoints -* `/graphql` - Support data fetching -* `/ws` - Support websocket notifications +* `/graphql` - data fetching +* `/ws` - websocket notifications, if TLS is disabled +* `/wss` - secure websocket notifications, if TLS is enabled ### Scenarios diff --git a/pkg/gql/httpserver.go b/pkg/gql/httpserver.go index a40e01dee..5bfdbcef5 100644 --- a/pkg/gql/httpserver.go +++ b/pkg/gql/httpserver.go @@ -10,6 +10,7 @@ import ( "context" "net" "net/http" + "os" "sync/atomic" "time" @@ -33,6 +34,7 @@ var log = logger.WithFields(logger.Fields{"prefix": "gql"}) const ( endpointWS = "/ws" + endpointWSS = "/wss" endpointGQL = "/graphql" ) @@ -110,6 +112,10 @@ func (s *Server) listenOnHTTPServer(l net.Listener) { var err error if conf.EnableTLS { + if _, err = os.Stat(conf.CertFile); err != nil { + log.WithError(err).Fatal("failed to enable TLS") + } + err = s.httpServer.ServeTLS(l, conf.CertFile, conf.KeyFile) } else { err = s.httpServer.Serve(l) @@ -196,7 +202,13 @@ func (s *Server) EnableNotifications(serverMux *http.ServeMux) error { } middleware := tollbooth.LimitFuncHandler(s.lmt, wsHandler) - serverMux.Handle(endpointWS, middleware) + + endpoint := endpointWS + if cfg.Get().Gql.EnableTLS { + endpoint = endpointWSS + } + + serverMux.Handle(endpoint, middleware) return nil } diff --git a/pkg/gql/httpserver_test.go b/pkg/gql/httpserver_test.go index f87183362..a9d4a635e 100644 --- a/pkg/gql/httpserver_test.go +++ b/pkg/gql/httpserver_test.go @@ -8,6 +8,7 @@ package gql import ( "context" + "crypto/tls" "net/url" "syscall" "testing" @@ -36,6 +37,11 @@ const ( // srvAddr Notification service address. srvAddr = "127.0.0.1:22222" + // TestMultiClient can be executed over TLS, if certFile, keyFile are provided. + enableTLS = false + certFile = "./example.crt" + keyFile = "./example.key" + maxOpenFiles = 10 * 1000 ) @@ -50,7 +56,7 @@ func TestMultiClient(t *testing.T) { // Set up HTTP server with notifications enabled // config - s, eb, err := createServer(srvAddr, brokersNum, maxAllowedClients/brokersNum) + s, eb, err := createServer(srvAddr, brokersNum, maxAllowedClients/brokersNum, enableTLS) assert.NoError(err) defer s.Stop() @@ -59,7 +65,7 @@ func TestMultiClient(t *testing.T) { resp := make(chan string, clientsNum) for i := 0; i < clientsNum; i++ { - if err := createClient(srvAddr, resp); err != nil { + if err := createClient(srvAddr, resp, enableTLS); err != nil { assert.NoError(err) } } @@ -94,12 +100,21 @@ func TestMultiClient(t *testing.T) { } } -func createClient(addr string, resp chan string) error { - dialCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) +func createClient(addr string, resp chan string, enableTLS bool) error { + dialCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - // Set up a websocket client - u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} + // Set up a websocket (secure/insecure) client + scheme := "ws" + + if enableTLS { + websocket.DefaultDialer.TLSClientConfig = &tls.Config{} + websocket.DefaultDialer.TLSClientConfig.InsecureSkipVerify = true + + scheme = "wss" + } + + u := url.URL{Scheme: scheme, Host: addr, Path: "/" + scheme} c, _, err := websocket.DefaultDialer.DialContext(dialCtx, u.String(), nil) if err != nil { @@ -120,14 +135,16 @@ func createClient(addr string, resp chan string) error { return nil } -func createServer(addr string, brokerNum, clientsPerBroker uint) (*Server, *eventbus.EventBus, error) { +func createServer(addr string, brokerNum, clientsPerBroker uint, enableTLS bool) (*Server, *eventbus.EventBus, error) { // Set up HTTP server with notifications enabled // config r := config.Registry{} r.Gql.Network = "tcp" r.Gql.Address = addr r.Gql.Enabled = true - r.Gql.EnableTLS = false + r.Gql.EnableTLS = enableTLS + r.Gql.CertFile = certFile + r.Gql.KeyFile = keyFile r.Gql.Notification.BrokersNum = brokerNum r.Gql.Notification.ClientsPerBroker = clientsPerBroker r.Database.Driver = lite.DriverName