diff --git a/cmd/toast/main.go b/cmd/toast/main.go index 72f4962..3f6ff6d 100644 --- a/cmd/toast/main.go +++ b/cmd/toast/main.go @@ -14,7 +14,7 @@ import ( "github.com/containeroo/toast/pkg/runner" ) -const version = "0.0.3" +const version = "0.0.4" // run is the main function of the application func run(ctx context.Context, getenv func(string) string, output io.Writer) error { diff --git a/pkg/checker/checker.go b/pkg/checker/checker.go index fc262f4..cb13369 100644 --- a/pkg/checker/checker.go +++ b/pkg/checker/checker.go @@ -24,6 +24,11 @@ func NewChecker(checkType, name, address string, timeout time.Duration, getEnv f } } +// IsValidCheckType validates if the check type is supported. +func IsValidCheckType(checkType string) bool { + return checkType == "tcp" || checkType == "http" +} + // InferCheckType infers the check type based on the scheme of the target address. func InferCheckType(address string) (string, error) { scheme, _ := extractScheme(address) diff --git a/pkg/checker/checker_test.go b/pkg/checker/checker_test.go index 9480f61..7ea5ede 100644 --- a/pkg/checker/checker_test.go +++ b/pkg/checker/checker_test.go @@ -49,6 +49,45 @@ func TestNewChecker(t *testing.T) { }) } +func TestIsValidCheckType(t *testing.T) { + t.Parallel() + + t.Run("Valid TCP Check Type", func(t *testing.T) { + t.Parallel() + if isValid := IsValidCheckType("tcp"); !isValid { + t.Errorf("expected true for check type 'tcp', got false") + } + }) + + t.Run("Valid HTTP Check Type", func(t *testing.T) { + t.Parallel() + if isValid := IsValidCheckType("http"); !isValid { + t.Errorf("expected true for check type 'http', got false") + } + }) + + t.Run("Invalid Check Type", func(t *testing.T) { + t.Parallel() + if isValid := IsValidCheckType("invalid"); isValid { + t.Errorf("expected false for check type 'invalid', got true") + } + }) + + t.Run("Empty Check Type", func(t *testing.T) { + t.Parallel() + if isValid := IsValidCheckType(""); isValid { + t.Errorf("expected false for empty check type, got true") + } + }) + + t.Run("Random String Check Type", func(t *testing.T) { + t.Parallel() + if isValid := IsValidCheckType("random"); isValid { + t.Errorf("expected false for check type 'random', got true") + } + }) +} + func TestInferCheckType(t *testing.T) { t.Parallel() diff --git a/pkg/checker/http_checker_test.go b/pkg/checker/http_checker_test.go index 9626a61..b84b241 100644 --- a/pkg/checker/http_checker_test.go +++ b/pkg/checker/http_checker_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" "time" ) @@ -76,4 +77,42 @@ func TestHTTPChecker(t *testing.T) { t.Fatal("expected an error, got none") } }) + + t.Run("Test cancel HTTP check", func(t *testing.T) { + t.Parallel() + + // Set up a test HTTP server that deliberately delays the response + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) // Delay to ensure the context has time to be canceled + w.WriteHeader(http.StatusOK) + }) + server := httptest.NewServer(handler) + defer server.Close() + + // Mock environment variables + mockEnv := func(key string) string { + env := map[string]string{ + envMethod: "GET", + envHeaders: "Authorization=Bearer token", + envExpectedStatuses: "200", + } + return env[key] + } + + // Create the HTTP checker using the mock environment variables + checker, err := NewHTTPChecker("example", server.URL, 5*time.Second, mockEnv) + if err != nil { + t.Fatalf("failed to create HTTPChecker: %v", err) + } + + // Cancel the context after a very short time + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + // Perform the check, expecting a context canceled error + err = checker.Check(ctx) + if err == nil || !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) { + t.Errorf("expected context canceled or deadline exceeded error, got %v", err) + } + }) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 77722e6..f42b1a9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -102,14 +102,9 @@ func ParseConfig(getenv func(string) string) (Config, error) { cfg.CheckType = checkType } - if !isValidCheckType(cfg.CheckType) { + if !checker.IsValidCheckType(cfg.CheckType) { return Config{}, fmt.Errorf("unsupported check type: %s", cfg.CheckType) } return cfg, nil } - -// isValidCheckType validates if the check type is supported. -func isValidCheckType(checkType string) bool { - return checkType == "tcp" || checkType == "http" -}