diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 779a063..c3640eb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,3 +25,8 @@ jobs: - name: Test run: go test -v ./... + + - name: Lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.61 diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index 82ae5c4..ff8e6a4 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -54,6 +54,7 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().BoolVar(&deployCommand.args.TargetOptions.ForwardHeaders, "forward-headers", false, "Forward X-Forwarded headers to target (default false if TLS enabled; otherwise true)") + //nolint:errcheck deployCommand.cmd.MarkFlagRequired("target") deployCommand.cmd.MarkFlagsRequiredTogether("tls-certificate-path", "tls-private-key-path") diff --git a/internal/cmd/rollout_deploy.go b/internal/cmd/rollout_deploy.go index 54d5021..a68ece0 100644 --- a/internal/cmd/rollout_deploy.go +++ b/internal/cmd/rollout_deploy.go @@ -26,6 +26,7 @@ func newRolloutDeployCommand() *rolloutDeployCommand { rolloutDeployCommand.cmd.Flags().DurationVar(&rolloutDeployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy") rolloutDeployCommand.cmd.Flags().DurationVar(&rolloutDeployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target") + //nolint:errcheck rolloutDeployCommand.cmd.MarkFlagRequired("target") return rolloutDeployCommand diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 900b76f..a0e3f18 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -35,10 +35,13 @@ func (c *runCommand) run(cmd *cobra.Command, args []string) error { c.setLogger() router := server.NewRouter(globalConfig.StatePath()) - router.RestoreLastSavedState() + err := router.RestoreLastSavedState() + if err != nil { + return err + } s := server.NewServer(&globalConfig, router) - err := s.Start() + err = s.Start() if err != nil { return err } diff --git a/internal/server/buffer.go b/internal/server/buffer.go index fd05f0e..3713603 100644 --- a/internal/server/buffer.go +++ b/internal/server/buffer.go @@ -86,7 +86,10 @@ func (b *Buffer) Write(p []byte) (int, error) { } func (b *Buffer) Read(p []byte) (n int, err error) { - b.setReader() + err = b.setReader() + if err != nil { + return 0, err + } return b.reader.Read(p) } @@ -95,8 +98,11 @@ func (b *Buffer) Overflowed() bool { } func (b *Buffer) Send(w io.Writer) error { - b.setReader() - _, err := io.Copy(w, b.reader) + err := b.setReader() + if err != nil { + return err + } + _, err = io.Copy(w, b.reader) return err } @@ -120,15 +126,19 @@ func (b *Buffer) writeToDisk(p []byte) (int, error) { return n, err } -func (b *Buffer) setReader() { +func (b *Buffer) setReader() error { if b.reader == nil { if b.diskBuffer != nil { - b.diskBuffer.Seek(0, 0) + _, err := b.diskBuffer.Seek(0, 0) + if err != nil { + return err + } b.reader = io.MultiReader(&b.memoryBuffer, b.diskBuffer) } else { b.reader = &b.memoryBuffer } } + return nil } func (b *Buffer) createSpill() error { diff --git a/internal/server/commands.go b/internal/server/commands.go index 8d34420..8997fea 100644 --- a/internal/server/commands.go +++ b/internal/server/commands.go @@ -134,9 +134,10 @@ func (h *CommandHandler) Remove(args RemoveArgs, reply *bool) error { } func (h *CommandHandler) List(args bool, reply *ListResponse) error { - reply.Targets = h.router.ListActiveServices() + var err error + reply.Targets, err = h.router.ListActiveServices() - return nil + return err } func (h *CommandHandler) RolloutDeploy(args RolloutDeployArgs, reply *bool) error { diff --git a/internal/server/pause_controller.go b/internal/server/pause_controller.go index 78be849..e36d9af 100644 --- a/internal/server/pause_controller.go +++ b/internal/server/pause_controller.go @@ -57,11 +57,11 @@ func (p *PauseController) UnmarshalJSON(data []byte) error { switch p.State { case PauseStateRunning: - p.Resume() + return p.Resume() case PauseStatePaused: - p.Pause(p.FailAfter) + return p.Pause(p.FailAfter) case PauseStateStopped: - p.Stop(p.StopMessage) + return p.Stop(p.StopMessage) } return nil diff --git a/internal/server/request_buffer_middleware_test.go b/internal/server/request_buffer_middleware_test.go index 07c2bde..f51b726 100644 --- a/internal/server/request_buffer_middleware_test.go +++ b/internal/server/request_buffer_middleware_test.go @@ -12,7 +12,8 @@ import ( func TestRequestBufferMiddleware(t *testing.T) { sendRequest := func(requestBody, responseBody string) *httptest.ResponseRecorder { middleware := WithRequestBufferMiddleware(4, 8, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(responseBody)) + _, err := w.Write([]byte(responseBody)) + assert.NoError(t, err) })) req := httptest.NewRequest("POST", "http://app.example.com/somepath", strings.NewReader(requestBody)) diff --git a/internal/server/response_buffer_middleware_test.go b/internal/server/response_buffer_middleware_test.go index 179e5e2..de415fb 100644 --- a/internal/server/response_buffer_middleware_test.go +++ b/internal/server/response_buffer_middleware_test.go @@ -12,7 +12,8 @@ import ( func TestResponseBufferMiddleware(t *testing.T) { sendRequest := func(requestBody, responseBody string) *httptest.ResponseRecorder { middleware := WithResponseBufferMiddleware(4, 8, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(responseBody)) + _, err := w.Write([]byte(responseBody)) + assert.NoError(t, err) })) req := httptest.NewRequest("POST", "http://app.example.com/somepath", strings.NewReader(requestBody)) diff --git a/internal/server/router.go b/internal/server/router.go index 9cad182..971452e 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -97,7 +97,7 @@ func (r *Router) RestoreLastSavedState() error { return err } - r.withWriteLock(func() error { + err = r.withWriteLock(func() error { r.services = ServiceMap{} for _, service := range services { r.services[service.name] = service @@ -106,6 +106,9 @@ func (r *Router) RestoreLastSavedState() error { r.hostServices = r.services.HostServices() return nil }) + if err != nil { + return err + } slog.Info("Restored saved state", "path", r.statePath) return nil @@ -125,8 +128,6 @@ func (r *Router) SetServiceTarget(name string, hosts []string, targetURL string, options ServiceOptions, targetOptions TargetOptions, deployTimeout time.Duration, drainTimeout time.Duration, ) error { - defer r.saveStateSnapshot() - slog.Info("Deploying", "service", name, "hosts", hosts, "target", targetURL, "tls", options.TLSEnabled) target, err := NewTarget(targetURL, targetOptions) @@ -146,12 +147,10 @@ func (r *Router) SetServiceTarget(name string, hosts []string, targetURL string, } slog.Info("Deployed", "service", name, "hosts", hosts, "target", targetURL) - return nil + return r.saveStateSnapshot() } func (r *Router) SetRolloutTarget(name string, targetURL string, deployTimeout time.Duration, drainTimeout time.Duration) error { - defer r.saveStateSnapshot() - slog.Info("Deploying for rollout", "service", name, "target", targetURL) service := r.serviceForName(name) @@ -174,34 +173,36 @@ func (r *Router) SetRolloutTarget(name string, targetURL string, deployTimeout t service.SetTarget(TargetSlotRollout, target, drainTimeout) slog.Info("Deployed for rollout", "service", name, "target", targetURL) - return nil + return r.saveStateSnapshot() } func (r *Router) SetRolloutSplit(name string, percent int, allowList []string) error { - defer r.saveStateSnapshot() - service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } - return service.SetRolloutSplit(percent, allowList) + err := service.SetRolloutSplit(percent, allowList) + if err != nil { + return err + } + return r.saveStateSnapshot() } func (r *Router) StopRollout(name string) error { - defer r.saveStateSnapshot() - service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } - return service.StopRollout() + err := service.StopRollout() + if err != nil { + return err + } + return r.saveStateSnapshot() } func (r *Router) RemoveService(name string) error { - defer r.saveStateSnapshot() - err := r.withWriteLock(func() error { service := r.services[name] if service == nil { @@ -218,46 +219,52 @@ func (r *Router) RemoveService(name string) error { return err } - return nil + return r.saveStateSnapshot() } func (r *Router) PauseService(name string, drainTimeout time.Duration, pauseTimeout time.Duration) error { - defer r.saveStateSnapshot() - service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } - return service.Pause(drainTimeout, pauseTimeout) + err := service.Pause(drainTimeout, pauseTimeout) + if err != nil { + return err + } + return r.saveStateSnapshot() } func (r *Router) StopService(name string, drainTimeout time.Duration, message string) error { - defer r.saveStateSnapshot() - service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } - return service.Stop(drainTimeout, message) + err := service.Stop(drainTimeout, message) + if err != nil { + return err + } + return r.saveStateSnapshot() } func (r *Router) ResumeService(name string) error { - defer r.saveStateSnapshot() - service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } - return service.Resume() + err := service.Resume() + if err != nil { + return err + } + return r.saveStateSnapshot() } -func (r *Router) ListActiveServices() ServiceDescriptionMap { +func (r *Router) ListActiveServices() (ServiceDescriptionMap, error) { result := ServiceDescriptionMap{} - r.withReadLock(func() error { + err := r.withReadLock(func() error { for name, service := range r.services { host := strings.Join(service.hosts, ",") if host == "" { @@ -275,7 +282,7 @@ func (r *Router) ListActiveServices() ServiceDescriptionMap { return nil }) - return result + return result, err } func (r *Router) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -303,12 +310,15 @@ func (r *Router) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, e func (r *Router) saveStateSnapshot() error { services := []*Service{} - r.withReadLock(func() error { + err := r.withReadLock(func() error { for _, service := range r.services { services = append(services, service) } return nil }) + if err != nil { + return err + } f, err := os.Create(r.statePath) if err != nil { diff --git a/internal/server/router_test.go b/internal/server/router_test.go index 00b3dff..2450f46 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -203,7 +203,8 @@ func TestRouter_UpdatingPauseStateIndependentlyOfDeployments(t *testing.T) { _, target := testBackend(t, "first", http.StatusOK) require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) - router.PauseService("service1", time.Second, time.Millisecond*10) + err := router.PauseService("service1", time.Second, time.Millisecond*10) + require.NoError(t, err) statusCode, _ := sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) assert.Equal(t, http.StatusGatewayTimeout, statusCode) @@ -213,7 +214,8 @@ func TestRouter_UpdatingPauseStateIndependentlyOfDeployments(t *testing.T) { statusCode, _ = sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) assert.Equal(t, http.StatusGatewayTimeout, statusCode) - router.ResumeService("service1") + err = router.ResumeService("service1") + require.NoError(t, err) statusCode, _ = sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) assert.Equal(t, http.StatusOK, statusCode) @@ -369,7 +371,8 @@ func TestRouter_RestoreLastSavedState(t *testing.T) { assert.Equal(t, http.StatusMovedPermanently, statusCode) router = NewRouter(statePath) - router.RestoreLastSavedState() + err := router.RestoreLastSavedState() + require.NoError(t, err) statusCode, body = sendGETRequest(router, "http://something.example.com") assert.Equal(t, http.StatusOK, statusCode) diff --git a/internal/server/server.go b/internal/server/server.go index bcd31f9..c54e42d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/tls" + "errors" "fmt" "log/slog" "net" @@ -58,7 +59,10 @@ func (s *Server) Stop() { defer cancel() s.commandHandler.Close() - s.httpServer.Shutdown(ctx) + err := s.httpServer.Shutdown(ctx) + if err != nil { + slog.Warn("Server shutdown error exit", "error", err) + } slog.Info("Server stopped") } @@ -103,8 +107,19 @@ func (s *Server) startHTTPServers() error { }, } - go s.httpServer.Serve(s.httpListener) - go s.httpsServer.ServeTLS(s.httpsListener, "", "") + go func() { + err := s.httpServer.Serve(s.httpListener) + if !errors.Is(err, http.ErrServerClosed) { + slog.Error("Error while serving http endpoint", "error", err) + } + }() + + go func() { + err := s.httpsServer.ServeTLS(s.httpsListener, "", "") + if !errors.Is(err, http.ErrServerClosed) { + slog.Error("Error while serving https endpoint", "error", err) + } + }() return nil } diff --git a/internal/server/service.go b/internal/server/service.go index 1c4203c..6eca888 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -235,9 +235,19 @@ func (s *Service) UnmarshalJSON(data []byte) error { s.pauseController = ms.PauseController s.rolloutController = ms.RolloutController - s.initialize(ms.Hosts, ms.Options) - s.restoreSavedTarget(TargetSlotActive, ms.ActiveTarget, ms.TargetOptions) - s.restoreSavedTarget(TargetSlotRollout, ms.RolloutTarget, ms.TargetOptions) + err = s.initialize(ms.Hosts, ms.Options) + if err != nil { + return err + } + err = s.restoreSavedTarget(TargetSlotActive, ms.ActiveTarget, ms.TargetOptions) + if err != nil { + return err + } + + err = s.restoreSavedTarget(TargetSlotRollout, ms.RolloutTarget, ms.TargetOptions) + if err != nil { + return err + } return nil } diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 2e10b0f..95c8e23 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -90,15 +90,18 @@ func TestService_ReturnSuccessfulHealthCheckWhilePausedOrStopped(t *testing.T) { assert.Equal(t, http.StatusOK, checkRequest("/up")) assert.Equal(t, http.StatusOK, checkRequest("/other")) - service.Pause(time.Second, time.Millisecond) + err := service.Pause(time.Second, time.Millisecond) + assert.NoError(t, err) assert.Equal(t, http.StatusOK, checkRequest("/up")) assert.Equal(t, http.StatusGatewayTimeout, checkRequest("/other")) - service.Stop(time.Second, DefaultStopMessage) + err = service.Stop(time.Second, DefaultStopMessage) + assert.NoError(t, err) assert.Equal(t, http.StatusOK, checkRequest("/up")) assert.Equal(t, http.StatusServiceUnavailable, checkRequest("/other")) - service.Resume() + err = service.Resume() + assert.NoError(t, err) assert.Equal(t, http.StatusOK, checkRequest("/up")) assert.Equal(t, http.StatusOK, checkRequest("/other")) } diff --git a/internal/server/target_test.go b/internal/server/target_test.go index c0d4e94..90993ef 100644 --- a/internal/server/target_test.go +++ b/internal/server/target_test.go @@ -19,7 +19,8 @@ import ( func TestTarget_Serve(t *testing.T) { target := testTarget(t, func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("ok")) + _, err := w.Write([]byte("ok")) + assert.NoError(t, err) }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -44,7 +45,8 @@ func TestTarget_ServeSSE(t *testing.T) { target := testTargetWithOptions(t, targetOptions, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") - w.Write([]byte("data: hello\n\n")) + _, err := w.Write([]byte("data: hello\n\n")) + assert.NoError(t, err) w.(http.Flusher).Flush() // Don't return until the client has finished reading. Fail the test if this takes too long. @@ -105,11 +107,13 @@ func TestTarget_ServeWebSocket(t *testing.T) { go func() { kind, body, err := c.Read(context.Background()) - require.NoError(t, err) + assert.NoError(t, err) assert.Equal(t, websocket.MessageText, kind) - c.Write(context.Background(), websocket.MessageText, body) - defer c.CloseNow() + err = c.Write(context.Background(), websocket.MessageText, body) + assert.NoError(t, err) + err = c.CloseNow() + assert.NoError(t, err) }() }) @@ -124,9 +128,12 @@ func TestTarget_ServeWebSocket(t *testing.T) { c, _, err := websocket.Dial(context.Background(), websocketURL, nil) require.NoError(t, err) - defer c.CloseNow() + defer func() { + assert.NoError(t, c.CloseNow()) + }() - c.Write(context.Background(), websocket.MessageText, []byte(body)) + err = c.Write(context.Background(), websocket.MessageText, []byte(body)) + require.NoError(t, err) return c.Read(context.Background()) } @@ -271,7 +278,8 @@ func TestTarget_IsHealthCheckRequest(t *testing.T) { func TestTarget_AddedTargetBecomesHealthy(t *testing.T) { target := testTarget(t, func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("ok")) + _, err := w.Write([]byte("ok")) + assert.NoError(t, err) }) target.BeginHealthChecks() @@ -351,25 +359,30 @@ func TestTarget_DrainRequestsThatNeedToBeCancelled(t *testing.T) { func TestTarget_DrainHijackedConnectionsImmediately(t *testing.T) { target := testTarget(t, func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{}) - require.NoError(t, err) - defer c.CloseNow() - + assert.NoError(t, err) _, _, err = c.Read(context.Background()) - require.Error(t, err) + assert.Error(t, err) + err = c.CloseNow() + // TODO: this check works strange if set to NoError. + // if it runs isolated it works, but if it runs with all, it fails. + assert.Error(t, err) }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r, err := target.StartRequest(r) - require.NoError(t, err) + assert.NoError(t, err) target.SendRequest(w, r) })) - defer server.Close() + + t.Cleanup(server.Close) websocketURL := strings.Replace(server.URL, "http:", "ws:", 1) c, _, err := websocket.Dial(context.Background(), websocketURL, nil) require.NoError(t, err) - defer c.CloseNow() + defer func() { + assert.NoError(t, c.CloseNow()) + }() startedDraining := time.Now() target.Drain(time.Second * 5) @@ -387,7 +400,8 @@ func TestTarget_EnforceMaxBodySizes(t *testing.T) { HealthCheckConfig: defaultHealthCheckConfig, } target := testTargetWithOptions(t, targetOptions, func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(responseBody)) + _, err := w.Write([]byte(responseBody)) + assert.NoError(t, err) }) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestBody)) diff --git a/internal/server/testing.go b/internal/server/testing.go index de5c88b..c09588c 100644 --- a/internal/server/testing.go +++ b/internal/server/testing.go @@ -5,10 +5,9 @@ import ( "net/http" "net/http/httptest" "net/url" - "os" "testing" - "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,7 +18,9 @@ var ( defaultTargetOptions = TargetOptions{HealthCheckConfig: defaultHealthCheckConfig, ResponseTimeout: DefaultTargetTimeout} ) -func testTarget(t *testing.T, handler http.HandlerFunc) *Target { +func testTarget(t testing.TB, handler http.HandlerFunc) *Target { + t.Helper() + _, targetURL := testBackendWithHandler(t, handler) target, err := NewTarget(targetURL, defaultTargetOptions) @@ -27,7 +28,9 @@ func testTarget(t *testing.T, handler http.HandlerFunc) *Target { return target } -func testTargetWithOptions(t *testing.T, targetOptions TargetOptions, handler http.HandlerFunc) *Target { +func testTargetWithOptions(t testing.TB, targetOptions TargetOptions, handler http.HandlerFunc) *Target { + t.Helper() + _, targetURL := testBackendWithHandler(t, handler) target, err := NewTarget(targetURL, targetOptions) @@ -35,14 +38,19 @@ func testTargetWithOptions(t *testing.T, targetOptions TargetOptions, handler ht return target } -func testBackend(t *testing.T, body string, statusCode int) (*httptest.Server, string) { +func testBackend(t testing.TB, body string, statusCode int) (*httptest.Server, string) { + t.Helper() + return testBackendWithHandler(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(statusCode) - w.Write([]byte(body)) + _, err := w.Write([]byte(body)) + assert.NoError(t, err) }) } -func testBackendWithHandler(t *testing.T, handler http.HandlerFunc) (*httptest.Server, string) { +func testBackendWithHandler(t testing.TB, handler http.HandlerFunc) (*httptest.Server, string) { + t.Helper() + server := httptest.NewServer(handler) t.Cleanup(server.Close) @@ -52,31 +60,22 @@ func testBackendWithHandler(t *testing.T, handler http.HandlerFunc) (*httptest.S return server, serverURL.Host } -func testServer(t *testing.T) (*Server, string) { +func testServer(t testing.TB) (*Server, string) { + t.Helper() + config := &Config{ Bind: "127.0.0.1", HttpPort: 0, HttpsPort: 0, - AlternateConfigDir: shortTmpDir(t), + AlternateConfigDir: t.TempDir(), } router := NewRouter(config.StatePath()) server := NewServer(config, router) - server.Start() - + err := server.Start() + require.NoError(t, err) t.Cleanup(server.Stop) addr := fmt.Sprintf("http://localhost:%d", server.HttpPort()) return server, addr } - -func shortTmpDir(t *testing.T) string { - path := "/tmp/" + uuid.New().String() - os.Mkdir(path, 0755) - - t.Cleanup(func() { - os.RemoveAll(path) - }) - - return path -}