Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple hosts per service #24

Merged
merged 12 commits into from
Sep 25, 2024
21 changes: 9 additions & 12 deletions internal/cmd/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import (
)

type deployCommand struct {
cmd *cobra.Command
args server.DeployArgs

tls bool
cmd *cobra.Command
args server.DeployArgs
tlsStaging bool
}

Expand All @@ -29,9 +27,9 @@ func newDeployCommand() *deployCommand {
}

deployCommand.cmd.Flags().StringVar(&deployCommand.args.TargetURL, "target", "", "Target host to deploy")
deployCommand.cmd.Flags().StringVar(&deployCommand.args.Host, "host", "", "Host to serve this target on (empty for wildcard)")
deployCommand.cmd.Flags().StringSliceVar(&deployCommand.args.Hosts, "host", []string{}, "Host(s) to serve this target on (empty for wildcard)")

deployCommand.cmd.Flags().BoolVar(&deployCommand.tls, "tls", false, "Configure TLS for this target (requires a non-empty host)")
deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)")
deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning")

deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy")
Expand Down Expand Up @@ -62,13 +60,12 @@ func newDeployCommand() *deployCommand {
func (c *deployCommand) run(cmd *cobra.Command, args []string) error {
c.args.Service = args[0]

if c.tls {
if c.args.ServiceOptions.TLSEnabled {
c.args.ServiceOptions.ACMECachePath = globalConfig.CertificatePath()
c.args.ServiceOptions.TLSHostname = c.args.Host
}

if c.tlsStaging {
c.args.ServiceOptions.ACMEDirectory = server.ACMEStagingDirectoryURL
if c.tlsStaging {
c.args.ServiceOptions.ACMEDirectory = server.ACMEStagingDirectoryURL
}
}

return withRPCClient(globalConfig.SocketPath(), func(client *rpc.Client) error {
Expand All @@ -91,7 +88,7 @@ func (c *deployCommand) preRun(cmd *cobra.Command, args []string) error {
}

if !cmd.Flags().Changed("forward-headers") {
c.args.TargetOptions.ForwardHeaders = !c.tls
c.args.TargetOptions.ForwardHeaders = !c.args.ServiceOptions.TLSEnabled
}

return nil
Expand Down
4 changes: 2 additions & 2 deletions internal/server/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ type CommandHandler struct {

type DeployArgs struct {
Service string
Host string
TargetURL string
Hosts []string
DeployTimeout time.Duration
DrainTimeout time.Duration
ServiceOptions ServiceOptions
Expand Down Expand Up @@ -114,7 +114,7 @@ func (h *CommandHandler) Close() error {
}

func (h *CommandHandler) Deploy(args DeployArgs, reply *bool) error {
return h.router.SetServiceTarget(args.Service, args.Host, args.TargetURL, args.ServiceOptions, args.TargetOptions, args.DeployTimeout, args.DrainTimeout)
return h.router.SetServiceTarget(args.Service, args.Hosts, args.TargetURL, args.ServiceOptions, args.TargetOptions, args.DeployTimeout, args.DrainTimeout)
}

func (h *CommandHandler) Pause(args PauseArgs, reply *bool) error {
Expand Down
132 changes: 78 additions & 54 deletions internal/server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,55 @@ import (
"net"
"net/http"
"os"
"strings"
"sync"
"time"
)

var (
ErrorServiceNotFound = errors.New("service not found")
ErrorTargetFailedToBecomeHealthy = errors.New("target failed to become healthy")
ErrorHostInUse = errors.New("host is used by another service")
ErrorHostInUse = errors.New("host settings conflict with another service")
ErrorNoServerName = errors.New("no server name provided")
ErrorUnknownServerName = errors.New("unknown server name")
)

type ServiceMap map[string]*Service
type HostServiceMap map[string]*Service

func (m ServiceMap) HostServices() HostServiceMap {
hostServices := HostServiceMap{}
for _, service := range m {
if len(service.hosts) == 0 {
hostServices[""] = service
continue
}
for _, host := range service.hosts {
hostServices[host] = service
}
}
return hostServices
}

func (m HostServiceMap) CheckHostAvailability(name string, hosts []string) *Service {
if len(hosts) == 0 {
hosts = []string{""}
}

for _, host := range hosts {
service := m[host]
if service != nil && service.name != name {
return service
}
}
return nil
}

type Router struct {
statePath string
services HostServiceMap
serviceLock sync.RWMutex
statePath string
services ServiceMap
hostServices HostServiceMap
serviceLock sync.RWMutex
}

type ServiceDescription struct {
Expand All @@ -39,8 +70,9 @@ type ServiceDescriptionMap map[string]ServiceDescription

func NewRouter(statePath string) *Router {
return &Router{
statePath: statePath,
services: HostServiceMap{},
statePath: statePath,
services: ServiceMap{},
hostServices: HostServiceMap{},
}
}

Expand All @@ -64,10 +96,12 @@ func (r *Router) RestoreLastSavedState() error {
}

r.withWriteLock(func() error {
r.services = HostServiceMap{}
r.services = ServiceMap{}
for _, service := range services {
r.services[service.host] = service
r.services[service.name] = service
}

r.hostServices = r.services.HostServices()
return nil
})

Expand All @@ -85,13 +119,13 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
service.ServeHTTP(w, req)
}

func (r *Router) SetServiceTarget(name string, host string, targetURL string,
func (r *Router) SetServiceTarget(name string, hosts []string, targetURL string,
options ServiceOptions, targetOptions TargetOptions,
deployTimeout time.Duration, drainTimeout time.Duration,
) error {
defer r.saveStateSnapshot()

slog.Info("Deploying", "service", name, "host", host, "target", targetURL, "tls", options.RequireTLS())
slog.Info("Deploying", "service", name, "hosts", hosts, "target", targetURL, "tls", options.TLSEnabled)

target, err := NewTarget(targetURL, targetOptions)
if err != nil {
Expand All @@ -100,16 +134,16 @@ func (r *Router) SetServiceTarget(name string, host string, targetURL string,

becameHealthy := target.WaitUntilHealthy(deployTimeout)
if !becameHealthy {
slog.Info("Target failed to become healthy", "host", host, "target", targetURL)
slog.Info("Target failed to become healthy", "hosts", hosts, "target", targetURL)
return ErrorTargetFailedToBecomeHealthy
}

err = r.setActiveTarget(name, host, target, options, drainTimeout)
err = r.setActiveTarget(name, hosts, target, options, drainTimeout)
if err != nil {
return err
}

slog.Info("Deployed", "service", name, "host", host, "target", targetURL)
slog.Info("Deployed", "service", name, "hosts", hosts, "target", targetURL)
return nil
}

Expand All @@ -118,7 +152,7 @@ func (r *Router) SetRolloutTarget(name string, targetURL string, deployTimeout t

slog.Info("Deploying for rollout", "service", name, "target", targetURL)

service := r.serviceForName(name, true)
service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}
Expand All @@ -144,7 +178,7 @@ func (r *Router) SetRolloutTarget(name string, targetURL string, deployTimeout t
func (r *Router) SetRolloutSplit(name string, percent int, allowList []string) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name, true)
service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}
Expand All @@ -155,7 +189,7 @@ func (r *Router) SetRolloutSplit(name string, percent int, allowList []string) e
func (r *Router) StopRollout(name string) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name, true)
service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}
Expand All @@ -167,13 +201,14 @@ func (r *Router) RemoveService(name string) error {
defer r.saveStateSnapshot()

err := r.withWriteLock(func() error {
service := r.serviceForName(name, false)
service := r.services[name]
if service == nil {
return ErrorServiceNotFound
}

service.SetTarget(TargetSlotActive, nil, DefaultDrainTimeout)
delete(r.services, service.host)
delete(r.services, service.name)
r.hostServices = r.services.HostServices()

return nil
})
Expand All @@ -187,7 +222,7 @@ func (r *Router) RemoveService(name string) error {
func (r *Router) PauseService(name string, drainTimeout time.Duration, pauseTimeout time.Duration) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name, true)
service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}
Expand All @@ -198,7 +233,7 @@ func (r *Router) PauseService(name string, drainTimeout time.Duration, pauseTime
func (r *Router) StopService(name string, drainTimeout time.Duration, message string) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name, true)
service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}
Expand All @@ -209,7 +244,7 @@ func (r *Router) StopService(name string, drainTimeout time.Duration, message st
func (r *Router) ResumeService(name string) error {
defer r.saveStateSnapshot()

service := r.serviceForName(name, true)
service := r.serviceForName(name)
if service == nil {
return ErrorServiceNotFound
}
Expand All @@ -221,15 +256,16 @@ func (r *Router) ListActiveServices() ServiceDescriptionMap {
result := ServiceDescriptionMap{}

r.withReadLock(func() error {
for host, service := range r.services {
for name, service := range r.services {
host := strings.Join(service.hosts, ",")
if host == "" {
host = "*"
}
if service.active != nil {
result[service.name] = ServiceDescription{
result[name] = ServiceDescription{
Host: host,
Target: service.active.Target(),
TLS: service.options.RequireTLS(),
TLS: service.options.TLSEnabled,
State: service.pauseController.GetState().String(),
}
}
Expand Down Expand Up @@ -300,56 +336,44 @@ func (r *Router) serviceForHost(host string) *Service {
r.serviceLock.RLock()
defer r.serviceLock.RUnlock()

service, ok := r.services[host]
service, ok := r.hostServices[host]
if !ok {
service = r.services[""]
service = r.hostServices[""]
}

return service
}

func (r *Router) setActiveTarget(name string, host string, target *Target, options ServiceOptions, drainTimeout time.Duration) error {
func (r *Router) setActiveTarget(name string, hosts []string, target *Target, options ServiceOptions, drainTimeout time.Duration) error {
r.serviceLock.Lock()
defer r.serviceLock.Unlock()

service := r.serviceForName(name, false)
conflict := r.hostServices.CheckHostAvailability(name, hosts)
if conflict != nil {
slog.Error("Host settings conflict with another service", "service", conflict.name)
return ErrorHostInUse
}

service := r.services[name]
if service == nil {
service = NewService(name, host, options)
service = NewService(name, hosts, options)
} else {
service.UpdateOptions(options)
service.UpdateOptions(hosts, options)
}

hostService, ok := r.services[host]
if !ok {
if host != service.host {
delete(r.services, service.host)
service.host = host
}

r.services[host] = service
} else if hostService != service {
slog.Error("Host in use by another service", "service", hostService.name, "host", host)
return ErrorHostInUse
}
r.services[name] = service
r.hostServices = r.services.HostServices()

service.SetTarget(TargetSlotActive, target, drainTimeout)

return nil
}

func (r *Router) serviceForName(name string, readLock bool) *Service {
if readLock {
r.serviceLock.RLock()
defer r.serviceLock.RUnlock()
}

for _, service := range r.services {
if name == service.name {
return service
}
}
func (r *Router) serviceForName(name string) *Service {
r.serviceLock.RLock()
defer r.serviceLock.RUnlock()

return nil
return r.services[name]
}

func (r *Router) withReadLock(fn func() error) error {
Expand Down
Loading