Skip to content

Commit

Permalink
feat: validate --host flag for IP address (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
abuchanan-airbyte authored Sep 19, 2024
1 parent f5b5265 commit f4db29a
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 5 deletions.
11 changes: 11 additions & 0 deletions internal/cmd/local/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"os"
"regexp"
"runtime"
"strconv"
"syscall"
Expand Down Expand Up @@ -167,3 +168,13 @@ func (e InvalidPortError) Unwrap() error {
func (e InvalidPortError) Error() string {
return fmt.Sprintf("unable to convert host port %s to integer: %s", e.Port, e.Inner)
}

func validateHostFlag(host string) error {
if ip := net.ParseIP(host); ip != nil {
return localerr.ErrIpAddressForHostFlag
}
if !regexp.MustCompile(`^[a-z0-9](?:[-a-z0-9]*[a-z0-9])?(?:\.[a-z0-9](?:[-a-z0-9]*[a-z0-9])?)*$`).MatchString(host) {
return localerr.ErrInvalidHostFlag
}
return nil
}
25 changes: 25 additions & 0 deletions internal/cmd/local/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,31 @@ func TestGetPort_InpsectErr(t *testing.T) {
}
}

func TestValidateHostFlag(t *testing.T) {
expectErr := func(host string, expect error) {
err := validateHostFlag(host)
if !errors.Is(err, expect) {
t.Errorf("expected error %v for host %q but got %v", expect, host, err)
}
}
expectErr("1.2.3.4", localerr.ErrIpAddressForHostFlag)
expectErr("1.2.3.4:8000", localerr.ErrInvalidHostFlag)
expectErr("1.2.3.4:8000", localerr.ErrInvalidHostFlag)
expectErr("ABC-DEF-GHI.abcd.efgh", localerr.ErrInvalidHostFlag)
expectErr("http://airbyte.foo-data-platform-sbx.bar.cloud", localerr.ErrInvalidHostFlag)

expectOk := func(host string) {
err := validateHostFlag(host)
if err != nil {
t.Errorf("unexpected error for host %q: %s", host, err)
}
}
expectOk("foo")
expectOk("foo.bar")
expectOk("example.com")
expectOk("sub.example01.com")
}

// port returns the port from a string value in the format of "ipv4:port" or "ip::v6:port"
func port(s string) int {
vals := strings.Split(s, ":")
Expand Down
17 changes: 12 additions & 5 deletions internal/cmd/local/local_install.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/airbytehq/abctl/internal/cmd/local/k8s"
"github.com/airbytehq/abctl/internal/cmd/local/local"
"github.com/airbytehq/abctl/internal/maps"

"github.com/airbytehq/abctl/internal/telemetry"
"github.com/pterm/pterm"
)
Expand Down Expand Up @@ -45,6 +46,17 @@ func (i *InstallCmd) Run(ctx context.Context, provider k8s.Provider, telClient t
return err
}

extraVolumeMounts, err := parseVolumeMounts(i.Volume)
if err != nil {
return err
}

for _, host := range i.Host {
if err := validateHostFlag(host); err != nil {
return err
}
}

return telClient.Wrap(ctx, telemetry.Install, func() error {
spinner.UpdateText(fmt.Sprintf("Checking for existing Kubernetes cluster '%s'", provider.ClusterName))

Expand Down Expand Up @@ -77,11 +89,6 @@ func (i *InstallCmd) Run(ctx context.Context, provider k8s.Provider, telClient t
// no existing cluster, need to create one
pterm.Info.Println(fmt.Sprintf("No existing cluster found, cluster '%s' will be created", provider.ClusterName))

extraVolumeMounts, err := parseVolumeMounts(i.Volume)
if err != nil {
return err
}

spinner.UpdateText(fmt.Sprintf("Checking if port %d is available", i.Port))
if err := portAvailable(ctx, i.Port); err != nil {
return err
Expand Down
21 changes: 21 additions & 0 deletions internal/cmd/local/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ package local

import (
"context"
"errors"

"os"
"path/filepath"
"strings"
"testing"

"github.com/airbytehq/abctl/internal/cmd/local/k8s"
"github.com/airbytehq/abctl/internal/cmd/local/localerr"

"github.com/airbytehq/abctl/internal/cmd/local/paths"
"github.com/airbytehq/abctl/internal/telemetry"
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -127,5 +131,22 @@ foo:

if !strings.HasPrefix(err.Error(), "failed to unmarshal file") {
t.Errorf("unexpected error: %v", err)

}
}

func TestInvalidHostFlag_IpAddr(t *testing.T) {
cmd := InstallCmd{Host: []string{"ok", "1.2.3.4"}}
err := cmd.Run(context.Background(), k8s.TestProvider, telemetry.NoopClient{})
if !errors.Is(err, localerr.ErrIpAddressForHostFlag) {
t.Errorf("expected ErrIpAddressForHostFlag but got %v", err)
}
}

func TestInvalidHostFlag_IpAddrWithPort(t *testing.T) {
cmd := InstallCmd{Host: []string{"ok", "1.2.3.4:8000"}}
err := cmd.Run(context.Background(), k8s.TestProvider, telemetry.NoopClient{})
if !errors.Is(err, localerr.ErrInvalidHostFlag) {
t.Errorf("expected ErrInvalidHostFlag but got %v", err)
}
}
16 changes: 16 additions & 0 deletions internal/cmd/local/localerr/localerr.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,20 @@ The ingress port can be changed by passing the flag --port.`,
This could be in indication that the ingress port is already in use by a different application.
The ingress port can be changed by passing the flag --port.`,
}

ErrIpAddressForHostFlag = &LocalError{
msg: "invalid host - can't use an IP address",
help: `Looks like you provided an IP address to the --host flag.
This won't work, because Kubernetes ingress rules require a lowercase domain name.
By default, abctl will allow access from any hostname or IP, so you might not need the --host flag.`,
}

ErrInvalidHostFlag = &LocalError{
msg: "invalid host",
help: `The --host flag expects a lowercase domain name, e.g. "example.com".
IP addresses won't work. Ports won't work (e.g. example:8000). URLs won't work (e.g. http://example.com).
By default, abctl will allow access from any hostname or IP, so you might not need the --host flag.`,
}
)

0 comments on commit f4db29a

Please sign in to comment.