diff --git a/service/cert_test.go b/service/cert_test.go index 636b3be9..fc61bc31 100644 --- a/service/cert_test.go +++ b/service/cert_test.go @@ -1,8 +1,16 @@ package service import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha1" + "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "math/big" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -31,5 +39,70 @@ func (c *CertSuite) Test_SkiFromCertificate() { ski, err := skiFromCertificate(leaf) assert.Nil(c.T(), err) - assert.NotNil(c.T(), ski) + assert.NotEqual(c.T(), "", ski) + + cert, err = CreateInvalidCertificate("unit", "org", "DE", "CN") + assert.Nil(c.T(), err) + + leaf, err = x509.ParseCertificate(cert.Certificate[0]) + assert.Nil(c.T(), err) + + ski, err = skiFromCertificate(leaf) + assert.NotNil(c.T(), err) + assert.Equal(c.T(), "", ski) +} + +func CreateInvalidCertificate(organizationalUnit, organization, country, commonName string) (tls.Certificate, error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + // Create the EEBUS service SKI using the private key + asn1, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return tls.Certificate{}, err + } + // SHIP 12.2: Required to be created according to RFC 3280 4.2.1.2 + ski := sha1.Sum(asn1) + + subject := pkix.Name{ + OrganizationalUnit: []string{organizationalUnit}, + Organization: []string{organization}, + Country: []string{country}, + CommonName: commonName, + } + + // Create a random serial big int value + maxValue := new(big.Int) + maxValue.Exp(big.NewInt(2), big.NewInt(130), nil).Sub(maxValue, big.NewInt(1)) + serialNumber, err := rand.Int(rand.Reader, maxValue) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SignatureAlgorithm: x509.ECDSAWithSHA256, + SerialNumber: serialNumber, + Subject: subject, + NotBefore: time.Now(), // Valid starting now + NotAfter: time.Now().Add(time.Hour * 24 * 365 * 10), // Valid for 10 years + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + SubjectKeyId: ski[:19], + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return tls.Certificate{}, err + } + + tlsCertificate := tls.Certificate{ + Certificate: [][]byte{certBytes}, + PrivateKey: privateKey, + SupportedSignatureAlgorithms: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256}, + } + + return tlsCertificate, nil } diff --git a/service/hub.go b/service/hub.go index 2b2a73ce..98873024 100644 --- a/service/hub.go +++ b/service/hub.go @@ -103,7 +103,7 @@ type connectionsHubImpl struct { knownMdnsEntries []*MdnsEntry // the SPINE local device - spineLocalDevice *spine.DeviceLocalImpl + spineLocalDevice spine.DeviceLocalConnection muxCon sync.Mutex muxConAttempt sync.Mutex @@ -111,7 +111,7 @@ type connectionsHubImpl struct { muxMdns sync.Mutex } -func newConnectionsHub(serviceProvider ServiceProvider, mdns MdnsService, spineLocalDevice *spine.DeviceLocalImpl, configuration *Configuration, localService *ServiceDetails) ConnectionsHub { +func newConnectionsHub(serviceProvider ServiceProvider, mdns MdnsService, spineLocalDevice spine.DeviceLocalConnection, configuration *Configuration, localService *ServiceDetails) ConnectionsHub { hub := &connectionsHubImpl{ connections: make(map[string]*ship.ShipConnection), connectionAttemptCounter: make(map[string]int), @@ -257,9 +257,9 @@ func (h *connectionsHubImpl) HandleShipHandshakeStateUpdate(ski string, state sh service := h.ServiceForSKI(ski) - existingDetails := service.ConnectionStateDetail + existingDetails := service.ConnectionStateDetail() if existingDetails.State() != pairingState || existingDetails.Error() != state.Error { - service.ConnectionStateDetail = pairingDetail + service.SetConnectionStateDetail(pairingDetail) h.serviceProvider.ServicePairingDetailUpdate(ski, pairingDetail) } @@ -280,7 +280,7 @@ func (h *connectionsHubImpl) PairingDetailForSki(ski string) *ConnectionStateDet return NewConnectionStateDetail(state, shipError) } - return service.ConnectionStateDetail + return service.ConnectionStateDetail() } // maps ShipMessageExchangeState to PairingState @@ -375,6 +375,25 @@ func (h *connectionsHubImpl) isSkiConnected(ski string) bool { } // Websocket connection handling +func (h *connectionsHubImpl) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + skiFound := false + for _, v := range rawCerts { + cert, err := x509.ParseCertificate(v) + if err != nil { + return err + } + + if _, err := skiFromCertificate(cert); err == nil { + skiFound = true + break + } + } + if !skiFound { + return errors.New("no valid SKI provided in certificate") + } + + return nil +} // start the ship websocket server func (h *connectionsHubImpl) startWebsocketServer() error { @@ -385,28 +404,10 @@ func (h *connectionsHubImpl) startWebsocketServer() error { Addr: addr, Handler: h, TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{h.configuration.certificate}, - ClientAuth: tls.RequireAnyClientCert, // SHIP 9: Client authentication is required - CipherSuites: ciperSuites, - VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - skiFound := false - for _, v := range rawCerts { - cert, err := x509.ParseCertificate(v) - if err != nil { - return err - } - - if _, err := skiFromCertificate(cert); err == nil { - skiFound = true - break - } - } - if !skiFound { - return errors.New("no valid SKI provided in certificate") - } - - return nil - }, + Certificates: []tls.Certificate{h.configuration.certificate}, + ClientAuth: tls.RequireAnyClientCert, // SHIP 9: Client authentication is required + CipherSuites: ciperSuites, + VerifyPeerCertificate: h.verifyPeerCertificate, }, } @@ -445,7 +446,7 @@ func (h *connectionsHubImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // check if the clients certificate provides a SKI - if len(r.TLS.PeerCertificates) == 0 { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { logging.Log.Debug("client does not provide a certificate") _ = conn.Close() return @@ -464,9 +465,10 @@ func (h *connectionsHubImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Check if the remote service is paired service := h.ServiceForSKI(remoteService.SKI) - if service.ConnectionStateDetail.State() == ConnectionStateQueued { - service.ConnectionStateDetail.SetState(ConnectionStateReceivedPairingRequest) - h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail) + connectionStateDetail := service.ConnectionStateDetail() + if connectionStateDetail.State() == ConnectionStateQueued { + connectionStateDetail.SetState(ConnectionStateReceivedPairingRequest) + h.serviceProvider.ServicePairingDetailUpdate(ski, connectionStateDetail) } remoteService = service @@ -608,7 +610,7 @@ func (h *connectionsHubImpl) ServiceForSKI(ski string) *ServiceDetails { service, ok := h.remoteServices[ski] if !ok { service = NewServiceDetails(ski) - service.ConnectionStateDetail.SetState(ConnectionStateNone) + service.ConnectionStateDetail().SetState(ConnectionStateNone) h.remoteServices[ski] = service } @@ -629,9 +631,9 @@ func (h *connectionsHubImpl) RegisterRemoteSKI(ski string, enable bool) { h.removeConnectionAttemptCounter(ski) - service.ConnectionStateDetail.SetState(ConnectionStateNone) + service.ConnectionStateDetail().SetState(ConnectionStateNone) - h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail) + h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail()) if existingC := h.connectionForSKI(ski); existingC != nil { existingC.CloseConnection(true, 4500, "User close") @@ -651,9 +653,9 @@ func (h *connectionsHubImpl) InitiatePairingWithSKI(ski string) { // locally initiated service := h.ServiceForSKI(ski) - service.ConnectionStateDetail.SetState(ConnectionStateQueued) + service.ConnectionStateDetail().SetState(ConnectionStateQueued) - h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail) + h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail()) // initiate a search and also a connection if it does not yet exist if !h.isSkiConnected(service.SKI) { @@ -670,10 +672,10 @@ func (h *connectionsHubImpl) CancelPairingWithSKI(ski string) { } service := h.ServiceForSKI(ski) - service.ConnectionStateDetail.SetState(ConnectionStateNone) + service.ConnectionStateDetail().SetState(ConnectionStateNone) service.Trusted = false - h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail) + h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail()) } // Process reported mDNS services @@ -694,7 +696,7 @@ func (h *connectionsHubImpl) ReportMdnsEntries(entries map[string]*MdnsEntry) { // Check if the remote service is paired or queued for connection service := h.ServiceForSKI(ski) if !h.IsRemoteServiceForSKIPaired(ski) && - service.ConnectionStateDetail.State() != ConnectionStateQueued { + service.ConnectionStateDetail().State() != ConnectionStateQueued { continue } @@ -732,7 +734,7 @@ func (h *connectionsHubImpl) coordinateConnectionInitations(ski string, entry *M counter, duration := h.getConnectionInitiationDelayTime(ski) service := h.ServiceForSKI(ski) - if service.ConnectionStateDetail.State() == ConnectionStateQueued { + if service.ConnectionStateDetail().State() == ConnectionStateQueued { go h.prepareConnectionInitation(ski, counter, entry) return } @@ -762,7 +764,7 @@ func (h *connectionsHubImpl) prepareConnectionInitation(ski string, counter int, // connection attempt is not relevant if the device is no longer paired // or it is not queued for pairing - pairingState := h.ServiceForSKI(ski).ConnectionStateDetail.State() + pairingState := h.ServiceForSKI(ski).ConnectionStateDetail().State() if !h.IsRemoteServiceForSKIPaired(ski) && pairingState != ConnectionStateQueued { return } @@ -790,7 +792,7 @@ func (h *connectionsHubImpl) initateConnection(remoteService *ServiceDetails, en for _, address := range entry.Addresses { // connection attempt is not relevant if the device is no longer paired // or it is not queued for pairing - pairingState := h.ServiceForSKI(remoteService.SKI).ConnectionStateDetail.State() + pairingState := h.ServiceForSKI(remoteService.SKI).ConnectionStateDetail().State() if !h.IsRemoteServiceForSKIPaired(remoteService.SKI) && pairingState != ConnectionStateQueued { return false } diff --git a/service/hub_test.go b/service/hub_test.go index ec6d9680..1ff44960 100644 --- a/service/hub_test.go +++ b/service/hub_test.go @@ -1,13 +1,18 @@ package service import ( + "crypto/tls" "errors" + "net/http" + "net/http/httptest" + "strings" "testing" "time" "github.com/enbility/eebus-go/ship" "github.com/enbility/eebus-go/spine/model" gomock "github.com/golang/mock/gomock" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -28,6 +33,8 @@ type HubSuite struct { mdnsService *MockMdnsService tests []testStruct + + sut *connectionsHubImpl } func (s *HubSuite) SetupSuite() { @@ -62,6 +69,27 @@ func (s *HubSuite) SetupSuite() { s.mdnsService.EXPECT().UnregisterMdnsSearch(gomock.Any()).AnyTimes() } +func (s *HubSuite) BeforeTest(suiteName, testName string) { + localService := &ServiceDetails{ + SKI: "localSKI", + } + + s.sut = &connectionsHubImpl{ + connections: make(map[string]*ship.ShipConnection), + connectionAttemptCounter: make(map[string]int), + connectionAttemptRunning: make(map[string]bool), + remoteServices: make(map[string]*ServiceDetails), + serviceProvider: s.serviceProvider, + localService: localService, + mdns: s.mdnsService, + } + + certificate, _ := CreateCertificate("unit", "org", "DE", "CN") + s.sut.configuration, _ = NewConfiguration("vendor", "brand", "model", "serial", + model.DeviceTypeTypeGeneric, []model.EntityTypeType{model.EntityTypeTypeCEM}, + 4567, certificate, 230, time.Second*4) +} + func (s *HubSuite) Test_NewConnectionsHub() { ski := "12af9e" localService := NewServiceDetails(ski) @@ -76,224 +104,251 @@ func (s *HubSuite) Test_NewConnectionsHub() { } func (s *HubSuite) Test_IsRemoteSKIPaired() { - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - connectionAttemptCounter: make(map[string]int), - remoteServices: make(map[string]*ServiceDetails), - serviceProvider: s.serviceProvider, - } ski := "test" - paired := sut.IsRemoteServiceForSKIPaired(ski) + paired := s.sut.IsRemoteServiceForSKIPaired(ski) assert.Equal(s.T(), false, paired) // mark it as connected, so mDNS is not triggered con := &ship.ShipConnection{ RemoteSKI: ski, } - sut.registerConnection(con) - sut.RegisterRemoteSKI(ski, true) + s.sut.registerConnection(con) + s.sut.RegisterRemoteSKI(ski, true) - paired = sut.IsRemoteServiceForSKIPaired(ski) + paired = s.sut.IsRemoteServiceForSKIPaired(ski) assert.Equal(s.T(), true, paired) // remove the connection, so the test doesn't try to close it - delete(sut.connections, ski) - sut.RegisterRemoteSKI(ski, false) - paired = sut.IsRemoteServiceForSKIPaired(ski) + delete(s.sut.connections, ski) + s.sut.RegisterRemoteSKI(ski, false) + paired = s.sut.IsRemoteServiceForSKIPaired(ski) assert.Equal(s.T(), false, paired) } func (s *HubSuite) Test_HandleConnecitonClosed() { - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - connectionAttemptCounter: make(map[string]int), - remoteServices: make(map[string]*ServiceDetails), - serviceProvider: s.serviceProvider, - } ski := "test" con := &ship.ShipConnection{ RemoteSKI: ski, } - sut.HandleConnectionClosed(con, false) + s.sut.HandleConnectionClosed(con, false) - sut.registerConnection(con) + s.sut.registerConnection(con) - sut.HandleConnectionClosed(con, true) + s.sut.HandleConnectionClosed(con, true) - assert.Equal(s.T(), 0, len(sut.connections)) + assert.Equal(s.T(), 0, len(s.sut.connections)) } func (s *HubSuite) Test_Mdns() { - localService := ServiceDetails{ - DeviceType: model.DeviceTypeTypeElectricitySupplySystem, - } - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - connectionAttemptCounter: make(map[string]int), - remoteServices: make(map[string]*ServiceDetails), - localService: &localService, - mdns: s.mdnsService, - serviceProvider: s.serviceProvider, - } - sut.checkRestartMdnsSearch() + s.sut.checkRestartMdnsSearch() - pairedServices := sut.numberPairedServices() - assert.Equal(s.T(), 0, len(sut.connections)) + pairedServices := s.sut.numberPairedServices() + assert.Equal(s.T(), 0, len(s.sut.connections)) assert.Equal(s.T(), 0, pairedServices) ski := "testski" - sut.RegisterRemoteSKI(ski, true) - pairedServices = sut.numberPairedServices() - assert.Equal(s.T(), 0, len(sut.connections)) + s.sut.RegisterRemoteSKI(ski, true) + pairedServices = s.sut.numberPairedServices() + assert.Equal(s.T(), 0, len(s.sut.connections)) assert.Equal(s.T(), 1, pairedServices) - sut.StartBrowseMdnsSearch() + s.sut.StartBrowseMdnsSearch() - sut.StopBrowseMdnsSearch() + s.sut.StopBrowseMdnsSearch() } func (s *HubSuite) Test_Ship() { - localService := ServiceDetails{ - DeviceType: model.DeviceTypeTypeElectricitySupplySystem, - } - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - connectionAttemptCounter: make(map[string]int), - remoteServices: make(map[string]*ServiceDetails), - localService: &localService, - mdns: s.mdnsService, - serviceProvider: s.serviceProvider, - } - ski := "testski" - sut.HandleShipHandshakeStateUpdate(ski, ship.ShipState{ + s.sut.HandleShipHandshakeStateUpdate(ski, ship.ShipState{ State: ship.SmeStateError, Error: errors.New("test"), }) - sut.HandleShipHandshakeStateUpdate(ski, ship.ShipState{ + s.sut.HandleShipHandshakeStateUpdate(ski, ship.ShipState{ State: ship.SmeHelloStateOk, }) - sut.ReportServiceShipID(ski, "test") + s.sut.ReportServiceShipID(ski, "test") - trust := sut.AllowWaitingForTrust(ski) + trust := s.sut.AllowWaitingForTrust(ski) assert.Equal(s.T(), true, trust) - trust = sut.AllowWaitingForTrust("test") + trust = s.sut.AllowWaitingForTrust("test") assert.Equal(s.T(), false, trust) - detail := sut.PairingDetailForSki(ski) + detail := s.sut.PairingDetailForSki(ski) assert.NotNil(s.T(), detail) con := &ship.ShipConnection{ RemoteSKI: ski, } - sut.registerConnection(con) + s.sut.registerConnection(con) - detail = sut.PairingDetailForSki(ski) + detail = s.sut.PairingDetailForSki(ski) assert.NotNil(s.T(), detail) } func (s *HubSuite) Test_MapShipMessageExchangeState() { - sut := connectionsHubImpl{} - ski := "test" - state := sut.mapShipMessageExchangeState(ship.CmiStateInitStart, ski) + state := s.sut.mapShipMessageExchangeState(ship.CmiStateInitStart, ski) assert.Equal(s.T(), ConnectionStateQueued, state) - state = sut.mapShipMessageExchangeState(ship.CmiStateClientSend, ski) + state = s.sut.mapShipMessageExchangeState(ship.CmiStateClientSend, ski) assert.Equal(s.T(), ConnectionStateInitiated, state) - state = sut.mapShipMessageExchangeState(ship.SmeHelloStateReadyInit, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeHelloStateReadyInit, ski) assert.Equal(s.T(), ConnectionStateInProgress, state) - state = sut.mapShipMessageExchangeState(ship.SmeHelloStatePendingInit, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeHelloStatePendingInit, ski) assert.Equal(s.T(), ConnectionStateReceivedPairingRequest, state) - state = sut.mapShipMessageExchangeState(ship.SmeHelloStateOk, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeHelloStateOk, ski) assert.Equal(s.T(), ConnectionStateTrusted, state) - state = sut.mapShipMessageExchangeState(ship.SmeHelloStateAbort, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeHelloStateAbort, ski) assert.Equal(s.T(), ConnectionStateNone, state) - state = sut.mapShipMessageExchangeState(ship.SmeHelloStateRemoteAbortDone, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeHelloStateRemoteAbortDone, ski) assert.Equal(s.T(), ConnectionStateRemoteDeniedTrust, state) - state = sut.mapShipMessageExchangeState(ship.SmePinStateCheckInit, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmePinStateCheckInit, ski) assert.Equal(s.T(), ConnectionStatePin, state) - state = sut.mapShipMessageExchangeState(ship.SmeAccessMethodsRequest, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeAccessMethodsRequest, ski) assert.Equal(s.T(), ConnectionStateInProgress, state) - state = sut.mapShipMessageExchangeState(ship.SmeStateComplete, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeStateComplete, ski) assert.Equal(s.T(), ConnectionStateCompleted, state) - state = sut.mapShipMessageExchangeState(ship.SmeStateError, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeStateError, ski) assert.Equal(s.T(), ConnectionStateError, state) - state = sut.mapShipMessageExchangeState(ship.SmeProtHStateTimeout, ski) + state = s.sut.mapShipMessageExchangeState(ship.SmeProtHStateTimeout, ski) assert.Equal(s.T(), ConnectionStateInProgress, state) } func (s *HubSuite) Test_DisconnectSKI() { - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - } ski := "test" - sut.DisconnectSKI(ski, "none") + s.sut.DisconnectSKI(ski, "none") } func (s *HubSuite) Test_RegisterConnection() { - ski := "12af9e" - localService := NewServiceDetails(ski) - - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - mdns: s.mdnsService, - localService: localService, - } - - ski = "test" + ski := "test" con := &ship.ShipConnection{ RemoteSKI: ski, } - sut.registerConnection(con) - assert.Equal(s.T(), 1, len(sut.connections)) - con = sut.connectionForSKI(ski) + s.sut.registerConnection(con) + assert.Equal(s.T(), 1, len(s.sut.connections)) + con = s.sut.connectionForSKI(ski) assert.NotNil(s.T(), con) } func (s *HubSuite) Test_Shutdown() { - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - mdns: s.mdnsService, - } s.mdnsService.EXPECT().ShutdownMdnsService() - sut.Shutdown() + s.sut.Shutdown() } -func (s *HubSuite) Test_IncreaseConnectionAttemptCounter() { +func (s *HubSuite) Test_VerifyPeerCertificate() { + testCert, _ := CreateCertificate("unit", "org", "DE", "CN") + var rawCerts [][]byte + rawCerts = append(rawCerts, testCert.Certificate...) + err := s.sut.verifyPeerCertificate(rawCerts, nil) + assert.Nil(s.T(), err) - // we just need a dummy for this test - sut := connectionsHubImpl{ - connectionAttemptCounter: make(map[string]int), + rawCerts = nil + rawCerts = append(rawCerts, []byte{100}) + err = s.sut.verifyPeerCertificate(rawCerts, nil) + assert.NotNil(s.T(), err) + + rawCerts = nil + invalidCert, _ := CreateInvalidCertificate("unit", "org", "DE", "CN") + rawCerts = append(rawCerts, invalidCert.Certificate...) + + err = s.sut.verifyPeerCertificate(rawCerts, nil) + assert.NotNil(s.T(), err) +} + +func (s *HubSuite) Test_ServeHTTP() { + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + w := httptest.NewRecorder() + s.sut.ServeHTTP(w, req) + + server := httptest.NewServer(s.sut) + wsURL := strings.Replace(server.URL, "http://", "ws://", -1) + + // Connect to the server + con, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + assert.Nil(s.T(), err) + con.Close() + + dialer := &websocket.Dialer{ + Subprotocols: []string{shipWebsocketSubProtocol}, } + con, _, err = dialer.Dial(wsURL, nil) + assert.Nil(s.T(), err) + con.Close() + server.Close() + + server = httptest.NewUnstartedServer(s.sut) + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{s.sut.configuration.certificate}, + ClientAuth: tls.RequireAnyClientCert, + CipherSuites: ciperSuites, + InsecureSkipVerify: true, + } + server.StartTLS() + wsURL = strings.Replace(server.URL, "https://", "wss://", -1) + + invalidCert, _ := CreateInvalidCertificate("unit", "org", "DE", "CN") + dialer = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 5 * time.Second, + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{invalidCert}, + InsecureSkipVerify: true, + CipherSuites: ciperSuites, + }, + Subprotocols: []string{shipWebsocketSubProtocol}, + } + con, _, err = dialer.Dial(wsURL, nil) + assert.Nil(s.T(), err) + + con.Close() + + validCert, _ := CreateCertificate("unit", "org", "DE", "CN") + dialer = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 5 * time.Second, + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{validCert}, + InsecureSkipVerify: true, + CipherSuites: ciperSuites, + }, + Subprotocols: []string{shipWebsocketSubProtocol}, + } + con, _, err = dialer.Dial(wsURL, nil) + assert.Nil(s.T(), err) + + con.Close() + server.Close() +} + +func (s *HubSuite) Test_IncreaseConnectionAttemptCounter() { ski := "test" for _, test := range s.tests { - sut.increaseConnectionAttemptCounter(ski) + s.sut.increaseConnectionAttemptCounter(ski) - sut.muxConAttempt.Lock() - counter, exists := sut.connectionAttemptCounter[ski] + s.sut.muxConAttempt.Lock() + counter, exists := s.sut.connectionAttemptCounter[ski] timeRange := connectionInitiationDelayTimeRanges[counter] - sut.muxConAttempt.Unlock() + s.sut.muxConAttempt.Unlock() assert.Equal(s.T(), true, exists) assert.Equal(s.T(), test.timeRange.min, timeRange.min) @@ -302,153 +357,103 @@ func (s *HubSuite) Test_IncreaseConnectionAttemptCounter() { } func (s *HubSuite) Test_RemoveConnectionAttemptCounter() { - // we just need a dummy for this test - sut := connectionsHubImpl{ - connectionAttemptCounter: make(map[string]int), - } ski := "test" - sut.increaseConnectionAttemptCounter(ski) - _, exists := sut.connectionAttemptCounter[ski] + s.sut.increaseConnectionAttemptCounter(ski) + _, exists := s.sut.connectionAttemptCounter[ski] assert.Equal(s.T(), true, exists) - sut.removeConnectionAttemptCounter(ski) - _, exists = sut.connectionAttemptCounter[ski] + s.sut.removeConnectionAttemptCounter(ski) + _, exists = s.sut.connectionAttemptCounter[ski] assert.Equal(s.T(), false, exists) } func (s *HubSuite) Test_GetCurrentConnectionAttemptCounter() { - // we just need a dummy for this test - sut := connectionsHubImpl{ - connectionAttemptCounter: make(map[string]int), - } ski := "test" - sut.increaseConnectionAttemptCounter(ski) - _, exists := sut.connectionAttemptCounter[ski] + s.sut.increaseConnectionAttemptCounter(ski) + _, exists := s.sut.connectionAttemptCounter[ski] assert.Equal(s.T(), exists, true) - sut.increaseConnectionAttemptCounter(ski) + s.sut.increaseConnectionAttemptCounter(ski) - value, exists := sut.getCurrentConnectionAttemptCounter(ski) + value, exists := s.sut.getCurrentConnectionAttemptCounter(ski) assert.Equal(s.T(), 1, value) assert.Equal(s.T(), true, exists) } func (s *HubSuite) Test_GetConnectionInitiationDelayTime() { - // we just need a dummy for this test - ski := "12af9e" - localService := NewServiceDetails(ski) - sut := connectionsHubImpl{ - localService: localService, - connectionAttemptCounter: make(map[string]int), - } + ski := "test" - counter, duration := sut.getConnectionInitiationDelayTime(ski) + counter, duration := s.sut.getConnectionInitiationDelayTime(ski) assert.Equal(s.T(), 0, counter) assert.LessOrEqual(s.T(), float64(s.tests[counter].timeRange.min), float64(duration/time.Second)) assert.GreaterOrEqual(s.T(), float64(s.tests[counter].timeRange.max), float64(duration/time.Second)) } func (s *HubSuite) Test_ConnectionAttemptRunning() { - // we just need a dummy for this test ski := "test" - sut := connectionsHubImpl{ - connectionAttemptRunning: make(map[string]bool), - } - sut.setConnectionAttemptRunning(ski, true) - status := sut.isConnectionAttemptRunning(ski) + s.sut.setConnectionAttemptRunning(ski, true) + status := s.sut.isConnectionAttemptRunning(ski) assert.Equal(s.T(), true, status) - sut.setConnectionAttemptRunning(ski, false) - status = sut.isConnectionAttemptRunning(ski) + s.sut.setConnectionAttemptRunning(ski, false) + status = s.sut.isConnectionAttemptRunning(ski) assert.Equal(s.T(), false, status) } func (s *HubSuite) Test_InitiatePairingWithSKI() { ski := "test" - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - connectionAttemptRunning: make(map[string]bool), - remoteServices: make(map[string]*ServiceDetails), - serviceProvider: s.serviceProvider, - mdns: s.mdnsService, - } - sut.InitiatePairingWithSKI(ski) - assert.Equal(s.T(), 0, len(sut.connections)) + s.sut.InitiatePairingWithSKI(ski) + assert.Equal(s.T(), 0, len(s.sut.connections)) con := &ship.ShipConnection{ RemoteSKI: ski, } - sut.registerConnection(con) - sut.InitiatePairingWithSKI(ski) - assert.Equal(s.T(), 1, len(sut.connections)) + s.sut.registerConnection(con) + s.sut.InitiatePairingWithSKI(ski) + assert.Equal(s.T(), 1, len(s.sut.connections)) } func (s *HubSuite) Test_CancelPairingWithSKI() { ski := "test" - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - connectionAttemptRunning: make(map[string]bool), - remoteServices: make(map[string]*ServiceDetails), - serviceProvider: s.serviceProvider, - mdns: s.mdnsService, - } - sut.CancelPairingWithSKI(ski) - assert.Equal(s.T(), 0, len(sut.connections)) - assert.Equal(s.T(), 0, len(sut.connectionAttemptRunning)) + s.sut.CancelPairingWithSKI(ski) + assert.Equal(s.T(), 0, len(s.sut.connections)) + assert.Equal(s.T(), 0, len(s.sut.connectionAttemptRunning)) con := &ship.ShipConnection{ RemoteSKI: ski, } - sut.registerConnection(con) - assert.Equal(s.T(), 1, len(sut.connections)) + s.sut.registerConnection(con) + assert.Equal(s.T(), 1, len(s.sut.connections)) - sut.CancelPairingWithSKI(ski) - assert.Equal(s.T(), 0, len(sut.connectionAttemptRunning)) + s.sut.CancelPairingWithSKI(ski) + assert.Equal(s.T(), 0, len(s.sut.connectionAttemptRunning)) } func (s *HubSuite) Test_ReportMdnsEntries() { - localService := &ServiceDetails{ - SKI: "localSKI", - } - sut := connectionsHubImpl{ - connections: make(map[string]*ship.ShipConnection), - connectionAttemptCounter: make(map[string]int), - connectionAttemptRunning: make(map[string]bool), - remoteServices: make(map[string]*ServiceDetails), - serviceProvider: s.serviceProvider, - localService: localService, - mdns: s.mdnsService, - } - - certificate, _ := CreateCertificate("unit", "org", "DE", "CN") - sut.configuration, _ = NewConfiguration("vendor", "brand", "model", "serial", - model.DeviceTypeTypeGeneric, []model.EntityTypeType{model.EntityTypeTypeCEM}, - 4567, certificate, 230, time.Second*4) - testski1 := "test1" testski2 := "test2" entries := make(map[string]*MdnsEntry) s.serviceProvider.EXPECT().VisibleMDNSRecordsUpdated(gomock.Any()).AnyTimes() - sut.ReportMdnsEntries(entries) + s.sut.ReportMdnsEntries(entries) entries[testski1] = &MdnsEntry{ Ski: testski1, } - service1 := sut.ServiceForSKI(testski1) + service1 := s.sut.ServiceForSKI(testski1) service1.Trusted = true service1.IPv4 = "127.0.0.1" entries[testski2] = &MdnsEntry{ Ski: testski2, } - service2 := sut.ServiceForSKI(testski2) + service2 := s.sut.ServiceForSKI(testski2) service2.Trusted = true service2.IPv4 = "127.0.0.1" - sut.ReportMdnsEntries(entries) + s.sut.ReportMdnsEntries(entries) } diff --git a/service/types.go b/service/types.go index aa537f54..4b5b4a83 100644 --- a/service/types.go +++ b/service/types.go @@ -101,7 +101,9 @@ type ServiceDetails struct { Trusted bool // the current connection state details - ConnectionStateDetail *ConnectionStateDetail + connectionStateDetail *ConnectionStateDetail + + mux sync.Mutex } // create a new ServiceDetails record with a SKI @@ -109,12 +111,26 @@ func NewServiceDetails(ski string) *ServiceDetails { connState := NewConnectionStateDetail(ConnectionStateNone, nil) service := &ServiceDetails{ SKI: util.NormalizeSKI(ski), // standardize the provided SKI strings - ConnectionStateDetail: connState, + connectionStateDetail: connState, } return service } +func (s *ServiceDetails) ConnectionStateDetail() *ConnectionStateDetail { + s.mux.Lock() + defer s.mux.Unlock() + + return s.connectionStateDetail +} + +func (s *ServiceDetails) SetConnectionStateDetail(detail *ConnectionStateDetail) { + s.mux.Lock() + defer s.mux.Unlock() + + s.connectionStateDetail = detail +} + // defines requires meta information about this service type Configuration struct { // The vendors IANA PEN, optional but highly recommended. diff --git a/service/types_test.go b/service/types_test.go index 28db095c..29ee5de7 100644 --- a/service/types_test.go +++ b/service/types_test.go @@ -40,6 +40,12 @@ func (s *TypesSuite) Test_ServiceDetails() { details := NewServiceDetails(testSki) assert.NotNil(s.T(), details) + + conState := NewConnectionStateDetail(ConnectionStateNone, nil) + details.SetConnectionStateDetail(conState) + + state := details.ConnectionStateDetail() + assert.Equal(s.T(), ConnectionStateNone, state.State()) } func (s *TypesSuite) Test_Configuration() { diff --git a/ship/connection.go b/ship/connection.go index 583aecd7..6434943e 100644 --- a/ship/connection.go +++ b/ship/connection.go @@ -132,6 +132,9 @@ func (c *ShipConnection) AbortPendingHandshake() { // report removing a connection func (c *ShipConnection) removeRemoteDeviceConnection() { + if c.deviceLocalCon == nil { + return + } c.deviceLocalCon.RemoveRemoteDeviceConnection(c.RemoteSKI) }