Skip to content

Commit

Permalink
Refactor logging to make it thread safe
Browse files Browse the repository at this point in the history
  • Loading branch information
DerAndereAndi committed Jan 5, 2024
1 parent 7590ab7 commit e0c054c
Show file tree
Hide file tree
Showing 20 changed files with 81 additions and 68 deletions.
17 changes: 15 additions & 2 deletions logging/log.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package logging

import "sync"

//go:generate mockery --name=Logging

// Logging needs to be implemented, if the internal logs should be printed
Expand All @@ -26,7 +28,8 @@ func (l *NoLogging) Infof(format string, args ...interface{}) {}
func (l *NoLogging) Error(args ...interface{}) {}
func (l *NoLogging) Errorf(format string, args ...interface{}) {}

var Log Logging = &NoLogging{}
var log Logging = &NoLogging{}
var mux sync.Mutex

// Sets a custom logging implementation
// By default NoLogging is used, so no logs are printed
Expand All @@ -35,5 +38,15 @@ func SetLogging(logger Logging) {
if logger == nil {
return
}
Log = logger
mux.Lock()
defer mux.Unlock()

log = logger
}

func Log() Logging {
mux.Lock()
defer mux.Unlock()

return log
}
34 changes: 17 additions & 17 deletions service/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ func newConnectionsHub(serviceProvider ServiceProvider, mdns MdnsService, spineL
func (h *connectionsHubImpl) Start() {
// start the websocket server
if err := h.startWebsocketServer(); err != nil {
logging.Log.Debug("error during websocket server starting:", err)
logging.Log().Debug("error during websocket server starting:", err)
}

// start mDNS
err := h.mdns.SetupMdnsService()
if err != nil {
logging.Log.Debug("error during mdns setup:", err)
logging.Log().Debug("error during mdns setup:", err)
}

h.checkRestartMdnsSearch()
Expand Down Expand Up @@ -398,7 +398,7 @@ func (h *connectionsHubImpl) verifyPeerCertificate(rawCerts [][]byte, verifiedCh
// start the ship websocket server
func (h *connectionsHubImpl) startWebsocketServer() error {
addr := fmt.Sprintf(":%d", h.configuration.port)
logging.Log.Debug("starting websocket server on", addr)
logging.Log().Debug("starting websocket server on", addr)

h.httpServer = &http.Server{
Addr: addr,
Expand All @@ -413,7 +413,7 @@ func (h *connectionsHubImpl) startWebsocketServer() error {

go func() {
if err := h.httpServer.ListenAndServeTLS("", ""); err != nil {
logging.Log.Debug("websocket server error:", err)
logging.Log().Debug("websocket server error:", err)
// TODO: decide how to handle this case
}
}()
Expand All @@ -434,34 +434,34 @@ func (h *connectionsHubImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) {

conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logging.Log.Debug("error during connection upgrading:", err)
logging.Log().Debug("error during connection upgrading:", err)
return
}

// check if the client supports the ship sub protocol
if conn.Subprotocol() != shipWebsocketSubProtocol {
logging.Log.Debug("client does not support the ship sub protocol")
logging.Log().Debug("client does not support the ship sub protocol")
_ = conn.Close()
return
}

// check if the clients certificate provides a SKI
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
logging.Log.Debug("client does not provide a certificate")
logging.Log().Debug("client does not provide a certificate")
_ = conn.Close()
return
}

ski, err := skiFromCertificate(r.TLS.PeerCertificates[0])
if err != nil {
logging.Log.Debug(err)
logging.Log().Debug(err)
_ = conn.Close()
return
}

// normalize the incoming SKI
remoteService := NewServiceDetails(ski)
logging.Log.Debug("incoming connection request from", remoteService.SKI)
logging.Log().Debug("incoming connection request from", remoteService.SKI)

// Check if the remote service is paired
service := h.ServiceForSKI(remoteService.SKI)
Expand Down Expand Up @@ -495,7 +495,7 @@ func (h *connectionsHubImpl) connectFoundService(remoteService *ServiceDetails,
return nil
}

logging.Log.Debugf("initiating connection to %s at %s:%s", remoteService.SKI, host, port)
logging.Log().Debugf("initiating connection to %s at %s:%s", remoteService.SKI, host, port)

dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
Expand Down Expand Up @@ -584,14 +584,14 @@ func (h *connectionsHubImpl) keepThisConnection(conn *websocket.Conn, incomingRe
if keep {
// we have an existing connection
// so keep the new (most recent) and close the old one
logging.Log.Debug("closing existing double connection")
logging.Log().Debug("closing existing double connection")
go existingC.CloseConnection(false, 0, "")
} else {
connType := "incoming"
if !incomingRequest {
connType = "outgoing"
}
logging.Log.Debugf("closing %s double connection, as the existing connection will be used", connType)
logging.Log().Debugf("closing %s double connection, as the existing connection will be used", connType)
if conn != nil {
go func() {
_ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "double connection"))
Expand Down Expand Up @@ -741,7 +741,7 @@ func (h *connectionsHubImpl) coordinateConnectionInitations(ski string, entry *M
return
}

logging.Log.Debugf("delaying connection to %s by %s to minimize double connection probability", ski, duration)
logging.Log().Debugf("delaying connection to %s by %s to minimize double connection probability", ski, duration)

// we do not stop this thread and just let the timer run out
// otherwise we would need a stop channel for each ski
Expand Down Expand Up @@ -799,19 +799,19 @@ func (h *connectionsHubImpl) initateConnection(remoteService *ServiceDetails, en
return false
}

logging.Log.Debug("trying to connect to", remoteService.SKI, "at", address)
logging.Log().Debug("trying to connect to", remoteService.SKI, "at", address)
if err = h.connectFoundService(remoteService, address.String(), strconv.Itoa(entry.Port)); err != nil {
logging.Log.Debug("connection to", remoteService.SKI, "failed: ", err)
logging.Log().Debug("connection to", remoteService.SKI, "failed: ", err)
} else {
return true
}
}

// connectdion via IP address failed, try hostname
if len(entry.Host) > 0 {
logging.Log.Debug("trying to connect to", remoteService.SKI, "at", entry.Host)
logging.Log().Debug("trying to connect to", remoteService.SKI, "at", entry.Host)
if err = h.connectFoundService(remoteService, entry.Host, strconv.Itoa(entry.Port)); err != nil {
logging.Log.Debugf("connection to %s failed: %s", remoteService.SKI, err)
logging.Log().Debugf("connection to %s failed: %s", remoteService.SKI, err)
} else {
return true
}
Expand Down
12 changes: 6 additions & 6 deletions service/mdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ func (m *mdnsManager) AnnounceMdnsEntry() error {
"register=" + fmt.Sprintf("%v", m.configuration.registerAutoAccept),
}

logging.Log.Debug("mdns: announce")
logging.Log().Debug("mdns: announce")

serviceName := m.configuration.MdnsServiceName()

if err := m.mdnsProvider.Announce(serviceName, m.configuration.port, txt); err != nil {
logging.Log.Debug("mdns: failure announcing service", err)
logging.Log().Debug("mdns: failure announcing service", err)
return err
}

Expand All @@ -183,7 +183,7 @@ func (m *mdnsManager) UnannounceMdnsEntry() {
}

m.mdnsProvider.Unannounce()
logging.Log.Debug("mdns: stop announcement")
logging.Log().Debug("mdns: stop announcement")

m.isAnnounced = false
}
Expand Down Expand Up @@ -289,7 +289,7 @@ func (m *mdnsManager) resolveEntries() {
return
}
go func() {
logging.Log.Debug("mdns: start search")
logging.Log().Debug("mdns: start search")
m.mdnsProvider.ResolveEntries(m.cancelChan, m.processMdnsEntry)

m.setIsSearchingServices(false)
Expand All @@ -306,7 +306,7 @@ func (m *mdnsManager) stopResolvingEntries() {
return
}

logging.Log.Debug("mdns: stop search")
logging.Log().Debug("mdns: stop search")

m.cancelChan <- true
}
Expand Down Expand Up @@ -407,7 +407,7 @@ func (m *mdnsManager) processMdnsEntry(elements map[string]string, name, host st
}
m.setMdnsEntry(ski, newEntry)

logging.Log.Debug("ski:", ski, "name:", name, "brand:", brand, "model:", model, "typ:", deviceType, "identifier:", identifier, "register:", register, "host:", host, "port:", port, "addresses:", addresses)
logging.Log().Debug("ski:", ski, "name:", name, "brand:", brand, "model:", model, "typ:", deviceType, "identifier:", identifier, "register:", register, "host:", host, "port:", port, "addresses:", addresses)
} else {
return
}
Expand Down
12 changes: 6 additions & 6 deletions service/mdns/avahi.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (a *AvahiProvider) Shutdown() {
}

func (a *AvahiProvider) Announce(serviceName string, port int, txt []string) error {
logging.Log.Debug("mdns: using avahi")
logging.Log().Debug("mdns: using avahi")

entryGroup, err := a.avServer.EntryGroupNew()
if err != nil {
Expand Down Expand Up @@ -108,12 +108,12 @@ func (a *AvahiProvider) ResolveEntries(cancelChan chan bool, callback func(eleme

// instead of limiting search on specific allowed interfaces, we allow all and filter the results
if avBrowser, err = a.avServer.ServiceBrowserNew(avahi.InterfaceUnspec, avahi.ProtoUnspec, shipZeroConfServiceType, shipZeroConfDomain, 0); err != nil {
logging.Log.Debug("mdns: error setting up avahi browser:", err)
logging.Log().Debug("mdns: error setting up avahi browser:", err)
return
}

if avBrowser == nil {
logging.Log.Debug("mdns: avahi browser is not available")
logging.Log().Debug("mdns: avahi browser is not available")
return
}

Expand Down Expand Up @@ -149,13 +149,13 @@ func (a *AvahiProvider) processService(service avahi.Service, remove bool, callb
}

if !allow {
logging.Log.Debug("avahi - ignoring service as its interface is not in the allowed list:", service.Name)
logging.Log().Debug("avahi - ignoring service as its interface is not in the allowed list:", service.Name)
return
}

resolved, err := a.avServer.ResolveService(service.Interface, service.Protocol, service.Name, service.Type, service.Domain, avahi.ProtoUnspec, 0)
if err != nil {
logging.Log.Debug("avahi - error resolving service:", service, "error:", err)
logging.Log().Debug("avahi - error resolving service:", service, "error:", err)
return
}

Expand All @@ -170,7 +170,7 @@ func (a *AvahiProvider) processService(service avahi.Service, remove bool, callb
address := net.ParseIP(resolved.Address)
// if the address can not be used, ignore the entry
if address == nil || address.IsUnspecified() {
logging.Log.Debug("avahi - service provides unusable address:", service.Name)
logging.Log().Debug("avahi - service provides unusable address:", service.Name)
return
}

Expand Down
2 changes: 1 addition & 1 deletion service/mdns/zeroconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (z *ZeroconfProvider) CheckAvailability() bool {
func (z *ZeroconfProvider) Shutdown() {}

func (z *ZeroconfProvider) Announce(serviceName string, port int, txt []string) error {
logging.Log.Debug("mdns: using zeroconf")
logging.Log().Debug("mdns: using zeroconf")

// use Zeroconf library if avahi is not available
// Set TTL to 2 minutes as defined in SHIP chapter 7
Expand Down
2 changes: 1 addition & 1 deletion service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (s *EEBUSService) Setup() error {
s.LocalService.DeviceType = sd.deviceType
s.LocalService.RegisterAutoAccept = sd.registerAutoAccept

logging.Log.Info("Local SKI: ", ski)
logging.Log().Info("Local SKI: ", ski)

vendor := sd.vendorCode
if vendor == "" {
Expand Down
6 changes: 3 additions & 3 deletions service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ func (s *ServiceSuite) Test_ConnectionsHub() {

func (s *ServiceSuite) Test_SetLogging() {
s.sut.SetLogging(nil)
assert.Equal(s.T(), &logging.NoLogging{}, logging.Log)
assert.Equal(s.T(), &logging.NoLogging{}, logging.Log())

s.sut.SetLogging(s.logging)
assert.Equal(s.T(), s.logging, logging.Log)
assert.Equal(s.T(), s.logging, logging.Log())

s.sut.SetLogging(&logging.NoLogging{})
assert.Equal(s.T(), &logging.NoLogging{}, logging.Log)
assert.Equal(s.T(), &logging.NoLogging{}, logging.Log())
}

func (s *ServiceSuite) Test_Setup() {
Expand Down
8 changes: 4 additions & 4 deletions ship/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ var _ spine.SpineDataConnection = (*ShipConnectionImpl)(nil)
// SpineDataConnection interface implementation
func (c *ShipConnectionImpl) WriteSpineMessage(message []byte) {
if err := c.sendSpineData(message); err != nil {
logging.Log.Debug(c.RemoteSKI, "Error sending spine message: ", err)
logging.Log().Debug(c.RemoteSKI, "Error sending spine message: ", err)
return
}
}
Expand All @@ -210,13 +210,13 @@ func (c *ShipConnectionImpl) shipModelFromMessage(message []byte) (*model.ShipDa
// Get the datagram from the message
data := model.ShipData{}
if err := json.Unmarshal(jsonData, &data); err != nil {
logging.Log.Debug(c.RemoteSKI, "error unmarshalling message: ", err)
logging.Log().Debug(c.RemoteSKI, "error unmarshalling message: ", err)
return nil, err
}

if data.Data.Payload == nil {
errorMsg := "received no valid payload"
logging.Log.Debug(c.RemoteSKI, errorMsg)
logging.Log().Debug(c.RemoteSKI, errorMsg)
return nil, errors.New(errorMsg)
}

Expand Down Expand Up @@ -342,7 +342,7 @@ func (c *ShipConnectionImpl) sendSpineData(data []byte) error {

err = c.dataHandler.WriteMessageToDataConnection(shipMsg)
if err != nil {
logging.Log.Debug("error sending message: ", err)
logging.Log().Debug("error sending message: ", err)
return err
}

Expand Down
4 changes: 2 additions & 2 deletions ship/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (c *ShipConnectionImpl) getState() ShipMessageExchangeState {
func (c *ShipConnectionImpl) handleState(timeout bool, message []byte) {
switch c.getState() {
case SmeStateError:
logging.Log.Debug(c.RemoteSKI, "connection is in error state")
logging.Log().Debug(c.RemoteSKI, "connection is in error state")
return

// cmiStateInit
Expand Down Expand Up @@ -210,7 +210,7 @@ func (c *ShipConnectionImpl) endHandshakeWithError(err error) {

c.setState(SmeStateError, err)

logging.Log.Debug(c.RemoteSKI, "SHIP handshake error:", err)
logging.Log().Debug(c.RemoteSKI, "SHIP handshake error:", err)

c.CloseConnection(true, 0, err.Error())

Expand Down
4 changes: 2 additions & 2 deletions ship/hs_hello.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (c *ShipConnectionImpl) handshakeHello_ReadyListen(timeout bool, message []

default:
// don't accept any other responses
logging.Log.Errorf("Unexpected connection hello phase: %s", hello.Phase)
logging.Log().Errorf("Unexpected connection hello phase: %s", hello.Phase)
c.setAndHandleState(SmeHelloStateAbort)
return
}
Expand Down Expand Up @@ -201,7 +201,7 @@ func (c *ShipConnectionImpl) handshakeHello_PendingListen(timeout bool, message

default:
// don't accept any other responses
logging.Log.Errorf("Unexpected connection hello phase: %s", hello.Phase)
logging.Log().Errorf("Unexpected connection hello phase: %s", hello.Phase)
c.setAndHandleState(SmeHelloStateAbort)
return
}
Expand Down
Loading

0 comments on commit e0c054c

Please sign in to comment.