Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle all errors #33

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ jobs:

- name: Test
run: go test -v ./...

- name: Lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.61
1 change: 1 addition & 0 deletions internal/cmd/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func newDeployCommand() *deployCommand {

deployCommand.cmd.Flags().BoolVar(&deployCommand.args.TargetOptions.ForwardHeaders, "forward-headers", false, "Forward X-Forwarded headers to target (default false if TLS enabled; otherwise true)")

//nolint:errcheck
deployCommand.cmd.MarkFlagRequired("target")
deployCommand.cmd.MarkFlagsRequiredTogether("tls-certificate-path", "tls-private-key-path")

Expand Down
1 change: 1 addition & 0 deletions internal/cmd/rollout_deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func newRolloutDeployCommand() *rolloutDeployCommand {
rolloutDeployCommand.cmd.Flags().DurationVar(&rolloutDeployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy")
rolloutDeployCommand.cmd.Flags().DurationVar(&rolloutDeployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target")

//nolint:errcheck
rolloutDeployCommand.cmd.MarkFlagRequired("target")

return rolloutDeployCommand
Expand Down
7 changes: 5 additions & 2 deletions internal/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ func (c *runCommand) run(cmd *cobra.Command, args []string) error {
c.setLogger()

router := server.NewRouter(globalConfig.StatePath())
router.RestoreLastSavedState()
err := router.RestoreLastSavedState()
if err != nil {
return err
}

s := server.NewServer(&globalConfig, router)
err := s.Start()
err = s.Start()
if err != nil {
return err
}
Expand Down
20 changes: 15 additions & 5 deletions internal/server/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ func (b *Buffer) Write(p []byte) (int, error) {
}

func (b *Buffer) Read(p []byte) (n int, err error) {
b.setReader()
err = b.setReader()
if err != nil {
return 0, err
}
return b.reader.Read(p)
}

Expand All @@ -95,8 +98,11 @@ func (b *Buffer) Overflowed() bool {
}

func (b *Buffer) Send(w io.Writer) error {
b.setReader()
_, err := io.Copy(w, b.reader)
err := b.setReader()
if err != nil {
return err
}
_, err = io.Copy(w, b.reader)
return err
}

Expand All @@ -120,15 +126,19 @@ func (b *Buffer) writeToDisk(p []byte) (int, error) {
return n, err
}

func (b *Buffer) setReader() {
func (b *Buffer) setReader() error {
if b.reader == nil {
if b.diskBuffer != nil {
b.diskBuffer.Seek(0, 0)
_, err := b.diskBuffer.Seek(0, 0)
if err != nil {
return err
}
b.reader = io.MultiReader(&b.memoryBuffer, b.diskBuffer)
} else {
b.reader = &b.memoryBuffer
}
}
return nil
}

func (b *Buffer) createSpill() error {
Expand Down
5 changes: 3 additions & 2 deletions internal/server/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ func (h *CommandHandler) Remove(args RemoveArgs, reply *bool) error {
}

func (h *CommandHandler) List(args bool, reply *ListResponse) error {
reply.Targets = h.router.ListActiveServices()
var err error
reply.Targets, err = h.router.ListActiveServices()

return nil
return err
}

func (h *CommandHandler) RolloutDeploy(args RolloutDeployArgs, reply *bool) error {
Expand Down
6 changes: 3 additions & 3 deletions internal/server/pause_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ func (p *PauseController) UnmarshalJSON(data []byte) error {

switch p.State {
case PauseStateRunning:
p.Resume()
return p.Resume()
case PauseStatePaused:
p.Pause(p.FailAfter)
return p.Pause(p.FailAfter)
case PauseStateStopped:
p.Stop(p.StopMessage)
return p.Stop(p.StopMessage)
}

return nil
Expand Down
3 changes: 2 additions & 1 deletion internal/server/request_buffer_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (
func TestRequestBufferMiddleware(t *testing.T) {
sendRequest := func(requestBody, responseBody string) *httptest.ResponseRecorder {
middleware := WithRequestBufferMiddleware(4, 8, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(responseBody))
_, err := w.Write([]byte(responseBody))
assert.NoError(t, err)
}))

req := httptest.NewRequest("POST", "http://app.example.com/somepath", strings.NewReader(requestBody))
Expand Down
3 changes: 2 additions & 1 deletion internal/server/response_buffer_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (
func TestResponseBufferMiddleware(t *testing.T) {
sendRequest := func(requestBody, responseBody string) *httptest.ResponseRecorder {
middleware := WithResponseBufferMiddleware(4, 8, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(responseBody))
_, err := w.Write([]byte(responseBody))
assert.NoError(t, err)
}))

req := httptest.NewRequest("POST", "http://app.example.com/somepath", strings.NewReader(requestBody))
Expand Down
68 changes: 39 additions & 29 deletions internal/server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (r *Router) RestoreLastSavedState() error {
return err
}

r.withWriteLock(func() error {
err = r.withWriteLock(func() error {
r.services = ServiceMap{}
for _, service := range services {
r.services[service.name] = service
Expand All @@ -106,6 +106,9 @@ func (r *Router) RestoreLastSavedState() error {
r.hostServices = r.services.HostServices()
return nil
})
if err != nil {
return err
}

slog.Info("Restored saved state", "path", r.statePath)
return nil
Expand All @@ -125,8 +128,6 @@ func (r *Router) SetServiceTarget(name string, hosts []string, targetURL string,
options ServiceOptions, targetOptions TargetOptions,
deployTimeout time.Duration, drainTimeout time.Duration,
) error {
defer r.saveStateSnapshot()

slog.Info("Deploying", "service", name, "hosts", hosts, "target", targetURL, "tls", options.TLSEnabled)

target, err := NewTarget(targetURL, targetOptions)
Expand All @@ -146,12 +147,10 @@ func (r *Router) SetServiceTarget(name string, hosts []string, targetURL string,
}

slog.Info("Deployed", "service", name, "hosts", hosts, "target", targetURL)
return nil
return r.saveStateSnapshot()
}

func (r *Router) SetRolloutTarget(name string, targetURL string, deployTimeout time.Duration, drainTimeout time.Duration) error {
defer r.saveStateSnapshot()

slog.Info("Deploying for rollout", "service", name, "target", targetURL)

service := r.serviceForName(name)
Expand All @@ -174,34 +173,36 @@ func (r *Router) SetRolloutTarget(name string, targetURL string, deployTimeout t
service.SetTarget(TargetSlotRollout, target, drainTimeout)

slog.Info("Deployed for rollout", "service", name, "target", targetURL)
return nil
return r.saveStateSnapshot()
}

func (r *Router) SetRolloutSplit(name string, percent int, allowList []string) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}

return service.SetRolloutSplit(percent, allowList)
err := service.SetRolloutSplit(percent, allowList)
if err != nil {
return err
}
return r.saveStateSnapshot()
}

func (r *Router) StopRollout(name string) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}

return service.StopRollout()
err := service.StopRollout()
if err != nil {
return err
}
return r.saveStateSnapshot()
}

func (r *Router) RemoveService(name string) error {
defer r.saveStateSnapshot()

err := r.withWriteLock(func() error {
service := r.services[name]
if service == nil {
Expand All @@ -218,46 +219,52 @@ func (r *Router) RemoveService(name string) error {
return err
}

return nil
return r.saveStateSnapshot()
}

func (r *Router) PauseService(name string, drainTimeout time.Duration, pauseTimeout time.Duration) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}

return service.Pause(drainTimeout, pauseTimeout)
err := service.Pause(drainTimeout, pauseTimeout)
if err != nil {
return err
}
return r.saveStateSnapshot()
}

func (r *Router) StopService(name string, drainTimeout time.Duration, message string) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}

return service.Stop(drainTimeout, message)
err := service.Stop(drainTimeout, message)
if err != nil {
return err
}
return r.saveStateSnapshot()
}

func (r *Router) ResumeService(name string) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}

return service.Resume()
err := service.Resume()
if err != nil {
return err
}
return r.saveStateSnapshot()
}

func (r *Router) ListActiveServices() ServiceDescriptionMap {
func (r *Router) ListActiveServices() (ServiceDescriptionMap, error) {
result := ServiceDescriptionMap{}

r.withReadLock(func() error {
err := r.withReadLock(func() error {
for name, service := range r.services {
host := strings.Join(service.hosts, ",")
if host == "" {
Expand All @@ -275,7 +282,7 @@ func (r *Router) ListActiveServices() ServiceDescriptionMap {
return nil
})

return result
return result, err
}

func (r *Router) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expand Down Expand Up @@ -303,12 +310,15 @@ func (r *Router) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, e

func (r *Router) saveStateSnapshot() error {
services := []*Service{}
r.withReadLock(func() error {
err := r.withReadLock(func() error {
for _, service := range r.services {
services = append(services, service)
}
return nil
})
if err != nil {
return err
}

f, err := os.Create(r.statePath)
if err != nil {
Expand Down
9 changes: 6 additions & 3 deletions internal/server/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ func TestRouter_UpdatingPauseStateIndependentlyOfDeployments(t *testing.T) {
_, target := testBackend(t, "first", http.StatusOK)

require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout))
router.PauseService("service1", time.Second, time.Millisecond*10)
err := router.PauseService("service1", time.Second, time.Millisecond*10)
require.NoError(t, err)

statusCode, _ := sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10")))
assert.Equal(t, http.StatusGatewayTimeout, statusCode)
Expand All @@ -213,7 +214,8 @@ func TestRouter_UpdatingPauseStateIndependentlyOfDeployments(t *testing.T) {
statusCode, _ = sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10")))
assert.Equal(t, http.StatusGatewayTimeout, statusCode)

router.ResumeService("service1")
err = router.ResumeService("service1")
require.NoError(t, err)

statusCode, _ = sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10")))
assert.Equal(t, http.StatusOK, statusCode)
Expand Down Expand Up @@ -369,7 +371,8 @@ func TestRouter_RestoreLastSavedState(t *testing.T) {
assert.Equal(t, http.StatusMovedPermanently, statusCode)

router = NewRouter(statePath)
router.RestoreLastSavedState()
err := router.RestoreLastSavedState()
require.NoError(t, err)

statusCode, body = sendGETRequest(router, "http://something.example.com")
assert.Equal(t, http.StatusOK, statusCode)
Expand Down
21 changes: 18 additions & 3 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net"
Expand Down Expand Up @@ -58,7 +59,10 @@ func (s *Server) Stop() {
defer cancel()

s.commandHandler.Close()
s.httpServer.Shutdown(ctx)
err := s.httpServer.Shutdown(ctx)
if err != nil {
slog.Warn("Server shutdown error exit", "error", err)
}

slog.Info("Server stopped")
}
Expand Down Expand Up @@ -103,8 +107,19 @@ func (s *Server) startHTTPServers() error {
},
}

go s.httpServer.Serve(s.httpListener)
go s.httpsServer.ServeTLS(s.httpsListener, "", "")
go func() {
err := s.httpServer.Serve(s.httpListener)
if !errors.Is(err, http.ErrServerClosed) {
slog.Error("Error while serving http endpoint", "error", err)
}
}()

go func() {
err := s.httpsServer.ServeTLS(s.httpsListener, "", "")
if !errors.Is(err, http.ErrServerClosed) {
slog.Error("Error while serving https endpoint", "error", err)
}
}()

return nil
}
Expand Down
Loading