From ec7a7e08e892c866a917cc11ba4bec984cd47e27 Mon Sep 17 00:00:00 2001 From: Joshua Rich Date: Tue, 24 Oct 2023 23:17:04 +1000 Subject: [PATCH] fix(agent,hass,device): better clean-up on agent quit/cancellation --- internal/agent/agent.go | 30 +++++++------- internal/agent/notifications.go | 10 +++-- internal/agent/ui/fyneUI/fyneUI.go | 20 ++++------ internal/device/helpers/polling.go | 2 +- internal/hass/api/websocket.go | 64 +++++++++++++++++------------- 5 files changed, 68 insertions(+), 58 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 5f6dcb9ed..aa2694a86 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -74,7 +74,6 @@ func Run(options AgentOptions) { var err error agent := newAgent(&options) - defer close(agent.done) // Pre-flight: check if agent is registered. If not, run registration flow. var regWait sync.WaitGroup @@ -99,7 +98,6 @@ func Run(options AgentOptions) { log.Warn().Err(err).Msg("Unable to set config version to app version.") } ctx, cancelFunc = agent.setupContext() - agent.handleCancellation(ctx) }() // Start main work goroutines @@ -125,13 +123,22 @@ func Run(options AgentOptions) { }() }() - agent.handleSignals() - agent.handleShutdown() + go func() { + <-agent.done + log.Debug().Msg("Agent done.") + cancelFunc() + }() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + defer close(agent.done) + <-c + log.Debug().Msg("Ctrl-C pressed.") + }() + agent.ui.DisplayTrayIcon(agent) agent.ui.Run() - defer cancelFunc() - - wg.Wait() } // Register runs a registration flow. It either prompts the user for needed @@ -227,14 +234,6 @@ func (agent *Agent) handleShutdown() { }() } -func (agent *Agent) handleCancellation(ctx context.Context) { - go func() { - <-ctx.Done() - log.Debug().Msg("Context canceled.") - os.Exit(1) - }() -} - // Agent satisfies ui.Agent, tracker.Agent and api.Agent interfaces func (agent *Agent) IsHeadless() bool { @@ -254,6 +253,7 @@ func (agent *Agent) AppVersion() string { } func (agent *Agent) Stop() { + log.Debug().Msg("Stopping agent.") close(agent.done) } diff --git a/internal/agent/notifications.go b/internal/agent/notifications.go index 735314bf4..75b7ab925 100644 --- a/internal/agent/notifications.go +++ b/internal/agent/notifications.go @@ -40,9 +40,13 @@ func (agent *Agent) runNotificationsWorker(ctx context.Context, options AgentOpt go func() { defer wg.Done() for { - restartCh := make(chan struct{}) - api.StartWebsocket(ctx, agent, notifyCh, restartCh) - <-restartCh + select { + case <-ctx.Done(): + log.Debug().Msg("Stopping websocket.") + return + default: + api.StartWebsocket(ctx, agent, notifyCh) + } } }() diff --git a/internal/agent/ui/fyneUI/fyneUI.go b/internal/agent/ui/fyneUI/fyneUI.go index 79edd89bd..648cfe6ba 100644 --- a/internal/agent/ui/fyneUI/fyneUI.go +++ b/internal/agent/ui/fyneUI/fyneUI.go @@ -35,9 +35,8 @@ server (if not auto-detected) and long-lived access token.` ) type fyneUI struct { - app fyne.App - mainWindow fyne.Window - text *translations.Translator + app fyne.App + text *translations.Translator } func (i *fyneUI) Run() { @@ -79,10 +78,6 @@ func NewFyneUI(agent ui.Agent) *fyneUI { text: translations.NewTranslator(), } i.app.SetIcon(&ui.TrayIcon{}) - i.mainWindow = i.app.NewWindow(agent.AppName()) - i.mainWindow.SetCloseIntercept(func() { - i.mainWindow.Hide() - }) return i } return &fyneUI{} @@ -97,6 +92,7 @@ func (i *fyneUI) DisplayTrayIcon(agent ui.Agent) { } if desk, ok := i.app.(desktop.App); ok { menuItemQuit := fyne.NewMenuItem(i.text.Translate("Quit"), func() { + i.app.Quit() agent.Stop() }) menuItemQuit.IsQuit = true @@ -144,29 +140,29 @@ func (i *fyneUI) DisplayTrayIcon(agent ui.Agent) { // complete registration. It will populate with any values that were already // provided via the command-line. func (i *fyneUI) DisplayRegistrationWindow(ctx context.Context, agent ui.Agent, done chan struct{}) { - i.mainWindow.SetTitle(i.text.Translate("App Registration")) + w := i.app.NewWindow(i.text.Translate("App Registration")) var allFormItems []*widget.FormItem allFormItems = append(allFormItems, i.serverConfigItems(ctx, agent, i.text)...) registrationForm := widget.NewForm(allFormItems...) registrationForm.OnSubmit = func() { - i.mainWindow.Hide() + w.Close() close(done) } registrationForm.OnCancel = func() { log.Warn().Msg("Canceling registration.") close(done) - i.mainWindow.Close() + w.Close() ctx.Done() } - i.mainWindow.SetContent(container.New(layout.NewVBoxLayout(), + w.SetContent(container.New(layout.NewVBoxLayout(), widget.NewLabel(i.text.Translate(explainRegistration)), registrationForm, )) log.Debug().Msg("Asking user for registration details.") - i.mainWindow.Show() + w.Show() } // aboutWindow creates a window that will show some interesting information diff --git a/internal/device/helpers/polling.go b/internal/device/helpers/polling.go index fe42ff787..1d4dc2dbf 100644 --- a/internal/device/helpers/polling.go +++ b/internal/device/helpers/polling.go @@ -27,10 +27,10 @@ func PollSensors(ctx context.Context, updater func(), interval, stdev time.Durat var wg sync.WaitGroup wg.Add(1) go func() { + defer wg.Done() for { select { case <-ctx.Done(): - wg.Done() return case <-ticker.C: updater() diff --git a/internal/hass/api/websocket.go b/internal/hass/api/websocket.go index faba36436..487aa099b 100644 --- a/internal/hass/api/websocket.go +++ b/internal/hass/api/websocket.go @@ -48,7 +48,7 @@ type websocketResponse struct { Success bool `json:"success,omitempty"` } -func StartWebsocket(ctx context.Context, settings Agent, notifyCh chan [2]string, doneCh chan struct{}) { +func StartWebsocket(ctx context.Context, settings Agent, notifyCh chan [2]string) { var websocketURL string if err := settings.GetConfig(config.PrefWebsocketURL, &websocketURL); err != nil { log.Warn().Err(err).Msg("Could not retrieve websocket URL from config.") @@ -60,7 +60,7 @@ func StartWebsocket(ctx context.Context, settings Agent, notifyCh chan [2]string retryFunc := func() error { var resp *http.Response socket, resp, err = gws.NewClient( - newWebsocket(ctx, settings, notifyCh, doneCh), + newWebsocket(ctx, settings, notifyCh), &gws.ClientOption{Addr: websocketURL}) if err != nil { log.Error().Err(err). @@ -77,7 +77,11 @@ func StartWebsocket(ctx context.Context, settings Agent, notifyCh chan [2]string return } log.Trace().Caller().Msg("Websocket connection established.") - go socket.ReadLoop() + go func() { + <-ctx.Done() + socket.WriteClose(1000, nil) + }() + socket.ReadLoop() } type webSocketData struct { @@ -94,7 +98,7 @@ type WebSocket struct { nextID uint64 } -func newWebsocket(ctx context.Context, settings Agent, notifyCh chan [2]string, doneCh chan struct{}) *WebSocket { +func newWebsocket(ctx context.Context, settings Agent, notifyCh chan [2]string) *WebSocket { var token, webhookID string if err := settings.GetConfig(config.PrefToken, &token); err != nil { log.Warn().Err(err).Msg("Could not retrieve token from config.") @@ -110,23 +114,24 @@ func newWebsocket(ctx context.Context, settings Agent, notifyCh chan [2]string, WriteCh: make(chan *webSocketData), token: token, webhookID: webhookID, - doneCh: doneCh, + doneCh: make(chan struct{}), } - go ws.responseHandler(ctx, notifyCh) - go ws.requestHandler(ctx) + go func() { + <-ctx.Done() + close(ws.doneCh) + }() + go ws.responseHandler(notifyCh) + go ws.requestHandler() return ws } func (c *WebSocket) OnError(socket *gws.Conn, err error) { log.Error().Err(err). Msg("Error on websocket") - c.doneCh <- struct{}{} } func (c *WebSocket) OnClose(socket *gws.Conn, err error) { log.Debug().Err(err).Msg("Websocket connection closed.") - c.doneCh <- struct{}{} - close(c.doneCh) } func (c *WebSocket) OnPong(socket *gws.Conn, payload []byte) { @@ -137,20 +142,25 @@ func (c *WebSocket) OnOpen(socket *gws.Conn) { log.Trace().Caller().Msg("Websocket opened.") go func() { ticker := time.NewTicker(PingInterval) - for range ticker.C { - log.Trace().Caller(). - Msg("Sending ping on websocket") - if err := socket.SetDeadline(time.Now().Add(2 * PingInterval)); err != nil { - log.Error().Err(err). - Msg("Error setting deadline on websocket.") + for { + select { + case <-c.doneCh: return - } - c.WriteCh <- &webSocketData{ - conn: socket, - data: &websocketMsg{ - Type: "ping", - ID: atomic.LoadUint64(&c.nextID), - }, + case <-ticker.C: + log.Trace().Caller(). + Msg("Sending ping on websocket") + if err := socket.SetDeadline(time.Now().Add(2 * PingInterval)); err != nil { + log.Error().Err(err). + Msg("Error setting deadline on websocket.") + return + } + c.WriteCh <- &webSocketData{ + conn: socket, + data: &websocketMsg{ + Type: "ping", + ID: atomic.LoadUint64(&c.nextID), + }, + } } } }() @@ -175,10 +185,10 @@ func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) { } } -func (c *WebSocket) responseHandler(ctx context.Context, notifyCh chan [2]string) { +func (c *WebSocket) responseHandler(notifyCh chan [2]string) { for { select { - case <-ctx.Done(): + case <-c.doneCh: log.Trace().Caller().Msg("Stopping websocket response handler.") return case r := <-c.ReadCh: @@ -237,10 +247,10 @@ func (c *WebSocket) responseHandler(ctx context.Context, notifyCh chan [2]string } } -func (c *WebSocket) requestHandler(ctx context.Context) { +func (c *WebSocket) requestHandler() { for { select { - case <-ctx.Done(): + case <-c.doneCh: log.Trace().Caller().Msg("Stopping websocket request handler.") return case m := <-c.WriteCh: