Skip to content

Commit

Permalink
Improve tests and fix a few bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
DerAndereAndi committed Jan 5, 2024
1 parent 36d099b commit 0141f27
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 231 deletions.
75 changes: 74 additions & 1 deletion service/cert_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
package service

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -31,5 +39,70 @@ func (c *CertSuite) Test_SkiFromCertificate() {

ski, err := skiFromCertificate(leaf)
assert.Nil(c.T(), err)
assert.NotNil(c.T(), ski)
assert.NotEqual(c.T(), "", ski)

cert, err = CreateInvalidCertificate("unit", "org", "DE", "CN")
assert.Nil(c.T(), err)

leaf, err = x509.ParseCertificate(cert.Certificate[0])
assert.Nil(c.T(), err)

ski, err = skiFromCertificate(leaf)
assert.NotNil(c.T(), err)
assert.Equal(c.T(), "", ski)
}

func CreateInvalidCertificate(organizationalUnit, organization, country, commonName string) (tls.Certificate, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}

// Create the EEBUS service SKI using the private key
asn1, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return tls.Certificate{}, err
}
// SHIP 12.2: Required to be created according to RFC 3280 4.2.1.2
ski := sha1.Sum(asn1)

subject := pkix.Name{
OrganizationalUnit: []string{organizationalUnit},
Organization: []string{organization},
Country: []string{country},
CommonName: commonName,
}

// Create a random serial big int value
maxValue := new(big.Int)
maxValue.Exp(big.NewInt(2), big.NewInt(130), nil).Sub(maxValue, big.NewInt(1))
serialNumber, err := rand.Int(rand.Reader, maxValue)
if err != nil {
return tls.Certificate{}, err
}

template := x509.Certificate{
SignatureAlgorithm: x509.ECDSAWithSHA256,
SerialNumber: serialNumber,
Subject: subject,
NotBefore: time.Now(), // Valid starting now
NotAfter: time.Now().Add(time.Hour * 24 * 365 * 10), // Valid for 10 years
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
IsCA: true,
SubjectKeyId: ski[:19],
}

certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return tls.Certificate{}, err
}

tlsCertificate := tls.Certificate{
Certificate: [][]byte{certBytes},
PrivateKey: privateKey,
SupportedSignatureAlgorithms: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256},
}

return tlsCertificate, nil
}
86 changes: 44 additions & 42 deletions service/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ type connectionsHubImpl struct {
knownMdnsEntries []*MdnsEntry

// the SPINE local device
spineLocalDevice *spine.DeviceLocalImpl
spineLocalDevice spine.DeviceLocalConnection

muxCon sync.Mutex
muxConAttempt sync.Mutex
muxReg sync.Mutex
muxMdns sync.Mutex
}

func newConnectionsHub(serviceProvider ServiceProvider, mdns MdnsService, spineLocalDevice *spine.DeviceLocalImpl, configuration *Configuration, localService *ServiceDetails) ConnectionsHub {
func newConnectionsHub(serviceProvider ServiceProvider, mdns MdnsService, spineLocalDevice spine.DeviceLocalConnection, configuration *Configuration, localService *ServiceDetails) ConnectionsHub {
hub := &connectionsHubImpl{
connections: make(map[string]*ship.ShipConnection),
connectionAttemptCounter: make(map[string]int),
Expand Down Expand Up @@ -257,9 +257,9 @@ func (h *connectionsHubImpl) HandleShipHandshakeStateUpdate(ski string, state sh

service := h.ServiceForSKI(ski)

existingDetails := service.ConnectionStateDetail
existingDetails := service.ConnectionStateDetail()
if existingDetails.State() != pairingState || existingDetails.Error() != state.Error {
service.ConnectionStateDetail = pairingDetail
service.SetConnectionStateDetail(pairingDetail)

h.serviceProvider.ServicePairingDetailUpdate(ski, pairingDetail)
}
Expand All @@ -280,7 +280,7 @@ func (h *connectionsHubImpl) PairingDetailForSki(ski string) *ConnectionStateDet
return NewConnectionStateDetail(state, shipError)
}

return service.ConnectionStateDetail
return service.ConnectionStateDetail()
}

// maps ShipMessageExchangeState to PairingState
Expand Down Expand Up @@ -375,6 +375,25 @@ func (h *connectionsHubImpl) isSkiConnected(ski string) bool {
}

// Websocket connection handling
func (h *connectionsHubImpl) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
skiFound := false
for _, v := range rawCerts {
cert, err := x509.ParseCertificate(v)
if err != nil {
return err
}

if _, err := skiFromCertificate(cert); err == nil {
skiFound = true
break
}
}
if !skiFound {
return errors.New("no valid SKI provided in certificate")
}

return nil
}

// start the ship websocket server
func (h *connectionsHubImpl) startWebsocketServer() error {
Expand All @@ -385,28 +404,10 @@ func (h *connectionsHubImpl) startWebsocketServer() error {
Addr: addr,
Handler: h,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{h.configuration.certificate},
ClientAuth: tls.RequireAnyClientCert, // SHIP 9: Client authentication is required
CipherSuites: ciperSuites,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
skiFound := false
for _, v := range rawCerts {
cert, err := x509.ParseCertificate(v)
if err != nil {
return err
}

if _, err := skiFromCertificate(cert); err == nil {
skiFound = true
break
}
}
if !skiFound {
return errors.New("no valid SKI provided in certificate")
}

return nil
},
Certificates: []tls.Certificate{h.configuration.certificate},
ClientAuth: tls.RequireAnyClientCert, // SHIP 9: Client authentication is required
CipherSuites: ciperSuites,
VerifyPeerCertificate: h.verifyPeerCertificate,
},
}

Expand Down Expand Up @@ -445,7 +446,7 @@ func (h *connectionsHubImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// check if the clients certificate provides a SKI
if len(r.TLS.PeerCertificates) == 0 {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
logging.Log.Debug("client does not provide a certificate")
_ = conn.Close()
return
Expand All @@ -464,9 +465,10 @@ func (h *connectionsHubImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Check if the remote service is paired
service := h.ServiceForSKI(remoteService.SKI)
if service.ConnectionStateDetail.State() == ConnectionStateQueued {
service.ConnectionStateDetail.SetState(ConnectionStateReceivedPairingRequest)
h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail)
connectionStateDetail := service.ConnectionStateDetail()
if connectionStateDetail.State() == ConnectionStateQueued {
connectionStateDetail.SetState(ConnectionStateReceivedPairingRequest)
h.serviceProvider.ServicePairingDetailUpdate(ski, connectionStateDetail)
}

remoteService = service
Expand Down Expand Up @@ -608,7 +610,7 @@ func (h *connectionsHubImpl) ServiceForSKI(ski string) *ServiceDetails {
service, ok := h.remoteServices[ski]
if !ok {
service = NewServiceDetails(ski)
service.ConnectionStateDetail.SetState(ConnectionStateNone)
service.ConnectionStateDetail().SetState(ConnectionStateNone)
h.remoteServices[ski] = service
}

Expand All @@ -629,9 +631,9 @@ func (h *connectionsHubImpl) RegisterRemoteSKI(ski string, enable bool) {

h.removeConnectionAttemptCounter(ski)

service.ConnectionStateDetail.SetState(ConnectionStateNone)
service.ConnectionStateDetail().SetState(ConnectionStateNone)

h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail)
h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail())

if existingC := h.connectionForSKI(ski); existingC != nil {
existingC.CloseConnection(true, 4500, "User close")
Expand All @@ -651,9 +653,9 @@ func (h *connectionsHubImpl) InitiatePairingWithSKI(ski string) {

// locally initiated
service := h.ServiceForSKI(ski)
service.ConnectionStateDetail.SetState(ConnectionStateQueued)
service.ConnectionStateDetail().SetState(ConnectionStateQueued)

h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail)
h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail())

// initiate a search and also a connection if it does not yet exist
if !h.isSkiConnected(service.SKI) {
Expand All @@ -670,10 +672,10 @@ func (h *connectionsHubImpl) CancelPairingWithSKI(ski string) {
}

service := h.ServiceForSKI(ski)
service.ConnectionStateDetail.SetState(ConnectionStateNone)
service.ConnectionStateDetail().SetState(ConnectionStateNone)
service.Trusted = false

h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail)
h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail())
}

// Process reported mDNS services
Expand All @@ -694,7 +696,7 @@ func (h *connectionsHubImpl) ReportMdnsEntries(entries map[string]*MdnsEntry) {
// Check if the remote service is paired or queued for connection
service := h.ServiceForSKI(ski)
if !h.IsRemoteServiceForSKIPaired(ski) &&
service.ConnectionStateDetail.State() != ConnectionStateQueued {
service.ConnectionStateDetail().State() != ConnectionStateQueued {
continue
}

Expand Down Expand Up @@ -732,7 +734,7 @@ func (h *connectionsHubImpl) coordinateConnectionInitations(ski string, entry *M
counter, duration := h.getConnectionInitiationDelayTime(ski)

service := h.ServiceForSKI(ski)
if service.ConnectionStateDetail.State() == ConnectionStateQueued {
if service.ConnectionStateDetail().State() == ConnectionStateQueued {
go h.prepareConnectionInitation(ski, counter, entry)
return
}
Expand Down Expand Up @@ -762,7 +764,7 @@ func (h *connectionsHubImpl) prepareConnectionInitation(ski string, counter int,

// connection attempt is not relevant if the device is no longer paired
// or it is not queued for pairing
pairingState := h.ServiceForSKI(ski).ConnectionStateDetail.State()
pairingState := h.ServiceForSKI(ski).ConnectionStateDetail().State()
if !h.IsRemoteServiceForSKIPaired(ski) && pairingState != ConnectionStateQueued {
return
}
Expand Down Expand Up @@ -790,7 +792,7 @@ func (h *connectionsHubImpl) initateConnection(remoteService *ServiceDetails, en
for _, address := range entry.Addresses {
// connection attempt is not relevant if the device is no longer paired
// or it is not queued for pairing
pairingState := h.ServiceForSKI(remoteService.SKI).ConnectionStateDetail.State()
pairingState := h.ServiceForSKI(remoteService.SKI).ConnectionStateDetail().State()
if !h.IsRemoteServiceForSKIPaired(remoteService.SKI) && pairingState != ConnectionStateQueued {
return false
}
Expand Down
Loading

0 comments on commit 0141f27

Please sign in to comment.