diff --git a/go/acl/acl_test.go b/go/acl/acl_test.go index 680044c1461..641f5e14c6f 100644 --- a/go/acl/acl_test.go +++ b/go/acl/acl_test.go @@ -18,8 +18,15 @@ package acl import ( "errors" + "fmt" "net/http" + "net/http/httptest" + "os" + "os/exec" "testing" + + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" ) type TestPolicy struct{} @@ -50,41 +57,92 @@ func TestSimplePolicy(t *testing.T) { currentPolicy = policies["test"] err := CheckAccessActor("", ADMIN) want := "not allowed" - if err == nil || err.Error() != want { - t.Errorf("got %v, want %s", err, want) - } + assert.Equalf(t, err.Error(), want, "got %v, want %s", err, want) + err = CheckAccessActor("", DEBUGGING) - if err != nil { - t.Errorf("got %v, want no error", err) - } + assert.Equalf(t, err, nil, "got %v, want no error", err) err = CheckAccessHTTP(nil, ADMIN) - if err == nil || err.Error() != want { - t.Errorf("got %v, want %s", err, want) - } + assert.Equalf(t, err.Error(), want, "got %v, want %s", err, want) + err = CheckAccessHTTP(nil, DEBUGGING) - if err != nil { - t.Errorf("got %v, want no error", err) - } + assert.Equalf(t, err, nil, "got %v, want no error", err) + } func TestEmptyPolicy(t *testing.T) { currentPolicy = nil err := CheckAccessActor("", ADMIN) - if err != nil { - t.Errorf("got %v, want no error", err) - } + assert.Equalf(t, err, nil, "got %v, want no error", err) + err = CheckAccessActor("", DEBUGGING) - if err != nil { - t.Errorf("got %v, want no error", err) - } + assert.Equalf(t, err, nil, "got %v, want no error", err) err = CheckAccessHTTP(nil, ADMIN) - if err != nil { - t.Errorf("got %v, want no error", err) - } + assert.Equalf(t, err, nil, "got %v, want no error", err) + err = CheckAccessHTTP(nil, DEBUGGING) - if err != nil { - t.Errorf("got %v, want no error", err) + assert.Equalf(t, err, nil, "got %v, want no error", err) +} + +func TestValidSecurityPolicy(t *testing.T) { + securityPolicy = "test" + savePolicy() + + assert.Equalf(t, TestPolicy{}, currentPolicy, "got %v, expected %v", currentPolicy, TestPolicy{}) +} + +func TestInvalidSecurityPolicy(t *testing.T) { + securityPolicy = "invalidSecurityPolicy" + savePolicy() + + assert.Equalf(t, denyAllPolicy{}, currentPolicy, "got %v, expected %v", currentPolicy, denyAllPolicy{}) +} + +func TestSendError(t *testing.T) { + testW := httptest.NewRecorder() + + testErr := errors.New("Testing error message") + SendError(testW, testErr) + + // Check the status code + assert.Equalf(t, testW.Code, http.StatusForbidden, "got %v; want %v", testW.Code, http.StatusForbidden) + + // Check the writer body + want := fmt.Sprintf("Access denied: %v\n", testErr) + got := testW.Body.String() + assert.Equalf(t, got, want, "got %v; want %v", got, want) +} + +func TestRegisterFlags(t *testing.T) { + testFs := pflag.NewFlagSet("test", pflag.ExitOnError) + securityPolicy = "test" + + RegisterFlags(testFs) + + securityPolicyFlag := testFs.Lookup("security_policy") + assert.NotNil(t, securityPolicyFlag, "no security_policy flag is registered") + + // Check the default value of the flag + want := "test" + got := securityPolicyFlag.DefValue + assert.Equalf(t, got, want, "got %v; want %v", got, want) +} + +func TestAlreadyRegisteredPolicy(t *testing.T) { + if os.Getenv("TEST_ACL") == "1" { + RegisterPolicy("test", nil) + return + } + + // Run subprocess to test os.Exit which is called by log.fatalf + // os.Exit should be called if we try to re-register a policy + cmd := exec.Command(os.Args[0], "-test.run=TestAlreadyRegisteredPolicy") + cmd.Env = append(os.Environ(), "TEST_ACL=1") + err := cmd.Run() + if e, ok := err.(*exec.ExitError); ok && !e.Success() { + return } + + t.Errorf("process ran with err %v, want exit status 1", err) } diff --git a/go/acl/deny_all_policy_test.go b/go/acl/deny_all_policy_test.go new file mode 100644 index 00000000000..b66a344af2e --- /dev/null +++ b/go/acl/deny_all_policy_test.go @@ -0,0 +1,46 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package acl + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDenyAllPolicy(t *testing.T) { + testDenyAllPolicy := denyAllPolicy{} + + want := errDenyAll + err := testDenyAllPolicy.CheckAccessActor("", ADMIN) + assert.Equalf(t, err, want, "got %v; want %v", err, want) + + err = testDenyAllPolicy.CheckAccessActor("", DEBUGGING) + assert.Equalf(t, err, want, "got %v; want %v", err, want) + + err = testDenyAllPolicy.CheckAccessActor("", MONITORING) + assert.Equalf(t, err, want, "got %v; want %v", err, want) + + err = testDenyAllPolicy.CheckAccessHTTP(nil, ADMIN) + assert.Equalf(t, err, want, "got %v; want %v", err, want) + + err = testDenyAllPolicy.CheckAccessHTTP(nil, DEBUGGING) + assert.Equalf(t, err, want, "got %v; want %v", err, want) + + err = testDenyAllPolicy.CheckAccessHTTP(nil, MONITORING) + assert.Equalf(t, err, want, "got %v; want %v", err, want) +} diff --git a/go/acl/read_only_policy_test.go b/go/acl/read_only_policy_test.go new file mode 100644 index 00000000000..c5e988a6734 --- /dev/null +++ b/go/acl/read_only_policy_test.go @@ -0,0 +1,46 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package acl + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReadOnlyPolicy(t *testing.T) { + testReadOnlyPolicy := readOnlyPolicy{} + + want := errReadOnly + err := testReadOnlyPolicy.CheckAccessActor("", ADMIN) + assert.Equalf(t, err, want, "got %v; want %v", err, want) + + err = testReadOnlyPolicy.CheckAccessActor("", DEBUGGING) + assert.Equalf(t, err, nil, "got %v; want no error", err) + + err = testReadOnlyPolicy.CheckAccessActor("", MONITORING) + assert.Equalf(t, err, nil, "got %v; want no error", err) + + err = testReadOnlyPolicy.CheckAccessHTTP(nil, ADMIN) + assert.Equalf(t, err, want, "got %v; want %v", err, want) + + err = testReadOnlyPolicy.CheckAccessHTTP(nil, DEBUGGING) + assert.Equalf(t, err, nil, "got %v; want no error", err) + + err = testReadOnlyPolicy.CheckAccessHTTP(nil, MONITORING) + assert.Equalf(t, err, nil, "got %v; want no error", err) +}