From 25d106b43400bea2220788d84956dde6ba12cf0c Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Mon, 23 Sep 2024 17:02:17 +0100 Subject: [PATCH 01/12] Add separate host routing table --- internal/server/router.go | 97 +++++++++++++++++++++----------------- internal/server/service.go | 3 +- 2 files changed, 56 insertions(+), 44 deletions(-) diff --git a/internal/server/router.go b/internal/server/router.go index f279b61..41bf385 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -20,12 +20,31 @@ var ( ErrorUnknownServerName = errors.New("unknown server name") ) +type ServiceMap map[string]*Service type HostServiceMap map[string]*Service +func (m *ServiceMap) HostServices() HostServiceMap { + hostServices := HostServiceMap{} + for _, service := range *m { + hostServices[service.host] = service + } + return hostServices +} + +func (m *ServiceMap) CheckHostAvailability(service *Service, host string) *Service { + for _, s := range *m { + if s.host == host && s != service { + return s + } + } + return nil +} + type Router struct { - statePath string - services HostServiceMap - serviceLock sync.RWMutex + statePath string + services ServiceMap + hostServices HostServiceMap + serviceLock sync.RWMutex } type ServiceDescription struct { @@ -39,8 +58,9 @@ type ServiceDescriptionMap map[string]ServiceDescription func NewRouter(statePath string) *Router { return &Router{ - statePath: statePath, - services: HostServiceMap{}, + statePath: statePath, + services: ServiceMap{}, + hostServices: HostServiceMap{}, } } @@ -64,10 +84,12 @@ func (r *Router) RestoreLastSavedState() error { } r.withWriteLock(func() error { - r.services = HostServiceMap{} + r.services = ServiceMap{} for _, service := range services { - r.services[service.host] = service + r.services[service.name] = service } + + r.hostServices = r.services.HostServices() return nil }) @@ -118,7 +140,7 @@ func (r *Router) SetRolloutTarget(name string, targetURL string, deployTimeout t slog.Info("Deploying for rollout", "service", name, "target", targetURL) - service := r.serviceForName(name, true) + service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } @@ -144,7 +166,7 @@ func (r *Router) SetRolloutTarget(name string, targetURL string, deployTimeout t func (r *Router) SetRolloutSplit(name string, percent int, allowList []string) error { defer r.saveStateSnapshot() - service := r.serviceForName(name, true) + service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } @@ -155,7 +177,7 @@ func (r *Router) SetRolloutSplit(name string, percent int, allowList []string) e func (r *Router) StopRollout(name string) error { defer r.saveStateSnapshot() - service := r.serviceForName(name, true) + service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } @@ -167,13 +189,14 @@ func (r *Router) RemoveService(name string) error { defer r.saveStateSnapshot() err := r.withWriteLock(func() error { - service := r.serviceForName(name, false) + service := r.services[name] if service == nil { return ErrorServiceNotFound } service.SetTarget(TargetSlotActive, nil, DefaultDrainTimeout) - delete(r.services, service.host) + delete(r.services, service.name) + r.hostServices = r.services.HostServices() return nil }) @@ -187,7 +210,7 @@ func (r *Router) RemoveService(name string) error { func (r *Router) PauseService(name string, drainTimeout time.Duration, pauseTimeout time.Duration) error { defer r.saveStateSnapshot() - service := r.serviceForName(name, true) + service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } @@ -198,7 +221,7 @@ func (r *Router) PauseService(name string, drainTimeout time.Duration, pauseTime func (r *Router) StopService(name string, drainTimeout time.Duration, message string) error { defer r.saveStateSnapshot() - service := r.serviceForName(name, true) + service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } @@ -209,7 +232,7 @@ func (r *Router) StopService(name string, drainTimeout time.Duration, message st func (r *Router) ResumeService(name string) error { defer r.saveStateSnapshot() - service := r.serviceForName(name, true) + service := r.serviceForName(name) if service == nil { return ErrorServiceNotFound } @@ -221,12 +244,13 @@ func (r *Router) ListActiveServices() ServiceDescriptionMap { result := ServiceDescriptionMap{} r.withReadLock(func() error { - for host, service := range r.services { + for name, service := range r.services { + host := service.host if host == "" { host = "*" } if service.active != nil { - result[service.name] = ServiceDescription{ + result[name] = ServiceDescription{ Host: host, Target: service.active.Target(), TLS: service.options.RequireTLS(), @@ -300,9 +324,9 @@ func (r *Router) serviceForHost(host string) *Service { r.serviceLock.RLock() defer r.serviceLock.RUnlock() - service, ok := r.services[host] + service, ok := r.hostServices[host] if !ok { - service = r.services[""] + service = r.hostServices[""] } return service @@ -312,44 +336,31 @@ func (r *Router) setActiveTarget(name string, host string, target *Target, optio r.serviceLock.Lock() defer r.serviceLock.Unlock() - service := r.serviceForName(name, false) + service := r.services[name] if service == nil { service = NewService(name, host, options) } else { - service.UpdateOptions(options) + service.UpdateOptions(host, options) } - hostService, ok := r.services[host] - if !ok { - if host != service.host { - delete(r.services, service.host) - service.host = host - } - - r.services[host] = service - } else if hostService != service { - slog.Error("Host in use by another service", "service", hostService.name, "host", host) + conflict := r.services.CheckHostAvailability(service, host) + if conflict != nil { + slog.Error("Host in use by another service", "service", conflict.name, "host", host) return ErrorHostInUse } + r.services[name] = service + r.hostServices = r.services.HostServices() service.SetTarget(TargetSlotActive, target, drainTimeout) return nil } -func (r *Router) serviceForName(name string, readLock bool) *Service { - if readLock { - r.serviceLock.RLock() - defer r.serviceLock.RUnlock() - } - - for _, service := range r.services { - if name == service.name { - return service - } - } +func (r *Router) serviceForName(name string) *Service { + r.serviceLock.RLock() + defer r.serviceLock.RUnlock() - return nil + return r.services[name] } func (r *Router) withReadLock(fn func() error) error { diff --git a/internal/server/service.go b/internal/server/service.go index cf1b18f..544a427 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -110,7 +110,8 @@ func NewService(name, host string, options ServiceOptions) *Service { return service } -func (s *Service) UpdateOptions(options ServiceOptions) { +func (s *Service) UpdateOptions(host string, options ServiceOptions) { + s.host = host s.options = options s.certManager = s.createCertManager() s.middleware = s.createMiddleware() From 3d6f0d6b5ee2fdf6a2944e8b1a2f75211158d910 Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 08:29:35 +0100 Subject: [PATCH 02/12] Allow `Service` to contain multiple hosts --- internal/server/router.go | 21 ++++++++++++++++----- internal/server/router_test.go | 2 +- internal/server/service.go | 16 ++++++++-------- internal/server/service_test.go | 2 +- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/internal/server/router.go b/internal/server/router.go index 41bf385..fd3c569 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -8,6 +8,8 @@ import ( "net" "net/http" "os" + "slices" + "strings" "sync" "time" ) @@ -26,14 +28,20 @@ type HostServiceMap map[string]*Service func (m *ServiceMap) HostServices() HostServiceMap { hostServices := HostServiceMap{} for _, service := range *m { - hostServices[service.host] = service + if len(service.hosts) == 0 { + hostServices[""] = service + continue + } + for _, host := range service.hosts { + hostServices[host] = service + } } return hostServices } func (m *ServiceMap) CheckHostAvailability(service *Service, host string) *Service { for _, s := range *m { - if s.host == host && s != service { + if s != service && slices.Contains(s.hosts, host) { return s } } @@ -245,7 +253,7 @@ func (r *Router) ListActiveServices() ServiceDescriptionMap { r.withReadLock(func() error { for name, service := range r.services { - host := service.host + host := strings.Join(service.hosts, ",") if host == "" { host = "*" } @@ -336,11 +344,14 @@ func (r *Router) setActiveTarget(name string, host string, target *Target, optio r.serviceLock.Lock() defer r.serviceLock.Unlock() + // TODO: allow setting multiple hosts here + hosts := []string{host} + service := r.services[name] if service == nil { - service = NewService(name, host, options) + service = NewService(name, hosts, options) } else { - service.UpdateOptions(host, options) + service.UpdateOptions(hosts, options) } conflict := r.services.CheckHostAvailability(service, host) diff --git a/internal/server/router_test.go b/internal/server/router_test.go index e1a079b..3baa9ad 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -36,7 +36,7 @@ func TestRouter_Removing(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", defaultEmptyHosts, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://dummy.example.com/") assert.Equal(t, http.StatusOK, statusCode) diff --git a/internal/server/service.go b/internal/server/service.go index 544a427..09656fb 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -85,7 +85,7 @@ func (so ServiceOptions) ScopedCachePath() string { type Service struct { name string - host string + hosts []string options ServiceOptions active *Target @@ -98,10 +98,10 @@ type Service struct { middleware http.Handler } -func NewService(name, host string, options ServiceOptions) *Service { +func NewService(name string, hosts []string, options ServiceOptions) *Service { service := &Service{ name: name, - host: host, + hosts: hosts, options: options, } @@ -110,8 +110,8 @@ func NewService(name, host string, options ServiceOptions) *Service { return service } -func (s *Service) UpdateOptions(host string, options ServiceOptions) { - s.host = host +func (s *Service) UpdateOptions(hosts []string, options ServiceOptions) { + s.hosts = hosts s.options = options s.certManager = s.createCertManager() s.middleware = s.createMiddleware() @@ -195,7 +195,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { type marshalledService struct { Name string `json:"name"` - Host string `json:"host"` + Hosts []string `json:"hosts"` ActiveTarget string `json:"active_target"` RolloutTarget string `json:"rollout_target"` Options ServiceOptions `json:"options"` @@ -214,7 +214,7 @@ func (s *Service) MarshalJSON() ([]byte, error) { return json.Marshal(marshalledService{ Name: s.name, - Host: s.host, + Hosts: s.hosts, ActiveTarget: activeTarget, RolloutTarget: rolloutTarget, Options: s.options, @@ -232,7 +232,7 @@ func (s *Service) UnmarshalJSON(data []byte) error { } s.name = ms.Name - s.host = ms.Host + s.hosts = ms.Hosts s.options = ms.Options s.initialize() diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 0eb91ee..0d3020b 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -127,7 +127,7 @@ func testCreateService(t *testing.T, options ServiceOptions, targetOptions Targe target, err := NewTarget(serverURL.Host, targetOptions) require.NoError(t, err) - service := NewService("test", "", options) + service := NewService("test", []string{""}, options) service.active = target return service From f06c473bcc79ed09da7d5ffa68b2d6c29189b20f Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 08:37:06 +0100 Subject: [PATCH 03/12] Allow multiple `TLSHostnames` --- internal/cmd/deploy.go | 2 +- internal/server/router_test.go | 2 +- internal/server/service.go | 14 +++++++------- internal/server/service_test.go | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index d22c905..2f0b9c3 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -64,7 +64,7 @@ func (c *deployCommand) run(cmd *cobra.Command, args []string) error { if c.tls { c.args.ServiceOptions.ACMECachePath = globalConfig.CertificatePath() - c.args.ServiceOptions.TLSHostname = c.args.Host + c.args.ServiceOptions.TLSHostnames = []string{c.args.Host} } if c.tlsStaging { diff --git a/internal/server/router_test.go b/internal/server/router_test.go index 3baa9ad..b216589 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -124,7 +124,7 @@ func TestRouter_UpdatingOptions(t *testing.T) { assert.Equal(t, http.StatusOK, statusCode) assert.Equal(t, "first", body) - serviceOptions.TLSHostname = "dummy.example.com" + serviceOptions.TLSHostnames = []string{"dummy.example.com"} require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, serviceOptions, targetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body = sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) diff --git a/internal/server/service.go b/internal/server/service.go index 09656fb..9bba75c 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -60,14 +60,14 @@ type HealthCheckConfig struct { } type ServiceOptions struct { - TLSHostname string `json:"tls_hostname"` - ACMEDirectory string `json:"acme_directory"` - ACMECachePath string `json:"acme_cache_path"` - ErrorPagePath string `json:"error_page_path"` + TLSHostnames []string `json:"tls_hostnames"` + ACMEDirectory string `json:"acme_directory"` + ACMECachePath string `json:"acme_cache_path"` + ErrorPagePath string `json:"error_page_path"` } func (so ServiceOptions) RequireTLS() bool { - return so.TLSHostname != "" + return len(so.TLSHostnames) > 0 } func (so ServiceOptions) ScopedCachePath() string { @@ -289,14 +289,14 @@ func (s *Service) initialize() { } func (s *Service) createCertManager() *autocert.Manager { - if s.options.TLSHostname == "" { + if !s.options.RequireTLS() { return nil } return &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(s.options.ScopedCachePath()), - HostPolicy: autocert.HostWhitelist(s.options.TLSHostname), + HostPolicy: autocert.HostWhitelist(s.options.TLSHostnames...), Client: &acme.Client{DirectoryURL: s.options.ACMEDirectory}, } } diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 0d3020b..9e3a9c6 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -25,7 +25,7 @@ func TestService_ServeRequest(t *testing.T) { } func TestService_RedirectToHTTPWhenTLSRequired(t *testing.T) { - service := testCreateService(t, ServiceOptions{TLSHostname: "example.com"}, defaultTargetOptions) + service := testCreateService(t, ServiceOptions{TLSHostnames: []string{"example.com"}}, defaultTargetOptions) require.True(t, service.options.RequireTLS()) From ffff5ed49f4c56a4eab0e4bfbb4fce1c521053d8 Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 08:45:11 +0100 Subject: [PATCH 04/12] Expose ability to set multiple hosts --- internal/cmd/deploy.go | 4 +-- internal/server/commands.go | 4 +-- internal/server/router.go | 26 +++++++++---------- internal/server/router_test.go | 46 +++++++++++++++++----------------- 4 files changed, 40 insertions(+), 40 deletions(-) diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index 2f0b9c3..1f7035f 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -29,7 +29,7 @@ func newDeployCommand() *deployCommand { } deployCommand.cmd.Flags().StringVar(&deployCommand.args.TargetURL, "target", "", "Target host to deploy") - deployCommand.cmd.Flags().StringVar(&deployCommand.args.Host, "host", "", "Host to serve this target on (empty for wildcard)") + deployCommand.cmd.Flags().StringSliceVar(&deployCommand.args.Hosts, "host", []string{}, "Host(s) to serve this target on (empty for wildcard)") deployCommand.cmd.Flags().BoolVar(&deployCommand.tls, "tls", false, "Configure TLS for this target (requires a non-empty host)") deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning") @@ -64,7 +64,7 @@ func (c *deployCommand) run(cmd *cobra.Command, args []string) error { if c.tls { c.args.ServiceOptions.ACMECachePath = globalConfig.CertificatePath() - c.args.ServiceOptions.TLSHostnames = []string{c.args.Host} + c.args.ServiceOptions.TLSHostnames = c.args.Hosts } if c.tlsStaging { diff --git a/internal/server/commands.go b/internal/server/commands.go index c6ca729..8d34420 100644 --- a/internal/server/commands.go +++ b/internal/server/commands.go @@ -18,8 +18,8 @@ type CommandHandler struct { type DeployArgs struct { Service string - Host string TargetURL string + Hosts []string DeployTimeout time.Duration DrainTimeout time.Duration ServiceOptions ServiceOptions @@ -114,7 +114,7 @@ func (h *CommandHandler) Close() error { } func (h *CommandHandler) Deploy(args DeployArgs, reply *bool) error { - return h.router.SetServiceTarget(args.Service, args.Host, args.TargetURL, args.ServiceOptions, args.TargetOptions, args.DeployTimeout, args.DrainTimeout) + return h.router.SetServiceTarget(args.Service, args.Hosts, args.TargetURL, args.ServiceOptions, args.TargetOptions, args.DeployTimeout, args.DrainTimeout) } func (h *CommandHandler) Pause(args PauseArgs, reply *bool) error { diff --git a/internal/server/router.go b/internal/server/router.go index fd3c569..adce8cc 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -115,13 +115,13 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { service.ServeHTTP(w, req) } -func (r *Router) SetServiceTarget(name string, host string, targetURL string, +func (r *Router) SetServiceTarget(name string, hosts []string, targetURL string, options ServiceOptions, targetOptions TargetOptions, deployTimeout time.Duration, drainTimeout time.Duration, ) error { defer r.saveStateSnapshot() - slog.Info("Deploying", "service", name, "host", host, "target", targetURL, "tls", options.RequireTLS()) + slog.Info("Deploying", "service", name, "hosts", hosts, "target", targetURL, "tls", options.RequireTLS()) target, err := NewTarget(targetURL, targetOptions) if err != nil { @@ -130,16 +130,16 @@ func (r *Router) SetServiceTarget(name string, host string, targetURL string, becameHealthy := target.WaitUntilHealthy(deployTimeout) if !becameHealthy { - slog.Info("Target failed to become healthy", "host", host, "target", targetURL) + slog.Info("Target failed to become healthy", "hosts", hosts, "target", targetURL) return ErrorTargetFailedToBecomeHealthy } - err = r.setActiveTarget(name, host, target, options, drainTimeout) + err = r.setActiveTarget(name, hosts, target, options, drainTimeout) if err != nil { return err } - slog.Info("Deployed", "service", name, "host", host, "target", targetURL) + slog.Info("Deployed", "service", name, "hosts", hosts, "target", targetURL) return nil } @@ -340,13 +340,10 @@ func (r *Router) serviceForHost(host string) *Service { return service } -func (r *Router) setActiveTarget(name string, host string, target *Target, options ServiceOptions, drainTimeout time.Duration) error { +func (r *Router) setActiveTarget(name string, hosts []string, target *Target, options ServiceOptions, drainTimeout time.Duration) error { r.serviceLock.Lock() defer r.serviceLock.Unlock() - // TODO: allow setting multiple hosts here - hosts := []string{host} - service := r.services[name] if service == nil { service = NewService(name, hosts, options) @@ -354,11 +351,14 @@ func (r *Router) setActiveTarget(name string, host string, target *Target, optio service.UpdateOptions(hosts, options) } - conflict := r.services.CheckHostAvailability(service, host) - if conflict != nil { - slog.Error("Host in use by another service", "service", conflict.name, "host", host) - return ErrorHostInUse + for _, host := range hosts { + conflict := r.services.CheckHostAvailability(service, host) + if conflict != nil { + slog.Error("Host in use by another service", "service", conflict.name, "host", host) + return ErrorHostInUse + } } + r.services[name] = service r.hostServices = r.services.HostServices() diff --git a/internal/server/router_test.go b/internal/server/router_test.go index b216589..516c726 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -24,7 +24,7 @@ func TestRouter_ActiveServiceForHost(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://dummy.example.com/") @@ -51,7 +51,7 @@ func TestRouter_ActiveServiceForUnknownHost(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, _ := sendGETRequest(router, "http://other.example.com/") @@ -62,7 +62,7 @@ func TestRouter_ActiveServiceForHostContainingPort(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://dummy.example.com:80/") @@ -74,7 +74,7 @@ func TestRouter_ActiveServiceWithoutHost(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "", target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{""}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://dummy.example.com/") @@ -87,14 +87,14 @@ func TestRouter_ReplacingActiveService(t *testing.T) { _, first := testBackend(t, "first", http.StatusOK) _, second := testBackend(t, "second", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://dummy.example.com/") assert.Equal(t, http.StatusOK, statusCode) assert.Equal(t, "first", body) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body = sendGETRequest(router, "http://dummy.example.com/") @@ -111,21 +111,21 @@ func TestRouter_UpdatingOptions(t *testing.T) { targetOptions.BufferRequests = true targetOptions.MaxRequestBodySize = 10 - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, serviceOptions, targetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, serviceOptions, targetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, _ := sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) assert.Equal(t, http.StatusRequestEntityTooLarge, statusCode) targetOptions.BufferRequests = false targetOptions.MaxRequestBodySize = 0 - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, serviceOptions, targetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, serviceOptions, targetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) assert.Equal(t, http.StatusOK, statusCode) assert.Equal(t, "first", body) serviceOptions.TLSHostnames = []string{"dummy.example.com"} - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, serviceOptions, targetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, serviceOptions, targetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body = sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) assert.Equal(t, http.StatusMovedPermanently, statusCode) @@ -136,13 +136,13 @@ func TestRouter_UpdatingPauseStateIndependentlyOfDeployments(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) router.PauseService("service1", time.Second, time.Millisecond*10) statusCode, _ := sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) assert.Equal(t, http.StatusGatewayTimeout, statusCode) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, _ = sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) assert.Equal(t, http.StatusGatewayTimeout, statusCode) @@ -158,14 +158,14 @@ func TestRouter_ChangingHostForService(t *testing.T) { _, first := testBackend(t, "first", http.StatusOK) _, second := testBackend(t, "second", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://dummy.example.com/") assert.Equal(t, http.StatusOK, statusCode) assert.Equal(t, "first", body) - require.NoError(t, router.SetServiceTarget("service1", "dummy2.example.com", second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy2.example.com"}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body = sendGETRequest(router, "http://dummy2.example.com/") @@ -181,8 +181,8 @@ func TestRouter_ReusingHost(t *testing.T) { _, first := testBackend(t, "first", http.StatusOK) _, second := testBackend(t, "second", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "dummy.example.com", first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) - err := router.SetServiceTarget("service12", "dummy.example.com", second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout) + require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + err := router.SetServiceTarget("service12", []string{"dummy.example.com"}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout) require.EqualError(t, err, "host is used by another service", "Error message does not match expected one") @@ -197,8 +197,8 @@ func TestRouter_RoutingMultipleHosts(t *testing.T) { _, first := testBackend(t, "first", http.StatusOK) _, second := testBackend(t, "second", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "s1.example.com", first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) - require.NoError(t, router.SetServiceTarget("service2", "s2.example.com", second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"s1.example.com"}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service2", []string{"s2.example.com"}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://s1.example.com/") assert.Equal(t, http.StatusOK, statusCode) @@ -214,8 +214,8 @@ func TestRouter_TargetWithoutHostActsAsWildcard(t *testing.T) { _, first := testBackend(t, "first", http.StatusOK) _, second := testBackend(t, "second", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "s1.example.com", first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) - require.NoError(t, router.SetServiceTarget("default", "", second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{"s1.example.com"}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("default", []string{""}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://s1.example.com/") assert.Equal(t, http.StatusOK, statusCode) @@ -234,7 +234,7 @@ func TestRouter_ServiceFailingToBecomeHealthy(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "", http.StatusInternalServerError) - err := router.SetServiceTarget("example", "example.com", target, defaultServiceOptions, defaultTargetOptions, time.Millisecond*20, DefaultDrainTimeout) + err := router.SetServiceTarget("example", []string{"example.com"}, target, defaultServiceOptions, defaultTargetOptions, time.Millisecond*20, DefaultDrainTimeout) assert.Equal(t, ErrorTargetFailedToBecomeHealthy, err) statusCode, _ := sendGETRequest(router, "http://example.com/") @@ -247,7 +247,7 @@ func TestRouter_EnablingRollout(t *testing.T) { _, first := testBackend(t, "first", http.StatusOK) _, second := testBackend(t, "second", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", "", first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", []string{""}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) require.NoError(t, router.SetRolloutTarget("service1", second, DefaultDeployTimeout, DefaultDrainTimeout)) checkResponse := func(expected string) { @@ -277,8 +277,8 @@ func TestRouter_RestoreLastSavedState(t *testing.T) { _, second := testBackend(t, "second", http.StatusOK) router := NewRouter(statePath) - require.NoError(t, router.SetServiceTarget("default", "", first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) - require.NoError(t, router.SetServiceTarget("other", "other.example.com", second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("default", []string{""}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("other", []string{"other.example.com"}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://something.example.com") assert.Equal(t, http.StatusOK, statusCode) From b05085782277b020e9163dcdc909502445ec1887 Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 09:09:27 +0100 Subject: [PATCH 05/12] Add test for multiple-host routing --- internal/server/router_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/internal/server/router_test.go b/internal/server/router_test.go index 516c726..1f2fd0b 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -47,6 +47,24 @@ func TestRouter_Removing(t *testing.T) { assert.Equal(t, http.StatusNotFound, statusCode) } +func TestRouter_ActiveServiceForMultipleHosts(t *testing.T) { + router := testRouter(t) + _, target := testBackend(t, "first", http.StatusOK) + + require.NoError(t, router.SetServiceTarget("service1", []string{"1.example.com", "2.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + + statusCode, body := sendGETRequest(router, "http://1.example.com/") + assert.Equal(t, http.StatusOK, statusCode) + assert.Equal(t, "first", body) + + statusCode, body = sendGETRequest(router, "http://2.example.com/") + assert.Equal(t, http.StatusOK, statusCode) + assert.Equal(t, "first", body) + + statusCode, _ = sendGETRequest(router, "http://3.example.com/") + assert.Equal(t, http.StatusNotFound, statusCode) +} + func TestRouter_ActiveServiceForUnknownHost(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) From c29e8ba596bd5ebc738c9479ba6b32051b8af1bd Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 11:04:35 +0100 Subject: [PATCH 06/12] Don't duplicate TLS hosts in config When TLS is enabled, the hostnames it uses always match the service's hosts. So there's no need to duplicate them. We can instead just have a flag to note whether TLS is enabled for the service. --- internal/cmd/deploy.go | 14 ++++++-------- internal/server/router.go | 4 ++-- internal/server/router_test.go | 2 +- internal/server/service.go | 20 ++++++++------------ internal/server/service_test.go | 18 +++++++++--------- internal/server/testing.go | 1 + 6 files changed, 27 insertions(+), 32 deletions(-) diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index 1f7035f..8524ded 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -13,7 +13,6 @@ type deployCommand struct { cmd *cobra.Command args server.DeployArgs - tls bool tlsStaging bool } @@ -31,7 +30,7 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().StringVar(&deployCommand.args.TargetURL, "target", "", "Target host to deploy") deployCommand.cmd.Flags().StringSliceVar(&deployCommand.args.Hosts, "host", []string{}, "Host(s) to serve this target on (empty for wildcard)") - deployCommand.cmd.Flags().BoolVar(&deployCommand.tls, "tls", false, "Configure TLS for this target (requires a non-empty host)") + deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)") deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning") deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy") @@ -62,13 +61,12 @@ func newDeployCommand() *deployCommand { func (c *deployCommand) run(cmd *cobra.Command, args []string) error { c.args.Service = args[0] - if c.tls { + if c.args.ServiceOptions.TLSEnabled { c.args.ServiceOptions.ACMECachePath = globalConfig.CertificatePath() - c.args.ServiceOptions.TLSHostnames = c.args.Hosts - } - if c.tlsStaging { - c.args.ServiceOptions.ACMEDirectory = server.ACMEStagingDirectoryURL + if c.tlsStaging { + c.args.ServiceOptions.ACMEDirectory = server.ACMEStagingDirectoryURL + } } return withRPCClient(globalConfig.SocketPath(), func(client *rpc.Client) error { @@ -91,7 +89,7 @@ func (c *deployCommand) preRun(cmd *cobra.Command, args []string) error { } if !cmd.Flags().Changed("forward-headers") { - c.args.TargetOptions.ForwardHeaders = !c.tls + c.args.TargetOptions.ForwardHeaders = !c.args.ServiceOptions.TLSEnabled } return nil diff --git a/internal/server/router.go b/internal/server/router.go index adce8cc..8f3ef3c 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -121,7 +121,7 @@ func (r *Router) SetServiceTarget(name string, hosts []string, targetURL string, ) error { defer r.saveStateSnapshot() - slog.Info("Deploying", "service", name, "hosts", hosts, "target", targetURL, "tls", options.RequireTLS()) + slog.Info("Deploying", "service", name, "hosts", hosts, "target", targetURL, "tls", options.TLSEnabled) target, err := NewTarget(targetURL, targetOptions) if err != nil { @@ -261,7 +261,7 @@ func (r *Router) ListActiveServices() ServiceDescriptionMap { result[name] = ServiceDescription{ Host: host, Target: service.active.Target(), - TLS: service.options.RequireTLS(), + TLS: service.options.TLSEnabled, State: service.pauseController.GetState().String(), } } diff --git a/internal/server/router_test.go b/internal/server/router_test.go index 1f2fd0b..508ec53 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -142,7 +142,7 @@ func TestRouter_UpdatingOptions(t *testing.T) { assert.Equal(t, http.StatusOK, statusCode) assert.Equal(t, "first", body) - serviceOptions.TLSHostnames = []string{"dummy.example.com"} + serviceOptions.TLSEnabled = true require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, target, serviceOptions, targetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body = sendRequest(router, httptest.NewRequest(http.MethodPost, "http://dummy.example.com", strings.NewReader("Something longer than 10"))) diff --git a/internal/server/service.go b/internal/server/service.go index 9bba75c..cce833b 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -60,14 +60,10 @@ type HealthCheckConfig struct { } type ServiceOptions struct { - TLSHostnames []string `json:"tls_hostnames"` - ACMEDirectory string `json:"acme_directory"` - ACMECachePath string `json:"acme_cache_path"` - ErrorPagePath string `json:"error_page_path"` -} - -func (so ServiceOptions) RequireTLS() bool { - return len(so.TLSHostnames) > 0 + TLSEnabled bool `json:"tls_enabled"` + ACMEDirectory string `json:"acme_directory"` + ACMECachePath string `json:"acme_cache_path"` + ErrorPagePath string `json:"error_page_path"` } func (so ServiceOptions) ScopedCachePath() string { @@ -289,14 +285,14 @@ func (s *Service) initialize() { } func (s *Service) createCertManager() *autocert.Manager { - if !s.options.RequireTLS() { + if !s.options.TLSEnabled { return nil } return &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(s.options.ScopedCachePath()), - HostPolicy: autocert.HostWhitelist(s.options.TLSHostnames...), + HostPolicy: autocert.HostWhitelist(s.hosts...), Client: &acme.Client{DirectoryURL: s.options.ACMEDirectory}, } } @@ -314,12 +310,12 @@ func (s *Service) createMiddleware() http.Handler { func (s *Service) serviceRequestWithTarget(w http.ResponseWriter, r *http.Request) { LoggingRequestContext(r).Service = s.name - if s.options.RequireTLS() && r.TLS == nil { + if s.options.TLSEnabled && r.TLS == nil { s.redirectToHTTPS(w, r) return } - if !s.options.RequireTLS() && r.TLS != nil { + if !s.options.TLSEnabled && r.TLS != nil { SetErrorResponse(w, r, http.StatusServiceUnavailable, nil) return } diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 9e3a9c6..8ae8b0c 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -15,7 +15,7 @@ import ( ) func TestService_ServeRequest(t *testing.T) { - service := testCreateService(t, defaultServiceOptions, defaultTargetOptions) + service := testCreateService(t, defaultEmptyHosts, defaultServiceOptions, defaultTargetOptions) req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader("")) w := httptest.NewRecorder() @@ -25,9 +25,9 @@ func TestService_ServeRequest(t *testing.T) { } func TestService_RedirectToHTTPWhenTLSRequired(t *testing.T) { - service := testCreateService(t, ServiceOptions{TLSHostnames: []string{"example.com"}}, defaultTargetOptions) + service := testCreateService(t, []string{"example.com"}, ServiceOptions{TLSEnabled: true}, defaultTargetOptions) - require.True(t, service.options.RequireTLS()) + require.True(t, service.options.TLSEnabled) req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) w := httptest.NewRecorder() @@ -43,9 +43,9 @@ func TestService_RedirectToHTTPWhenTLSRequired(t *testing.T) { } func TestService_RejectTLSRequestsWhenNotConfigured(t *testing.T) { - service := testCreateService(t, defaultServiceOptions, defaultTargetOptions) + service := testCreateService(t, defaultEmptyHosts, defaultServiceOptions, defaultTargetOptions) - require.False(t, service.options.RequireTLS()) + require.False(t, service.options.TLSEnabled) req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) w := httptest.NewRecorder() @@ -61,7 +61,7 @@ func TestService_RejectTLSRequestsWhenNotConfigured(t *testing.T) { } func TestService_ReturnSuccessfulHealthCheckWhilePausedOrStopped(t *testing.T) { - service := testCreateService(t, defaultServiceOptions, defaultTargetOptions) + service := testCreateService(t, defaultEmptyHosts, defaultServiceOptions, defaultTargetOptions) checkRequest := func(path string) int { req := httptest.NewRequest(http.MethodGet, path, nil) @@ -93,7 +93,7 @@ func TestService_MarshallingState(t *testing.T) { MaxMemoryBufferSize: 123, } - service := testCreateService(t, defaultServiceOptions, targetOptions) + service := testCreateService(t, defaultEmptyHosts, defaultServiceOptions, targetOptions) require.NoError(t, service.Stop(time.Second, DefaultStopMessage)) service.SetTarget(TargetSlotRollout, service.active, time.Millisecond) require.NoError(t, service.SetRolloutSplit(20, []string{"first"})) @@ -117,7 +117,7 @@ func TestService_MarshallingState(t *testing.T) { assert.Equal(t, []string{"first"}, service2.rolloutController.Allowlist) } -func testCreateService(t *testing.T, options ServiceOptions, targetOptions TargetOptions) *Service { +func testCreateService(t *testing.T, hosts []string, options ServiceOptions, targetOptions TargetOptions) *Service { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) t.Cleanup(server.Close) @@ -127,7 +127,7 @@ func testCreateService(t *testing.T, options ServiceOptions, targetOptions Targe target, err := NewTarget(serverURL.Host, targetOptions) require.NoError(t, err) - service := NewService("test", []string{""}, options) + service := NewService("test", hosts, options) service.active = target return service diff --git a/internal/server/testing.go b/internal/server/testing.go index 84183bb..de5c88b 100644 --- a/internal/server/testing.go +++ b/internal/server/testing.go @@ -14,6 +14,7 @@ import ( var ( defaultHealthCheckConfig = HealthCheckConfig{Path: DefaultHealthCheckPath, Interval: DefaultHealthCheckInterval, Timeout: DefaultHealthCheckTimeout} + defaultEmptyHosts = []string{} defaultServiceOptions = ServiceOptions{} defaultTargetOptions = TargetOptions{HealthCheckConfig: defaultHealthCheckConfig, ResponseTimeout: DefaultTargetTimeout} ) From 95c22bee9ba766230b6116c1d543cc3e4b0f169a Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 12:49:25 +0100 Subject: [PATCH 07/12] Include TLS in saved state test --- internal/server/router_test.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/internal/server/router_test.go b/internal/server/router_test.go index 508ec53..d84b8f3 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -296,15 +296,14 @@ func TestRouter_RestoreLastSavedState(t *testing.T) { router := NewRouter(statePath) require.NoError(t, router.SetServiceTarget("default", []string{""}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) - require.NoError(t, router.SetServiceTarget("other", []string{"other.example.com"}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("other", []string{"other.example.com"}, second, ServiceOptions{TLSEnabled: true}, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://something.example.com") assert.Equal(t, http.StatusOK, statusCode) assert.Equal(t, "first", body) - statusCode, body = sendGETRequest(router, "http://other.example.com/") - assert.Equal(t, http.StatusOK, statusCode) - assert.Equal(t, "second", body) + statusCode, _ = sendGETRequest(router, "http://other.example.com/") + assert.Equal(t, http.StatusMovedPermanently, statusCode) router = NewRouter(statePath) router.RestoreLastSavedState() @@ -313,9 +312,8 @@ func TestRouter_RestoreLastSavedState(t *testing.T) { assert.Equal(t, http.StatusOK, statusCode) assert.Equal(t, "first", body) - statusCode, body = sendGETRequest(router, "http://other.example.com/") - assert.Equal(t, http.StatusOK, statusCode) - assert.Equal(t, "second", body) + statusCode, _ = sendGETRequest(router, "http://other.example.com/") + assert.Equal(t, http.StatusMovedPermanently, statusCode) } // Helpers From 5f2961ac2d0edc2e9a910b5f9dffb40f9577f47f Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 15:58:46 -0400 Subject: [PATCH 08/12] Remove unnecessary whitespace --- internal/cmd/deploy.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index 8524ded..bb23f56 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -10,9 +10,8 @@ import ( ) type deployCommand struct { - cmd *cobra.Command - args server.DeployArgs - + cmd *cobra.Command + args server.DeployArgs tlsStaging bool } From 6a2770f9e2ec133a2eaf67ec4512a874f0c7beca Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 16:16:30 -0400 Subject: [PATCH 09/12] Add test for updating hosts on a service --- internal/server/router_test.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/internal/server/router_test.go b/internal/server/router_test.go index d84b8f3..7ce295c 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -65,6 +65,26 @@ func TestRouter_ActiveServiceForMultipleHosts(t *testing.T) { assert.Equal(t, http.StatusNotFound, statusCode) } +func TestRouter_UpdatingHostsOfActiveService(t *testing.T) { + router := testRouter(t) + _, target := testBackend(t, "first", http.StatusOK) + + require.NoError(t, router.SetServiceTarget("service1", []string{"1.example.com", "2.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + + require.NoError(t, router.SetServiceTarget("service1", []string{"3.example.com", "2.example.com"}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + + statusCode, _ := sendGETRequest(router, "http://1.example.com/") + assert.Equal(t, http.StatusNotFound, statusCode) + + statusCode, body := sendGETRequest(router, "http://2.example.com/") + assert.Equal(t, http.StatusOK, statusCode) + assert.Equal(t, "first", body) + + statusCode, body = sendGETRequest(router, "http://3.example.com/") + assert.Equal(t, http.StatusOK, statusCode) + assert.Equal(t, "first", body) +} + func TestRouter_ActiveServiceForUnknownHost(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) From 026bb6264ac549c84faf1b0cd5ca2ab7d38405b7 Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 16:35:13 -0400 Subject: [PATCH 10/12] Check for host conflict before updating service --- internal/server/router.go | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/internal/server/router.go b/internal/server/router.go index 8f3ef3c..dfcb74b 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "os" - "slices" "strings" "sync" "time" @@ -25,9 +24,9 @@ var ( type ServiceMap map[string]*Service type HostServiceMap map[string]*Service -func (m *ServiceMap) HostServices() HostServiceMap { +func (m ServiceMap) HostServices() HostServiceMap { hostServices := HostServiceMap{} - for _, service := range *m { + for _, service := range m { if len(service.hosts) == 0 { hostServices[""] = service continue @@ -39,10 +38,11 @@ func (m *ServiceMap) HostServices() HostServiceMap { return hostServices } -func (m *ServiceMap) CheckHostAvailability(service *Service, host string) *Service { - for _, s := range *m { - if s != service && slices.Contains(s.hosts, host) { - return s +func (m HostServiceMap) CheckHostAvailability(name string, hosts []string) *Service { + for _, host := range hosts { + service := m[host] + if service != nil && service.name != name { + return service } } return nil @@ -344,6 +344,12 @@ func (r *Router) setActiveTarget(name string, hosts []string, target *Target, op r.serviceLock.Lock() defer r.serviceLock.Unlock() + conflict := r.hostServices.CheckHostAvailability(name, hosts) + if conflict != nil { + slog.Error("Host in use by another service", "service", conflict.name) + return ErrorHostInUse + } + service := r.services[name] if service == nil { service = NewService(name, hosts, options) @@ -351,14 +357,6 @@ func (r *Router) setActiveTarget(name string, hosts []string, target *Target, op service.UpdateOptions(hosts, options) } - for _, host := range hosts { - conflict := r.services.CheckHostAvailability(service, host) - if conflict != nil { - slog.Error("Host in use by another service", "service", conflict.name, "host", host) - return ErrorHostInUse - } - } - r.services[name] = service r.hostServices = r.services.HostServices() From 2e8e753036ea0fbe09649ec6e045f2d858cc9ce0 Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 17:06:29 -0400 Subject: [PATCH 11/12] Ensure wildcard hosts don't conflict --- internal/server/router.go | 8 ++++++-- internal/server/router_test.go | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/internal/server/router.go b/internal/server/router.go index dfcb74b..87a1bcd 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -16,7 +16,7 @@ import ( var ( ErrorServiceNotFound = errors.New("service not found") ErrorTargetFailedToBecomeHealthy = errors.New("target failed to become healthy") - ErrorHostInUse = errors.New("host is used by another service") + ErrorHostInUse = errors.New("host settings conflict with another service") ErrorNoServerName = errors.New("no server name provided") ErrorUnknownServerName = errors.New("unknown server name") ) @@ -39,6 +39,10 @@ func (m ServiceMap) HostServices() HostServiceMap { } func (m HostServiceMap) CheckHostAvailability(name string, hosts []string) *Service { + if len(hosts) == 0 { + hosts = []string{""} + } + for _, host := range hosts { service := m[host] if service != nil && service.name != name { @@ -346,7 +350,7 @@ func (r *Router) setActiveTarget(name string, hosts []string, target *Target, op conflict := r.hostServices.CheckHostAvailability(name, hosts) if conflict != nil { - slog.Error("Host in use by another service", "service", conflict.name) + slog.Error("Host settings conflict with another service", "service", conflict.name) return ErrorHostInUse } diff --git a/internal/server/router_test.go b/internal/server/router_test.go index 7ce295c..b8e0066 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -222,7 +222,7 @@ func TestRouter_ReusingHost(t *testing.T) { require.NoError(t, router.SetServiceTarget("service1", []string{"dummy.example.com"}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) err := router.SetServiceTarget("service12", []string{"dummy.example.com"}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout) - require.EqualError(t, err, "host is used by another service", "Error message does not match expected one") + require.Equal(t, ErrorHostInUse, err) statusCode, body := sendGETRequest(router, "http://dummy.example.com/") @@ -230,6 +230,21 @@ func TestRouter_ReusingHost(t *testing.T) { assert.Equal(t, "first", body) } +func TestRouter_ReusingEmptyHost(t *testing.T) { + router := testRouter(t) + _, first := testBackend(t, "first", http.StatusOK) + _, second := testBackend(t, "second", http.StatusOK) + + require.NoError(t, router.SetServiceTarget("service1", defaultEmptyHosts, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + err := router.SetServiceTarget("service12", defaultEmptyHosts, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout) + + require.Equal(t, ErrorHostInUse, err) + + statusCode, body := sendGETRequest(router, "http://anything.example.com/") + assert.Equal(t, http.StatusOK, statusCode) + assert.Equal(t, "first", body) +} + func TestRouter_RoutingMultipleHosts(t *testing.T) { router := testRouter(t) _, first := testBackend(t, "first", http.StatusOK) From 12d99151b88cead89e1a2e9f200bdc45b92c0bd9 Mon Sep 17 00:00:00 2001 From: Kevin McConnell Date: Wed, 25 Sep 2024 17:10:17 -0400 Subject: [PATCH 12/12] Tidy test argument --- internal/server/router_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/server/router_test.go b/internal/server/router_test.go index b8e0066..980b40e 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -112,7 +112,7 @@ func TestRouter_ActiveServiceWithoutHost(t *testing.T) { router := testRouter(t) _, target := testBackend(t, "first", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", []string{""}, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", defaultEmptyHosts, target, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://dummy.example.com/") @@ -268,7 +268,7 @@ func TestRouter_TargetWithoutHostActsAsWildcard(t *testing.T) { _, second := testBackend(t, "second", http.StatusOK) require.NoError(t, router.SetServiceTarget("service1", []string{"s1.example.com"}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) - require.NoError(t, router.SetServiceTarget("default", []string{""}, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("default", defaultEmptyHosts, second, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://s1.example.com/") assert.Equal(t, http.StatusOK, statusCode) @@ -300,7 +300,7 @@ func TestRouter_EnablingRollout(t *testing.T) { _, first := testBackend(t, "first", http.StatusOK) _, second := testBackend(t, "second", http.StatusOK) - require.NoError(t, router.SetServiceTarget("service1", []string{""}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("service1", defaultEmptyHosts, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) require.NoError(t, router.SetRolloutTarget("service1", second, DefaultDeployTimeout, DefaultDrainTimeout)) checkResponse := func(expected string) { @@ -330,7 +330,7 @@ func TestRouter_RestoreLastSavedState(t *testing.T) { _, second := testBackend(t, "second", http.StatusOK) router := NewRouter(statePath) - require.NoError(t, router.SetServiceTarget("default", []string{""}, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) + require.NoError(t, router.SetServiceTarget("default", defaultEmptyHosts, first, defaultServiceOptions, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) require.NoError(t, router.SetServiceTarget("other", []string{"other.example.com"}, second, ServiceOptions{TLSEnabled: true}, defaultTargetOptions, DefaultDeployTimeout, DefaultDrainTimeout)) statusCode, body := sendGETRequest(router, "http://something.example.com")