diff --git a/cmd/portpatrol/main.go b/cmd/portpatrol/main.go index 678602c..9f1a475 100644 --- a/cmd/portpatrol/main.go +++ b/cmd/portpatrol/main.go @@ -14,7 +14,7 @@ import ( "github.com/containeroo/portpatrol/internal/runner" ) -const version = "0.4.4" +const version = "0.4.5" // run is the main function of the application func run(ctx context.Context, getEnv func(string) string, output io.Writer) error { diff --git a/cmd/portpatrol/main_test.go b/cmd/portpatrol/main_test.go index 294f3e8..60a551f 100644 --- a/cmd/portpatrol/main_test.go +++ b/cmd/portpatrol/main_test.go @@ -206,7 +206,7 @@ func TestRun(t *testing.T) { t.Error("Expected error, got none") } - expected := "configuration error: unsupported check type: invalid" + expected := "configuration error: invalid check type from environment: unsupported check type: invalid" if err.Error() != expected { t.Errorf("Expected error to contain %q, got %q", expected, err.Error()) } @@ -237,7 +237,7 @@ func TestRun(t *testing.T) { t.Error("Expected error, got none") } - expected := "configuration error: could not infer check type for address htp://localhost:8080: unsupported scheme: htp" + expected := "configuration error: could not infer check type from address htp://localhost:8080: unsupported check type: htp" if err.Error() != expected { t.Errorf("Expected error to contain %q, got %q", expected, err.Error()) } diff --git a/internal/checker/checker.go b/internal/checker/checker.go index b60531d..b6912b4 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -7,11 +7,18 @@ import ( "time" ) -// SupportedCheckTypes maps check types to their supported schemes. -var SupportedCheckTypes = map[string][]string{ - "http": {"http", "https"}, - "tcp": {"tcp"}, - "icmp": {"icmp"}, +// CheckType is an enumeration that represents the type of check being performed. +type CheckType int + +const ( + TCP CheckType = iota // TCP represents a check over the TCP protocol. + HTTP // HTTP represents a check over the HTTP protocol. + ICMP // ICMP represents a check using the ICMP protocol (ping). +) + +// String returns the string representation of the CheckType. +func (c CheckType) String() string { + return [...]string{"TCP", "HTTP", "ICMP"}[c] } // Checker is an interface that defines methods to perform a check. @@ -21,47 +28,29 @@ type Checker interface { } // Factory function that returns the appropriate Checker based on checkType -func NewChecker(checkType, name, address string, timeout time.Duration, getEnv func(string) string) (Checker, error) { +func NewChecker(checkType CheckType, name, address string, timeout time.Duration, getEnv func(string) string) (Checker, error) { switch checkType { - case "http", "https": - // HTTP and HTTPS checkers may need environment variables for proxy settings, etc. + case HTTP: // HTTP and HTTPS checkers may need environment variables for proxy settings, etc. return NewHTTPChecker(name, address, timeout, getEnv) - case "tcp": - // TCP checkers may not need environment variables + case TCP: // TCP checkers may not need environment variables return NewTCPChecker(name, address, timeout) - case "icmp": - // ICMP checkers may have a different timeout logic + case ICMP: // ICMP checkers may have a different timeout logic return NewICMPChecker(name, address, timeout, getEnv) default: - return nil, fmt.Errorf("unsupported check type: %s", checkType) + return nil, fmt.Errorf("unsupported check type: %d", checkType) } } -// IsValidCheckType validates if the check type is supported. -func IsValidCheckType(checkType string) bool { - _, exists := SupportedCheckTypes[checkType] - - return exists -} - -// InferCheckType infers the check type based on the scheme of the target address. -// It returns an empty string and no error if no scheme is provided. -// If an unsupported scheme is provided, it returns an error. -func InferCheckType(address string) (string, error) { - scheme, _ := extractScheme(address) - if scheme == "" { - return "", nil - } - - scheme = strings.ToLower(scheme) // Normalize the scheme to lowercase - - for checkType, schemes := range SupportedCheckTypes { - for _, s := range schemes { - if s == scheme { - return checkType, nil - } - } +// GetCheckTypeFromString converts a string to a CheckType enum. +func GetCheckTypeFromString(checkTypeStr string) (CheckType, error) { + switch strings.ToLower(checkTypeStr) { + case "http", "https": + return HTTP, nil + case "tcp": + return TCP, nil + case "icmp": + return ICMP, nil + default: + return -1, fmt.Errorf("unsupported check type: %s", checkTypeStr) } - - return "", fmt.Errorf("unsupported scheme: %s", scheme) } diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index f4ccdfe..63ae753 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -11,7 +11,7 @@ func TestNewChecker(t *testing.T) { t.Run("Valid HTTP checker", func(t *testing.T) { t.Parallel() - check, err := NewChecker("http", "example", "http://example.com", 5*time.Second, func(s string) string { + check, err := NewChecker(HTTP, "example", "http://example.com", 5*time.Second, func(s string) string { return "" }) if err != nil { @@ -27,7 +27,7 @@ func TestNewChecker(t *testing.T) { t.Run("Valid TCP checker", func(t *testing.T) { t.Parallel() - check, err := NewChecker("tcp", "example", "example.com:80", 5*time.Second, func(s string) string { + check, err := NewChecker(TCP, "example", "example.com:80", 5*time.Second, func(s string) string { return "" }) if err != nil { @@ -43,7 +43,7 @@ func TestNewChecker(t *testing.T) { t.Run("Valid ICMP checker", func(t *testing.T) { t.Parallel() - check, err := NewChecker("icmp", "example", "example.com", 5*time.Second, func(s string) string { + check, err := NewChecker(ICMP, "example", "example.com", 5*time.Second, func(s string) string { return "" }) if err != nil { @@ -59,130 +59,69 @@ func TestNewChecker(t *testing.T) { t.Run("Invalid checker type", func(t *testing.T) { t.Parallel() - _, err := NewChecker("invalid", "example", "example.com", 5*time.Second, func(s string) string { + _, err := NewChecker(8, "example", "example.com", 5*time.Second, func(s string) string { return "" }) if err == nil { t.Fatal("expected an error, got none") } - expected := "unsupported check type: invalid" + expected := "unsupported check type: 8" if err.Error() != expected { t.Errorf("expected error to be %q, got %q", expected, err.Error()) } }) } -func TestIsValidCheckType(t *testing.T) { +func TestGetCheckTypeString(t *testing.T) { t.Parallel() - t.Run("Valid TCP Check Type", func(t *testing.T) { + t.Run("Check type string (enum)", func(t *testing.T) { t.Parallel() - if isValid := IsValidCheckType("tcp"); !isValid { - t.Errorf("expected true for check type 'tcp', got false") + if HTTP.String() != "HTTP" { + t.Fatalf("expected 'HTTP', got %q", HTTP.String()) } - }) - - t.Run("Valid HTTP Check Type", func(t *testing.T) { - t.Parallel() - - if isValid := IsValidCheckType("http"); !isValid { - t.Errorf("expected true for check type 'http', got false") - } - }) - - t.Run("Invalid Check Type", func(t *testing.T) { - t.Parallel() - - if isValid := IsValidCheckType("invalid"); isValid { - t.Errorf("expected false for check type 'invalid', got true") - } - }) - - t.Run("Empty Check Type", func(t *testing.T) { - t.Parallel() - - if isValid := IsValidCheckType(""); isValid { - t.Errorf("expected false for empty check type, got true") - } - }) - - t.Run("Random String Check Type", func(t *testing.T) { - t.Parallel() - - if isValid := IsValidCheckType("random"); isValid { - t.Errorf("expected false for check type 'random', got true") - } - }) -} - -func TestInferCheckType(t *testing.T) { - t.Parallel() - - t.Run("HTTP scheme", func(t *testing.T) { - t.Parallel() - - checkType, err := InferCheckType("http://example.com") - if err != nil { - t.Fatalf("expected no error, got %q", err) + if TCP.String() != "TCP" { + t.Fatalf("expected 'TCP', got %q", TCP.String()) } - - if checkType != "http" { - t.Fatalf("expected 'http', got %q", checkType) + if ICMP.String() != "ICMP" { + t.Fatalf("expected 'ICMP', got %q", ICMP.String()) } }) - t.Run("TCP scheme", func(t *testing.T) { - t.Parallel() - - checkType, err := InferCheckType("tcp://example.com") + t.Run("Check type string (func)", func(t *testing.T) { + want := HTTP + got, err := GetCheckTypeFromString("http") if err != nil { t.Fatalf("expected no error, got %q", err) } - - if checkType != "tcp" { - t.Fatalf("expected 'tcp', got %q", checkType) + if want != got { + t.Fatalf("expected %q, got %q", want, got) } - }) - - t.Run("ICMP scheme", func(t *testing.T) { - t.Parallel() - checkType, err := InferCheckType("icmp://host.example.com") + want = TCP + got, err = GetCheckTypeFromString("tcp") if err != nil { t.Fatalf("expected no error, got %q", err) } - - if checkType != "icmp" { - t.Fatalf("expected 'http', got %q", checkType) + if want != got { + t.Fatalf("expected %q, got %q", want, got) } - }) - - t.Run("No scheme", func(t *testing.T) { - t.Parallel() - checkType, err := InferCheckType("example.com:80") + want = ICMP + got, err = GetCheckTypeFromString("icmp") if err != nil { t.Fatalf("expected no error, got %q", err) } - - if checkType != "" { - t.Fatalf("expected 'tcp', got %q", checkType) + if want != got { + t.Fatalf("expected %q, got %q", want, got) } - }) - t.Run("Unsupported scheme", func(t *testing.T) { - t.Parallel() - - _, err := InferCheckType("ftp://example.com") + want = -1 + got, err = GetCheckTypeFromString("invalid") if err == nil { t.Fatal("expected an error, got none") } - - expected := "unsupported scheme: ftp" - if err.Error() != expected { - t.Errorf("expected error to be %q, got %q", expected, err.Error()) - } }) } diff --git a/internal/checker/utils.go b/internal/checker/utils.go deleted file mode 100644 index 5b5e0f2..0000000 --- a/internal/checker/utils.go +++ /dev/null @@ -1,17 +0,0 @@ -package checker - -import ( - "fmt" - "strings" -) - -// extractScheme extracts the scheme from the address if it exists. -// If the address does not have a scheme, it returns an empty string. -func extractScheme(address string) (string, error) { - parts := strings.SplitN(address, "://", 2) - if len(parts) != 2 { - return "", fmt.Errorf("no scheme found in address: %s", address) - } - - return parts[0], nil -} diff --git a/internal/checker/utils_test.go b/internal/checker/utils_test.go deleted file mode 100644 index fcc1365..0000000 --- a/internal/checker/utils_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package checker - -import ( - "testing" -) - -func TestExtractScheme(t *testing.T) { - t.Parallel() - - t.Run("Valid address with http scheme", func(t *testing.T) { - t.Parallel() - - scheme, err := extractScheme("http://example.com") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - if scheme != "http" { - t.Fatalf("expected scheme 'http', got %q", scheme) - } - }) - - t.Run("Valid address with https scheme", func(t *testing.T) { - t.Parallel() - - scheme, err := extractScheme("https://example.com") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - if scheme != "https" { - t.Fatalf("expected scheme 'https', got %q", scheme) - } - }) - - t.Run("Invalid address without scheme", func(t *testing.T) { - t.Parallel() - - _, err := extractScheme("example.com") - - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "no scheme found in address: example.com" - if err.Error() != expected { - t.Errorf("expected error containing %q, got %q", expected, err) - } - }) - - t.Run("Invalid address with scheme only", func(t *testing.T) { - t.Parallel() - - scheme, err := extractScheme("ftp://") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - if scheme != "ftp" { - t.Fatalf("expected scheme 'ftp', got %q", scheme) - } - }) - - t.Run("Invalid address with missing colon", func(t *testing.T) { - t.Parallel() - - _, err := extractScheme("http//example.com") - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "no scheme found in address: http//example.com" - if err.Error() != expected { - t.Errorf("expected error containing %q, got %q", expected, err) - } - }) -} diff --git a/internal/config/config.go b/internal/config/config.go index 8ac624d..a8778e8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,21 +18,21 @@ const ( envDialTimeout string = "DIAL_TIMEOUT" envLogExtraFields string = "LOG_EXTRA_FIELDS" - defaultTargetCheckType string = "tcp" - defaultCheckInterval time.Duration = 2 * time.Second - defaultDialTimeout time.Duration = 1 * time.Second - defaultLogExtraFields bool = false + defaultTargetCheckType checker.CheckType = checker.TCP + defaultCheckInterval time.Duration = 2 * time.Second + defaultDialTimeout time.Duration = 1 * time.Second + defaultLogExtraFields bool = false ) // Config holds the required environment variables. type Config struct { - Version string // The version of the application. - TargetName string // The name of the target. - TargetAddress string // The address of the target. - TargetCheckType string // Type of check: "tcp" or "http" - CheckInterval time.Duration // The interval between connection attempts. - DialTimeout time.Duration // The timeout for dialing the target. - LogExtraFields bool // Whether to log the fields in the log message. + Version string // The version of the application. + TargetName string // The name of the target. + TargetAddress string // The address of the target. + TargetCheckType checker.CheckType // Type of check: "tcp", "http" or "icmp". + CheckInterval time.Duration // The interval between connection attempts. + DialTimeout time.Duration // The timeout for dialing the target. + LogExtraFields bool // Whether to log the fields in the log message. } // ParseConfig retrieves and parses the required environment variables. @@ -41,7 +41,7 @@ func ParseConfig(getEnv func(string) string) (Config, error) { cfg := Config{ TargetName: getEnv(envTargetName), TargetAddress: getEnv(envTargetAddress), - TargetCheckType: getEnv(envTargetCheckType), + TargetCheckType: defaultTargetCheckType, CheckInterval: defaultCheckInterval, DialTimeout: defaultDialTimeout, LogExtraFields: defaultLogExtraFields, @@ -98,22 +98,38 @@ func ParseConfig(getEnv func(string) string) (Config, error) { cfg.LogExtraFields = logExtraFields } - // Infer TargetCheckType if not provided - if cfg.TargetCheckType == "" { - checkType, err := checker.InferCheckType(cfg.TargetAddress) + // Resolve TargetCheckType + if err := resolveTargetCheckType(&cfg, getEnv); err != nil { + return Config{}, err + } + + return cfg, nil +} + +// resolveTargetCheckType handles the logic for determining the check type +func resolveTargetCheckType(cfg *Config, getEnv func(string) string) error { + // First, check if envTargetCheckType is explicitly set + if checkTypeStr := getEnv(envTargetCheckType); checkTypeStr != "" { + checkType, err := checker.GetCheckTypeFromString(checkTypeStr) if err != nil { - return Config{}, fmt.Errorf("could not infer check type for address %s: %w", cfg.TargetAddress, err) - } - if checkType == "" { - checkType = defaultTargetCheckType + return fmt.Errorf("invalid check type from environment: %w", err) } cfg.TargetCheckType = checkType + return nil } - // Validate the check type - if !checker.IsValidCheckType(cfg.TargetCheckType) { - return Config{}, fmt.Errorf("unsupported check type: %s", cfg.TargetCheckType) + // If not set, try to infer from the TargetAddress scheme + parts := strings.SplitN(cfg.TargetAddress, "://", 2) // parts[0] is the scheme, parts[1] is the address + if len(parts) == 2 { + checkType, err := checker.GetCheckTypeFromString(parts[0]) + if err != nil { + return fmt.Errorf("could not infer check type from address %s: %w", cfg.TargetAddress, err) + } + cfg.TargetCheckType = checkType + return nil } - return cfg, nil + // Fallback to default check type if neither is set or inferred + cfg.TargetCheckType = defaultTargetCheckType + return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index fbb4579..0fd9c4e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -5,6 +5,8 @@ import ( "reflect" "testing" "time" + + "github.com/containeroo/portpatrol/internal/checker" ) func TestParseConfig(t *testing.T) { @@ -28,7 +30,7 @@ func TestParseConfig(t *testing.T) { expected := Config{ TargetName: "example.com", // Extracted from TargetAddress TargetAddress: "http://example.com", - TargetCheckType: "http", + TargetCheckType: checker.HTTP, CheckInterval: 2 * time.Second, DialTimeout: 1 * time.Second, } @@ -58,7 +60,7 @@ func TestParseConfig(t *testing.T) { expected := Config{ TargetName: "www.example.com", // Extracted from TargetAddress TargetAddress: "www.example.com:80", - TargetCheckType: "http", + TargetCheckType: checker.HTTP, CheckInterval: 5 * time.Second, DialTimeout: 10 * time.Second, } @@ -88,7 +90,7 @@ func TestParseConfig(t *testing.T) { expected := Config{ TargetName: "postgres.postgres.svc.cluster.local", // Extracted from TargetAddress TargetAddress: "http://postgres.postgres.svc.cluster.local:80", - TargetCheckType: "http", + TargetCheckType: checker.HTTP, CheckInterval: 5 * time.Second, DialTimeout: 10 * time.Second, } @@ -118,7 +120,7 @@ func TestParseConfig(t *testing.T) { expected := Config{ TargetName: "example.com", // Extracted from TargetAddress TargetAddress: "tcp://example.com:80", - TargetCheckType: "tcp", + TargetCheckType: checker.TCP, CheckInterval: 5 * time.Second, DialTimeout: 10 * time.Second, } @@ -271,7 +273,7 @@ func TestParseConfig(t *testing.T) { expected := Config{ TargetName: "example.com", TargetAddress: "http://example.com", - TargetCheckType: "http", + TargetCheckType: checker.HTTP, CheckInterval: 2 * time.Second, DialTimeout: 1 * time.Second, LogExtraFields: true, @@ -321,7 +323,7 @@ func TestParseConfig(t *testing.T) { expected := Config{ TargetName: "example.com", TargetAddress: "example.com:80", - TargetCheckType: "tcp", + TargetCheckType: checker.TCP, CheckInterval: 2 * time.Second, DialTimeout: 1 * time.Second, LogExtraFields: false, @@ -347,7 +349,7 @@ func TestParseConfig(t *testing.T) { t.Fatal("expected an error, got none") } - expected := "unsupported check type: invalid" + expected := "invalid check type from environment: unsupported check type: invalid" if err.Error() != expected { t.Errorf("expected error to contain %q, got %q", expected, err) } @@ -368,7 +370,7 @@ func TestParseConfig(t *testing.T) { t.Fatal("expected an error, got none") } - expected := "could not infer check type for address htp://example.com: unsupported scheme: htp" + expected := "could not infer check type from address htp://example.com: unsupported check type: htp" if err.Error() != expected { t.Fatalf("expected error to contain %q, got %q", expected, err) } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 7d9a884..5fdfa1b 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -17,7 +17,7 @@ func SetupLogger(cfg config.Config, output io.Writer) *slog.Logger { slog.String("target_address", cfg.TargetAddress), slog.String("interval", cfg.CheckInterval.String()), slog.String("dial_timeout", cfg.DialTimeout.String()), - slog.String("checker_type", cfg.TargetCheckType), + slog.String("checker_type", cfg.TargetCheckType.String()), slog.String("version", cfg.Version), ) } diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index fe4e1ee..afa96d4 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/containeroo/portpatrol/internal/checker" "github.com/containeroo/portpatrol/internal/config" ) @@ -21,7 +22,7 @@ func TestSetupLogger(t *testing.T) { TargetAddress: "localhost:8080", CheckInterval: 1 * time.Second, DialTimeout: 2 * time.Second, - TargetCheckType: "http", + TargetCheckType: checker.HTTP, LogExtraFields: true, } var buf bytes.Buffer @@ -46,7 +47,7 @@ func TestSetupLogger(t *testing.T) { t.Errorf("Expected log output to contain %q, got %q", expected, logOutput) } - expected = "checker_type=http" + expected = "checker_type=HTTP" if !strings.Contains(logOutput, expected) { t.Errorf("Expected log output to contain %q, got %q", expected, logOutput) } diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index b0841f8..b46d340 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -133,7 +133,7 @@ func TestLoopUntilReadyHTTP(t *testing.T) { TargetAddress: "http://localhost:6081/success", CheckInterval: 500 * time.Millisecond, DialTimeout: 500 * time.Millisecond, - TargetCheckType: "http", + TargetCheckType: checker.HTTP, LogExtraFields: true, Version: "1.0.0", } @@ -257,7 +257,7 @@ func TestLoopUntilReadyHTTP(t *testing.T) { TargetAddress: "http://localhost:2081/wrong", CheckInterval: 500 * time.Millisecond, DialTimeout: 500 * time.Millisecond, - TargetCheckType: "http", + TargetCheckType: checker.HTTP, LogExtraFields: true, Version: "1.0.0", } @@ -369,7 +369,7 @@ func TestLoopUntilReadyHTTP(t *testing.T) { TargetAddress: "http://localhost:7083/fail", CheckInterval: 50 * time.Millisecond, DialTimeout: 50 * time.Millisecond, - TargetCheckType: "http", + TargetCheckType: checker.HTTP, } mockEnv := func(key string) string { @@ -430,7 +430,7 @@ func TestLoopUntilReadyTCP(t *testing.T) { TargetAddress: listener.Addr().String(), CheckInterval: 50 * time.Millisecond, DialTimeout: 50 * time.Millisecond, - TargetCheckType: "tcp", + TargetCheckType: checker.TCP, } checker, err := checker.NewTCPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout) @@ -501,7 +501,7 @@ func TestLoopUntilReadyTCP(t *testing.T) { TargetAddress: "localhost:5081", CheckInterval: 500 * time.Millisecond, DialTimeout: 500 * time.Millisecond, - TargetCheckType: "tcp", + TargetCheckType: checker.TCP, LogExtraFields: true, Version: "1.0.0", } @@ -640,7 +640,7 @@ func TestLoopUntilReadyTCP(t *testing.T) { TargetAddress: "localhost:7084", CheckInterval: 50 * time.Millisecond, DialTimeout: 50 * time.Millisecond, - TargetCheckType: "tcp", + TargetCheckType: checker.TCP, } checker, err := checker.NewTCPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout)