diff --git a/Makefile b/Makefile index 5999804..9eb60c9 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ test: # Generate and display test coverage cover: - sudo go test ./... -count=1 -coverprofile=coverage.out + sudo go test ./cmd/... ./pkg/... ./internal/... -count=1 -coverprofile=coverage.out go tool cover -html=coverage.out # Clean up generated files diff --git a/README.md b/README.md index 38933a5..8baddca 100644 --- a/README.md +++ b/README.md @@ -4,47 +4,131 @@ # PortPatrol -`PortPatrol` is a simple Go application that checks if a specified `TCP`, `HTTP` or `ICMP` target is available. It continuously attempts to connect to the specified target at regular intervals until the target becomes available or the program is terminated. Intended to run as a Kubernetes initContainer, `PortPatrol` helps verify whether a dependency is ready. Configuration is managed through environment variables; for more details, refer to the [Environment Variables](#EnvironmentVariables) section." +`PortPatrol` is a simple Go application that checks if a specified `TCP`, `HTTP` or `ICMP` target is available. It continuously attempts to connect to the specified target at regular intervals until the target becomes available or the program is terminated. Intended to run as a Kubernetes initContainer, `PortPatrol` helps verify whether a dependency is ready. The configuration is done through startup arguments. +You can check multiple targets at once. -## Environment Variables -`PortPatrol` accepts the following environment variables: +## Command-Line Flags -### Common Variables +`PortPatrol` accepts the following command-line flags: -- `TARGET_NAME`: Name assigned to the target (optional, default: inferred from `TARGET_ADDRESS`). If not specified, it's derived from the target address. For example, `http://postgres.default.svc.cluster.local:5432` is inferred as `postgres.default.svc.cluster.local`. -- `TARGET_ADDRESS`: The target's address in the following formats: +### Common Flags - - **TCP**: `host:port` (port is required). - - **HTTP**: `scheme://host[:port]` (scheme is required). - - **ICMP**: `host` (no scheme and port allowed). +| Flag | Type | Default | Description | +|-----------------------|----------|---------|-----------------------------------------------------------------------------------------------| +| `--default-interval` | duration | `2s` | Default interval between checks. Can be overridden for each target. | +| `--debug` | bool | `false` | Enable logging of additional fields. | +| `--version` | bool | `false` | Show version and exit. | +| `--help`, `-h` | bool | `false` | Show help. | - You can always specify a scheme (e.g., `http://`, `tcp://`, `icmp://`) in `TARGET_ADDRESS`, which automatically infers the `TARGET_CHECK_TYPE`, making the `TARGET_CHECK_TYPE` variable optional. +### Target Flags -- `TARGET_CHECK_TYPE`: Specifies the type of check (`tcp`, `http`, `https`, or `icmp`). If no scheme is provided in `TARGET_ADDRESS`, this variable determines the check type. If a scheme is provided, `TARGET_CHECK_TYPE` becomes obsolete. -- `CHECK_INTERVAL`: Time between connection attempts (optional, default: `2s`). -- `DIAL_TIMEOUT`: Maximum allowed time for each connection attempt (optional, default: `1s`). -- `LOG_EXTRA_FIELDS`: Enable logging of additional fields (optional, default: `false`). +`PortPatrol` accepts "dynamic" flags that can be defined in the startup arguments. +Use the `--..=` format to define targets. +Types are: `http`, `icmp` or `tcp`. -### HTTP-Specific Variables +#### HTTP-Flags -- `HTTP_METHOD`: HTTP method to use (optional, default: `GET`). -- `HTTP_HEADERS`: Comma-separated list of HTTP headers in `key=value` format (optional). Examples: - - `Authorization=Bearer token` - - `Content-Type=application/json,Accept=application/json` -- `HTTP_ALLOW_DUPLICATE_HEADERS`: Allow duplicate headers (optional, default: `false`). -- `HTTP_EXPECTED_STATUS_CODES`: Comma-separated list of expected HTTP status codes or ranges (optional, default: `200`). You can specify individual status codes or ranges: - - `200,301,404` - - `200,300-302` - - `200,301-302,404,500-502` -- `HTTP_SKIP_TLS_VERIFY`: Skip TLS verification (optional, default: `false`). -- `HTTP_PROXY`: HTTP proxy to use (optional). -- `HTTPS_PROXY`: HTTPS proxy to use (optional). -- `NO_PROXY`: Comma-separated list of domains to exclude from proxying (optional). +- **`--http..name`** = `string` + The name of the target. If not specified, it uses the `` as the name. -### ICMP-Specific Variables +- **`--http..address`** = `string` + The target's address. + **Resolvable:** `env:ENV_VAR`, `file:path/to/file.txt`. see below. -- `ICMP_READ_TIMEOUT`: Maximum allowed time for each ICMP echo reply (optional, default: `1s`). + - **`--http..interval`** = `duration` + The interval between HTTP requests (e.g., `1s`). Overwrites the global `--default-interval`. + +- **`--http..method`** = `string` + The HTTP method to use (e.g., `GET`, `POST`). Defaults to `GET`. + +- **`--http..header`** = `string` + A HTTP header in `key=value` format. Can be specified multiple times. + **Example:** `Authorization=Bearer token` + **Resolvable:** The value of the Header is resolvable: `env:ENV_VAR`, `file:path/to/file.txt`. see below. + +- **`--http..allow-duplicate-headers`** = `bool` + Allow duplicate headers. Defaults to `false`. + +- **`--http..expected-status-codes`** = `string` + A comma-separated list of expected HTTP status codes or ranges (e.g., `200,301-302`). Defaults to `200`. + +- **`--http..skip-tls-verify`** = `bool` + Whether to skip TLS verification. Defaults to `false`. + +- **`--http..timeout`** = `duration` + The timeout for the HTTP request (e.g., `5s`). Defaults to `1s`. + +#### ICMP Flags + +- **`--icmp..name`** = `string` + The name of the target. If not specified, it uses the `` as the name. + +- **`--icmp..address`** = `string` + The target's address. + **Resolvable:** The value of the Address is resolvable: `env:ENV_VAR`, `file:path/to/file.txt`. + +- **`--icmp..interval`** = `duration` + The interval between ICMP requests (e.g., `1s`). Overwrites the global `--default-interval`. + +- **`--icmp..read-timeout`** = `duration` + The read timeout for the ICMP connection (e.g., `1s`). Defaults to `1s`. + +- **`--icmp..write-timeout`** = `duration` + The write timeout for the ICMP connection (e.g., `1s`).Defaults to `1s`. + +### TCP Flags + +- **`--tcp..name`** = `string` + The name of the target. If not specified, it uses the `` as the name. + +- **`--tcp..address`** = `string` + The target's address. + **Resolvable:** `env:ENV_VAR`, `file:path/to/file.txt`. see below. + +- **`--tcp..interval`** = `duration` + The interval between ICMP requests (e.g., `1s`). Overwrites the global `--default-interval`. + +### Resolving variables + +Each `address` field can be resolved using environment variables, files, or plain text: + +- **Plain Text**: Simply input the credentials directly in plain text. +- **Environment Variable**: Use the `env:` prefix, followed by the name of the environment variable that stores the credentials. +- **File**: Use the `file:` prefix, followed by the path of the file that contains the credentials. The file should contain only the credentials. + +In case the file contains multiple key-value pairs, the specific key for the credentials can be selected by appending `//KEY` to the end of the path. Each key-value pair in the file must follow the `key = value` format. The system will use the value corresponding to the specified `//KEY`. + +HTTP headers values can also be resolved using the same mechanism, (`-- + +### Examples + +#### Define an HTTP Target + +```sh +portpatrol \ + --http.web.address=http://example.com:80 \ + --http.web.method=GET \ + --http.web.expected-status-codes=200,204 \ + --http.web.header="Authorization=Bearer token" \ + --http.web.header="Content-Type=application/json" \ + --http.web.skip-tls-verify=false \ + --default-interval=5s \ + --debug +``` + +#### Define Multiple Targets (HTTP and TCP) Running in Parallel + +```sh +portpatrol \ + --http.web.address=http://example.com:80 \ + --tcp.db.address=tcp://localhost:5432 \ + --default-interval=10s +``` + +#### Notes + +**Proxy Settings**: Proxy configurations (`HTTP_PROXY`, `HTTPS_PROXY`, `NO_PROXY`) are managed via environment variables. ## Behavior Flowchart @@ -213,29 +297,6 @@ class MainFlow,RetryLoop transparent; -## Logging - -With the `LOG_EXTRA_FIELDS` environment variable set to true, additional fields will be logged. - -### With additional fields - -```text -ts=2024-07-05T13:08:20+02:00 level=INFO msg="Waiting for PostgreSQL to become ready..." dial_timeout="1s" interval="2s" target_address="postgres.default.svc.cluster.local:5432" target_name="PostgreSQL" version="0.0.22" -ts=2024-07-05T13:08:21+02:00 level=WARN msg="PostgreSQL is not ready ✗" dial_timeout="1s" error="dial tcp: lookup postgres.default.svc.cluster.local: i/o timeout" interval="2s" target_address="postgres.default.svc.cluster.local:5432" target_name="PostgreSQL" version="0.0.22" -ts=2024-07-05T13:08:24+02:00 level=WARN msg="PostgreSQL is not ready ✗" dial_timeout="1s" error="dial tcp: lookup postgres.default.svc.cluster.local: i/o timeout" interval="2s" target_address="postgres.default.svc.cluster.local:5432" target_name="PostgreSQL" version="0.0.22" -ts=2024-07-05T13:08:27+02:00 level=WARN msg="PostgreSQL is not ready ✗" dial_timeout="1s" error="dial tcp: lookup postgres.default.svc.cluster.local: i/o timeout" interval="2s" target_address="postgres.default.svc.cluster.local:5432" target_name="PostgreSQL" version="0.0.22" -ts=2024-07-05T13:08:27+02:00 level=INFO msg="PostgreSQL is ready ✓" dial_timeout="1s" error="dial tcp: lookup postgres.default.svc.cluster.local: i/o timeout" interval="2s" target_address="postgres.default.svc.cluster.local:5432" target_name="PostgreSQL" version="0.0.22" -``` - -### Without additional fields - -```text -time=2024-07-12T12:44:41.494Z level=INFO msg="Waiting for PostgreSQL to become ready..." -time=2024-07-12T12:44:41.512Z level=WARN msg="PostgreSQL is not ready ✗" -time=2024-07-12T12:44:43.532Z level=WARN msg="PostgreSQL is not ready ✗" -time=2024-07-12T12:44:45.552Z level=INFO msg="PostgreSQL is ready ✓" -``` - ## Kubernetes initContainer Configuration Configure your Kubernetes deployment to use this init container: @@ -244,63 +305,27 @@ Configure your Kubernetes deployment to use this init container: initContainers: - name: wait-for-vm image: ghcr.io/containeroo/portpatrol:latest - env: - - name: TARGET_ADDRESS - value: icmp://hostname.domain.tld + args: + - --icmp.vm.address=hostname.domain.tld securityContext: # icmp requires CAP_NET_RAW readOnlyRootFilesystem: true allowPrivilegeEscalation: false capabilities: add: ["CAP_NET_RAW"] - - name: wait-for-valkey + - name: wait-for-it image: ghcr.io/containeroo/portpatrol:latest - env: - - name: TARGET_ADDRESS - value: valkey.default.svc.cluster.local:6379 - - name: wait-for-valkey - image: ghcr.io/containeroo/portpatrol:latest - env: - - name: TARGET_NAME - value: Valkey - - name: TARGET_ADDRESS - value: valkey.default.svc.cluster.local:6379 - - name: TARGET_CHECK_TYPE - value: tcp # Specify the type of check - - name: CHECK_INTERVAL - value: "5s" # Specify the interval duration, e.g., 5 seconds - - name: DIAL_TIMEOUT - value: "5s" # Specify the dial timeout duration, e.g., 5 seconds - - name: LOG_EXTRA_FIELDS - value: "true" - - name: wait-for-postgres - image: ghcr.io/containeroo/portpatrol:latest - env: - - name: TARGET_ADDRESS - value: http://postgres.default.svc.cluster.local:9000/healthz # use healthz endpoint to check if postgres is ready - # TARGET_NAME will be inferred from TARGET_ADDRESS to postgres.default.svc.cluster.local - # TARGET_CHECK_TYPE is not not necessary, because TARGET_ADDRESS has a scheme (http://) - # HTTP_METHOD is not necessary, because the default is GET - # HTTP_EXPECTED_STATUS_CODES is not necessary, because the default is 200 and /healthz returns 200 if the service is ready - # CHECK_INTERVAL defaults to 2 seconds which is okay for a health check - # DIAL_TIMEOUT defaults to 1 second which is okay for a health check - - name: wait-for-webapp - image: ghcr.io/containeroo/portpatrol:latest - env: - - name: TARGET_NAME - value: webapp - - name: TARGET_ADDRESS - value: webapp.default.svc.cluster.local:8080 - - name: TARGET_CHECK_TYPE - value: http - - name: HTTP_METHOD - value: "POST" - - name: HTTP_HEADERS - value: "Authorization=Bearer token" - - name: HTTP_EXPECTED_STATUS_CODES - value: "200,202" - - name: CHECK_INTERVAL - value: "5s" # Specify the interval duration, e.g., 5 seconds - - name: DIAL_TIMEOUT - value: "2s" # Specify the dial timeout duration, e.g., 2 seconds -``` + args: + - --target.postgres.address=postgres.default.svc.cluster.local:9000/healthz # use healthz endpoint to check if postgres is ready + - --target.postgres.method=POST + - --target.postgres.header=Authorization=env:BEARER_TOKEN + - --target.postgres.expected-status-codes=200,202 + - --target.redis.name=redis + - --target.redis.address=redis.default.svc.cluster.local:6437 + - --tcp.vaultkey.address=valkey.default.svc.cluster.local:6379 + - --tcp.vaultkey.interval=5s + - --tcp.vaultkey.timeout=5s + envFrom: + - secretRef: + name: bearer-token +``` diff --git a/cmd/portpatrol/main.go b/cmd/portpatrol/main.go index 6ea3200..7c56fa7 100644 --- a/cmd/portpatrol/main.go +++ b/cmd/portpatrol/main.go @@ -2,50 +2,77 @@ package main import ( "context" + "errors" "fmt" "io" "os" "os/signal" "syscall" - "github.com/containeroo/portpatrol/internal/checker" "github.com/containeroo/portpatrol/internal/config" - "github.com/containeroo/portpatrol/internal/logger" - "github.com/containeroo/portpatrol/internal/runner" + "github.com/containeroo/portpatrol/internal/factory" + "github.com/containeroo/portpatrol/internal/logging" + "github.com/containeroo/portpatrol/internal/wait" + "golang.org/x/sync/errgroup" ) -const version = "0.4.7" +const version = "0.5.0" // run is the main function of the application. -func run(ctx context.Context, getEnv func(string) string, output io.Writer) error { +func run(ctx context.Context, args []string, output io.Writer) error { // Create a new context that listens for interrupt signals - // and cancels the context when received. Ensures proper resource cleanup. ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer cancel() - cfg, err := config.ParseConfig(getEnv) + // Parse command-line flags + f, err := config.ParseFlags(args, version, output) if err != nil { + if errors.Is(err, &config.HelpRequested{}) { + fmt.Fprint(output, err.Error()) + return nil + } return fmt.Errorf("configuration error: %w", err) } - cfg.Version = version - logger := logger.SetupLogger(cfg, output) - - targetChecker, err := checker.NewChecker(cfg.TargetCheckType, cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout, getEnv) + // Initialize target checkers + checkers, err := factory.BuildCheckers(f.DynFlags, f.DefaultCheckInterval) if err != nil { - return fmt.Errorf("failed to initialize checker: %w", err) + return fmt.Errorf("failed to initialize target checkers: %w", err) + } + + if len(checkers) == 0 { + return errors.New("configuration error: no checkers configured") + } + + logger := logging.SetupLogger(version, output) + + // Run checkers concurrently + eg, ctx := errgroup.WithContext(ctx) + for _, chk := range checkers { + checker := chk // Capture loop variable + eg.Go(func() error { + err := wait.WaitUntilReady(ctx, checker.Interval, checker.Checker, logger) + if err != nil { + return fmt.Errorf("checker '%s' failed: %w", checker.Checker.GetName(), err) + } + return nil + }) + } + + // Wait for all checkers to finish or return error + if err := eg.Wait(); err != nil { + return err } - return runner.LoopUntilReady(ctx, cfg.CheckInterval, targetChecker, logger) + return nil } func main() { - // Create a root context with no cancellation or deadline. This is the top-level context - // that all other contexts will derive from in the application. + // Create a root context ctx := context.Background() - if err := run(ctx, os.Getenv, os.Stdout); err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err) + if err := run(ctx, os.Args[1:], os.Stdout); err != nil { + fmt.Fprintln(os.Stderr, err) os.Exit(1) } } diff --git a/cmd/portpatrol/main_test.go b/cmd/portpatrol/main_test.go index 60a551f..420352d 100644 --- a/cmd/portpatrol/main_test.go +++ b/cmd/portpatrol/main_test.go @@ -3,274 +3,161 @@ package main import ( "bytes" "context" - "fmt" "net" "net/http" "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) -func TestRun(t *testing.T) { +func TestRunHTTPReady(t *testing.T) { t.Parallel() - const ( - envTargetName string = "TARGET_NAME" - envTargetAddress string = "TARGET_ADDRESS" - envTargetCheckType string = "TARGET_CHECK_TYPE" - envCheckInterval string = "CHECK_INTERVAL" - envDialTimeout string = "DIAL_TIMEOUT" - envLogAdditionalFields string = "LOG_EXTRA_FIELDS" - envHTTPHeaders string = "HTTP_HEADERS" - ) - - t.Run("HTTP Target is ready", func(t *testing.T) { - t.Parallel() - - env := map[string]string{ - envTargetAddress: "http://localhost:8081", - envCheckInterval: "1s", - envDialTimeout: "1s", - envTargetCheckType: "http", - } - - mockEnv := func(key string) string { - return env[key] - } - - server := &http.Server{Addr: ":8081"} - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - go func() { // make linter happy - _ = server.ListenAndServe() - }() - defer server.Close() - - var output strings.Builder - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - // cancel after 2 Seconds - go func() { - time.Sleep(2 * time.Second) - cancel() - }() - - err := run(ctx, mockEnv, &output) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - outputEntries := strings.Split(strings.TrimSpace(output.String()), "\n") - last := len(outputEntries) - 1 - - expected := "localhost is ready ✓" - if !strings.Contains(outputEntries[last], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, output.String()) - } - }) - - t.Run("TCP Target is ready", func(t *testing.T) { - t.Parallel() - - env := map[string]string{ - envTargetAddress: "localhost:8082", - envCheckInterval: "1s", - envDialTimeout: "1s", - } - - mockEnv := func(key string) string { - return env[key] - } - - listener, err := net.Listen("tcp", "localhost:8082") - if err != nil { - t.Fatalf("Failed to start TCP server: %q", err) - } - defer listener.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + args := []string{ + "--http.httpcheck.name=HTTPServer", + "--http.httpcheck.address=http://localhost:8081", + "--http.httpcheck.interval=1s", + "--http.httpcheck.timeout=1s", + } - // cancel after 2 Seconds - go func() { - time.Sleep(2 * time.Second) - cancel() - }() - - var output strings.Builder - - err = run(ctx, mockEnv, &output) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - outputEntries := strings.Split(strings.TrimSpace(output.String()), "\n") - last := len(outputEntries) - 1 - - expected := "localhost is ready ✓" - if !strings.Contains(outputEntries[last], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, output.String()) - } + server := &http.Server{Addr: ":8081"} + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) }) - t.Run("ICMP Target is ready", func(t *testing.T) { - t.Parallel() + go func() { _ = server.ListenAndServe() }() + defer server.Close() - env := map[string]string{ - envTargetAddress: "icmp://127.0.0.1", - envCheckInterval: "1s", - envDialTimeout: "1s", - envTargetCheckType: "icmp", - } + var output strings.Builder + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() - mockEnv := func(key string) string { - return env[key] - } + err := run(ctx, args, &output) + assert.NoError(t, err) - var output strings.Builder + outputEntries := strings.Split(strings.TrimSpace(output.String()), "\n") + last := len(outputEntries) - 1 - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + assert.Contains(t, outputEntries[last], "HTTPServer is ready ✓") +} - // cancel after 2 Seconds - go func() { - time.Sleep(2 * time.Second) - cancel() - }() +func TestRunTCPReady(t *testing.T) { + t.Parallel() - err := run(ctx, mockEnv, &output) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } + args := []string{ + "--tcp.tcptest.name=TCPServer", + "--tcp.tcptest.address=localhost:8082", + "--tcp.tcptest.interval=1s", + "--tcp.tcptest.timeout=1s", + } - outputEntries := strings.Split(strings.TrimSpace(output.String()), "\n") - last := len(outputEntries) - 1 + listener, err := net.Listen("tcp", "localhost:8082") + assert.NoError(t, err) + defer listener.Close() - expected := "127.0.0.1 is ready ✓" - if !strings.Contains(outputEntries[last], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, output.String()) - } - }) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() - t.Run("Config error: variable is required", func(t *testing.T) { - t.Parallel() + var output strings.Builder + err = run(ctx, args, &output) + assert.NoError(t, err) - env := map[string]string{} + outputEntries := strings.Split(strings.TrimSpace(output.String()), "\n") + last := len(outputEntries) - 1 - mockEnv := func(key string) string { - return env[key] - } + assert.Contains(t, outputEntries[last], "TCPServer is ready ✓") +} - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() +func TestRunConfigErrorMissingTarget(t *testing.T) { + t.Parallel() - var output bytes.Buffer + args := []string{} - err := run(ctx, mockEnv, &output) - if err == nil { - t.Fatalf("Expected configuration error, got none") - } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() - expected := fmt.Sprintf("configuration error: %s environment variable is required", envTargetAddress) - if err.Error() != expected { - t.Errorf("Expected configuration error, got %q", err) - } - }) + var output bytes.Buffer + err := run(ctx, args, &output) - t.Run("Config error: unsupported check type", func(t *testing.T) { - t.Parallel() + assert.Error(t, err) + assert.EqualError(t, err, "configuration error: no checkers configured") +} - env := map[string]string{ - envTargetName: "TestService", - envTargetAddress: "localhost:8080", - envCheckInterval: "1s", - envDialTimeout: "1s", - envTargetCheckType: "invalid", - } +func TestRunConfigErrorUnsupportedCheckType(t *testing.T) { + t.Parallel() - mockEnv := func(key string) string { - return env[key] - } + args := []string{ + "--target.unsupported.name=TestService", + "--target.unsupported.address=localhost:8080", + "--target.unsupported.interval=1s", + "--target.unsupported.timeout=1s", + } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() - var output bytes.Buffer + var output bytes.Buffer + err := run(ctx, args, &output) - err := run(ctx, mockEnv, &output) - if err == nil { - t.Error("Expected error, got none") - } + assert.Error(t, err) + assert.EqualError(t, err, "configuration error: no checkers configured") +} - expected := "configuration error: invalid check type from environment: unsupported check type: invalid" - if err.Error() != expected { - t.Errorf("Expected error to contain %q, got %q", expected, err.Error()) - } - }) +func TestRunConfigErrorInvalidHeaders(t *testing.T) { + t.Parallel() - t.Run("Inizalize error: unknown check type", func(t *testing.T) { - t.Parallel() + args := []string{ + "--http.invalidheaders.name=TestService", + "--http.invalidheaders.address=http://localhost:8080", + "--http.invalidheaders.interval=1s", + "--http.invalidheaders.timeout=1s", + "--http.invalidheaders.header=InvalidHeader", + } - env := map[string]string{ - envTargetName: "TestService", - envTargetAddress: "htp://localhost:8080", - envCheckInterval: "1s", - envDialTimeout: "1s", - envHTTPHeaders: "Auportpatrolization Bearer token", - } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() - mockEnv := func(key string) string { - return env[key] - } + var output bytes.Buffer + err := run(ctx, args, &output) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + assert.Error(t, err) + assert.EqualError(t, err, "failed to initialize target checkers: invalid \"--http.invalidheaders.header\": invalid header format: \"InvalidHeader\"") +} - var output bytes.Buffer +func TestRunParseError(t *testing.T) { + t.Parallel() - err := run(ctx, mockEnv, &output) - if err == nil { - t.Error("Expected error, got none") - } + args := []string{ + "--http.invalidheaders.name=TestService", + "--invalid", + } - expected := "configuration error: could not infer check type from address htp://localhost:8080: unsupported check type: htp" - if err.Error() != expected { - t.Errorf("Expected error to contain %q, got %q", expected, err.Error()) - } - }) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() - t.Run("Inizalize error: invalid headers", func(t *testing.T) { - t.Parallel() + var output bytes.Buffer + err := run(ctx, args, &output) - env := map[string]string{ - envTargetName: "TestService", - envTargetAddress: "http://localhost:8080", - envCheckInterval: "1s", - envDialTimeout: "1s", - envHTTPHeaders: "Auportpatrolization Bearer token", - } + assert.Error(t, err) + assert.EqualError(t, err, "configuration error: Flag parsing error: unknown flag: --invalid") +} - mockEnv := func(key string) string { - return env[key] - } +func TestRunShowVersion(t *testing.T) { + t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + args := []string{ + "--http.invalidheaders.name=TestService", + "--http.invalidheaders.address=http://localhost:8080", + "--version", + } - var output bytes.Buffer + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() - err := run(ctx, mockEnv, &output) - if err == nil { - t.Error("Expected error, got none") - } + var output bytes.Buffer + err := run(ctx, args, &output) - expected := fmt.Sprintf("failed to initialize checker: invalid %s value: invalid header format: Auportpatrolization Bearer token", envHTTPHeaders) - if err.Error() != expected { - t.Errorf("Expected error to contain %q, got %q", expected, err.Error()) - } - }) + assert.NoError(t, err) } diff --git a/examples/advanced.go b/examples/advanced.go new file mode 100644 index 0000000..b32aaed --- /dev/null +++ b/examples/advanced.go @@ -0,0 +1,105 @@ +package main + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/spf13/pflag" +) + +func main() { + // args := os.Args[1:] + args := []string{ + "--http.idenitfier1.method", "POST", + "--http.idenitfier1.address", "https://example.com", + "--tcp.idenitfier1.address", "127.0.0.1", + "--tcp.idenitfier1.timeout", "10s", + "--unknown.identifier2.name", "example 2", + } + + var output strings.Builder // create a io.Writer to capture output + + // Initialize pflag with ContinueOnError behavior + flagSet := pflag.NewFlagSet("advanced", pflag.ContinueOnError) + // Add some flags + flagSet.Bool("debug", false, "Set debug mode") + flagSet.SetOutput(&output) // Output to the io.Writer + + // Initialize DynFlags with ContinueOnError behavior + dynFlags := dynflags.New(dynflags.ParseUnknown) + + // Set the output for the DynFlags instance to the same io.Writer + dynFlags.SetOutput(&output) + + // Create a custom usage function for the flagSet instance + flagSet.Usage = func() { + fmt.Fprintln(&output, "Usage: advanced [FLAGS] [DYNAMIC FLAGS..]") + + fmt.Fprintln(&output, "\nGlobal Flags:") + flagSet.PrintDefaults() + + fmt.Fprintln(&output, "\nDynamic Flags:") + dynFlags.PrintDefaults() + } + + // Add a title and description for the usage output + dynFlags.Title("DynFlags Example Application") + dynFlags.Description("This application demonstrates the usage of DynFlags for managing hierarchical flags dynamically.") + dynFlags.Epilog("For more information, see https://github.com/containerish/portpatrol") + + // Register groups and flags + httpGroup := dynFlags.Group("http") + httpGroup.String("method", "GET", "HTTP method to use") + httpGroup.String("address", "", "HTTP target URL") + + tcpGroup := dynFlags.Group("tcp") + tcpGroup.String("address", "", "TCP target address") + tcpGroup.Duration("timeout", 10*time.Second, "TCP timeout") + + // Parse first with dynflags + if err := dynFlags.Parse(args); err != nil { + fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err) + os.Exit(1) + } + + // retrieve all unknown groups + unparsedArgs := dynFlags.UnparsedArgs() + + // Parse flags wich were not parsed by dynflags with pflag + if err := flagSet.Parse(unparsedArgs); err != nil { + fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err) + os.Exit(1) + } + + // Retrieve values from parsed group http + method := dynFlags.Parsed().Lookup("http").Lookup("idenitfier1").Lookup("method") + httpAddress := dynFlags.Parsed().Lookup("http").Lookup("idenitfier1").Lookup("address") + + fmt.Println("Method:", method) + fmt.Println("Address:", httpAddress) + + // Retrieve values from parsed group tcp + tcpAddress := dynFlags.Parsed().Lookup("tcp").Lookup("idenitfier1").Lookup("address") + tcpTimeout := dynFlags.Parsed().Lookup("tcp").Lookup("idenitfier1").Lookup("timeout") + + fmt.Println("TCP Address:", tcpAddress) + fmt.Println("TCP Timeout:", tcpTimeout) + + // Retrieve all unknown groups + unknownGroups := dynFlags.Unknown().Groups() + if len(unknownGroups) > 0 { + fmt.Println("\n=== Unknown Groups ===") + for groupName, groups := range unknownGroups { + fmt.Println("Group:", groupName) + for _, group := range groups { + fmt.Println(" Identifier:", group.Name) + for key, value := range group.Values { + fmt.Printf(" Flag: %s, Value: %v\n", key, value) + } + } + } + } +} diff --git a/examples/simple.go b/examples/simple.go new file mode 100644 index 0000000..8758d22 --- /dev/null +++ b/examples/simple.go @@ -0,0 +1,112 @@ +package main + +import ( + "fmt" + "time" + + "github.com/containeroo/portpatrol/pkg/dynflags" +) + +func main() { + // Initialize DynFlags + dynFlags := dynflags.New(dynflags.ContinueOnError) + + // Define configuration groups and flags + httpGroup := dynFlags.Group("http") + httpGroup.String("method", "GET", "HTTP method to use") + httpGroup.String("address", "", "HTTP target URL") + httpGroup.Bool("secure", true, "Use secure connection (HTTPS)") + httpGroup.Duration("timeout", 5*time.Second, "Request timeout") + + tcpGroup := dynFlags.Group("tcp") + tcpGroup.String("address", "", "TCP target address") + tcpGroup.Duration("timeout", 10*time.Second, "TCP timeout") + + // Simulate CLI arguments + args := []string{ + "--http.identifier1.method", "POST", + "--http.identifier1.address", "https://example.com", + "--tcp.identifier2.address", "127.0.0.1", + "--tcp.identifier2.timeout", "15s", + "--unknown.identifier3.flag", "unknownValue", + } + + // Parse arguments + if err := dynFlags.Parse(args); err != nil { + fmt.Printf("Error parsing flags: %v\n", err) + return + } + + // ITERATION: Iterate over all config groups + fmt.Println("=== Iterating over Config Groups ===") + for groupName, group := range dynFlags.Config().Groups() { + fmt.Printf("Group: %s\n", groupName) + for flagName, flag := range group.Flags { + fmt.Printf(" Flag: %s, Default: %v, Usage: %s\n", flagName, flag.Default, flag.Usage) + } + } + + // ITERATION: Iterate over all parsed groups + fmt.Println("\n=== Iterating over Parsed Groups ===") + for groupName, groups := range dynFlags.Parsed().Groups() { + fmt.Printf("Group: %s\n", groupName) + for _, group := range groups { + fmt.Printf(" Identifier: %s\n", group.Name) + for flagName, value := range group.Values { + fmt.Printf(" Flag: %s, Value: %v\n", flagName, value) + } + } + } + + // ITERATION: Iterate over all unknown groups + fmt.Println("\n=== Iterating over Unknown Groups ===") + for groupName, groups := range dynFlags.Unknown().Groups() { + fmt.Printf("Unknown Group: %s\n", groupName) + for _, group := range groups { + fmt.Printf(" Identifier: %s\n", group.Name) + for flagName, value := range group.Values { + fmt.Printf(" Flag: %s, Value: %v\n", flagName, value) + } + } + } + + // LOOKUP: Direct access using Lookup methods + fmt.Println("\n=== Lookup Example ===") + + // Lookup a config group + httpConfig := dynFlags.Config().Lookup("http") + if httpConfig != nil { + fmt.Printf("Config Group 'http' exists, Flags: %v\n", httpConfig.Flags) + } + + // Lookup the "http" group + httpGroups := dynFlags.Parsed().Lookup("http") + if httpGroups != nil { + // Lookup "identifier1" within the "http" group + httpIdentifier1 := httpGroups.Lookup("identifier1") + if httpIdentifier1 != nil { + method := httpIdentifier1.Lookup("method") + fmt.Printf("HTTP Method (Lookup): %s\n", method) + } + } + + // Lookup the "unknown" group + unknownGroups := dynFlags.Unknown().Lookup("unknown") + if unknownGroups != nil { + // Lookup "identifier3.flag" within the "unknown" group + unknownIdentifier3 := unknownGroups.Lookup("identifier3") + if unknownIdentifier3 != nil { + unknownValue := unknownIdentifier3.Lookup("flag") + fmt.Printf("Unknown Value (Lookup): %s\n", unknownValue) + } + } + + // LOOKUP: Direct flag retrieval from a config group + fmt.Println("\n=== Direct Flag Lookup ===") + if httpConfig != nil { + methodFlag := httpConfig.Lookup("method") + if methodFlag != nil { + fmt.Printf("HTTP Method Flag: Default = %v, Usage = %s\n", methodFlag.Default, methodFlag.Usage) + } + } +} diff --git a/go.mod b/go.mod index 82caf57..42cc717 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,16 @@ module github.com/containeroo/portpatrol go 1.23.2 -require golang.org/x/net v0.32.0 +require ( + github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.10.0 + golang.org/x/net v0.32.0 + golang.org/x/sync v0.8.0 +) -require golang.org/x/sys v0.28.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 405dbb2..cbdbeb5 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/checker/checker.go b/internal/checker/checker.go index cc0711e..968e1b8 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -4,10 +4,9 @@ import ( "context" "fmt" "strings" - "time" ) -// CheckType is an enumeration that represents the type of check being performed. +// CheckType represents the type of check to perform. type CheckType int const ( @@ -21,28 +20,33 @@ func (c CheckType) String() string { return [...]string{"TCP", "HTTP", "ICMP"}[c] } +// Option defines a functional option for configuring a Checker. +type Option interface { + apply(Checker) +} + +// OptionFunc is a function that applies an Option to a Checker. +type OptionFunc func(Checker) + +// apply calls the OptionFunc with the given Checker. +func (f OptionFunc) apply(c Checker) { + f(c) +} + // Checker defines an interface for performing various types of checks, such as TCP, HTTP, or ICMP. // It provides methods for executing the check and obtaining a string representation of the checker. type Checker interface { // Check performs a check and returns an error if the check fails. Check(ctx context.Context) error - // String returns the name of the checker. - String() string -} + // GetName returns the name of the checker. + GetName() string -// Factory function that returns the appropriate Checker based on checkType. -func NewChecker(checkType CheckType, name, address string, timeout time.Duration, getEnv func(string) string) (Checker, error) { - switch checkType { - case HTTP: // HTTP and HTTPS checkers may need environment variables for proxy settings, etc. - return NewHTTPChecker(name, address, timeout, getEnv) - case TCP: // TCP checkers may not need environment variables - return NewTCPChecker(name, address, timeout) - case ICMP: // ICMP checkers may have a different timeout logic - return NewICMPChecker(name, address, timeout, getEnv) - default: - return nil, fmt.Errorf("unsupported check type: %d", checkType) - } + // GetType returns the type of the checker. + GetType() string + + // GetAddress returns the address of the checker. + GetAddress() string } // GetCheckTypeFromString converts a string to a CheckType enum. @@ -58,3 +62,18 @@ func GetCheckTypeFromString(checkTypeStr string) (CheckType, error) { return -1, fmt.Errorf("unsupported check type: %s", checkTypeStr) } } + +// NewChecker creates a new Checker based on the specified CheckType, name, address, and options. +func NewChecker(checkType CheckType, name, address string, opts ...Option) (Checker, error) { + // Create the appropriate checker based on the type + switch checkType { + case HTTP: + return newHTTPChecker(name, address, opts...) + case TCP: + return newTCPChecker(name, address, opts...) + case ICMP: + return newICMPChecker(name, address, opts...) + default: + return nil, fmt.Errorf("unsupported check type: %d", checkType) + } +} diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index 63ae753..80f7903 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -2,7 +2,8 @@ package checker import ( "testing" - "time" + + "github.com/stretchr/testify/assert" ) func TestNewChecker(t *testing.T) { @@ -11,117 +12,106 @@ func TestNewChecker(t *testing.T) { t.Run("Valid HTTP checker", func(t *testing.T) { t.Parallel() - check, err := NewChecker(HTTP, "example", "http://example.com", 5*time.Second, func(s string) string { - return "" - }) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - expected := "example" - if check.String() != expected { - t.Fatalf("expected name to be %q got %q", expected, check.String()) - } + check, err := NewChecker(HTTP, "example", "http://example.com") + + assert.NoError(t, err) + assert.Equal(t, check.GetName(), "example") + assert.Equal(t, check.GetType(), "HTTP") }) t.Run("Valid TCP checker", func(t *testing.T) { t.Parallel() - check, err := NewChecker(TCP, "example", "example.com:80", 5*time.Second, func(s string) string { - return "" - }) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - expected := "example" - if check.String() != expected { - t.Fatalf("expected name to be %q got %q", expected, check.String()) - } + check, err := NewChecker(TCP, "example", "example.com:80") + + assert.NoError(t, err) + assert.Equal(t, check.GetName(), "example") + assert.Equal(t, check.GetType(), "TCP") }) t.Run("Valid ICMP checker", func(t *testing.T) { t.Parallel() - check, err := NewChecker(ICMP, "example", "example.com", 5*time.Second, func(s string) string { - return "" - }) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - expected := "example" - if check.String() != expected { - t.Fatalf("expected name to be %q got %q", expected, check.String()) - } + check, err := NewChecker(ICMP, "example", "example.com") + + assert.NoError(t, err) + assert.Equal(t, check.GetName(), "example") + assert.Equal(t, check.GetType(), "ICMP") }) t.Run("Invalid checker type", func(t *testing.T) { t.Parallel() - _, err := NewChecker(8, "example", "example.com", 5*time.Second, func(s string) string { - return "" - }) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "unsupported check type: 8" - if err.Error() != expected { - t.Errorf("expected error to be %q, got %q", expected, err.Error()) - } + _, err := NewChecker(99, "example", "example.com") + + assert.Error(t, err) + assert.EqualError(t, err, "unsupported check type: 99") }) } -func TestGetCheckTypeString(t *testing.T) { +func TestGetCheckTypeFromString(t *testing.T) { t.Parallel() - t.Run("Check type string (enum)", func(t *testing.T) { + t.Run("Check type HTTP", func(t *testing.T) { + t.Parallel() + + result, err := GetCheckTypeFromString("HTTP") + + assert.NoError(t, err) + assert.Equal(t, result, HTTP) + }) + + t.Run("Check type http", func(t *testing.T) { + t.Parallel() + + result, err := GetCheckTypeFromString("http") + + assert.NoError(t, err) + assert.Equal(t, result, HTTP) + }) + + t.Run("Check type TCP", func(t *testing.T) { + t.Parallel() + + result, err := GetCheckTypeFromString("tcp") + + assert.NoError(t, err) + assert.Equal(t, result, TCP) + }) + + t.Run("Check type tcp", func(t *testing.T) { t.Parallel() - if HTTP.String() != "HTTP" { - t.Fatalf("expected 'HTTP', got %q", HTTP.String()) - } - if TCP.String() != "TCP" { - t.Fatalf("expected 'TCP', got %q", TCP.String()) - } - if ICMP.String() != "ICMP" { - t.Fatalf("expected 'ICMP', got %q", ICMP.String()) - } + result, err := GetCheckTypeFromString("tcp") + + assert.NoError(t, err) + assert.Equal(t, result, TCP) + }) + + t.Run("Check type ICMP", func(t *testing.T) { + t.Parallel() + + result, err := GetCheckTypeFromString("ICMP") + + assert.NoError(t, err) + assert.Equal(t, result, ICMP) }) - t.Run("Check type string (func)", func(t *testing.T) { - want := HTTP - got, err := GetCheckTypeFromString("http") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - if want != got { - t.Fatalf("expected %q, got %q", want, got) - } - - want = TCP - got, err = GetCheckTypeFromString("tcp") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - if want != got { - t.Fatalf("expected %q, got %q", want, got) - } - - want = ICMP - got, err = GetCheckTypeFromString("icmp") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - if want != got { - t.Fatalf("expected %q, got %q", want, got) - } - - want = -1 - got, err = GetCheckTypeFromString("invalid") - if err == nil { - t.Fatal("expected an error, got none") - } + t.Run("Check type icmp", func(t *testing.T) { + t.Parallel() + + result, err := GetCheckTypeFromString("icmp") + + assert.NoError(t, err) + assert.Equal(t, result, ICMP) + }) + + t.Run("Invalid check type", func(t *testing.T) { + t.Parallel() + + _, err := GetCheckTypeFromString("invalid") + + assert.Error(t, err) + assert.EqualError(t, err, "unsupported check type: invalid") }) } diff --git a/internal/checker/http_checker.go b/internal/checker/http_checker.go index 78e411a..7f8cf1b 100644 --- a/internal/checker/http_checker.go +++ b/internal/checker/http_checker.go @@ -5,131 +5,127 @@ import ( "crypto/tls" "fmt" "net/http" - "strconv" "time" - - "github.com/containeroo/portpatrol/pkg/httputils" ) const ( - envHTTPMethod string = "HTTP_METHOD" - envHTTPHeaders string = "HTTP_HEADERS" - envHTTPAllowDuplicateHeaders string = "HTTP_ALLOW_DUPLICATE_HEADERS" - envHTTPExpectedStatusCodes string = "HTTP_EXPECTED_STATUS_CODES" - envHTTPSkipTLSVerify string = "HTTP_SKIP_TLS_VERIFY" - - defaultHTTPMethod string = http.MethodGet - defaultHTTPAllowDuplicateHeaders bool = false - defaultHTTPSkipTLSVerify bool = false + defaultHTTPTimeout time.Duration = 1 * time.Second + defaultHTTPMethod string = http.MethodGet + defaultHTTPSkipTLSVerify bool = false ) -var defaultHTTPExpectedStatusCodes = []int{200} // Slice cannot be consts +var defaultHTTPExpectedStatusCodes = []int{200} // HTTPChecker implements the Checker interface for HTTP checks. type HTTPChecker struct { - Name string // The name of the checker. - Address string // The address of the target. - ExpectedStatusCodes []int // The expected status codes. - Method string // The HTTP method to use. - Headers map[string]string // The HTTP headers to include in the request. - client *http.Client // The HTTP client to use for the request. - DialTimeout time.Duration // The timeout for dialing the target. + name string + address string + method string + headers map[string]string + expectedStatusCodes []int + skipTLSVerify bool + timeout time.Duration + client *http.Client } -// String returns the name of the checker. -func (c *HTTPChecker) String() string { - return c.Name -} +func (c *HTTPChecker) GetAddress() string { return c.address } +func (c *HTTPChecker) GetName() string { return c.name } +func (c *HTTPChecker) GetType() string { return HTTP.String() } +func (c *HTTPChecker) Check(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, c.method, c.address, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } -// NewHTTPChecker creates a new HTTPChecker. -func NewHTTPChecker(name, address string, timeout time.Duration, getEnv func(string) string) (Checker, error) { - checker := HTTPChecker{ - Name: name, - Address: address, - Method: defaultHTTPMethod, - ExpectedStatusCodes: defaultHTTPExpectedStatusCodes, + for key, value := range c.headers { + req.Header.Add(key, value) } - // Override the default HTTP method if specified - if method := getEnv(envHTTPMethod); method != "" { - checker.Method = method + resp, err := c.client.Do(req) + if err != nil { + return fmt.Errorf("HTTP request failed: %w", err) } + defer resp.Body.Close() - // Determine if duplicate headers are allowed - var err error - allowDupHeaders := defaultHTTPAllowDuplicateHeaders - if allowDupHeaderStr := getEnv(envHTTPAllowDuplicateHeaders); allowDupHeaderStr != "" { - allowDupHeaders, err = strconv.ParseBool(allowDupHeaderStr) - if err != nil { - return nil, fmt.Errorf("invalid %s value: %w", envHTTPAllowDuplicateHeaders, err) + for _, code := range c.expectedStatusCodes { + if resp.StatusCode == code { + return nil } } - // Parse the headers string into a headers map - headers, err := httputils.ParseHeaders(getEnv(envHTTPHeaders), allowDupHeaders) - if err != nil { - return nil, fmt.Errorf("invalid %s value: %w", envHTTPHeaders, err) - } - checker.Headers = headers + return fmt.Errorf("unexpected status code: got %d, expected one of %v", resp.StatusCode, c.expectedStatusCodes) +} - // Override the default expected status codes if specified - if expectedStatusStr := getEnv(envHTTPExpectedStatusCodes); expectedStatusStr != "" { - expectedStatusCodes, err := httputils.ParseStatusCodes(expectedStatusStr) - if err != nil { - return nil, fmt.Errorf("invalid %s value: %w", envHTTPExpectedStatusCodes, err) - } - checker.ExpectedStatusCodes = expectedStatusCodes +// newHTTPChecker creates a new HTTPChecker with functional options. +func newHTTPChecker(name, address string, opts ...Option) (*HTTPChecker, error) { + checker := &HTTPChecker{ + name: name, + address: address, + method: defaultHTTPMethod, + headers: make(map[string]string), + expectedStatusCodes: defaultHTTPExpectedStatusCodes, + skipTLSVerify: defaultHTTPSkipTLSVerify, + timeout: defaultHTTPTimeout, } - // Determine if TLS verification should be skipped - skipTLSVerify := defaultHTTPSkipTLSVerify - if skipTLSVerifyStr := getEnv(envHTTPSkipTLSVerify); skipTLSVerifyStr != "" { - skipTLSVerify, err = strconv.ParseBool(skipTLSVerifyStr) - if err != nil { - return nil, fmt.Errorf("invalid %s value: %w", envHTTPSkipTLSVerify, err) - } + for _, opt := range opts { + opt.apply(checker) } - // Create the HTTP client with the given timeout and TLS configuration checker.client = &http.Client{ - Timeout: timeout, + Timeout: checker.timeout, Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ - InsecureSkipVerify: skipTLSVerify, + InsecureSkipVerify: checker.skipTLSVerify, }, }, } - return &checker, nil + return checker, nil } -// Check performs an HTTP request and checks the response. -func (c *HTTPChecker) Check(ctx context.Context) error { - // Create the HTTP request - req, err := http.NewRequestWithContext(ctx, c.Method, c.Address, nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } +// WithHTTPMethod sets the HTTP method for the HTTPChecker. +func WithHTTPMethod(method string) Option { + return OptionFunc(func(c Checker) { + if httpChecker, ok := c.(*HTTPChecker); ok { + httpChecker.method = method + } + }) +} - // Add headers to the request - for key, value := range c.Headers { - req.Header.Add(key, value) - } +// WithHTTPHeaders sets the HTTP headers for the HTTPChecker. +func WithHTTPHeaders(headers map[string]string) Option { + return OptionFunc(func(c Checker) { + if httpChecker, ok := c.(*HTTPChecker); ok { + httpChecker.headers = headers + } + }) +} - // Perform the HTTP request - resp, err := c.client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() +// WithExpectedStatusCodes sets the expected status codes for the HTTPChecker. +func WithExpectedStatusCodes(codes []int) Option { + return OptionFunc(func(c Checker) { + if httpChecker, ok := c.(*HTTPChecker); ok { + httpChecker.expectedStatusCodes = codes + } + }) +} - // Check the response status code - for _, code := range c.ExpectedStatusCodes { - if resp.StatusCode == code { - return nil // Return nil if the status code matches +// WithHTTPSkipTLSVerify sets the TLS verification flag for the HTTPChecker. +func WithHTTPSkipTLSVerify(skip bool) Option { + return OptionFunc(func(c Checker) { + if httpChecker, ok := c.(*HTTPChecker); ok { + httpChecker.skipTLSVerify = skip } - } + }) +} - return fmt.Errorf("unexpected status code: got %d, expected one of %v", resp.StatusCode, c.ExpectedStatusCodes) +// WithHTTPTimeout sets the timeout for the HTTPChecker. +func WithHTTPTimeout(timeout time.Duration) Option { + return OptionFunc(func(c Checker) { + if httpChecker, ok := c.(*HTTPChecker); ok { + httpChecker.timeout = timeout + } + }) } diff --git a/internal/checker/http_checker_test.go b/internal/checker/http_checker_test.go index 60aab33..8f84ce0 100644 --- a/internal/checker/http_checker_test.go +++ b/internal/checker/http_checker_test.go @@ -5,77 +5,16 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" - "os" - "reflect" - "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestHTTPChecker(t *testing.T) { t.Parallel() - t.Run("Valid HTTP check config", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "POST", - envHTTPHeaders: "Authorization=Bearer token", - envHTTPExpectedStatusCodes: "201", - envHTTPSkipTLSVerify: "false", - envHTTPAllowDuplicateHeaders: "false", - } - return env[key] - } - - checker, err := NewHTTPChecker("example", "http://localhost:8080", 10*time.Second, mockEnv) - if err != nil { - t.Fatalf("failed to create HTTPChecker: %q", err) - } - - checkerConfig := checker.(*HTTPChecker) // Type assertion to *HTTPChecker - - expected := "example" - if checkerConfig.Name != expected { - t.Errorf("expected Name to be '%s', got %v", expected, checkerConfig.Name) - } - - expected = "http://localhost:8080" - if checkerConfig.Address != expected { - t.Errorf("expected Address to be '%s', got %v", expected, checkerConfig.Address) - } - - expected = "POST" - if checkerConfig.Method != expected { - t.Errorf("expected Method to be '%s', got %v", expected, checkerConfig.Method) - } - - expectedInsecureSkipVerify := false - if checkerConfig.client.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify != expectedInsecureSkipVerify { - t.Errorf("expected client.Transport.TLSClientConfig.InsecureSkipVerify to be %v, got %v", expectedInsecureSkipVerify, checkerConfig.client.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify) - } - - expectedStatusCodes := []int{201} - if len(checkerConfig.ExpectedStatusCodes) != len(expectedStatusCodes) || checkerConfig.ExpectedStatusCodes[0] != expectedStatusCodes[0] { - t.Errorf("expected ExpectedStatusCodes to be %v, got %v", expectedStatusCodes, checkerConfig.ExpectedStatusCodes) - } - - expectedHeaders := map[string]string{"Authorization": "Bearer token"} - for key, value := range expectedHeaders { - if checkerConfig.Headers[key] != value { - t.Errorf("expected Headers[%s] to be '%s', got '%s'", key, value, checkerConfig.Headers[key]) - } - } - - expectedTimeout := 10 * time.Second - if checkerConfig.client.Timeout != expectedTimeout { - t.Errorf("expected client Timeout to be '%v', got %v", expectedTimeout, checkerConfig.client.Timeout) - } - }) - - t.Run("Valid HTTP check", func(t *testing.T) { + t.Run("Valid HTTP check with default configuration", func(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -84,357 +23,153 @@ func TestHTTPChecker(t *testing.T) { server := httptest.NewServer(handler) defer server.Close() - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPHeaders: "Auportpatrolization=Bearer token", - envHTTPExpectedStatusCodes: "200", - } - return env[key] - } + checker, err := newHTTPChecker("example", server.URL) + assert.NoError(t, err) - checker, err := NewHTTPChecker("example", server.URL, 1*time.Second, mockEnv) - if err != nil { - t.Fatalf("failed to create HTTPChecker: %q", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() err = checker.Check(ctx) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } + assert.NoError(t, err) + assert.Equal(t, checker.GetAddress(), server.URL) }) - t.Run("no scheme", func(t *testing.T) { + + t.Run("HTTP check with custom headers", func(t *testing.T) { t.Parallel() - // Set up a test HTTP server with a unexpected status code handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) + if r.Header.Get("Authorization") != "Bearer token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.WriteHeader(http.StatusOK) }) server := httptest.NewServer(handler) defer server.Close() - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPHeaders: "Auportpatrolization=Bearer token", - envHTTPExpectedStatusCodes: "200", - } - return env[key] - } - - url, err := url.Parse(server.URL) - if err != nil { - t.Fatalf("failed to create HTTPChecker: %q", err) - } - - checker, err := NewHTTPChecker("example", fmt.Sprintf("%s:%s", url.Hostname(), url.Port()), 1*time.Second, mockEnv) - if err != nil { - t.Fatalf("failed to create HTTPChecker: %q", err) - } + checker, err := newHTTPChecker("example", server.URL, WithHTTPHeaders(map[string]string{"Authorization": "Bearer token"})) + assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() err = checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := fmt.Sprintf("failed to create request: parse \"%s:%s\": first path segment in URL cannot contain colon", url.Hostname(), url.Port()) - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } + assert.NoError(t, err) }) - t.Run("Unexpected status code", func(t *testing.T) { + t.Run("HTTP check with unexpected status code", func(t *testing.T) { t.Parallel() - // Set up a test HTTP server with a unexpected status code handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) }) server := httptest.NewServer(handler) defer server.Close() - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPHeaders: "Auportpatrolization=Bearer token", - envHTTPExpectedStatusCodes: "200", - } - return env[key] - } - - checker, err := NewHTTPChecker("example", server.URL, 1*time.Second, mockEnv) - if err != nil { - t.Fatalf("failed to create HTTPChecker: %q", err) - } + checker, err := newHTTPChecker("example", server.URL) + assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() err = checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } + assert.Error(t, err) + assert.EqualError(t, err, "unexpected status code: got 404, expected one of [200]") + }) + + t.Run("Invalid URL for HTTP check", func(t *testing.T) { + t.Parallel() - expected := "unexpected status code: got 404, expected one of [200]" - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) + checker, err := newHTTPChecker("example", "://invalid-url") + if err != nil { + t.Fatalf("unexpected error: %v", err) } + + err = checker.Check(context.Background()) // Run the check to trigger the error. + assert.Error(t, err) + assert.EqualError(t, err, "failed to create request: parse \"://invalid-url\": missing protocol scheme") }) - t.Run("Cancel HTTP check", func(t *testing.T) { + t.Run("Timeout during HTTP check", func(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(1 * time.Second) // Delay to ensure the context has time to be canceled + time.Sleep(2 * time.Second) // Simulate delay w.WriteHeader(http.StatusOK) }) server := httptest.NewServer(handler) defer server.Close() - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPHeaders: "Auportpatrolization=Bearer token", - envHTTPExpectedStatusCodes: "200", - } - return env[key] - } + checker, err := newHTTPChecker("example", server.URL, WithHTTPTimeout(1*time.Second)) + assert.NoError(t, err) - checker, err := NewHTTPChecker("example", server.URL, 5*time.Second, mockEnv) - if err != nil { - t.Fatalf("failed to create HTTPChecker: %q", err) - } - - // Cancel the context after a very short time - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - // Perform the check, expecting a context canceled error err = checker.Check(ctx) - if err == nil { - t.Fatalf("expected an error, got none") - } - expected := "context deadline exceeded" - if !strings.Contains(err.Error(), expected) { - t.Errorf("expected error containing %q, got %q", expected, err) - } + assert.Error(t, err) + assert.EqualError(t, err, fmt.Sprintf("HTTP request failed: Get \"http://%s\": context deadline exceeded", server.Listener.Addr().String())) }) - t.Run("Invalid HTTP check (malformed URL)", func(t *testing.T) { + t.Run("Custom expected status codes", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - "METHOD": "GET", - "HEADERS": "Auportpatrolization=Bearer token", - "EXPECTED_STATUSES": "200", - } - return env[key] - } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + }) + server := httptest.NewServer(handler) + defer server.Close() - // Use a malformed URL to trigger an error in creating the request - checker, err := NewHTTPChecker("example", "://invalid-url", 5*time.Second, mockEnv) - if err != nil { - t.Fatalf("failed to create HTTPChecker: %q", err) - } + checker, err := newHTTPChecker("example", server.URL, WithExpectedStatusCodes([]int{202})) + assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() err = checker.Check(ctx) - if err == nil { - t.Fatalf("expected an error, got none") - } - - expected := "failed to create request: parse \"://invalid-url\": missing protocol scheme" - if err.Error() != expected { - t.Errorf("expected error containing %q, got %q", expected, err) - } + assert.NoError(t, err) }) - t.Run("Valid HTTP check (duplicate headers)", func(t *testing.T) { + t.Run("Custom HTTP method", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPAllowDuplicateHeaders: "true", - envHTTPMethod: "GET", - envHTTPHeaders: "Content-Type=application/json,Content-Type=application/json", - envHTTPExpectedStatusCodes: "200", - } - return env[key] - } - - checker, err := NewHTTPChecker("example", "localhost:8080", 1*time.Second, mockEnv) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - c := checker.(*HTTPChecker) // Cast the checker to HTTPChecker - - expectedHeaders := map[string]string{ - "Content-Type": "application/json", - } - - if !reflect.DeepEqual(c.Headers, expectedHeaders) { - t.Fatalf("expected headers %v, got %v", expectedHeaders, c.Headers) - } - }) - - t.Run("Invalid HTTP check (duplicate headers)", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPHeaders: "Content-Type=application/json,Content-Type=application/json", - envHTTPExpectedStatusCodes: "200", - } - return env[key] - } - - _, err := NewHTTPChecker("example", "localhost:8080", 1*time.Second, mockEnv) - if err == nil { - t.Fatalf("expected an error, got none") - } - - expected := "invalid HTTP_HEADERS value: duplicate header key found: Content-Type" - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } - }) - - t.Run("Invalid HTTP check (malformed status range)", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPHeaders: "Auportpatrolization=Bearer token", - envHTTPExpectedStatusCodes: "202-200", - } - return env[key] - } - - _, err := NewHTTPChecker("example", "localhost:7654", 1*time.Second, mockEnv) - if err == nil { - t.Fatalf("expected an error, got none") - } - - expected := fmt.Sprintf("invalid %s value: invalid status range: 202-200", envHTTPExpectedStatusCodes) - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } - }) - - t.Run("Invalid HTTP check (malformed header)", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPHeaders: "Auportpatrolization Bearer token", // Missing '=' in the header - envHTTPExpectedStatusCodes: "200", + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return } - return env[key] - } - - _, err := NewHTTPChecker("example", "http://example.com", 1*time.Second, mockEnv) - if err == nil { - t.Errorf("expected an error, got none") - } - - expected := fmt.Sprintf("invalid %s value: invalid header format: Auportpatrolization Bearer token", envHTTPHeaders) - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } - }) - - t.Run("Invalid HTTP check (malformed HTTP_ALLOW_DUPLICATE_HEADERS)", func(t *testing.T) { - t.Parallel() + w.WriteHeader(http.StatusOK) + }) + server := httptest.NewServer(handler) + defer server.Close() - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPHeaders: "Content-Type=application/json,Content-Type=application/json", - envHTTPAllowDuplicateHeaders: "invalid", - envHTTPExpectedStatusCodes: "200", - } - return env[key] - } + checker, err := newHTTPChecker("example", server.URL, WithHTTPMethod(http.MethodPost)) + assert.NoError(t, err) - _, err := NewHTTPChecker("example", "localhost:8080", 1*time.Second, mockEnv) - if err == nil { - t.Fatalf("expected an error, got none") - } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() - expected := fmt.Sprintf("invalid %s value: strconv.ParseBool: parsing \"invalid\": invalid syntax", envHTTPAllowDuplicateHeaders) - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } + err = checker.Check(ctx) + assert.NoError(t, err) }) - t.Run("Invalid HTTP check (malformed HTTP_SKIP_TLS_VERIFY)", func(t *testing.T) { + t.Run("Skip TLS verification", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envHTTPMethod: "GET", - envHTTPSkipTLSVerify: "invalid", - envHTTPExpectedStatusCodes: "200", - } - return env[key] - } - - _, err := NewHTTPChecker("example", "localhost:8080", 1*time.Second, mockEnv) - if err == nil { - t.Fatalf("expected an error, got none") - } - - expected := fmt.Sprintf("invalid %s value: strconv.ParseBool: parsing \"invalid\": invalid syntax", envHTTPSkipTLSVerify) - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } - }) -} - -func TestIsValidCheckTypeWithProxy(t *testing.T) { - t.Run("Invalid HTTP check (invalid proxy)", func(t *testing.T) { - // Do not use t.Parallel here since we're modifying global state (environment variables) - // t.Parallel() - - // Set the HTTP_PROXY environment variable to an invalid proxy - err := os.Setenv("HTTP_PROXY", "http://invalid-proxy:8080") - if err != nil { - t.Fatalf("Failed to set HTTP_PROXY environment variable: %v", err) - } - defer os.Unsetenv("HTTP_PROXY") // Clean up after the test + // Create a test server with a self-signed certificate + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() - // Create the HTTPChecker instance - checker, err := NewHTTPChecker("example", "http://example.com", 1*time.Second, os.Getenv) - if err != nil { - t.Errorf("expected no error, got %q", err) - } + checker, err := newHTTPChecker("example", server.URL, WithHTTPSkipTLSVerify(true)) + assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() err = checker.Check(ctx) - if err == nil { - t.Fatalf("expected an error, got none") - } - - // Github throws a different error message than on my local machine, so check with contains - expected := "Get \"http://example.com\": proxyconnect tcp: dial tcp: lookup invalid-proxy" - if !strings.Contains(err.Error(), expected) { - t.Errorf("expected error containing %q, got %q", expected, err.Error()) - } + assert.NoError(t, err) }) } diff --git a/internal/checker/icmp_checker.go b/internal/checker/icmp_checker.go index c4afae2..effcc8e 100644 --- a/internal/checker/icmp_checker.go +++ b/internal/checker/icmp_checker.go @@ -5,72 +5,35 @@ import ( "fmt" "net" "os" - "strings" "sync/atomic" "time" ) const ( - envICMPReadTimeout string = "ICMP_READ_TIMEOUT" - - defaultICMPReadTimeout time.Duration = time.Second * 1 + defaultICMPReadTimeout time.Duration = 1 * time.Second + defaultICMPWriteTimeout time.Duration = 1 * time.Second ) -// ICMPChecker implements a basic ICMP ping checker. +// ICMPChecker implements the Checker interface for ICMP checks. type ICMPChecker struct { - Name string // The name of the checker. - Address string // The address of the target. - Protocol Protocol // The protocol to use for the connection. - ReadTimeout time.Duration // The timeout for reading the ICMP reply. - WriteTimeout time.Duration // The timeout for writing the ICMP request. -} - -// String returns the name of the checker. -func (c *ICMPChecker) String() string { - return c.Name + name string + address string + readTimeout time.Duration + writeTimeout time.Duration + protocol Protocol } -// NewICMPChecker initializes a new ICMPChecker with its specific configuration. -func NewICMPChecker(name, address string, dialTimeout time.Duration, getEnv func(string) string) (Checker, error) { - // The "icmp://" prefix is used to identify the check type and is not needed for further processing, - // so it must be removed before passing the address to other functions. - address = strings.TrimPrefix(address, "icmp://") - - checker := ICMPChecker{ - Name: name, - Address: address, - ReadTimeout: defaultICMPReadTimeout, - WriteTimeout: dialTimeout, - } - - protocol, err := newProtocol(checker.Address) - if err != nil { - return nil, fmt.Errorf("failed to create ICMP protocol: %w", err) - } - checker.Protocol = protocol - - // Determine the read timeout - if readTimeoutStr := getEnv(envICMPReadTimeout); readTimeoutStr != "" { - readTimeout, err := time.ParseDuration(readTimeoutStr) - if err != nil || readTimeout <= 0 { - return nil, fmt.Errorf("invalid %s value: %s", envICMPReadTimeout, readTimeoutStr) - } - checker.ReadTimeout = readTimeout - } - - return &checker, nil -} +func (c *ICMPChecker) GetAddress() string { return c.address } +func (c *ICMPChecker) GetName() string { return c.name } +func (c *ICMPChecker) GetType() string { return ICMP.String() } -// Check performs an ICMP check on the target. func (c *ICMPChecker) Check(ctx context.Context) error { - // Resolve the IP address - dst, err := net.ResolveIPAddr(c.Protocol.Network(), c.Address) + dst, err := net.ResolveIPAddr(c.protocol.Network(), c.address) if err != nil { - return fmt.Errorf("failed to resolve IP address: %w", err) + return fmt.Errorf("failed to resolve IP address '%s': %w", c.address, err) } - // Listen for ICMP packets - conn, err := c.Protocol.ListenPacket(ctx, c.Protocol.Network(), "0.0.0.0") + conn, err := c.protocol.ListenPacket(ctx, c.protocol.Network(), "") if err != nil { return fmt.Errorf("failed to listen for ICMP packets: %w", err) } @@ -79,99 +42,72 @@ func (c *ICMPChecker) Check(ctx context.Context) error { identifier := uint16(os.Getpid() & 0xffff) // Create a unique identifier sequence := uint16(atomic.AddUint32(new(uint32), 1) & 0xffff) // Create a unique sequence number - // Make the ICMP request - msg, err := c.Protocol.MakeRequest(identifier, sequence) + msg, err := c.protocol.MakeRequest(identifier, sequence) if err != nil { - return err + return fmt.Errorf("failed to create ICMP request: %w", err) + } + + if err := conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil { + return fmt.Errorf("failed to set write deadline: %w", err) } - // Write the ICMP request - if err := c.writeICMPRequest(ctx, conn, msg, dst); err != nil { - return err + if _, err := conn.WriteTo(msg, dst); err != nil { + return fmt.Errorf("failed to send ICMP request: %w", err) } - // Read the ICMP reply with context - reply, err := c.readICMPReply(ctx, conn) + if err := conn.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { + return fmt.Errorf("failed to set read deadline: %w", err) + } + + reply := make([]byte, 1500) + n, _, err := conn.ReadFrom(reply) if err != nil { - return err + return fmt.Errorf("failed to read ICMP reply: %w", err) } - // Validate the ICMP reply - if err := c.validateICMPReply(ctx, reply, identifier, sequence); err != nil { - return err + if err := c.protocol.ValidateReply(reply[:n], identifier, sequence); err != nil { + return fmt.Errorf("failed to validate ICMP reply: %w", err) } return nil } -// writeICMPRequest handles writing the ICMP request. -func (c *ICMPChecker) writeICMPRequest(ctx context.Context, conn net.PacketConn, msg []byte, dst net.Addr) error { - if err := conn.SetWriteDeadline(time.Now().Add(c.ReadTimeout)); err != nil { - return fmt.Errorf("failed to set write deadline: %w", err) +// newICMPChecker initializes a new ICMPChecker with functional options. +func newICMPChecker(name, address string, opts ...Option) (*ICMPChecker, error) { + checker := &ICMPChecker{ + name: name, + address: address, + readTimeout: defaultICMPReadTimeout, + writeTimeout: defaultICMPWriteTimeout, } - done := make(chan error, 1) - - go func() { - _, err := conn.WriteTo(msg, dst) - done <- err - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - if err != nil { - return fmt.Errorf("failed to send ICMP request to %s: %w", c.Address, err) - } - return nil + for _, opt := range opts { + opt.apply(checker) } -} -// readICMPReply handles reading the ICMP reply. -func (c *ICMPChecker) readICMPReply(ctx context.Context, conn net.PacketConn) ([]byte, error) { - // Set the read deadline - if err := conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)); err != nil { - return nil, fmt.Errorf("failed to set read deadline: %w", err) + protocol, err := newProtocol(checker.address) + if err != nil { + return nil, fmt.Errorf("failed to create ICMP protocol: %w", err) } + checker.protocol = protocol - done := make(chan error, 1) - reply := make([]byte, 1500) - var n int - - go func() { - var err error - n, _, err = conn.ReadFrom(reply) - done <- err - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case err := <-done: - if err != nil { - return nil, fmt.Errorf("failed to read ICMP reply from %s: %w", c.Address, err) + return checker, nil +} + +// WithICMPReadTimeout sets the read timeout for the ICMPChecker. +func WithICMPReadTimeout(timeout time.Duration) Option { + return OptionFunc(func(c Checker) { + if icmpChecker, ok := c.(*ICMPChecker); ok { + icmpChecker.readTimeout = timeout } - return reply[:n], nil - } + }) } -// validateICMPReply handles validating the ICMP reply. -func (c *ICMPChecker) validateICMPReply(ctx context.Context, reply []byte, identifier, sequence uint16) error { - done := make(chan error, 1) - - go func() { - err := c.Protocol.ValidateReply(reply, identifier, sequence) - done <- err - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - if err != nil { - return err +// WithICMPWriteTimeout sets the write timeout for the ICMPChecker. +func WithICMPWriteTimeout(timeout time.Duration) Option { + return OptionFunc(func(c Checker) { + if icmpChecker, ok := c.(*ICMPChecker); ok { + icmpChecker.writeTimeout = timeout } - return nil - } + }) } diff --git a/internal/checker/icmp_checker_test.go b/internal/checker/icmp_checker_test.go index b6a9e4e..a57ddb5 100644 --- a/internal/checker/icmp_checker_test.go +++ b/internal/checker/icmp_checker_test.go @@ -2,809 +2,385 @@ package checker import ( "context" + "errors" "fmt" "net" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/containeroo/portpatrol/internal/testutils" + "golang.org/x/net/icmp" "golang.org/x/net/ipv4" ) -func TestNewICMPChecker(t *testing.T) { +// TestNewICMPCheckerValidIPv4 tests creating an ICMPChecker with a valid IPv4 address. +func TestNewICMPCheckerValidIPv4(t *testing.T) { t.Parallel() - t.Run("Valid IPv4 Configuration", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - return "2s" - } - - checker, err := NewICMPChecker("TestIPv4", "icmp://google.com", 1*time.Second, mockEnv) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if checker == nil { - t.Fatal("expected valid checker, got nil") - } - - expected := "TestIPv4" - if expected != checker.String() { - t.Fatalf("expected %q, got %q", expected, checker) - } - - icmpChecker := checker.(*ICMPChecker) - if icmpChecker.ReadTimeout != 2*time.Second { - t.Errorf("expected timeout of 2s, got %v", icmpChecker.ReadTimeout) - } - }) - - t.Run("Valid IPv6 Configuration", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - return "" // will fall back to default - } - - checker, err := NewICMPChecker("TestIPv6", "icmp://0:0:0:0:0:0:0:0", 1*time.Second, mockEnv) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if checker == nil { - t.Fatal("expected valid checker, got nil") - } - - icmpChecker := checker.(*ICMPChecker) - if icmpChecker.ReadTimeout != time.Second { - t.Errorf("expected default timeout of 1s, got %v", icmpChecker.ReadTimeout) - } - }) - - t.Run("Invalid IP Address", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - return "" - } - - _, err := NewICMPChecker("TestInvalidIP", "icmp://0.260.0.0", 1*time.Second, mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to create ICMP protocol: invalid or unresolvable address: 0.260.0.0" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } - }) - - t.Run("Invalid Read Timeout", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - return "invalid" - } - - _, err := NewICMPChecker("TestInvalidTimeout", "icmp://localhost", 1*time.Second, mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := fmt.Sprintf("invalid %s value: invalid", envICMPReadTimeout) - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } - }) - - t.Run("Negative Read Timeout", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - return "-1s" - } - - _, err := NewICMPChecker("TestNegativeTimeout", "icmp://localhost", 1*time.Second, mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := fmt.Sprintf("invalid %s value: -1s", envICMPReadTimeout) - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } - }) + r := WithICMPReadTimeout(2 * time.Second) + w := WithICMPWriteTimeout(2 * time.Second) + checker, err := newICMPChecker("ValidIPv4", "127.0.0.1", r, w) + + assert.NoError(t, err) + assert.Equal(t, checker.GetName(), "ValidIPv4") + assert.Equal(t, checker.GetAddress(), "127.0.0.1") } -func TestICMPChecker(t *testing.T) { +// TestNewICMPCheckerInvalidAddress tests creating an ICMPChecker with an invalid address. +func TestNewICMPCheckerInvalidAddress(t *testing.T) { t.Parallel() - t.Run("Successful ICMP Check", func(t *testing.T) { - t.Parallel() - - expectedIdentifier := uint16(1234) - expectedSequence := uint16(1) - - mockPacketConn := &testutils.MockPacketConn{ - WriteToFunc: func(b []byte, addr net.Addr) (int, error) { - return len(b), nil - }, - ReadFromFunc: func(b []byte) (int, net.Addr, error) { - msg := icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: int(expectedIdentifier), - Seq: int(expectedSequence), - Data: []byte("HELLO-R-U-THERE"), - }, - } - msgBytes, _ := msg.Marshal(nil) - copy(b, msgBytes) - return len(msgBytes), &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}, nil - }, - SetReadDeadlineFunc: func(t time.Time) error { - return nil - }, - CloseFunc: func() error { - return nil - }, - LocalAddrFunc: func() net.Addr { - return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} - }, - } - - mockProtocol := &testutils.MockProtocol{ - MakeRequestFunc: func(id, seq uint16) ([]byte, error) { - body := &icmp.Echo{ - ID: int(expectedIdentifier), - Seq: int(expectedSequence), - Data: []byte("HELLO-R-U-THERE"), - } - msg := icmp.Message{ - Type: ipv4.ICMPTypeEcho, - Code: 0, - Body: body, - } - return msg.Marshal(nil) - }, - ValidateReplyFunc: func(reply []byte, id, seq uint16) error { - parsedMsg, err := icmp.ParseMessage(1, reply) - if err != nil { - return err - } - body, ok := parsedMsg.Body.(*icmp.Echo) - if !ok || body.ID != int(expectedIdentifier) || body.Seq != int(expectedSequence) { - return fmt.Errorf("identifier or sequence mismatch") - } - return nil - }, - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return mockPacketConn, nil - }, - } - - checker := &ICMPChecker{ - Name: "TestChecker", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err != nil { - t.Errorf("expected no error, got %v", err) - } - }) - - t.Run("Error Listening for ICMP Packets", func(t *testing.T) { - t.Parallel() - - mockProtocol := &testutils.MockProtocol{ - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return nil, fmt.Errorf("mock listen packet error") - }, - } - - checker := &ICMPChecker{ - Name: "TestCheckerListenError", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to listen for ICMP packets: mock listen packet error" - if err.Error() != expected { - t.Errorf("expected listen packet error, got %v", err) - } - }) - - t.Run("Error Setting Write Deadline", func(t *testing.T) { - t.Parallel() - - mockPacketConn := &testutils.MockPacketConn{ - SetWriteDeadlineFunc: func(t time.Time) error { - return fmt.Errorf("mock set write deadline error") - }, - CloseFunc: func() error { - return nil - }, - } - - mockProtocol := &testutils.MockProtocol{ - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return mockPacketConn, nil - }, - } - - checker := &ICMPChecker{ - Name: "TestCheckerWriteDeadlineError", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to set write deadline: mock set write deadline error" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } - }) - - t.Run("Error Setting Read Deadline", func(t *testing.T) { - t.Parallel() - - mockPacketConn := &testutils.MockPacketConn{ - SetReadDeadlineFunc: func(t time.Time) error { - return fmt.Errorf("mock set read deadline error") - }, - CloseFunc: func() error { - return nil - }, - } - - mockProtocol := &testutils.MockProtocol{ - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return mockPacketConn, nil - }, - } - - checker := &ICMPChecker{ - Name: "TestCheckerDeadlineError", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to set read deadline: mock set read deadline error" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } - }) - - t.Run("Error Creating ICMP Request", func(t *testing.T) { - t.Parallel() - - mockPacketConn := &testutils.MockPacketConn{ - CloseFunc: func() error { - return nil - }, - SetReadDeadlineFunc: func(t time.Time) error { - return nil - }, - } - - mockProtocol := &testutils.MockProtocol{ - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return mockPacketConn, nil - }, - MakeRequestFunc: func(id, seq uint16) ([]byte, error) { - return nil, fmt.Errorf("mock make request error") - }, - } - - checker := &ICMPChecker{ - Name: "TestCheckerRequestError", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "mock make request error" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } - }) - - t.Run("Error Resolving IP Address", func(t *testing.T) { - t.Parallel() - - // You don't need to mock a PacketConn here because the error occurs before it is used. - mockProtocol := &testutils.MockProtocol{ - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return nil, nil - }, - MakeRequestFunc: func(id, seq uint16) ([]byte, error) { - return []byte{}, nil - }, - } - - checker := &ICMPChecker{ - Name: "TestCheckerResolveError", - Address: "invalid-address", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to resolve IP address: lookup invalid-address: no such host" - if err.Error() != expected { - t.Fatalf("expected %q, got %q", expected, err) - } - }) - - t.Run("Error Sending ICMP Request", func(t *testing.T) { - t.Parallel() - - mockPacketConn := &testutils.MockPacketConn{ - WriteToFunc: func(b []byte, addr net.Addr) (int, error) { - return 0, fmt.Errorf("mock write to error") - }, - SetReadDeadlineFunc: func(t time.Time) error { - // Ensure this function is properly mocked to avoid nil pointer dereference - return nil - }, - CloseFunc: func() error { - return nil - }, - } - - mockProtocol := &testutils.MockProtocol{ - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return mockPacketConn, nil - }, - MakeRequestFunc: func(id, seq uint16) ([]byte, error) { - return []byte{}, nil - }, - } - - checker := &ICMPChecker{ - Name: "TestCheckerWriteError", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to send ICMP request to 127.0.0.1: mock write to error" - if err.Error() != expected { - t.Errorf("expected write to error, got %v", err) - } - }) - - t.Run("Error Reading ICMP Reply", func(t *testing.T) { - t.Parallel() - - mockPacketConn := &testutils.MockPacketConn{ - WriteToFunc: func(b []byte, addr net.Addr) (int, error) { - return len(b), nil - }, - ReadFromFunc: func(b []byte) (int, net.Addr, error) { - return 0, nil, fmt.Errorf("mock read from error") - }, - SetReadDeadlineFunc: func(t time.Time) error { - return nil - }, - CloseFunc: func() error { - return nil - }, - } - - mockProtocol := &testutils.MockProtocol{ - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return mockPacketConn, nil - }, - MakeRequestFunc: func(id, seq uint16) ([]byte, error) { - return []byte{}, nil - }, - } - - checker := &ICMPChecker{ - Name: "TestCheckerReadError", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to read ICMP reply from 127.0.0.1: mock read from error" - if err.Error() != expected { - t.Errorf("expected read from error, got %v", err) - } - }) - - t.Run("Error Validating ICMP Reply", func(t *testing.T) { - t.Parallel() - - mockPacketConn := &testutils.MockPacketConn{ - WriteToFunc: func(b []byte, addr net.Addr) (int, error) { - return len(b), nil - }, - ReadFromFunc: func(b []byte) (int, net.Addr, error) { - msg := icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: int(1234), // incorrect ID to force validation error - Seq: int(1), - Data: []byte("HELLO-R-U-THERE"), - }, - } - msgBytes, _ := msg.Marshal(nil) - copy(b, msgBytes) - return len(msgBytes), &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}, nil - }, - SetReadDeadlineFunc: func(t time.Time) error { - return nil - }, - CloseFunc: func() error { - return nil - }, - } - - mockProtocol := &testutils.MockProtocol{ - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return mockPacketConn, nil - }, - MakeRequestFunc: func(id, seq uint16) ([]byte, error) { - return []byte{}, nil - }, - ValidateReplyFunc: func(reply []byte, id, seq uint16) error { - return fmt.Errorf("mock validation error") - }, - } - - checker := &ICMPChecker{ - Name: "TestCheckerValidationError", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err == nil { - t.Error("expected an error, got none") - } - - expected := "mock validation error" - if err.Error() != expected { - t.Errorf("expected validation error, got %v", err) - } - }) + _, err := newICMPChecker("InvalidAddress", "invalid-address") + assert.Error(t, err) + assert.Equal(t, err.Error(), "failed to create ICMP protocol: invalid or unresolvable address: invalid-address") } -func TestMakeICMPRequest(t *testing.T) { +// TestICMPCheckerCheckSuccess tests successful ICMP checking. +func TestICMPCheckerCheckSuccess(t *testing.T) { t.Parallel() - c := &ICMPChecker{ - Protocol: &testutils.MockProtocol{ - MakeRequestFunc: func(id, seq uint16) ([]byte, error) { - body := &icmp.Echo{ + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + msg := icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ ID: int(id), Seq: int(seq), Data: []byte("HELLO-R-U-THERE"), - } - msg := icmp.Message{ - Type: ipv4.ICMPTypeEcho, - Code: 0, - Body: body, - } - msgBytes, err := msg.Marshal(nil) - if err != nil { - return nil, err - } - - t.Logf("Generated ICMP Request: %v", msgBytes) - - return msgBytes, nil - }, + }, + } + return msg.Marshal(nil) + }, + ValidateReplyFunc: func(reply []byte, id, seq uint16) error { + return nil + }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return &testutils.MockPacketConn{}, nil + }, + } + + checker := &ICMPChecker{ + name: "SuccessChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.NoError(t, err) +} + +// TestICMPCheckerCheckResolveError tests ICMP checking with an address resolution failure. +func TestICMPCheckerCheckResolveError(t *testing.T) { + t.Parallel() + + mockProtocol := &testutils.MockProtocol{ + NetworkFunc: func() string { + return "ip4:icmp" }, } - t.Run("Success", func(t *testing.T) { - t.Parallel() + checker := &ICMPChecker{ + name: "ResolveErrorChecker", + address: "invalid-host", + protocol: mockProtocol, + readTimeout: 2 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to resolve IP address 'invalid-host': lookup invalid-host: no such host") +} + +// TestICMPCheckerCheckWriteError tests ICMP checking with a failure to write to the connection. +func TestICMPCheckerCheckWriteError(t *testing.T) { + t.Parallel() + + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + return []byte{}, nil + }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return &testutils.MockPacketConn{ + WriteToFunc: func(b []byte, addr net.Addr) (int, error) { + return 0, fmt.Errorf("mock write error") + }, + }, nil + }, + } - msg, err := c.Protocol.MakeRequest(1234, 1) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + checker := &ICMPChecker{ + name: "WriteErrorChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, + } - if len(msg) == 0 { - t.Fatalf("expected non-empty message, got empty") - } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() - t.Logf("Generated message in test: %v", msg) - }) + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to send ICMP request: mock write error") } -func TestWriteICMPRequest(t *testing.T) { +func TestICMPCheckerCheckListenPacketError(t *testing.T) { t.Parallel() - c := &ICMPChecker{ - Protocol: &testutils.MockProtocol{}, + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + return []byte{}, nil + }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return nil, fmt.Errorf("mock listen packet error") + }, } - mockConn := &testutils.MockPacketConn{ - WriteToFunc: func(b []byte, addr net.Addr) (int, error) { - return len(b), nil + checker := &ICMPChecker{ + name: "ListenPacketErrorChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, + writeTimeout: 2 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to listen for ICMP packets: mock listen packet error") +} + +func TestICMPCheckerCheckMakeRequestError(t *testing.T) { + t.Parallel() + + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + return []byte{}, errors.New("mock make request error") }, - SetWriteDeadlineFunc: func(t time.Time) error { - return nil + ValidateReplyFunc: func(reply []byte, id, seq uint16) error { + return fmt.Errorf("mock validation error") + }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return &testutils.MockPacketConn{}, nil + }, + } + + checker := &ICMPChecker{ + name: "WriteDeadlineErrorChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, + writeTimeout: 2 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to create ICMP request: mock make request error") +} + +func TestICMPCheckerCheckWriteDeadlineError(t *testing.T) { + t.Parallel() + + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + return []byte{}, nil + }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + SetDeadlineFunc: func(t time.Time) error { + return fmt.Errorf("mock write deadline error") + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return &testutils.MockPacketConn{ + WriteToFunc: func(b []byte, addr net.Addr) (int, error) { + return 0, fmt.Errorf("mock write error") + }, + }, nil + }, + } + + checker := &ICMPChecker{ + name: "WriteDeadlineErrorChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, + writeTimeout: 2 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to send ICMP request: mock write error") +} + +// TestICMPCheckerCheckReadError tests ICMP checking with a failure to read from the connection. +func TestICMPCheckerCheckReadError(t *testing.T) { + t.Parallel() + + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + return []byte{}, nil }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return &testutils.MockPacketConn{ + ReadFromFunc: func(b []byte) (int, net.Addr, error) { + return 0, nil, fmt.Errorf("mock read error") + }, + }, nil + }, + } + + checker := &ICMPChecker{ + name: "ReadErrorChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, } - t.Run("Success", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := c.writeICMPRequest(ctx, mockConn, []byte{0x01, 0x02, 0x03}, &net.IPAddr{}) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - }) - - t.Run("Context Canceled", func(t *testing.T) { - t.Parallel() - - mockConn := &testutils.MockPacketConn{ - ReadFromFunc: func(b []byte) (int, net.Addr, error) { - // Simulate a slow response - time.Sleep(3 * time.Second) - copy(b, []byte("valid")) - return 5, &net.IPAddr{}, nil - }, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - cancel() - - err := c.writeICMPRequest(ctx, mockConn, []byte{0x01, 0x02, 0x03}, &net.IPAddr{}) - if err == nil { - t.Fatalf("expected context canceled error, got nil") - } - }) - - t.Run("Write Deadline Error", func(t *testing.T) { - t.Parallel() - - mockConn.SetWriteDeadlineFunc = func(t time.Time) error { - return fmt.Errorf("mock set write deadline error") - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := c.writeICMPRequest(ctx, mockConn, []byte{0x01, 0x02, 0x03}, &net.IPAddr{}) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to set write deadline: mock set write deadline error" - if err.Error() != expected { - t.Fatalf("expected write deadline error, got %v", err) - } - }) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to read ICMP reply: mock read error") } -func TestReadICMPReply(t *testing.T) { +func TestICMPCheckerSetWriteDeadlineError(t *testing.T) { t.Parallel() - c := &ICMPChecker{ - Protocol: &testutils.MockProtocol{}, + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + return []byte{}, nil + }, + ValidateReplyFunc: func(reply []byte, id, seq uint16) error { + return nil + }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return &testutils.MockPacketConn{ + SetWriteDeadlineFunc: func(t time.Time) error { + return fmt.Errorf("mock write deadline error") + }, + }, nil + }, + } + + checker := &ICMPChecker{ + name: "SetWriteDeadlineErrorChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, } - mockConn := &testutils.MockPacketConn{ - ReadFromFunc: func(b []byte) (int, net.Addr, error) { - copy(b, []byte("valid")) - return 5, &net.IPAddr{}, nil // Return the exact number of bytes written. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to set write deadline: mock write deadline error") +} + +func TestICMPCheckerSetReadDeadlineError(t *testing.T) { + t.Parallel() + + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + return []byte{}, nil + }, + ValidateReplyFunc: func(reply []byte, id, seq uint16) error { + return nil + }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return &testutils.MockPacketConn{ + SetReadDeadlineFunc: func(t time.Time) error { + return fmt.Errorf("mock write deadline error") + }, + }, nil }, } - t.Run("Success", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - reply, err := c.readICMPReply(ctx, mockConn) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if string(reply) != "valid" { - t.Fatalf("expected 'valid', got %v", string(reply)) - } - }) - - t.Run("Context Canceled", func(t *testing.T) { - t.Parallel() - - mockConn := &testutils.MockPacketConn{ - ReadFromFunc: func(b []byte) (int, net.Addr, error) { - // Simulate a slow response - time.Sleep(3 * time.Second) - copy(b, []byte("valid")) - return 5, &net.IPAddr{}, nil - }, - } - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := c.readICMPReply(ctx, mockConn) - if err == nil { - t.Fatalf("expected context canceled error, got nil") - } - }) - - t.Run("Read Deadline Error", func(t *testing.T) { - t.Parallel() - - mockConn.SetReadDeadlineFunc = func(t time.Time) error { - return fmt.Errorf("mock set read deadline error") - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - _, err := c.readICMPReply(ctx, mockConn) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to set read deadline: mock set read deadline error" - if err.Error() != expected { - t.Fatalf("expected write deadline error, got %v", err) - } - }) + checker := &ICMPChecker{ + name: "SetReadDeadlineErrorChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to set read deadline: mock write deadline error") } -func TestValidateICMPReply(t *testing.T) { +func TestICMPCheckerValidateReplyError(t *testing.T) { t.Parallel() - c := &ICMPChecker{ - Protocol: &testutils.MockProtocol{ - ValidateReplyFunc: func(reply []byte, identifier, sequence uint16) error { - if string(reply) == "valid" { - return nil - } - return fmt.Errorf("identifier or sequence mismatch") - }, + mockProtocol := &testutils.MockProtocol{ + MakeRequestFunc: func(id, seq uint16) ([]byte, error) { + return []byte{}, nil + }, + ValidateReplyFunc: func(reply []byte, id, seq uint16) error { + return fmt.Errorf("mock validation error") + }, + NetworkFunc: func() string { + return "ip4:icmp" + }, + ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { + return &testutils.MockPacketConn{}, nil }, } - t.Run("Success", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - err := c.validateICMPReply(ctx, []byte("valid"), 1234, 1) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - }) - - t.Run("Validation Failure", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - err := c.validateICMPReply(ctx, []byte("invalid"), 1234, 1) - if err == nil { - t.Fatalf("expected validation error, got nil") - } - - expectedErr := "identifier or sequence mismatch" - if err.Error() != expectedErr { - t.Fatalf("expected error %v, got %v", expectedErr, err) - } - }) - - t.Run("Context Canceled", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - err := c.validateICMPReply(ctx, []byte("valid"), 1234, 1) - if err == nil { - t.Fatalf("expected context canceled error, got nil") - } - }) + checker := &ICMPChecker{ + name: "ValidateReplyErrorChecker", + address: "127.0.0.1", + protocol: mockProtocol, + readTimeout: 2 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := checker.Check(ctx) + assert.Error(t, err) + assert.EqualError(t, err, "failed to validate ICMP reply: mock validation error") } diff --git a/internal/checker/icmp_factory.go b/internal/checker/icmp_factory.go index d9b9b37..26854fd 100644 --- a/internal/checker/icmp_factory.go +++ b/internal/checker/icmp_factory.go @@ -16,33 +16,26 @@ const ( icmpv6ProtocolNumber int = 58 ) -// Protocol defines the interface for an ICMP protocol, which allows for sending and receiving ICMP echo requests -// (ping) for network diagnostics and availability checks. This interface abstracts the details of ICMPv4 and ICMPv6 protocols -// and provides methods for constructing requests, validating responses, and managing packet connections. +// Protocol defines an interface for ICMP-based diagnostics, abstracting ICMPv4 and ICMPv6 behavior. type Protocol interface { // MakeRequest creates an ICMP echo request message with the specified identifier and sequence number. // Returns the serialized byte representation of the message or an error if message construction fails. MakeRequest(identifier, sequence uint16) ([]byte, error) - // ValidateReply verifies that an ICMP echo reply message matches the expected identifier and sequence number. // Returns an error if the reply is invalid, such as a mismatch in identifier, sequence number, or unexpected message type. ValidateReply(reply []byte, identifier, sequence uint16) error - // Network returns the network type string to be used for listening to ICMP packets, which typically indicates the IP // protocol version (e.g., "ip4:icmp" for IPv4 ICMP or "ip6:ipv6-icmp" for IPv6 ICMP). Network() string - // ListenPacket sets up a listener for ICMP packets on the specified network and address, using the provided context. // Returns a net.PacketConn for reading and writing packets, or an error if the listener cannot be established. ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) - // SetDeadline sets the read and write deadlines for the packet connection, affecting any I/O operations on the connection. // Returns an error if setting the deadline fails. SetDeadline(t time.Time) error } -// newProtocol creates a new ICMP protocol based on the given address. -// If the address is not an IP, it will be resolved as a domain name. +// newProtocol initializes a protocol based on the given address. func newProtocol(address string) (Protocol, error) { ip := net.ParseIP(address) if ip == nil { @@ -61,7 +54,7 @@ func newProtocol(address string) (Protocol, error) { return &ICMPv4{}, nil } -// ICMPv4 implements the ICMP protocol for IPv4. +// ICMPv4 implements the Protocol interface for IPv4 ICMP. type ICMPv4 struct { conn net.PacketConn } @@ -102,9 +95,7 @@ func (p *ICMPv4) ValidateReply(reply []byte, identifier, sequence uint16) error } // Network returns the network type for the ICMP protocol. -func (p *ICMPv4) Network() string { - return "ip4:icmp" -} +func (p *ICMPv4) Network() string { return "ip4:icmp" } // ListenPacket creates a new ICMPv4 packet connection. func (p *ICMPv4) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { @@ -114,16 +105,18 @@ func (p *ICMPv4) ListenPacket(ctx context.Context, network, address string) (net return nil, fmt.Errorf("failed to listen for ICMP packets: %w", err) } p.conn = conn - - return p.conn, nil + return conn, nil } // SetDeadline sets the read and write deadlines associated with the connection. It is equivalent to calling both SetReadDeadline and SetWriteDeadline. func (p *ICMPv4) SetDeadline(t time.Time) error { + if p.conn == nil { + return fmt.Errorf("connection not initialized") + } return p.conn.SetDeadline(t) } -// ICMPv6 implements the ICMP protocol for IPv6. +// ICMPv6 implements the Protocol interface for IPv6 ICMP. type ICMPv6 struct { conn net.PacketConn } @@ -164,9 +157,7 @@ func (p *ICMPv6) ValidateReply(reply []byte, identifier, sequence uint16) error } // Network returns the network type for the ICMP protocol. -func (p *ICMPv6) Network() string { - return "ip6:ipv6-icmp" -} +func (p *ICMPv6) Network() string { return "ip6:ipv6-icmp" } // ListenPacket creates a new ICMPv6 packet connection. func (p *ICMPv6) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { @@ -176,11 +167,13 @@ func (p *ICMPv6) ListenPacket(ctx context.Context, network, address string) (net return nil, fmt.Errorf("failed to listen for ICMP packets: %w", err) } p.conn = conn - return p.conn, nil } // SetDeadline sets the read and write deadlines associated with the connection. It is equivalent to calling both SetReadDeadline and SetWriteDeadline. func (p *ICMPv6) SetDeadline(t time.Time) error { + if p.conn == nil { + return fmt.Errorf("connection not initialized") + } return p.conn.SetDeadline(t) } diff --git a/internal/checker/icmp_factory_test.go b/internal/checker/icmp_factory_test.go index d6e322e..bf6bab3 100644 --- a/internal/checker/icmp_factory_test.go +++ b/internal/checker/icmp_factory_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/containeroo/portpatrol/internal/testutils" "golang.org/x/net/icmp" @@ -19,9 +21,7 @@ func TestNewProtocol(t *testing.T) { t.Parallel() protocol, err := newProtocol("192.168.1.1") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + assert.NoError(t, err) if _, ok := protocol.(*ICMPv4); !ok { t.Fatalf("expected ICMPv4 protocol, got %T", protocol) @@ -32,9 +32,7 @@ func TestNewProtocol(t *testing.T) { t.Parallel() protocol, err := newProtocol("2001:db8::1") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + assert.NoError(t, err) if _, ok := protocol.(*ICMPv6); !ok { t.Fatalf("expected ICMPv6 protocol, got %T", protocol) @@ -45,28 +43,18 @@ func TestNewProtocol(t *testing.T) { t.Parallel() _, err := newProtocol("invalid.domain") - if err == nil { - t.Fatal("expected an error, got none") - } - expected := "invalid or unresolvable address: invalid.domain" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "invalid or unresolvable address: invalid.domain") }) t.Run("Unsupported IP Address", func(t *testing.T) { t.Parallel() _, err := newProtocol("300.300.300.300") - if err == nil { - t.Fatal("expected an error, got none") - } - expected := "invalid or unresolvable address: 300.300.300.300" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "invalid or unresolvable address: 300.300.300.300") }) } @@ -78,13 +66,9 @@ func TestICMPv4MakeRequest(t *testing.T) { protocol := &ICMPv4{} msg, err := protocol.MakeRequest(1234, 1) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if len(msg) == 0 { - t.Fatal("expected non-empty ICMP message, got empty") - } + assert.NoError(t, err) + assert.Len(t, msg, 23) }) } @@ -95,10 +79,7 @@ func TestICMPv4_Network(t *testing.T) { t.Parallel() protocol := &ICMPv4{} - expected := "ip4:icmp" - if protocol.Network() != expected { - t.Errorf("expected %q, got %q", expected, protocol.Network()) - } + assert.Equal(t, protocol.Network(), "ip4:icmp") }) } @@ -111,9 +92,17 @@ func TestICMPv4_SetDeadline(t *testing.T) { mockConn := testutils.MockPacketConn{} protocol := &ICMPv4{conn: &mockConn} err := protocol.SetDeadline(time.Now().Add(1 * time.Second)) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + + assert.NoError(t, err) + }) + + t.Run("SetDeadline Error", func(t *testing.T) { + t.Parallel() + + protocol := &ICMPv4{conn: nil} + err := protocol.SetDeadline(time.Now().Add(1 * time.Second)) + + assert.Error(t, err) }) } @@ -130,14 +119,9 @@ func TestICMPv4_ValidateReply(t *testing.T) { request[4] = 0xFF err := protocol.ValidateReply(request, 1234, 1) - if err == nil { - t.Fatalf("expected an error, got none") - } - expected := "unexpected ICMPv4 message type: echo" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "unexpected ICMPv4 message type: echo") }) t.Run("ValidateReply Success", func(t *testing.T) { @@ -151,9 +135,7 @@ func TestICMPv4_ValidateReply(t *testing.T) { reply[0] = byte(ipv4.ICMPTypeEchoReply) err := protocol.ValidateReply(reply, 1234, 1) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + assert.NoError(t, err) }) t.Run("ValidateReply Identifier Mismatch", func(t *testing.T) { t.Parallel() @@ -166,9 +148,9 @@ func TestICMPv4_ValidateReply(t *testing.T) { reply[4] = 0xFF // Modify the identifier err := protocol.ValidateReply(reply, 1234, 1) - if err == nil { - t.Fatal("expected an identifier mismatch error, got none") - } + + assert.Error(t, err) + assert.EqualError(t, err, "unexpected ICMPv4 message type: echo") }) t.Run("Error Parsing Message", func(t *testing.T) { @@ -179,14 +161,8 @@ func TestICMPv4_ValidateReply(t *testing.T) { reply := []byte{0xff, 0xff, 0xff} err := protocol.ValidateReply(reply, 1234, 1) - if err == nil { - t.Fatalf("expected no error, got %v", err) - } - - expected := "failed to parse ICMPv4 message: message too short" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "failed to parse ICMPv4 message: message too short") }) t.Run("Unexpected Message Type", func(t *testing.T) { @@ -199,14 +175,8 @@ func TestICMPv4_ValidateReply(t *testing.T) { request[4] = 0xFF err := protocol.ValidateReply(request, 1234, 1) - if err == nil { - t.Fatalf("expected an error, got none") - } - - expected := "unexpected ICMPv4 message type: echo" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "unexpected ICMPv4 message type: echo") }) t.Run("IdentifierOrSequenceMismatch", func(t *testing.T) { @@ -218,9 +188,7 @@ func TestICMPv4_ValidateReply(t *testing.T) { identifier := uint16(1234) sequence := uint16(1) validRequest, err := protocol.MakeRequest(identifier, sequence) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + assert.NoError(t, err) // Modify the request to simulate an incorrect identifier or sequence in the reply replyMsg := icmp.Message{ @@ -233,20 +201,12 @@ func TestICMPv4_ValidateReply(t *testing.T) { }, } reply, err := replyMsg.Marshal(nil) - if err != nil { - t.Fatalf("failed to marshal reply message: %v", err) - } + assert.NoError(t, err) // Call ValidateReply with the modified reply err = protocol.ValidateReply(reply, identifier, sequence) - if err == nil { - t.Fatal("expected an identifier or sequence mismatch error, got none") - } - - expected := "identifier or sequence mismatch" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "identifier or sequence mismatch") }) } @@ -262,13 +222,9 @@ func TestICMPv4_ListenPacket(t *testing.T) { defer cancel() conn, err := protocol.ListenPacket(ctx, "ip4:icmp", "localhost") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if conn == nil { - t.Fatal("expected a valid PacketConn, got nil") - } + assert.NoError(t, err) + assert.NotNil(t, conn) // Clean up the connection defer conn.Close() @@ -283,14 +239,8 @@ func TestICMPv4_ListenPacket(t *testing.T) { defer cancel() _, err := protocol.ListenPacket(ctx, "invalid-network", "localhost") - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "failed to listen for ICMP packets: listen invalid-network: unknown network invalid-network" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "failed to listen for ICMP packets: listen invalid-network: unknown network invalid-network") }) t.Run("Invalid Address", func(t *testing.T) { @@ -302,19 +252,12 @@ func TestICMPv4_ListenPacket(t *testing.T) { defer cancel() _, err := protocol.ListenPacket(ctx, "ip4:icmp", "invalid-address") - if err == nil { - t.Fatal("expected an error, got none") - } - expected := "failed to listen for ICMP packets: listen ip4:icmp: lookup invalid-address: no such host" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "failed to listen for ICMP packets: listen ip4:icmp: lookup invalid-address: no such host") }) } -// HERE - func TestICMPv6MakeRequest(t *testing.T) { t.Parallel() @@ -323,13 +266,10 @@ func TestICMPv6MakeRequest(t *testing.T) { protocol := &ICMPv6{} msg, err := protocol.MakeRequest(1234, 1) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if len(msg) == 0 { - t.Fatal("expected non-empty ICMP message, got empty") - } + assert.NoError(t, err) + assert.NotNil(t, msg) + assert.Len(t, msg, 23) }) } @@ -340,10 +280,7 @@ func TestICMPv6_Network(t *testing.T) { t.Parallel() protocol := &ICMPv6{} - expected := "ip6:ipv6-icmp" - if protocol.Network() != expected { - t.Errorf("expected %q, got %q", expected, protocol.Network()) - } + assert.Equal(t, protocol.Network(), "ip6:ipv6-icmp") }) } @@ -356,9 +293,17 @@ func TestICMPv6_SetDeadline(t *testing.T) { mockConn := testutils.MockPacketConn{} protocol := &ICMPv6{conn: &mockConn} err := protocol.SetDeadline(time.Now().Add(1 * time.Second)) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + + assert.NoError(t, err) + }) + + t.Run("SetDeadline Error", func(t *testing.T) { + t.Parallel() + + protocol := &ICMPv6{conn: nil} + err := protocol.SetDeadline(time.Now().Add(1 * time.Second)) + + assert.Error(t, err) }) } @@ -375,14 +320,9 @@ func TestICMPv6_ValidateReply(t *testing.T) { request[4] = 0xFF err := protocol.ValidateReply(request, 1234, 1) - if err == nil { - t.Fatalf("expected an error, got none") - } - expected := "unexpected ICMPv6 message type: echo request" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "unexpected ICMPv6 message type: echo request") }) t.Run("ValidateReply Success", func(t *testing.T) { @@ -396,9 +336,7 @@ func TestICMPv6_ValidateReply(t *testing.T) { reply[0] = byte(ipv6.ICMPTypeEchoReply) err := protocol.ValidateReply(reply, 1234, 1) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + assert.NoError(t, err) }) t.Run("ValidateReply Identifier Mismatch", func(t *testing.T) { t.Parallel() @@ -411,9 +349,9 @@ func TestICMPv6_ValidateReply(t *testing.T) { reply[4] = 0xFF // Modify the identifier err := protocol.ValidateReply(reply, 1234, 1) - if err == nil { - t.Fatal("expected an identifier mismatch error, got none") - } + + assert.Error(t, err) + assert.EqualError(t, err, "unexpected ICMPv6 message type: echo request") }) t.Run("Error Parsing Message", func(t *testing.T) { @@ -424,14 +362,8 @@ func TestICMPv6_ValidateReply(t *testing.T) { reply := []byte{0xff, 0xff, 0xff} err := protocol.ValidateReply(reply, 1234, 1) - if err == nil { - t.Fatalf("expected no error, got %v", err) - } - - expected := "failed to parse ICMPv6 message: message too short" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "failed to parse ICMPv6 message: message too short") }) t.Run("Unexpected Message Type", func(t *testing.T) { @@ -444,14 +376,9 @@ func TestICMPv6_ValidateReply(t *testing.T) { request[4] = 0xFF err := protocol.ValidateReply(request, 1234, 1) - if err == nil { - t.Fatalf("expected an error, got none") - } - expected := "unexpected ICMPv6 message type: echo request" - if err.Error() != expected { - t.Fatalf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "unexpected ICMPv6 message type: echo request") }) t.Run("IdentifierOrSequenceMismatch", func(t *testing.T) { @@ -463,9 +390,8 @@ func TestICMPv6_ValidateReply(t *testing.T) { identifier := uint16(1234) sequence := uint16(1) validRequest, err := protocol.MakeRequest(identifier, sequence) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + + assert.NoError(t, err) // Modify the request to simulate an incorrect identifier or sequence in the reply replyMsg := icmp.Message{ @@ -478,20 +404,12 @@ func TestICMPv6_ValidateReply(t *testing.T) { }, } reply, err := replyMsg.Marshal(nil) - if err != nil { - t.Fatalf("failed to marshal reply message: %v", err) - } + assert.NoError(t, err) // Call ValidateReply with the modified reply err = protocol.ValidateReply(reply, identifier, sequence) - if err == nil { - t.Fatal("expected an identifier or sequence mismatch error, got none") - } - - expected := "identifier or sequence mismatch" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "identifier or sequence mismatch") }) } @@ -507,13 +425,9 @@ func TestICMPv6_ListenPacket(t *testing.T) { defer cancel() conn, err := protocol.ListenPacket(ctx, "ip6:ipv6-icmp", "localhost") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if conn == nil { - t.Fatal("expected a valid PacketConn, got nil") - } + assert.NoError(t, err) + assert.NotNil(t, conn) // Clean up the connection defer conn.Close() @@ -528,14 +442,9 @@ func TestICMPv6_ListenPacket(t *testing.T) { defer cancel() _, err := protocol.ListenPacket(ctx, "invalid-network", "localhost") - if err == nil { - t.Fatal("expected an error, got none") - } - expected := "failed to listen for ICMP packets: listen invalid-network: unknown network invalid-network" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "failed to listen for ICMP packets: listen invalid-network: unknown network invalid-network") }) t.Run("Invalid Address", func(t *testing.T) { @@ -547,13 +456,8 @@ func TestICMPv6_ListenPacket(t *testing.T) { defer cancel() _, err := protocol.ListenPacket(ctx, "ip6:ipv6-icmp", "invalid-address") - if err == nil { - t.Fatal("expected an error, got none") - } - expected := "failed to listen for ICMP packets: listen ip6:ipv6-icmp: lookup invalid-address: no such host" - if err.Error() != expected { - t.Errorf("expected error %q, got %q", expected, err.Error()) - } + assert.Error(t, err) + assert.EqualError(t, err, "failed to listen for ICMP packets: listen ip6:ipv6-icmp: lookup invalid-address: no such host") }) } diff --git a/internal/checker/tcp_checker.go b/internal/checker/tcp_checker.go index c49a614..3d91d9a 100644 --- a/internal/checker/tcp_checker.go +++ b/internal/checker/tcp_checker.go @@ -3,46 +3,52 @@ package checker import ( "context" "net" - "strings" "time" ) +const defaultTCPTimeout time.Duration = 1 * time.Second + // TCPChecker implements the Checker interface for TCP checks. type TCPChecker struct { - Name string // The name of the checker. - Address string // The address of the target. - dialer *net.Dialer // The dialer to use for the connection. + name string + address string + dialer *net.Dialer } -// String returns the name of the checker. -func (c *TCPChecker) String() string { - return c.Name +func (c *TCPChecker) GetAddress() string { return c.address } +func (c *TCPChecker) GetName() string { return c.name } +func (c *TCPChecker) GetType() string { return TCP.String() } +func (c *TCPChecker) Check(ctx context.Context) error { + conn, err := c.dialer.DialContext(ctx, "tcp", c.address) + if err != nil { + return err + } + defer conn.Close() + return nil } -// NewTCPChecker creates a new TCPChecker. -func NewTCPChecker(name, address string, timeout time.Duration) (Checker, error) { - // The "tcp://" prefix is used to identify the check type and is not needed for further processing, - // so it must be removed before passing the address to other functions. - address = strings.TrimPrefix(address, "tcp://") - - checker := TCPChecker{ - Address: address, - Name: name, +// newTCPChecker creates a new TCPChecker with functional options. +func newTCPChecker(name, address string, opts ...Option) (*TCPChecker, error) { + checker := &TCPChecker{ + name: name, + address: address, dialer: &net.Dialer{ - Timeout: timeout, + Timeout: defaultTCPTimeout, }, } - return &checker, nil -} - -// Check performs a TCP connection check. -func (c *TCPChecker) Check(ctx context.Context) error { - conn, err := c.dialer.DialContext(ctx, "tcp", c.Address) - if err != nil { - return err + for _, opt := range opts { + opt.apply(checker) } - defer conn.Close() - return nil + return checker, nil +} + +// WithTCPTimeout sets the timeout for the TCPChecker. +func WithTCPTimeout(timeout time.Duration) Option { + return OptionFunc(func(c Checker) { + if tcpChecker, ok := c.(*TCPChecker); ok { + tcpChecker.dialer.Timeout = timeout + } + }) } diff --git a/internal/checker/tcp_checker_test.go b/internal/checker/tcp_checker_test.go index 5fc08d3..cc6b1f3 100644 --- a/internal/checker/tcp_checker_test.go +++ b/internal/checker/tcp_checker_test.go @@ -5,48 +5,84 @@ import ( "net" "testing" "time" + + "github.com/stretchr/testify/assert" ) -func TestTCPChecker(t *testing.T) { +func TestNewTCPChecker_Valid(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:7080") + if err != nil { + t.Fatalf("failed to start TCP server: %q", err) + } + defer ln.Close() + + checker, err := newTCPChecker("example", ln.Addr().String(), WithTCPTimeout(1*time.Second)) + assert.NoError(t, err) + + assert.Equal(t, checker.GetName(), "example") + assert.Equal(t, checker.GetAddress(), ln.Addr().String()) + assert.Equal(t, checker.GetType(), TCP.String()) +} + +func TestTCPChecker_ValidConnection(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:7081") + if err != nil { + t.Fatalf("failed to start TCP server: %q", err) + } + defer ln.Close() + + checker, err := newTCPChecker("example", ln.Addr().String(), WithTCPTimeout(1*time.Second)) + assert.NoError(t, err) + + ctx := context.Background() + err = checker.Check(ctx) + assert.NoError(t, err) +} + +func TestTCPChecker_FailedConnection(t *testing.T) { + t.Parallel() + + checker, err := newTCPChecker("example", "127.0.0.1:7090", WithTCPTimeout(1*time.Second)) + assert.NoError(t, err) + + ctx := context.Background() + err = checker.Check(ctx) + + assert.Error(t, err) + assert.EqualError(t, err, "dial tcp 127.0.0.1:7090: connect: connection refused") +} + +func TestTCPChecker_InvalidAddress(t *testing.T) { + t.Parallel() + + checker, err := newTCPChecker("example", "invalid-address", WithTCPTimeout(1*time.Second)) + assert.NoError(t, err) + + ctx := context.Background() + err = checker.Check(ctx) + + assert.Error(t, err) + assert.EqualError(t, err, "dial tcp: address invalid-address: missing port in address") +} + +func TestTCPChecker_Timeout(t *testing.T) { t.Parallel() - t.Run("Valid TCP check", func(t *testing.T) { - t.Parallel() - - ln, err := net.Listen("tcp", "127.0.0.1:7080") - if err != nil { - t.Fatalf("failed to start TCP server: %q", err) - } - defer ln.Close() - - checker, err := NewTCPChecker("example", ln.Addr().String(), 1*time.Second) - if err != nil { - t.Fatalf("failed to create TCPChecker: %q", err) - } - - // Perform the check - err = checker.Check(context.Background()) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - }) - - t.Run("Failed TCP check", func(t *testing.T) { - t.Parallel() - - checker, err := NewTCPChecker("example", "localhost:7090", 1*time.Second) - if err != nil { - t.Fatalf("failed to create TCPChecker: %q", err) - } - - err = checker.Check(context.Background()) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "dial tcp [::1]:7090: connect: connection refused" - if err.Error() != expected { - t.Errorf("expected error containing %q, got %q", expected, err) - } - }) + ln, err := net.Listen("tcp", "127.0.0.1:7082") + defer ln.Close() + assert.NoError(t, err) + + // Simulate a timeout by setting an impossibly short timeout + checker, err := newTCPChecker("example", ln.Addr().String(), WithTCPTimeout(1*time.Nanosecond)) + assert.NoError(t, err) + + ctx := context.Background() + err = checker.Check(ctx) + + assert.Error(t, err) + assert.EqualError(t, err, "dial tcp 127.0.0.1:7082: i/o timeout") } diff --git a/internal/config/config.go b/internal/config/config.go index 27e7e15..098e9b3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,132 +2,172 @@ package config import ( "fmt" - "net/url" - "strconv" - "strings" + "io" "time" - "github.com/containeroo/portpatrol/internal/checker" + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/spf13/pflag" ) const ( - envTargetName string = "TARGET_NAME" - envTargetAddress string = "TARGET_ADDRESS" - envTargetCheckType string = "TARGET_CHECK_TYPE" - envCheckInterval string = "CHECK_INTERVAL" - envDialTimeout string = "DIAL_TIMEOUT" - envLogExtraFields string = "LOG_EXTRA_FIELDS" - - defaultTargetCheckType checker.CheckType = checker.TCP - defaultCheckInterval time.Duration = 2 * time.Second - defaultDialTimeout time.Duration = 1 * time.Second - defaultLogExtraFields bool = false + paramDefaultInterval string = "default-interval" + defaultCheckInterval time.Duration = 2 * time.Second + defaultHTTPAllowDuplicateHeaders bool = false + defaultHTTPSkipTLSVerify bool = false ) -// Config holds the required environment variables. -type Config struct { - Version string // The version of the application. - TargetName string // The name of the target. - TargetAddress string // The address of the target. - TargetCheckType checker.CheckType // Type of check: "tcp", "http" or "icmp". - CheckInterval time.Duration // The interval between connection attempts. - DialTimeout time.Duration // The timeout for dialing the target. - LogExtraFields bool // Whether to log the fields in the log message. +type HelpRequested struct { + Message string } -// ParseConfig retrieves and parses the required environment variables. -// Provides default values if the environment variables are not set. -func ParseConfig(getEnv func(string) string) (Config, error) { - cfg := Config{ - TargetName: getEnv(envTargetName), - TargetAddress: getEnv(envTargetAddress), - TargetCheckType: defaultTargetCheckType, - CheckInterval: defaultCheckInterval, - DialTimeout: defaultDialTimeout, - LogExtraFields: defaultLogExtraFields, - } +func (e *HelpRequested) Error() string { + return e.Message +} - if cfg.TargetAddress == "" { - return Config{}, fmt.Errorf("%s environment variable is required", envTargetAddress) - } +// Is returns true if the error is a HelpRequested error. +func (e *HelpRequested) Is(target error) bool { + _, ok := target.(*HelpRequested) + return ok +} + +// ParsedFlags holds the parsed command-line flags. +type ParsedFlags struct { + ShowHelp bool + ShowVersion bool + Version string + DefaultCheckInterval time.Duration + DynFlags *dynflags.DynFlags +} - if cfg.TargetName == "" { - address := cfg.TargetAddress - if !strings.Contains(address, "://") { - address = fmt.Sprintf("http://%s", address) // Prepend scheme if missing to avoid url.Parse error - } +// ParseFlags parses command-line arguments and returns the parsed flags. +func ParseFlags(args []string, version string, output io.Writer) (*ParsedFlags, error) { + // Create global flagSet and dynamic flags + flagSet := setupGlobalFlags() + dynFlags := setupDynamicFlags() - // Use url.Parse to handle both cases: with and without a port - parsedURL, err := url.Parse(address) - if err != nil { - return Config{}, fmt.Errorf("could not parse target address: %w", err) - } + // Set output for flagSet and dynFlags + flagSet.SetOutput(output) + dynFlags.SetOutput(output) - hostname := parsedURL.Hostname() // Extract the hostname - if hostname == "" { - return Config{}, fmt.Errorf("could not extract hostname from target address: %s", cfg.TargetAddress) - } + // Set up custom usage function + setupUsage(output, flagSet, dynFlags) - cfg.TargetName = hostname + // Parse unknown arguments with dynamic flags + if err := dynFlags.Parse(args); err != nil { + return nil, fmt.Errorf("error parsing dynamic flags: %w", err) } - // Parse the interval - if intervalStr := getEnv(envCheckInterval); intervalStr != "" { - interval, err := time.ParseDuration(intervalStr) - if err != nil || interval <= 0 { - return Config{}, fmt.Errorf("invalid %s value: %s", envCheckInterval, intervalStr) - } - cfg.CheckInterval = interval + unknownArgs := dynFlags.UnparsedArgs() + + // Parse known flags + if err := flagSet.Parse(unknownArgs); err != nil { + return nil, fmt.Errorf("Flag parsing error: %s", err.Error()) + } + // Handle special flags (e.g., --help or --version) + if err := handleSpecialFlags(flagSet, version); err != nil { + return nil, err } - // Parse the dial timeout - if dialTimeoutStr := getEnv(envDialTimeout); dialTimeoutStr != "" { - dialTimeout, err := time.ParseDuration(dialTimeoutStr) - if err != nil || dialTimeout <= 0 { - return Config{}, fmt.Errorf("invalid %s value: %s", envDialTimeout, dialTimeoutStr) - } - cfg.DialTimeout = dialTimeout + // Retrieve the default interval value + defaultInterval, err := getDurationFlag(flagSet, paramDefaultInterval, defaultCheckInterval) + if err != nil { + return nil, err } - // Parse the log additional fields - if logFieldsStr := getEnv(envLogExtraFields); logFieldsStr != "" { - logExtraFields, err := strconv.ParseBool(logFieldsStr) - if err != nil { - return Config{}, fmt.Errorf("invalid %s value: %s", envLogExtraFields, logFieldsStr) - } - cfg.LogExtraFields = logExtraFields + return &ParsedFlags{ + DefaultCheckInterval: defaultInterval, + DynFlags: dynFlags, + }, nil +} + +// setupGlobalFlags sets up global application flags. +func setupGlobalFlags() *pflag.FlagSet { + flagSet := pflag.NewFlagSet("portpatrol", pflag.ContinueOnError) + flagSet.SortFlags = false + + flagSet.Duration(paramDefaultInterval, defaultCheckInterval, "Default interval between checks. Can be overridden for each target.") + flagSet.Bool("version", false, "Show version and exit.") + flagSet.BoolP("help", "h", false, "Show help.") + + return flagSet +} + +// setupDynamicFlags sets up dynamic flags for HTTP, TCP, ICMP. +func setupDynamicFlags() *dynflags.DynFlags { + dynFlags := dynflags.New(dynflags.ContinueOnError) + dynFlags.Epilog("For more information, see https://github.com/containeroo/portpatrol") + dynFlags.SortGroups = true + dynFlags.SortFlags = true + + // HTTP flags + httpFlags := dynFlags.Group("http") + httpFlags.String("name", "", "Name of the HTTP checker") + httpFlags.String("method", "GET", "HTTP method to use") + httpFlags.String("address", "", "HTTP target URL") + httpFlags.Duration("interval", 1*time.Second, "Time between HTTP requests. Can be overwritten with --default-interval.") + httpFlags.StringSlices("header", nil, "HTTP headers to send") + httpFlags.Bool("allow-duplicate-headers", defaultHTTPAllowDuplicateHeaders, "Allow duplicate HTTP headers") + httpFlags.String("expected-status-codes", "200", "Expected HTTP status codes") + httpFlags.Bool("skip-tls-verify", defaultHTTPSkipTLSVerify, "Skip TLS verification") + httpFlags.Duration("timeout", 2*time.Second, "Timeout in seconds") + + // ICMP flags + icmpFlags := dynFlags.Group("icmp") + icmpFlags.String("name", "", "Name of the ICMP checker") + icmpFlags.String("address", "", "ICMP target address") + icmpFlags.Duration("interval", 1*time.Second, "Time between ICMP requests. Can be overwritten with --default-interval.") + icmpFlags.Duration("read-timeout", 2*time.Second, "Timeout for ICMP read") + icmpFlags.Duration("write-timeout", 2*time.Second, "Timeout for ICMP write") + + // TCP flags + tcpFlags := dynFlags.Group("tcp") + tcpFlags.String("name", "", "Name of the TCP checker") + tcpFlags.String("address", "", "TCP target address") + tcpFlags.Duration("timeout", 2*time.Second, "Timeout for TCP connection") + tcpFlags.Duration("interval", 1*time.Second, "Time between TCP requests. Can be overwritten with --default-interval.") + + return dynFlags +} + +// setupUsage sets the custom usage function. +func setupUsage(output io.Writer, flagSet *pflag.FlagSet, dynFlags *dynflags.DynFlags) { + flagSet.Usage = func() { + fmt.Fprintln(output, "Usage: portpatrol [FLAGS] [DYNAMIC FLAGS..]") + + fmt.Fprintln(output, "\nGlobal Flags:") + flagSet.PrintDefaults() + + fmt.Fprintln(output, "\nDynamic Flags:") + dynFlags.PrintDefaults() + } +} + +// handleSpecialFlags handles help and version flags. +func handleSpecialFlags(flagSet *pflag.FlagSet, version string) error { + if flagSet.Lookup("help").Value.String() == "true" { + flagSet.Usage() + return &HelpRequested{Message: ""} } - // Resolve TargetCheckType - if err := resolveTargetCheckType(&cfg, getEnv); err != nil { - return Config{}, err + if flagSet.Lookup("version").Value.String() == "true" { + return &HelpRequested{Message: fmt.Sprintf("PortPatrol version %s\n", version)} } - return cfg, nil + return nil } -// resolveTargetCheckType handles the logic for determining the check type -func resolveTargetCheckType(cfg *Config, getEnv func(string) string) error { - // First, check if envTargetCheckType is explicitly set - if checkTypeStr := getEnv(envTargetCheckType); checkTypeStr != "" { - checkType, err := checker.GetCheckTypeFromString(checkTypeStr) - if err != nil { - return fmt.Errorf("invalid check type from environment: %w", err) - } - cfg.TargetCheckType = checkType - return nil +// Example of getting a flag value as a time.Duration +func getDurationFlag(flagSet *pflag.FlagSet, name string, defaultValue time.Duration) (time.Duration, error) { + flag := flagSet.Lookup(name) + if flag == nil { + return defaultValue, nil } - // If envTargetCheckType is not set, try to infer it from the target address - parts := strings.SplitN(cfg.TargetAddress, "://", 2) // parts[0] is the scheme, parts[1] is the address - if len(parts) == 2 { - checkType, err := checker.GetCheckTypeFromString(parts[0]) - if err != nil { - return fmt.Errorf("could not infer check type from address %s: %w", cfg.TargetAddress, err) - } - cfg.TargetCheckType = checkType - return nil + // Parse the flag value from string to time.Duration + value, err := time.ParseDuration(flag.Value.String()) + if err != nil { + return defaultValue, fmt.Errorf("invalid duration for flag '%s'", flag.Value.String()) } - return nil + return value, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4343244..35a7475 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,396 +1,222 @@ package config import ( + "bytes" "fmt" - "reflect" + "strings" "testing" "time" - "github.com/containeroo/portpatrol/internal/checker" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" ) -func TestParseConfig(t *testing.T) { +func TestParseFlags(t *testing.T) { t.Parallel() - t.Run("Valid config with defaults", func(t *testing.T) { + t.Run("Successful Parsing", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "example.com:80", - } - return env[key] - } - - cfg, err := ParseConfig(mockEnv) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } + args := []string{"--default-interval=5s"} + var output strings.Builder - expected := Config{ - TargetName: "example.com", // Extracted from TargetAddress - TargetAddress: "example.com:80", - TargetCheckType: checker.TCP, - CheckInterval: 2 * time.Second, - DialTimeout: 1 * time.Second, - } - if !reflect.DeepEqual(cfg, expected) { - t.Fatalf("expected config %+v, got %+v", expected, cfg) - } + parsedFlags, err := ParseFlags(args, "1.0.0", &output) + assert.NoError(t, err) + assert.Equal(t, 5*time.Second, parsedFlags.DefaultCheckInterval) }) - t.Run("Valid config with www as scheme", func(t *testing.T) { + t.Run("Unknown Dynamic Flag", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "www.example.com:80", - envTargetCheckType: "http", - envCheckInterval: "5s", - envDialTimeout: "10s", - } - return env[key] - } - - cfg, err := ParseConfig(mockEnv) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - expected := Config{ - TargetName: "www.example.com", // Extracted from TargetAddress - TargetAddress: "www.example.com:80", - TargetCheckType: checker.HTTP, - CheckInterval: 5 * time.Second, - DialTimeout: 10 * time.Second, - } - if !reflect.DeepEqual(cfg, expected) { - t.Fatalf("expected config %+v, got %+v", expected, cfg) - } - }) - - t.Run("Valid config with kubernetes service", func(t *testing.T) { - t.Parallel() + args := []string{"--unknown.identifier.flag=value"} + var output bytes.Buffer - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://postgres.postgres.svc.cluster.local:80", - envTargetCheckType: "http", - envCheckInterval: "5s", - envDialTimeout: "10s", - } - return env[key] - } + parsedFlags, err := ParseFlags(args, "1.0.0", &output) + assert.NoError(t, err) - cfg, err := ParseConfig(mockEnv) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } + df := parsedFlags.DynFlags + g := df.Unknown() - expected := Config{ - TargetName: "postgres.postgres.svc.cluster.local", // Extracted from TargetAddress - TargetAddress: "http://postgres.postgres.svc.cluster.local:80", - TargetCheckType: checker.HTTP, - CheckInterval: 5 * time.Second, - DialTimeout: 10 * time.Second, - } - if !reflect.DeepEqual(cfg, expected) { - t.Fatalf("expected config %+v, got %+v", expected, cfg) - } + assert.NotNil(t, g) }) - t.Run("Valid config with custom values", func(t *testing.T) { + t.Run("Handle Help Flag", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "tcp://example.com:80", - envTargetCheckType: "tcp", - envCheckInterval: "5s", - envDialTimeout: "10s", - } - return env[key] - } + var output strings.Builder + flagSet := setupGlobalFlags() + flagSet.SetOutput(&output) // Ensure output is properly set + _ = flagSet.Parse([]string{"--help"}) - cfg, err := ParseConfig(mockEnv) - if err != nil { - t.Fatalf("expected no error, got %q", err) + flagSet.Usage = func() { + fmt.Fprintln(&output, "Usage: portpatrol [FLAGS] [DYNAMIC FLAGS..]") } - expected := Config{ - TargetName: "example.com", // Extracted from TargetAddress - TargetAddress: "tcp://example.com:80", - TargetCheckType: checker.TCP, - CheckInterval: 5 * time.Second, - DialTimeout: 10 * time.Second, - } - if !reflect.DeepEqual(cfg, expected) { - t.Fatalf("expected config %+v, got %+v", expected, cfg) - } + err := handleSpecialFlags(flagSet, "1.0.0") + assert.Error(t, err) + assert.IsType(t, &HelpRequested{}, err) + assert.Contains(t, output.String(), "Usage: portpatrol [FLAGS] [DYNAMIC FLAGS..]") }) - t.Run("Invalid interval (invalid)", func(t *testing.T) { + t.Run("Show Version Flag", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://example.com", - envCheckInterval: "invalid", - } - return env[key] - } + args := []string{"--version"} + var output bytes.Buffer - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := fmt.Sprintf("invalid %s value: invalid", envCheckInterval) - if err.Error() != expected { - t.Fatalf("expected error to contain %q, got %q", expected, err) - } + _, err := ParseFlags(args, "1.0.0", &output) + assert.Error(t, err) + assert.IsType(t, &HelpRequested{}, err) + assert.Contains(t, err.Error(), "PortPatrol version 1.0.0") }) - t.Run("Invalid interval (zero)", func(t *testing.T) { + t.Run("Invalid Duration Flag", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://example.com", - envCheckInterval: "0s", - } - return env[key] - } + args := []string{"--default-interval=invalid"} + var output bytes.Buffer - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } + _, err := ParseFlags(args, "1.0.0", &output) + assert.Error(t, err) - expected := fmt.Sprintf("invalid %s value: 0s", envCheckInterval) - if err.Error() != expected { - t.Fatalf("expected error to contain %q, got %q", expected, err) - } + assert.EqualError(t, err, "Flag parsing error: invalid argument \"invalid\" for \"--default-interval\" flag: time: invalid duration \"invalid\"") }) +} - t.Run("Invalid dial timeout (invalid)", func(t *testing.T) { - t.Parallel() - - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://example.com", - envDialTimeout: "invalid", - } - return env[key] - } +func TestIsHelpRequested(t *testing.T) { + t.Parallel() - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } + t.Run("Help requested", func(t *testing.T) { + t.Parallel() - expected := fmt.Sprintf("invalid %s value: invalid", envDialTimeout) - if err.Error() != expected { - t.Fatalf("expected error to contain %q, got %q", expected, err) - } + err := &HelpRequested{Message: "Help requested"} + assert.ErrorIs(t, err, &HelpRequested{}) }) +} - t.Run("Invalid dial timeout (zero)", func(t *testing.T) { - t.Parallel() +func TestSetupGlobalFlags(t *testing.T) { + t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://example.com", - envDialTimeout: "0s", - } - return env[key] - } + flagSet := setupGlobalFlags() + assert.NotNil(t, flagSet.Lookup("default-interval")) + assert.NotNil(t, flagSet.Lookup("version")) + assert.NotNil(t, flagSet.Lookup("help")) +} - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } - }) +func TestSetupDynamicFlags(t *testing.T) { + t.Parallel() - t.Run("Invalid address (invalid address)", func(t *testing.T) { - t.Parallel() + dynFlags := setupDynamicFlags() + assert.NotNil(t, dynFlags.Group("http")) + assert.NotNil(t, dynFlags.Group("tcp")) + assert.NotNil(t, dynFlags.Group("icmp")) - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://exam ple.com", - } - return env[key] - } + httpGroup := dynFlags.Group("http") + assert.NotNil(t, httpGroup.Lookup("name")) + assert.NotNil(t, httpGroup.Lookup("method")) + assert.NotNil(t, httpGroup.Lookup("address")) +} - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } +func TestSetupUsage(t *testing.T) { + t.Parallel() - expected := "could not parse target address: parse \"http://exam ple.com\": invalid character \" \" in host name" - if err.Error() != expected { - t.Fatalf("expected error to contain %q, got %q", expected, err) - } - }) + var output strings.Builder + flagSet := setupGlobalFlags() + flagSet.SetOutput(&output) - t.Run("Invalid hostname (missing address)", func(t *testing.T) { - t.Parallel() + dynFlags := setupDynamicFlags() + dynFlags.SetOutput(&output) - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://:8080", - } - return env[key] - } + setupUsage(&output, flagSet, dynFlags) + flagSet.Usage() - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } + usageOutput := output.String() + assert.Contains(t, usageOutput, "Usage: portpatrol [FLAGS] [DYNAMIC FLAGS..]") + assert.Contains(t, usageOutput, "Global Flags:") + assert.Contains(t, usageOutput, "--default-interval") + assert.Contains(t, usageOutput, "Dynamic Flags:") + assert.Contains(t, usageOutput, "http") +} - expected := "could not extract hostname from target address: http://:8080" - if err.Error() != expected { - t.Fatalf("expected error to contain %q, got %q", expected, err) - } - }) +func TestHandleSpecialFlags(t *testing.T) { + t.Parallel() - t.Run("Valid LOG_EXTRA_FIELDS", func(t *testing.T) { + t.Run("Handle Help Flag", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://example.com", - envLogExtraFields: "true", - } - return env[key] - } + var output strings.Builder + flagSet := setupGlobalFlags() + flagSet.SetOutput(&output) - result, err := ParseConfig(mockEnv) - if err != nil { - t.Fatalf("expected no error, got %q", err) + flagSet.Usage = func() { + fmt.Fprintln(&output, "Usage: portpatrol [FLAGS] [DYNAMIC FLAGS..]") } - expected := Config{ - TargetName: "example.com", - TargetAddress: "http://example.com", - TargetCheckType: checker.HTTP, - CheckInterval: 2 * time.Second, - DialTimeout: 1 * time.Second, - LogExtraFields: true, - } - if !reflect.DeepEqual(result, expected) { - t.Fatalf("expected %v, got %v", expected, result) - } + args := []string{"--help"} + err := flagSet.Parse(args) + assert.NoError(t, err) + + err = handleSpecialFlags(flagSet, "1.0.0") + assert.Error(t, err) }) - t.Run("Invalid LOG_EXTRA_FIELDS (not boolean)", func(t *testing.T) { + t.Run("Handle Version Flag", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://example.com", - envLogExtraFields: "invalid", - } - return env[key] - } - - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } + flagSet := setupGlobalFlags() + _ = flagSet.Parse([]string{"--version"}) - expected := fmt.Sprintf("invalid %s value: invalid", envLogExtraFields) - if err.Error() != expected { - t.Fatalf("expected error to contain %q, got %q", expected, err) - } + err := handleSpecialFlags(flagSet, "1.0.0") + assert.Error(t, err) + assert.IsType(t, &HelpRequested{}, err) + assert.Contains(t, err.Error(), "PortPatrol version 1.0.0") }) - t.Run("Valid check type (defaults to tcp)", func(t *testing.T) { + t.Run("No Special Flags", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "example.com:80", - } - return env[key] - } - - result, err := ParseConfig(mockEnv) - if err != nil { - t.Fatalf("expected no error, got %q", err) - } + flagSet := setupGlobalFlags() + _ = flagSet.Parse([]string{}) - expected := Config{ - TargetName: "example.com", - TargetAddress: "example.com:80", - TargetCheckType: checker.TCP, - CheckInterval: 2 * time.Second, - DialTimeout: 1 * time.Second, - LogExtraFields: false, - } - if !reflect.DeepEqual(result, expected) { - t.Fatalf("expected %v, got %v", expected, result) - } + err := handleSpecialFlags(flagSet, "1.0.0") + assert.NoError(t, err) }) +} + +func TestGetDurationFlag(t *testing.T) { + t.Parallel() - t.Run("Invalid check type (invalid)", func(t *testing.T) { + t.Run("Valid Duration Flag", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "http://example.com", - envTargetCheckType: "invalid", - } - return env[key] - } + flagSet := setupGlobalFlags() + _ = flagSet.Set("default-interval", "10s") - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "invalid check type from environment: unsupported check type: invalid" - if err.Error() != expected { - t.Errorf("expected error to contain %q, got %q", expected, err) - } + duration, err := getDurationFlag(flagSet, "default-interval", time.Second) + assert.NoError(t, err) + assert.Equal(t, 10*time.Second, duration) }) - t.Run("Invalid check type (infer invalid)", func(t *testing.T) { + t.Run("Invalid Duration", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - env := map[string]string{ - envTargetAddress: "htp://example.com", - } - return env[key] - } - - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } + flagSet := pflag.NewFlagSet("portpatrol", pflag.ContinueOnError) + flagSet.String("invalid-flag", "invalid", "Invalid flag") + err := flagSet.Set("invalid-flag", "invalid") + assert.NoError(t, err) - expected := "could not infer check type from address htp://example.com: unsupported check type: htp" - if err.Error() != expected { - t.Fatalf("expected error to contain %q, got %q", expected, err) - } + _, err = getDurationFlag(flagSet, "invalid-flag", time.Second) + assert.Error(t, err) + assert.EqualError(t, err, "invalid duration for flag 'invalid'") }) - t.Run("Missing target address", func(t *testing.T) { + t.Run("Missing Duration Flag", func(t *testing.T) { t.Parallel() - mockEnv := func(key string) string { - return "" - } + flagSet := setupGlobalFlags() - _, err := ParseConfig(mockEnv) - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := fmt.Sprintf("%s environment variable is required", envTargetAddress) - if err.Error() != expected { - t.Fatalf("expected error to contain %q, got %q", expected, err) - } + duration, err := getDurationFlag(flagSet, "non-existent-flag", time.Second) + assert.NoError(t, err) + assert.Equal(t, time.Second, duration) }) } diff --git a/internal/factory/factory.go b/internal/factory/factory.go new file mode 100644 index 0000000..852c8d0 --- /dev/null +++ b/internal/factory/factory.go @@ -0,0 +1,151 @@ +package factory + +import ( + "fmt" + "strings" + "time" + + "github.com/containeroo/portpatrol/internal/checker" + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/containeroo/portpatrol/pkg/httputils" + "github.com/containeroo/portpatrol/pkg/resolver" +) + +// CheckerWithInterval represents a checker with its interval. +type CheckerWithInterval struct { + Interval time.Duration + Checker checker.Checker +} + +// BuildCheckers creates a list of CheckerWithInterval from the parsed dynflags configuration. +func BuildCheckers(dynFlags *dynflags.DynFlags, defaultInterval time.Duration) ([]CheckerWithInterval, error) { + var checkers []CheckerWithInterval + + // Iterate over all parsed groups + for parentName, childGroups := range dynFlags.Parsed().Groups() { + checkType, err := checker.GetCheckTypeFromString(parentName) + if err != nil { + return nil, fmt.Errorf("invalid check type '%s': %w", parentName, err) + } + + // Process each parsed group (child) under the parent group + for _, group := range childGroups { + address, err := group.GetString("address") + if err != nil { + return nil, fmt.Errorf("missing address for %s checker: %w", parentName, err) + } + + resolvedAddress, err := resolver.ResolveVariable(address) + if err != nil { + return nil, fmt.Errorf("failed to resolve variable in address: %w", err) + } + + // Default interval for the checker + interval := defaultInterval + if customInterval, err := group.GetDuration("interval"); err == nil { + interval = customInterval + } + + // Prepare options based on the checker type + var opts []checker.Option + + switch checkType { + case checker.HTTP: + if method, err := group.GetString("method"); err == nil { + opts = append(opts, checker.WithHTTPMethod(method)) + } + + allowDuplicateHeaders, _ := group.GetBool("allow-duplicate-headers") // Type is checked when parsing + if headers, err := group.GetStringSlices("header"); err == nil { + headersMap, err := createHTTPHeadersMap(headers, allowDuplicateHeaders) + if err != nil { + return nil, fmt.Errorf("invalid \"--%s.%s.header\": %w", parentName, group.Name, err) + } + opts = append(opts, checker.WithHTTPHeaders(headersMap)) + } + + if allowedStatusCodes, err := group.GetString("expected-status-codes"); err == nil { + statusCodes, err := httputils.ParseStatusCodes(allowedStatusCodes) + if err != nil { + return nil, fmt.Errorf("invalid \"--%s.%s.expected-status-codes\": %w", parentName, group.Name, err) + } + + opts = append(opts, checker.WithExpectedStatusCodes(statusCodes)) + } + + if skipTLS, err := group.GetBool("skip-tls-verify"); err == nil { + opts = append(opts, checker.WithHTTPSkipTLSVerify(skipTLS)) + } + + if timeout, err := group.GetDuration("timeout"); err == nil { + opts = append(opts, checker.WithHTTPTimeout(timeout)) + } + + case checker.TCP: + if timeout, err := group.GetDuration("timeout"); err == nil { + opts = append(opts, checker.WithHTTPTimeout(timeout)) // Could have a TCP-specific timeout option + } + + case checker.ICMP: + if readTimeout, err := group.GetDuration("read-timeout"); err == nil { + opts = append(opts, checker.WithICMPReadTimeout(readTimeout)) + } + if writeTimeout, err := group.GetDuration("write-timeout"); err == nil { + opts = append(opts, checker.WithICMPWriteTimeout(writeTimeout)) + } + } + + name, _ := group.GetString("name") + if name == "" { + name = group.Name + } + + instance, err := checker.NewChecker(checkType, name, resolvedAddress, opts...) + if err != nil { + return nil, fmt.Errorf("failed to create %s checker: %w", parentName, err) + } + + // Wrap the checker with its interval and add to the list + checkers = append(checkers, CheckerWithInterval{ + Interval: interval, + Checker: instance, + }) + } + } + + return checkers, nil +} + +// createHTTPHeadersMap creates a map or slice-based map of HTTP headers from a slice of strings. +// If allowDuplicateHeaders is true, headers with the same key will be overwritten. +func createHTTPHeadersMap(headers []string, allowDuplicateHeaders bool) (map[string]string, error) { + if headers == nil { + return nil, fmt.Errorf("headers cannot be nil") + } + + headersMap := make(map[string]string) + + for _, header := range headers { + parts := strings.SplitN(header, "=", 2) + + if len(parts) != 2 || parts[0] == "" { + return nil, fmt.Errorf("invalid header format: %q", header) + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + resolved, err := resolver.ResolveVariable(value) + if err != nil { + return nil, fmt.Errorf("failed to resolve variable in header: %w", err) + } + + if _, exists := headersMap[key]; exists && !allowDuplicateHeaders { + return nil, fmt.Errorf("duplicate header: %q", header) + } + + headersMap[key] = resolved + } + + return headersMap, nil +} diff --git a/internal/factory/factory_test.go b/internal/factory/factory_test.go new file mode 100644 index 0000000..39d681b --- /dev/null +++ b/internal/factory/factory_test.go @@ -0,0 +1,220 @@ +package factory_test + +import ( + "testing" + "time" + + "github.com/containeroo/portpatrol/internal/factory" + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestBuildCheckers(t *testing.T) { + t.Parallel() + + t.Run("Valid HTTP Checker", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + httpGroup := df.Group("http") + httpGroup.String("address", "http://example.com", "HTTP target address") + httpGroup.String("method", "GET", "HTTP method") + httpGroup.Duration("interval", 5*time.Second, "Request interval") + httpGroup.StringSlices("header", nil, "HTTP header") + httpGroup.Bool("skip-tls-verify", false, "Skip TLS verification") + httpGroup.Duration("timeout", 2*time.Second, "Timeout") + + args := []string{ + "--http.mygroup.address=http://example.com", + "--http.mygroup.method=GET", + "--http.mygroup.interval=5s", + "--http.mygroup.header=Content-Type=application/json", + "--http.mygroup.skip-tls-verify=true", + "--http.mygroup.timeout=2s", + } + err := df.Parse(args) + assert.NoError(t, err) + + checkers, err := factory.BuildCheckers(df, 2*time.Second) + assert.NoError(t, err) + assert.Len(t, checkers, 1) + assert.Equal(t, "http://example.com", checkers[0].Checker.GetAddress()) + assert.Equal(t, 5*time.Second, checkers[0].Interval) + }) + + t.Run("Missing Address", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + httpGroup := df.Group("http") + httpGroup.String("method", "GET", "HTTP method") + + args := []string{"--http.mygroup.method=GET"} + err := df.Parse(args) + assert.NoError(t, err) + + checkers, err := factory.BuildCheckers(df, 2*time.Second) + assert.Nil(t, checkers) + assert.ErrorContains(t, err, "missing address for http checker") + }) + + t.Run("Invalid Check Type", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + invalidGroup := df.Group("invalid") + invalidGroup.String("address", "invalid-address", "Invalid target address") + + args := []string{"--invalid.mygroup.address=invalid-address"} + err := df.Parse(args) + assert.NoError(t, err) + + checkers, err := factory.BuildCheckers(df, 2*time.Second) + assert.Nil(t, checkers) + assert.ErrorContains(t, err, "invalid check type 'invalid'") + }) + + t.Run("Invalid Header Parsing", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + httpGroup := df.Group("http") + httpGroup.String("address", "http://example.com", "HTTP target address") + httpGroup.StringSlices("header", []string{}, "HTTP headers") + + args := []string{ + "--http.mygroup.address=http://example.com", + "--http.mygroup.header=InvalidHeaderFormat", + } + err := df.Parse(args) + assert.NoError(t, err) + + checkers, err := factory.BuildCheckers(df, 2*time.Second) + + assert.Error(t, err) + assert.EqualError(t, err, "invalid \"--http.mygroup.header\": invalid header format: \"InvalidHeaderFormat\"") + assert.Nil(t, checkers) + assert.ErrorContains(t, err, "invalid \"--http.mygroup.header\"") + }) + + t.Run("Inalid HTTP Status codes", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + httpGroup := df.Group("http") + httpGroup.String("address", "http://example.com", "HTTP target address") + httpGroup.String("expected-status-codes", "400,401", "HTTP expected status codes") + + args := []string{ + "--http.mygroup.address=http://example.com", + "--http.mygroup.expected-status-codes=201-200", + } + err := df.Parse(args) + assert.NoError(t, err) + + res := httpGroup.Lookup("expected-status-codes").GetValue() + assert.Equal(t, "201-200", res) + checkers, err := factory.BuildCheckers(df, 2*time.Second) + assert.Error(t, err) + assert.Len(t, checkers, 0) + }) + + t.Run("Valid HTTP Status codes", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + httpGroup := df.Group("http") + httpGroup.String("address", "http://example.com", "HTTP target address") + httpGroup.String("expected-status-codes", "200,201", "HTTP expected status codes") + + args := []string{ + "--http.mygroup.address=http://example.com", + "--http.mygroup.expected-status-codes=200,201", + } + err := df.Parse(args) + assert.NoError(t, err) + + checkers, err := factory.BuildCheckers(df, 2*time.Second) + assert.NoError(t, err) + assert.Len(t, checkers, 1) + }) + + t.Run("Valid TCP Checker", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + tcpGroup := df.Group("tcp") + tcpGroup.String("address", "127.0.0.1:8080", "TCP target address") + tcpGroup.Duration("timeout", 3*time.Second, "Timeout") + + args := []string{ + "--tcp.mygroup.address=127.0.0.1:8080", + "--tcp.mygroup.timeout=3s", + } + err := df.Parse(args) + assert.NoError(t, err) + + checkers, err := factory.BuildCheckers(df, 2*time.Second) + assert.NoError(t, err) + assert.Len(t, checkers, 1) + assert.Equal(t, "127.0.0.1:8080", checkers[0].Checker.GetAddress()) + }) + + t.Run("Valid ICMP Checker", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + icmpGroup := df.Group("icmp") + icmpGroup.String("address", "8.8.8.8", "ICMP target address") + icmpGroup.Duration("read-timeout", 2*time.Second, "Read timeout") + icmpGroup.Duration("write-timeout", 2*time.Second, "Write timeout") + + args := []string{ + "--icmp.mygroup.address=8.8.8.8", + "--icmp.mygroup.read-timeout=2s", + "--icmp.mygroup.write-timeout=2s", + } + err := df.Parse(args) + assert.NoError(t, err) + + checkers, err := factory.BuildCheckers(df, 2*time.Second) + assert.NoError(t, err) + assert.Len(t, checkers, 1) + assert.Equal(t, "8.8.8.8", checkers[0].Checker.GetAddress()) + }) + + t.Run("Invalid ICMP Checker", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + icmpGroup := df.Group("icmp") + icmpGroup.String("address", "8.8.8.8", "ICMP target address") + + args := []string{ + "--icmp.mygroup.address=://invalid-url", + } + + err := df.Parse(args) + assert.NoError(t, err) + + checker, err := factory.BuildCheckers(df, 2*time.Second) + assert.Nil(t, checker) + assert.Error(t, err) + }) + + t.Run("Checker Creation Failure", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + httpGroup := df.Group("http") + httpGroup.String("address", "", "HTTP target address") + + args := []string{"--http.mygroup.address="} + err := df.Parse(args) + assert.NoError(t, err) + + checkers, err := factory.BuildCheckers(df, 2*time.Second) + assert.NotNil(t, checkers) + assert.NoError(t, err) + }) +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go deleted file mode 100644 index 4418d03..0000000 --- a/internal/logger/logger.go +++ /dev/null @@ -1,36 +0,0 @@ -package logger - -import ( - "io" - "log/slog" - - "github.com/containeroo/portpatrol/internal/config" -) - -// SetupLogger configures the logger based on the configuration. -func SetupLogger(cfg config.Config, output io.Writer) *slog.Logger { - handlerOpts := &slog.HandlerOptions{} - - if cfg.LogExtraFields { - // Return a logger with the additional fields - return slog.New(slog.NewTextHandler(output, handlerOpts)).With( - slog.String("target_address", cfg.TargetAddress), - slog.String("interval", cfg.CheckInterval.String()), - slog.String("dial_timeout", cfg.DialTimeout.String()), - slog.String("checker_type", cfg.TargetCheckType.String()), - slog.String("version", cfg.Version), - ) - } - - // If logExtraFields is false, remove the error attribute from the handler. - // The error attribute is unwanted when no additional fields is set to true. - handlerOpts.ReplaceAttr = func(groups []string, a slog.Attr) slog.Attr { - if a.Key == "error" { - return slog.Attr{} - } - return a - } - - // Return a logger without the additional fields and with a function to remove the error attribute - return slog.New(slog.NewTextHandler(output, handlerOpts)) -} diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go deleted file mode 100644 index afa96d4..0000000 --- a/internal/logger/logger_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package logger - -import ( - "bytes" - "log/slog" - "strings" - "testing" - "time" - - "github.com/containeroo/portpatrol/internal/checker" - "github.com/containeroo/portpatrol/internal/config" -) - -func TestSetupLogger(t *testing.T) { - t.Parallel() - - t.Run("Log with additional fields", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - Version: "0.0.1", - TargetAddress: "localhost:8080", - CheckInterval: 1 * time.Second, - DialTimeout: 2 * time.Second, - TargetCheckType: checker.HTTP, - LogExtraFields: true, - } - var buf bytes.Buffer - - logger := SetupLogger(cfg, &buf) - logger.Info("Test log") - - logOutput := buf.String() - - expected := "target_address=localhost:8080" - if !strings.Contains(logOutput, expected) { - t.Errorf("Expected log output to contain %q, got %q", expected, logOutput) - } - - expected = "interval=1s" - if !strings.Contains(logOutput, expected) { - t.Errorf("Expected log output to contain %q, got %q", expected, logOutput) - } - - expected = "dial_timeout=2s" - if !strings.Contains(logOutput, expected) { - t.Errorf("Expected log output to contain %q, got %q", expected, logOutput) - } - - expected = "checker_type=HTTP" - if !strings.Contains(logOutput, expected) { - t.Errorf("Expected log output to contain %q, got %q", expected, logOutput) - } - - expected = "version=0.0.1" - if !strings.Contains(logOutput, expected) { - t.Errorf("Expected log output to contain %q, got %q", expected, logOutput) - } - }) - - t.Run("Log without additional fields", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - LogExtraFields: false, - } - var buf bytes.Buffer - - logger := SetupLogger(cfg, &buf) - logger.Error("Test error", slog.String("error", "some error")) - - logOutput := buf.String() - - expected := "error=some error" - if strings.Contains(logOutput, expected) { - t.Errorf("Expected error to contain %q, got %q", expected, logOutput) - } - }) -} diff --git a/internal/logging/logger.go b/internal/logging/logger.go new file mode 100644 index 0000000..ba0c292 --- /dev/null +++ b/internal/logging/logger.go @@ -0,0 +1,12 @@ +package logging + +import ( + "io" + "log/slog" +) + +// SetupLogger configures the application logger. +func SetupLogger(version string, output io.Writer) *slog.Logger { + logger := slog.New(slog.NewTextHandler(output, &slog.HandlerOptions{})) + return logger.With(slog.String("version", version)) +} diff --git a/internal/logging/logger_test.go b/internal/logging/logger_test.go new file mode 100644 index 0000000..785767f --- /dev/null +++ b/internal/logging/logger_test.go @@ -0,0 +1,54 @@ +package logging + +import ( + "strings" + "testing" +) + +func TestSetupLogger(t *testing.T) { + t.Parallel() + + // Test that the logger includes the version and outputs correctly. + t.Run("Logger includes version and writes to output", func(t *testing.T) { + t.Parallel() + + var output strings.Builder + version := "1.0.0" + + logger := SetupLogger(version, &output) + if logger == nil { + t.Fatalf("Expected a logger instance, got nil") + } + + logger.Info("Test log message") + + logOutput := output.String() + if !strings.Contains(logOutput, "Test log message") { + t.Errorf("Expected log output to contain 'Test log message', got %q", logOutput) + } + + if !strings.Contains(logOutput, "version=1.0.0") { + t.Errorf("Expected log output to contain 'version=1.0.0', got %q", logOutput) + } + }) + + // Test that the logger writes output to the correct writer. + t.Run("Logger writes to specified output", func(t *testing.T) { + t.Parallel() + + var output strings.Builder + version := "2.0.0" + + logger := SetupLogger(version, &output) + if logger == nil { + t.Fatalf("Expected a logger instance, got nil") + } + + logger.Warn("This is a warning") + + logOutput := output.String() + if !strings.Contains(logOutput, "This is a warning") { + t.Errorf("Expected log output to contain 'This is a warning', got %q", logOutput) + } + }) +} diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go deleted file mode 100644 index b46d340..0000000 --- a/internal/runner/runner_test.go +++ /dev/null @@ -1,755 +0,0 @@ -package runner - -import ( - "context" - "fmt" - "log/slog" - "net" - "net/http" - "net/url" - "strings" - "sync" - "testing" - "time" - - "github.com/containeroo/portpatrol/internal/checker" - "github.com/containeroo/portpatrol/internal/config" - "github.com/containeroo/portpatrol/internal/logger" - "github.com/containeroo/portpatrol/internal/testutils" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" -) - -func TestLoopUntilReadyHTTP(t *testing.T) { - t.Parallel() - - t.Run("HTTP target is ready", func(t *testing.T) { - t.Parallel() - - server := &http.Server{Addr: ":9082"} - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - go func() { - // Run the server in a goroutine so that it does not block the test - _ = server.ListenAndServe() - }() - - defer server.Close() - - cfg := config.Config{ - TargetName: "HTTPServer", - TargetAddress: "http://localhost:9082/", - CheckInterval: 50 * time.Millisecond, - DialTimeout: 50 * time.Millisecond, - } - - mockEnv := func(key string) string { - env := map[string]string{ - "METHOD": "GET", - "EXPECTED_STATUSES": "200", - } - return env[key] - } - - checker, err := checker.NewHTTPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout, mockEnv) - if err != nil { - t.Fatalf("Failed to create HTTPChecker: %q", err) - } - - var stdOut strings.Builder - logger := slog.New(slog.NewTextHandler(&stdOut, nil)) - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - defer cancel() - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := "HTTPServer is ready ✓" - if !strings.Contains(stdOut.String(), expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOut.String()) - } - }) - - t.Run("HTTP Target with path is ready", func(t *testing.T) { - t.Parallel() - - server := &http.Server{Addr: ":9081"} - http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - go func() { - // Run the server in a goroutine so that it does not block the test - _ = server.ListenAndServe() - }() - defer server.Close() - - cfg := config.Config{ - TargetName: "HTTPServer", - TargetAddress: "http://localhost:9081/ping", - CheckInterval: 50 * time.Millisecond, - DialTimeout: 50 * time.Millisecond, - } - - mockEnv := func(key string) string { - env := map[string]string{ - "METHOD": "GET", - "EXPECTED_STATUSES": "200", - } - return env[key] - } - - checker, err := checker.NewHTTPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout, mockEnv) - if err != nil { - t.Fatalf("Failed to create HTTPChecker: %q", err) - } - - var stdOut strings.Builder - logger := slog.New(slog.NewTextHandler(&stdOut, nil)) - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - defer cancel() - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := "HTTPServer is ready ✓" - if !strings.Contains(stdOut.String(), expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOut.String()) - } - }) - - t.Run("Successful HTTP target run after 3 attempts", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - TargetName: "HTTPServer", - TargetAddress: "http://localhost:6081/success", - CheckInterval: 500 * time.Millisecond, - DialTimeout: 500 * time.Millisecond, - TargetCheckType: checker.HTTP, - LogExtraFields: true, - Version: "1.0.0", - } - - parsedURL, err := url.Parse(cfg.TargetAddress) - if err != nil { - t.Fatalf("Failed to parse URL: %q", err) - } - - host := parsedURL.Host - - _, addressPort, err := net.SplitHostPort(host) - if err != nil { - t.Fatalf("Failed to split host and port: %q", err) - } - - var wg sync.WaitGroup - wg.Add(1) - - server := &http.Server{Addr: fmt.Sprintf(":%s", addressPort)} - http.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - go func() { - // Run the server in a goroutine so that it does not block the test - // Wait 3 times the interval before starting the server - defer wg.Done() // Mark the WaitGroup as done when the goroutine completes - time.Sleep(cfg.CheckInterval * 3) - err := server.ListenAndServe() - - if err != nil && err != http.ErrServerClosed { // After Server.Shutdown the returned error is ErrServerClosed. - panic("failed to listen: " + err.Error()) - } - time.Sleep(200 * time.Millisecond) // Ensure runloop get a successful attempt - }() - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - defer cancel() - - go func() { - // Wait for the context to be canceled - <-ctx.Done() - _ = server.Shutdown(context.Background()) // Gracefully shutdown the server - }() - - mockEnv := func(key string) string { - env := map[string]string{ - "METHOD": "GET", - "EXPECTED_STATUSES": "200", - } - return env[key] - } - - checker, err := checker.NewHTTPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout, mockEnv) - if err != nil { - t.Fatalf("Failed to create HTTPChecker: %q", err) - } - - var stdOut strings.Builder - logger := logger.SetupLogger(cfg, &stdOut) - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - wg.Wait() // Ensure server is closed after the test - - stdOutEntries := strings.Split(strings.TrimSpace(stdOut.String()), "\n") - // output must be: - // 0: Waiting for HTTPServer to become ready... - // 1: HTTPServer is not ready ✗ - // 2: HTTPServer is not ready ✗ - // 3: HTTPServer is not ready ✗ - // 4: HTTPServer is ready ✓ - lenExpectedOuts := 5 - - if len(stdOutEntries) != lenExpectedOuts { - t.Errorf("Expected output to contain '%d' lines but got '%d'.", lenExpectedOuts, len(stdOutEntries)) - } - - // First log entry: "Waiting for HTTPServer to become ready..." - expected := fmt.Sprintf("Waiting for %s to become ready...", cfg.TargetName) - if !strings.Contains(stdOutEntries[0], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[0]) - } - - from := 1 - to := 3 - for i := from; i < to; i++ { - expected := fmt.Sprintf("%s is not ready ✗", cfg.TargetName) - if !strings.Contains(stdOutEntries[i], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[i]) - } - - expected = fmt.Sprintf("error=\"Get \\\"%s\\\": dial tcp [::1]:%s: connect: connection refused\"", cfg.TargetAddress, addressPort) - if !strings.Contains(stdOutEntries[i], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[i]) - } - } - - // Last log entry: "HTTPServer is ready ✓" - expected = fmt.Sprintf("%s is ready ✓", cfg.TargetName) - if !strings.Contains(stdOutEntries[lenExpectedOuts-1], expected) { // lenExpectedOuts -1 = last element - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[1]) - } - - // Check version in the last entry - expected = fmt.Sprintf("version=%s", cfg.Version) - if !strings.Contains(stdOutEntries[lenExpectedOuts-1], expected) { // lenExpectedOuts -1 = last element - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[1]) - } - }) - - t.Run("Successful HTTP target run after 3 wrong responses", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - TargetName: "HTTPServer", - TargetAddress: "http://localhost:2081/wrong", - CheckInterval: 500 * time.Millisecond, - DialTimeout: 500 * time.Millisecond, - TargetCheckType: checker.HTTP, - LogExtraFields: true, - Version: "1.0.0", - } - - parsedURL, err := url.Parse(cfg.TargetAddress) - if err != nil { - t.Fatalf("Failed to parse URL: %q", err) - } - - host := parsedURL.Host - - _, addressPort, err := net.SplitHostPort(host) - if err != nil { - t.Fatalf("Failed to split host and port: %q", err) - } - - counter := 0 - - server := &http.Server{Addr: fmt.Sprintf(":%s", addressPort)} - http.HandleFunc("/wrong", func(w http.ResponseWriter, r *http.Request) { - if counter < 3 { - counter++ - w.WriteHeader(http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) - }) - - go func() { - // Run the server in a goroutine so that it does not block the test - _ = server.ListenAndServe() - }() - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - defer cancel() - - mockEnv := func(key string) string { - env := map[string]string{ - "METHOD": "GET", - "EXPECTED_STATUSES": "200", - } - return env[key] - } - - checker, err := checker.NewHTTPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout, mockEnv) - if err != nil { - t.Fatalf("Failed to create HTTPChecker: %q", err) - } - - var stdOut strings.Builder - logger := logger.SetupLogger(cfg, &stdOut) - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - stdOutEntries := strings.Split(strings.TrimSpace(stdOut.String()), "\n") - // output must be: - // 0: Waiting for HTTPServer to become ready... - // 1: HTTPServer is not ready ✗ - // 2: HTTPServer is not ready ✗ - // 3: HTTPServer is not ready ✗ - // 4: HTTPServer is ready ✓ - lenExpectedOuts := 5 - - if len(stdOutEntries) != lenExpectedOuts { - t.Errorf("Expected output to contain '%d' lines but got '%d'.", lenExpectedOuts, len(stdOutEntries)) - } - - // First log entry: "Waiting for HTTPServer to become ready..." - expected := fmt.Sprintf("Waiting for %s to become ready...", cfg.TargetName) - if !strings.Contains(stdOutEntries[0], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[0]) - } - - from := 1 - to := 3 - for i := from; i < to; i++ { - expected := fmt.Sprintf("%s is not ready ✗", cfg.TargetName) - if !strings.Contains(stdOutEntries[i], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[i]) - } - - expected = "error=\"unexpected status code: got 500, expected one of [200]\"" - if !strings.Contains(stdOutEntries[i], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[i]) - } - } - - // Last log entry: "HTTPServer is ready ✓" - expected = fmt.Sprintf("%s is ready ✓", cfg.TargetName) - if !strings.Contains(stdOutEntries[lenExpectedOuts-1], expected) { // lenExpectedOuts -1 = last element - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[1]) - } - - // Check version in the last entry - expected = fmt.Sprintf("version=%s", cfg.Version) - if !strings.Contains(stdOutEntries[lenExpectedOuts-1], expected) { // lenExpectedOuts -1 = last element - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[1]) - } - }) - - t.Run("HTTP target context cancled", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - TargetName: "HTTPServer", - TargetAddress: "http://localhost:7083/fail", - CheckInterval: 50 * time.Millisecond, - DialTimeout: 50 * time.Millisecond, - TargetCheckType: checker.HTTP, - } - - mockEnv := func(key string) string { - env := map[string]string{ - "METHOD": "GET", - "EXPECTED_STATUSES": "200", - } - return env[key] - } - - checker, err := checker.NewHTTPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout, mockEnv) - if err != nil { - t.Fatalf("Failed to create HTTPChecker: %q", err) - } - - var stdOut strings.Builder - logger := slog.New(slog.NewTextHandler(&stdOut, nil)) - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - - go func() { - // Wait for the context to be canceled - time.Sleep(100 * time.Millisecond) - cancel() - }() - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil && err != context.Canceled { - t.Errorf("Expected context canceled error, got %q", err) - } - - expected := fmt.Sprintf("Waiting for %s to become ready...", cfg.TargetName) - if !strings.Contains(stdOut.String(), expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOut.String()) - } - - expected = fmt.Sprintf("%s is not ready ✗", cfg.TargetName) - if !strings.Contains(stdOut.String(), expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOut.String()) - } - }) -} - -func TestLoopUntilReadyTCP(t *testing.T) { - t.Parallel() - - t.Run("TCP Target is ready", func(t *testing.T) { - t.Parallel() - - listener, err := net.Listen("tcp", "localhost:5082") - if err != nil { - t.Fatalf("Failed to start TCP server: %q", err) - } - defer listener.Close() - - cfg := config.Config{ - TargetName: "TCPServer", - TargetAddress: listener.Addr().String(), - CheckInterval: 50 * time.Millisecond, - DialTimeout: 50 * time.Millisecond, - TargetCheckType: checker.TCP, - } - - checker, err := checker.NewTCPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout) - if err != nil { - t.Fatalf("Failed to create TCPChecker: %q", err) - } - - var stdOut strings.Builder - logger := slog.New(slog.NewTextHandler(&stdOut, nil)) - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - defer cancel() - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := "TCPServer is ready ✓" - if !strings.Contains(stdOut.String(), expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOut.String()) - } - }) - - t.Run("TCP Target is ready without type", func(t *testing.T) { - t.Parallel() - - listener, err := net.Listen("tcp", "localhost:7082") - if err != nil { - t.Fatalf("Failed to start TCP server: %q", err) - } - defer listener.Close() - - cfg := config.Config{ - TargetName: "TCPServer", - TargetAddress: fmt.Sprintf("tcp://%s", listener.Addr().String()), - CheckInterval: 50 * time.Millisecond, - DialTimeout: 50 * time.Millisecond, - } - - checker, err := checker.NewTCPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout) - if err != nil { - t.Fatalf("Failed to create TCPChecker: %q", err) - } - - var stdOut strings.Builder - logger := slog.New(slog.NewTextHandler(&stdOut, nil)) - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - defer cancel() - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := "TCPServer is ready ✓" - if !strings.Contains(stdOut.String(), expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOut.String()) - } - }) - - t.Run("Successful TCP target run after 3 attempts", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - TargetName: "TCPServer", - TargetAddress: "localhost:5081", - CheckInterval: 500 * time.Millisecond, - DialTimeout: 500 * time.Millisecond, - TargetCheckType: checker.TCP, - LogExtraFields: true, - Version: "1.0.0", - } - - addressPort := strings.Split(cfg.TargetAddress, ":")[1] - - var wg sync.WaitGroup - wg.Add(1) - - var lis net.Listener - - go func() { - // Run the server in a goroutine so that it does not block the test - // Wait 3 times the interval before starting the server - defer wg.Done() // Mark the WaitGroup as done when the goroutine completes - time.Sleep(cfg.CheckInterval * 3) - var err error - lis, err = net.Listen("tcp", cfg.TargetAddress) - if err != nil { - panic("failed to listen: " + err.Error()) - } - time.Sleep(200 * time.Millisecond) // Ensure runloop get a successful attempt - }() - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - defer cancel() - - checker, err := checker.NewTCPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout) - if err != nil { - t.Fatalf("Failed to create HTTPChecker: %q", err) - } - - var stdOut strings.Builder - logger := logger.SetupLogger(cfg, &stdOut) - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - wg.Wait() - defer lis.Close() // listener must be closed after waiting group is done - - stdOutEntries := strings.Split(strings.TrimSpace(stdOut.String()), "\n") - // output must be: - // 0: Waiting for TCPServer to become ready... - // 1: TCPServer is not ready ✗ - // 2: TCPServer is not ready ✗ - // 3: TCPServer is not ready ✗ - // 4: TCPServer is ready ✓ - lenExpectedOuts := 5 - - if len(stdOutEntries) != lenExpectedOuts { - t.Errorf("Expected output to contain '%d' lines but got '%d'.", lenExpectedOuts, len(stdOutEntries)) - } - - // First log entry: "Waiting for HTTPServer to become ready..." - expected := fmt.Sprintf("Waiting for %s to become ready...", cfg.TargetName) - if !strings.Contains(stdOutEntries[0], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[0]) - } - - from := 1 - to := 3 - for i := from; i < to; i++ { - expected := fmt.Sprintf("%s is not ready ✗", cfg.TargetName) - if !strings.Contains(stdOutEntries[i], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[i]) - } - - expected = fmt.Sprintf("error=\"dial tcp [::1]:%s: connect: connection refused\"", addressPort) - if !strings.Contains(stdOutEntries[i], expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[i]) - } - } - - // Last log entry: "HTTPServer is ready ✓" - expected = fmt.Sprintf("%s is ready ✓", cfg.TargetName) - if !strings.Contains(stdOutEntries[lenExpectedOuts-1], expected) { // lenExpectedOuts -1 = last element - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[1]) - } - - // Check version in the last entry - expected = fmt.Sprintf("version=%s", cfg.Version) - if !strings.Contains(stdOutEntries[lenExpectedOuts-1], expected) { // lenExpectedOuts -1 = last element - t.Errorf("Expected output to contain %q but got %q", expected, stdOutEntries[1]) - } - }) - - t.Run("TCP target context cancled", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - TargetName: "TCPServer", - TargetAddress: "localhost:7084", - CheckInterval: 50 * time.Millisecond, - DialTimeout: 50 * time.Millisecond, - } - - checker, err := checker.NewTCPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout) - if err != nil { - t.Fatalf("Failed to create TCPChecker: %q", err) - } - - var stdOut strings.Builder - logger := slog.New(slog.NewTextHandler(&stdOut, nil)) - - ctx, cancel := context.WithTimeout(context.Background(), cfg.CheckInterval*4) - - go func() { - time.Sleep(100 * time.Millisecond) - cancel() - }() - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != nil && err != context.Canceled { - t.Errorf("Expected context canceled error, got %q", err) - } - - expected := fmt.Sprintf("Waiting for %s to become ready...", cfg.TargetName) - if !strings.Contains(stdOut.String(), expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOut.String()) - } - - expected = fmt.Sprintf("%s is not ready ✗", cfg.TargetName) - if !strings.Contains(stdOut.String(), expected) { - t.Errorf("Expected output to contain %q but got %q", expected, stdOut.String()) - } - }) - - t.Run("TCP target context deadline exceeded", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - TargetName: "TCPServer", - TargetAddress: "localhost:7084", - CheckInterval: 50 * time.Millisecond, - DialTimeout: 50 * time.Millisecond, - TargetCheckType: checker.TCP, - } - - checker, err := checker.NewTCPChecker(cfg.TargetName, cfg.TargetAddress, cfg.DialTimeout) - if err != nil { - t.Fatalf("Failed to create TCPChecker: %q", err) - } - - var stdOut strings.Builder - logger := slog.New(slog.NewTextHandler(&stdOut, nil)) - - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond)) - defer cancel() // Ensure cancel is called to free resources - - err = LoopUntilReady(ctx, cfg.CheckInterval, checker, logger) - if err != context.DeadlineExceeded { - t.Errorf("Expected context canceled error, got %q", err) - } - }) -} - -func TestICMPChecker_Check_SuccessfulICMPCheck(t *testing.T) { - t.Run("Successful ICMP Check", func(t *testing.T) { - var generatedIdentifier uint16 - var generatedSequence uint16 - - mockPacketConn := &testutils.MockPacketConn{ - WriteToFunc: func(b []byte, addr net.Addr) (int, error) { - // Capture the identifier and sequence number generated by the ICMPChecker - msg, err := icmp.ParseMessage(1, b) - if err != nil { - return 0, err - } - echo, ok := msg.Body.(*icmp.Echo) - if !ok { - return 0, fmt.Errorf("invalid ICMP message body") - } - generatedIdentifier = uint16(echo.ID) - generatedSequence = uint16(echo.Seq) - return len(b), nil - }, - ReadFromFunc: func(b []byte) (int, net.Addr, error) { - // Create a response with the captured identifier and sequence number - msg := icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: int(generatedIdentifier), - Seq: int(generatedSequence), - Data: []byte("HELLO-R-U-THERE"), - }, - } - msgBytes, _ := msg.Marshal(nil) - copy(b, msgBytes) - return len(msgBytes), &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}, nil - }, - SetReadDeadlineFunc: func(t time.Time) error { - return nil - }, - CloseFunc: func() error { - return nil - }, - } - - mockProtocol := &testutils.MockProtocol{ - MakeRequestFunc: func(identifier, sequence uint16) ([]byte, error) { - body := &icmp.Echo{ - ID: int(identifier), - Seq: int(sequence), - Data: []byte("HELLO-R-U-THERE"), - } - msg := icmp.Message{ - Type: ipv4.ICMPTypeEcho, - Code: 0, - Body: body, - } - return msg.Marshal(nil) - }, - ValidateReplyFunc: func(reply []byte, identifier, sequence uint16) error { - parsedMsg, err := icmp.ParseMessage(1, reply) - if err != nil { - return err - } - body, ok := parsedMsg.Body.(*icmp.Echo) - if !ok || body.ID != int(identifier) || body.Seq != int(sequence) { - return fmt.Errorf("identifier or sequence mismatch") - } - return nil - }, - NetworkFunc: func() string { - return "ip4:icmp" - }, - ListenPacketFunc: func(ctx context.Context, network, address string) (net.PacketConn, error) { - return mockPacketConn, nil - }, - } - - checker := &checker.ICMPChecker{ - Name: "TestChecker", - Address: "127.0.0.1", - Protocol: mockProtocol, - ReadTimeout: 2 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - err := checker.Check(ctx) - if err != nil { - t.Errorf("expected no error, got %v", err) - } - }) -} diff --git a/internal/runner/runner.go b/internal/wait/wait.go similarity index 55% rename from internal/runner/runner.go rename to internal/wait/wait.go index e193485..2b09c21 100644 --- a/internal/runner/runner.go +++ b/internal/wait/wait.go @@ -1,4 +1,4 @@ -package runner +package wait import ( "context" @@ -9,18 +9,25 @@ import ( "github.com/containeroo/portpatrol/internal/checker" ) -// LoopUntilReady continuously attempts to connect to the specified target until it becomes available or the context is canceled. -func LoopUntilReady(ctx context.Context, interval time.Duration, checker checker.Checker, logger *slog.Logger) error { - logger.Info(fmt.Sprintf("Waiting for %s to become ready...", checker)) +// WaitUntilReady continuously attempts to connect to the specified target until it becomes available or the context is canceled. +func WaitUntilReady(ctx context.Context, interval time.Duration, checker checker.Checker, logger *slog.Logger) error { + logger = logger.With( + slog.String("target", checker.GetName()), + slog.String("type", checker.GetType()), + slog.String("address", checker.GetAddress()), + slog.Duration("interval", interval), + ) + + logger.Info(fmt.Sprintf("Waiting for %s to become ready...", checker.GetName())) for { err := checker.Check(ctx) if err == nil { - logger.Info(fmt.Sprintf("%s is ready ✓", checker)) + logger.Info(fmt.Sprintf("%s is ready ✓", checker.GetName())) return nil // Successfully connected to the target } - logger.Warn(fmt.Sprintf("%s is not ready ✗", checker), slog.String("error", err.Error())) + logger.Warn(fmt.Sprintf("%s is not ready ✗", checker.GetName()), slog.String("error", err.Error())) select { case <-time.After(interval): diff --git a/internal/wait/wait_test.go b/internal/wait/wait_test.go new file mode 100644 index 0000000..9dc71ed --- /dev/null +++ b/internal/wait/wait_test.go @@ -0,0 +1,213 @@ +package wait + +import ( + "context" + "log/slog" + "net" + "net/http" + "strings" + "testing" + "time" + + "github.com/containeroo/portpatrol/internal/checker" +) + +// TestWaitUntilReady_ReadyHTTP ensures WaitUntilReady returns success when the HTTP target is ready. +func TestWaitUntilReady_ReadyHTTP(t *testing.T) { + t.Parallel() + + server := &http.Server{Addr: ":9082"} + http.HandleFunc("/ready", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + go func() { _ = server.ListenAndServe() }() + defer server.Close() + + checker, err := checker.NewChecker(checker.HTTP, "HTTPServer", "http://localhost:9082/ready") + if err != nil { + t.Fatalf("Failed to create HTTPChecker: %v", err) + } + + var output strings.Builder + logger := slog.New(slog.NewTextHandler(&output, nil)) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = WaitUntilReady(ctx, 100*time.Millisecond, checker, logger) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedLog := "HTTPServer is ready ✓" + if !strings.Contains(output.String(), expectedLog) { + t.Errorf("Expected log to contain %q, got %q", expectedLog, output.String()) + } +} + +// TestWaitUntilReady_HTTPFailsInitially tests HTTP target readiness after initial failures. +func TestWaitUntilReady_HTTPFailsInitially(t *testing.T) { + t.Parallel() + + server := &http.Server{Addr: ":9083"} + http.HandleFunc("/fail", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(500 * time.Millisecond) // Simulate a delayed start + w.WriteHeader(http.StatusOK) + }) + go func() { _ = server.ListenAndServe() }() + defer server.Close() + + checker, err := checker.NewChecker(checker.HTTP, "HTTPServer", "http://localhost:9083/fail") + if err != nil { + t.Fatalf("Failed to create HTTPChecker: %v", err) + } + + var output strings.Builder + logger := slog.New(slog.NewTextHandler(&output, nil)) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = WaitUntilReady(ctx, 100*time.Millisecond, checker, logger) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedLog := "HTTPServer is ready ✓" + if !strings.Contains(output.String(), expectedLog) { + t.Errorf("Expected log to contain %q, got %q", expectedLog, output.String()) + } +} + +// TestWaitUntilReady_HTTPContextCanceled tests behavior when the context is canceled. +func TestWaitUntilReady_HTTPContextCanceled(t *testing.T) { + t.Parallel() + + server := &http.Server{Addr: ":9084"} + http.HandleFunc("/canceled", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(500 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }) + go func() { _ = server.ListenAndServe() }() + defer server.Close() + + checker, err := checker.NewChecker(checker.HTTP, "HTTPServer", "http://localhost:9084/canceled") + if err != nil { + t.Fatalf("Failed to create HTTPChecker: %v", err) + } + + var output strings.Builder + logger := slog.New(slog.NewTextHandler(&output, nil)) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = WaitUntilReady(ctx, 50*time.Millisecond, checker, logger) + if err == nil { + t.Fatalf("Expected context cancellation error, got nil") + } + + expectedLog := "Waiting for HTTPServer to become ready..." + if !strings.Contains(output.String(), expectedLog) { + t.Errorf("Expected log to contain %q, got %q", expectedLog, output.String()) + } +} + +// TestWaitUntilReady_ReadyTCP ensures WaitUntilReady succeeds for a ready TCP target. +func TestWaitUntilReady_ReadyTCP(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "localhost:9085") + if err != nil { + t.Fatalf("Failed to create TCP server: %v", err) + } + defer ln.Close() + + checker, err := checker.NewChecker(checker.TCP, "TCPServer", "localhost:9085") + if err != nil { + t.Fatalf("Failed to create TCPChecker: %v", err) + } + + var output strings.Builder + logger := slog.New(slog.NewTextHandler(&output, nil)) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = WaitUntilReady(ctx, 100*time.Millisecond, checker, logger) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedLog := "TCPServer is ready ✓" + if !strings.Contains(output.String(), expectedLog) { + t.Errorf("Expected log to contain %q, got %q", expectedLog, output.String()) + } +} + +// TestWaitUntilReady_TCPFailsInitially tests TCP readiness after initial failures. +func TestWaitUntilReady_TCPFailsInitially(t *testing.T) { + t.Parallel() + + var ln net.Listener + go func() { + time.Sleep(500 * time.Millisecond) // Simulate a delayed server start + var err error + ln, err = net.Listen("tcp", "localhost:9086") + if err != nil { + panic("Failed to start TCP server") + } + }() + defer func() { + if ln != nil { + ln.Close() + } + }() + + checker, err := checker.NewChecker(checker.TCP, "TCPServer", "localhost:9086") + if err != nil { + t.Fatalf("Failed to create TCPChecker: %v", err) + } + + var output strings.Builder + logger := slog.New(slog.NewTextHandler(&output, nil)) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = WaitUntilReady(ctx, 100*time.Millisecond, checker, logger) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + expectedLog := "TCPServer is ready ✓" + if !strings.Contains(output.String(), expectedLog) { + t.Errorf("Expected log to contain %q, got %q", expectedLog, output.String()) + } +} + +// TestWaitUntilReady_TCPContextCanceled tests behavior when the TCP target's context is canceled. +func TestWaitUntilReady_TCPContextCanceled(t *testing.T) { + t.Parallel() + + checker, err := checker.NewChecker(checker.TCP, "TCPServer", "localhost:9087") + if err != nil { + t.Fatalf("Failed to create TCPChecker: %v", err) + } + + var output strings.Builder + logger := slog.New(slog.NewTextHandler(&output, nil)) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = WaitUntilReady(ctx, 50*time.Millisecond, checker, logger) + if err == nil { + t.Fatalf("Expected context cancellation error, got nil") + } + + expectedLog := "Waiting for TCPServer to become ready..." + if !strings.Contains(output.String(), expectedLog) { + t.Errorf("Expected log to contain %q, got %q", expectedLog, output.String()) + } +} diff --git a/pkg/dynflags/README.md b/pkg/dynflags/README.md new file mode 100644 index 0000000..d3841a0 --- /dev/null +++ b/pkg/dynflags/README.md @@ -0,0 +1,230 @@ +# DynFlags + +**DynFlags** is a Go package designed for dynamically managing hierarchical command-line flags. It supports parsing flags with a structure like `--group.identifier.flag=value` while allowing dynamic group and flag registration at runtime. For "POSIX/GNU-style --flags" use the library [pflag](https://github.com/spf13/pflag). + +## Features + +- Dynamically register groups and flags at runtime. +- Hierarchical structure for flags (`group.identifier.flag`). +- Supports multiple data types: `string`, `int`, `bool`, `float64`, `time.Duration`, etc. +- Handles unknown groups and flags with configurable behavior. +- Provides a customizable usage output. +- Designed with testability in mind by accepting `io.Writer` for output. + +## Installation + +Install the package using: + +```bash +go get github.com/containeroo/portpatrol/pkg/dynflags +``` + +## Usage + +`dynflags` is a Go package that provides a simple way to manage hierarchical command-line flags. +It supports parsing flags with a structure like `--group.identifier.flag=value` while allowing dynamic group and flag registration at runtime. +For POSIX/GNU-style `--flags` use the library [pflag](https://github.com/spf13/pflag). `dynflags` can be used together with `pflag`. + +```go +import "github.com/containeroo/portpatrol/pkg/dynflags" +``` + +Create a new `DynFlags` instance: + +```go +dynFlags := dynflags.New(dynflags.ContinueOnError) +``` + +Add groups to the `DynFlags` instance: + +```go +httpGroup := dynFlags.Group("http") +``` + +Add flags to the `DynFlags` instance: + +```go +httpGroup.String("method", "GET", "HTTP method to use") +httpGroup.Int("timeout", 5, "Timeout for HTTP requests") +// httpGroup.Bool, httpGroup.Float64, httpGroup.Duration, etc. +``` + +After all flags are defined, call + +```go +args := os.Args[1:] // Skip the first argument (the executable name) +dynflags.Parse(args) +``` + +to parse the command line into the defined flags. `args` are the command-line arguments to parse. +When using `pflag`, the the ParseBehavior is set to `dynflags.ContinueOnError` and parse first `dynflags` and then `pflag`. +Unparsed arguments are stored in `dynflags.UnparsedArgs()`. + +```go +args := os.Args[1:] // Skip the first argument (the executable name) + +// Separate known and unknown flags +if err := dynFlags.Parse(args); err != nil { + return err +} + +unknownArgs := dynFlags.UnparsedArgs() + +// Parse known flags +if err := flagSet.Parse(unknownArgs); err != nil { + return err +} +``` + +`dynflags` provides 3 Groups: + +- `dynflags.Config()` returns a `ConfigGroups` instance that provides direct access to the static configuration of the `DynFlags` instance. +- `dynflags.Parsed()` returns a `ParsedGroups` instance that provides direct access to the parsed configuration of the `DynFlags` instance. +- `dynflags.Unknown()` returns a `UnknownGroups` instance that provides direct access to the unknown configuration of the `DynFlags` instance. + +Each of these Groups provides a `Lookup("SEARCH")` method that can be used to retrieve a specific group or flag. + +```go +// Retrieve the "http" group +httpGroups := dynFlags.Parsed().Lookup("http") +// Retrieve "identifier1" object within "http" +httpIdentifier1 := httpGroups.Lookup("identifier1") +// Retrieve "method" object within "identifier1" +method := httpIdentifier1.Lookup("method") +// Show value of "method" within "identifier1" +value := method.Value() +fmt.Printf("Method: %s\n", value) +``` + +and each of these Groups provides a `Groups()` method that can be used to iterate over all groups. + +```go +for groupName, groups := range dynFlags.Parsed().Groups() { + fmt.Printf("Group: %s\n", groupName) + for _, group := range groups { + fmt.Printf(" Identifier: %s\n", group.Name) + for flagName, value := range group.Values { + fmt.Printf(" Flag: %s, Value: %v\n", flagName, value) + } + } +} +``` + +## Title, Description, and Epilog + +`dynflags` allows you to set a title, description, and epilog for the help message. +You can also change the default usage output by setting the `Usage` field of a group. If not set, it uses the Group name in uppercase. + +**Example:** + +```go +dynFlags := dynflags.New(dynflags.ContinueOnError) +dynFlags.Title("DynFlags Example Application") +dynFlags.Description("This application demonstrates the usage of DynFlags for managing hierarchical flags dynamically.") +dynFlags.Epilog("For more information, see https://github.com/containerish/portpatrol") + +tcpGroup := dynFlags.Group("tcp") +tcGroup.Usage("TCP flags") +tcpGroup.String("Timeout", "10s", "TCP timeout") +tcpGroup.String("address", "127.0.0.1:8080", "TCP target address") + +httpGroup := dynFlags.Group("http") +httpGroup.Usage("HTTP flags") +httpGroup.String("method", "GET", "HTTP method to use") +httpGroup.String("address", "https://example.com", "HTTP target URL") + +dynFlags.PrintDefaults() +``` + +**Output:** + +```text +DynFlags Example Application + +This application demonstrates the usage of DynFlags for managing hierarchical flags dynamically. + +TCP flags + Flag Usage + --tcp..Timeout STRING TCP timeout (default: 10s) + --tcp..address STRING TCP target address (default: 127.0.0.1:8080) + +HTTP flags + Flag Usage + --http..method STRING HTTP method to use (default: GET) + --http..address STRING HTTP target URL (default: https://example.com) + + +For more information, see https://github.com/containerish/portpatrol +``` + +## Disable sorting of flags + +`dynflags` allows you to disable sorting of groups and flags for help and usage message. Sort is disabled by default. + +**Example:** + +```go +dynFlags := dynflags.New(dynflags.ContinueOnError) +tcpGroup := dynFlags.Group("tcp") +tcpGroup.String("Timeout", "10s", "TCP timeout") +tcpGroup.String("address", "127.0.0.1:8080", "TCP target address") + +httpGroup := dynFlags.Group("http") +httpGroup.String("method", "GET", "HTTP method to use") +httpGroup.String("address", "https://example.com", "HTTP target URL") + +dynFlags.SortGroups = true +dynFlags.SortFlags = true +dynFlags.PrintDefaults() +``` + +**Output:** + +```text +HTTP + Flag Usage + --http..address STRING HTTP target URL (default: https://example.com) + --http..method STRING HTTP method to use (default: GET) + +TCP + Flag Usage + --tcp..Timeout STRING TCP timeout (default: 10s) + --tcp..address STRING TCP target address (default: 127.0.0.1:8080) +``` + +## MetaVar + +`MetaVar` is a string that is used to represent the flag in the usage message. It defaults to the flag type in uppercase. + +- String flags: `--.. STRING` +- Boolean flags: `--.. BOOL` + +Slices will have a `MetaVar` with the base type in uppdercase, followed by a small `s`. + +- String Slices: `--.. STRINGs` +- Boolean Slices: `--.. BOOLs` + +To change the `MetaVar` for a flag, set the `MetaVar` field on the flag. + +**Example:** + +```go +dynFlags := dynflags.New(dynflags.ContinueOnError) +tcpGroup := dynFlags.Group("tcp") +timeout := tcpGroup.String("Timeout", "10s", "TCP timeout") +timeout.MetaVar("CUSTOM") +dynFlags.PrintDefaults() +``` + +**Output:** + +```text +TCP + Flag Usage + --tcp..Timeout CUSTOM TCP timeout (default: 10s) +``` + +## Examples + +The `examples` directory contains a simple example that demonstrates the usage of `dynflags`, as well as an advanced example that shows how to use `dynflags` with `pflag`. + diff --git a/pkg/dynflags/bool.go b/pkg/dynflags/bool.go new file mode 100644 index 0000000..30c1159 --- /dev/null +++ b/pkg/dynflags/bool.go @@ -0,0 +1,57 @@ +package dynflags + +import ( + "fmt" + "strconv" +) + +// BoolValue implementation for boolean flags +type BoolValue struct { + Bound *bool +} + +func (b *BoolValue) GetBound() interface{} { + if b.Bound == nil { + return nil + } + return *b.Bound +} + +func (b *BoolValue) Parse(value string) (interface{}, error) { + return strconv.ParseBool(value) +} + +func (b *BoolValue) Set(value interface{}) error { + if val, ok := value.(bool); ok { + *b.Bound = val + return nil + } + return fmt.Errorf("invalid value type: expected bool") +} + +// Bool defines a bool flag with specified name, default value, and usage string. +// The return value is the address of a bool variable that stores the value of the flag. +func (g *ConfigGroup) Bool(name string, value bool, usage string) *Flag { + bound := &value + flag := &Flag{ + Type: FlagTypeBool, + Default: value, + Usage: usage, + value: &BoolValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetBool returns the bool value of a flag with the given name +func (pg *ParsedGroup) GetBool(flagName string) (bool, error) { + value, exists := pg.Values[flagName] + if !exists { + return false, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + if boolVal, ok := value.(bool); ok { + return boolVal, nil + } + return false, fmt.Errorf("flag '%s' is not a bool", flagName) +} diff --git a/pkg/dynflags/bool_slice.go b/pkg/dynflags/bool_slice.go new file mode 100644 index 0000000..f0aafd1 --- /dev/null +++ b/pkg/dynflags/bool_slice.go @@ -0,0 +1,72 @@ +package dynflags + +import ( + "fmt" + "strconv" + "strings" +) + +// BoolSlicesValue implementation for bool slice flags +type BoolSlicesValue struct { + Bound *[]bool +} + +func (b *BoolSlicesValue) GetBound() interface{} { + if b.Bound == nil { + return nil + } + return *b.Bound +} + +func (b *BoolSlicesValue) Parse(value string) (interface{}, error) { + parsed, err := strconv.ParseBool(value) + if err != nil { + return nil, fmt.Errorf("invalid boolean value: %s, error: %w", value, err) + } + return parsed, nil +} + +func (b *BoolSlicesValue) Set(value interface{}) error { + if parsedBool, ok := value.(bool); ok { + *b.Bound = append(*b.Bound, parsedBool) + return nil + } + return fmt.Errorf("invalid value type: expected bool") +} + +// BoolSlices defines a bool slice flag with specified name, default value, and usage string. +// The return value is the address of a slice of bool that stores the value of the flag. +func (g *ConfigGroup) BoolSlices(name string, value []bool, usage string) *Flag { + bound := &value + defaultValue := make([]string, len(value)) + for i, v := range value { + defaultValue[i] = strconv.FormatBool(v) + } + flag := &Flag{ + Type: FlagTypeBoolSlice, + Default: strings.Join(defaultValue, ","), + Usage: usage, + value: &BoolSlicesValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + + return flag +} + +// GetBoolSlices returns the []bool value of a flag with the given name +func (pg *ParsedGroup) GetBoolSlices(flagName string) ([]bool, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + if slice, ok := value.([]bool); ok { + return slice, nil + } + + if b, ok := value.(bool); ok { + return []bool{b}, nil + } + + return nil, fmt.Errorf("flag '%s' is not a []bool", flagName) +} diff --git a/pkg/dynflags/bool_slice_test.go b/pkg/dynflags/bool_slice_test.go new file mode 100644 index 0000000..757a95e --- /dev/null +++ b/pkg/dynflags/bool_slice_test.go @@ -0,0 +1,146 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestBoolSlicesValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid bool value", func(t *testing.T) { + t.Parallel() + + boolSlicesValue := dynflags.BoolSlicesValue{Bound: &[]bool{}} + parsed, err := boolSlicesValue.Parse("true") + assert.NoError(t, err) + assert.Equal(t, true, parsed) + }) + + t.Run("Parse invalid bool value", func(t *testing.T) { + t.Parallel() + + boolSlicesValue := dynflags.BoolSlicesValue{Bound: &[]bool{}} + parsed, err := boolSlicesValue.Parse("invalid") + assert.Error(t, err) + assert.Nil(t, parsed) + }) + + t.Run("Set valid bool value", func(t *testing.T) { + t.Parallel() + + bound := []bool{true} + boolSlicesValue := dynflags.BoolSlicesValue{Bound: &bound} + + err := boolSlicesValue.Set(false) + assert.NoError(t, err) + assert.Equal(t, []bool{true, false}, bound) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := []bool{} + boolSlicesValue := dynflags.BoolSlicesValue{Bound: &bound} + + err := boolSlicesValue.Set("invalid") + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected bool") + }) +} + +func TestGroupConfigBoolSlices(t *testing.T) { + t.Parallel() + + t.Run("Define bool slices flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := []bool{true, false} + group.BoolSlices("boolSliceFlag", defaultValue, "A bool slices flag") + + assert.Contains(t, group.Flags, "boolSliceFlag") + assert.Equal(t, "A bool slices flag", group.Flags["boolSliceFlag"].Usage) + assert.Equal(t, "true,false", group.Flags["boolSliceFlag"].Default) + }) +} + +func TestGetBoolSlices(t *testing.T) { + t.Parallel() + + t.Run("Retrieve []bool value", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{ + "flag1": []bool{true, false, true}, + }, + } + + result, err := parsedGroup.GetBoolSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []bool{true, false, true}, result) + }) + + t.Run("Retrieve single bool value as []bool", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{ + "flag1": true, + }, + } + + result, err := parsedGroup.GetBoolSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []bool{true}, result) + }) + + t.Run("Flag not found", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + result, err := parsedGroup.GetBoolSlices("nonExistentFlag") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'nonExistentFlag' not found in group 'testGroup'") + }) + + t.Run("Flag value is invalid type", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{ + "flag1": "invalid", + }, + } + + result, err := parsedGroup.GetBoolSlices("flag1") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'flag1' is not a []bool") + }) +} + +func TestBoolSlicesGetBound(t *testing.T) { + t.Run("BoolSlicesValue - GetBound", func(t *testing.T) { + var slices *[]bool + val := []bool{true, false, true} + slices = &val + + boolSlicesValue := dynflags.BoolSlicesValue{Bound: slices} + assert.Equal(t, val, boolSlicesValue.GetBound()) + + boolSlicesValue = dynflags.BoolSlicesValue{Bound: nil} + assert.Nil(t, boolSlicesValue.GetBound()) + }) +} diff --git a/pkg/dynflags/bool_test.go b/pkg/dynflags/bool_test.go new file mode 100644 index 0000000..5bf6d7a --- /dev/null +++ b/pkg/dynflags/bool_test.go @@ -0,0 +1,146 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestBoolValue_Parse(t *testing.T) { + t.Parallel() + + t.Run("ValidTrueValue", func(t *testing.T) { + t.Parallel() + + b := &dynflags.BoolValue{Bound: new(bool)} + value, err := b.Parse("true") + assert.NoError(t, err) + assert.Equal(t, true, value) + }) + + t.Run("ValidFalseValue", func(t *testing.T) { + t.Parallel() + + b := &dynflags.BoolValue{Bound: new(bool)} + value, err := b.Parse("false") + assert.NoError(t, err) + assert.Equal(t, false, value) + }) + + t.Run("InvalidValue", func(t *testing.T) { + t.Parallel() + + b := &dynflags.BoolValue{Bound: new(bool)} + value, err := b.Parse("invalid") + assert.Error(t, err) + assert.Equal(t, value, false) + }) +} + +func TestBoolValue_Set(t *testing.T) { + t.Parallel() + + t.Run("SetValidTrue", func(t *testing.T) { + t.Parallel() + + bound := new(bool) + b := &dynflags.BoolValue{Bound: bound} + err := b.Set(true) + assert.NoError(t, err) + assert.Equal(t, true, *bound) + }) + + t.Run("SetValidFalse", func(t *testing.T) { + t.Parallel() + + bound := new(bool) + b := &dynflags.BoolValue{Bound: bound} + err := b.Set(false) + assert.NoError(t, err) + assert.Equal(t, false, *bound) + }) + + t.Run("SetInvalidValue", func(t *testing.T) { + t.Parallel() + + bound := new(bool) + b := &dynflags.BoolValue{Bound: bound} + err := b.Set(123) // Invalid type + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected bool") + }) +} + +func TestGroupConfig_Bool(t *testing.T) { + t.Parallel() + + t.Run("DefaultBool", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{ + Flags: make(map[string]*dynflags.Flag), + } + group.Bool("testBool", true, "Test boolean flag") + flag := group.Flags["testBool"] + assert.NotNil(t, flag) + assert.Equal(t, dynflags.FlagTypeBool, flag.Type) + assert.Equal(t, true, flag.Default) + }) +} + +func TestParsedGroup_GetBool(t *testing.T) { + t.Parallel() + + t.Run("GetExistingBool", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"testBool": true}, + } + value, err := parsedGroup.GetBool("testBool") + assert.NoError(t, err) + assert.Equal(t, true, value) + }) + + t.Run("GetNonExistentBool", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + value, err := parsedGroup.GetBool("nonExistent") + assert.Error(t, err) + assert.Equal(t, false, value) + assert.EqualError(t, err, "flag 'nonExistent' not found in group 'testGroup'") + }) + + t.Run("GetInvalidBoolType", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"invalidBool": "notABool"}, + } + value, err := parsedGroup.GetBool("invalidBool") + assert.Error(t, err) + assert.Equal(t, false, value) + assert.EqualError(t, err, "flag 'invalidBool' is not a bool") + }) +} + +func TestBoolGetBound(t *testing.T) { + t.Run("BoolValue - GetBound", func(t *testing.T) { + var b *bool + val := true + b = &val + + boolValue := dynflags.BoolValue{Bound: b} + assert.Equal(t, true, boolValue.GetBound()) + + boolValue = dynflags.BoolValue{Bound: nil} + assert.Nil(t, boolValue.GetBound()) + }) +} diff --git a/pkg/dynflags/defaults.go b/pkg/dynflags/defaults.go new file mode 100644 index 0000000..51e465a --- /dev/null +++ b/pkg/dynflags/defaults.go @@ -0,0 +1,73 @@ +package dynflags + +import ( + "fmt" + "sort" + "strings" + "text/tabwriter" +) + +// PrintDefaults prints all registered flags +func (df *DynFlags) PrintDefaults() { + w := tabwriter.NewWriter(df.output, 0, 8, 2, ' ', 0) + defer w.Flush() + + // Print title if present + if df.title != "" { + fmt.Fprintln(df.output, df.title) + fmt.Fprintln(df.output) + } + + // Print description if present + if df.description != "" { + fmt.Fprintln(df.output, df.description) + fmt.Fprintln(df.output) + } + + // Sort group names + if df.SortGroups { + sort.Strings(df.groupOrder) + } + + // Iterate over groups in the order they were added + for _, groupName := range df.groupOrder { + group := df.configGroups[groupName] + + // Print group usage or fallback to uppercase group name + if group.usage != "" { + fmt.Fprintln(w, group.usage) + } else { + fmt.Fprintln(w, strings.ToUpper(groupName)) + } + + // Sort flag names + if df.SortFlags { + sort.Strings(group.flagOrder) + } + + // Print flags for the group + if len(group.flagOrder) > 0 { + fmt.Fprintln(w, " Flag\tUsage") + for _, flagName := range group.flagOrder { + flag := group.Flags[flagName] + usage := flag.Usage + if flag.Default != nil && flag.Default != "" { + usage = fmt.Sprintf("%s (default: %v)", flag.Usage, flag.Default) + } + metavar := string(flag.Type) + if flag.metaVar != "" { + metavar = flag.metaVar + } + + fmt.Fprintf(w, " --%s..%s %s\t%s\n", groupName, flagName, metavar, usage) + } + fmt.Fprintln(w, "") + } + } + + // Print epilog if present + if df.epilog != "" { + fmt.Fprintln(df.output) + fmt.Fprintln(df.output, df.epilog) + } +} diff --git a/pkg/dynflags/defaults_test.go b/pkg/dynflags/defaults_test.go new file mode 100644 index 0000000..2874806 --- /dev/null +++ b/pkg/dynflags/defaults_test.go @@ -0,0 +1,199 @@ +package dynflags_test + +import ( + "bytes" + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestPrintDefaults(t *testing.T) { + t.Parallel() + + t.Run("No groups, title, description, or epilog", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + + df.PrintDefaults() + + output := buf.String() + assert.Empty(t, output) + }) + + t.Run("Only title is present", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + df.Title("Test Title") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "Test Title") + assert.NotContains(t, output, "Usage:") + }) + + t.Run("Only description is present", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + df.Description("Test Description") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "Test Description") + assert.NotContains(t, output, "Usage:") + }) + + t.Run("Only epilog is present", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + df.Epilog("Test Epilog") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "Test Epilog") + }) + + t.Run("Title, description, and epilog are all present", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + df.Title("Test Title") + df.Description("Test Description") + df.Epilog("Test Epilog") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "Test Title") + assert.Contains(t, output, "Test Description") + assert.Contains(t, output, "Test Epilog") + }) + + t.Run("Single group with unsorted flags", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + group := df.Group("test") + group.String("flag2", "", "Second flag") + group.String("flag1", "", "First flag") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "TEST") + assert.Contains(t, output, "--test..flag2") + assert.Contains(t, output, "--test..flag1") + }) + + t.Run("Multiple groups with sorted flags", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + df.SortFlags = true + df.SortGroups = true + group1 := df.Group("test1") + group1.String("flagA", "", "Flag A") + group1.String("flagB", "", "Flag B") + + group2 := df.Group("test2") + group2.String("flagX", "", "Flag X") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "TEST1") + assert.Contains(t, output, "--test1..flagA") + assert.Contains(t, output, "--test1..flagB") + assert.Contains(t, output, "TEST2") + assert.Contains(t, output, "--test2..flagX") + }) + + t.Run("Group with usage text", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + group := df.Group("test") + group.Usage("Test Group Usage") + group.String("flag", "", "Test flag") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "Test Group Usage") + assert.Contains(t, output, "--test..flag") + }) + + t.Run("Flags with and without default values", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + group := df.Group("test") + group.String("flag1", "default1", "Flag with default") + group.String("flag2", "", "Flag without default") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "--test..flag1") + assert.Contains(t, output, "(default: default1)") + assert.Contains(t, output, "--test..flag2") + assert.NotContains(t, output, "(default: )") + }) + + t.Run("Empty group with no flags", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + df.Group("test") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "TEST") + assert.NotContains(t, output, "Flag\tUsage") + }) + + t.Run("Metavar", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + g := df.Group("test") + s := g.String("test", "", "Test flag") + s.MetaVar("CUSTOM") + + df.PrintDefaults() + + output := buf.String() + assert.Contains(t, output, "CUSTOM") + assert.NotContains(t, output, "Flag\tUsage") + }) +} diff --git a/pkg/dynflags/duration.go b/pkg/dynflags/duration.go new file mode 100644 index 0000000..35dbcec --- /dev/null +++ b/pkg/dynflags/duration.go @@ -0,0 +1,57 @@ +package dynflags + +import ( + "fmt" + "time" +) + +// DurationValue implementation for duration flags +type DurationValue struct { + Bound *time.Duration +} + +func (d *DurationValue) GetBound() interface{} { + if d.Bound == nil { + return nil + } + return *d.Bound +} + +func (d *DurationValue) Parse(value string) (interface{}, error) { + return time.ParseDuration(value) +} + +func (d *DurationValue) Set(value interface{}) error { + if dur, ok := value.(time.Duration); ok { + *d.Bound = dur + return nil + } + return fmt.Errorf("invalid value type: expected duration") +} + +// Duration defines a duration flag with specified name, default value, and usage string. +// The return value is the address of a time.Duration variable that stores the value of the flag. +func (g *ConfigGroup) Duration(name string, value time.Duration, usage string) *Flag { + bound := &value + flag := &Flag{ + Type: FlagTypeDuration, + Default: value, + Usage: usage, + value: &DurationValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetDuration returns the time.Duration value of a flag with the given name +func (pg *ParsedGroup) GetDuration(flagName string) (time.Duration, error) { + vaue, exists := pg.Values[flagName] + if !exists { + return 0, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + if durationVal, ok := vaue.(time.Duration); ok { + return durationVal, nil + } + return 0, fmt.Errorf("flag '%s' is not a time.Duration", flagName) +} diff --git a/pkg/dynflags/duration_slice.go b/pkg/dynflags/duration_slice.go new file mode 100644 index 0000000..84ba82d --- /dev/null +++ b/pkg/dynflags/duration_slice.go @@ -0,0 +1,73 @@ +package dynflags + +import ( + "fmt" + "strings" + "time" +) + +// DurationSlicesValue implementation for duration slice flags +type DurationSlicesValue struct { + Bound *[]time.Duration +} + +func (d *DurationSlicesValue) GetBound() interface{} { + if d.Bound == nil { + return nil + } + return *d.Bound +} + +func (d *DurationSlicesValue) Parse(value string) (interface{}, error) { + parsed, err := time.ParseDuration(value) + if err != nil { + return nil, fmt.Errorf("invalid duration value: %s, error: %w", value, err) + } + return parsed, nil +} + +func (d *DurationSlicesValue) Set(value interface{}) error { + if parsedDuration, ok := value.(time.Duration); ok { + *d.Bound = append(*d.Bound, parsedDuration) + return nil + } + return fmt.Errorf("invalid value type: expected time.Duration") +} + +// DurationSlices defines a duration slice flag with specified name, default value, and usage string. +// The return value is the address of a slice of durations that stores the value of the flag. +func (g *ConfigGroup) DurationSlices(name string, value []time.Duration, usage string) *Flag { + bound := &value + defaultValue := make([]string, len(value)) + for i, v := range value { + defaultValue[i] = v.String() + } + + flag := &Flag{ + Type: FlagTypeDurationSlice, + Default: strings.Join(defaultValue, ","), + Usage: usage, + value: &DurationSlicesValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetDurationSlices returns the []time.Duration value of a flag with the given name +func (pg *ParsedGroup) GetDurationSlices(flagName string) ([]time.Duration, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + + if slice, ok := value.([]time.Duration); ok { + return slice, nil + } + + if d, ok := value.(time.Duration); ok { + return []time.Duration{d}, nil + } + + return nil, fmt.Errorf("flag '%s' is not a []time.Duration", flagName) +} diff --git a/pkg/dynflags/duration_slice_test.go b/pkg/dynflags/duration_slice_test.go new file mode 100644 index 0000000..3403708 --- /dev/null +++ b/pkg/dynflags/duration_slice_test.go @@ -0,0 +1,147 @@ +package dynflags_test + +import ( + "testing" + "time" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestDurationSlicesValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid duration slice value", func(t *testing.T) { + t.Parallel() + + durationSlicesValue := dynflags.DurationSlicesValue{Bound: &[]time.Duration{}} + parsed, err := durationSlicesValue.Parse("5s") + assert.NoError(t, err) + assert.Equal(t, 5*time.Second, parsed) + }) + + t.Run("Parse invalid duration value", func(t *testing.T) { + t.Parallel() + + durationSlicesValue := dynflags.DurationSlicesValue{Bound: &[]time.Duration{}} + parsed, err := durationSlicesValue.Parse("invalid") + assert.Error(t, err) + assert.Nil(t, parsed) + }) + + t.Run("Set valid duration value", func(t *testing.T) { + t.Parallel() + + bound := []time.Duration{1 * time.Second} + durationSlicesValue := dynflags.DurationSlicesValue{Bound: &bound} + + err := durationSlicesValue.Set(2 * time.Second) + assert.NoError(t, err) + assert.Equal(t, []time.Duration{1 * time.Second, 2 * time.Second}, bound) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := []time.Duration{} + durationSlicesValue := dynflags.DurationSlicesValue{Bound: &bound} + + err := durationSlicesValue.Set("invalid") + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected time.Duration") + }) +} + +func TestGroupConfigDurationSlices(t *testing.T) { + t.Parallel() + + t.Run("Define duration slices flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := []time.Duration{1 * time.Second, 2 * time.Second} + group.DurationSlices("durationSliceFlag", defaultValue, "A duration slices flag") + + assert.Contains(t, group.Flags, "durationSliceFlag") + assert.Equal(t, "A duration slices flag", group.Flags["durationSliceFlag"].Usage) + assert.Equal(t, "1s,2s", group.Flags["durationSliceFlag"].Default) + }) +} + +func TestGetDurationSlices(t *testing.T) { + t.Parallel() + + t.Run("Retrieve []time.Duration value", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{ + "flag1": []time.Duration{1 * time.Second, 2 * time.Second, 3 * time.Second}, + }, + } + + result, err := parsedGroup.GetDurationSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []time.Duration{1 * time.Second, 2 * time.Second, 3 * time.Second}, result) + }) + + t.Run("Retrieve single time.Duration value as []time.Duration", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{ + "flag1": 5 * time.Second, + }, + } + + result, err := parsedGroup.GetDurationSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []time.Duration{5 * time.Second}, result) + }) + + t.Run("Flag not found", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + result, err := parsedGroup.GetDurationSlices("nonExistentFlag") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'nonExistentFlag' not found in group 'testGroup'") + }) + + t.Run("Flag value is invalid type", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{ + "flag1": "invalid", + }, + } + + result, err := parsedGroup.GetDurationSlices("flag1") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'flag1' is not a []time.Duration") + }) +} + +func TestDurationSlicesGetBound(t *testing.T) { + t.Run("DurationSlicesValue - GetBound", func(t *testing.T) { + var slices *[]time.Duration + val := []time.Duration{1 * time.Second, 2 * time.Second} + slices = &val + + durationSlicesValue := dynflags.DurationSlicesValue{Bound: slices} + assert.Equal(t, val, durationSlicesValue.GetBound()) + + durationSlicesValue = dynflags.DurationSlicesValue{Bound: nil} + assert.Nil(t, durationSlicesValue.GetBound()) + }) +} diff --git a/pkg/dynflags/duration_test.go b/pkg/dynflags/duration_test.go new file mode 100644 index 0000000..9c92e41 --- /dev/null +++ b/pkg/dynflags/duration_test.go @@ -0,0 +1,123 @@ +package dynflags_test + +import ( + "testing" + "time" + + "github.com/containeroo/portpatrol/pkg/dynflags" + + "github.com/stretchr/testify/assert" +) + +func TestDurationValue_Parse(t *testing.T) { + t.Parallel() + + t.Run("ValidDuration", func(t *testing.T) { + t.Parallel() + + d := &dynflags.DurationValue{} + value, err := d.Parse("2h") + assert.NoError(t, err) + assert.Equal(t, 2*time.Hour, value) + }) + + t.Run("InvalidDuration", func(t *testing.T) { + t.Parallel() + + d := &dynflags.DurationValue{} + _, err := d.Parse("invalid") + assert.Error(t, err) + }) +} + +func TestDurationValue_Set(t *testing.T) { + t.Parallel() + + t.Run("SetValidDuration", func(t *testing.T) { + t.Parallel() + + var bound time.Duration + d := &dynflags.DurationValue{Bound: &bound} + err := d.Set(1 * time.Minute) + assert.NoError(t, err) + assert.Equal(t, 1*time.Minute, bound) + }) + + t.Run("SetInvalidType", func(t *testing.T) { + t.Parallel() + + var bound time.Duration + d := &dynflags.DurationValue{Bound: &bound} + err := d.Set("not a duration") + assert.Error(t, err) + assert.Equal(t, time.Duration(0), bound) + }) +} + +func TestGroupConfig_Duration(t *testing.T) { + t.Parallel() + + t.Run("DurationDefault", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := 5 * time.Second + bound := group.Duration("timeout", defaultValue, "Timeout duration") + assert.Equal(t, defaultValue, bound.Default) + assert.Contains(t, group.Flags, "timeout") + assert.Equal(t, defaultValue, group.Flags["timeout"].Default) + assert.Equal(t, dynflags.FlagTypeDuration, group.Flags["timeout"].Type) + }) +} + +func TestParsedGroup_GetDuration(t *testing.T) { + t.Parallel() + + t.Run("GetValidDuration", func(t *testing.T) { + t.Parallel() + + parsed := &dynflags.ParsedGroup{ + Name: "test", + Values: map[string]interface{}{"timeout": 30 * time.Second}, + } + dur, err := parsed.GetDuration("timeout") + assert.NoError(t, err) + assert.Equal(t, 30*time.Second, dur) + }) + + t.Run("GetDurationNotFound", func(t *testing.T) { + t.Parallel() + + parsed := &dynflags.ParsedGroup{ + Name: "test", + Values: map[string]interface{}{}, + } + _, err := parsed.GetDuration("missing") + assert.Error(t, err) + }) + + t.Run("GetDurationWrongType", func(t *testing.T) { + t.Parallel() + + parsed := &dynflags.ParsedGroup{ + Name: "test", + Values: map[string]interface{}{"timeout": "not a duration"}, + } + _, err := parsed.GetDuration("timeout") + assert.Error(t, err) + }) +} + +func TestDurationGetBound(t *testing.T) { + t.Run("DurationValue - GetBound", func(t *testing.T) { + var d *time.Duration + val := 2 * time.Second + d = &val + + durationValue := dynflags.DurationValue{Bound: d} + assert.Equal(t, val, durationValue.GetBound()) + + durationValue = dynflags.DurationValue{Bound: nil} + assert.Nil(t, durationValue.GetBound()) + }) +} diff --git a/pkg/dynflags/dynflags.go b/pkg/dynflags/dynflags.go new file mode 100644 index 0000000..ca186fb --- /dev/null +++ b/pkg/dynflags/dynflags.go @@ -0,0 +1,91 @@ +package dynflags + +import ( + "fmt" + "io" + "os" +) + +// ParseBehavior defines how the parser handles errors +type ParseBehavior int + +const ( + // Continue parsing on error + ContinueOnError ParseBehavior = iota + // Try to parse unknown flags. Unknown flags can be retrived with the method Unknown() on the DynFlags instance + ParseUnknown + // Exit on error + ExitOnError +) + +// DynFlags manages configuration and parsed values +type DynFlags struct { + configGroups map[string]*ConfigGroup // Static parent groups + groupOrder []string // Order of group names + SortGroups bool // Sort groups in help message + SortFlags bool // Sort flags in help message + parsedGroups map[string][]*ParsedGroup // Parsed child groups organized by parent group + unknownGroups map[string][]*UnknownGroup // Unknown parent groups and their parsed values + parseBehavior ParseBehavior // Parsing behavior + unparsedArgs []string // Arguments that couldn't be parsed + output io.Writer // Output for usage/help + usage func() // Customizable usage function + title string // Title in the help message + description string // Description after the title in the help message + epilog string // Epilog in the help message +} + +// New initializes a new DynFlags instance +func New(behavior ParseBehavior) *DynFlags { + df := &DynFlags{ + configGroups: make(map[string]*ConfigGroup), + parsedGroups: make(map[string][]*ParsedGroup), + unknownGroups: make(map[string][]*UnknownGroup), + parseBehavior: behavior, + output: os.Stdout, + } + df.usage = func() { df.Usage() } + return df +} + +// Title adds a title to the help message +func (df *DynFlags) Title(title string) { + df.title = title +} + +// Description adds a descripton after the Title +func (df *DynFlags) Description(description string) { + df.description = description +} + +// Epilog adds an epilog after the description of the dynamic flags to the help message +func (df *DynFlags) Epilog(epilog string) { + df.epilog = epilog +} + +// Group defines a new group or retrieves an existing one +func (df *DynFlags) Group(name string) *ConfigGroup { + if _, exists := df.configGroups[name]; exists { + return df.configGroups[name] + } + + df.groupOrder = append(df.groupOrder, name) + + group := &ConfigGroup{ + Name: name, + Flags: make(map[string]*Flag), + } + df.configGroups[name] = group + return group +} + +// DefaultUsage provides the default usage output +func (df *DynFlags) Usage() { + fmt.Fprintf(df.output, "Usage: [--.. value]\n\n") + df.PrintDefaults() +} + +// SetOutput sets the output writer +func (df *DynFlags) SetOutput(buf io.Writer) { + df.output = buf +} diff --git a/pkg/dynflags/dynflags_test.go b/pkg/dynflags/dynflags_test.go new file mode 100644 index 0000000..20d009c --- /dev/null +++ b/pkg/dynflags/dynflags_test.go @@ -0,0 +1,156 @@ +package dynflags_test + +import ( + "bytes" + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestDynFlagsInitialization(t *testing.T) { + t.Parallel() + + t.Run("New initializes correctly", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + assert.NotNil(t, df) + assert.NotNil(t, df.Config()) + assert.NotNil(t, df.Parsed()) + assert.NotNil(t, df.Unknown()) + }) +} + +func TestDynFlagsGroupManagement(t *testing.T) { + t.Parallel() + + t.Run("Create new group", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + + // Create Group + group := df.Group("group1") + assert.NotNil(t, group) + assert.Contains(t, df.Config().Groups(), "group1") + assert.Equal(t, group, df.Config().Lookup("group1")) + assert.Equal(t, "group1", group.Name) + assert.NotNil(t, group.Flags) + + // Get Group again + group = df.Group("group1") + assert.NotNil(t, group) + assert.Contains(t, df.Config().Groups(), "group1") + }) +} + +func TestDynFlagsUsageOutput(t *testing.T) { + t.Parallel() + + t.Run("Generate usage with title, description, and epilog", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + df := dynflags.New(dynflags.ContinueOnError) + df.SetOutput(&buf) + + df.Title("Test Application") + df.Description("This application demonstrates usage of dynamic flags.") + df.Epilog("For more information, visit https://example.com.") + + df.Usage() + + output := buf.String() + assert.Contains(t, output, "Test Application") + assert.Contains(t, output, "This application demonstrates usage of dynamic flags.") + assert.Contains(t, output, "For more information, visit https://example.com.") + }) +} + +func TestDynFlagsParsedAndUnknown(t *testing.T) { + t.Parallel() + + t.Run("Empty parsed and unknown groups", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + + assert.Empty(t, df.Parsed().Groups()) + assert.Empty(t, df.Unknown().Groups()) + }) +} + +func TestParsedGroupMethods(t *testing.T) { + t.Parallel() + + t.Run("Retrieve unknown values", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ParseUnknown) + df.Group("known") + args := []string{"--unknown.identifier.value", "value1"} + err := df.Parse(args) + assert.NoError(t, err) + + unknownGroups := df.Unknown() + group := unknownGroups.Lookup("unknown") + assert.NotNil(t, group) + + identifier := group.Lookup("identifier") + assert.NotNil(t, identifier) + assert.Equal(t, "value1", identifier.Lookup("value")) + }) + + t.Run("Retrieve parsed group values", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + df.Group("testGroup").String("flag1", "defaultValue", "Test flag") + args := []string{"--testGroup.identifier1.flag1", "value1"} + err := df.Parse(args) + assert.NoError(t, err) + + parsedGroups := df.Parsed() + group := parsedGroups.Lookup("testGroup") + assert.NotNil(t, group) + + identifier := group.Lookup("identifier1") + assert.NotNil(t, identifier) + assert.Equal(t, "value1", identifier.Lookup("flag1")) + }) + + t.Run("Non-existent flag in parsed group", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + df.Group("testGroup") + args := []string{"--testGroup.identifier1.flag1", "value1"} + err := df.Parse(args) + assert.NoError(t, err) + + parsedGroups := df.Parsed() + assert.Len(t, parsedGroups.Groups(), 0) + + unknownGroup := df.Unknown() + assert.NotNil(t, unknownGroup) + }) +} + +func TestDynFlagsUnparsedArgs(t *testing.T) { + t.Parallel() + + t.Run("Retrieve unparsed arguments", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + args := []string{ + "--unparsable", "value1", + } + err := df.Parse(args) + assert.NoError(t, err) + + unparsedArgs := df.UnparsedArgs() + assert.Contains(t, unparsedArgs, "--unparsable") + }) +} diff --git a/pkg/dynflags/flag.go b/pkg/dynflags/flag.go new file mode 100644 index 0000000..a72e445 --- /dev/null +++ b/pkg/dynflags/flag.go @@ -0,0 +1,51 @@ +package dynflags + +type FlagType string + +const ( + FlagTypeStringSlice FlagType = "..STRINGs" + FlagTypeString FlagType = "STRING" + FlagTypeInt FlagType = "INT" + FlagTypeIntSlice FlagType = "..INTs" + FlagTypeBool FlagType = "BOOL" + FlagTypeBoolSlice FlagType = "..BOOLs" + FlagTypeDuration FlagType = "DURATION" + FlagTypeDurationSlice FlagType = "..DURATIONs" + FlagTypeFloat FlagType = "FLOAT" + FlagTypeFloatSlice FlagType = "..FLOATs" + FlagTypeIP FlagType = "IP" + FlagTypeIPSlice FlagType = "..IPs" + FlagTypeURL FlagType = "URL" + FlagTypeURLSlice FlagType = "..URLs" +) + +// Flag represents a single configuration flag +type Flag struct { + Default interface{} // Default value for the flag + Type FlagType // Type of the flag + Usage string // Description for usage + metaVar string // MetaVar for flag + value FlagValue // Encapsulated parsing and value-setting logic +} + +func (f *Flag) MetaVar(metaVar string) { + f.metaVar = metaVar +} + +// FlagValue interface encapsulates parsing and value-setting logic +type FlagValue interface { + // Parse parses the given string value into the flag's value type + Parse(value string) (interface{}, error) + // Set sets the flag's value to the given value + Set(value interface{}) error + // GetBound returns the bound value of the flag. + GetBound() interface{} +} + +// Value returns the current value of the flag. +func (f *Flag) GetValue() interface{} { + if f == nil || f.value == nil { + return nil + } + return f.value.GetBound() +} diff --git a/pkg/dynflags/flag_test.go b/pkg/dynflags/flag_test.go new file mode 100644 index 0000000..a4cbba5 --- /dev/null +++ b/pkg/dynflags/flag_test.go @@ -0,0 +1,30 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestFlagGetValue(t *testing.T) { + t.Parallel() + + t.Run("Nil Flag - GetValue", func(t *testing.T) { + t.Parallel() + var flag *dynflags.Flag + assert.Nil(t, flag.GetValue(), "Expected nil when flag is nil") + }) + + t.Run("String Flag - GetValue", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{ + Name: "testGroup", + Flags: make(map[string]*dynflags.Flag), + } + + flag := group.String("example", "default-value", "An example string flag") + assert.Equal(t, "default-value", flag.GetValue(), "Expected GetValue() to return the default value") + }) +} diff --git a/pkg/dynflags/flags_config.go b/pkg/dynflags/flags_config.go new file mode 100644 index 0000000..c1de7e3 --- /dev/null +++ b/pkg/dynflags/flags_config.go @@ -0,0 +1,55 @@ +package dynflags + +// ConfigGroup represents the static configuration for a group. +type ConfigGroup struct { + Name string // Name of the group. + usage string // Title for usage. If not set it takes the name of the group in Uppercase. + Flags map[string]*Flag // Flags within the group. + flagOrder []string // Order of flags. +} + +// Usage sets the usage for the group. +func (cg *ConfigGroup) Usage(usage string) { + cg.usage = usage +} + +// Lookup retrieves a flag in the group by its name. +func (gc *ConfigGroup) Lookup(flagName string) *Flag { + if gc == nil { + return nil + } + + return gc.Flags[flagName] +} + +// ConfigGroups represents all configuration groups with lookup and iteration support. +type ConfigGroups struct { + groups map[string]*ConfigGroup +} + +// Lookup retrieves a configuration group by its name. +func (cg *ConfigGroups) Lookup(groupName string) *ConfigGroup { + if cg == nil { + return nil + } + + return cg.groups[groupName] +} + +// Groups returns the underlying map for direct iteration. +func (cg *ConfigGroups) Groups() map[string]*ConfigGroup { + if cg == nil { + return nil + } + + return cg.groups +} + +// Config returns a ConfigGroups instance for the dynflags instance. +func (df *DynFlags) Config() *ConfigGroups { + if df == nil { + return nil + } + + return &ConfigGroups{groups: df.configGroups} +} diff --git a/pkg/dynflags/flags_config_test.go b/pkg/dynflags/flags_config_test.go new file mode 100644 index 0000000..f1858b8 --- /dev/null +++ b/pkg/dynflags/flags_config_test.go @@ -0,0 +1,135 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestConfigGroup(t *testing.T) { + t.Parallel() + + t.Run("Lookup existing flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{ + Name: "testGroup", + Flags: map[string]*dynflags.Flag{"flag1": {Usage: "Test Flag"}}, + } + flag := group.Lookup("flag1") + assert.NotNil(t, flag) + assert.Equal(t, "Test Flag", flag.Usage) + }) + + t.Run("Lookup non-existing flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{ + Name: "testGroup", + Flags: map[string]*dynflags.Flag{}, + } + flag := group.Lookup("flag1") + assert.Nil(t, flag) + }) +} + +func TestConfigGroups(t *testing.T) { + t.Parallel() + + t.Run("Lookup existing group", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + df.Group("http") + + groups := df.Config() + group := groups.Lookup("http") + + assert.NotNil(t, group) + assert.Equal(t, "http", group.Name) + }) + + t.Run("Iterate over groups", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + df.Group("http") + df.Group("tcp") + + groups := df.Config().Groups() + + assert.Contains(t, groups, "http") + assert.Contains(t, groups, "tcp") + assert.Equal(t, "http", groups["http"].Name) + assert.Equal(t, "tcp", groups["tcp"].Name) + }) +} + +func TestConfigGroup_Lookup_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Lookup on nil ConfigGroup returns nil", func(t *testing.T) { + t.Parallel() + + var groupConfig *dynflags.ConfigGroup + result := groupConfig.Lookup("flag1") + assert.Nil(t, result, "Expected Lookup on nil ConfigGroup to return nil") + }) + + t.Run("Lookup non-existing flag returns nil", func(t *testing.T) { + t.Parallel() + + groupConfig := &dynflags.ConfigGroup{ + Name: "testGroup", + Flags: map[string]*dynflags.Flag{}, + } + + result := groupConfig.Lookup("nonExistingFlag") + assert.Nil(t, result, "Expected Lookup for non-existing flag to return nil") + }) +} + +func TestConfigGroups_Lookup_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Lookup on nil ConfigGroups returns nil", func(t *testing.T) { + t.Parallel() + + var configGroups *dynflags.ConfigGroups + result := configGroups.Lookup("group1") + + assert.Nil(t, result, "Expected Lookup on nil ConfigGroups to return nil") + }) + + t.Run("Lookup non-existing group returns nil", func(t *testing.T) { + t.Parallel() + + configGroups := &dynflags.ConfigGroups{} + result := configGroups.Lookup("nonExistingGroup") + + assert.Nil(t, result, "Expected Lookup for non-existing group to return nil") + }) +} + +func TestConfigGroups_Groups_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Groups on nil ConfigGroups returns nil", func(t *testing.T) { + var configGroups *dynflags.ConfigGroups + result := configGroups.Groups() + + assert.Nil(t, result, "Expected Groups on nil ConfigGroups to return nil") + }) +} + +func TestDynFlags_Config_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Config on nil DynFlags returns nil", func(t *testing.T) { + var dynFlags *dynflags.DynFlags + result := dynFlags.Config() + + assert.Nil(t, result, "Expected Config on nil DynFlags to return nil") + }) +} diff --git a/pkg/dynflags/flags_parsed.go b/pkg/dynflags/flags_parsed.go new file mode 100644 index 0000000..dbec8c4 --- /dev/null +++ b/pkg/dynflags/flags_parsed.go @@ -0,0 +1,66 @@ +package dynflags + +// ParsedGroup represents a runtime group with parsed values. +type ParsedGroup struct { + Parent *ConfigGroup // Reference to the parent static group. + Name string // Identifier for the child group (e.g., "IDENTIFIER1"). + Values map[string]interface{} // Parsed values for the group's flags. +} + +// Lookup retrieves the value of a flag in the parsed group. +func (pg *ParsedGroup) Lookup(flagName string) interface{} { + if pg == nil { + return nil + } + + return pg.Values[flagName] +} + +// ParsedGroups represents all parsed groups with lookup and iteration support. +type ParsedGroups struct { + groups map[string]map[string]*ParsedGroup // Nested map of group name -> identifier -> ParsedGroup. +} + +// Lookup retrieves a group by its name. +func (pg *ParsedGroups) Lookup(groupName string) *ParsedIdentifiers { + if pg == nil { + return nil + } + if identifiers, exists := pg.groups[groupName]; exists { + return &ParsedIdentifiers{Name: groupName, identifiers: identifiers} + } + return nil +} + +// Groups returns the underlying map for direct iteration. +func (pg *ParsedGroups) Groups() map[string]map[string]*ParsedGroup { + return pg.groups +} + +// ParsedIdentifiers provides lookup for identifiers within a group. +type ParsedIdentifiers struct { + Name string + identifiers map[string]*ParsedGroup +} + +// Lookup retrieves a specific identifier within a group. +func (gi *ParsedIdentifiers) Lookup(identifier string) *ParsedGroup { + if gi == nil { + return nil + } + + return gi.identifiers[identifier] +} + +// Parsed returns a ParsedGroups instance for the dynflags instance. +func (df *DynFlags) Parsed() *ParsedGroups { + parsed := make(map[string]map[string]*ParsedGroup) + for groupName, groups := range df.parsedGroups { + identifierMap := make(map[string]*ParsedGroup) + for _, group := range groups { + identifierMap[group.Name] = group + } + parsed[groupName] = identifierMap + } + return &ParsedGroups{groups: parsed} +} diff --git a/pkg/dynflags/flags_parsed_test.go b/pkg/dynflags/flags_parsed_test.go new file mode 100644 index 0000000..0e8b169 --- /dev/null +++ b/pkg/dynflags/flags_parsed_test.go @@ -0,0 +1,167 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestParsedGroup(t *testing.T) { + t.Parallel() + + t.Run("Lookup existing parsed flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": "value1"}, + } + + value := group.Lookup("flag1") + assert.Equal(t, "value1", value) + }) + + t.Run("Lookup non-existing parsed flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + value := group.Lookup("flag1") + assert.Nil(t, value) + }) +} + +func TestParsedGroups(t *testing.T) { + t.Parallel() + + t.Run("Lookup existing parsed group", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + args := []string{"--testgroup.identifier1.flag1", "value1"} + err := df.Parse(args) + assert.NoError(t, err) + + group := df.Group("testGroup") + assert.NotNil(t, group) + assert.Equal(t, "testGroup", group.Name) + }) + + t.Run("Lookup non-existing parsed group", func(t *testing.T) { + t.Parallel() + + parsedGroups := &dynflags.ParsedGroups{} + + group := parsedGroups.Lookup("nonExistentGroup") + assert.Nil(t, group) + }) +} + +func TestDynFlagsParsed(t *testing.T) { + t.Parallel() + + t.Run("Combine parsed groups", func(t *testing.T) { + t.Parallel() + + args := []string{ + "--group1.identifier1.flag1", "value1", + "--group1.identifier2.flag2", "value2", + } + + df := dynflags.New(dynflags.ContinueOnError) + g1 := df.Group("group1") + g1.String("flag1", "", "Description flag1") + g1.String("flag2", "", "Description flag2") + + err := df.Parse(args) + assert.NoError(t, err) + + parsedGroups := df.Parsed() + + group := parsedGroups.Lookup("group1") + assert.NotNil(t, group) + assert.Equal(t, "group1", group.Name) + assert.Equal(t, "value1", group.Lookup("identifier1").Lookup("flag1")) + assert.Equal(t, "value2", group.Lookup("identifier2").Lookup("flag2")) + }) + + t.Run("Handle no parsed groups", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + parsedGroups := df.Parsed() + + group := parsedGroups.Lookup("nonExistentGroup") + assert.Nil(t, group) + }) +} + +func TestParsedGroups_Lookup_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Lookup on nil ParsedGroups returns nil", func(t *testing.T) { + t.Parallel() + + var parsedGroups *dynflags.ParsedGroups + result := parsedGroups.Lookup("http") + assert.Nil(t, result, "Expected Lookup on nil ParsedGroups to return nil") + }) + + t.Run("Lookup non-existing group returns nil", func(t *testing.T) { + t.Parallel() + + parsedGroups := &dynflags.ParsedGroups{} + + result := parsedGroups.Lookup("nonExistingGroup") + assert.Nil(t, result, "Expected Lookup for non-existing group to return nil") + }) +} + +func TestParsedIdentifiers_Lookup_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Lookup on nil ParsedIdentifiers returns nil", func(t *testing.T) { + t.Parallel() + + var parsedIdentifiers *dynflags.ParsedIdentifiers + result := parsedIdentifiers.Lookup("identifier1") + assert.Nil(t, result, "Expected Lookup on nil ParsedIdentifiers to return nil") + }) + + t.Run("Lookup non-existing identifier returns nil", func(t *testing.T) { + t.Parallel() + + parsedIdentifiers := &dynflags.ParsedIdentifiers{} + + result := parsedIdentifiers.Lookup("nonExistingIdentifier") + assert.Nil(t, result, "Expected Lookup for non-existing identifier to return nil") + }) +} + +func TestParsedGroup_Lookup_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Lookup on nil ParsedGroup returns nil", func(t *testing.T) { + t.Parallel() + + var parsedGroup *dynflags.ParsedGroup + result := parsedGroup.Lookup("flag1") + assert.Nil(t, result, "Expected Lookup on nil ParsedGroup to return nil") + }) + + t.Run("Lookup non-existing flag returns nil", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "identifier1", + Values: map[string]interface{}{}, + } + + result := parsedGroup.Lookup("nonExistingFlag") + assert.Nil(t, result, "Expected Lookup for non-existing flag to return nil") + }) +} diff --git a/pkg/dynflags/flags_unknown.go b/pkg/dynflags/flags_unknown.go new file mode 100644 index 0000000..a7b9073 --- /dev/null +++ b/pkg/dynflags/flags_unknown.go @@ -0,0 +1,75 @@ +package dynflags + +// UnknownGroup represents a runtime group with unrecognized values. +type UnknownGroup struct { + Name string // Identifier for the child group (e.g., "IDENTIFIER1"). + Values map[string]interface{} // Unrecognized flags and their parsed values. +} + +// Lookup retrieves the value of a flag in the unknown group. +func (ug *UnknownGroup) Lookup(flagName string) interface{} { + if ug == nil { + return nil + } + + return ug.Values[flagName] +} + +// UnknownGroups represents all unknown groups with lookup and iteration support. +type UnknownGroups struct { + groups map[string]map[string]*UnknownGroup // Nested map of group name -> identifier -> UnknownGroup. + unparsedArgs []string // List of arguments that couldn't be parsed into groups or flags. +} + +// Lookup retrieves unknown groups by name. +func (ug *UnknownGroups) Lookup(groupName string) *UnknownIdentifiers { + if ug == nil { + return nil + } + + if identifiers, exists := ug.groups[groupName]; exists { + return &UnknownIdentifiers{Name: groupName, identifiers: identifiers} + } + return nil +} + +// Groups returns the underlying map for direct iteration. +func (ug *UnknownGroups) Groups() map[string]map[string]*UnknownGroup { + return ug.groups +} + +// UnknownIdentifiers provides lookup for identifiers within a group. +type UnknownIdentifiers struct { + Name string + identifiers map[string]*UnknownGroup +} + +// Lookup retrieves a specific identifier within a group. +func (ui *UnknownIdentifiers) Lookup(identifier string) *UnknownGroup { + if ui == nil { + return nil + } + + return ui.identifiers[identifier] +} + +// Unknown returns an UnknownGroups instance for the DynFlags instance. +func (df *DynFlags) Unknown() *UnknownGroups { + parsed := make(map[string]map[string]*UnknownGroup) + for groupName, groups := range df.unknownGroups { + identifierMap := make(map[string]*UnknownGroup) + for _, group := range groups { + identifierMap[group.Name] = group + } + parsed[groupName] = identifierMap + } + return &UnknownGroups{ + groups: parsed, + unparsedArgs: df.unparsedArgs, + } +} + +// UnparsedArgs returns the list of unparseable arguments. +func (df *DynFlags) UnparsedArgs() []string { + return df.unparsedArgs +} diff --git a/pkg/dynflags/flags_unknwon_test.go b/pkg/dynflags/flags_unknwon_test.go new file mode 100644 index 0000000..e526d9a --- /dev/null +++ b/pkg/dynflags/flags_unknwon_test.go @@ -0,0 +1,175 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestUnknownGroup(t *testing.T) { + t.Parallel() + + t.Run("Lookup existing unknown flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.UnknownGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": "value1"}, + } + + value := group.Lookup("flag1") + assert.Equal(t, "value1", value) + }) + + t.Run("Lookup non-existing unknown flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.UnknownGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + value := group.Lookup("flag1") + assert.Nil(t, value) + }) +} + +func TestUnknownGroups(t *testing.T) { + t.Parallel() + + t.Run("Lookup existing unknown group", func(t *testing.T) { + t.Parallel() + + args := []string{ + "--unknown.identifier1.flag1", "value1", + } + + df := dynflags.New(dynflags.ParseUnknown) + err := df.Parse(args) + assert.NoError(t, err) + + unknownGroups := df.Unknown() + + group := unknownGroups.Lookup("unknown") + assert.NotNil(t, group) + assert.Equal(t, "unknown", group.Name) + }) + + t.Run("Lookup non-existing unknown group", func(t *testing.T) { + t.Parallel() + + args := []string{ + "--unknown.identifier1.flag1", "value1", + } + + df := dynflags.New(dynflags.ContinueOnError) + err := df.Parse(args) + assert.NoError(t, err) + + unknownGroups := df.Unknown() + group := unknownGroups.Lookup("unknown") + assert.Nil(t, group) + }) +} + +func TestDynFlagsUnknown(t *testing.T) { + t.Parallel() + + t.Run("Combine unknown groups", func(t *testing.T) { + t.Parallel() + + args := []string{ + "--group1.identifier1.flag1", "value1", + "--group1.identifier2.flag2", "value2", + } + + df := dynflags.New(dynflags.ParseUnknown) + err := df.Parse(args) + assert.NoError(t, err) + + unknownGroups := df.Unknown() + + group := unknownGroups.Lookup("group1") + assert.NotNil(t, group) + assert.Equal(t, "group1", group.Name) + assert.Equal(t, "value1", group.Lookup("identifier1").Lookup("flag1")) + assert.Equal(t, "value2", group.Lookup("identifier2").Lookup("flag2")) + }) + + t.Run("Handle no unknown groups", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ParseUnknown) + unknownGroups := df.Unknown() + + group := unknownGroups.Lookup("nonExistentGroup") + assert.Nil(t, group) + }) +} + +func TestUnknownGroups_Lookup_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Lookup on nil UnknownGroups returns nil", func(t *testing.T) { + t.Parallel() + + var unknownGroups *dynflags.UnknownGroups + result := unknownGroups.Lookup("http") + assert.Nil(t, result, "Expected Lookup on nil UnknownGroups to return nil") + }) + + t.Run("Lookup non-existing group returns nil", func(t *testing.T) { + t.Parallel() + + unknownGroups := &dynflags.UnknownGroups{} + + result := unknownGroups.Lookup("nonExistingGroup") + assert.Nil(t, result, "Expected Lookup for non-existing group to return nil") + }) +} + +func TestUnknownIdentifiers_Lookup_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Lookup on nil UnknownIdentifiers returns nil", func(t *testing.T) { + t.Parallel() + + var unknownIdentifiers *dynflags.UnknownIdentifiers + result := unknownIdentifiers.Lookup("identifier1") + assert.Nil(t, result, "Expected Lookup on nil UnknownIdentifiers to return nil") + }) + + t.Run("Lookup non-existing identifier returns nil", func(t *testing.T) { + t.Parallel() + + unknownIdentifiers := &dynflags.UnknownIdentifiers{} + + result := unknownIdentifiers.Lookup("nonExistingIdentifier") + assert.Nil(t, result, "Expected Lookup for non-existing identifier to return nil") + }) +} + +func TestUnknownGroup_Lookup_NilHandling(t *testing.T) { + t.Parallel() + + t.Run("Lookup on nil UnknownGroup returns nil", func(t *testing.T) { + t.Parallel() + + var unknownGroup *dynflags.UnknownGroup + result := unknownGroup.Lookup("flag1") + assert.Nil(t, result, "Expected Lookup on nil UnknownGroup to return nil") + }) + + t.Run("Lookup non-existing flag returns nil", func(t *testing.T) { + t.Parallel() + + unknownGroup := &dynflags.UnknownGroup{ + Name: "identifier1", + Values: map[string]interface{}{}, + } + + result := unknownGroup.Lookup("nonExistingFlag") + assert.Nil(t, result, "Expected Lookup for non-existing flag to return nil") + }) +} diff --git a/pkg/dynflags/float64.go b/pkg/dynflags/float64.go new file mode 100644 index 0000000..c65eea5 --- /dev/null +++ b/pkg/dynflags/float64.go @@ -0,0 +1,57 @@ +package dynflags + +import ( + "fmt" + "strconv" +) + +// IntValue implementation for integer flags +type Float64Value struct { + Bound *float64 +} + +func (f *Float64Value) GetBound() interface{} { + if f.Bound == nil { + return nil + } + return *f.Bound +} + +func (i *Float64Value) Parse(value string) (interface{}, error) { + return strconv.ParseFloat(value, 64) +} + +func (i *Float64Value) Set(value interface{}) error { + if num, ok := value.(float64); ok { + *i.Bound = num + return nil + } + return fmt.Errorf("invalid value type: expected float64") +} + +// Float64 defines a float64 flag with specified name, default value, and usage string. +// The return value is the address of a float64 variable that stores the value of the flag. +func (g *ConfigGroup) Float64(name string, value float64, usage string) *Flag { + bound := &value + flag := &Flag{ + Type: FlagTypeInt, + Default: value, + Usage: usage, + value: &Float64Value{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetFloat64 returns the float64 value of a flag with the given name +func (pg *ParsedGroup) GetFloat64(flagName string) (float64, error) { + value, exists := pg.Values[flagName] + if !exists { + return 0, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + if floatVal, ok := value.(float64); ok { + return floatVal, nil + } + return 0, fmt.Errorf("flag '%s' is not a float64", flagName) +} diff --git a/pkg/dynflags/float64_slice.go b/pkg/dynflags/float64_slice.go new file mode 100644 index 0000000..9b9df2e --- /dev/null +++ b/pkg/dynflags/float64_slice.go @@ -0,0 +1,73 @@ +package dynflags + +import ( + "fmt" + "strconv" + "strings" +) + +// Float64SlicesValue implementation for float64 slice flags +type Float64SlicesValue struct { + Bound *[]float64 +} + +func (f *Float64SlicesValue) GetBound() interface{} { + if f.Bound == nil { + return nil + } + return *f.Bound +} + +func (f *Float64SlicesValue) Parse(value string) (interface{}, error) { + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return nil, fmt.Errorf("invalid float64 value: %s, error: %w", value, err) + } + return parsed, nil +} + +func (f *Float64SlicesValue) Set(value interface{}) error { + if parsedFloat, ok := value.(float64); ok { + *f.Bound = append(*f.Bound, parsedFloat) + return nil + } + return fmt.Errorf("invalid value type: expected float64") +} + +// Float64Slices defines a float64 slice flag with specified name, default value, and usage string. +// The return value is the address of a slice of float64 that stores the value of the flag. +func (g *ConfigGroup) Float64Slices(name string, value []float64, usage string) *Flag { + bound := &value + defaultValue := make([]string, len(value)) + for i, v := range value { + defaultValue[i] = strconv.FormatFloat(v, 'f', -1, 64) + } + + flag := &Flag{ + Type: FlagTypeFloatSlice, + Default: strings.Join(defaultValue, ","), + Usage: usage, + value: &Float64SlicesValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetFloat64Slices returns the []float64 value of a flag with the given name +func (pg *ParsedGroup) GetFloat64Slices(flagName string) ([]float64, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + + if slice, ok := value.([]float64); ok { + return slice, nil + } + + if f, ok := value.(float64); ok { + return []float64{f}, nil + } + + return nil, fmt.Errorf("flag '%s' is not a []float64", flagName) +} diff --git a/pkg/dynflags/float64_slice_test.go b/pkg/dynflags/float64_slice_test.go new file mode 100644 index 0000000..3882f83 --- /dev/null +++ b/pkg/dynflags/float64_slice_test.go @@ -0,0 +1,140 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestFloat64SlicesValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid float64 value", func(t *testing.T) { + t.Parallel() + + float64SlicesValue := dynflags.Float64SlicesValue{Bound: &[]float64{}} + parsed, err := float64SlicesValue.Parse("3.14159") + assert.NoError(t, err) + assert.Equal(t, 3.14159, parsed) + }) + + t.Run("Parse invalid float64 value", func(t *testing.T) { + t.Parallel() + + float64SlicesValue := dynflags.Float64SlicesValue{Bound: &[]float64{}} + parsed, err := float64SlicesValue.Parse("invalid") + assert.Error(t, err) + assert.Nil(t, parsed) + }) + + t.Run("Set valid float64 value", func(t *testing.T) { + t.Parallel() + + bound := []float64{1.23} + float64SlicesValue := dynflags.Float64SlicesValue{Bound: &bound} + + err := float64SlicesValue.Set(4.56) + assert.NoError(t, err) + assert.Equal(t, []float64{1.23, 4.56}, bound) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := []float64{} + float64SlicesValue := dynflags.Float64SlicesValue{Bound: &bound} + + err := float64SlicesValue.Set("invalid") + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected float64") + }) +} + +func TestGroupConfigFloat64Slices(t *testing.T) { + t.Parallel() + + t.Run("Define float64 slices flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := []float64{1.23, 4.56} + group.Float64Slices("float64SliceFlag", defaultValue, "A float64 slices flag") + + assert.Contains(t, group.Flags, "float64SliceFlag") + assert.Equal(t, "A float64 slices flag", group.Flags["float64SliceFlag"].Usage) + assert.Equal(t, "1.23,4.56", group.Flags["float64SliceFlag"].Default) + }) +} + +func TestGetFloat64Slices(t *testing.T) { + t.Parallel() + + t.Run("Retrieve []float64 value", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": []float64{1.1, 2.2, 3.3}}, + } + + result, err := parsedGroup.GetFloat64Slices("flag1") + assert.NoError(t, err) + assert.Equal(t, []float64{1.1, 2.2, 3.3}, result) + }) + + t.Run("Retrieve single float64 value as []float64", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": 42.42}, + } + + result, err := parsedGroup.GetFloat64Slices("flag1") + assert.NoError(t, err) + assert.Equal(t, []float64{42.42}, result) + }) + + t.Run("Flag not found", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + result, err := parsedGroup.GetFloat64Slices("nonExistentFlag") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'nonExistentFlag' not found in group 'testGroup'") + }) + + t.Run("Flag value is invalid type", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": "invalid"}, + } + + result, err := parsedGroup.GetFloat64Slices("flag1") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'flag1' is not a []float64") + }) +} + +func TestFloat64SlicesGetBound(t *testing.T) { + t.Run("Float64SlicesValue - GetBound", func(t *testing.T) { + var slices *[]float64 + val := []float64{1.1, 2.2, 3.3} + slices = &val + + floatSlicesValue := dynflags.Float64SlicesValue{Bound: slices} + assert.Equal(t, val, floatSlicesValue.GetBound()) + + floatSlicesValue = dynflags.Float64SlicesValue{Bound: nil} + assert.Nil(t, floatSlicesValue.GetBound()) + }) +} diff --git a/pkg/dynflags/float64_test.go b/pkg/dynflags/float64_test.go new file mode 100644 index 0000000..6466a6c --- /dev/null +++ b/pkg/dynflags/float64_test.go @@ -0,0 +1,130 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestFloat64Value_Parse(t *testing.T) { + t.Parallel() + + t.Run("Valid Float64", func(t *testing.T) { + t.Parallel() + + bound := new(float64) + fv := &dynflags.Float64Value{Bound: bound} + + parsedValue, err := fv.Parse("123.456") + assert.NoError(t, err) + assert.Equal(t, 123.456, parsedValue) + }) + + t.Run("Invalid Float64", func(t *testing.T) { + t.Parallel() + + bound := new(float64) + fv := &dynflags.Float64Value{Bound: bound} + + _, err := fv.Parse("invalid") + assert.Error(t, err) + }) +} + +func TestFloat64Value_Set(t *testing.T) { + t.Parallel() + + t.Run("Set Valid Float64", func(t *testing.T) { + t.Parallel() + + bound := new(float64) + fv := &dynflags.Float64Value{Bound: bound} + + err := fv.Set(123.456) + assert.NoError(t, err) + assert.Equal(t, 123.456, *bound) + }) + + t.Run("Set Invalid Float64", func(t *testing.T) { + t.Parallel() + + bound := new(float64) + fv := &dynflags.Float64Value{Bound: bound} + + err := fv.Set("invalid") + assert.Error(t, err) + }) +} + +func TestGroupConfig_Float64(t *testing.T) { + t.Parallel() + + t.Run("Define Float64 Flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{ + Flags: make(map[string]*dynflags.Flag), + } + value := group.Float64("float64-test", 123.456, "test float64 flag") + + assert.NotNil(t, value) + assert.Equal(t, 123.456, value.Default) + assert.Contains(t, group.Flags, "float64-test") + }) +} + +func TestParsedGroup_GetFloat64(t *testing.T) { + t.Parallel() + + t.Run("Get Existing Float64 Value", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "test-group", + Values: map[string]interface{}{"float64-test": 123.456}, + } + + value, err := parsedGroup.GetFloat64("float64-test") + assert.NoError(t, err) + assert.Equal(t, 123.456, value) + }) + + t.Run("Get Non-Existing Float64 Value", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "test-group", + Values: map[string]interface{}{}, + } + + _, err := parsedGroup.GetFloat64("non-existing") + assert.Error(t, err) + }) + + t.Run("Get Invalid Float64 Value", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "test-group", + Values: map[string]interface{}{"invalid-test": "not-a-float"}, + } + + _, err := parsedGroup.GetFloat64("invalid-test") + assert.Error(t, err) + }) +} + +func TestFloat64GetBound(t *testing.T) { + t.Run("Float64Value - GetBound", func(t *testing.T) { + var f *float64 + val := 3.14 + f = &val + + floatValue := dynflags.Float64Value{Bound: f} + assert.Equal(t, 3.14, floatValue.GetBound()) + + floatValue = dynflags.Float64Value{Bound: nil} + assert.Nil(t, floatValue.GetBound()) + }) +} diff --git a/pkg/dynflags/int.go b/pkg/dynflags/int.go new file mode 100644 index 0000000..39350a3 --- /dev/null +++ b/pkg/dynflags/int.go @@ -0,0 +1,57 @@ +package dynflags + +import ( + "fmt" + "strconv" +) + +// IntValue implementation for integer flags +type IntValue struct { + Bound *int +} + +func (i *IntValue) GetBound() interface{} { + if i.Bound == nil { + return nil + } + return *i.Bound +} + +func (i *IntValue) Parse(value string) (interface{}, error) { + return strconv.Atoi(value) +} + +func (i *IntValue) Set(value interface{}) error { + if num, ok := value.(int); ok { + *i.Bound = num + return nil + } + return fmt.Errorf("invalid value type: expected int") +} + +// Int defines an int flag with specified name, default value, and usage string. +// The return value is the address of an int variable that stores the value of the flag. +func (g *ConfigGroup) Int(name string, value int, usage string) *Flag { + bound := &value + flag := &Flag{ + Type: FlagTypeInt, + Default: value, + Usage: usage, + value: &IntValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetInt returns the int value of a flag with the given name +func (pg *ParsedGroup) GetInt(flagName string) (int, error) { + value, exists := pg.Values[flagName] + if !exists { + return 0, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + if intVal, ok := value.(int); ok { + return intVal, nil + } + return 0, fmt.Errorf("flag '%s' is not an int", flagName) +} diff --git a/pkg/dynflags/int_slice.go b/pkg/dynflags/int_slice.go new file mode 100644 index 0000000..1d89ba2 --- /dev/null +++ b/pkg/dynflags/int_slice.go @@ -0,0 +1,72 @@ +package dynflags + +import ( + "fmt" + "strconv" + "strings" +) + +// IntSlicesValue implementation for int slice flags +type IntSlicesValue struct { + Bound *[]int +} + +func (i *IntSlicesValue) GetBound() interface{} { + if i.Bound == nil { + return nil + } + return *i.Bound +} + +func (s *IntSlicesValue) Parse(value string) (interface{}, error) { + parsedValue, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("invalid integer value: %s", value) + } + return parsedValue, nil +} + +func (s *IntSlicesValue) Set(value interface{}) error { + if num, ok := value.(int); ok { + *s.Bound = append(*s.Bound, num) + return nil + } + return fmt.Errorf("invalid value type: expected int") +} + +// IntSlices defines an int slice flag with specified name, default value, and usage string. +// The return value is the address of a slice of integers that stores the value of the flag. +func (g *ConfigGroup) IntSlices(name string, value []int, usage string) *Flag { + bound := &value + defaults := make([]string, len(value)) + for i, v := range value { + defaults[i] = strconv.Itoa(v) + } + flag := &Flag{ + Type: FlagTypeIntSlice, + Default: strings.Join(defaults, ","), + Usage: usage, + value: &IntSlicesValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetIntSlices returns the []int value of a flag with the given name +func (pg *ParsedGroup) GetIntSlices(flagName string) ([]int, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + + if slice, ok := value.([]int); ok { + return slice, nil + } + + if i, ok := value.(int); ok { + return []int{i}, nil + } + + return nil, fmt.Errorf("flag '%s' is not a []int", flagName) +} diff --git a/pkg/dynflags/int_slice_test.go b/pkg/dynflags/int_slice_test.go new file mode 100644 index 0000000..1143cb4 --- /dev/null +++ b/pkg/dynflags/int_slice_test.go @@ -0,0 +1,140 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestIntSlicesValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid int slice value", func(t *testing.T) { + t.Parallel() + + intSlicesValue := dynflags.IntSlicesValue{Bound: &[]int{}} + parsed, err := intSlicesValue.Parse("123") + assert.NoError(t, err) + assert.Equal(t, 123, parsed) + }) + + t.Run("Parse invalid int slice value", func(t *testing.T) { + t.Parallel() + + intSlicesValue := dynflags.IntSlicesValue{Bound: &[]int{}} + parsed, err := intSlicesValue.Parse("invalid") + assert.Error(t, err) + assert.Nil(t, parsed) + }) + + t.Run("Set valid int slice value", func(t *testing.T) { + t.Parallel() + + bound := []int{1} + intSlicesValue := dynflags.IntSlicesValue{Bound: &bound} + + err := intSlicesValue.Set(2) + assert.NoError(t, err) + assert.Equal(t, []int{1, 2}, bound) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := []int{1} + intSlicesValue := dynflags.IntSlicesValue{Bound: &bound} + + err := intSlicesValue.Set("invalid") // Invalid type + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected int") + }) +} + +func TestGroupConfigIntSlices(t *testing.T) { + t.Parallel() + + t.Run("Define int slices flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := []int{1, 2} + group.IntSlices("intSliceFlag", defaultValue, "An int slices flag") + + assert.Contains(t, group.Flags, "intSliceFlag") + assert.Equal(t, "An int slices flag", group.Flags["intSliceFlag"].Usage) + assert.Equal(t, "1,2", group.Flags["intSliceFlag"].Default) + }) +} + +func TestGetIntSlices(t *testing.T) { + t.Parallel() + + t.Run("Retrieve []int value", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": []int{1, 2, 3}}, + } + + result, err := parsedGroup.GetIntSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, result) + }) + + t.Run("Retrieve single int value as []int", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": 42}, + } + + result, err := parsedGroup.GetIntSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []int{42}, result) + }) + + t.Run("Flag not found", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + result, err := parsedGroup.GetIntSlices("nonExistentFlag") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'nonExistentFlag' not found in group 'testGroup'") + }) + + t.Run("Flag value is invalid type", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": "invalid"}, + } + + result, err := parsedGroup.GetIntSlices("flag1") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'flag1' is not a []int") + }) +} + +func TestIntSlicesGetBound(t *testing.T) { + t.Run("IntSlicesValue - GetBound", func(t *testing.T) { + var slices *[]int + val := []int{1, 2, 3} + slices = &val + + intSlicesValue := dynflags.IntSlicesValue{Bound: slices} + assert.Equal(t, val, intSlicesValue.GetBound()) + + intSlicesValue = dynflags.IntSlicesValue{Bound: nil} + assert.Nil(t, intSlicesValue.GetBound()) + }) +} diff --git a/pkg/dynflags/int_test.go b/pkg/dynflags/int_test.go new file mode 100644 index 0000000..f05fc78 --- /dev/null +++ b/pkg/dynflags/int_test.go @@ -0,0 +1,124 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestIntValue_Parse(t *testing.T) { + t.Parallel() + + t.Run("ValidInt", func(t *testing.T) { + t.Parallel() + + var bound int + val := &dynflags.IntValue{Bound: &bound} + parsed, err := val.Parse("42") + assert.NoError(t, err) + assert.Equal(t, 42, parsed) + }) + + t.Run("InvalidInt", func(t *testing.T) { + t.Parallel() + + var bound int + val := &dynflags.IntValue{Bound: &bound} + _, err := val.Parse("invalid") + assert.Error(t, err) + }) +} + +func TestIntValue_Set(t *testing.T) { + t.Parallel() + + t.Run("ValidInt", func(t *testing.T) { + t.Parallel() + + var bound int + val := &dynflags.IntValue{Bound: &bound} + assert.NoError(t, val.Set(42)) + assert.Equal(t, 42, bound) + }) + + t.Run("InvalidType", func(t *testing.T) { + t.Parallel() + + var bound int + val := &dynflags.IntValue{Bound: &bound} + assert.Error(t, val.Set("not an int")) + }) +} + +func TestGroupConfig_Int(t *testing.T) { + t.Parallel() + + t.Run("DefineAndRetrieveInt", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + bound := group.Int("test-int", 100, "Test integer flag") + assert.NotNil(t, bound) + + flag, exists := group.Flags["test-int"] + assert.True(t, exists) + assert.NotNil(t, flag) + assert.Equal(t, dynflags.FlagTypeInt, flag.Type) + assert.Equal(t, 100, flag.Default) + assert.Equal(t, "Test integer flag", flag.Usage) + }) +} + +func TestParsedGroup_GetInt(t *testing.T) { + t.Parallel() + + t.Run("ValidIntRetrieval", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Values: map[string]interface{}{ + "test-int": 42, + }, + } + val, err := parsedGroup.GetInt("test-int") + assert.NoError(t, err) + assert.Equal(t, 42, val) + }) + + t.Run("FlagNotFound", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Values: make(map[string]interface{}), + } + _, err := parsedGroup.GetInt("non-existent") + assert.Error(t, err) + }) + + t.Run("InvalidType", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Values: map[string]interface{}{ + "test-int": "not an int", + }, + } + _, err := parsedGroup.GetInt("test-int") + assert.Error(t, err) + }) +} + +func TestIntGetBound(t *testing.T) { + t.Run("IntValue - GetBound", func(t *testing.T) { + var i *int + val := 42 + i = &val + + intValue := dynflags.IntValue{Bound: i} + assert.Equal(t, 42, intValue.GetBound()) + + intValue = dynflags.IntValue{Bound: nil} + assert.Nil(t, intValue.GetBound()) + }) +} diff --git a/pkg/dynflags/ip.go b/pkg/dynflags/ip.go new file mode 100644 index 0000000..0f5001d --- /dev/null +++ b/pkg/dynflags/ip.go @@ -0,0 +1,69 @@ +package dynflags + +import ( + "fmt" + "net" +) + +// IPValue implementation for URL flags +type IPValue struct { + Bound *net.IP +} + +func (i *IPValue) GetBound() interface{} { + if i.Bound == nil { + return nil + } + return *i.Bound +} + +func (u *IPValue) Parse(value string) (interface{}, error) { + result := net.ParseIP(value) + if result == nil { + return nil, fmt.Errorf("invalid IP address: %s", value) + } + return &result, nil +} + +func (u *IPValue) Set(value interface{}) error { + if parsedIP, ok := value.(*net.IP); ok { + *u.Bound = *parsedIP + return nil + } + return fmt.Errorf("invalid value type: expected IP") +} + +// IP defines an net.IP flag with specified name, default value, and usage string. +// The return value is the address of an net.IP variable that stores the value of the flag. +func (g *ConfigGroup) IP(name, value, usage string) *Flag { + bound := new(*net.IP) + if value != "" { + parsed := net.ParseIP(value) + if parsed == nil { + panic(fmt.Sprintf("%s has a invalid default IP flag '%s'", name, value)) + } + *bound = &parsed // Copy the parsed URL into bound + } + flag := &Flag{ + Type: FlagTypeIP, + Default: value, + Usage: usage, + value: &IPValue{Bound: *bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetIP returns the net.IP value of a flag with the given name +func (pg *ParsedGroup) GetIP(flagName string) (net.IP, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + if ip, ok := value.(net.IP); ok { + return ip, nil + } + + return nil, fmt.Errorf("flag '%s' is not a IP", flagName) +} diff --git a/pkg/dynflags/ip_slice.go b/pkg/dynflags/ip_slice.go new file mode 100644 index 0000000..3d18306 --- /dev/null +++ b/pkg/dynflags/ip_slice.go @@ -0,0 +1,73 @@ +package dynflags + +import ( + "fmt" + "net" + "strings" +) + +// IPSlicesValue implementation for IP slice flags +type IPSlicesValue struct { + Bound *[]net.IP +} + +func (i *IPSlicesValue) GetBound() interface{} { + if i.Bound == nil { + return nil + } + return *i.Bound +} + +func (s *IPSlicesValue) Parse(value string) (interface{}, error) { + ip := net.ParseIP(value) + if ip == nil { + return nil, fmt.Errorf("invalid IP address: %s", value) + } + return ip, nil +} + +func (s *IPSlicesValue) Set(value interface{}) error { + if ip, ok := value.(net.IP); ok { + *s.Bound = append(*s.Bound, ip) + return nil + } + return fmt.Errorf("invalid value type: expected net.IP") +} + +// IPSlices defines an IP slice flag with specified name, default value, and usage string. +// The return value is the address of a slice of IPs that stores the value of the flag. +func (g *ConfigGroup) IPSlices(name string, value []net.IP, usage string) *Flag { + bound := &value + defaultValue := make([]string, len(value)) + for i, ip := range value { + defaultValue[i] = ip.String() + } + + flag := &Flag{ + Type: FlagTypeIPSlice, + Default: strings.Join(defaultValue, ","), + Usage: usage, + value: &IPSlicesValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetIPSlices returns the []net.IP value of a flag with the given name +func (pg *ParsedGroup) GetIPSlices(flagName string) ([]net.IP, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + + if ipSlice, ok := value.([]net.IP); ok { + return ipSlice, nil + } + + if i, ok := value.(net.IP); ok { + return []net.IP{i}, nil + } + + return nil, fmt.Errorf("flag '%s' is not a []net.IP", flagName) +} diff --git a/pkg/dynflags/ip_slice_test.go b/pkg/dynflags/ip_slice_test.go new file mode 100644 index 0000000..c44f2be --- /dev/null +++ b/pkg/dynflags/ip_slice_test.go @@ -0,0 +1,146 @@ +package dynflags_test + +import ( + "net" + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestIPSlicesValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid IP", func(t *testing.T) { + t.Parallel() + + ipSlicesValue := dynflags.IPSlicesValue{Bound: &[]net.IP{}} + parsed, err := ipSlicesValue.Parse("192.168.0.1") + assert.NoError(t, err) + assert.Equal(t, net.ParseIP("192.168.0.1"), parsed) + }) + + t.Run("Parse invalid IP", func(t *testing.T) { + t.Parallel() + + ipSlicesValue := dynflags.IPSlicesValue{Bound: &[]net.IP{}} + parsed, err := ipSlicesValue.Parse("invalid-ip") + assert.Error(t, err) + assert.Nil(t, parsed) + }) + + t.Run("Set valid IP", func(t *testing.T) { + t.Parallel() + + bound := []net.IP{net.ParseIP("192.168.0.1")} + ipSlicesValue := dynflags.IPSlicesValue{Bound: &bound} + + err := ipSlicesValue.Set(net.ParseIP("10.0.0.1")) + assert.NoError(t, err) + assert.Equal(t, []net.IP{net.ParseIP("192.168.0.1"), net.ParseIP("10.0.0.1")}, bound) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := []net.IP{} + ipSlicesValue := dynflags.IPSlicesValue{Bound: &bound} + + err := ipSlicesValue.Set("invalid-ip-type") + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected net.IP") + }) +} + +func TestGroupConfigIPSlices(t *testing.T) { + t.Parallel() + + t.Run("Define IP slices flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := []net.IP{net.ParseIP("192.168.0.1"), net.ParseIP("10.0.0.1")} + group.IPSlices("ipSliceFlag", defaultValue, "An IP slices flag") + + assert.Contains(t, group.Flags, "ipSliceFlag") + assert.Equal(t, "An IP slices flag", group.Flags["ipSliceFlag"].Usage) + assert.Equal(t, "192.168.0.1,10.0.0.1", group.Flags["ipSliceFlag"].Default) + }) +} + +func TestGetIPSlices(t *testing.T) { + t.Parallel() + + t.Run("Retrieve []net.IP value", func(t *testing.T) { + t.Parallel() + + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("10.0.0.1") + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": []net.IP{ip1, ip2}}, + } + + result, err := parsedGroup.GetIPSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []net.IP{ip1, ip2}, result) + }) + + t.Run("Retrieve single net.IP value as []net.IP", func(t *testing.T) { + t.Parallel() + + ip := net.ParseIP("127.0.0.1") + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": ip}, + } + + result, err := parsedGroup.GetIPSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []net.IP{ip}, result) + }) + + t.Run("Flag not found", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + result, err := parsedGroup.GetIPSlices("nonExistentFlag") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'nonExistentFlag' not found in group 'testGroup'") + }) + + t.Run("Flag value is invalid type", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": "invalid"}, + } + + result, err := parsedGroup.GetIPSlices("flag1") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'flag1' is not a []net.IP") + }) +} + +func TestIPSlicesGetBound(t *testing.T) { + t.Run("IPSlicesValue - GetBound", func(t *testing.T) { + var slices *[]net.IP + val := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("10.0.0.1")} + slices = &val + + ipSlicesValue := dynflags.IPSlicesValue{Bound: slices} + assert.Equal(t, val, ipSlicesValue.GetBound()) + + ipSlicesValue = dynflags.IPSlicesValue{Bound: nil} + assert.Nil(t, ipSlicesValue.GetBound()) + }) +} diff --git a/pkg/dynflags/ip_test.go b/pkg/dynflags/ip_test.go new file mode 100644 index 0000000..ab5a37a --- /dev/null +++ b/pkg/dynflags/ip_test.go @@ -0,0 +1,140 @@ +package dynflags_test + +import ( + "net" + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestIPValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid IP address", func(t *testing.T) { + t.Parallel() + + ipValue := dynflags.IPValue{} + parsed, err := ipValue.Parse("192.168.1.1") + assert.NoError(t, err) + assert.NotNil(t, parsed) + assert.Equal(t, "192.168.1.1", parsed.(*net.IP).String()) + }) + + t.Run("Parse invalid IP address", func(t *testing.T) { + t.Parallel() + + ipValue := dynflags.IPValue{} + parsed, err := ipValue.Parse("invalid-ip") + assert.Error(t, err) + assert.Nil(t, parsed) + }) + + t.Run("Set valid IP value", func(t *testing.T) { + t.Parallel() + + bound := net.ParseIP("0.0.0.0") + ipValue := dynflags.IPValue{Bound: &bound} + + parsed := net.ParseIP("192.168.1.1") + err := ipValue.Set(&parsed) + assert.NoError(t, err) + assert.Equal(t, "192.168.1.1", ipValue.Bound.String()) + }) + + t.Run("Set invalid value type", func(t *testing.T) { + t.Parallel() + + bound := net.ParseIP("0.0.0.0") + ipValue := dynflags.IPValue{Bound: &bound} + + err := ipValue.Set("invalid-type") + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected IP") + }) +} + +func TestGroupConfigIP(t *testing.T) { + t.Parallel() + + t.Run("Define IP flag with valid default", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultIP := "192.168.1.1" + group.IP("ipFlag", defaultIP, "An example IP flag") + + assert.Contains(t, group.Flags, "ipFlag") + assert.Equal(t, "An example IP flag", group.Flags["ipFlag"].Usage) + assert.Equal(t, defaultIP, group.Flags["ipFlag"].Default) + }) + + t.Run("Define IP flag with invalid default", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + + assert.PanicsWithValue(t, + "ipFlag has a invalid default IP flag 'invalid-ip'", + func() { + group.IP("ipFlag", "invalid-ip", "Invalid IP flag") + }) + }) +} + +func TestParsedGroupGetIP(t *testing.T) { + t.Parallel() + + t.Run("Get existing IP flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{ + "ipFlag": net.ParseIP("192.168.1.1"), + }, + } + ip, err := group.GetIP("ipFlag") + assert.NoError(t, err) + assert.Equal(t, "192.168.1.1", ip.String()) + }) + + t.Run("Get non-existent IP flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{}, + } + ip, err := group.GetIP("ipFlag") + assert.Error(t, err) + assert.Nil(t, ip) + assert.EqualError(t, err, "flag 'ipFlag' not found in group ''") + }) + + t.Run("Get IP flag with invalid type", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{ + "ipFlag": "not-an-ip", + }, + } + ip, err := group.GetIP("ipFlag") + assert.Error(t, err) + assert.Nil(t, ip) + assert.EqualError(t, err, "flag 'ipFlag' is not a IP") + }) +} + +func TestIPGetBound(t *testing.T) { + t.Run("IPValue - GetBound", func(t *testing.T) { + var ip *net.IP + val := net.ParseIP("127.0.0.1") + ip = &val + + ipValue := dynflags.IPValue{Bound: ip} + assert.Equal(t, val, ipValue.GetBound()) + + ipValue = dynflags.IPValue{Bound: nil} + assert.Nil(t, ipValue.GetBound()) + }) +} diff --git a/pkg/dynflags/parser.go b/pkg/dynflags/parser.go new file mode 100644 index 0000000..2171126 --- /dev/null +++ b/pkg/dynflags/parser.go @@ -0,0 +1,152 @@ +package dynflags + +import ( + "fmt" + "strings" +) + +// Parse parses the CLI arguments and populates parsed and unknown groups. +func (df *DynFlags) Parse(args []string) error { + for i := 0; i < len(args); i++ { + arg := args[i] + + // Extract the key and value + fullKey, value, err := df.extractKeyValue(arg, args, &i) + if err != nil { + // Handle unparseable arguments + if df.parseBehavior == ExitOnError { + return err + } + df.unparsedArgs = append(df.unparsedArgs, arg) + continue + } + + // Validate and split the key + parentName, identifier, flagName, err := df.splitKey(fullKey) + if err != nil { + // Handle invalid keys + if df.parseBehavior == ExitOnError { + return err + } + df.unparsedArgs = append(df.unparsedArgs, arg) + continue + } + + // Handle the flag + if err := df.handleFlag(parentName, identifier, flagName, value); err != nil { + if df.parseBehavior == ExitOnError { + return err + } + df.unparsedArgs = append(df.unparsedArgs, arg) + } + } + return nil +} + +// extractKeyValue extracts the key and value from an argument. +func (df *DynFlags) extractKeyValue(arg string, args []string, index *int) (string, string, error) { + if !strings.HasPrefix(arg, "--") { + // Invalid argument format + return "", "", fmt.Errorf("invalid argument format: %s", arg) + } + + arg = strings.TrimPrefix(arg, "--") + + // Handle "--key=value" format + if strings.Contains(arg, "=") { + parts := strings.SplitN(arg, "=", 2) + return parts[0], parts[1], nil + } + + // Handle "--key value" format + if *index+1 < len(args) && !strings.HasPrefix(args[*index+1], "--") { + *index++ + return arg, args[*index], nil + } + + // Missing value for the key + return "", "", fmt.Errorf("missing value for flag: --%s", arg) +} + +// splitKey validates and splits a key into its components. +func (df *DynFlags) splitKey(fullKey string) (string, string, string, error) { + parts := strings.Split(fullKey, ".") + if len(parts) != 3 { + return "", "", "", fmt.Errorf("flag must follow the pattern: --..") + } + return parts[0], parts[1], parts[2], nil +} + +// handleFlag processes a known or unknown flag. +func (df *DynFlags) handleFlag(parentName, identifier, flagName, value string) error { + if parentGroup, exists := df.configGroups[parentName]; exists { + if flag := parentGroup.Lookup(flagName); flag != nil { + // Known flag + parsedGroup := df.createOrGetParsedGroup(parentGroup, identifier) + return df.setFlagValue(parsedGroup, flagName, flag, value) + } + } + + // Unknown flag + return df.handleUnknownFlag(parentName, identifier, flagName, value) +} + +// setFlagValue sets the value of a known flag in the parsed group. +func (df *DynFlags) setFlagValue(parsedGroup *ParsedGroup, flagName string, flag *Flag, value string) error { + parsedValue, err := flag.value.Parse(value) + if err != nil { + return fmt.Errorf("failed to parse value for flag '%s': %v", flagName, err) + } + + if err := flag.value.Set(parsedValue); err != nil { + return fmt.Errorf("failed to set value for flag '%s': %v", flagName, err) + } + + parsedGroup.Values[flagName] = parsedValue + return nil +} + +// handleUnknownFlag handles unknown flags based on the parse behavior. +func (df *DynFlags) handleUnknownFlag(parentName, identifier, flagName, value string) error { + switch df.parseBehavior { + case ExitOnError: + return fmt.Errorf("unknown flag '%s' in group '%s'", flagName, parentName) + case ParseUnknown: + unknownGroup := df.createOrGetUnknownGroup(parentName, identifier) + unknownGroup.Values[flagName] = value + } + return nil +} + +// createOrGetParsedGroup retrieves or initializes a parsed group. +func (df *DynFlags) createOrGetParsedGroup(parentGroup *ConfigGroup, identifier string) *ParsedGroup { + for _, group := range df.parsedGroups[parentGroup.Name] { + if group.Name == identifier { + return group + } + } + + newGroup := &ParsedGroup{ + Parent: parentGroup, + Name: identifier, + Values: make(map[string]interface{}), + } + df.parsedGroups[parentGroup.Name] = append(df.parsedGroups[parentGroup.Name], newGroup) + return newGroup +} + +// createOrGetUnknownGroup retrieves or initializes an unknown group. +func (df *DynFlags) createOrGetUnknownGroup(parentName, identifier string) *UnknownGroup { + for _, group := range df.unknownGroups[parentName] { + if group.Name == identifier { + return group + } + } + + newGroup := &UnknownGroup{ + Name: identifier, + Values: make(map[string]interface{}), + } + df.unknownGroups[parentName] = append(df.unknownGroups[parentName], newGroup) + return newGroup +} diff --git a/pkg/dynflags/parser_test.go b/pkg/dynflags/parser_test.go new file mode 100644 index 0000000..51743ae --- /dev/null +++ b/pkg/dynflags/parser_test.go @@ -0,0 +1,223 @@ +package dynflags_test + +import ( + "testing" + "time" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestDynFlagsParse(t *testing.T) { + t.Parallel() + + t.Run("Parse valid arguments", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + group := df.Group("http") + group.String("method", "GET", "HTTP method to use") + group.String("url", "", "Target URL") + + args := []string{ + "--http.identifier1.method", "POST", + "--http.identifier1.url=https://example.com", + } + err := df.Parse(args) + assert.NoError(t, err) + + parsedGroups := df.Parsed() + httpGroup := parsedGroups.Lookup("http") + assert.NotNil(t, httpGroup) + + identifier1 := httpGroup.Lookup("identifier1") + assert.NotNil(t, identifier1) + assert.Equal(t, "POST", identifier1.Lookup("method")) + assert.Equal(t, "https://example.com", identifier1.Lookup("url")) + }) + + t.Run("Exit on missing key", func(t *testing.T) { + df := dynflags.New(dynflags.ExitOnError) + group := df.Group("http") + group.String("method", "GET", "HTTP method to use") + + args := []string{ + "-http.identifier1", "https://example.com", + } + err := df.Parse(args) + assert.Error(t, err) + }) + + t.Run("Parse with missing value", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + group := df.Group("http") + group.String("method", "GET", "HTTP method to use") + + args := []string{ + "--http.identifier1.method", + } + err := df.Parse(args) + assert.NoError(t, err) + + unparsedArgs := df.UnparsedArgs() + assert.Contains(t, unparsedArgs, "--http.identifier1.method") + }) + + t.Run("Parse with wrong value type and continue", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + group := df.Group("http") + group.Duration("timeout", 10*time.Second, "HTTP timeout") + + args := []string{ + "--http.identifier1.timeout", "1", + } + err := df.Parse(args) + assert.NoError(t, err) + + unparsedArgs := df.UnparsedArgs() + assert.Contains(t, unparsedArgs, "--http.identifier1.timeout") + }) + + t.Run("Parse with invalid flag format", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + + args := []string{ + "-invalidFlag", + } + err := df.Parse(args) + assert.NoError(t, err) + + unparsedArgs := df.UnparsedArgs() + assert.Contains(t, unparsedArgs, "-invalidFlag") + }) + + t.Run("Parse with no identifier and exit", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ExitOnError) + group1 := df.Group("http") + group1.Duration("timeout", 10*time.Second, "HTTP timeout") + + args := []string{ + "--http.duration", "10s", + } + err := df.Parse(args) + assert.Error(t, err) + }) + + t.Run("Parse with unknown group and continue on error", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ParseUnknown) + + args := []string{ + "--unknown.identifier1.flag1", "value1", + } + err := df.Parse(args) + assert.NoError(t, err) + + unknownGroups := df.Unknown() + unknownGroup := unknownGroups.Lookup("unknown") + assert.NotNil(t, unknownGroup) + + identifier1 := unknownGroup.Lookup("identifier1") + assert.NotNil(t, identifier1) + assert.Equal(t, "value1", identifier1.Lookup("flag1")) + }) + + t.Run("Parse with unknown group and parse unknown behavior", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ParseUnknown) + + args := []string{ + "--unknown.identifier1.flag1", "value1", + } + err := df.Parse(args) + assert.NoError(t, err) + + unknownGroups := df.Unknown() + unknownGroup := unknownGroups.Lookup("unknown") + assert.NotNil(t, unknownGroup) + + identifier1 := unknownGroup.Lookup("identifier1") + assert.NotNil(t, identifier1) + assert.Equal(t, "value1", identifier1.Lookup("flag1")) + }) + + t.Run("Parse with unknown group and exit on error", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ExitOnError) + + args := []string{ + "--unknown.identifier1.flag1", "value1", + } + err := df.Parse(args) + assert.Error(t, err) + assert.EqualError(t, err, "unknown flag 'flag1' in group 'unknown'") + }) + + t.Run("Handle invalid key format", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + + args := []string{ + "--invalidformat", + } + err := df.Parse(args) + assert.NoError(t, err) + + unparsedArgs := df.UnparsedArgs() + assert.Contains(t, unparsedArgs, "--invalidformat") + }) + + t.Run("Handle missing flag value", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ContinueOnError) + group := df.Group("http") + group.String("method", "GET", "HTTP method to use") + + args := []string{ + "--http.identifier1.method", + } + err := df.Parse(args) + assert.NoError(t, err) + + unparsedArgs := df.UnparsedArgs() + assert.Contains(t, unparsedArgs, "--http.identifier1.method") + }) + + t.Run("Same unknown group", func(t *testing.T) { + t.Parallel() + + df := dynflags.New(dynflags.ParseUnknown) + + args := []string{ + "--unknownGroup.identifier.flag=value1", + "--unknownGroup.identifier2.flag2=value2", + } + + err := df.Parse(args) + assert.NoError(t, err) + + unknownGroups := df.Unknown() + + group1 := unknownGroups.Lookup("unknownGroup").Lookup("identifier") + assert.NotNil(t, group1, "Expected group 'unknownGroup.identifier' to exist") + assert.Equal(t, "value1", group1.Values["flag"], "Expected 'flag' to have the correct value") + + group2 := unknownGroups.Lookup("unknownGroup").Lookup("identifier2") + assert.NotNil(t, group2, "Expected group 'unknownGroup.identifier2' to exist") + assert.Equal(t, "value2", group2.Values["flag2"], "Expected 'flag2' to have the correct value") + assert.Len(t, unknownGroups.Groups()["unknownGroup"], 2, "Expected exactly 2 identifiers in 'unknownGroup'") + }) +} diff --git a/pkg/dynflags/string.go b/pkg/dynflags/string.go new file mode 100644 index 0000000..c129c47 --- /dev/null +++ b/pkg/dynflags/string.go @@ -0,0 +1,55 @@ +package dynflags + +import "fmt" + +// StringValue implementation for string flags +type StringValue struct { + Bound *string +} + +func (s *StringValue) GetBound() interface{} { + if s.Bound == nil { + return nil + } + return *s.Bound +} + +func (s *StringValue) Parse(value string) (interface{}, error) { + return value, nil +} + +func (s *StringValue) Set(value interface{}) error { + if str, ok := value.(string); ok { + *s.Bound = str + return nil + } + return fmt.Errorf("invalid value type: expected string") +} + +// String defines a string flag with specified name, default value, and usage string. +// It returns the *Flag for further customization. +func (g *ConfigGroup) String(name, value, usage string) *Flag { + bound := &value + flag := &Flag{ + Type: FlagTypeString, + Default: value, + Usage: usage, + value: &StringValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetString returns the string value of a flag with the given name +func (pg *ParsedGroup) GetString(flagName string) (string, error) { + value, exists := pg.Values[flagName] + if !exists { + return "", fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + if str, ok := value.(string); ok { + return str, nil + } + + return "", fmt.Errorf("flag '%s' is not a string", flagName) +} diff --git a/pkg/dynflags/string_slice.go b/pkg/dynflags/string_slice.go new file mode 100644 index 0000000..14fb68f --- /dev/null +++ b/pkg/dynflags/string_slice.go @@ -0,0 +1,63 @@ +package dynflags + +import ( + "fmt" + "strings" +) + +// StringSlicesValue implementation for string slice flags +type StringSlicesValue struct { + Bound *[]string +} + +func (s *StringSlicesValue) GetBound() interface{} { + if s.Bound == nil { + return nil + } + return *s.Bound +} + +func (s *StringSlicesValue) Parse(value string) (interface{}, error) { + return value, nil +} + +func (s *StringSlicesValue) Set(value interface{}) error { + if str, ok := value.(string); ok { + *s.Bound = append(*s.Bound, str) + return nil + } + return fmt.Errorf("invalid value type: expected string") +} + +// StringSlices defines a string slice flag with specified name, default value, and usage string. +// The return value is the address of a slice of strings that stores the value of the flag. +func (g *ConfigGroup) StringSlices(name string, value []string, usage string) *Flag { + bound := &value + flag := &Flag{ + Type: FlagTypeStringSlice, + Default: strings.Join(value, ","), + Usage: usage, + value: &StringSlicesValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetStringSlices returns the []string value of a flag with the given name +func (pg *ParsedGroup) GetStringSlices(flagName string) ([]string, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + + if strSlice, ok := value.([]string); ok { + return strSlice, nil + } + + if str, ok := value.(string); ok { + return []string{str}, nil + } + + return nil, fmt.Errorf("flag '%s' is not a []string", flagName) +} diff --git a/pkg/dynflags/string_slice_test.go b/pkg/dynflags/string_slice_test.go new file mode 100644 index 0000000..e9e72e4 --- /dev/null +++ b/pkg/dynflags/string_slice_test.go @@ -0,0 +1,169 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestStringSlicesValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid string slice value", func(t *testing.T) { + t.Parallel() + + stringSlicesValue := dynflags.StringSlicesValue{Bound: &[]string{}} + parsed, err := stringSlicesValue.Parse("example") + assert.NoError(t, err) + assert.Equal(t, "example", parsed) + }) + + t.Run("Set valid string slice value", func(t *testing.T) { + t.Parallel() + + bound := []string{"initial"} + stringSlicesValue := dynflags.StringSlicesValue{Bound: &bound} + + err := stringSlicesValue.Set("updated") + assert.NoError(t, err) + assert.Equal(t, []string{"initial", "updated"}, bound) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := []string{"initial"} + stringSlicesValue := dynflags.StringSlicesValue{Bound: &bound} + + err := stringSlicesValue.Set(123) // Invalid type + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected string") + }) + + t.Run("Multiple Occurrences Append Correctly", func(t *testing.T) { + t.Parallel() + + var bound []string + value := &dynflags.StringSlicesValue{Bound: &bound} + + assert.NoError(t, value.Set("Content-Type=application/json")) + assert.NoError(t, value.Set("MyHeader=header1")) + assert.NoError(t, value.Set("Header1=value1,Header2=value2")) + + assert.Equal(t, []string{ + "Content-Type=application/json", + "MyHeader=header1", + "Header1=value1,Header2=value2", + }, bound) + }) + + t.Run("Single Value Append", func(t *testing.T) { + t.Parallel() + + var bound []string + value := &dynflags.StringSlicesValue{Bound: &bound} + + assert.NoError(t, value.Set("Content-Type=application/json")) + assert.Equal(t, []string{"Content-Type=application/json"}, bound) + }) + + t.Run("Invalid Value Type", func(t *testing.T) { + t.Parallel() + + var bound []string + value := &dynflags.StringSlicesValue{Bound: &bound} + + err := value.Set(123) // Invalid type + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid value type") + }) +} + +func TestGroupConfigStringSlices(t *testing.T) { + t.Parallel() + + t.Run("Define string slices flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := []string{"default1", "default2"} + group.StringSlices("stringSliceFlag", defaultValue, "A string slices flag") + + assert.Contains(t, group.Flags, "stringSliceFlag") + assert.Equal(t, "A string slices flag", group.Flags["stringSliceFlag"].Usage) + assert.Equal(t, "default1,default2", group.Flags["stringSliceFlag"].Default) + }) +} + +func TestGetStringSlices(t *testing.T) { + t.Parallel() + + t.Run("Retrieve []string value", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": []string{"value1", "value2"}}, + } + + result, err := parsedGroup.GetStringSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []string{"value1", "value2"}, result) + }) + + t.Run("Retrieve single string value as []string", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": "singleValue"}, + } + + result, err := parsedGroup.GetStringSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []string{"singleValue"}, result) + }) + + t.Run("Flag not found", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + result, err := parsedGroup.GetStringSlices("nonExistentFlag") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'nonExistentFlag' not found in group 'testGroup'") + }) + + t.Run("Flag value is invalid type", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": 123}, // Invalid type (int) + } + + result, err := parsedGroup.GetStringSlices("flag1") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'flag1' is not a []string") + }) +} + +func TestGetStringSlicesGetBound(t *testing.T) { + t.Run("StringSlicesValue - GetBound", func(t *testing.T) { + var slices *[]string + val := []string{"a", "b", "c"} + slices = &val + + stringSlicesValue := dynflags.StringSlicesValue{Bound: slices} + assert.Equal(t, val, stringSlicesValue.GetBound()) + + stringSlicesValue = dynflags.StringSlicesValue{Bound: nil} + assert.Nil(t, stringSlicesValue.GetBound()) + }) +} diff --git a/pkg/dynflags/string_test.go b/pkg/dynflags/string_test.go new file mode 100644 index 0000000..b5929e6 --- /dev/null +++ b/pkg/dynflags/string_test.go @@ -0,0 +1,116 @@ +package dynflags_test + +import ( + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestStringValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid string", func(t *testing.T) { + t.Parallel() + + stringValue := dynflags.StringValue{} + parsed, err := stringValue.Parse("example") + assert.NoError(t, err) + assert.Equal(t, "example", parsed) + }) + + t.Run("Set valid string", func(t *testing.T) { + t.Parallel() + + bound := "initial" + stringValue := dynflags.StringValue{Bound: &bound} + + err := stringValue.Set("updated") + assert.NoError(t, err) + assert.Equal(t, "updated", bound) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := "initial" + stringValue := dynflags.StringValue{Bound: &bound} + + err := stringValue.Set(123) // Invalid type + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected string") + }) +} + +func TestGroupConfigString(t *testing.T) { + t.Parallel() + + t.Run("Define string flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := "default" + group.String("stringFlag", defaultValue, "A string flag") + + assert.Contains(t, group.Flags, "stringFlag") + assert.Equal(t, "A string flag", group.Flags["stringFlag"].Usage) + assert.Equal(t, defaultValue, group.Flags["stringFlag"].Default) + }) +} + +func TestParsedGroupGetString(t *testing.T) { + t.Parallel() + + t.Run("Get existing string flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{ + "stringFlag": "value", + }, + } + str, err := group.GetString("stringFlag") + assert.NoError(t, err) + assert.Equal(t, "value", str) + }) + + t.Run("Get non-existent string flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{}, + } + str, err := group.GetString("stringFlag") + assert.Error(t, err) + assert.Equal(t, "", str) + assert.EqualError(t, err, "flag 'stringFlag' not found in group ''") + }) + + t.Run("Get string flag with invalid type", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{ + "stringFlag": 123, // Invalid type + }, + } + str, err := group.GetString("stringFlag") + assert.Error(t, err) + assert.Equal(t, "", str) + assert.EqualError(t, err, "flag 'stringFlag' is not a string") + }) +} + +func TestStringGetBound(t *testing.T) { + t.Run("StringValue - GetBound", func(t *testing.T) { + var str *string + value := "test" + str = &value + + stringValue := dynflags.StringValue{Bound: str} + assert.Equal(t, "test", stringValue.GetBound()) + + stringValue = dynflags.StringValue{Bound: nil} + assert.Nil(t, stringValue.GetBound()) + }) +} diff --git a/pkg/dynflags/url.go b/pkg/dynflags/url.go new file mode 100644 index 0000000..af49ee7 --- /dev/null +++ b/pkg/dynflags/url.go @@ -0,0 +1,65 @@ +package dynflags + +import ( + "fmt" + "net/url" +) + +// URLValue implementation for URL flags +type URLValue struct { + Bound *url.URL +} + +func (u *URLValue) GetBound() interface{} { + if u.Bound == nil { + return nil + } + return *u.Bound +} + +func (u *URLValue) Parse(value string) (interface{}, error) { + return url.Parse(value) +} + +func (u *URLValue) Set(value interface{}) error { + if parsedURL, ok := value.(*url.URL); ok { + *u.Bound = *parsedURL + return nil + } + return fmt.Errorf("invalid value type: expected URL") +} + +// URL defines a URL flag with specified name, default value, and usage string. +// The return value is the address of a url.URL variable that stores the value of the flag. +func (g *ConfigGroup) URL(name, value, usage string) *Flag { + bound := new(url.URL) + if value != "" { + parsed, err := url.Parse(value) + if err != nil { + panic(fmt.Sprintf("invalid default URL for flag '%s': %s", name, err)) + } + *bound = *parsed // Copy the parsed URL into bound + } + flag := &Flag{ + Type: FlagTypeURL, + Default: value, + Usage: usage, + value: &URLValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetURL returns the url.URL value of a flag with the given name +func (pg *ParsedGroup) GetURL(flagName string) (*url.URL, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + if url, ok := value.(url.URL); ok { + return &url, nil + } + + return nil, fmt.Errorf("flag '%s' is not a URL", flagName) +} diff --git a/pkg/dynflags/url_slice.go b/pkg/dynflags/url_slice.go new file mode 100644 index 0000000..a3fbf48 --- /dev/null +++ b/pkg/dynflags/url_slice.go @@ -0,0 +1,73 @@ +package dynflags + +import ( + "fmt" + "net/url" + "strings" +) + +// URLSlicesValue implementation for URL slice flags +type URLSlicesValue struct { + Bound *[]*url.URL +} + +func (u *URLSlicesValue) GetBound() interface{} { + if u.Bound == nil { + return nil + } + return *u.Bound +} + +func (u *URLSlicesValue) Parse(value string) (interface{}, error) { + parsedURL, err := url.Parse(value) + if err != nil { + return nil, fmt.Errorf("invalid URL: %s, error: %w", value, err) + } + return parsedURL, nil +} + +func (u *URLSlicesValue) Set(value interface{}) error { + if parsedURL, ok := value.(*url.URL); ok { + *u.Bound = append(*u.Bound, parsedURL) + return nil + } + return fmt.Errorf("invalid value type: expected *url.URL") +} + +// URLSlices defines a URL slice flag with specified name, default value, and usage string. +// The return value is the address of a slice of URLs that stores the value of the flag. +func (g *ConfigGroup) URLSlices(name string, value []*url.URL, usage string) *Flag { + bound := &value + defaultValue := make([]string, len(value)) + for i, u := range value { + defaultValue[i] = u.String() + } + + flag := &Flag{ + Type: FlagTypeURLSlice, + Default: strings.Join(defaultValue, ","), + Usage: usage, + value: &URLSlicesValue{Bound: bound}, + } + g.Flags[name] = flag + g.flagOrder = append(g.flagOrder, name) + return flag +} + +// GetURLSlices returns the []*url.URL value of a flag with the given name +func (pg *ParsedGroup) GetURLSlices(flagName string) ([]*url.URL, error) { + value, exists := pg.Values[flagName] + if !exists { + return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name) + } + + if urlSlice, ok := value.([]*url.URL); ok { + return urlSlice, nil + } + + if u, ok := value.(*url.URL); ok { + return []*url.URL{u}, nil + } + + return nil, fmt.Errorf("flag '%s' is not a []*url.URL", flagName) +} diff --git a/pkg/dynflags/url_slice_test.go b/pkg/dynflags/url_slice_test.go new file mode 100644 index 0000000..3b9727c --- /dev/null +++ b/pkg/dynflags/url_slice_test.go @@ -0,0 +1,154 @@ +package dynflags_test + +import ( + "net/url" + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestURLSlicesValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid URL", func(t *testing.T) { + t.Parallel() + + urlSlicesValue := dynflags.URLSlicesValue{Bound: &[]*url.URL{}} + parsed, err := urlSlicesValue.Parse("https://example.com") + assert.NoError(t, err) + assert.Equal(t, "https://example.com", parsed.(*url.URL).String()) + }) + + t.Run("Parse invalid URL", func(t *testing.T) { + t.Parallel() + + urlSlicesValue := dynflags.URLSlicesValue{Bound: &[]*url.URL{}} + parsed, err := urlSlicesValue.Parse("://invalid-url") + assert.Error(t, err) + assert.Nil(t, parsed) + }) + + t.Run("Set valid URL", func(t *testing.T) { + t.Parallel() + + bound := []*url.URL{{Scheme: "https", Host: "example.com"}} + urlSlicesValue := dynflags.URLSlicesValue{Bound: &bound} + + err := urlSlicesValue.Set(&url.URL{Scheme: "http", Host: "localhost"}) + assert.NoError(t, err) + assert.Equal(t, []*url.URL{ + {Scheme: "https", Host: "example.com"}, + {Scheme: "http", Host: "localhost"}, + }, bound) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := []*url.URL{} + urlSlicesValue := dynflags.URLSlicesValue{Bound: &bound} + + err := urlSlicesValue.Set("invalid-type") + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected *url.URL") + }) +} + +func TestGroupConfigURLSlices(t *testing.T) { + t.Parallel() + + t.Run("Define URL slices flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := []*url.URL{ + {Scheme: "https", Host: "example.com"}, + {Scheme: "http", Host: "localhost"}, + } + group.URLSlices("urlSliceFlag", defaultValue, "A URL slices flag") + + assert.Contains(t, group.Flags, "urlSliceFlag") + assert.Equal(t, "A URL slices flag", group.Flags["urlSliceFlag"].Usage) + assert.Equal(t, "https://example.com,http://localhost", group.Flags["urlSliceFlag"].Default) + }) +} + +func TestGetURLSlices(t *testing.T) { + t.Parallel() + + t.Run("Retrieve []*url.URL value", func(t *testing.T) { + t.Parallel() + + parsedURL1, _ := url.Parse("https://example.com") + parsedURL2, _ := url.Parse("https://example.org") + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": []*url.URL{parsedURL1, parsedURL2}}, + } + + result, err := parsedGroup.GetURLSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []*url.URL{parsedURL1, parsedURL2}, result) + }) + + t.Run("Retrieve single *url.URL value as []*url.URL", func(t *testing.T) { + t.Parallel() + + parsedURL, _ := url.Parse("https://example.com") + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": parsedURL}, + } + + result, err := parsedGroup.GetURLSlices("flag1") + assert.NoError(t, err) + assert.Equal(t, []*url.URL{parsedURL}, result) + }) + + t.Run("Flag not found", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{}, + } + + result, err := parsedGroup.GetURLSlices("nonExistentFlag") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'nonExistentFlag' not found in group 'testGroup'") + }) + + t.Run("Flag value is invalid type", func(t *testing.T) { + t.Parallel() + + parsedGroup := &dynflags.ParsedGroup{ + Name: "testGroup", + Values: map[string]interface{}{"flag1": 123}, // Invalid type (int) + } + + result, err := parsedGroup.GetURLSlices("flag1") + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "flag 'flag1' is not a []*url.URL") + }) +} + +func TestURLSlicesGetBound(t *testing.T) { + t.Run("URLSlicesValue - GetBound", func(t *testing.T) { + var slices *[]*url.URL + u1, _ := url.Parse("http://example.com") + u2, _ := url.Parse("http://example.org") + val := []*url.URL{u1, u2} + slices = &val + + urlSlicesValue := dynflags.URLSlicesValue{Bound: slices} + assert.Equal(t, val, urlSlicesValue.GetBound()) + + urlSlicesValue = dynflags.URLSlicesValue{Bound: nil} + assert.Nil(t, urlSlicesValue.GetBound()) + }) +} diff --git a/pkg/dynflags/url_test.go b/pkg/dynflags/url_test.go new file mode 100644 index 0000000..7215d83 --- /dev/null +++ b/pkg/dynflags/url_test.go @@ -0,0 +1,146 @@ +package dynflags_test + +import ( + "net/url" + "testing" + + "github.com/containeroo/portpatrol/pkg/dynflags" + "github.com/stretchr/testify/assert" +) + +func TestURLValue(t *testing.T) { + t.Parallel() + + t.Run("Parse valid URL", func(t *testing.T) { + t.Parallel() + + urlValue := dynflags.URLValue{} + parsed, err := urlValue.Parse("https://example.com") + assert.NoError(t, err) + assert.NotNil(t, parsed) + + parsedURL, ok := parsed.(*url.URL) + assert.True(t, ok) + assert.Equal(t, "https://example.com", parsedURL.String()) + }) + + t.Run("Parse invalid URL", func(t *testing.T) { + t.Parallel() + + urlValue := dynflags.URLValue{} + parsed, err := urlValue.Parse("https://invalid-url^") + assert.Error(t, err) + assert.Nil(t, parsed) + }) + + t.Run("Set valid URL", func(t *testing.T) { + t.Parallel() + + bound := &url.URL{} + urlValue := dynflags.URLValue{Bound: bound} + + parsedURL, _ := url.Parse("https://example.com") + err := urlValue.Set(parsedURL) + assert.NoError(t, err) + assert.Equal(t, "https://example.com", bound.String()) + }) + + t.Run("Set invalid type", func(t *testing.T) { + t.Parallel() + + bound := &url.URL{} + urlValue := dynflags.URLValue{Bound: bound} + + err := urlValue.Set("not-a-url") // Invalid type + assert.Error(t, err) + assert.EqualError(t, err, "invalid value type: expected URL") + }) +} + +func TestGroupConfigURL(t *testing.T) { + t.Parallel() + + t.Run("Define URL flag with default value", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + defaultValue := "https://default.com" + urlFlag := group.URL("urlFlag", defaultValue, "A URL flag") + + assert.Equal(t, "https://default.com", urlFlag.Default) + assert.Contains(t, group.Flags, "urlFlag") + assert.Equal(t, "A URL flag", group.Flags["urlFlag"].Usage) + assert.Equal(t, defaultValue, group.Flags["urlFlag"].Default) + }) + + t.Run("Define URL flag with invalid default", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)} + + assert.PanicsWithValue(t, + "invalid default URL for flag 'urlFlag': parse \"http://i nvalid-url\": invalid character \" \" in host name", + func() { + group.URL("urlFlag", "http://i nvalid-url", "Invalid URL flag") + }) + }) +} + +func TestParsedGroupGetURL(t *testing.T) { + t.Parallel() + + t.Run("Get existing URL flag", func(t *testing.T) { + t.Parallel() + + parsedURL, _ := url.Parse("https://example.com") + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{ + "urlFlag": *parsedURL, + }, + } + retrievedURL, err := group.GetURL("urlFlag") + assert.NoError(t, err) + assert.NotNil(t, retrievedURL) + assert.Equal(t, "https://example.com", retrievedURL.String()) + }) + + t.Run("Get non-existent URL flag", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{}, + } + retrievedURL, err := group.GetURL("urlFlag") + assert.Error(t, err) + assert.Nil(t, retrievedURL) + assert.EqualError(t, err, "flag 'urlFlag' not found in group ''") + }) + + t.Run("Get URL flag with invalid type", func(t *testing.T) { + t.Parallel() + + group := &dynflags.ParsedGroup{ + Values: map[string]interface{}{ + "urlFlag": "not-a-url", // Invalid type + }, + } + retrievedURL, err := group.GetURL("urlFlag") + assert.Error(t, err) + assert.Nil(t, retrievedURL) + assert.EqualError(t, err, "flag 'urlFlag' is not a URL") + }) +} + +func TestURLGetBound(t *testing.T) { + t.Run("URLValue - GetBound", func(t *testing.T) { + var u *url.URL + val, _ := url.Parse("http://example.com") + u = val + + urlValue := dynflags.URLValue{Bound: u} + assert.Equal(t, *val, urlValue.GetBound()) + + urlValue = dynflags.URLValue{Bound: nil} + assert.Nil(t, urlValue.GetBound()) + }) +} diff --git a/pkg/httputils/README.md b/pkg/httputils/README.md new file mode 100644 index 0000000..00589f5 --- /dev/null +++ b/pkg/httputils/README.md @@ -0,0 +1,110 @@ +# httputils + +The `httputils` package provides utility functions for parsing HTTP headers and status codes from strings. These functions are designed to facilitate working with HTTP-related configurations that are passed as strings, such as environment variables or configuration files. + +## Features + +- Parse HTTP status codes and ranges from a string. +- Parse HTTP headers into a key-value map. +- Support for handling duplicate headers. + +## Installation + +To use the `httputils` package, add it to your Go project: + +```sh +go get github.com/containerish/portpatrol/pkg/httputils +``` + +## Usage + +### ParseStatusCodes + +Parses a comma-separated string of HTTP status codes and ranges into a slice of integers. + +__Example:__ + +```go +package main + +import ( + "fmt" + "log" + "httputils" +) + +func main() { + statusString := "200,300-302,404" + statusCodes, err := httputils.ParseStatusCodes(statusString) + if err != nil { + log.Fatalf("Error parsing status codes: %v", err) + } + fmt.Println("Parsed Status Codes:", statusCodes) +} +``` + +__Parameters:__ + +- `statusRanges` (string): Comma-separated string of single status codes (e.g., `200`) and/or ranges (e.g., `200-204`). + +__Returns:__ + +- `[]int`: A slice of status codes. +- `error`: An error if the parsing fails. + +__Output:__ + +```bash +Parsed Status Codes: [200 300 301 302 404] +``` + +## ParseHeaders + +Parses a comma-separated string of HTTP headers into a key-value map. + +__Example:__ + +```go +package main + +import ( + "fmt" + "log" + "httputils" +) + +func main() { + headerString := "Content-Type=application/json,Authorization=Bearer token,X-Custom-Header=" + headers, err := httputils.ParseHeaders(headerString, false) + if err != nil { + log.Fatalf("Error parsing headers: %v", err) + } + fmt.Println("Parsed Headers:", headers) +} +``` + +__Parameters:__ + +- `headers` (string): Comma-separated string of headers in `Key=Value` format. Keys must not be empty. +- `allowDuplicates` (bool): If `true`, overrides previous values for duplicate keys. If `false`, returns an error on duplicate keys. + +__Returns:__ + +- `map[string]string`: A map of header names to values. +- `error`: An error if the parsing fails. + +__Output:__ + +```bash +Parsed Headers: map[Content-Type:application/json Authorization:Bearer token X-Custom-Header:] + +``` + +## Error Handling + +Both `ParseStatusCodes` and `ParseHeaders` return descriptive errors for invalid input, such as: + +- Invalid HTTP status codes or ranges. +- Empty or malformed header keys. +- Duplicate header keys (when `allowDuplicates` is `false`). + diff --git a/pkg/httputils/http_parsing_test.go b/pkg/httputils/http_parsing_test.go index bf8c3d3..e880da7 100644 --- a/pkg/httputils/http_parsing_test.go +++ b/pkg/httputils/http_parsing_test.go @@ -1,8 +1,9 @@ package httputils import ( - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestParseHTTPHeaders(t *testing.T) { @@ -13,14 +14,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "Content-Type=application/json,Auportpatrolization=Bearer token" result, err := ParseHeaders(headers, true) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := map[string]string{"Content-Type": "application/json", "Auportpatrolization": "Bearer token"} - if !reflect.DeepEqual(result, expected) { - t.Errorf("Expected result: %q, got: %q", expected, result) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual(map[string]string{"Content-Type": "application/json", "Auportpatrolization": "Bearer token"}, result) }) t.Run("Single header", func(t *testing.T) { @@ -28,14 +24,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "Content-Type=application/json" result, err := ParseHeaders(headers, true) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := map[string]string{"Content-Type": "application/json"} - if !reflect.DeepEqual(result, expected) { - t.Errorf("Expected result: %q, got: %q", expected, result) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual(map[string]string{"Content-Type": "application/json"}, result) }) t.Run("Empty headers string", func(t *testing.T) { @@ -43,14 +34,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "" result, err := ParseHeaders(headers, true) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := map[string]string{} - if !reflect.DeepEqual(result, expected) { - t.Errorf("Expected result: %q, got: %q", expected, result) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual(map[string]string{}, result) }) t.Run("Malformed header (missing =)", func(t *testing.T) { @@ -58,14 +44,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "Content-Type=application/json,AuportpatrolizationBearer token" _, err := ParseHeaders(headers, true) - if err == nil { - t.Error("Expected error, got nil") - } - - expected := "invalid header format: AuportpatrolizationBearer token" - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } + + assert.Error(t, err) + assert.EqualError(t, err, "invalid header format: AuportpatrolizationBearer token") }) t.Run("Header with spaces", func(t *testing.T) { @@ -73,14 +54,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := " Content-Type = application/json , Auportpatrolization = Bearer token " result, err := ParseHeaders(headers, true) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := map[string]string{"Content-Type": "application/json", "Auportpatrolization": "Bearer token"} - if !reflect.DeepEqual(result, expected) { - t.Errorf("Expected result: %q, got: %q", expected, result) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual(map[string]string{"Content-Type": "application/json", "Auportpatrolization": "Bearer token"}, result) }) t.Run("Header with empty key", func(t *testing.T) { @@ -88,14 +64,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "=value" _, err := ParseHeaders(headers, true) - if err == nil { - t.Error("Expected error, got nil") - } - - expected := "header key cannot be empty: =value" - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } + + assert.Error(t, err) + assert.EqualError(t, err, "header key cannot be empty: =value") }) t.Run("Header with empty value", func(t *testing.T) { @@ -103,14 +74,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "key=" result, err := ParseHeaders(headers, true) - if err != nil { - t.Errorf("Unexpected error: %q", err) - } - - expected := map[string]string{"key": ""} - if !reflect.DeepEqual(result, expected) { - t.Errorf("Expected result: %q, got: %q", expected, result) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual(map[string]string{"key": ""}, result) }) t.Run("Trailing comma", func(t *testing.T) { @@ -118,14 +84,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "Content-Type=application/json," result, err := ParseHeaders(headers, true) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - expected := map[string]string{"Content-Type": "application/json"} - if !reflect.DeepEqual(result, expected) { - t.Errorf("expected %v, got %v", expected, result) - } + assert.NoError(t, err) + + assert.ObjectsAreEqual(map[string]string{"Content-Type": "application/json"}, result) }) t.Run("Valid header with duplicate headers (allowDuplicates=true)", func(t *testing.T) { @@ -133,15 +94,8 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "Content-Type=application/json,Content-Type=application/json" h, err := ParseHeaders(headers, true) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - expected := map[string]string{"Content-Type": "application/json"} - - if !reflect.DeepEqual(h, expected) { - t.Fatalf("expected %v, got %v", expected, h) - } + assert.NoError(t, err) + assert.ObjectsAreEqual(map[string]string{"Content-Type": "application/json"}, h) }) t.Run("Invalid header with duplicate headers (allowDuplicates=false)", func(t *testing.T) { @@ -149,14 +103,9 @@ func TestParseHTTPHeaders(t *testing.T) { headers := "Content-Type=application/json,Content-Type=application/json" _, err := ParseHeaders(headers, false) - if err == nil { - t.Fatalf("expected an error, got none") - } - - expected := "duplicate header key found: Content-Type" - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } + + assert.Error(t, err) + assert.EqualError(t, err, "duplicate header key found: Content-Type") }) } @@ -167,97 +116,62 @@ func TestParseHTTPStatusCodes(t *testing.T) { t.Parallel() statuses, err := ParseStatusCodes("200") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - expected := []int{200} - if !reflect.DeepEqual(statuses, expected) { - t.Fatalf("expected %q, got %q", expected, statuses) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual([]int{200}, statuses) }) t.Run("Valid multiple status codes", func(t *testing.T) { t.Parallel() statuses, err := ParseStatusCodes("200,404,500") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - expected := []int{200, 404, 500} - if !reflect.DeepEqual(statuses, expected) { - t.Fatalf("expected %q, got %q", expected, statuses) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual([]int{200, 404, 500}, statuses) }) t.Run("Valid status code range", func(t *testing.T) { t.Parallel() statuses, err := ParseStatusCodes("200-202") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - expected := []int{200, 201, 202} - if !reflect.DeepEqual(statuses, expected) { - t.Fatalf("expected %q, got %q", expected, statuses) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual([]int{200, 201, 202}, statuses) }) t.Run("Valid multiple status code ranges", func(t *testing.T) { t.Parallel() statuses, err := ParseStatusCodes("200-202,300-301,500") - if err != nil { - t.Fatalf("expected no error, got %q", err) - } - - expected := []int{200, 201, 202, 300, 301, 500} - if !reflect.DeepEqual(statuses, expected) { - t.Fatalf("expected %q, got %q", expected, statuses) - } + + assert.NoError(t, err) + assert.ObjectsAreEqual([]int{200, 201, 202, 300, 301, 500}, statuses) }) t.Run("Invalid status code", func(t *testing.T) { t.Parallel() _, err := ParseStatusCodes("abc") - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "invalid status code: abc" - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } + + assert.Error(t, err) + assert.EqualError(t, err, "invalid status code: abc") }) t.Run("Invalid status range double dash", func(t *testing.T) { t.Parallel() _, err := ParseStatusCodes("200--202") - if err == nil { - t.Fatal("expected an error, got none") - } - - expected := "invalid status range: 200--202" - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } + + assert.Error(t, err) + assert.EqualError(t, err, "invalid status range: 200--202") }) t.Run("Invalid status range (start > end)", func(t *testing.T) { t.Parallel() - _, err := ParseStatusCodes("202-200") - if err == nil { - t.Fatal("expected an error, got none") - } + _, err := ParseStatusCodes("201-200") - expected := "invalid status range: 202-200" - if err.Error() != expected { - t.Fatalf("expected error containing %q, got %q", expected, err) - } + assert.Error(t, err) + assert.EqualError(t, err, "invalid status range: 201-200") }) } diff --git a/pkg/resolver/resolver.go b/pkg/resolver/resolver.go new file mode 100644 index 0000000..f0128ee --- /dev/null +++ b/pkg/resolver/resolver.go @@ -0,0 +1,117 @@ +package resolver + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" +) + +// Constants for variable resolution. +const ( + envPrefix = "env:" // Prefix to identify environment variable references + filePrefix = "file:" // Prefix to identify file references + keyDelim = "//" // Delimiter to identify a key in a file +) + +// ResolveVariable resolves a string value based on its prefix. +// +// - "env:": Treated as an environment variable and resolved accordingly. +// - "file:": Treated as a file path, optionally followed by a key to retrieve a specific line in "key = value" format. +// - No prefix: The string is returned as is. +// +// Parameters: +// - value: The string to resolve. +// +// Returns: +// - The resolved value of the input string. +// - An error if the resolution fails. +func ResolveVariable(value string) (string, error) { + if strings.HasPrefix(value, envPrefix) { + return resolveEnvVariable(value[len(envPrefix):]) + } + + if strings.HasPrefix(value, filePrefix) { + return resolveFileVariable(value[len(filePrefix):]) + } + + return value, nil +} + +// resolveEnvVariable resolves an environment variable. +// +// Parameters: +// - envVar: The name of the environment variable to resolve. +// +// Returns: +// - The value of the environment variable. +// - An error if the environment variable is not found. +func resolveEnvVariable(envVar string) (string, error) { + resolvedVariable, found := os.LookupEnv(envVar) + if !found { + return "", fmt.Errorf("environment variable '%s' not found.", envVar) + } + return resolvedVariable, nil +} + +// resolveFileVariable resolves a file path with an optional key. +// The key should be in the format "key = value". +// +// Parameters: +// - filePathWithKey: The string containing the file path and optional key. +// +// Returns: +// - The resolved value based on the file and optional key. +// - An error if resolving the file or key fails. +func resolveFileVariable(filePathWithKey string) (string, error) { + lastSeparatorIndex := strings.LastIndex(filePathWithKey, keyDelim) + filePath := filePathWithKey // default filePath (whole value) + key := "" // default key (no key) + + // Check for key specification + if lastSeparatorIndex != -1 { + filePath = filePathWithKey[:lastSeparatorIndex] + key = filePathWithKey[lastSeparatorIndex+len(keyDelim):] + } + + filePath = os.ExpandEnv(filePath) + file, err := os.Open(filePath) + if err != nil { + return "", fmt.Errorf("Failed to open file '%s'. %v", filePath, err) + } + defer file.Close() + + if key != "" { + return searchKeyInFile(file, key) + } + + // No key specified, read the whole file + data, err := io.ReadAll(file) + if err != nil { + return "", fmt.Errorf("Failed to read file '%s'. %v", filePath, err) + } + return strings.TrimSpace(string(data)), nil +} + +// searchKeyInFile searches for a specified key in a file and returns its associated value. +// The key should be in the format "key = value". +// +// Parameters: +// - file: The opened file to search for the key. +// - key: The key to search for in the file. +// +// Returns: +// - The value associated with the specified key. +// - An error if the key is not found in the file. +func searchKeyInFile(file *os.File, key string) (string, error) { + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + pair := strings.SplitN(line, "=", 2) + if len(pair) == 2 && strings.TrimSpace(pair[0]) == key { + return strings.TrimSpace(pair[1]), nil + } + } + return "", fmt.Errorf("Key '%s' not found in file '%s'.", key, file.Name()) +} diff --git a/pkg/resolver/resolver_test.go b/pkg/resolver/resolver_test.go new file mode 100644 index 0000000..46bd42e --- /dev/null +++ b/pkg/resolver/resolver_test.go @@ -0,0 +1,78 @@ +package resolver + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResolveVariable(t *testing.T) { + t.Run("Resolve environment variable", func(t *testing.T) { + os.Setenv("TEST_ENV_VAR", "test_value") + defer os.Unsetenv("TEST_ENV_VAR") + + result, err := ResolveVariable("env:TEST_ENV_VAR") + assert.NoError(t, err) + assert.Equal(t, "test_value", result) + }) + + t.Run("Resolve non-existing environment variable", func(t *testing.T) { + result, err := ResolveVariable("env:NON_EXISTING_ENV_VAR") + assert.Error(t, err) + assert.Equal(t, "", result) + assert.Contains(t, err.Error(), "environment variable 'NON_EXISTING_ENV_VAR' not found") + }) + + t.Run("Resolve file variable", func(t *testing.T) { + fileContent := "key1=value1\nkey2=value2\n" + file, err := os.CreateTemp("", "testfile") + assert.NoError(t, err) + defer os.Remove(file.Name()) + + _, err = file.WriteString(fileContent) + assert.NoError(t, err) + file.Close() + + result, err := ResolveVariable("file:" + file.Name()) + assert.NoError(t, err) + assert.Equal(t, fileContent, result+"\n") + }) + + t.Run("Resolve file with key", func(t *testing.T) { + fileContent := "key1=value1\nkey2=value2\n" + file, err := os.CreateTemp("", "testfile") + assert.NoError(t, err) + defer os.Remove(file.Name()) + + _, err = file.WriteString(fileContent) + assert.NoError(t, err) + file.Close() + + result, err := ResolveVariable("file:" + file.Name() + "//key1") + assert.NoError(t, err) + assert.Equal(t, "value1", result) + }) + + t.Run("Resolve file with non-existing key", func(t *testing.T) { + fileContent := "key1=value1\nkey2=value2\n" + file, err := os.CreateTemp("", "testfile") + assert.NoError(t, err) + defer os.Remove(file.Name()) + + _, err = file.WriteString(fileContent) + assert.NoError(t, err) + file.Close() + + result, err := ResolveVariable("file:" + file.Name() + "//non_existing_key") + assert.Error(t, err) + assert.Equal(t, "", result) + assert.Contains(t, err.Error(), "Key 'non_existing_key' not found in file") + }) + + t.Run("Resolve plain string", func(t *testing.T) { + result, err := ResolveVariable("plain_string") + assert.NoError(t, err) + assert.Equal(t, "plain_string", result) + }) +}