From b4ea6960881bae1d7e83fd6c41d80f32db6ca839 Mon Sep 17 00:00:00 2001 From: Travis Raines <571832+rainest@users.noreply.github.com> Date: Wed, 8 Nov 2023 16:51:19 -0800 Subject: [PATCH] chore: import relevant deck packages Copy the konnect, cprint, crud, diff, dump, file, scripts, state, types, and utils packages from deck into this repository. The source was imported from deck@v1.29.2: https://github.com/Kong/deck/releases/tag/v1.29.2 --- pkg/cprint/color.go | 77 + pkg/cprint/color_test.go | 95 + pkg/crud/registry.go | 128 + pkg/crud/registry_test.go | 228 ++ pkg/crud/types.go | 51 + pkg/crud/types_test.go | 15 + pkg/diff/diff.go | 559 ++++ pkg/diff/diff_helpers.go | 136 + pkg/diff/diff_helpers_test.go | 176 ++ pkg/diff/order.go | 131 + pkg/diff/order_test.go | 91 + pkg/dump/dump.go | 888 ++++++ pkg/dump/dump_konnect.go | 195 ++ pkg/dump/dump_konnect_test.go | 373 +++ pkg/dump/dump_test.go | 63 + pkg/file/builder.go | 1235 ++++++++ pkg/file/builder_test.go | 2757 +++++++++++++++++ pkg/file/codegen/.gitignore | 1 + pkg/file/codegen/main.go | 121 + pkg/file/kong_json_schema.json | 1808 +++++++++++ pkg/file/konnect.go | 64 + pkg/file/reader.go | 137 + pkg/file/reader_test.go | 169 + pkg/file/readfile.go | 236 ++ pkg/file/readfile_test.go | 636 ++++ pkg/file/schema.go | 6 + pkg/file/testdata/bad-env-var/file.yaml | 9 + pkg/file/testdata/badjson/foo.json | 13 + pkg/file/testdata/badyaml/bar.yml | 8 + pkg/file/testdata/badyamlwithspace/bar.yml | 12 + pkg/file/testdata/config.yaml | 4 + pkg/file/testdata/defaults/bar.yaml | 7 + .../testdata/differentruntimegroup/bar.yaml | 6 + .../testdata/differentruntimegroup/foo.yaml | 6 + pkg/file/testdata/differentworkspace/bar.yaml | 11 + pkg/file/testdata/differentworkspace/foo.yaml | 11 + pkg/file/testdata/emptydir/README | 1 + pkg/file/testdata/emptyfiles/Baz.YamL | 0 pkg/file/testdata/emptyfiles/bar.yaml | 0 pkg/file/testdata/emptyfiles/foo.notyaml | 0 pkg/file/testdata/emptyfiles/foo.yaml.pdf | Bin pkg/file/testdata/emptyfiles/foo.yml | 0 pkg/file/testdata/emptyfiles/foobar.json | 0 .../emptyfiles/not-a-file.yaml/info.txt | 0 pkg/file/testdata/file.json | 10 + pkg/file/testdata/file.yaml | 15 + pkg/file/testdata/parsebool/file.yaml | 4 + pkg/file/testdata/parsefloat/file.yaml | 4 + pkg/file/testdata/parseint/file.yaml | 4 + pkg/file/testdata/sameworkspace/bar.yaml | 11 + pkg/file/testdata/sameworkspace/foo.yaml | 11 + pkg/file/testdata/sharedworkspace/foo.yaml | 9 + pkg/file/testdata/sharedworkspace/meta.yaml | 2 + pkg/file/testdata/valid/bar.yml | 9 + pkg/file/testdata/valid/consumers.json | 10 + pkg/file/testdata/valid/foo.yaml | 14 + pkg/file/types.go | 743 +++++ pkg/file/types_test.go | 496 +++ pkg/file/validate.go | 68 + pkg/file/writer.go | 842 +++++ pkg/file/writer_test.go | 405 +++ pkg/file/zz_generated.deepcopy.go | 823 +++++ pkg/konnect/client.go | 179 ++ pkg/konnect/consumer_group.go | 472 +++ .../control_plane_relations_service.go | 137 + pkg/konnect/control_plane_service.go | 36 + pkg/konnect/document_service.go | 150 + pkg/konnect/error.go | 62 + pkg/konnect/list.go | 60 + pkg/konnect/login_service.go | 146 + pkg/konnect/login_service_test.go | 42 + pkg/konnect/request.go | 53 + pkg/konnect/runtime_group_service.go | 38 + pkg/konnect/service_package_service.go | 118 + pkg/konnect/service_version_service.go | 107 + pkg/konnect/types.go | 148 + pkg/konnect/utils.go | 10 + pkg/konnect/zz_generated.deepcopy.go | 200 ++ pkg/scripts/header-template.go.tmpl | 15 + pkg/scripts/update-deepcopy-gen.sh | 23 + pkg/scripts/verify-codegen.sh | 7 + pkg/scripts/verify-deepcopy-gen.sh | 23 + pkg/state/aclgroup.go | 262 ++ pkg/state/aclgroup_test.go | 240 ++ pkg/state/basicauth.go | 85 + pkg/state/basicauth_test.go | 219 ++ pkg/state/builder.go | 385 +++ pkg/state/cacert.go | 173 ++ pkg/state/cacert_test.go | 144 + pkg/state/certificate.go | 235 ++ pkg/state/certificate_test.go | 220 ++ pkg/state/consumer.go | 204 ++ pkg/state/consumer_group.go | 176 ++ pkg/state/consumer_group_consumers.go | 244 ++ pkg/state/consumer_group_plugin.go | 242 ++ pkg/state/consumer_test.go | 161 + pkg/state/credentials.go | 201 ++ pkg/state/document.go | 219 ++ pkg/state/document_test.go | 438 +++ pkg/state/hmacauth.go | 85 + pkg/state/hmacauth_test.go | 218 ++ pkg/state/indexers/md5Indexer.go | 55 + pkg/state/indexers/md5Indexer_test.go | 60 + pkg/state/indexers/methodIndexer.go | 48 + pkg/state/indexers/methodIndexer_test.go | 71 + pkg/state/indexers/subFieldIndexer.go | 66 + pkg/state/indexers/subFieldIndexer_test.go | 100 + pkg/state/jwtauth.go | 85 + pkg/state/jwtauth_test.go | 214 ++ pkg/state/keyauth.go | 85 + pkg/state/keyauth_test.go | 256 ++ pkg/state/konnect_types.go | 143 + pkg/state/mtlsauth.go | 85 + pkg/state/mtlsauth_test.go | 204 ++ pkg/state/oauth2.go | 85 + pkg/state/oauth2_test.go | 211 ++ pkg/state/plugin.go | 380 +++ pkg/state/plugin_test.go | 574 ++++ pkg/state/rbac_endpoint_permission.go | 210 ++ pkg/state/rbac_endpoint_permission_test.go | 333 ++ pkg/state/rbac_role.go | 176 ++ pkg/state/rbac_role_test.go | 298 ++ pkg/state/route.go | 213 ++ pkg/state/route_test.go | 439 +++ pkg/state/service.go | 172 + pkg/state/service_package.go | 171 + pkg/state/service_package_test.go | 390 +++ pkg/state/service_test.go | 426 +++ pkg/state/service_version.go | 226 ++ pkg/state/service_version_test.go | 415 +++ pkg/state/sni.go | 226 ++ pkg/state/sni_test.go | 481 +++ pkg/state/state.go | 131 + pkg/state/state_test.go | 22 + pkg/state/target.go | 236 ++ pkg/state/target_test.go | 263 ++ pkg/state/types.go | 1601 ++++++++++ pkg/state/types_test.go | 502 +++ pkg/state/upstream.go | 178 ++ pkg/state/upstream_test.go | 172 + pkg/state/utils.go | 56 + pkg/state/vault.go | 176 ++ pkg/types/aclgroup.go | 185 ++ pkg/types/basicauth.go | 199 ++ pkg/types/ca_cert.go | 165 + pkg/types/certificate.go | 202 ++ pkg/types/consumer.go | 228 ++ pkg/types/consumer_group.go | 183 ++ pkg/types/consumer_group_consumer.go | 220 ++ pkg/types/consumer_group_plugin.go | 177 ++ pkg/types/core.go | 535 ++++ pkg/types/document.go | 183 ++ pkg/types/hmacauth.go | 184 ++ pkg/types/jwtauth.go | 183 ++ pkg/types/keyauth.go | 170 + pkg/types/mtlsauth.go | 170 + pkg/types/oauth2.go | 187 ++ pkg/types/plugin.go | 206 ++ pkg/types/postProcess.go | 457 +++ pkg/types/rbac_endpoint_permission.go | 174 ++ pkg/types/rbac_role.go | 161 + pkg/types/route.go | 213 ++ pkg/types/service.go | 222 ++ pkg/types/service_package.go | 160 + pkg/types/service_version.go | 232 ++ pkg/types/sni.go | 159 + pkg/types/target.go | 169 + pkg/types/upstream.go | 162 + pkg/types/vault.go | 166 + pkg/utils/analytics.go | 76 + pkg/utils/constants.go | 88 + pkg/utils/counter.go | 20 + pkg/utils/counter_test.go | 23 + pkg/utils/defaulter.go | 410 +++ pkg/utils/defaulter_test.go | 737 +++++ pkg/utils/prompt.go | 49 + pkg/utils/tags.go | 84 + pkg/utils/tags_test.go | 92 + pkg/utils/types.go | 347 +++ pkg/utils/types_test.go | 138 + pkg/utils/utils.go | 252 ++ pkg/utils/utils_test.go | 328 ++ pkg/utils/uuid.go | 10 + pkg/utils/uuid_test.go | 17 + pkg/utils/zero.go | 38 + pkg/utils/zero_test.go | 138 + 186 files changed, 38855 insertions(+) create mode 100644 pkg/cprint/color.go create mode 100644 pkg/cprint/color_test.go create mode 100644 pkg/crud/registry.go create mode 100644 pkg/crud/registry_test.go create mode 100644 pkg/crud/types.go create mode 100644 pkg/crud/types_test.go create mode 100644 pkg/diff/diff.go create mode 100644 pkg/diff/diff_helpers.go create mode 100644 pkg/diff/diff_helpers_test.go create mode 100644 pkg/diff/order.go create mode 100644 pkg/diff/order_test.go create mode 100644 pkg/dump/dump.go create mode 100644 pkg/dump/dump_konnect.go create mode 100644 pkg/dump/dump_konnect_test.go create mode 100644 pkg/dump/dump_test.go create mode 100644 pkg/file/builder.go create mode 100644 pkg/file/builder_test.go create mode 100644 pkg/file/codegen/.gitignore create mode 100644 pkg/file/codegen/main.go create mode 100644 pkg/file/kong_json_schema.json create mode 100644 pkg/file/konnect.go create mode 100644 pkg/file/reader.go create mode 100644 pkg/file/reader_test.go create mode 100644 pkg/file/readfile.go create mode 100644 pkg/file/readfile_test.go create mode 100644 pkg/file/schema.go create mode 100644 pkg/file/testdata/bad-env-var/file.yaml create mode 100644 pkg/file/testdata/badjson/foo.json create mode 100644 pkg/file/testdata/badyaml/bar.yml create mode 100644 pkg/file/testdata/badyamlwithspace/bar.yml create mode 100644 pkg/file/testdata/config.yaml create mode 100644 pkg/file/testdata/defaults/bar.yaml create mode 100644 pkg/file/testdata/differentruntimegroup/bar.yaml create mode 100644 pkg/file/testdata/differentruntimegroup/foo.yaml create mode 100644 pkg/file/testdata/differentworkspace/bar.yaml create mode 100644 pkg/file/testdata/differentworkspace/foo.yaml create mode 100644 pkg/file/testdata/emptydir/README create mode 100644 pkg/file/testdata/emptyfiles/Baz.YamL create mode 100644 pkg/file/testdata/emptyfiles/bar.yaml create mode 100644 pkg/file/testdata/emptyfiles/foo.notyaml create mode 100644 pkg/file/testdata/emptyfiles/foo.yaml.pdf create mode 100644 pkg/file/testdata/emptyfiles/foo.yml create mode 100644 pkg/file/testdata/emptyfiles/foobar.json create mode 100644 pkg/file/testdata/emptyfiles/not-a-file.yaml/info.txt create mode 100644 pkg/file/testdata/file.json create mode 100644 pkg/file/testdata/file.yaml create mode 100644 pkg/file/testdata/parsebool/file.yaml create mode 100644 pkg/file/testdata/parsefloat/file.yaml create mode 100644 pkg/file/testdata/parseint/file.yaml create mode 100644 pkg/file/testdata/sameworkspace/bar.yaml create mode 100644 pkg/file/testdata/sameworkspace/foo.yaml create mode 100644 pkg/file/testdata/sharedworkspace/foo.yaml create mode 100644 pkg/file/testdata/sharedworkspace/meta.yaml create mode 100644 pkg/file/testdata/valid/bar.yml create mode 100644 pkg/file/testdata/valid/consumers.json create mode 100644 pkg/file/testdata/valid/foo.yaml create mode 100644 pkg/file/types.go create mode 100644 pkg/file/types_test.go create mode 100644 pkg/file/validate.go create mode 100644 pkg/file/writer.go create mode 100644 pkg/file/writer_test.go create mode 100644 pkg/file/zz_generated.deepcopy.go create mode 100644 pkg/konnect/client.go create mode 100644 pkg/konnect/consumer_group.go create mode 100644 pkg/konnect/control_plane_relations_service.go create mode 100644 pkg/konnect/control_plane_service.go create mode 100644 pkg/konnect/document_service.go create mode 100644 pkg/konnect/error.go create mode 100644 pkg/konnect/list.go create mode 100644 pkg/konnect/login_service.go create mode 100644 pkg/konnect/login_service_test.go create mode 100644 pkg/konnect/request.go create mode 100644 pkg/konnect/runtime_group_service.go create mode 100644 pkg/konnect/service_package_service.go create mode 100644 pkg/konnect/service_version_service.go create mode 100644 pkg/konnect/types.go create mode 100644 pkg/konnect/utils.go create mode 100644 pkg/konnect/zz_generated.deepcopy.go create mode 100644 pkg/scripts/header-template.go.tmpl create mode 100755 pkg/scripts/update-deepcopy-gen.sh create mode 100755 pkg/scripts/verify-codegen.sh create mode 100755 pkg/scripts/verify-deepcopy-gen.sh create mode 100644 pkg/state/aclgroup.go create mode 100644 pkg/state/aclgroup_test.go create mode 100644 pkg/state/basicauth.go create mode 100644 pkg/state/basicauth_test.go create mode 100644 pkg/state/builder.go create mode 100644 pkg/state/cacert.go create mode 100644 pkg/state/cacert_test.go create mode 100644 pkg/state/certificate.go create mode 100644 pkg/state/certificate_test.go create mode 100644 pkg/state/consumer.go create mode 100644 pkg/state/consumer_group.go create mode 100644 pkg/state/consumer_group_consumers.go create mode 100644 pkg/state/consumer_group_plugin.go create mode 100644 pkg/state/consumer_test.go create mode 100644 pkg/state/credentials.go create mode 100644 pkg/state/document.go create mode 100644 pkg/state/document_test.go create mode 100644 pkg/state/hmacauth.go create mode 100644 pkg/state/hmacauth_test.go create mode 100644 pkg/state/indexers/md5Indexer.go create mode 100644 pkg/state/indexers/md5Indexer_test.go create mode 100644 pkg/state/indexers/methodIndexer.go create mode 100644 pkg/state/indexers/methodIndexer_test.go create mode 100644 pkg/state/indexers/subFieldIndexer.go create mode 100644 pkg/state/indexers/subFieldIndexer_test.go create mode 100644 pkg/state/jwtauth.go create mode 100644 pkg/state/jwtauth_test.go create mode 100644 pkg/state/keyauth.go create mode 100644 pkg/state/keyauth_test.go create mode 100644 pkg/state/konnect_types.go create mode 100644 pkg/state/mtlsauth.go create mode 100644 pkg/state/mtlsauth_test.go create mode 100644 pkg/state/oauth2.go create mode 100644 pkg/state/oauth2_test.go create mode 100644 pkg/state/plugin.go create mode 100644 pkg/state/plugin_test.go create mode 100644 pkg/state/rbac_endpoint_permission.go create mode 100644 pkg/state/rbac_endpoint_permission_test.go create mode 100644 pkg/state/rbac_role.go create mode 100644 pkg/state/rbac_role_test.go create mode 100644 pkg/state/route.go create mode 100644 pkg/state/route_test.go create mode 100644 pkg/state/service.go create mode 100644 pkg/state/service_package.go create mode 100644 pkg/state/service_package_test.go create mode 100644 pkg/state/service_test.go create mode 100644 pkg/state/service_version.go create mode 100644 pkg/state/service_version_test.go create mode 100644 pkg/state/sni.go create mode 100644 pkg/state/sni_test.go create mode 100644 pkg/state/state.go create mode 100644 pkg/state/state_test.go create mode 100644 pkg/state/target.go create mode 100644 pkg/state/target_test.go create mode 100644 pkg/state/types.go create mode 100644 pkg/state/types_test.go create mode 100644 pkg/state/upstream.go create mode 100644 pkg/state/upstream_test.go create mode 100644 pkg/state/utils.go create mode 100644 pkg/state/vault.go create mode 100644 pkg/types/aclgroup.go create mode 100644 pkg/types/basicauth.go create mode 100644 pkg/types/ca_cert.go create mode 100644 pkg/types/certificate.go create mode 100644 pkg/types/consumer.go create mode 100644 pkg/types/consumer_group.go create mode 100644 pkg/types/consumer_group_consumer.go create mode 100644 pkg/types/consumer_group_plugin.go create mode 100644 pkg/types/core.go create mode 100644 pkg/types/document.go create mode 100644 pkg/types/hmacauth.go create mode 100644 pkg/types/jwtauth.go create mode 100644 pkg/types/keyauth.go create mode 100644 pkg/types/mtlsauth.go create mode 100644 pkg/types/oauth2.go create mode 100644 pkg/types/plugin.go create mode 100644 pkg/types/postProcess.go create mode 100644 pkg/types/rbac_endpoint_permission.go create mode 100644 pkg/types/rbac_role.go create mode 100644 pkg/types/route.go create mode 100644 pkg/types/service.go create mode 100644 pkg/types/service_package.go create mode 100644 pkg/types/service_version.go create mode 100644 pkg/types/sni.go create mode 100644 pkg/types/target.go create mode 100644 pkg/types/upstream.go create mode 100644 pkg/types/vault.go create mode 100644 pkg/utils/analytics.go create mode 100644 pkg/utils/constants.go create mode 100644 pkg/utils/counter.go create mode 100644 pkg/utils/counter_test.go create mode 100644 pkg/utils/defaulter.go create mode 100644 pkg/utils/defaulter_test.go create mode 100644 pkg/utils/prompt.go create mode 100644 pkg/utils/tags.go create mode 100644 pkg/utils/tags_test.go create mode 100644 pkg/utils/types.go create mode 100644 pkg/utils/types_test.go create mode 100644 pkg/utils/utils.go create mode 100644 pkg/utils/utils_test.go create mode 100644 pkg/utils/uuid.go create mode 100644 pkg/utils/uuid_test.go create mode 100644 pkg/utils/zero.go create mode 100644 pkg/utils/zero_test.go diff --git a/pkg/cprint/color.go b/pkg/cprint/color.go new file mode 100644 index 0000000..e5c9d8d --- /dev/null +++ b/pkg/cprint/color.go @@ -0,0 +1,77 @@ +package cprint + +import ( + "sync" + + "github.com/fatih/color" +) + +var ( + // mu is used to synchronize writes from multiple goroutines. + mu sync.Mutex + // DisableOutput disables all output. + DisableOutput bool +) + +func conditionalPrintf(fn func(string, ...interface{}), format string, a ...interface{}) { + if DisableOutput { + return + } + mu.Lock() + defer mu.Unlock() + fn(format, a...) +} + +func conditionalPrintln(fn func(...interface{}), a ...interface{}) { + if DisableOutput { + return + } + mu.Lock() + defer mu.Unlock() + fn(a...) +} + +var ( + createPrintf = color.New(color.FgGreen).PrintfFunc() + deletePrintf = color.New(color.FgRed).PrintfFunc() + updatePrintf = color.New(color.FgYellow).PrintfFunc() + + // CreatePrintf is fmt.Printf with red as foreground color. + CreatePrintf = func(format string, a ...interface{}) { + conditionalPrintf(createPrintf, format, a...) + } + + // DeletePrintf is fmt.Printf with green as foreground color. + DeletePrintf = func(format string, a ...interface{}) { + conditionalPrintf(deletePrintf, format, a...) + } + + // UpdatePrintf is fmt.Printf with yellow as foreground color. + UpdatePrintf = func(format string, a ...interface{}) { + conditionalPrintf(updatePrintf, format, a...) + } + + createPrintln = color.New(color.FgGreen).PrintlnFunc() + deletePrintln = color.New(color.FgRed).PrintlnFunc() + updatePrintln = color.New(color.FgYellow).PrintlnFunc() + bluePrintln = color.New(color.BgBlue).PrintlnFunc() + + // CreatePrintln is fmt.Println with red as foreground color. + CreatePrintln = func(a ...interface{}) { + conditionalPrintln(createPrintln, a...) + } + + // DeletePrintln is fmt.Println with green as foreground color. + DeletePrintln = func(a ...interface{}) { + conditionalPrintln(deletePrintln, a...) + } + + // UpdatePrintln is fmt.Println with yellow as foreground color. + UpdatePrintln = func(a ...interface{}) { + conditionalPrintln(updatePrintln, a...) + } + + BluePrintLn = func(a ...interface{}) { + conditionalPrintln(bluePrintln, a...) + } +) diff --git a/pkg/cprint/color_test.go b/pkg/cprint/color_test.go new file mode 100644 index 0000000..cbec6c2 --- /dev/null +++ b/pkg/cprint/color_test.go @@ -0,0 +1,95 @@ +package cprint + +import ( + "bytes" + "os" + "testing" + + "github.com/fatih/color" + "github.com/stretchr/testify/assert" +) + +// captureOutput captures color.Output and returns the recorded output as +// f runs. +// It is not thread-safe. +func captureOutput(f func()) string { + backupOutput := color.Output + defer func() { + color.Output = backupOutput + }() + var out bytes.Buffer + color.Output = &out + f() + return out.String() +} + +func TestMain(m *testing.M) { + backup := color.NoColor + color.NoColor = false + exitVal := m.Run() + color.NoColor = backup + os.Exit(exitVal) +} + +func TestPrint(t *testing.T) { + tests := []struct { + name string + DisableOutput bool + Run func() + Expected string + }{ + { + name: "println prints colored output", + DisableOutput: false, + Run: func() { + CreatePrintln("foo") + UpdatePrintln("bar") + DeletePrintln("fubaz") + }, + Expected: "\x1b[32mfoo\n\x1b[0m\x1b[33mbar\n\x1b[0m\x1b[31mfubaz\n\x1b[0m", + }, + { + name: "println doesn't output anything when disabled", + DisableOutput: true, + Run: func() { + CreatePrintln("foo") + UpdatePrintln("bar") + DeletePrintln("fubaz") + }, + Expected: "", + }, + { + name: "printf prints colored output", + DisableOutput: false, + Run: func() { + CreatePrintf("%s", "foo") + UpdatePrintf("%s", "bar") + DeletePrintf("%s", "fubaz") + }, + Expected: "\x1b[32mfoo\x1b[0m\x1b[33mbar\x1b[0m\x1b[31mfubaz\x1b[0m", + }, + { + name: "printf doesn't output anything when disabled", + DisableOutput: true, + Run: func() { + CreatePrintln("foo") + UpdatePrintln("bar") + DeletePrintln("fubaz") + }, + Expected: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + DisableOutput = tt.DisableOutput + defer func() { + DisableOutput = false + }() + + output := captureOutput(func() { + tt.Run() + }) + assert.Equal(t, tt.Expected, output) + }) + } +} diff --git a/pkg/crud/registry.go b/pkg/crud/registry.go new file mode 100644 index 0000000..3d04ebf --- /dev/null +++ b/pkg/crud/registry.go @@ -0,0 +1,128 @@ +package crud + +import ( + "context" + "fmt" +) + +// Kind represents Kind of an entity or object. +type Kind string + +// Registry can hold Kinds and their respective CRUD operations. +type Registry struct { + types map[Kind]Actions +} + +func (r *Registry) typesMap() map[Kind]Actions { + if r.types == nil { + r.types = make(map[Kind]Actions) + } + return r.types +} + +// Register a kind with actions. +// An error will be returned if kind was previously registered. +func (r *Registry) Register(kind Kind, a Actions) error { + if kind == "" { + return fmt.Errorf("kind cannot be empty") + } + m := r.typesMap() + if _, ok := m[kind]; ok { + return fmt.Errorf("kind %q already registered", kind) + } + m[kind] = a + return nil +} + +// MustRegister is same as Register but panics on error. +func (r *Registry) MustRegister(kind Kind, a Actions) { + err := r.Register(kind, a) + if err != nil { + panic(err) + } +} + +// Get returns actions associated with kind. +// An error will be returned if kind was never registered. +func (r *Registry) Get(kind Kind) (Actions, error) { + if kind == "" { + return nil, fmt.Errorf("kind cannot be empty") + } + m := r.typesMap() + a, ok := m[kind] + if !ok { + return nil, fmt.Errorf("kind %q is not registered", kind) + } + return a, nil +} + +// Create calls the registered create action of kind with args +// and returns the result and error (if any). +func (r *Registry) Create(ctx context.Context, kind Kind, args ...Arg) (Arg, error) { + a, err := r.Get(kind) + if err != nil { + return nil, fmt.Errorf("create failed: %w", err) + } + + res, err := a.Create(ctx, args...) + if err != nil { + return nil, fmt.Errorf("create failed: %w", err) + } + return res, nil +} + +// Update calls the registered update action of kind with args +// and returns the result and error (if any). +func (r *Registry) Update(ctx context.Context, kind Kind, args ...Arg) (Arg, error) { + a, err := r.Get(kind) + if err != nil { + return nil, fmt.Errorf("update failed: %w", err) + } + + res, err := a.Update(ctx, args...) + if err != nil { + return nil, fmt.Errorf("update failed: %w", err) + } + return res, nil +} + +// Delete calls the registered delete action of kind with args +// and returns the result and error (if any). +func (r *Registry) Delete(ctx context.Context, kind Kind, args ...Arg) (Arg, error) { + a, err := r.Get(kind) + if err != nil { + return nil, fmt.Errorf("delete failed: %w", err) + } + + res, err := a.Delete(ctx, args...) + if err != nil { + return nil, fmt.Errorf("delete failed: %w", err) + } + return res, nil +} + +// Do calls an action based on op with args and returns the result and error. +func (r *Registry) Do(ctx context.Context, kind Kind, op Op, args ...Arg) (Arg, error) { + a, err := r.Get(kind) + if err != nil { + return nil, fmt.Errorf("%v failed: %w", op, err) + } + + var res Arg + + switch op.name { + case Create.name: + res, err = a.Create(ctx, args...) + case Update.name: + res, err = a.Update(ctx, args...) + case Delete.name: + res, err = a.Delete(ctx, args...) + default: + return nil, fmt.Errorf("unknown operation: %s", op.name) + } + + if err != nil { + return nil, err + } + return res, nil +} diff --git a/pkg/crud/registry_test.go b/pkg/crud/registry_test.go new file mode 100644 index 0000000..1b20e74 --- /dev/null +++ b/pkg/crud/registry_test.go @@ -0,0 +1,228 @@ +package crud + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testActionFixture struct { + state string +} + +func newTestActionFixture(state string) testActionFixture { + return testActionFixture{state: state} +} + +func (t testActionFixture) invoke(op string, inputs ...Arg) (Arg, error) { + res := t.state + " " + op + + for _, input := range inputs { + iString, ok := input.(string) + if !ok { + return nil, fmt.Errorf("input is not a string") + } + res += " " + iString + } + return res, nil +} + +func (t testActionFixture) Create(_ context.Context, input ...Arg) (Arg, error) { + return t.invoke("create", input...) +} + +func (t testActionFixture) Delete(_ context.Context, input ...Arg) (Arg, error) { + return t.invoke("delete", input...) +} + +func (t testActionFixture) Update(_ context.Context, input ...Arg) (Arg, error) { + return t.invoke("update", input...) +} + +func TestRegistryRegister(t *testing.T) { + assert := assert.New(t) + var r Registry + var a Actions = newTestActionFixture("yolo") + + err := r.Register("", nil) + assert.NotNil(err) + + err = r.Register("foo", a) + assert.Nil(err) + + err = r.Register("foo", a) + assert.NotNil(err) +} + +func TestRegistryMustRegister(t *testing.T) { + assert := assert.New(t) + var r Registry + var a Actions = newTestActionFixture("yolo") + + assert.Panics(func() { + r.MustRegister("", nil) + }) + + assert.NotPanics(func() { + r.MustRegister("foo", a) + }) + + assert.Panics(func() { + r.MustRegister("foo", a) + }) +} + +func TestRegistryGet(t *testing.T) { + assert := assert.New(t) + var r Registry + var a Actions = newTestActionFixture("foo") + + err := r.Register("foo", a) + assert.Nil(err) + + a, err = r.Get("foo") + assert.Nil(err) + assert.NotNil(a) + + a, err = r.Get("bar") + assert.NotNil(err) + assert.Nil(a) + + a, err = r.Get("") + assert.NotNil(err) + assert.Nil(a) +} + +func TestRegistryCreate(t *testing.T) { + assert := assert.New(t) + var r Registry + var a Actions = newTestActionFixture("foo") + + err := r.Register("foo", a) + assert.Nil(err) + + res, err := r.Create(context.Background(), "foo", "yolo") + assert.Nil(err) + assert.NotNil(res) + result, ok := res.(string) + assert.True(ok) + assert.Equal("foo create yolo", result) + + // make sure it takes multiple arguments + res, err = r.Create(context.Background(), "foo", "yolo", "always") + assert.Nil(err) + assert.NotNil(res) + result, ok = res.(string) + assert.True(ok) + assert.Equal("foo create yolo always", result) + + res, err = r.Create(context.Background(), "foo", 42) + assert.NotNil(err) + assert.Nil(res) + + res, err = r.Create(context.Background(), "bar", 42) + assert.NotNil(err) + assert.Nil(res) +} + +func TestRegistryUpdate(t *testing.T) { + assert := assert.New(t) + var r Registry + var a Actions = newTestActionFixture("foo") + + err := r.Register("foo", a) + assert.Nil(err) + + res, err := r.Update(context.Background(), "foo", "yolo") + assert.Nil(err) + assert.NotNil(res) + result, ok := res.(string) + assert.True(ok) + assert.Equal("foo update yolo", result) + + // make sure it takes multiple arguments + res, err = r.Update(context.Background(), "foo", "yolo", "always") + assert.Nil(err) + assert.NotNil(res) + result, ok = res.(string) + assert.True(ok) + assert.Equal("foo update yolo always", result) + + res, err = r.Update(context.Background(), "foo", 42) + assert.NotNil(err) + assert.Nil(res) + + res, err = r.Update(context.Background(), "bar", 42) + assert.NotNil(err) + assert.Nil(res) +} + +func TestRegistryDelete(t *testing.T) { + assert := assert.New(t) + var r Registry + var a Actions = newTestActionFixture("foo") + + err := r.Register("foo", a) + assert.Nil(err) + + res, err := r.Delete(context.Background(), "foo", "yolo") + assert.Nil(err) + assert.NotNil(res) + result, ok := res.(string) + assert.True(ok) + assert.Equal("foo delete yolo", result) + + // make sure it takes multiple arguments + res, err = r.Delete(context.Background(), "foo", "yolo", "always") + assert.Nil(err) + assert.NotNil(res) + result, ok = res.(string) + assert.True(ok) + assert.Equal("foo delete yolo always", result) + + res, err = r.Delete(context.Background(), "foo", 42) + assert.NotNil(err) + assert.Nil(res) + + res, err = r.Delete(context.Background(), "bar", 42) + assert.NotNil(err) + assert.Nil(res) +} + +func TestRegistryDo(t *testing.T) { + assert := assert.New(t) + var r Registry + var a Actions = newTestActionFixture("foo") + + err := r.Register("foo", a) + assert.Nil(err) + + res, err := r.Do(context.Background(), "foo", Create, "yolo") + assert.Nil(err) + assert.NotNil(res) + result, ok := res.(string) + assert.True(ok) + assert.Equal("foo create yolo", result) + + // make sure it takes multiple arguments + res, err = r.Do(context.Background(), "foo", Update, "yolo", "always") + assert.Nil(err) + assert.NotNil(res) + result, ok = res.(string) + assert.True(ok) + assert.Equal("foo update yolo always", result) + + res, err = r.Do(context.Background(), "foo", Delete, 42) + assert.NotNil(err) + assert.Nil(res) + + res, err = r.Do(context.Background(), "foo", Op{"unknown-op"}, 42) + assert.NotNil(err) + assert.Nil(res) + + res, err = r.Do(context.Background(), "bar", Create, "yolo") + assert.NotNil(err) + assert.Nil(res) +} diff --git a/pkg/crud/types.go b/pkg/crud/types.go new file mode 100644 index 0000000..0853944 --- /dev/null +++ b/pkg/crud/types.go @@ -0,0 +1,51 @@ +package crud + +import "context" + +// Op represents +type Op struct { + name string +} + +func (op *Op) String() string { + return op.name +} + +var ( + // Create is a constant representing create operations. + Create = Op{"Create"} + // Update is a constant representing update operations. + Update = Op{"Update"} + // Delete is a constant representing delete operations. + Delete = Op{"Delete"} +) + +// Arg is an argument to a callback function. +type Arg interface{} + +// Actions is an interface for CRUD operations on any entity +type Actions interface { + Create(context.Context, ...Arg) (Arg, error) + Delete(context.Context, ...Arg) (Arg, error) + Update(context.Context, ...Arg) (Arg, error) +} + +// Event represents an event to perform +// an imperative operation +// that gets Kong closer to the target state. +type Event struct { + Op Op + Kind Kind + Obj interface{} + OldObj interface{} +} + +// EventFromArg converts arg into Event. +// It panics if the type of arg is not Event. +func EventFromArg(arg Arg) Event { + event, ok := arg.(Event) + if !ok { + panic("unexpected type, expected diff.Event") + } + return event +} diff --git a/pkg/crud/types_test.go b/pkg/crud/types_test.go new file mode 100644 index 0000000..987cc39 --- /dev/null +++ b/pkg/crud/types_test.go @@ -0,0 +1,15 @@ +package crud + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOpString(t *testing.T) { + assert := assert.New(t) + op := Op{"foo"} + var op2 Op + assert.Equal("foo", op.String()) + assert.Equal("", op2.String()) +} diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go new file mode 100644 index 0000000..48f7307 --- /dev/null +++ b/pkg/diff/diff.go @@ -0,0 +1,559 @@ +package diff + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/kong/deck/cprint" + "github.com/kong/deck/crud" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" + "github.com/kong/deck/types" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +type EntityState struct { + Name string `json:"name"` + Kind string `json:"kind"` + Body any `json:"body"` +} + +type Summary struct { + Creating int32 `json:"creating"` + Updating int32 `json:"updating"` + Deleting int32 `json:"deleting"` + Total int32 `json:"total"` +} + +type JSONOutputObject struct { + Changes EntityChanges `json:"changes"` + Summary Summary `json:"summary"` + Warnings []string `json:"warnings"` + Errors []string `json:"errors"` +} + +type EntityChanges struct { + Creating []EntityState `json:"creating"` + Updating []EntityState `json:"updating"` + Deleting []EntityState `json:"deleting"` +} + +var errEnqueueFailed = errors.New("failed to queue event") + +func defaultBackOff() backoff.BackOff { + // For various reasons, Kong can temporarily fail to process + // a valid request (e.g. when the database is under heavy load). + // We retry each request up to 3 times on failure, after around + // 1 second, 3 seconds, and 9 seconds (randomized exponential backoff). + exponentialBackoff := backoff.NewExponentialBackOff() + exponentialBackoff.InitialInterval = 1 * time.Second + exponentialBackoff.Multiplier = 3 + return backoff.WithMaxRetries(exponentialBackoff, 4) +} + +// Syncer takes in a current and target state of Kong, +// diffs them, generating a Graph to get Kong from current +// to target state. +type Syncer struct { + currentState *state.KongState + targetState *state.KongState + + processor crud.Registry + postProcessor crud.Registry + + eventChan chan crud.Event + errChan chan error + stopChan chan struct{} + + inFlightOps int32 + + silenceWarnings bool + stageDelaySec int + + createPrintln func(a ...interface{}) + updatePrintln func(a ...interface{}) + deletePrintln func(a ...interface{}) + + kongClient *kong.Client + konnectClient *konnect.Client + + entityDiffers map[types.EntityType]types.Differ + + noMaskValues bool + + isKonnect bool +} + +type SyncerOpts struct { + CurrentState *state.KongState + TargetState *state.KongState + + KongClient *kong.Client + KonnectClient *konnect.Client + + SilenceWarnings bool + StageDelaySec int + + NoMaskValues bool + + IsKonnect bool + + CreatePrintln func(a ...interface{}) + UpdatePrintln func(a ...interface{}) + DeletePrintln func(a ...interface{}) +} + +// NewSyncer constructs a Syncer. +func NewSyncer(opts SyncerOpts) (*Syncer, error) { + s := &Syncer{ + currentState: opts.CurrentState, + targetState: opts.TargetState, + + kongClient: opts.KongClient, + konnectClient: opts.KonnectClient, + + silenceWarnings: opts.SilenceWarnings, + stageDelaySec: opts.StageDelaySec, + + noMaskValues: opts.NoMaskValues, + + createPrintln: opts.CreatePrintln, + updatePrintln: opts.UpdatePrintln, + deletePrintln: opts.DeletePrintln, + isKonnect: opts.IsKonnect, + } + + if s.createPrintln == nil { + s.createPrintln = cprint.CreatePrintln + } + if s.updatePrintln == nil { + s.updatePrintln = cprint.UpdatePrintln + } + if s.deletePrintln == nil { + s.deletePrintln = cprint.DeletePrintln + } + + err := s.init() + if err != nil { + return nil, err + } + + return s, nil +} + +func (sc *Syncer) init() error { + opts := types.EntityOpts{ + CurrentState: sc.currentState, + TargetState: sc.targetState, + + KongClient: sc.kongClient, + KonnectClient: sc.konnectClient, + + IsKonnect: sc.isKonnect, + } + + entities := []types.EntityType{ + types.Service, types.Route, types.Plugin, + + types.Certificate, types.SNI, types.CACertificate, + + types.Upstream, types.Target, + + types.Consumer, + types.ConsumerGroup, types.ConsumerGroupConsumer, types.ConsumerGroupPlugin, + types.ACLGroup, types.BasicAuth, types.KeyAuth, + types.HMACAuth, types.JWTAuth, types.OAuth2Cred, + types.MTLSAuth, + + types.Vault, + + types.RBACRole, types.RBACEndpointPermission, + + types.ServicePackage, types.ServiceVersion, types.Document, + } + sc.entityDiffers = map[types.EntityType]types.Differ{} + for _, entityType := range entities { + entity, err := types.NewEntity(entityType, opts) + if err != nil { + return err + } + sc.postProcessor.MustRegister(crud.Kind(entityType), entity.PostProcessActions()) + sc.processor.MustRegister(crud.Kind(entityType), entity.CRUDActions()) + sc.entityDiffers[entityType] = entity.Differ() + } + return nil +} + +func (sc *Syncer) diff() error { + for _, operation := range []func() error{ + sc.deleteDuplicates, + sc.createUpdate, + sc.delete, + } { + err := operation() + if err != nil { + return err + } + } + return nil +} + +func (sc *Syncer) deleteDuplicates() error { + var events []crud.Event + for _, ts := range reverseOrder() { + for _, entityType := range ts { + entityDiffer, ok := sc.entityDiffers[entityType].(types.DuplicatesDeleter) + if !ok { + continue + } + entityEvents, err := entityDiffer.DuplicatesDeletes() + if err != nil { + return err + } + events = append(events, entityEvents...) + } + } + + return sc.processDeleteDuplicates(eventsInOrder(events, reverseOrder())) +} + +func (sc *Syncer) processDeleteDuplicates(eventsByLevel [][]crud.Event) error { + // All entities implement this interface. We'll use it to index delete events by (kind, identifier) tuple to prevent + // deleting a single object twice. + type identifier interface { + Identifier() string + } + var ( + alreadyDeleted = map[string]struct{}{} + keyForEvent = func(event crud.Event) (string, error) { + obj, ok := event.Obj.(identifier) + if !ok { + return "", fmt.Errorf("unexpected type %T in event", event.Obj) + } + return fmt.Sprintf("%s-%s", event.Kind, obj.Identifier()), nil + } + ) + + for _, events := range eventsByLevel { + for _, event := range events { + key, err := keyForEvent(event) + if err != nil { + return err + } + if _, ok := alreadyDeleted[key]; ok { + continue + } + if err := sc.queueEvent(event); err != nil { + return err + } + alreadyDeleted[key] = struct{}{} + } + + // Wait for all the deletes to finish before moving to the next level to avoid conflicts. + sc.wait() + } + + return nil +} + +func (sc *Syncer) delete() error { + for _, types := range reverseOrder() { + for _, entityType := range types { + err := sc.entityDiffers[entityType].Deletes(sc.queueEvent) + if err != nil { + return err + } + sc.wait() + } + } + return nil +} + +func (sc *Syncer) createUpdate() error { + for _, types := range order() { + for _, entityType := range types { + err := sc.entityDiffers[entityType].CreateAndUpdates(sc.queueEvent) + if err != nil { + return err + } + sc.wait() + } + } + return nil +} + +func (sc *Syncer) queueEvent(e crud.Event) error { + atomic.AddInt32(&sc.inFlightOps, 1) + select { + case sc.eventChan <- e: + return nil + case <-sc.stopChan: + atomic.AddInt32(&sc.inFlightOps, -1) + return errEnqueueFailed + } +} + +func (sc *Syncer) eventCompleted() { + atomic.AddInt32(&sc.inFlightOps, -1) +} + +func (sc *Syncer) wait() { + time.Sleep(time.Duration(sc.stageDelaySec) * time.Second) + for atomic.LoadInt32(&sc.inFlightOps) != 0 { + select { + case <-sc.stopChan: + return + default: + time.Sleep(1 * time.Millisecond) + } + } +} + +// Run starts a diff and invokes d for every diff. +func (sc *Syncer) Run(ctx context.Context, parallelism int, d Do) []error { + if parallelism < 1 { + return append([]error{}, fmt.Errorf("parallelism can not be negative")) + } + + var wg sync.WaitGroup + const eventBuffer = 10 + + sc.eventChan = make(chan crud.Event, eventBuffer) + sc.stopChan = make(chan struct{}) + sc.errChan = make(chan error) + + // run rabbit run + // start the consumers + wg.Add(parallelism) + for i := 0; i < parallelism; i++ { + go func() { + err := sc.eventLoop(ctx, d) + if err != nil { + sc.errChan <- err + } + wg.Done() + }() + } + + // start the producer + wg.Add(1) + go func() { + err := sc.diff() + if err != nil { + sc.errChan <- err + } + close(sc.eventChan) + wg.Done() + }() + + // close the error chan once all done + go func() { + wg.Wait() + close(sc.errChan) + }() + + var errs []error + select { + case <-ctx.Done(): + errs = append(errs, fmt.Errorf("failed to sync all entities: %w", ctx.Err())) + case err, ok := <-sc.errChan: + if ok && err != nil { + if !errors.Is(err, errEnqueueFailed) { + errs = append(errs, err) + } + } + } + + // stop the producer + close(sc.stopChan) + + // collect errors + for err := range sc.errChan { + if !errors.Is(err, errEnqueueFailed) { + errs = append(errs, err) + } + } + + return errs +} + +// Do is the worker function to sync the diff +type Do func(a crud.Event) (crud.Arg, error) + +func (sc *Syncer) eventLoop(ctx context.Context, d Do) error { + for event := range sc.eventChan { + // Stop if program is terminated + select { + case <-sc.stopChan: + return nil + default: + } + + err := sc.handleEvent(ctx, d, event) + sc.eventCompleted() + if err != nil { + return err + } + } + return nil +} + +func (sc *Syncer) handleEvent(ctx context.Context, d Do, event crud.Event) error { + err := backoff.Retry(func() error { + res, err := d(event) + if err != nil { + err = fmt.Errorf("while processing event: %w", err) + + var kongAPIError *kong.APIError + if errors.As(err, &kongAPIError) && + kongAPIError.Code() == http.StatusInternalServerError { + // Only retry if the request to Kong returned a 500 status code + return err + } + + // Do not retry on other status codes + return backoff.Permanent(err) + } + if res == nil { + // Do not retry empty responses + return backoff.Permanent(fmt.Errorf("result of event is nil")) + } + _, err = sc.postProcessor.Do(ctx, event.Kind, event.Op, res) + if err != nil { + // Do not retry program errors + return backoff.Permanent(fmt.Errorf("while post processing event: %w", err)) + } + return nil + }, defaultBackOff()) + + return err +} + +// Stats holds the stats related to a Solve. +type Stats struct { + CreateOps *utils.AtomicInt32Counter + UpdateOps *utils.AtomicInt32Counter + DeleteOps *utils.AtomicInt32Counter +} + +// Generete Diff output for 'sync' and 'diff' commands +func generateDiffString(e crud.Event, isDelete bool, noMaskValues bool) (string, error) { + var diffString string + var err error + if oldObj, ok := e.OldObj.(*state.Document); ok { + if !isDelete { + diffString, err = getDocumentDiff(oldObj, e.Obj.(*state.Document)) + } else { + diffString, err = getDocumentDiff(e.Obj.(*state.Document), oldObj) + } + } else { + if !isDelete { + diffString, err = getDiff(e.OldObj, e.Obj) + } else { + diffString, err = getDiff(e.Obj, e.OldObj) + } + } + if err != nil { + return "", err + } + if !noMaskValues { + diffString = MaskEnvVarValue(diffString) + } + return diffString, err +} + +// Solve generates a diff and walks the graph. +func (sc *Syncer) Solve(ctx context.Context, parallelism int, dry bool, isJSONOut bool) (Stats, + []error, EntityChanges, +) { + stats := Stats{ + CreateOps: &utils.AtomicInt32Counter{}, + UpdateOps: &utils.AtomicInt32Counter{}, + DeleteOps: &utils.AtomicInt32Counter{}, + } + recordOp := func(op crud.Op) { + switch op { + case crud.Create: + stats.CreateOps.Increment(1) + case crud.Update: + stats.UpdateOps.Increment(1) + case crud.Delete: + stats.DeleteOps.Increment(1) + } + } + + output := EntityChanges{ + Creating: []EntityState{}, + Updating: []EntityState{}, + Deleting: []EntityState{}, + } + + errs := sc.Run(ctx, parallelism, func(e crud.Event) (crud.Arg, error) { + var err error + var result crud.Arg + + c := e.Obj.(state.ConsoleString) + objDiff := map[string]interface{}{ + "old": e.OldObj, + "new": e.Obj, + } + item := EntityState{ + Body: objDiff, + Name: c.Console(), + Kind: string(e.Kind), + } + switch e.Op { + case crud.Create: + if isJSONOut { + output.Creating = append(output.Creating, item) + } else { + sc.createPrintln("creating", e.Kind, c.Console()) + } + case crud.Update: + diffString, err := generateDiffString(e, false, sc.noMaskValues) + if err != nil { + return nil, err + } + if isJSONOut { + output.Updating = append(output.Updating, item) + } else { + sc.updatePrintln("updating", e.Kind, c.Console(), diffString) + } + case crud.Delete: + if isJSONOut { + output.Deleting = append(output.Deleting, item) + } else { + sc.deletePrintln("deleting", e.Kind, c.Console()) + } + default: + panic("unknown operation " + e.Op.String()) + } + + if !dry { + // sync mode + // fire the request to Kong + result, err = sc.processor.Do(ctx, e.Kind, e.Op, e) + if err != nil { + return nil, fmt.Errorf("%v %v %v failed: %w", e.Op, e.Kind, c.Console(), err) + } + } else { + // diff mode + // return the new obj as is but with timestamps zeroed out + utils.ZeroOutTimestamps(e.Obj) + utils.ZeroOutTimestamps(e.OldObj) + result = e.Obj + } + // record operation in both: diff and sync commands + recordOp(e.Op) + + return result, nil + }) + return stats, errs, output +} diff --git a/pkg/diff/diff_helpers.go b/pkg/diff/diff_helpers.go new file mode 100644 index 0000000..9dbdf4c --- /dev/null +++ b/pkg/diff/diff_helpers.go @@ -0,0 +1,136 @@ +package diff + +import ( + "encoding/json" + "fmt" + "os" + "sort" + "strings" + + "github.com/Kong/gojsondiff" + "github.com/Kong/gojsondiff/formatter" + "github.com/hexops/gotextdiff" + "github.com/hexops/gotextdiff/myers" + "github.com/hexops/gotextdiff/span" + "github.com/kong/deck/state" +) + +var differ = gojsondiff.New() + +func getDocumentDiff(a, b *state.Document) (string, error) { + aCopy := a.ShallowCopy() + bCopy := a.ShallowCopy() + aContent := *a.Content + bContent := *b.Content + aCopy.Content = nil + bCopy.Content = nil + objDiff, err := getDiff(aCopy, bCopy) + if err != nil { + return "", err + } + var contentDiff string + if json.Valid([]byte(aContent)) && json.Valid([]byte(bContent)) { + aContent, err = prettyPrintJSONString(aContent) + if err != nil { + return "", err + } + bContent, err = prettyPrintJSONString(bContent) + if err != nil { + return "", err + } + } + edits := myers.ComputeEdits(span.URIFromPath("old"), aContent, bContent) + contentDiff = fmt.Sprint(gotextdiff.ToUnified("old", "new", aContent, edits)) + + return objDiff + contentDiff, nil +} + +func prettyPrintJSONString(JSONString string) (string, error) { + jBlob := []byte(JSONString) + var obj interface{} + err := json.Unmarshal(jBlob, &obj) + if err != nil { + return "", err + } + bytes, err := json.MarshalIndent(obj, "", "\t") + if err != nil { + return "", err + } + return string(bytes), nil +} + +func getDiff(a, b interface{}) (string, error) { + aJSON, err := json.Marshal(a) + if err != nil { + return "", err + } + bJSON, err := json.Marshal(b) + if err != nil { + return "", err + } + + // remove timestamps from JSON data without modifying the original data + aJSON = removeTimestamps(aJSON) + bJSON = removeTimestamps(bJSON) + + d, err := differ.Compare(aJSON, bJSON) + if err != nil { + return "", err + } + var leftObject map[string]interface{} + err = json.Unmarshal(aJSON, &leftObject) + if err != nil { + return "", err + } + + formatter := formatter.NewAsciiFormatter(leftObject, + formatter.AsciiFormatterConfig{}) + diffString, err := formatter.Format(d) + return diffString, err +} + +func removeTimestamps(jsonData []byte) []byte { + var dataMap map[string]interface{} + if err := json.Unmarshal(jsonData, &dataMap); err != nil { + return jsonData + } + delete(dataMap, "created_at") + delete(dataMap, "updated_at") + modifiedJSON, err := json.Marshal(dataMap) + if err != nil { + return jsonData + } + return modifiedJSON +} + +type EnvVar struct { + Key string + Value string +} + +func parseDeckEnvVars() []EnvVar { + const envVarPrefix = "DECK_" + var parsedEnvVars []EnvVar + + for _, envVarStr := range os.Environ() { + envPair := strings.SplitN(envVarStr, "=", 2) + if strings.HasPrefix(envPair[0], envVarPrefix) { + envVar := EnvVar{} + envVar.Key = envPair[0] + envVar.Value = envPair[1] + parsedEnvVars = append(parsedEnvVars, envVar) + } + } + + sort.Slice(parsedEnvVars, func(i, j int) bool { + return len(parsedEnvVars[i].Value) > len(parsedEnvVars[j].Value) + }) + return parsedEnvVars +} + +func MaskEnvVarValue(diffString string) string { + for _, envVar := range parseDeckEnvVars() { + diffString = strings.Replace(diffString, envVar.Value, "[masked]", -1) + } + return diffString +} diff --git a/pkg/diff/diff_helpers_test.go b/pkg/diff/diff_helpers_test.go new file mode 100644 index 0000000..82f26c1 --- /dev/null +++ b/pkg/diff/diff_helpers_test.go @@ -0,0 +1,176 @@ +package diff + +import ( + "testing" + + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +func Test_PrettyPrintJSONString(t *testing.T) { + type args struct { + jstring string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "basic JSON string", + args: args{ + jstring: `{"foo":"foo","bar":{"a": 1, "b": 2}}`, + }, + want: `{ + "bar": { + "a": 1, + "b": 2 + }, + "foo": "foo" +}`, + wantErr: false, + }, + { + name: "invalid JSON string", + args: args{ + jstring: "a large swarm of bees", + }, + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := prettyPrintJSONString(tt.args.jstring) + if (err != nil) != tt.wantErr { + t.Errorf("prettyPrintJSONString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("prettyPrintJSONString() = %v\nwant %v", got, tt.want) + } + }) + } +} + +func Test_GetDocumentDiff(t *testing.T) { + type args struct { + docA *state.Document + docB *state.Document + } + tests := []struct { + name string + args args + want string + }{ + { + name: "JSON", + args: args{ + docA: &state.Document{ + Document: konnect.Document{ + Path: kong.String("foo"), + Parent: &konnect.ServiceVersion{ + ID: kong.String("abc"), + }, + Content: kong.String(`{"foo":"foo","bar":"bar"}`), + }, + }, + docB: &state.Document{ + Document: konnect.Document{ + Path: kong.String("foo"), + Parent: &konnect.ServiceVersion{ + ID: kong.String("abc"), + }, + Content: kong.String(`{"foo":"foo","bar":"bar","baz":"baz"}`), + }, + }, + }, + want: ` { + "path": "foo" + } +--- old ++++ new +@@ -1,4 +1,5 @@ + { + "bar": "bar", ++ "baz": "baz", + "foo": "foo" + } +\ No newline at end of file +`, + }, + { + name: "not JSON", + args: args{ + docA: &state.Document{ + Document: konnect.Document{ + Path: kong.String("foo"), + Parent: &konnect.ServiceVersion{ + ID: kong.String("abc"), + }, + Content: kong.String(`foo +`), + }, + }, + docB: &state.Document{ + Document: konnect.Document{ + Path: kong.String("foo"), + Parent: &konnect.ServiceVersion{ + ID: kong.String("abc"), + }, + Content: kong.String(`foo +bar +`), + }, + }, + }, + want: ` { + "path": "foo" + } +--- old ++++ new +@@ -1 +1,2 @@ + foo ++bar +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, _ := getDocumentDiff(tt.args.docA, tt.args.docB); got != tt.want { + t.Errorf("getDocumentDiff() = %v\nwant %v", got, tt.want) + } + }) + } +} + +func Test_MaskEnvVarsValues(t *testing.T) { + tests := []struct { + name string + args string + want string + envVars map[string]string + }{ + { + name: "JSON", + envVars: map[string]string{ + "DECK_BAR": "barbar", + "DECK_BAZ": "bazbaz", + }, + args: `{"foo":"foo","bar":"barbar","baz":"bazbaz"}`, + want: `{"foo":"foo","bar":"[masked]","baz":"[masked]"}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.envVars { + t.Setenv(k, v) + } + if got := MaskEnvVarValue(tt.args); got != tt.want { + t.Errorf("maskEnvVarValue() = %v\nwant %v", got, tt.want) + } + }) + } +} diff --git a/pkg/diff/order.go b/pkg/diff/order.go new file mode 100644 index 0000000..064ed5d --- /dev/null +++ b/pkg/diff/order.go @@ -0,0 +1,131 @@ +package diff + +import ( + "github.com/kong/deck/crud" + "github.com/kong/deck/types" +) + +/* + Root + | + +----------+----------+---------+------------+---------------+ + | | | | | | + v v v v v v +L1 Service RbacRole Upstream Certificate CACertificate Consumer ---+ + Package | | | | | | | + | v v v | v v | +L2 | RBACRole Target SNI +-> Service Credentials | + | Endpoint | | (7) | + | | | | + | | | | +L3 +---------------------------> Service <---+ +-> Route | + | Version | | | + | | | | | + | | | v | +L4 +----------> Document <---------+ +-> Plugins <----------+ +*/ + +// dependencyOrder defines the order in which entities will be synced by decK. +// Entities at the same level are processed concurrently. +// Entities at level n will only be processed after all entities at level n-1 +// have been processed. +// The processing order for create and update stage is top-down while that +// for delete stage is bottom-up. +var dependencyOrder = [][]types.EntityType{ + { + types.ServicePackage, + types.RBACRole, + types.Certificate, + types.CACertificate, + types.Consumer, + types.Vault, + }, + { + types.ConsumerGroup, + types.RBACEndpointPermission, + types.SNI, + types.Service, + types.Upstream, + + types.KeyAuth, types.HMACAuth, types.JWTAuth, + types.BasicAuth, types.OAuth2Cred, types.ACLGroup, + types.MTLSAuth, + }, + { + types.ServiceVersion, + types.Route, + types.Target, + types.ConsumerGroupConsumer, + types.ConsumerGroupPlugin, + }, + { + types.Plugin, + types.Document, + }, +} + +func order() [][]types.EntityType { + return deepCopy(dependencyOrder) +} + +func reverseOrder() [][]types.EntityType { + order := deepCopy(dependencyOrder) + return reverse(order) +} + +func reverse(src [][]types.EntityType) [][]types.EntityType { + src = deepCopy(src) + i := 0 + j := len(src) - 1 + for i < j { + temp := src[i] + src[i] = src[j] + src[j] = temp + i++ + j-- + } + return src +} + +func deepCopy(src [][]types.EntityType) [][]types.EntityType { + res := make([][]types.EntityType, len(src)) + for i := range src { + res[i] = make([]types.EntityType, len(src[i])) + copy(res[i], src[i]) + } + return res +} + +func eventsInOrder(events []crud.Event, order [][]types.EntityType) [][]crud.Event { + // kindToLevel maps a Kind to its level in the order to avoid repeated lookups. + kindToLevel := make(map[crud.Kind]int) + + // eventsByLevel is a slice of slices of events, where each slice of events is at the same level and can be + // processed concurrently. + eventsByLevel := make([][]crud.Event, len(order)) + + for _, event := range events { + level, ok := kindToLevel[event.Kind] + if !ok { + level = levelForEvent(event, order) + kindToLevel[event.Kind] = level + } + + eventsByLevel[level] = append(eventsByLevel[level], event) + } + + return eventsByLevel +} + +func levelForEvent(event crud.Event, order [][]types.EntityType) int { + for i, level := range order { + for _, entityType := range level { + if event.Kind == crud.Kind(entityType) { + return i + } + } + } + + // This should never happen. + return -1 +} diff --git a/pkg/diff/order_test.go b/pkg/diff/order_test.go new file mode 100644 index 0000000..a923981 --- /dev/null +++ b/pkg/diff/order_test.go @@ -0,0 +1,91 @@ +package diff + +import ( + "reflect" + "testing" + + "github.com/kong/deck/crud" + "github.com/kong/deck/types" + "github.com/stretchr/testify/require" +) + +func Test_reverse(t *testing.T) { + type args struct { + src [][]types.EntityType + } + tests := []struct { + name string + args args + want [][]types.EntityType + }{ + { + name: "doesn't panic on empty slice", + args: args{ + src: nil, + }, + want: [][]types.EntityType{}, + }, + { + name: "doesn't panic on empty slice", + args: args{ + src: [][]types.EntityType{ + {"foo"}, + {"bar"}, + {"baz", "fubar"}, + }, + }, + want: [][]types.EntityType{ + {"baz", "fubar"}, + {"bar"}, + {"foo"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := reverse(tt.args.src); !reflect.DeepEqual(got, tt.want) { + t.Errorf("reverse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEventsInOrder(t *testing.T) { + e := func(entityType types.EntityType) crud.Event { + return crud.Event{Kind: crud.Kind(entityType)} + } + + eventsOutOfOrder := []crud.Event{ + e(types.Consumer), + e(types.Service), + e(types.KeyAuth), + e(types.Route), + e(types.ServicePackage), + e(types.ConsumerGroup), + e(types.ServiceVersion), + e(types.Plugin), + } + + order := reverseOrder() + result := eventsInOrder(eventsOutOfOrder, order) + + require.Equal(t, [][]crud.Event{ + { + e(types.Plugin), + }, + { + e(types.Route), + e(types.ServiceVersion), + }, + { + e(types.Service), + e(types.KeyAuth), + e(types.ConsumerGroup), + }, + { + e(types.Consumer), + e(types.ServicePackage), + }, + }, result) +} diff --git a/pkg/dump/dump.go b/pkg/dump/dump.go new file mode 100644 index 0000000..eeb59d4 --- /dev/null +++ b/pkg/dump/dump.go @@ -0,0 +1,888 @@ +package dump + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" + "golang.org/x/sync/errgroup" +) + +// Config can be used to skip exporting certain entities +type Config struct { + // If true, only RBAC resources are exported. + // SkipConsumers and SelectorTags should be falsy when this is set. + RBACResourcesOnly bool + + // If true, consumers and any plugins associated with it + // are not exported. + SkipConsumers bool + + // If true, CA certificates are not exported. + SkipCACerts bool + + // SelectorTags can be used to export entities tagged with only specific + // tags. + SelectorTags []string + + // KonnectControlPlane + KonnectControlPlane string + + // IsConsumerGroupScopedPluginSupported + IsConsumerGroupScopedPluginSupported bool +} + +func deduplicate(stringSlice []string) []string { + existing := map[string]struct{}{} + result := []string{} + + for _, s := range stringSlice { + if _, exists := existing[s]; !exists { + existing[s] = struct{}{} + result = append(result, s) + } + } + + return result +} + +func newOpt(tags []string) *kong.ListOpt { + opt := new(kong.ListOpt) + opt.Size = 1000 + opt.Tags = kong.StringSlice(deduplicate(tags)...) + opt.MatchAllTags = true + return opt +} + +func validateConfig(config Config) error { + if config.RBACResourcesOnly { + if config.SkipConsumers { + return fmt.Errorf("dump: config: SkipConsumer cannot be set when RBACResourcesOnly is set") + } + if len(config.SelectorTags) != 0 { + return fmt.Errorf("dump: config: SelectorTags cannot be set when RBACResourcesOnly is set") + } + } + return nil +} + +func getConsumerGroupsConfiguration(ctx context.Context, group *errgroup.Group, + client *kong.Client, config Config, state *utils.KongRawState, +) { + group.Go(func() error { + consumerGroups, err := GetAllConsumerGroups(ctx, client, config.SelectorTags) + if err != nil { + if kong.IsNotFoundErr(err) || kong.IsForbiddenErr(err) { + return nil + } + return fmt.Errorf("consumer_groups: %w", err) + } + state.ConsumerGroups = consumerGroups + return nil + }) +} + +func getConsumerConfiguration(ctx context.Context, group *errgroup.Group, + client *kong.Client, config Config, state *utils.KongRawState, +) { + group.Go(func() error { + consumers, err := GetAllConsumers(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("consumers: %w", err) + } + state.Consumers = consumers + return nil + }) + + group.Go(func() error { + keyAuths, err := GetAllKeyAuths(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("key-auths: %w", err) + } + state.KeyAuths = keyAuths + return nil + }) + + group.Go(func() error { + hmacAuths, err := GetAllHMACAuths(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("hmac-auths: %w", err) + } + state.HMACAuths = hmacAuths + return nil + }) + + group.Go(func() error { + jwtAuths, err := GetAllJWTAuths(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("jwts: %w", err) + } + state.JWTAuths = jwtAuths + return nil + }) + + group.Go(func() error { + basicAuths, err := GetAllBasicAuths(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("basic-auths: %w", err) + } + state.BasicAuths = basicAuths + return nil + }) + + group.Go(func() error { + oauth2Creds, err := GetAllOauth2Creds(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("oauth2: %w", err) + } + state.Oauth2Creds = oauth2Creds + return nil + }) + + group.Go(func() error { + aclGroups, err := GetAllACLGroups(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("acls: %w", err) + } + state.ACLGroups = aclGroups + return nil + }) + + group.Go(func() error { + // XXX Select-tags based filtering is not performed for mTLS-auth credentials + // because of the following problems: + // - We currently do not already tag these credentials, filtering these + // credentials with tags will break any existing user + // - this is not a big issue since only mTLS-auth credentials for tagged + // consumers are exported anyway + // This feature would only benefit a user who uses tagged consumers but + // then managed mtls-auth credentials out-of-band. We expect such users + // to be rare or non-existent. + mtlsAuths, err := GetAllMTLSAuths(ctx, client, nil) + if err != nil { + return fmt.Errorf("mtls-auths: %w", err) + } + state.MTLSAuths = mtlsAuths + return nil + }) +} + +func getProxyConfiguration(ctx context.Context, group *errgroup.Group, + client *kong.Client, config Config, state *utils.KongRawState, +) { + group.Go(func() error { + services, err := GetAllServices(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("services: %w", err) + } + state.Services = services + return nil + }) + + group.Go(func() error { + routes, err := GetAllRoutes(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("routes: %w", err) + } + state.Routes = routes + return nil + }) + + group.Go(func() error { + plugins, err := GetAllPlugins(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("plugins: %w", err) + } + plugins = excludeKonnectManagedPlugins(plugins) + if config.SkipConsumers { + plugins = excludeConsumersPlugins(plugins) + plugins = excludeConsumerGroupsPlugins(plugins) + } + state.Plugins = plugins + return nil + }) + + group.Go(func() error { + certificates, err := GetAllCertificates(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("certificates: %w", err) + } + state.Certificates = certificates + return nil + }) + + if !config.SkipCACerts { + group.Go(func() error { + caCerts, err := GetAllCACertificates(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("ca-certificates: %w", err) + } + state.CACertificates = caCerts + return nil + }) + } + + group.Go(func() error { + snis, err := GetAllSNIs(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("snis: %w", err) + } + state.SNIs = snis + return nil + }) + + group.Go(func() error { + upstreams, err := GetAllUpstreams(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("upstreams: %w", err) + } + state.Upstreams = upstreams + targets, err := GetAllTargets(ctx, client, upstreams, config.SelectorTags) + if err != nil { + return fmt.Errorf("targets: %w", err) + } + state.Targets = targets + return nil + }) + + group.Go(func() error { + vaults, err := GetAllVaults(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("vaults: %w", err) + } + state.Vaults = vaults + return nil + }) +} + +func getEnterpriseRBACConfiguration(ctx context.Context, group *errgroup.Group, + client *kong.Client, state *utils.KongRawState, +) { + group.Go(func() error { + roles, err := GetAllRBACRoles(ctx, client) + if err != nil { + return fmt.Errorf("roles: %w", err) + } + state.RBACRoles = roles + return nil + }) + + group.Go(func() error { + eps, err := GetAllRBACREndpointPermissions(ctx, client) + if err != nil { + return fmt.Errorf("eps: %w", err) + } + state.RBACEndpointPermissions = eps + return nil + }) +} + +// Get queries all the entities using client and returns +// all the entities in KongRawState. +func Get(ctx context.Context, client *kong.Client, config Config) (*utils.KongRawState, error) { + var state utils.KongRawState + + if err := validateConfig(config); err != nil { + return nil, err + } + + group, ctx := errgroup.WithContext(ctx) + + // dump only rbac resources + if config.RBACResourcesOnly { + getEnterpriseRBACConfiguration(ctx, group, client, &state) + } else { + // regular case + getProxyConfiguration(ctx, group, client, config, &state) + if !config.SkipConsumers { + getConsumerGroupsConfiguration(ctx, group, client, config, &state) + getConsumerConfiguration(ctx, group, client, config, &state) + } + } + + err := group.Wait() + if err != nil { + return nil, err + } + + return &state, nil +} + +// GetAllServices queries Kong for all the services using client. +func GetAllServices(ctx context.Context, client *kong.Client, + tags []string, +) ([]*kong.Service, error) { + var services []*kong.Service + opt := newOpt(tags) + + for { + s, nextopt, err := client.Services.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + services = append(services, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return services, nil +} + +// GetAllRoutes queries Kong for all the routes using client. +func GetAllRoutes(ctx context.Context, client *kong.Client, + tags []string, +) ([]*kong.Route, error) { + var routes []*kong.Route + opt := newOpt(tags) + + for { + s, nextopt, err := client.Routes.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + routes = append(routes, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return routes, nil +} + +// GetAllPlugins queries Kong for all the plugins using client. +func GetAllPlugins(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.Plugin, error) { + var plugins []*kong.Plugin + opt := newOpt(tags) + + for { + s, nextopt, err := client.Plugins.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + plugins = append(plugins, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return plugins, nil +} + +// GetAllCertificates queries Kong for all the certificates using client. +func GetAllCertificates(ctx context.Context, client *kong.Client, + tags []string, +) ([]*kong.Certificate, error) { + var certificates []*kong.Certificate + opt := newOpt(tags) + + for { + s, nextopt, err := client.Certificates.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + for _, cert := range s { + c := cert + c.SNIs = nil + certificates = append(certificates, cert) + } + if nextopt == nil { + break + } + opt = nextopt + } + return certificates, nil +} + +// GetAllCACertificates queries Kong for all the CACertificates using client. +func GetAllCACertificates(ctx context.Context, + client *kong.Client, + tags []string, +) ([]*kong.CACertificate, error) { + var caCertificates []*kong.CACertificate + opt := newOpt(tags) + + for { + s, nextopt, err := client.CACertificates.List(ctx, opt) + // Compatibility for Kong < 1.3 + // This core entitiy was not present in the past + // and the Admin API request will error with 404 Not Found + // If we do get the error, we return back an empty array of + // CACertificates, effectively disabling the entity for versions + // which don't have it. + // A better solution would be to have a version check, and based + // on the version, the entities are loaded and synced. + if err != nil { + if kong.IsNotFoundErr(err) { + return caCertificates, nil + } + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + caCertificates = append(caCertificates, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return caCertificates, nil +} + +// GetAllSNIs queries Kong for all the SNIs using client. +func GetAllSNIs(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.SNI, error) { + var snis []*kong.SNI + opt := newOpt(tags) + + for { + s, nextopt, err := client.SNIs.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + snis = append(snis, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return snis, nil +} + +// GetAllConsumers queries Kong for all the consumers using client. +// Please use this method with caution if you have a lot of consumers. +func GetAllConsumers(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.Consumer, error) { + var consumers []*kong.Consumer + opt := newOpt(tags) + + for { + s, nextopt, err := client.Consumers.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + consumers = append(consumers, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return consumers, nil +} + +// GetAllUpstreams queries Kong for all the Upstreams using client. +func GetAllUpstreams(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.Upstream, error) { + var upstreams []*kong.Upstream + opt := newOpt(tags) + + for { + s, nextopt, err := client.Upstreams.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + upstreams = append(upstreams, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return upstreams, nil +} + +// GetAllConsumerGroups queries Kong for all the ConsumerGroups using client. +func GetAllConsumerGroups(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.ConsumerGroupObject, error) { + var consumerGroupObjects []*kong.ConsumerGroupObject + opt := newOpt(tags) + + for { + cgs, nextopt, err := client.ConsumerGroups.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + + for _, cg := range cgs { + r, err := client.ConsumerGroups.Get(ctx, cg.Name) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + group := &kong.ConsumerGroupObject{ + ConsumerGroup: r.ConsumerGroup, + Consumers: r.Consumers, + Plugins: r.Plugins, + } + consumerGroupObjects = append(consumerGroupObjects, group) + } + if nextopt == nil { + break + } + opt = nextopt + } + return consumerGroupObjects, nil +} + +// GetAllTargets queries Kong for all the Targets of upstreams using client. +// Targets are queries per upstream as there exists no endpoint in Kong +// to list all targets of all upstreams. +func GetAllTargets(ctx context.Context, client *kong.Client, + upstreams []*kong.Upstream, tags []string, +) ([]*kong.Target, error) { + var targets []*kong.Target + opt := newOpt(tags) + + for _, upstream := range upstreams { + for { + t, nextopt, err := client.Targets.List(ctx, upstream.ID, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + targets = append(targets, t...) + if nextopt == nil { + break + } + opt = nextopt + } + } + + return targets, nil +} + +// GetAllVaults queries Kong for all the Vaults using client. +func GetAllVaults( + ctx context.Context, client *kong.Client, tags []string, +) ([]*kong.Vault, error) { + var vaults []*kong.Vault + opt := newOpt(tags) + + for { + s, nextopt, err := client.Vaults.List(ctx, opt) + if kong.IsNotFoundErr(err) || kong.IsForbiddenErr(err) { + return vaults, nil + } + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + vaults = append(vaults, s...) + if nextopt == nil { + break + } + opt = nextopt + } + + return vaults, nil +} + +// GetAllKeyAuths queries Kong for all key-auth credentials using client. +func GetAllKeyAuths(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.KeyAuth, error) { + var keyAuths []*kong.KeyAuth + opt := newOpt(tags) + + for { + s, nextopt, err := client.KeyAuths.List(ctx, opt) + if kong.IsNotFoundErr(err) { + return keyAuths, nil + } + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + keyAuths = append(keyAuths, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return keyAuths, nil +} + +// GetAllHMACAuths queries Kong for all hmac-auth credentials using client. +func GetAllHMACAuths(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.HMACAuth, error) { + var hmacAuths []*kong.HMACAuth + opt := newOpt(tags) + + for { + s, nextopt, err := client.HMACAuths.List(ctx, opt) + if kong.IsNotFoundErr(err) { + return hmacAuths, nil + } + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + hmacAuths = append(hmacAuths, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return hmacAuths, nil +} + +// GetAllJWTAuths queries Kong for all jwt credentials using client. +func GetAllJWTAuths(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.JWTAuth, error) { + var jwtAuths []*kong.JWTAuth + opt := newOpt(tags) + + for { + s, nextopt, err := client.JWTAuths.List(ctx, opt) + if kong.IsNotFoundErr(err) { + return jwtAuths, nil + } + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + jwtAuths = append(jwtAuths, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return jwtAuths, nil +} + +// GetAllBasicAuths queries Kong for all basic-auth credentials using client. +func GetAllBasicAuths(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.BasicAuth, error) { + var basicAuths []*kong.BasicAuth + opt := newOpt(tags) + + for { + s, nextopt, err := client.BasicAuths.List(ctx, opt) + if kong.IsNotFoundErr(err) { + return basicAuths, nil + } + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + basicAuths = append(basicAuths, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return basicAuths, nil +} + +// GetAllOauth2Creds queries Kong for all oauth2 credentials using client. +func GetAllOauth2Creds(ctx context.Context, client *kong.Client, + tags []string, +) ([]*kong.Oauth2Credential, error) { + var oauth2Creds []*kong.Oauth2Credential + opt := newOpt(tags) + + for { + s, nextopt, err := client.Oauth2Credentials.List(ctx, opt) + if kong.IsNotFoundErr(err) { + return oauth2Creds, nil + } + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + oauth2Creds = append(oauth2Creds, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return oauth2Creds, nil +} + +// GetAllACLGroups queries Kong for all ACL groups using client. +func GetAllACLGroups(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.ACLGroup, error) { + var aclGroups []*kong.ACLGroup + opt := newOpt(tags) + + for { + s, nextopt, err := client.ACLs.List(ctx, opt) + if kong.IsNotFoundErr(err) { + return aclGroups, nil + } + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + aclGroups = append(aclGroups, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return aclGroups, nil +} + +// GetAllMTLSAuths queries Kong for all basic-auth credentials using client. +func GetAllMTLSAuths(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.MTLSAuth, error) { + var mtlsAuths []*kong.MTLSAuth + opt := newOpt(tags) + + for { + s, nextopt, err := client.MTLSAuths.List(ctx, opt) + if kong.IsNotFoundErr(err) { + return mtlsAuths, nil + } + if err != nil { + // TODO figure out a better way to handle unauthorized endpoints + // per https://github.com/Kong/deck/issues/274 we can't dump these resources + // from an Enterprise instance running in free mode, and the 403 results in a + // fatal error when running "deck dump". We don't want to just treat 403s the + // same as 404s because Kong also uses them to indicate missing RBAC permissions, + // but this is currently necessary for compatibility. We need a better approach + // before adding other Enterprise resources that decK handles by default (versus, + // for example, RBAC roles, which require the --rbac-resources-only flag). + var kongErr *kong.APIError + if errors.As(err, &kongErr) { + if kongErr.Code() == http.StatusForbidden { + return mtlsAuths, nil + } + } + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + mtlsAuths = append(mtlsAuths, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return mtlsAuths, nil +} + +// GetAllRBACRoles queries Kong for all the RBACRoles using client. +func GetAllRBACRoles(ctx context.Context, + client *kong.Client, +) ([]*kong.RBACRole, error) { + roles, err := client.RBACRoles.ListAll(ctx) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + + return roles, nil +} + +func GetAllRBACREndpointPermissions(ctx context.Context, + client *kong.Client, +) ([]*kong.RBACEndpointPermission, error) { + eps := []*kong.RBACEndpointPermission{} + roles, err := client.RBACRoles.ListAll(ctx) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + // retrieve all permissions for the role + for _, r := range roles { + reps, err := client.RBACEndpointPermissions.ListAllForRole(ctx, r.ID) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + eps = append(eps, reps...) + } + + return eps, nil +} + +// excludeConsumersPlugins filter out consumer plugins +func excludeConsumersPlugins(plugins []*kong.Plugin) []*kong.Plugin { + var filtered []*kong.Plugin + for _, p := range plugins { + if p.Consumer != nil && !utils.Empty(p.Consumer.ID) { + continue + } + filtered = append(filtered, p) + } + return filtered +} + +// excludeConsumerGroupsPlugins filter out consumer-groups plugins +func excludeConsumerGroupsPlugins(plugins []*kong.Plugin) []*kong.Plugin { + var filtered []*kong.Plugin + for _, p := range plugins { + if p.ConsumerGroup != nil && !utils.Empty(p.ConsumerGroup.ID) { + continue + } + filtered = append(filtered, p) + } + return filtered +} diff --git a/pkg/dump/dump_konnect.go b/pkg/dump/dump_konnect.go new file mode 100644 index 0000000..3e8b50d --- /dev/null +++ b/pkg/dump/dump_konnect.go @@ -0,0 +1,195 @@ +package dump + +import ( + "context" + "fmt" + "sync" + + "github.com/kong/deck/konnect" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" +) + +type KonnectConfig struct { + // ID of the Kong Control Plane being managed. + ControlPlaneID string +} + +func GetFromKonnect(ctx context.Context, konnectClient *konnect.Client, + config KonnectConfig, +) (*utils.KonnectRawState, error) { + var res utils.KonnectRawState + var servicePackages []*konnect.ServicePackage + var relations []*konnect.ControlPlaneServiceRelation + + group, ctx := errgroup.WithContext(ctx) + // group1 fetches service packages and their versions + group.Go(func() error { + var err error + // fetch service packages + servicePackages, err = konnectClient.ServicePackages.ListAll(ctx) + if err != nil { + return err + } + + // versions of service packages are fetched concurrently + errChan := make(chan error) + var err2 error + + m := &sync.Mutex{} + m.Lock() + go func() { + defer m.Unlock() + // only the last error matters + for err := range errChan { + err2 = err + } + }() + + const concurrency = 10 + + semaphore := semaphore.NewWeighted(concurrency) + for i := 0; i < len(servicePackages); i++ { + // control the number of outstanding go routines, also controlling + // the number of parallel requests + err := semaphore.Acquire(ctx, 2) + if err != nil { + return fmt.Errorf("acquire semaphore: %w", err) + } + go func(i int) { + defer semaphore.Release(1) + versions, err := konnectClient.ServiceVersions.ListForPackage(ctx, servicePackages[i].ID) + if err != nil { + errChan <- err + return + } + servicePackages[i].Versions = versions + }(i) + go func(i int) { + defer semaphore.Release(1) + documents, err := konnectClient.Documents.ListAllForParent(ctx, servicePackages[i]) + if err != nil { + errChan <- err + return + } + res.Documents = append(res.Documents, documents...) + }(i) + } + for i := 0; i < len(servicePackages); i++ { + for _, version := range servicePackages[i].Versions { + err := semaphore.Acquire(ctx, 1) + if err != nil { + return fmt.Errorf("acquire semaphore: %w", err) + } + go func(version konnect.ServiceVersion) { + defer semaphore.Release(1) + documents, err := konnectClient.Documents.ListAllForParent(ctx, &version) + if err != nil { + errChan <- err + return + } + res.Documents = append(res.Documents, documents...) + }(version) + } + } + err = semaphore.Acquire(ctx, concurrency) + if err != nil { + return fmt.Errorf("acquire semaphore: %w", err) + } + close(errChan) + semaphore.Release(concurrency) + m.Lock() + defer m.Unlock() + if err2 != nil { + return err2 + } + return nil + }) + + // group2 fetches CP-service relations + group.Go(func() error { + var err error + relations, err = konnectClient.ControlPlaneRelations.ListAll(ctx) + return err + }) + + err := group.Wait() + if err != nil { + return nil, err + } + + res.ServicePackages = filterNonKongPackages(config.ControlPlaneID, + servicePackages, relations) + return &res, nil +} + +func filterNonKongPackages(controlPlaneID string, packages []*konnect.ServicePackage, + relations []*konnect.ControlPlaneServiceRelation, +) []*konnect.ServicePackage { + kongServiceIDs := kongServiceIDs(controlPlaneID, relations) + var res []*konnect.ServicePackage + for _, p := range packages { + // if a package has no versions, decK will manage it + switch len(p.Versions) { + case 0: + res = append(res, p) + default: + // decK will manage two types of versions: + // - either versions that don't have any implementation + // - versions which have a Kong Service as an implementation + pCopy := p.DeepCopy() + pCopy.Versions = nil + for _, v := range p.Versions { + if v.ControlPlaneServiceRelation == nil { + pCopy.Versions = append(pCopy.Versions, v) + } else if !utils.Empty(v.ControlPlaneServiceRelation.ControlPlaneEntityID) && + kongServiceIDs[*v.ControlPlaneServiceRelation.ControlPlaneEntityID] { + pCopy.Versions = append(pCopy.Versions, v) + } + } + // manage only if at least one version satisfies the above criteria + if len(pCopy.Versions) >= 1 { + res = append(res, pCopy) + } + } + } + return res +} + +func kongServiceIDs(cpID string, + relations []*konnect.ControlPlaneServiceRelation, +) map[string]bool { + res := map[string]bool{} + for _, relation := range relations { + if !utils.Empty(relation.ControlPlaneEntityID) && + relation.ControlPlane != nil && + !utils.Empty(relation.ControlPlane.ID) && + cpID == *relation.ControlPlane.ID { + res[*relation.ControlPlaneEntityID] = true + } + } + return res +} + +// excludeKonnectManagedPlugins filters out konnect-managed plugins +func excludeKonnectManagedPlugins(plugins []*kong.Plugin) []*kong.Plugin { + filtered := []*kong.Plugin{} + for _, p := range plugins { + if isManagedByKonnect(p) { + continue + } + filtered = append(filtered, p) + } + return filtered +} + +func isManagedByKonnect(plugin *kong.Plugin) bool { + for _, t := range plugin.Tags { + if *t == konnect.KonnectManagedPluginTag { + return true + } + } + return false +} diff --git a/pkg/dump/dump_konnect_test.go b/pkg/dump/dump_konnect_test.go new file mode 100644 index 0000000..ad609f2 --- /dev/null +++ b/pkg/dump/dump_konnect_test.go @@ -0,0 +1,373 @@ +package dump + +import ( + "reflect" + "testing" + + "github.com/kong/deck/konnect" + "github.com/kong/go-kong/kong" +) + +func Test_kongServiceIDs(t *testing.T) { + type args struct { + cpID string + relations []*konnect.ControlPlaneServiceRelation + } + tests := []struct { + name string + args args + want map[string]bool + }{ + { + name: "returns services belonging to the same control plane", + args: args{ + cpID: "cp1", + relations: []*konnect.ControlPlaneServiceRelation{ + { + ID: kong.String("id1"), + ControlPlaneEntityID: kong.String("kong-svc-1"), + ControlPlane: &konnect.ControlPlane{ + ID: kong.String("cp1"), + }, + }, + { + ID: kong.String("id2"), + ControlPlaneEntityID: kong.String("kong-svc-2"), + ControlPlane: &konnect.ControlPlane{ + ID: kong.String("cp1"), + }, + }, + }, + }, + want: map[string]bool{ + "kong-svc-1": true, + "kong-svc-2": true, + }, + }, + { + name: "doesn't panic if relation.ControlPlaneEntityID is nil", + args: args{ + cpID: "cp1", + relations: []*konnect.ControlPlaneServiceRelation{ + { + ID: kong.String("id1"), + ControlPlane: &konnect.ControlPlane{ + ID: kong.String("cp2"), + }, + }, + }, + }, + want: map[string]bool{}, + }, + { + name: "doesn't include a service belonging to a different control plane", + args: args{ + cpID: "cp1", + relations: []*konnect.ControlPlaneServiceRelation{ + { + ID: kong.String("id1"), + ControlPlaneEntityID: kong.String("kong-svc-1"), + ControlPlane: &konnect.ControlPlane{ + ID: kong.String("cp2"), + }, + }, + }, + }, + want: map[string]bool{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := kongServiceIDs(tt.args.cpID, tt.args.relations) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("kongServiceIDs() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_filterNonKongPackages(t *testing.T) { + type args struct { + controlPlaneID string + packages []*konnect.ServicePackage + relations []*konnect.ControlPlaneServiceRelation + } + tests := []struct { + name string + args args + want []*konnect.ServicePackage + }{ + { + name: "empty packages and relations returns nil", + args: args{ + controlPlaneID: "cp1", + packages: []*konnect.ServicePackage{}, + relations: []*konnect.ControlPlaneServiceRelation{}, + }, + want: nil, + }, + { + name: "package with no versions is returned in output", + args: args{ + controlPlaneID: "cp1", + packages: []*konnect.ServicePackage{ + { + ID: kong.String("sp-id1"), + Name: kong.String("sp-name1"), + }, + }, + }, + want: []*konnect.ServicePackage{ + { + ID: kong.String("sp-id1"), + Name: kong.String("sp-name1"), + }, + }, + }, + { + name: "package with version that belong to a different control-plane is not included in output", + args: args{ + controlPlaneID: "cp1", + packages: []*konnect.ServicePackage{ + { + ID: kong.String("sp-id1"), + Name: kong.String("sp-name1"), + Versions: []konnect.ServiceVersion{ + { + ID: kong.String("sv-id1"), + Version: kong.String("sv-v1"), + ControlPlaneServiceRelation: &konnect.ControlPlaneServiceRelation{ + ControlPlaneEntityID: kong.String("kong-svc-1"), + }, + }, + }, + }, + }, + relations: []*konnect.ControlPlaneServiceRelation{ + { + ID: kong.String("id1"), + ControlPlaneEntityID: kong.String("kong-svc-1"), + ControlPlane: &konnect.ControlPlane{ + ID: kong.String("cp2"), + }, + }, + }, + }, + want: nil, + }, + { + name: "package with version that belong to same control-plane is included in output", + args: args{ + controlPlaneID: "cp1", + packages: []*konnect.ServicePackage{ + { + ID: kong.String("sp-id1"), + Name: kong.String("sp-name1"), + Versions: []konnect.ServiceVersion{ + { + ID: kong.String("sv-id1"), + Version: kong.String("sv-v1"), + ControlPlaneServiceRelation: &konnect.ControlPlaneServiceRelation{ + ControlPlaneEntityID: kong.String("kong-svc-1"), + }, + }, + }, + }, + }, + relations: []*konnect.ControlPlaneServiceRelation{ + { + ID: kong.String("id1"), + ControlPlaneEntityID: kong.String("kong-svc-1"), + ControlPlane: &konnect.ControlPlane{ + ID: kong.String("cp1"), + }, + }, + }, + }, + want: []*konnect.ServicePackage{ + { + ID: kong.String("sp-id1"), + Name: kong.String("sp-name1"), + Versions: []konnect.ServiceVersion{ + { + ID: kong.String("sv-id1"), + Version: kong.String("sv-v1"), + ControlPlaneServiceRelation: &konnect.ControlPlaneServiceRelation{ + ControlPlaneEntityID: kong.String("kong-svc-1"), + }, + }, + }, + }, + }, + }, + { + name: "package with versions without any implementations is not included", + args: args{ + controlPlaneID: "cp1", + packages: []*konnect.ServicePackage{ + { + ID: kong.String("sp-id1"), + Name: kong.String("sp-name1"), + Versions: []konnect.ServiceVersion{ + { + ID: kong.String("sv-id1"), + Version: kong.String("sv-v1"), + }, + { + ID: kong.String("sv-id2"), + Version: kong.String("sv-v2"), + }, + }, + }, + }, + relations: []*konnect.ControlPlaneServiceRelation{}, + }, + want: []*konnect.ServicePackage{ + { + ID: kong.String("sp-id1"), + Name: kong.String("sp-name1"), + Versions: []konnect.ServiceVersion{ + { + ID: kong.String("sv-id1"), + Version: kong.String("sv-v1"), + }, + { + ID: kong.String("sv-id2"), + Version: kong.String("sv-v2"), + }, + }, + }, + }, + }, + { + name: "package with version's implementation absent from relations is not included", + args: args{ + controlPlaneID: "cp1", + packages: []*konnect.ServicePackage{ + { + ID: kong.String("sp-id1"), + Name: kong.String("sp-name1"), + Versions: []konnect.ServiceVersion{ + { + ID: kong.String("sv-id1"), + Version: kong.String("sv-v1"), + ControlPlaneServiceRelation: &konnect.ControlPlaneServiceRelation{ + ControlPlaneEntityID: kong.String("kong-svc-1"), + }, + }, + }, + }, + }, + relations: []*konnect.ControlPlaneServiceRelation{ + { + ID: kong.String("id1"), + ControlPlaneEntityID: kong.String("kong-svc-42"), + ControlPlane: &konnect.ControlPlane{ + ID: kong.String("cp1"), + }, + }, + }, + }, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := filterNonKongPackages(tt.args.controlPlaneID, tt.args.packages, tt.args.relations) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("filterNonKongPackages() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_excludeKonnectManagedPlugins(t *testing.T) { + tests := []struct { + name string + plugins []*kong.Plugin + want []*kong.Plugin + }{ + { + name: "exclude konnect tags", + plugins: []*kong.Plugin{ + { + Name: kong.String("rate-limiting"), + Tags: []*string{kong.String("tag1")}, + }, + { + Name: kong.String("basic-auth"), + Tags: []*string{}, + }, + { + Name: kong.String("key-auth"), + Tags: []*string{ + kong.String("konnect-app-registration"), + kong.String("konnect-managed-plugin"), + }, + }, + { + Name: kong.String("acl"), + Tags: []*string{ + kong.String("konnect-app-registration"), + kong.String("konnect-managed-plugin"), + }, + }, + { + Name: kong.String("prometheus"), + Tags: []*string{ + kong.String("konnect-managed-plugin"), + }, + }, + }, + want: []*kong.Plugin{ + { + Name: kong.String("rate-limiting"), + Tags: []*string{kong.String("tag1")}, + }, + { + Name: kong.String("basic-auth"), + Tags: []*string{}, + }, + }, + }, + { + name: "empty input", + plugins: []*kong.Plugin{}, + want: []*kong.Plugin{}, + }, + { + name: "all konnect managed", + plugins: []*kong.Plugin{ + { + Name: kong.String("key-auth"), + Tags: []*string{ + kong.String("konnect-app-registration"), + kong.String("konnect-managed-plugin"), + }, + }, + { + Name: kong.String("acl"), + Tags: []*string{ + kong.String("konnect-app-registration"), + kong.String("konnect-managed-plugin"), + }, + }, + { + Name: kong.String("prometheus"), + Tags: []*string{ + kong.String("konnect-managed-plugin"), + }, + }, + }, + want: []*kong.Plugin{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := excludeKonnectManagedPlugins(tt.plugins) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("excludeKonnectManagedPlugins() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/dump/dump_test.go b/pkg/dump/dump_test.go new file mode 100644 index 0000000..e322baf --- /dev/null +++ b/pkg/dump/dump_test.go @@ -0,0 +1,63 @@ +package dump + +import ( + "testing" +) + +func Test_validateConfig(t *testing.T) { + type args struct { + config Config + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid config for RBAC resources", + args: args{ + config: Config{ + RBACResourcesOnly: true, + }, + }, + wantErr: false, + }, + { + name: "valid config for proxy resources", + args: args{ + config: Config{ + SkipConsumers: true, + SelectorTags: []string{"foo", "bar"}, + }, + }, + wantErr: false, + }, + { + name: "invalid config mixing RBAC and selector tags", + args: args{ + config: Config{ + SelectorTags: []string{"foo", "bar"}, + RBACResourcesOnly: true, + }, + }, + wantErr: true, + }, + { + name: "invalid config mixing RBAC and SkipConsumers", + args: args{ + config: Config{ + SkipConsumers: true, + RBACResourcesOnly: true, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := validateConfig(tt.args.config); (err != nil) != tt.wantErr { + t.Errorf("validateConfig() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/file/builder.go b/pkg/file/builder.go new file mode 100644 index 0000000..f70cc79 --- /dev/null +++ b/pkg/file/builder.go @@ -0,0 +1,1235 @@ +package file + +import ( + "context" + "errors" + "fmt" + + "github.com/blang/semver/v4" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +const ratelimitingAdvancedPluginName = "rate-limiting-advanced" + +type stateBuilder struct { + targetContent *Content + rawState *utils.KongRawState + konnectRawState *utils.KonnectRawState + currentState *state.KongState + defaulter *utils.Defaulter + kongVersion semver.Version + + selectTags []string + skipCACerts bool + intermediate *state.KongState + + client *kong.Client + ctx context.Context + + schemasCache map[string]map[string]interface{} + + disableDynamicDefaults bool + + isKonnect bool + + checkRoutePaths bool + + isConsumerGroupScopedPluginSupported bool + + err error +} + +// uuid generates a UUID string and returns a pointer to it. +// It is a variable for testing purpose, to override and supply +// a deterministic UUID generator. +var uuid = func() *string { + return kong.String(utils.UUID()) +} + +var ErrWorkspaceNotFound = fmt.Errorf("workspace not found") + +func (b *stateBuilder) build() (*utils.KongRawState, *utils.KonnectRawState, error) { + // setup + var err error + b.rawState = &utils.KongRawState{} + b.konnectRawState = &utils.KonnectRawState{} + b.schemasCache = make(map[string]map[string]interface{}) + + b.intermediate, err = state.NewKongState() + if err != nil { + return nil, nil, err + } + + defaulter, err := defaulter(b.ctx, b.client, b.targetContent, b.disableDynamicDefaults, b.isKonnect) + if err != nil { + return nil, nil, err + } + b.defaulter = defaulter + + if utils.Kong300Version.LTE(b.kongVersion) { + b.checkRoutePaths = true + } + + if utils.Kong340Version.LTE(b.kongVersion) || b.isKonnect { + b.isConsumerGroupScopedPluginSupported = true + } + + // build + b.certificates() + if !b.skipCACerts { + b.caCertificates() + } + b.services() + b.routes() + b.upstreams() + b.consumerGroups() + b.consumers() + b.plugins() + b.enterprise() + + // konnect + b.konnect() + + // result + if b.err != nil { + return nil, nil, b.err + } + return b.rawState, b.konnectRawState, nil +} + +func (b *stateBuilder) ingestConsumerGroupScopedPlugins(cg FConsumerGroupObject) error { + var plugins []FPlugin + for _, plugin := range cg.Plugins { + plugin.ConsumerGroup = utils.GetConsumerGroupReference(cg.ConsumerGroup) + plugins = append(plugins, FPlugin{ + Plugin: kong.Plugin{ + ID: plugin.ID, + Name: plugin.Name, + Config: plugin.Config, + ConsumerGroup: &kong.ConsumerGroup{ + ID: cg.ID, + }, + }, + }) + } + return b.ingestPlugins(plugins) +} + +func (b *stateBuilder) addConsumerGroupPlugins( + cg FConsumerGroupObject, cgo *kong.ConsumerGroupObject, +) error { + for _, plugin := range cg.Plugins { + if utils.Empty(plugin.ID) { + current, err := b.currentState.ConsumerGroupPlugins.Get( + *plugin.Name, *cg.ConsumerGroup.ID, + ) + if errors.Is(err, state.ErrNotFound) { + plugin.ID = uuid() + } else if err != nil { + return err + } else { + plugin.ID = kong.String(*current.ID) + } + } + b.defaulter.MustSet(plugin) + cgo.Plugins = append(cgo.Plugins, plugin) + } + return nil +} + +func (b *stateBuilder) consumerGroups() { + if b.err != nil { + return + } + + for _, cg := range b.targetContent.ConsumerGroups { + cg := cg + current, err := b.currentState.ConsumerGroups.Get(*cg.Name) + if utils.Empty(cg.ID) { + if errors.Is(err, state.ErrNotFound) { + cg.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + cg.ID = kong.String(*current.ID) + } + } + utils.MustMergeTags(&cg.ConsumerGroup, b.selectTags) + + cgo := kong.ConsumerGroupObject{ + ConsumerGroup: &cg.ConsumerGroup, + } + + err = b.intermediate.ConsumerGroups.Add(state.ConsumerGroup{ConsumerGroup: cg.ConsumerGroup}) + if err != nil { + b.err = err + return + } + + // Plugins and Consumer Groups can be handled in two ways: + // 1. directly in the ConsumerGroup object + // 2. by scoping the plugin to the ConsumerGroup (Kong >= 3.4.0) + // + // The first method is deprecated and will be removed in the future, but + // we still need to support it for now. The isConsumerGroupScopedPluginSupported + // flag is used to determine which method to use based on the Kong version. + if b.isConsumerGroupScopedPluginSupported { + if err := b.ingestConsumerGroupScopedPlugins(cg); err != nil { + b.err = err + return + } + } else { + if err := b.addConsumerGroupPlugins(cg, &cgo); err != nil { + b.err = err + return + } + } + if current != nil { + cgo.ConsumerGroup.CreatedAt = current.CreatedAt + } + b.rawState.ConsumerGroups = append(b.rawState.ConsumerGroups, &cgo) + } +} + +func (b *stateBuilder) certificates() { + if b.err != nil { + return + } + + for i := range b.targetContent.Certificates { + c := b.targetContent.Certificates[i] + if utils.Empty(c.ID) { + cert, err := b.currentState.Certificates.GetByCertKey(*c.Cert, + *c.Key) + if errors.Is(err, state.ErrNotFound) { + c.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + c.ID = kong.String(*cert.ID) + } + } + utils.MustMergeTags(&c, b.selectTags) + + snisFromCert := c.SNIs + + kongCert := kong.Certificate{ + ID: c.ID, + Key: c.Key, + Cert: c.Cert, + Tags: c.Tags, + CreatedAt: c.CreatedAt, + } + b.rawState.Certificates = append(b.rawState.Certificates, &kongCert) + + // snis associated with the certificate + var snis []kong.SNI + for _, sni := range snisFromCert { + sni.Certificate = &kong.Certificate{ID: kong.String(*c.ID)} + snis = append(snis, sni) + } + if err := b.ingestSNIs(snis); err != nil { + b.err = err + return + } + } +} + +func (b *stateBuilder) ingestSNIs(snis []kong.SNI) error { + for _, sni := range snis { + sni := sni + currentSNI, err := b.currentState.SNIs.Get(*sni.Name) + if utils.Empty(sni.ID) { + if errors.Is(err, state.ErrNotFound) { + sni.ID = uuid() + } else if err != nil { + return err + } else { + sni.ID = kong.String(*currentSNI.ID) + } + } + utils.MustMergeTags(&sni, b.selectTags) + if currentSNI != nil { + sni.CreatedAt = currentSNI.CreatedAt + } + b.rawState.SNIs = append(b.rawState.SNIs, &sni) + } + return nil +} + +func (b *stateBuilder) caCertificates() { + if b.err != nil { + return + } + + for _, c := range b.targetContent.CACertificates { + c := c + cert, err := b.currentState.CACertificates.Get(*c.Cert) + if utils.Empty(c.ID) { + if errors.Is(err, state.ErrNotFound) { + c.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + c.ID = kong.String(*cert.ID) + } + } + utils.MustMergeTags(&c.CACertificate, b.selectTags) + if cert != nil { + c.CACertificate.CreatedAt = cert.CreatedAt + } + + b.rawState.CACertificates = append(b.rawState.CACertificates, + &c.CACertificate) + } +} + +func (b *stateBuilder) consumers() { + if b.err != nil { + return + } + + for _, c := range b.targetContent.Consumers { + c := c + + var ( + consumer *state.Consumer + err error + ) + if c.Username != nil { + consumer, err = b.currentState.Consumers.GetByIDOrUsername(*c.Username) + } + if errors.Is(err, state.ErrNotFound) || consumer == nil { + if c.CustomID != nil { + consumer, err = b.currentState.Consumers.GetByCustomID(*c.CustomID) + } + } + + if utils.Empty(c.ID) { + if errors.Is(err, state.ErrNotFound) { + c.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + c.ID = kong.String(*consumer.ID) + } + } + utils.MustMergeTags(&c.Consumer, b.selectTags) + if consumer != nil { + c.Consumer.CreatedAt = consumer.CreatedAt + } + b.rawState.Consumers = append(b.rawState.Consumers, &c.Consumer) + err = b.intermediate.Consumers.Add(state.Consumer{Consumer: c.Consumer}) + if err != nil { + b.err = err + return + } + + // ingest consumer into consumer group + if err := b.ingestIntoConsumerGroup(c); err != nil { + b.err = err + return + } + + // plugins for the Consumer + var plugins []FPlugin + for _, p := range c.Plugins { + p.Consumer = utils.GetConsumerReference(c.Consumer) + plugins = append(plugins, *p) + } + if err := b.ingestPlugins(plugins); err != nil { + b.err = err + return + } + + var keyAuths []kong.KeyAuth + for _, cred := range c.KeyAuths { + cred.Consumer = utils.GetConsumerReference(c.Consumer) + keyAuths = append(keyAuths, *cred) + } + if err := b.ingestKeyAuths(keyAuths); err != nil { + b.err = err + return + } + + var basicAuths []kong.BasicAuth + for _, cred := range c.BasicAuths { + cred.Consumer = utils.GetConsumerReference(c.Consumer) + basicAuths = append(basicAuths, *cred) + } + if err := b.ingestBasicAuths(basicAuths); err != nil { + b.err = err + return + } + + var hmacAuths []kong.HMACAuth + for _, cred := range c.HMACAuths { + cred.Consumer = utils.GetConsumerReference(c.Consumer) + hmacAuths = append(hmacAuths, *cred) + } + if err := b.ingestHMACAuths(hmacAuths); err != nil { + b.err = err + return + } + + var jwtAuths []kong.JWTAuth + for _, cred := range c.JWTAuths { + cred.Consumer = utils.GetConsumerReference(c.Consumer) + jwtAuths = append(jwtAuths, *cred) + } + if err := b.ingestJWTAuths(jwtAuths); err != nil { + b.err = err + return + } + + var oauth2Creds []kong.Oauth2Credential + for _, cred := range c.Oauth2Creds { + cred.Consumer = utils.GetConsumerReference(c.Consumer) + oauth2Creds = append(oauth2Creds, *cred) + } + if err := b.ingestOauth2Creds(oauth2Creds); err != nil { + b.err = err + return + } + + var aclGroups []kong.ACLGroup + for _, cred := range c.ACLGroups { + cred.Consumer = utils.GetConsumerReference(c.Consumer) + aclGroups = append(aclGroups, *cred) + } + if err := b.ingestACLGroups(aclGroups); err != nil { + b.err = err + return + } + + var mtlsAuths []kong.MTLSAuth + for _, cred := range c.MTLSAuths { + cred.Consumer = utils.GetConsumerReference(c.Consumer) + mtlsAuths = append(mtlsAuths, *cred) + } + + b.ingestMTLSAuths(mtlsAuths) + } +} + +func (b *stateBuilder) ingestIntoConsumerGroup(consumer FConsumer) error { + for _, group := range consumer.Groups { + found := false + for _, cg := range b.rawState.ConsumerGroups { + if group.ID != nil && *cg.ConsumerGroup.ID == *group.ID { + cg.Consumers = append(cg.Consumers, &consumer.Consumer) + found = true + break + + } + if group.Name != nil && *cg.ConsumerGroup.Name == *group.Name { + cg.Consumers = append(cg.Consumers, &consumer.Consumer) + found = true + break + } + } + if !found { + var groupIdentifier string + if group.Name != nil { + groupIdentifier = *group.Name + } else { + groupIdentifier = *group.ID + } + return fmt.Errorf( + "consumer-group '%s' not found for consumer '%s'", groupIdentifier, *consumer.ID, + ) + } + } + return nil +} + +func (b *stateBuilder) ingestKeyAuths(creds []kong.KeyAuth) error { + for _, cred := range creds { + cred := cred + existingCred, err := b.currentState.KeyAuths.Get(*cred.Key) + if utils.Empty(cred.ID) { + if errors.Is(err, state.ErrNotFound) { + cred.ID = uuid() + } else if err != nil { + return err + } else { + cred.ID = kong.String(*existingCred.ID) + } + } + if b.kongVersion.GTE(utils.Kong140Version) { + utils.MustMergeTags(&cred, b.selectTags) + } + if existingCred != nil { + cred.CreatedAt = existingCred.CreatedAt + } + b.rawState.KeyAuths = append(b.rawState.KeyAuths, &cred) + } + return nil +} + +func (b *stateBuilder) ingestBasicAuths(creds []kong.BasicAuth) error { + for _, cred := range creds { + cred := cred + existingCred, err := b.currentState.BasicAuths.Get(*cred.Username) + if utils.Empty(cred.ID) { + if errors.Is(err, state.ErrNotFound) { + cred.ID = uuid() + } else if err != nil { + return err + } else { + cred.ID = kong.String(*existingCred.ID) + } + } + if b.kongVersion.GTE(utils.Kong140Version) { + utils.MustMergeTags(&cred, b.selectTags) + } + if existingCred != nil { + cred.CreatedAt = existingCred.CreatedAt + } + b.rawState.BasicAuths = append(b.rawState.BasicAuths, &cred) + } + return nil +} + +func (b *stateBuilder) ingestHMACAuths(creds []kong.HMACAuth) error { + for _, cred := range creds { + cred := cred + existingCred, err := b.currentState.HMACAuths.Get(*cred.Username) + if utils.Empty(cred.ID) { + if errors.Is(err, state.ErrNotFound) { + cred.ID = uuid() + } else if err != nil { + return err + } else { + cred.ID = kong.String(*existingCred.ID) + } + } + if b.kongVersion.GTE(utils.Kong140Version) { + utils.MustMergeTags(&cred, b.selectTags) + } + if existingCred != nil { + cred.CreatedAt = existingCred.CreatedAt + } + b.rawState.HMACAuths = append(b.rawState.HMACAuths, &cred) + } + return nil +} + +func (b *stateBuilder) ingestJWTAuths(creds []kong.JWTAuth) error { + for _, cred := range creds { + cred := cred + existingCred, err := b.currentState.JWTAuths.Get(*cred.Key) + if utils.Empty(cred.ID) { + if errors.Is(err, state.ErrNotFound) { + cred.ID = uuid() + } else if err != nil { + return err + } else { + cred.ID = kong.String(*existingCred.ID) + } + } + if b.kongVersion.GTE(utils.Kong140Version) { + utils.MustMergeTags(&cred, b.selectTags) + } + if existingCred != nil { + cred.CreatedAt = existingCred.CreatedAt + } + b.rawState.JWTAuths = append(b.rawState.JWTAuths, &cred) + } + return nil +} + +func (b *stateBuilder) ingestOauth2Creds(creds []kong.Oauth2Credential) error { + for _, cred := range creds { + cred := cred + existingCred, err := b.currentState.Oauth2Creds.Get(*cred.ClientID) + if utils.Empty(cred.ID) { + if errors.Is(err, state.ErrNotFound) { + cred.ID = uuid() + } else if err != nil { + return err + } else { + cred.ID = kong.String(*existingCred.ID) + } + } + if b.kongVersion.GTE(utils.Kong140Version) { + utils.MustMergeTags(&cred, b.selectTags) + } + if existingCred != nil { + cred.CreatedAt = existingCred.CreatedAt + } + b.rawState.Oauth2Creds = append(b.rawState.Oauth2Creds, &cred) + } + return nil +} + +func (b *stateBuilder) ingestACLGroups(creds []kong.ACLGroup) error { + for _, cred := range creds { + cred := cred + if utils.Empty(cred.ID) { + existingCred, err := b.currentState.ACLGroups.Get( + *cred.Consumer.ID, + *cred.Group) + if errors.Is(err, state.ErrNotFound) { + cred.ID = uuid() + } else if err != nil { + return err + } else { + cred.ID = kong.String(*existingCred.ID) + } + } + if b.kongVersion.GTE(utils.Kong140Version) { + utils.MustMergeTags(&cred, b.selectTags) + } + b.rawState.ACLGroups = append(b.rawState.ACLGroups, &cred) + } + return nil +} + +func (b *stateBuilder) ingestMTLSAuths(creds []kong.MTLSAuth) { + kong230Version := semver.MustParse("2.3.0") + for _, cred := range creds { + cred := cred + // normally, we'd want to look up existing resources in this case + // however, this is impossible here: mtls-auth simply has no unique fields other than ID, + // so we don't--schema validation requires the ID + // there's nothing more to do here + + if b.kongVersion.GTE(kong230Version) { + utils.MustMergeTags(&cred, b.selectTags) + } + b.rawState.MTLSAuths = append(b.rawState.MTLSAuths, &cred) + } +} + +func (b *stateBuilder) konnect() { + if b.err != nil { + return + } + + for i := range b.targetContent.ServicePackages { + targetSP := b.targetContent.ServicePackages[i] + if utils.Empty(targetSP.ID) { + currentSP, err := b.currentState.ServicePackages.Get(*targetSP.Name) + if errors.Is(err, state.ErrNotFound) { + targetSP.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + targetSP.ID = kong.String(*currentSP.ID) + } + } + + targetKonnectSP := konnect.ServicePackage{ + ID: targetSP.ID, + Name: targetSP.Name, + Description: targetSP.Description, + } + + if targetSP.Document != nil { + targetKonnectDoc := konnect.Document{ + ID: targetSP.Document.ID, + Path: targetSP.Document.Path, + Published: targetSP.Document.Published, + Content: targetSP.Document.Content, + Parent: &targetKonnectSP, + } + if utils.Empty(targetKonnectDoc.ID) { + currentDoc, err := b.currentState.Documents.GetByParent(&targetKonnectSP, *targetKonnectDoc.Path) + if errors.Is(err, state.ErrNotFound) { + targetKonnectDoc.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + targetKonnectDoc.ID = kong.String(*currentDoc.ID) + } + } + b.konnectRawState.Documents = append(b.konnectRawState.Documents, &targetKonnectDoc) + } + + // versions associated with the package + for _, targetSV := range targetSP.Versions { + targetKonnectSV := konnect.ServiceVersion{ + ID: targetSV.ID, + Version: targetSV.Version, + } + targetRelationID := "" + if utils.Empty(targetKonnectSV.ID) { + currentSV, err := b.currentState.ServiceVersions.Get(*targetKonnectSP.ID, *targetKonnectSV.Version) + if errors.Is(err, state.ErrNotFound) { + targetKonnectSV.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + targetKonnectSV.ID = kong.String(*currentSV.ID) + if currentSV.ControlPlaneServiceRelation != nil { + targetRelationID = *currentSV.ControlPlaneServiceRelation.ID + } + } + } + if targetSV.Implementation != nil && + targetSV.Implementation.Kong != nil { + err := b.ingestService(targetSV.Implementation.Kong.Service) + if err != nil { + b.err = err + return + } + targetKonnectSV.ControlPlaneServiceRelation = &konnect.ControlPlaneServiceRelation{ + ControlPlaneEntityID: targetSV.Implementation.Kong.Service.ID, + } + if targetRelationID != "" { + targetKonnectSV.ControlPlaneServiceRelation.ID = &targetRelationID + } + } + if targetSV.Document != nil { + targetKonnectDoc := konnect.Document{ + ID: targetSV.Document.ID, + Path: targetSV.Document.Path, + Published: targetSV.Document.Published, + Content: targetSV.Document.Content, + Parent: &targetKonnectSV, + } + if utils.Empty(targetKonnectDoc.ID) { + currentDoc, err := b.currentState.Documents.GetByParent(&targetKonnectSV, *targetKonnectDoc.Path) + if errors.Is(err, state.ErrNotFound) { + targetKonnectDoc.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + targetKonnectDoc.ID = kong.String(*currentDoc.ID) + } + } + b.konnectRawState.Documents = append(b.konnectRawState.Documents, &targetKonnectDoc) + } + targetKonnectSP.Versions = append(targetKonnectSP.Versions, targetKonnectSV) + } + + b.konnectRawState.ServicePackages = append(b.konnectRawState.ServicePackages, + &targetKonnectSP) + } +} + +func (b *stateBuilder) services() { + if b.err != nil { + return + } + + for _, s := range b.targetContent.Services { + s := s + err := b.ingestService(&s) + if err != nil { + b.err = err + return + } + } +} + +func (b *stateBuilder) ingestService(s *FService) error { + var ( + svc *state.Service + err error + ) + if !utils.Empty(s.Name) { + svc, err = b.currentState.Services.Get(*s.Name) + } + if utils.Empty(s.ID) { + if errors.Is(err, state.ErrNotFound) { + s.ID = uuid() + } else if err != nil { + return err + } else { + s.ID = kong.String(*svc.ID) + } + } + utils.MustMergeTags(&s.Service, b.selectTags) + b.defaulter.MustSet(&s.Service) + if svc != nil { + s.Service.CreatedAt = svc.CreatedAt + } + b.rawState.Services = append(b.rawState.Services, &s.Service) + err = b.intermediate.Services.Add(state.Service{Service: s.Service}) + if err != nil { + return err + } + + // plugins for the service + var plugins []FPlugin + for _, p := range s.Plugins { + p.Service = utils.GetServiceReference(s.Service) + plugins = append(plugins, *p) + } + if err := b.ingestPlugins(plugins); err != nil { + return err + } + + // routes for the service + for _, r := range s.Routes { + r := r + r.Service = utils.GetServiceReference(s.Service) + if err := b.ingestRoute(*r); err != nil { + return err + } + } + return nil +} + +func (b *stateBuilder) routes() { + if b.err != nil { + return + } + + for _, r := range b.targetContent.Routes { + r := r + if err := b.ingestRoute(r); err != nil { + b.err = err + return + } + } + + // check routes' paths format + if b.checkRoutePaths { + unsupportedRoutes := []string{} + allRoutes, err := b.intermediate.Routes.GetAll() + if err != nil { + b.err = err + return + } + for _, r := range allRoutes { + if utils.HasPathsWithRegex300AndAbove(r.Route) { + unsupportedRoutes = append(unsupportedRoutes, *r.Route.ID) + } + } + if len(unsupportedRoutes) > 0 { + utils.PrintRouteRegexWarning(unsupportedRoutes) + } + } +} + +func (b *stateBuilder) enterprise() { + b.rbacRoles() + b.vaults() +} + +func (b *stateBuilder) vaults() { + if b.err != nil { + return + } + + for _, v := range b.targetContent.Vaults { + v := v + vault, err := b.currentState.Vaults.Get(*v.Prefix) + if utils.Empty(v.ID) { + if errors.Is(err, state.ErrNotFound) { + v.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + v.ID = kong.String(*vault.ID) + } + } + utils.MustMergeTags(&v.Vault, b.selectTags) + if vault != nil { + v.Vault.CreatedAt = vault.CreatedAt + } + + b.rawState.Vaults = append(b.rawState.Vaults, &v.Vault) + } +} + +func (b *stateBuilder) rbacRoles() { + if b.err != nil { + return + } + + for _, r := range b.targetContent.RBACRoles { + r := r + role, err := b.currentState.RBACRoles.Get(*r.Name) + if utils.Empty(r.ID) { + if errors.Is(err, state.ErrNotFound) { + r.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + r.ID = kong.String(*role.ID) + } + } + if role != nil { + r.RBACRole.CreatedAt = role.CreatedAt + } + b.rawState.RBACRoles = append(b.rawState.RBACRoles, &r.RBACRole) + // rbac endpoint permissions for the role + for _, ep := range r.EndpointPermissions { + ep := ep + ep.Role = &kong.RBACRole{ID: kong.String(*r.ID)} + b.rawState.RBACEndpointPermissions = append(b.rawState.RBACEndpointPermissions, &ep.RBACEndpointPermission) + } + } +} + +func (b *stateBuilder) upstreams() { + if b.err != nil { + return + } + + for _, u := range b.targetContent.Upstreams { + u := u + ups, err := b.currentState.Upstreams.Get(*u.Name) + if utils.Empty(u.ID) { + if errors.Is(err, state.ErrNotFound) { + u.ID = uuid() + } else if err != nil { + b.err = err + return + } else { + u.ID = kong.String(*ups.ID) + } + } + utils.MustMergeTags(&u.Upstream, b.selectTags) + b.defaulter.MustSet(&u.Upstream) + if ups != nil { + u.Upstream.CreatedAt = ups.CreatedAt + } + + b.rawState.Upstreams = append(b.rawState.Upstreams, &u.Upstream) + + // targets for the upstream + var targets []kong.Target + for _, t := range u.Targets { + t.Upstream = &kong.Upstream{ID: kong.String(*u.ID)} + targets = append(targets, t.Target) + } + if err := b.ingestTargets(targets); err != nil { + b.err = err + return + } + } +} + +func (b *stateBuilder) ingestTargets(targets []kong.Target) error { + for _, t := range targets { + t := t + if utils.Empty(t.ID) { + target, err := b.currentState.Targets.Get(*t.Upstream.ID, *t.Target) + if errors.Is(err, state.ErrNotFound) { + t.ID = uuid() + } else if err != nil { + return err + } else { + t.ID = kong.String(*target.ID) + } + } + utils.MustMergeTags(&t, b.selectTags) + b.defaulter.MustSet(&t) + b.rawState.Targets = append(b.rawState.Targets, &t) + } + return nil +} + +func (b *stateBuilder) plugins() { + if b.err != nil { + return + } + + var plugins []FPlugin + for _, p := range b.targetContent.Plugins { + p := p + if p.Consumer != nil && !utils.Empty(p.Consumer.ID) { + c, err := b.intermediate.Consumers.GetByIDOrUsername(*p.Consumer.ID) + if errors.Is(err, state.ErrNotFound) { + b.err = fmt.Errorf("consumer %v for plugin %v: %w", + p.Consumer.FriendlyName(), *p.Name, err) + + return + } else if err != nil { + b.err = err + return + } + p.Consumer = utils.GetConsumerReference(c.Consumer) + } + if p.Service != nil && !utils.Empty(p.Service.ID) { + s, err := b.intermediate.Services.Get(*p.Service.ID) + if errors.Is(err, state.ErrNotFound) { + b.err = fmt.Errorf("service %v for plugin %v: %w", + p.Service.FriendlyName(), *p.Name, err) + + return + } else if err != nil { + b.err = err + return + } + p.Service = utils.GetServiceReference(s.Service) + } + if p.Route != nil && !utils.Empty(p.Route.ID) { + r, err := b.intermediate.Routes.Get(*p.Route.ID) + if errors.Is(err, state.ErrNotFound) { + b.err = fmt.Errorf("route %v for plugin %v: %w", + p.Route.FriendlyName(), *p.Name, err) + + return + } else if err != nil { + b.err = err + return + } + p.Route = utils.GetRouteReference(r.Route) + } + if p.ConsumerGroup != nil && !utils.Empty(p.ConsumerGroup.ID) { + cg, err := b.intermediate.ConsumerGroups.Get(*p.ConsumerGroup.ID) + if errors.Is(err, state.ErrNotFound) { + b.err = fmt.Errorf("consumer-group %v for plugin %v: %w", + p.ConsumerGroup.FriendlyName(), *p.Name, err) + return + } else if err != nil { + b.err = err + return + } + p.ConsumerGroup = utils.GetConsumerGroupReference(cg.ConsumerGroup) + } + + if err := b.validatePlugin(p); err != nil { + b.err = err + return + } + plugins = append(plugins, p) + } + if err := b.ingestPlugins(plugins); err != nil { + b.err = err + return + } +} + +func (b *stateBuilder) validatePlugin(p FPlugin) error { + if b.isConsumerGroupScopedPluginSupported && *p.Name == ratelimitingAdvancedPluginName { + // check if deprecated consumer-groups configuration is present in the config + var consumerGroupsFound bool + if groups, ok := p.Config["consumer_groups"]; ok { + // if groups is an array of length > 0, then consumer_groups is set + if groupsArray, ok := groups.([]interface{}); ok && len(groupsArray) > 0 { + consumerGroupsFound = true + } + } + var enforceConsumerGroupsFound bool + if enforceConsumerGroups, ok := p.Config["enforce_consumer_groups"]; ok { + if enforceConsumerGroupsBool, ok := enforceConsumerGroups.(bool); ok && enforceConsumerGroupsBool { + enforceConsumerGroupsFound = true + } + } + if consumerGroupsFound || enforceConsumerGroupsFound { + return utils.ErrorConsumerGroupUpgrade + } + } + return nil +} + +// strip_path schema default value is 'true', but it cannot be set when +// protocols include 'grpc' and/or 'grpcs'. When users explicitly set +// strip_path to 'true' with grpc/s protocols, deck returns a schema violation error. +// When strip_path is not set and protocols include grpc/s, deck sets strip_path to 'false', +// despite its default value would be 'true' under normal circumstances. +func getStripPathBasedOnProtocols(route kong.Route) (*bool, error) { + for _, p := range route.Protocols { + if *p == "grpc" || *p == "grpcs" { + if route.StripPath != nil && *route.StripPath { + return nil, fmt.Errorf("schema violation (strip_path: cannot set " + + "'strip_path' when 'protocols' is 'grpc' or 'grpcs')") + } + return kong.Bool(false), nil + } + } + return route.StripPath, nil +} + +func (b *stateBuilder) ingestRoute(r FRoute) error { + var ( + route *state.Route + err error + ) + if !utils.Empty(r.Name) { + route, err = b.currentState.Routes.Get(*r.Name) + } + if utils.Empty(r.ID) { + if errors.Is(err, state.ErrNotFound) { + r.ID = uuid() + } else if err != nil { + return err + } else { + r.ID = kong.String(*route.ID) + } + } + + utils.MustMergeTags(&r, b.selectTags) + stripPath, err := getStripPathBasedOnProtocols(r.Route) + if err != nil { + return err + } + r.Route.StripPath = stripPath + b.defaulter.MustSet(&r.Route) + if route != nil { + r.Route.CreatedAt = route.CreatedAt + } + + b.rawState.Routes = append(b.rawState.Routes, &r.Route) + err = b.intermediate.Routes.Add(state.Route{Route: r.Route}) + if err != nil { + return err + } + + // plugins for the route + var plugins []FPlugin + for _, p := range r.Plugins { + p.Route = utils.GetRouteReference(r.Route) + plugins = append(plugins, *p) + } + if err := b.ingestPlugins(plugins); err != nil { + return err + } + if r.Service != nil && utils.Empty(r.Service.ID) && !utils.Empty(r.Service.Name) { + s, err := b.intermediate.Services.Get(*r.Service.Name) + if err != nil { + return fmt.Errorf("retrieve intermediate services (%s): %w", *r.Service.Name, err) + } + r.Service.ID = s.ID + r.Service.Name = nil + } + return nil +} + +func (b *stateBuilder) getPluginSchema(pluginName string) (map[string]interface{}, error) { + var schema map[string]interface{} + + // lookup in cache + if schema, ok := b.schemasCache[pluginName]; ok { + return schema, nil + } + + exists, err := utils.WorkspaceExists(b.ctx, b.client) + if err != nil { + return nil, fmt.Errorf("ensure workspace exists: %w", err) + } + if !exists { + return schema, ErrWorkspaceNotFound + } + + schema, err = b.client.Plugins.GetFullSchema(b.ctx, &pluginName) + if err != nil { + return schema, err + } + b.schemasCache[pluginName] = schema + return schema, nil +} + +func (b *stateBuilder) addPluginDefaults(plugin *FPlugin) error { + if b.client == nil { + return nil + } + schema, err := b.getPluginSchema(*plugin.Name) + if err != nil { + if errors.Is(err, ErrWorkspaceNotFound) { + return nil + } + return fmt.Errorf("retrieve schema for %v from Kong: %w", *plugin.Name, err) + } + return kong.FillPluginsDefaults(&plugin.Plugin, schema) +} + +func (b *stateBuilder) ingestPlugins(plugins []FPlugin) error { + for _, p := range plugins { + p := p + cID, rID, sID, cgID := pluginRelations(&p.Plugin) + plugin, err := b.currentState.Plugins.GetByProp(*p.Name, + sID, rID, cID, cgID) + if utils.Empty(p.ID) { + if errors.Is(err, state.ErrNotFound) { + p.ID = uuid() + } else if err != nil { + return err + } else { + p.ID = kong.String(*plugin.ID) + } + } + if p.Config == nil { + p.Config = make(map[string]interface{}) + } + p.Config = ensureJSON(p.Config) + err = b.fillPluginConfig(&p) + if err != nil { + return err + } + if err := b.addPluginDefaults(&p); err != nil { + return fmt.Errorf("add defaults to plugin '%v': %w", *p.Name, err) + } + utils.MustMergeTags(&p, b.selectTags) + if plugin != nil { + p.Plugin.CreatedAt = plugin.CreatedAt + } + b.rawState.Plugins = append(b.rawState.Plugins, &p.Plugin) + } + return nil +} + +func (b *stateBuilder) fillPluginConfig(plugin *FPlugin) error { + if plugin == nil { + return fmt.Errorf("plugin is nil") + } + if !utils.Empty(plugin.ConfigSource) { + conf, ok := b.targetContent.PluginConfigs[*plugin.ConfigSource] + if !ok { + return fmt.Errorf("_plugin_config %q not found", + *plugin.ConfigSource) + } + for k, v := range conf { + if _, ok := plugin.Config[k]; !ok { + plugin.Config[k] = v + } + } + } + return nil +} + +func pluginRelations(plugin *kong.Plugin) (cID, rID, sID, cgID string) { + if plugin.Consumer != nil && !utils.Empty(plugin.Consumer.ID) { + cID = *plugin.Consumer.ID + } + if plugin.Route != nil && !utils.Empty(plugin.Route.ID) { + rID = *plugin.Route.ID + } + if plugin.Service != nil && !utils.Empty(plugin.Service.ID) { + sID = *plugin.Service.ID + } + if plugin.ConsumerGroup != nil && !utils.Empty(plugin.ConsumerGroup.ID) { + cgID = *plugin.ConsumerGroup.ID + } + return +} + +func defaulter( + ctx context.Context, client *kong.Client, fileContent *Content, disableDynamicDefaults, isKonnect bool, +) (*utils.Defaulter, error) { + var kongDefaults KongDefaults + if fileContent.Info != nil { + kongDefaults = fileContent.Info.Defaults + } + opts := utils.DefaulterOpts{ + Client: client, + KongDefaults: kongDefaults, + DisableDynamicDefaults: disableDynamicDefaults, + IsKonnect: isKonnect, + } + defaulter, err := utils.GetDefaulter(ctx, opts) + if err != nil { + return nil, fmt.Errorf("creating defaulter: %w", err) + } + return defaulter, nil +} diff --git a/pkg/file/builder_test.go b/pkg/file/builder_test.go new file mode 100644 index 0000000..1b79c28 --- /dev/null +++ b/pkg/file/builder_test.go @@ -0,0 +1,2757 @@ +package file + +import ( + "context" + "encoding/hex" + "math/rand" + "os" + "reflect" + "testing" + + "github.com/blang/semver/v4" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +const ( + defaultTimeout = 60000 + defaultSlots = 10000 + defaultWeight = 100 + defaultConcurrency = 10 +) + +var kong130Version = semver.MustParse("1.3.0") + +var kongDefaults = KongDefaults{ + Service: &kong.Service{ + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(defaultTimeout), + WriteTimeout: kong.Int(defaultTimeout), + ReadTimeout: kong.Int(defaultTimeout), + }, + Route: &kong.Route{ + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + }, + Upstream: &kong.Upstream{ + Slots: kong.Int(defaultSlots), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(defaultConcurrency), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + Target: &kong.Target{ + Weight: kong.Int(defaultWeight), + }, +} + +var defaulterTestOpts = utils.DefaulterOpts{ + KongDefaults: kongDefaults, + DisableDynamicDefaults: false, +} + +func emptyState() *state.KongState { + s, _ := state.NewKongState() + return s +} + +func existingRouteState() *state.KongState { + s, _ := state.NewKongState() + s.Routes.Add(state.Route{ + Route: kong.Route{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + }, + }) + return s +} + +func existingServiceState() *state.KongState { + s, _ := state.NewKongState() + s.Services.Add(state.Service{ + Service: kong.Service{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + }, + }) + return s +} + +func existingConsumerCredState() *state.KongState { + s, _ := state.NewKongState() + s.Consumers.Add(state.Consumer{ + Consumer: kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }) + s.KeyAuths.Add(state.KeyAuth{ + KeyAuth: kong.KeyAuth{ + ID: kong.String("5f1ef1ea-a2a5-4a1b-adbb-b0d3434013e5"), + Key: kong.String("foo-apikey"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }) + s.BasicAuths.Add(state.BasicAuth{ + BasicAuth: kong.BasicAuth{ + ID: kong.String("92f4c849-960b-43af-aad3-f307051408d3"), + Username: kong.String("basic-username"), + Password: kong.String("basic-password"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }) + s.JWTAuths.Add(state.JWTAuth{ + JWTAuth: kong.JWTAuth{ + ID: kong.String("917b9402-1be0-49d2-b482-ca4dccc2054e"), + Key: kong.String("jwt-key"), + Secret: kong.String("jwt-secret"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }) + s.HMACAuths.Add(state.HMACAuth{ + HMACAuth: kong.HMACAuth{ + ID: kong.String("e5d81b73-bf9e-42b0-9d68-30a1d791b9c9"), + Username: kong.String("hmac-username"), + Secret: kong.String("hmac-secret"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }) + s.ACLGroups.Add(state.ACLGroup{ + ACLGroup: kong.ACLGroup{ + ID: kong.String("b7c9352a-775a-4ba5-9869-98e926a3e6cb"), + Group: kong.String("foo-group"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }) + s.Oauth2Creds.Add(state.Oauth2Credential{ + Oauth2Credential: kong.Oauth2Credential{ + ID: kong.String("4eef5285-3d6a-4f6b-b659-8957a940e2ca"), + ClientID: kong.String("oauth2-clientid"), + Name: kong.String("oauth2-name"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }) + s.MTLSAuths.Add(state.MTLSAuth{ + MTLSAuth: kong.MTLSAuth{ + ID: kong.String("92f4c829-968b-42af-afd3-f337051508d3"), + SubjectName: kong.String("test@example.com"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }) + return s +} + +func existingUpstreamState() *state.KongState { + s, _ := state.NewKongState() + s.Upstreams.Add(state.Upstream{ + Upstream: kong.Upstream{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + }, + }) + return s +} + +func existingCertificateState() *state.KongState { + s, _ := state.NewKongState() + s.Certificates.Add(state.Certificate{ + Certificate: kong.Certificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Cert: kong.String("foo"), + Key: kong.String("bar"), + }, + }) + return s +} + +func existingCertificateAndSNIState() *state.KongState { + s, _ := state.NewKongState() + s.Certificates.Add(state.Certificate{ + Certificate: kong.Certificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Cert: kong.String("foo"), + Key: kong.String("bar"), + }, + }) + s.SNIs.Add(state.SNI{ + SNI: kong.SNI{ + ID: kong.String("a53e9598-3a5e-4c12-a672-71a4cdcf7a47"), + Name: kong.String("foo.example.com"), + Certificate: &kong.Certificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + }, + }, + }) + s.SNIs.Add(state.SNI{ + SNI: kong.SNI{ + ID: kong.String("5f8e6848-4cb9-479a-a27e-860e1a77f875"), + Name: kong.String("bar.example.com"), + Certificate: &kong.Certificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + }, + }, + }) + return s +} + +func existingCACertificateState() *state.KongState { + s, _ := state.NewKongState() + s.CACertificates.Add(state.CACertificate{ + CACertificate: kong.CACertificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Cert: kong.String("foo"), + }, + }) + return s +} + +func existingPluginState() *state.KongState { + s, _ := state.NewKongState() + s.Plugins.Add(state.Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + }, + }) + s.Plugins.Add(state.Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("f7e64af5-e438-4a9b-8ff8-ec6f5f06dccb"), + Name: kong.String("bar"), + Consumer: &kong.Consumer{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + }, + }) + s.Plugins.Add(state.Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("53ce0a9c-d518-40ee-b8ab-1ee83a20d382"), + Name: kong.String("foo"), + Consumer: &kong.Consumer{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + Route: &kong.Route{ + ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), + }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("69ed4618-a653-4b54-8bb6-dc33bd6fe048"), + }, + }, + }) + return s +} + +func existingTargetsState() *state.KongState { + s, _ := state.NewKongState() + s.Targets.Add(state.Target{ + Target: kong.Target{ + ID: kong.String("f7e64af5-e438-4a9b-8ff8-ec6f5f06dccb"), + Target: kong.String("bar"), + Upstream: &kong.Upstream{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + }, + }) + s.Targets.Add(state.Target{ + Target: kong.Target{ + ID: kong.String("53ce0a9c-d518-40ee-b8ab-1ee83a20d382"), + Target: kong.String("foo"), + Upstream: &kong.Upstream{ + ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), + }, + }, + }) + return s +} + +func existingDocumentState() *state.KongState { + s, _ := state.NewKongState() + s.ServicePackages.Add(state.ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + }, + }) + parent, _ := s.ServicePackages.Get("4bfcb11f-c962-4817-83e5-9433cf20b663") + s.Documents.Add(state.Document{ + Document: konnect.Document{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Path: kong.String("/foo.md"), + Published: kong.Bool(true), + Content: kong.String("foo"), + Parent: parent, + }, + }) + return s +} + +var deterministicUUID = func() *string { + version := byte(4) + uuid := make([]byte, 16) + rand.Read(uuid) + + // Set version + uuid[6] = (uuid[6] & 0x0f) | (version << 4) + + // Set variant + uuid[8] = (uuid[8] & 0xbf) | 0x80 + + buf := make([]byte, 36) + var dash byte = '-' + hex.Encode(buf[0:8], uuid[0:4]) + buf[8] = dash + hex.Encode(buf[9:13], uuid[4:6]) + buf[13] = dash + hex.Encode(buf[14:18], uuid[6:8]) + buf[18] = dash + hex.Encode(buf[19:23], uuid[8:10]) + buf[23] = dash + hex.Encode(buf[24:], uuid[10:]) + s := string(buf) + return &s +} + +func TestMain(m *testing.M) { + uuid = deterministicUUID + os.Exit(m.Run()) +} + +func Test_stateBuilder_services(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + targetContent *Content + currentState *state.KongState + } + tests := []struct { + name string + fields fields + want *utils.KongRawState + }{ + { + name: "matches ID of an existing service", + fields: fields{ + targetContent: &Content{ + Info: &Info{ + Defaults: kongDefaults, + }, + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("foo"), + }, + }, + }, + }, + currentState: existingServiceState(), + }, + want: &utils.KongRawState{ + Services: []*kong.Service{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(60000), + ReadTimeout: kong.Int(60000), + Tags: kong.StringSlice("tag1"), + }, + }, + }, + }, + { + name: "process a non-existent service", + fields: fields{ + targetContent: &Content{ + Info: &Info{ + Defaults: kongDefaults, + }, + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("foo"), + }, + }, + }, + }, + currentState: emptyState(), + }, + want: &utils.KongRawState{ + Services: []*kong.Service{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo"), + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(60000), + ReadTimeout: kong.Int(60000), + Tags: kong.StringSlice("tag1"), + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &stateBuilder{ + targetContent: tt.fields.targetContent, + currentState: tt.fields.currentState, + selectTags: []string{"tag1"}, + } + b.build() + assert.Equal(tt.want, b.rawState) + }) + } +} + +func Test_stateBuilder_ingestRoute(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + currentState *state.KongState + } + type args struct { + route FRoute + } + tests := []struct { + name string + fields fields + args args + wantErr bool + wantState *utils.KongRawState + }{ + { + name: "generates ID for a non-existing route", + fields: fields{ + currentState: emptyState(), + }, + args: args{ + route: FRoute{ + Route: kong.Route{ + Name: kong.String("foo"), + }, + }, + }, + wantErr: false, + wantState: &utils.KongRawState{ + Routes: []*kong.Route{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + }, + }, + }, + }, + { + name: "matches up IDs of routes correctly", + fields: fields{ + currentState: existingRouteState(), + }, + args: args{ + route: FRoute{ + Route: kong.Route{ + Name: kong.String("foo"), + }, + }, + }, + wantErr: false, + wantState: &utils.KongRawState{ + Routes: []*kong.Route{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + }, + }, + }, + }, + { + name: "grpc route has strip_path=false", + fields: fields{ + currentState: existingRouteState(), + }, + args: args{ + route: FRoute{ + Route: kong.Route{ + Name: kong.String("foo"), + Protocols: kong.StringSlice("grpc"), + }, + }, + }, + wantErr: false, + wantState: &utils.KongRawState{ + Routes: []*kong.Route{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("grpc"), + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + b := &stateBuilder{ + currentState: tt.fields.currentState, + } + b.rawState = &utils.KongRawState{} + d, _ := utils.GetDefaulter(ctx, defaulterTestOpts) + b.defaulter = d + b.intermediate, _ = state.NewKongState() + if err := b.ingestRoute(tt.args.route); (err != nil) != tt.wantErr { + t.Errorf("stateBuilder.ingestPlugins() error = %v, wantErr %v", err, tt.wantErr) + } + assert.Equal(tt.wantState, b.rawState) + }) + } +} + +func Test_stateBuilder_ingestTargets(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + currentState *state.KongState + } + type args struct { + targets []kong.Target + } + tests := []struct { + name string + fields fields + args args + wantErr bool + wantState *utils.KongRawState + }{ + { + name: "generates ID for a non-existing target", + fields: fields{ + currentState: emptyState(), + }, + args: args{ + targets: []kong.Target{ + { + Target: kong.String("foo"), + Upstream: &kong.Upstream{ + ID: kong.String("952ddf37-e815-40b6-b119-5379a3b1f7be"), + }, + }, + }, + }, + wantErr: false, + wantState: &utils.KongRawState{ + Targets: []*kong.Target{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Target: kong.String("foo"), + Weight: kong.Int(100), + Upstream: &kong.Upstream{ + ID: kong.String("952ddf37-e815-40b6-b119-5379a3b1f7be"), + }, + }, + }, + }, + }, + { + name: "matches up IDs of Targets correctly", + fields: fields{ + currentState: existingTargetsState(), + }, + args: args{ + targets: []kong.Target{ + { + Target: kong.String("bar"), + Upstream: &kong.Upstream{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + }, + { + Target: kong.String("foo"), + Upstream: &kong.Upstream{ + ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), + }, + }, + }, + }, + wantErr: false, + wantState: &utils.KongRawState{ + Targets: []*kong.Target{ + { + ID: kong.String("f7e64af5-e438-4a9b-8ff8-ec6f5f06dccb"), + Target: kong.String("bar"), + Weight: kong.Int(100), + Upstream: &kong.Upstream{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + }, + { + ID: kong.String("53ce0a9c-d518-40ee-b8ab-1ee83a20d382"), + Target: kong.String("foo"), + Weight: kong.Int(100), + Upstream: &kong.Upstream{ + ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + b := &stateBuilder{ + currentState: tt.fields.currentState, + } + b.rawState = &utils.KongRawState{} + d, _ := utils.GetDefaulter(ctx, defaulterTestOpts) + b.defaulter = d + if err := b.ingestTargets(tt.args.targets); (err != nil) != tt.wantErr { + t.Errorf("stateBuilder.ingestPlugins() error = %v, wantErr %v", err, tt.wantErr) + } + assert.Equal(tt.wantState, b.rawState) + }) + } +} + +func Test_stateBuilder_ingestPlugins(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + currentState *state.KongState + } + type args struct { + plugins []FPlugin + } + tests := []struct { + name string + fields fields + args args + wantErr bool + wantState *utils.KongRawState + }{ + { + name: "generates ID for a non-existing plugin", + fields: fields{ + currentState: emptyState(), + }, + args: args{ + plugins: []FPlugin{ + { + Plugin: kong.Plugin{ + Name: kong.String("foo"), + }, + }, + }, + }, + wantErr: false, + wantState: &utils.KongRawState{ + Plugins: []*kong.Plugin{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo"), + Config: kong.Configuration{}, + }, + }, + }, + }, + { + name: "matches up IDs of plugins correctly", + fields: fields{ + currentState: existingPluginState(), + }, + args: args{ + plugins: []FPlugin{ + { + Plugin: kong.Plugin{ + Name: kong.String("foo"), + }, + }, + { + Plugin: kong.Plugin{ + Name: kong.String("bar"), + Consumer: &kong.Consumer{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + }, + }, + { + Plugin: kong.Plugin{ + Name: kong.String("foo"), + Consumer: &kong.Consumer{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + Route: &kong.Route{ + ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), + }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("69ed4618-a653-4b54-8bb6-dc33bd6fe048"), + }, + }, + }, + }, + }, + wantErr: false, + wantState: &utils.KongRawState{ + Plugins: []*kong.Plugin{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + Config: kong.Configuration{}, + }, + { + ID: kong.String("f7e64af5-e438-4a9b-8ff8-ec6f5f06dccb"), + Name: kong.String("bar"), + Consumer: &kong.Consumer{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + Config: kong.Configuration{}, + }, + { + ID: kong.String("53ce0a9c-d518-40ee-b8ab-1ee83a20d382"), + Name: kong.String("foo"), + Consumer: &kong.Consumer{ + ID: kong.String("f77ca8c7-581d-45a4-a42c-c003234228e1"), + }, + Route: &kong.Route{ + ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), + }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("69ed4618-a653-4b54-8bb6-dc33bd6fe048"), + }, + Config: kong.Configuration{}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &stateBuilder{ + currentState: tt.fields.currentState, + } + b.rawState = &utils.KongRawState{} + if err := b.ingestPlugins(tt.args.plugins); (err != nil) != tt.wantErr { + t.Errorf("stateBuilder.ingestPlugins() error = %v, wantErr %v", err, tt.wantErr) + } + assert.Equal(tt.wantState, b.rawState) + }) + } +} + +func Test_pluginRelations(t *testing.T) { + type args struct { + plugin *kong.Plugin + } + tests := []struct { + name string + args args + wantCID string + wantRID string + wantSID string + wantCGID string + }{ + { + args: args{ + plugin: &kong.Plugin{ + Name: kong.String("foo"), + }, + }, + wantCID: "", + wantRID: "", + wantSID: "", + wantCGID: "", + }, + { + args: args{ + plugin: &kong.Plugin{ + Name: kong.String("foo"), + Consumer: &kong.Consumer{ + ID: kong.String("cID"), + }, + Route: &kong.Route{ + ID: kong.String("rID"), + }, + Service: &kong.Service{ + ID: kong.String("sID"), + }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("cgID"), + }, + }, + }, + wantCID: "cID", + wantRID: "rID", + wantSID: "sID", + wantCGID: "cgID", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCID, gotRID, gotSID, gotCGID := pluginRelations(tt.args.plugin) + if gotCID != tt.wantCID { + t.Errorf("pluginRelations() gotCID = %v, want %v", gotCID, tt.wantCID) + } + if gotRID != tt.wantRID { + t.Errorf("pluginRelations() gotRID = %v, want %v", gotRID, tt.wantRID) + } + if gotSID != tt.wantSID { + t.Errorf("pluginRelations() gotSID = %v, want %v", gotSID, tt.wantSID) + } + if gotCGID != tt.wantCGID { + t.Errorf("pluginRelations() gotCGID = %v, want %v", gotCGID, tt.wantCGID) + } + }) + } +} + +func Test_stateBuilder_consumers(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + currentState *state.KongState + targetContent *Content + kongVersion *semver.Version + } + tests := []struct { + name string + fields fields + want *utils.KongRawState + }{ + { + name: "generates ID for a non-existing consumer", + fields: fields{ + targetContent: &Content{ + Consumers: []FConsumer{ + { + Consumer: kong.Consumer{ + Username: kong.String("foo"), + }, + }, + }, + Info: &Info{}, + }, + currentState: emptyState(), + }, + want: &utils.KongRawState{ + Consumers: []*kong.Consumer{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Username: kong.String("foo"), + }, + }, + }, + }, + { + name: "generates ID for a non-existing credential", + fields: fields{ + targetContent: &Content{ + Consumers: []FConsumer{ + { + Consumer: kong.Consumer{ + Username: kong.String("foo"), + }, + KeyAuths: []*kong.KeyAuth{ + { + Key: kong.String("foo-key"), + }, + }, + BasicAuths: []*kong.BasicAuth{ + { + Username: kong.String("basic-username"), + Password: kong.String("basic-password"), + }, + }, + HMACAuths: []*kong.HMACAuth{ + { + Username: kong.String("hmac-username"), + Secret: kong.String("hmac-secret"), + }, + }, + JWTAuths: []*kong.JWTAuth{ + { + Key: kong.String("jwt-key"), + Secret: kong.String("jwt-secret"), + }, + }, + Oauth2Creds: []*kong.Oauth2Credential{ + { + ClientID: kong.String("oauth2-clientid"), + Name: kong.String("oauth2-name"), + }, + }, + ACLGroups: []*kong.ACLGroup{ + { + Group: kong.String("foo-group"), + }, + }, + }, + }, + Info: &Info{}, + }, + currentState: emptyState(), + }, + want: &utils.KongRawState{ + Consumers: []*kong.Consumer{ + { + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Username: kong.String("foo"), + }, + }, + KeyAuths: []*kong.KeyAuth{ + { + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Key: kong.String("foo-key"), + Consumer: &kong.Consumer{ + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Username: kong.String("foo"), + }, + }, + }, + BasicAuths: []*kong.BasicAuth{ + { + ID: kong.String("0cc0d614-4c88-4535-841a-cbe0709b0758"), + Username: kong.String("basic-username"), + Password: kong.String("basic-password"), + Consumer: &kong.Consumer{ + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Username: kong.String("foo"), + }, + }, + }, + HMACAuths: []*kong.HMACAuth{ + { + ID: kong.String("083f61d3-75bc-42b4-9df4-f91929e18fda"), + Username: kong.String("hmac-username"), + Secret: kong.String("hmac-secret"), + Consumer: &kong.Consumer{ + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Username: kong.String("foo"), + }, + }, + }, + JWTAuths: []*kong.JWTAuth{ + { + ID: kong.String("9e6f82e5-4e74-4e81-a79e-4bbd6fe34cdc"), + Key: kong.String("jwt-key"), + Secret: kong.String("jwt-secret"), + Consumer: &kong.Consumer{ + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Username: kong.String("foo"), + }, + }, + }, + Oauth2Creds: []*kong.Oauth2Credential{ + { + ID: kong.String("ba843ee8-d63e-4c4f-be1c-ebea546d8fac"), + ClientID: kong.String("oauth2-clientid"), + Name: kong.String("oauth2-name"), + Consumer: &kong.Consumer{ + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Username: kong.String("foo"), + }, + }, + }, + ACLGroups: []*kong.ACLGroup{ + { + ID: kong.String("13dd1aac-04ce-4ea2-877c-5579cfa2c78e"), + Group: kong.String("foo-group"), + Consumer: &kong.Consumer{ + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Username: kong.String("foo"), + }, + }, + }, + MTLSAuths: nil, + }, + }, + { + name: "matches ID of an existing consumer", + fields: fields{ + targetContent: &Content{ + Consumers: []FConsumer{ + { + Consumer: kong.Consumer{ + Username: kong.String("foo"), + }, + }, + }, + }, + currentState: existingConsumerCredState(), + }, + want: &utils.KongRawState{ + Consumers: []*kong.Consumer{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + }, + { + name: "matches ID of an existing credential", + fields: fields{ + targetContent: &Content{ + Consumers: []FConsumer{ + { + Consumer: kong.Consumer{ + Username: kong.String("foo"), + }, + KeyAuths: []*kong.KeyAuth{ + { + Key: kong.String("foo-apikey"), + }, + }, + BasicAuths: []*kong.BasicAuth{ + { + Username: kong.String("basic-username"), + Password: kong.String("basic-password"), + }, + }, + HMACAuths: []*kong.HMACAuth{ + { + Username: kong.String("hmac-username"), + Secret: kong.String("hmac-secret"), + }, + }, + JWTAuths: []*kong.JWTAuth{ + { + Key: kong.String("jwt-key"), + Secret: kong.String("jwt-secret"), + }, + }, + Oauth2Creds: []*kong.Oauth2Credential{ + { + ClientID: kong.String("oauth2-clientid"), + Name: kong.String("oauth2-name"), + }, + }, + ACLGroups: []*kong.ACLGroup{ + { + Group: kong.String("foo-group"), + }, + }, + MTLSAuths: []*kong.MTLSAuth{ + { + ID: kong.String("533c259e-bf71-4d77-99d2-97944c70a6a4"), + SubjectName: kong.String("test@example.com"), + }, + }, + }, + }, + Info: &Info{}, + }, + currentState: existingConsumerCredState(), + }, + want: &utils.KongRawState{ + Consumers: []*kong.Consumer{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + KeyAuths: []*kong.KeyAuth{ + { + ID: kong.String("5f1ef1ea-a2a5-4a1b-adbb-b0d3434013e5"), + Key: kong.String("foo-apikey"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + BasicAuths: []*kong.BasicAuth{ + { + ID: kong.String("92f4c849-960b-43af-aad3-f307051408d3"), + Username: kong.String("basic-username"), + Password: kong.String("basic-password"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + HMACAuths: []*kong.HMACAuth{ + { + ID: kong.String("e5d81b73-bf9e-42b0-9d68-30a1d791b9c9"), + Username: kong.String("hmac-username"), + Secret: kong.String("hmac-secret"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + JWTAuths: []*kong.JWTAuth{ + { + ID: kong.String("917b9402-1be0-49d2-b482-ca4dccc2054e"), + Key: kong.String("jwt-key"), + Secret: kong.String("jwt-secret"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + Oauth2Creds: []*kong.Oauth2Credential{ + { + ID: kong.String("4eef5285-3d6a-4f6b-b659-8957a940e2ca"), + ClientID: kong.String("oauth2-clientid"), + Name: kong.String("oauth2-name"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + ACLGroups: []*kong.ACLGroup{ + { + ID: kong.String("b7c9352a-775a-4ba5-9869-98e926a3e6cb"), + Group: kong.String("foo-group"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + MTLSAuths: []*kong.MTLSAuth{ + { + ID: kong.String("533c259e-bf71-4d77-99d2-97944c70a6a4"), + SubjectName: kong.String("test@example.com"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + }, + }, + { + name: "does not inject tags if Kong version is older than 1.4", + fields: fields{ + targetContent: &Content{ + Consumers: []FConsumer{ + { + Consumer: kong.Consumer{ + Username: kong.String("foo"), + }, + KeyAuths: []*kong.KeyAuth{ + { + Key: kong.String("foo-apikey"), + }, + }, + BasicAuths: []*kong.BasicAuth{ + { + Username: kong.String("basic-username"), + Password: kong.String("basic-password"), + }, + }, + HMACAuths: []*kong.HMACAuth{ + { + Username: kong.String("hmac-username"), + Secret: kong.String("hmac-secret"), + }, + }, + JWTAuths: []*kong.JWTAuth{ + { + Key: kong.String("jwt-key"), + Secret: kong.String("jwt-secret"), + }, + }, + Oauth2Creds: []*kong.Oauth2Credential{ + { + ClientID: kong.String("oauth2-clientid"), + Name: kong.String("oauth2-name"), + }, + }, + ACLGroups: []*kong.ACLGroup{ + { + Group: kong.String("foo-group"), + }, + }, + MTLSAuths: []*kong.MTLSAuth{ + { + ID: kong.String("533c259e-bf71-4d77-99d2-97944c70a6a4"), + SubjectName: kong.String("test@example.com"), + }, + }, + }, + }, + Info: &Info{}, + }, + currentState: existingConsumerCredState(), + kongVersion: &kong130Version, + }, + want: &utils.KongRawState{ + Consumers: []*kong.Consumer{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + KeyAuths: []*kong.KeyAuth{ + { + ID: kong.String("5f1ef1ea-a2a5-4a1b-adbb-b0d3434013e5"), + Key: kong.String("foo-apikey"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + BasicAuths: []*kong.BasicAuth{ + { + ID: kong.String("92f4c849-960b-43af-aad3-f307051408d3"), + Username: kong.String("basic-username"), + Password: kong.String("basic-password"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + HMACAuths: []*kong.HMACAuth{ + { + ID: kong.String("e5d81b73-bf9e-42b0-9d68-30a1d791b9c9"), + Username: kong.String("hmac-username"), + Secret: kong.String("hmac-secret"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + JWTAuths: []*kong.JWTAuth{ + { + ID: kong.String("917b9402-1be0-49d2-b482-ca4dccc2054e"), + Key: kong.String("jwt-key"), + Secret: kong.String("jwt-secret"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + Oauth2Creds: []*kong.Oauth2Credential{ + { + ID: kong.String("4eef5285-3d6a-4f6b-b659-8957a940e2ca"), + ClientID: kong.String("oauth2-clientid"), + Name: kong.String("oauth2-name"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + ACLGroups: []*kong.ACLGroup{ + { + ID: kong.String("b7c9352a-775a-4ba5-9869-98e926a3e6cb"), + Group: kong.String("foo-group"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + MTLSAuths: []*kong.MTLSAuth{ + { + ID: kong.String("533c259e-bf71-4d77-99d2-97944c70a6a4"), + SubjectName: kong.String("test@example.com"), + Consumer: &kong.Consumer{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Username: kong.String("foo"), + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + b := &stateBuilder{ + targetContent: tt.fields.targetContent, + currentState: tt.fields.currentState, + kongVersion: utils.Kong140Version, + } + if tt.fields.kongVersion != nil { + b.kongVersion = *tt.fields.kongVersion + } + d, _ := utils.GetDefaulter(ctx, defaulterTestOpts) + b.defaulter = d + b.build() + assert.Equal(tt.want, b.rawState) + }) + } +} + +func Test_stateBuilder_certificates(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + currentState *state.KongState + targetContent *Content + } + tests := []struct { + name string + fields fields + want *utils.KongRawState + }{ + { + name: "generates ID for a non-existing certificate", + fields: fields{ + targetContent: &Content{ + Certificates: []FCertificate{ + { + Cert: kong.String("foo"), + Key: kong.String("bar"), + }, + }, + }, + currentState: emptyState(), + }, + want: &utils.KongRawState{ + Certificates: []*kong.Certificate{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Cert: kong.String("foo"), + Key: kong.String("bar"), + }, + }, + }, + }, + { + name: "matches ID of an existing certificate", + fields: fields{ + targetContent: &Content{ + Certificates: []FCertificate{ + { + Cert: kong.String("foo"), + Key: kong.String("bar"), + }, + }, + }, + currentState: existingCertificateState(), + }, + want: &utils.KongRawState{ + Certificates: []*kong.Certificate{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Cert: kong.String("foo"), + Key: kong.String("bar"), + }, + }, + }, + }, + { + name: "generates ID for SNIs", + fields: fields{ + targetContent: &Content{ + Certificates: []FCertificate{ + { + Cert: kong.String("foo"), + Key: kong.String("bar"), + SNIs: []kong.SNI{ + { + Name: kong.String("foo.example.com"), + }, + { + Name: kong.String("bar.example.com"), + }, + }, + }, + }, + }, + currentState: existingCertificateState(), + }, + want: &utils.KongRawState{ + Certificates: []*kong.Certificate{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Cert: kong.String("foo"), + Key: kong.String("bar"), + }, + }, + SNIs: []*kong.SNI{ + { + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Name: kong.String("foo.example.com"), + Certificate: &kong.Certificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + }, + }, + { + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Name: kong.String("bar.example.com"), + Certificate: &kong.Certificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + }, + }, + }, + }, + }, + { + name: "matches ID for SNIs", + fields: fields{ + targetContent: &Content{ + Certificates: []FCertificate{ + { + Cert: kong.String("foo"), + Key: kong.String("bar"), + SNIs: []kong.SNI{ + { + Name: kong.String("foo.example.com"), + }, + { + Name: kong.String("bar.example.com"), + }, + }, + }, + }, + }, + currentState: existingCertificateAndSNIState(), + }, + want: &utils.KongRawState{ + Certificates: []*kong.Certificate{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Cert: kong.String("foo"), + Key: kong.String("bar"), + }, + }, + SNIs: []*kong.SNI{ + { + ID: kong.String("a53e9598-3a5e-4c12-a672-71a4cdcf7a47"), + Name: kong.String("foo.example.com"), + Certificate: &kong.Certificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + }, + }, + { + ID: kong.String("5f8e6848-4cb9-479a-a27e-860e1a77f875"), + Name: kong.String("bar.example.com"), + Certificate: &kong.Certificate{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + b := &stateBuilder{ + targetContent: tt.fields.targetContent, + currentState: tt.fields.currentState, + } + d, _ := utils.GetDefaulter(ctx, defaulterTestOpts) + b.defaulter = d + b.build() + assert.Equal(tt.want, b.rawState) + }) + } +} + +func Test_stateBuilder_caCertificates(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + currentState *state.KongState + targetContent *Content + } + tests := []struct { + name string + fields fields + want *utils.KongRawState + }{ + { + name: "generates ID for a non-existing CACertificate", + fields: fields{ + targetContent: &Content{ + CACertificates: []FCACertificate{ + { + CACertificate: kong.CACertificate{ + Cert: kong.String("foo"), + }, + }, + }, + }, + currentState: emptyState(), + }, + want: &utils.KongRawState{ + CACertificates: []*kong.CACertificate{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Cert: kong.String("foo"), + }, + }, + }, + }, + { + name: "matches ID of an existing CACertificate", + fields: fields{ + targetContent: &Content{ + CACertificates: []FCACertificate{ + { + CACertificate: kong.CACertificate{ + Cert: kong.String("foo"), + }, + }, + }, + }, + currentState: existingCACertificateState(), + }, + want: &utils.KongRawState{ + CACertificates: []*kong.CACertificate{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Cert: kong.String("foo"), + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + b := &stateBuilder{ + targetContent: tt.fields.targetContent, + currentState: tt.fields.currentState, + } + d, _ := utils.GetDefaulter(ctx, defaulterTestOpts) + b.defaulter = d + b.build() + assert.Equal(tt.want, b.rawState) + }) + } +} + +func Test_stateBuilder_upstream(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + targetContent *Content + currentState *state.KongState + } + tests := []struct { + name string + fields fields + want *utils.KongRawState + }{ + { + name: "process a non-existent upstream", + fields: fields{ + targetContent: &Content{ + Info: &Info{ + Defaults: kongDefaults, + }, + Upstreams: []FUpstream{ + { + Upstream: kong.Upstream{ + Name: kong.String("foo"), + Slots: kong.Int(42), + }, + }, + }, + }, + currentState: existingServiceState(), + }, + want: &utils.KongRawState{ + Upstreams: []*kong.Upstream{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo"), + Slots: kong.Int(42), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + }, + }, + { + name: "matches ID of an existing service", + fields: fields{ + targetContent: &Content{ + Info: &Info{ + Defaults: kongDefaults, + }, + Upstreams: []FUpstream{ + { + Upstream: kong.Upstream{ + Name: kong.String("foo"), + }, + }, + }, + }, + currentState: existingUpstreamState(), + }, + want: &utils.KongRawState{ + Upstreams: []*kong.Upstream{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + Slots: kong.Int(10000), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + }, + }, + { + name: "multiple upstreams are handled correctly", + fields: fields{ + targetContent: &Content{ + Info: &Info{ + Defaults: kongDefaults, + }, + Upstreams: []FUpstream{ + { + Upstream: kong.Upstream{ + Name: kong.String("foo"), + }, + }, + { + Upstream: kong.Upstream{ + Name: kong.String("bar"), + }, + }, + }, + }, + currentState: emptyState(), + }, + want: &utils.KongRawState{ + Upstreams: []*kong.Upstream{ + { + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Name: kong.String("foo"), + Slots: kong.Int(10000), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + { + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Name: kong.String("bar"), + Slots: kong.Int(10000), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + }, + }, + { + name: "upstream with new 3.0 fields", + fields: fields{ + targetContent: &Content{ + Info: &Info{ + Defaults: kongDefaults, + }, + Upstreams: []FUpstream{ + { + Upstream: kong.Upstream{ + Name: kong.String("foo"), + Slots: kong.Int(42), + // not actually valid configuration, but this only needs to check that these translate + // into the raw state + HashOnQueryArg: kong.String("foo"), + HashFallbackQueryArg: kong.String("foo"), + HashOnURICapture: kong.String("foo"), + HashFallbackURICapture: kong.String("foo"), + }, + }, + }, + }, + currentState: existingServiceState(), + }, + want: &utils.KongRawState{ + Upstreams: []*kong.Upstream{ + { + ID: kong.String("0cc0d614-4c88-4535-841a-cbe0709b0758"), + Name: kong.String("foo"), + Slots: kong.Int(42), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + HashOnQueryArg: kong.String("foo"), + HashFallbackQueryArg: kong.String("foo"), + HashOnURICapture: kong.String("foo"), + HashFallbackURICapture: kong.String("foo"), + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + b := &stateBuilder{ + targetContent: tt.fields.targetContent, + currentState: tt.fields.currentState, + } + d, _ := utils.GetDefaulter(ctx, defaulterTestOpts) + b.defaulter = d + b.build() + assert.Equal(tt.want, b.rawState) + }) + } +} + +func Test_stateBuilder_documents(t *testing.T) { + assert := assert.New(t) + rand.Seed(42) + type fields struct { + targetContent *Content + currentState *state.KongState + } + tests := []struct { + name string + fields fields + want *utils.KonnectRawState + }{ + { + name: "matches ID of an existing document", + fields: fields{ + targetContent: &Content{ + ServicePackages: []FServicePackage{ + { + Name: kong.String("foo"), + Document: &FDocument{ + Path: kong.String("/foo.md"), + Published: kong.Bool(true), + Content: kong.String("foo"), + }, + }, + }, + }, + currentState: existingDocumentState(), + }, + want: &utils.KonnectRawState{ + Documents: []*konnect.Document{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Path: kong.String("/foo.md"), + Published: kong.Bool(true), + Content: kong.String("foo"), + Parent: &konnect.ServicePackage{ + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + }, + }, + }, + ServicePackages: []*konnect.ServicePackage{ + { + ID: kong.String("4bfcb11f-c962-4817-83e5-9433cf20b663"), + Name: kong.String("foo"), + }, + }, + }, + }, + { + name: "process a non-existent document", + fields: fields{ + targetContent: &Content{ + ServicePackages: []FServicePackage{ + { + Name: kong.String("bar"), + Document: &FDocument{ + Path: kong.String("/bar.md"), + Published: kong.Bool(true), + Content: kong.String("bar"), + }, + }, + }, + }, + currentState: existingDocumentState(), + }, + want: &utils.KonnectRawState{ + Documents: []*konnect.Document{ + { + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Path: kong.String("/bar.md"), + Published: kong.Bool(true), + Content: kong.String("bar"), + Parent: &konnect.ServicePackage{ + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("bar"), + }, + }, + }, + ServicePackages: []*konnect.ServicePackage{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("bar"), + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + b := &stateBuilder{ + targetContent: tt.fields.targetContent, + currentState: tt.fields.currentState, + } + d, _ := utils.GetDefaulter(ctx, defaulterTestOpts) + b.defaulter = d + b.build() + assert.Equal(tt.want, b.konnectRawState) + }) + } +} + +func Test_stateBuilder(t *testing.T) { + assert := assert.New(t) + type fields struct { + targetContent *Content + currentState *state.KongState + } + tests := []struct { + name string + fields fields + want *utils.KongRawState + }{ + { + name: "end to end test with all entities", + fields: fields{ + targetContent: &Content{ + Info: &Info{ + Defaults: kongDefaults, + }, + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("foo-service"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("foo-route1"), + }, + }, + { + Route: kong.Route{ + ID: kong.String("d125e79a-297c-414b-bc00-ad3a87be6c2b"), + Name: kong.String("foo-route2"), + }, + }, + }, + }, + { + Service: kong.Service{ + Name: kong.String("bar-service"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("bar-route1"), + }, + }, + { + Route: kong.Route{ + Name: kong.String("bar-route2"), + }, + }, + }, + }, + { + Service: kong.Service{ + Name: kong.String("large-payload-service"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("dont-buffer-these"), + RequestBuffering: kong.Bool(false), + ResponseBuffering: kong.Bool(false), + }, + }, + { + Route: kong.Route{ + Name: kong.String("buffer-these"), + RequestBuffering: kong.Bool(true), + ResponseBuffering: kong.Bool(true), + }, + }, + }, + }, + }, + Upstreams: []FUpstream{ + { + Upstream: kong.Upstream{ + Name: kong.String("foo"), + Slots: kong.Int(42), + }, + }, + }, + }, + currentState: existingServiceState(), + }, + want: &utils.KongRawState{ + Services: []*kong.Service{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo-service"), + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(60000), + ReadTimeout: kong.Int(60000), + }, + { + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Name: kong.String("bar-service"), + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(60000), + ReadTimeout: kong.Int(60000), + }, + { + ID: kong.String("9e6f82e5-4e74-4e81-a79e-4bbd6fe34cdc"), + Name: kong.String("large-payload-service"), + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(60000), + ReadTimeout: kong.Int(60000), + }, + }, + Routes: []*kong.Route{ + { + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Name: kong.String("foo-route1"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + Service: &kong.Service{ + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo-service"), + }, + }, + { + ID: kong.String("d125e79a-297c-414b-bc00-ad3a87be6c2b"), + Name: kong.String("foo-route2"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + Service: &kong.Service{ + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo-service"), + }, + }, + { + ID: kong.String("0cc0d614-4c88-4535-841a-cbe0709b0758"), + Name: kong.String("bar-route1"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + Service: &kong.Service{ + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Name: kong.String("bar-service"), + }, + }, + { + ID: kong.String("083f61d3-75bc-42b4-9df4-f91929e18fda"), + Name: kong.String("bar-route2"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + Service: &kong.Service{ + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Name: kong.String("bar-service"), + }, + }, + { + ID: kong.String("ba843ee8-d63e-4c4f-be1c-ebea546d8fac"), + Name: kong.String("dont-buffer-these"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + Service: &kong.Service{ + ID: kong.String("9e6f82e5-4e74-4e81-a79e-4bbd6fe34cdc"), + Name: kong.String("large-payload-service"), + }, + RequestBuffering: kong.Bool(false), + ResponseBuffering: kong.Bool(false), + }, + { + ID: kong.String("13dd1aac-04ce-4ea2-877c-5579cfa2c78e"), + Name: kong.String("buffer-these"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + Service: &kong.Service{ + ID: kong.String("9e6f82e5-4e74-4e81-a79e-4bbd6fe34cdc"), + Name: kong.String("large-payload-service"), + }, + RequestBuffering: kong.Bool(true), + ResponseBuffering: kong.Bool(true), + }, + }, + Upstreams: []*kong.Upstream{ + { + ID: kong.String("1b0bafae-881b-42a7-9110-8a42ed3c903c"), + Name: kong.String("foo"), + Slots: kong.Int(42), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + }, + }, + { + name: "entities with configurable defaults", + fields: fields{ + targetContent: &Content{ + Info: &Info{ + Defaults: KongDefaults{ + Route: &kong.Route{ + PathHandling: kong.String("v0"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + RequestBuffering: kong.Bool(false), + }, + Service: &kong.Service{ + Protocol: kong.String("https"), + ConnectTimeout: kong.Int(5000), + WriteTimeout: kong.Int(5000), + ReadTimeout: kong.Int(5000), + }, + Upstream: &kong.Upstream{ + Slots: kong.Int(100), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(5), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + }, + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("foo-service"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("foo-route1"), + }, + }, + { + Route: kong.Route{ + ID: kong.String("d125e79a-297c-414b-bc00-ad3a87be6c2b"), + Name: kong.String("foo-route2"), + }, + }, + }, + }, + { + Service: kong.Service{ + Name: kong.String("bar-service"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("bar-route1"), + }, + }, + { + Route: kong.Route{ + Name: kong.String("bar-route2"), + }, + }, + }, + }, + { + Service: kong.Service{ + Name: kong.String("large-payload-service"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("dont-buffer-these"), + RequestBuffering: kong.Bool(false), + ResponseBuffering: kong.Bool(false), + }, + }, + { + Route: kong.Route{ + Name: kong.String("buffer-these"), + RequestBuffering: kong.Bool(true), + ResponseBuffering: kong.Bool(true), + }, + }, + }, + }, + }, + Upstreams: []FUpstream{ + { + Upstream: kong.Upstream{ + Name: kong.String("foo"), + Slots: kong.Int(42), + }, + }, + }, + }, + currentState: existingServiceState(), + }, + want: &utils.KongRawState{ + Services: []*kong.Service{ + { + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo-service"), + Protocol: kong.String("https"), + ConnectTimeout: kong.Int(5000), + WriteTimeout: kong.Int(5000), + ReadTimeout: kong.Int(5000), + }, + { + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Name: kong.String("bar-service"), + Protocol: kong.String("https"), + ConnectTimeout: kong.Int(5000), + WriteTimeout: kong.Int(5000), + ReadTimeout: kong.Int(5000), + }, + { + ID: kong.String("9e6f82e5-4e74-4e81-a79e-4bbd6fe34cdc"), + Name: kong.String("large-payload-service"), + Protocol: kong.String("https"), + ConnectTimeout: kong.Int(5000), + WriteTimeout: kong.Int(5000), + ReadTimeout: kong.Int(5000), + }, + }, + Routes: []*kong.Route{ + { + ID: kong.String("5b1484f2-5209-49d9-b43e-92ba09dd9d52"), + Name: kong.String("foo-route1"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + RequestBuffering: kong.Bool(false), + PathHandling: kong.String("v0"), + Service: &kong.Service{ + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo-service"), + }, + }, + { + ID: kong.String("d125e79a-297c-414b-bc00-ad3a87be6c2b"), + Name: kong.String("foo-route2"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + RequestBuffering: kong.Bool(false), + PathHandling: kong.String("v0"), + Service: &kong.Service{ + ID: kong.String("538c7f96-b164-4f1b-97bb-9f4bb472e89f"), + Name: kong.String("foo-service"), + }, + }, + { + ID: kong.String("0cc0d614-4c88-4535-841a-cbe0709b0758"), + Name: kong.String("bar-route1"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + RequestBuffering: kong.Bool(false), + PathHandling: kong.String("v0"), + Service: &kong.Service{ + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Name: kong.String("bar-service"), + }, + }, + { + ID: kong.String("083f61d3-75bc-42b4-9df4-f91929e18fda"), + Name: kong.String("bar-route2"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + RequestBuffering: kong.Bool(false), + PathHandling: kong.String("v0"), + Service: &kong.Service{ + ID: kong.String("dfd79b4d-7642-4b61-ba0c-9f9f0d3ba55b"), + Name: kong.String("bar-service"), + }, + }, + { + ID: kong.String("ba843ee8-d63e-4c4f-be1c-ebea546d8fac"), + Name: kong.String("dont-buffer-these"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + PathHandling: kong.String("v0"), + Service: &kong.Service{ + ID: kong.String("9e6f82e5-4e74-4e81-a79e-4bbd6fe34cdc"), + Name: kong.String("large-payload-service"), + }, + RequestBuffering: kong.Bool(false), + ResponseBuffering: kong.Bool(false), + }, + { + ID: kong.String("13dd1aac-04ce-4ea2-877c-5579cfa2c78e"), + Name: kong.String("buffer-these"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + PathHandling: kong.String("v0"), + Service: &kong.Service{ + ID: kong.String("9e6f82e5-4e74-4e81-a79e-4bbd6fe34cdc"), + Name: kong.String("large-payload-service"), + }, + RequestBuffering: kong.Bool(true), + ResponseBuffering: kong.Bool(true), + }, + }, + Upstreams: []*kong.Upstream{ + { + ID: kong.String("1b0bafae-881b-42a7-9110-8a42ed3c903c"), + Name: kong.String("foo"), + Slots: kong.Int(42), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(5), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + rand.Seed(42) + b := &stateBuilder{ + targetContent: tt.fields.targetContent, + currentState: tt.fields.currentState, + } + d, _ := utils.GetDefaulter(ctx, defaulterTestOpts) + b.defaulter = d + b.build() + assert.Equal(tt.want, b.rawState) + }) + } +} + +func Test_stateBuilder_fillPluginConfig(t *testing.T) { + type fields struct { + targetContent *Content + } + type args struct { + plugin *FPlugin + } + tests := []struct { + name string + fields fields + args args + wantErr bool + result FPlugin + }{ + { + name: "nil arg throws an error", + wantErr: true, + }, + { + name: "no _plugin_config throws an error", + fields: fields{ + targetContent: &Content{}, + }, + args: args{ + plugin: &FPlugin{ + ConfigSource: kong.String("foo"), + }, + }, + wantErr: true, + }, + { + name: "no _plugin_config throws an error", + fields: fields{ + targetContent: &Content{ + PluginConfigs: map[string]kong.Configuration{ + "foo": { + "k2": "v3", + "k3:": "v3", + }, + }, + }, + }, + args: args{ + plugin: &FPlugin{ + ConfigSource: kong.String("foo"), + Plugin: kong.Plugin{ + Config: kong.Configuration{ + "k1": "v1", + "k2": "v2", + }, + }, + }, + }, + result: FPlugin{ + ConfigSource: kong.String("foo"), + Plugin: kong.Plugin{ + Config: kong.Configuration{ + "k1": "v1", + "k2": "v2", + "k3:": "v3", + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &stateBuilder{ + targetContent: tt.fields.targetContent, + } + if err := b.fillPluginConfig(tt.args.plugin); (err != nil) != tt.wantErr { + t.Errorf("stateBuilder.fillPluginConfig() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && !reflect.DeepEqual(tt.result, tt.args.plugin) { + assert.Equal(t, tt.result, *tt.args.plugin) + } + }) + } +} + +func Test_getStripPathBasedOnProtocols(t *testing.T) { + tests := []struct { + name string + route kong.Route + wantErr bool + expectedStripPath *bool + }{ + { + name: "true strip_path and grpc protocols", + route: kong.Route{ + Protocols: []*string{kong.String("grpc")}, + StripPath: kong.Bool(true), + }, + wantErr: true, + }, + { + name: "true strip_path and grpcs protocol", + route: kong.Route{ + Protocols: []*string{kong.String("grpcs")}, + StripPath: kong.Bool(true), + }, + wantErr: true, + }, + { + name: "no strip_path and http protocol", + route: kong.Route{ + Protocols: []*string{kong.String("http")}, + }, + expectedStripPath: nil, + }, + { + name: "no strip_path and grpc protocol", + route: kong.Route{ + Protocols: []*string{kong.String("grpc")}, + }, + expectedStripPath: kong.Bool(false), + }, + { + name: "no strip_path and grpcs protocol", + route: kong.Route{ + Protocols: []*string{kong.String("grpcs")}, + }, + expectedStripPath: kong.Bool(false), + }, + { + name: "false strip_path and grpc protocol", + route: kong.Route{ + Protocols: []*string{kong.String("grpc")}, + StripPath: kong.Bool(false), + }, + expectedStripPath: kong.Bool(false), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stripPath, err := getStripPathBasedOnProtocols(tt.route) + if (err != nil) != tt.wantErr { + t.Errorf("getStripPathBasedOnProtocols() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && tt.expectedStripPath != nil { + assert.Equal(t, *tt.expectedStripPath, *stripPath) + } else { + assert.Equal(t, tt.expectedStripPath, stripPath) + } + }) + } +} diff --git a/pkg/file/codegen/.gitignore b/pkg/file/codegen/.gitignore new file mode 100644 index 0000000..e3a1140 --- /dev/null +++ b/pkg/file/codegen/.gitignore @@ -0,0 +1 @@ +codegen diff --git a/pkg/file/codegen/main.go b/pkg/file/codegen/main.go new file mode 100644 index 0000000..b9d5cac --- /dev/null +++ b/pkg/file/codegen/main.go @@ -0,0 +1,121 @@ +package main + +import ( + "encoding/json" + "io/ioutil" + "log" + "reflect" + + "github.com/alecthomas/jsonschema" + "github.com/kong/deck/file" + "github.com/kong/go-kong/kong" +) + +var ( + // routes and services + anyOfNameOrID = []*jsonschema.Type{ + { + Required: []string{"name"}, + }, + { + Required: []string{"id"}, + }, + } + + // consumers + anyOfUsernameOrCustomID = []*jsonschema.Type{ + { + Description: "at least one of custom_id or username must be set", + Required: []string{"username"}, + }, + { + Description: "at least one of custom_id or username must be set", + Required: []string{"custom_id"}, + }, + } +) + +func main() { + var reflector jsonschema.Reflector + reflector.ExpandedStruct = true + reflector.TypeMapper = func(typ reflect.Type) *jsonschema.Type { + // plugin configuration + if typ == reflect.TypeOf(kong.Configuration{}) { + return &jsonschema.Type{ + Type: "object", + Properties: map[string]*jsonschema.Type{}, + AdditionalProperties: []byte("true"), + } + } + return nil + } + schema := reflector.Reflect(file.Content{}) + + schema.Definitions["FService"].AnyOf = anyOfNameOrID + + schema.Definitions["FRoute"].AnyOf = anyOfNameOrID + + schema.Definitions["FConsumer"].AnyOf = anyOfUsernameOrCustomID + + schema.Definitions["FUpstream"].Required = []string{"name"} + + schema.Definitions["FTarget"].Required = []string{"target"} + schema.Definitions["FCACertificate"].Required = []string{"cert"} + schema.Definitions["FPlugin"].Required = []string{"name"} + + schema.Definitions["FCertificate"].Required = []string{"id", "cert", "key"} + schema.Definitions["FCertificate"].Properties["snis"] = &jsonschema.Type{ + Type: "array", + Items: &jsonschema.Type{ + Type: "object", + Properties: map[string]*jsonschema.Type{ + "name": { + Type: "string", + }, + }, + }, + } + + // creds + schema.Definitions["ACLGroup"].Required = []string{"group"} + schema.Definitions["BasicAuth"].Required = []string{"username", "password"} + schema.Definitions["HMACAuth"].Required = []string{"username", "secret"} + schema.Definitions["JWTAuth"].Required = []string{ + "algorithm", "key", + "secret", + } + schema.Definitions["KeyAuth"].Required = []string{"key"} + schema.Definitions["Oauth2Credential"].Required = []string{ + "name", + "client_id", "client_secret", + } + schema.Definitions["MTLSAuth"].Required = []string{"id", "subject_name"} + + // RBAC resources + schema.Definitions["FRBACRole"].Required = []string{"name"} + schema.Definitions["FRBACEndpointPermission"].Required = []string{"workspace", "endpoint"} + + // Foreign references + stringType := &jsonschema.Type{Type: "string"} + schema.Definitions["FPlugin"].Properties["consumer"] = stringType + schema.Definitions["FPlugin"].Properties["service"] = stringType + schema.Definitions["FPlugin"].Properties["route"] = stringType + schema.Definitions["FPlugin"].Properties["consumer_group"] = stringType + + schema.Definitions["FService"].Properties["client_certificate"] = stringType + + // konnect resources + schema.Definitions["FServicePackage"].Required = []string{"name"} + schema.Definitions["FServiceVersion"].Required = []string{"version"} + schema.Definitions["Implementation"].Required = []string{"type", "kong"} + + jsonSchema, err := json.MarshalIndent(schema, "", " ") + if err != nil { + log.Fatalln(err) + } + + err = ioutil.WriteFile("kong_json_schema.json", jsonSchema, 0o644) + if err != nil { + log.Fatalln(err) + } +} diff --git a/pkg/file/kong_json_schema.json b/pkg/file/kong_json_schema.json new file mode 100644 index 0000000..5f9d7b8 --- /dev/null +++ b/pkg/file/kong_json_schema.json @@ -0,0 +1,1808 @@ +{ + "$schema": "http://json-schema.org/draft-04/schema#", + "properties": { + "_format_version": { + "type": "string" + }, + "_info": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Info" + }, + "_konnect": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Konnect" + }, + "_plugin_configs": { + "patternProperties": { + ".*": { + "additionalProperties": true, + "type": "object" + } + }, + "type": "object" + }, + "_transform": { + "type": "boolean" + }, + "_workspace": { + "type": "string" + }, + "ca_certificates": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FCACertificate" + }, + "type": "array" + }, + "certificates": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FCertificate" + }, + "type": "array" + }, + "consumer_groups": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FConsumerGroupObject" + }, + "type": "array" + }, + "consumers": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FConsumer" + }, + "type": "array" + }, + "licenses": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FLicense" + }, + "type": "array" + }, + "plugins": { + "items": { + "$ref": "#/definitions/FPlugin" + }, + "type": "array" + }, + "rbac_roles": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FRBACRole" + }, + "type": "array" + }, + "routes": { + "items": { + "$ref": "#/definitions/FRoute" + }, + "type": "array" + }, + "service_packages": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FServicePackage" + }, + "type": "array" + }, + "services": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FService" + }, + "type": "array" + }, + "upstreams": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FUpstream" + }, + "type": "array" + }, + "vaults": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FVault" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object", + "definitions": { + "ACLGroup": { + "required": [ + "group" + ], + "properties": { + "consumer": { + "$ref": "#/definitions/Consumer" + }, + "created_at": { + "type": "integer" + }, + "group": { + "type": "string" + }, + "id": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "ActiveHealthcheck": { + "properties": { + "concurrency": { + "type": "integer" + }, + "headers": { + "patternProperties": { + ".*": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, + "healthy": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Healthy" + }, + "http_path": { + "type": "string" + }, + "https_sni": { + "type": "string" + }, + "https_verify_certificate": { + "type": "boolean" + }, + "timeout": { + "type": "integer" + }, + "type": { + "type": "string" + }, + "unhealthy": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Unhealthy" + } + }, + "additionalProperties": false, + "type": "object" + }, + "BasicAuth": { + "required": [ + "username", + "password" + ], + "properties": { + "consumer": { + "$ref": "#/definitions/Consumer" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "password": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "username": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "CACertificate": { + "properties": { + "cert": { + "type": "string" + }, + "cert_digest": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "CIDRPort": { + "properties": { + "ip": { + "type": "string" + }, + "port": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Certificate": { + "properties": { + "cert": { + "type": "string" + }, + "cert_alt": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "key": { + "type": "string" + }, + "key_alt": { + "type": "string" + }, + "snis": { + "items": { + "type": "string" + }, + "type": "array" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Consumer": { + "properties": { + "created_at": { + "type": "integer" + }, + "custom_id": { + "type": "string" + }, + "id": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "username": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "ConsumerGroup": { + "properties": { + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "ConsumerGroupPlugin": { + "properties": { + "config": { + "additionalProperties": true, + "type": "object" + }, + "consumer_group": { + "$ref": "#/definitions/ConsumerGroup" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FCACertificate": { + "required": [ + "cert" + ], + "properties": { + "cert": { + "type": "string" + }, + "cert_digest": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FCertificate": { + "required": [ + "id", + "cert", + "key" + ], + "properties": { + "cert": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "key": { + "type": "string" + }, + "snis": { + "items": { + "properties": { + "name": { + "type": "string" + } + }, + "type": "object" + }, + "type": "array" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FConsumer": { + "properties": { + "acls": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/ACLGroup" + }, + "type": "array" + }, + "basicauth_credentials": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/BasicAuth" + }, + "type": "array" + }, + "created_at": { + "type": "integer" + }, + "custom_id": { + "type": "string" + }, + "groups": { + "items": { + "$ref": "#/definitions/ConsumerGroup" + }, + "type": "array" + }, + "hmacauth_credentials": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/HMACAuth" + }, + "type": "array" + }, + "id": { + "type": "string" + }, + "jwt_secrets": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/JWTAuth" + }, + "type": "array" + }, + "keyauth_credentials": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/KeyAuth" + }, + "type": "array" + }, + "mtls_auth_credentials": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/MTLSAuth" + }, + "type": "array" + }, + "oauth2_credentials": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Oauth2Credential" + }, + "type": "array" + }, + "plugins": { + "items": { + "$ref": "#/definitions/FPlugin" + }, + "type": "array" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "username": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object", + "anyOf": [ + { + "required": [ + "username" + ], + "description": "at least one of custom_id or username must be set" + }, + { + "required": [ + "custom_id" + ], + "description": "at least one of custom_id or username must be set" + } + ] + }, + "FConsumerGroupObject": { + "properties": { + "consumers": { + "items": { + "$ref": "#/definitions/Consumer" + }, + "type": "array" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "plugins": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/ConsumerGroupPlugin" + }, + "type": "array" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FDocument": { + "properties": { + "id": { + "type": "string" + }, + "path": { + "type": "string" + }, + "published": { + "type": "boolean" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FLicense": { + "properties": { + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "payload": { + "type": "string" + }, + "updated_at": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FPlugin": { + "required": [ + "name" + ], + "properties": { + "_config": { + "type": "string" + }, + "config": { + "additionalProperties": true, + "type": "object" + }, + "consumer": { + "type": "string" + }, + "consumer_group": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "enabled": { + "type": "boolean" + }, + "id": { + "type": "string" + }, + "instance_name": { + "type": "string" + }, + "name": { + "type": "string" + }, + "ordering": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/PluginOrdering" + }, + "protocols": { + "items": { + "type": "string" + }, + "type": "array" + }, + "route": { + "type": "string" + }, + "run_on": { + "type": "string" + }, + "service": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FRBACEndpointPermission": { + "required": [ + "workspace", + "endpoint" + ], + "properties": { + "actions": { + "items": { + "type": "string" + }, + "type": "array" + }, + "comment": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "endpoint": { + "type": "string" + }, + "negative": { + "type": "boolean" + }, + "role": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/RBACRole" + }, + "workspace": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FRBACRole": { + "required": [ + "name" + ], + "properties": { + "comment": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "endpoint_permissions": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FRBACEndpointPermission" + }, + "type": "array" + }, + "id": { + "type": "string" + }, + "is_default": { + "type": "boolean" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FRoute": { + "properties": { + "created_at": { + "type": "integer" + }, + "destinations": { + "items": { + "$ref": "#/definitions/CIDRPort" + }, + "type": "array" + }, + "expression": { + "type": "string" + }, + "headers": { + "patternProperties": { + ".*": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, + "hosts": { + "items": { + "type": "string" + }, + "type": "array" + }, + "https_redirect_status_code": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "methods": { + "items": { + "type": "string" + }, + "type": "array" + }, + "name": { + "type": "string" + }, + "path_handling": { + "type": "string" + }, + "paths": { + "items": { + "type": "string" + }, + "type": "array" + }, + "plugins": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FPlugin" + }, + "type": "array" + }, + "preserve_host": { + "type": "boolean" + }, + "priority": { + "type": "integer" + }, + "protocols": { + "items": { + "type": "string" + }, + "type": "array" + }, + "regex_priority": { + "type": "integer" + }, + "request_buffering": { + "type": "boolean" + }, + "response_buffering": { + "type": "boolean" + }, + "service": { + "$ref": "#/definitions/Service" + }, + "snis": { + "items": { + "type": "string" + }, + "type": "array" + }, + "sources": { + "items": { + "$ref": "#/definitions/CIDRPort" + }, + "type": "array" + }, + "strip_path": { + "type": "boolean" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "updated_at": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object", + "anyOf": [ + { + "required": [ + "name" + ] + }, + { + "required": [ + "id" + ] + } + ] + }, + "FService": { + "properties": { + "ca_certificates": { + "items": { + "type": "string" + }, + "type": "array" + }, + "client_certificate": { + "type": "string" + }, + "connect_timeout": { + "type": "integer" + }, + "created_at": { + "type": "integer" + }, + "enabled": { + "type": "boolean" + }, + "host": { + "type": "string" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "path": { + "type": "string" + }, + "plugins": { + "items": { + "$ref": "#/definitions/FPlugin" + }, + "type": "array" + }, + "port": { + "type": "integer" + }, + "protocol": { + "type": "string" + }, + "read_timeout": { + "type": "integer" + }, + "retries": { + "type": "integer" + }, + "routes": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FRoute" + }, + "type": "array" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "tls_verify": { + "type": "boolean" + }, + "tls_verify_depth": { + "type": "integer" + }, + "updated_at": { + "type": "integer" + }, + "url": { + "type": "string" + }, + "write_timeout": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object", + "anyOf": [ + { + "required": [ + "name" + ] + }, + { + "required": [ + "id" + ] + } + ] + }, + "FServicePackage": { + "required": [ + "name" + ], + "properties": { + "description": { + "type": "string" + }, + "document": { + "$ref": "#/definitions/FDocument" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "versions": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FServiceVersion" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FServiceVersion": { + "required": [ + "version" + ], + "properties": { + "document": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FDocument" + }, + "id": { + "type": "string" + }, + "implementation": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Implementation" + }, + "version": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FTarget": { + "required": [ + "target" + ], + "properties": { + "created_at": { + "type": "number" + }, + "id": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "target": { + "type": "string" + }, + "upstream": { + "$ref": "#/definitions/Upstream" + }, + "weight": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FUpstream": { + "required": [ + "name" + ], + "properties": { + "algorithm": { + "type": "string" + }, + "client_certificate": { + "$ref": "#/definitions/Certificate" + }, + "created_at": { + "type": "integer" + }, + "hash_fallback": { + "type": "string" + }, + "hash_fallback_header": { + "type": "string" + }, + "hash_fallback_query_arg": { + "type": "string" + }, + "hash_fallback_uri_capture": { + "type": "string" + }, + "hash_on": { + "type": "string" + }, + "hash_on_cookie": { + "type": "string" + }, + "hash_on_cookie_path": { + "type": "string" + }, + "hash_on_header": { + "type": "string" + }, + "hash_on_query_arg": { + "type": "string" + }, + "hash_on_uri_capture": { + "type": "string" + }, + "healthchecks": { + "$ref": "#/definitions/Healthcheck" + }, + "host_header": { + "type": "string" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "slots": { + "type": "integer" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "targets": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FTarget" + }, + "type": "array" + }, + "use_srv_name": { + "type": "boolean" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FVault": { + "properties": { + "config": { + "additionalProperties": true, + "type": "object" + }, + "created_at": { + "type": "integer" + }, + "description": { + "type": "string" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "prefix": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "updated_at": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "HMACAuth": { + "required": [ + "username", + "secret" + ], + "properties": { + "consumer": { + "$ref": "#/definitions/Consumer" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "secret": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "username": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Healthcheck": { + "properties": { + "active": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/ActiveHealthcheck" + }, + "passive": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/PassiveHealthcheck" + }, + "threshold": { + "type": "number" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Healthy": { + "properties": { + "http_statuses": { + "items": { + "type": "integer" + }, + "type": "array" + }, + "interval": { + "type": "integer" + }, + "successes": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Implementation": { + "required": [ + "type", + "kong" + ], + "properties": { + "kong": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Kong" + }, + "type": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Info": { + "properties": { + "defaults": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/KongDefaults" + }, + "select_tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "JWTAuth": { + "required": [ + "algorithm", + "key", + "secret" + ], + "properties": { + "algorithm": { + "type": "string" + }, + "consumer": { + "$ref": "#/definitions/Consumer" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "key": { + "type": "string" + }, + "rsa_public_key": { + "type": "string" + }, + "secret": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "KeyAuth": { + "required": [ + "key" + ], + "properties": { + "consumer": { + "$ref": "#/definitions/Consumer" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "key": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "ttl": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Kong": { + "properties": { + "service": { + "$ref": "#/definitions/FService" + } + }, + "additionalProperties": false, + "type": "object" + }, + "KongDefaults": { + "properties": { + "route": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Route" + }, + "service": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Service" + }, + "target": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Target" + }, + "upstream": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Upstream" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Konnect": { + "properties": { + "control_plane_name": { + "type": "string" + }, + "runtime_group_name": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "MTLSAuth": { + "required": [ + "id", + "subject_name" + ], + "properties": { + "ca_certificate": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/CACertificate" + }, + "consumer": { + "$ref": "#/definitions/Consumer" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "subject_name": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Oauth2Credential": { + "required": [ + "name", + "client_id", + "client_secret" + ], + "properties": { + "client_id": { + "type": "string" + }, + "client_secret": { + "type": "string" + }, + "client_type": { + "type": "string" + }, + "consumer": { + "$ref": "#/definitions/Consumer" + }, + "created_at": { + "type": "integer" + }, + "hash_secret": { + "type": "boolean" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "redirect_uris": { + "items": { + "type": "string" + }, + "type": "array" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "PassiveHealthcheck": { + "properties": { + "healthy": { + "$ref": "#/definitions/Healthy" + }, + "type": { + "type": "string" + }, + "unhealthy": { + "$ref": "#/definitions/Unhealthy" + } + }, + "additionalProperties": false, + "type": "object" + }, + "PluginOrdering": { + "properties": { + "after": { + "patternProperties": { + ".*": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, + "before": { + "patternProperties": { + ".*": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + } + }, + "additionalProperties": false, + "type": "object" + }, + "RBACRole": { + "properties": { + "comment": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "is_default": { + "type": "boolean" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Route": { + "properties": { + "created_at": { + "type": "integer" + }, + "destinations": { + "items": { + "$ref": "#/definitions/CIDRPort" + }, + "type": "array" + }, + "expression": { + "type": "string" + }, + "headers": { + "patternProperties": { + ".*": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + }, + "hosts": { + "items": { + "type": "string" + }, + "type": "array" + }, + "https_redirect_status_code": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "methods": { + "items": { + "type": "string" + }, + "type": "array" + }, + "name": { + "type": "string" + }, + "path_handling": { + "type": "string" + }, + "paths": { + "items": { + "type": "string" + }, + "type": "array" + }, + "preserve_host": { + "type": "boolean" + }, + "priority": { + "type": "integer" + }, + "protocols": { + "items": { + "type": "string" + }, + "type": "array" + }, + "regex_priority": { + "type": "integer" + }, + "request_buffering": { + "type": "boolean" + }, + "response_buffering": { + "type": "boolean" + }, + "service": { + "$ref": "#/definitions/Service" + }, + "snis": { + "items": { + "type": "string" + }, + "type": "array" + }, + "sources": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/CIDRPort" + }, + "type": "array" + }, + "strip_path": { + "type": "boolean" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "updated_at": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "SNI": { + "properties": { + "certificate": { + "$ref": "#/definitions/Certificate" + }, + "created_at": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Service": { + "properties": { + "ca_certificates": { + "items": { + "type": "string" + }, + "type": "array" + }, + "client_certificate": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Certificate" + }, + "connect_timeout": { + "type": "integer" + }, + "created_at": { + "type": "integer" + }, + "enabled": { + "type": "boolean" + }, + "host": { + "type": "string" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "path": { + "type": "string" + }, + "port": { + "type": "integer" + }, + "protocol": { + "type": "string" + }, + "read_timeout": { + "type": "integer" + }, + "retries": { + "type": "integer" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "tls_verify": { + "type": "boolean" + }, + "tls_verify_depth": { + "type": "integer" + }, + "updated_at": { + "type": "integer" + }, + "url": { + "type": "string" + }, + "write_timeout": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Target": { + "properties": { + "created_at": { + "type": "number" + }, + "id": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "target": { + "type": "string" + }, + "upstream": { + "$ref": "#/definitions/Upstream" + }, + "weight": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Unhealthy": { + "properties": { + "http_failures": { + "type": "integer" + }, + "http_statuses": { + "items": { + "type": "integer" + }, + "type": "array" + }, + "interval": { + "type": "integer" + }, + "tcp_failures": { + "type": "integer" + }, + "timeouts": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object" + }, + "Upstream": { + "properties": { + "algorithm": { + "type": "string" + }, + "client_certificate": { + "$ref": "#/definitions/Certificate" + }, + "created_at": { + "type": "integer" + }, + "hash_fallback": { + "type": "string" + }, + "hash_fallback_header": { + "type": "string" + }, + "hash_fallback_query_arg": { + "type": "string" + }, + "hash_fallback_uri_capture": { + "type": "string" + }, + "hash_on": { + "type": "string" + }, + "hash_on_cookie": { + "type": "string" + }, + "hash_on_cookie_path": { + "type": "string" + }, + "hash_on_header": { + "type": "string" + }, + "hash_on_query_arg": { + "type": "string" + }, + "hash_on_uri_capture": { + "type": "string" + }, + "healthchecks": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/Healthcheck" + }, + "host_header": { + "type": "string" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "slots": { + "type": "integer" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "use_srv_name": { + "type": "boolean" + } + }, + "additionalProperties": false, + "type": "object" + } + } +} \ No newline at end of file diff --git a/pkg/file/konnect.go b/pkg/file/konnect.go new file mode 100644 index 0000000..46b4abe --- /dev/null +++ b/pkg/file/konnect.go @@ -0,0 +1,64 @@ +package file + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// PopulateDocumentContent updates the Documents contained within a Content with the +// contents of their files on disk. Document files are stored at +// /ServicePackage.Name/Document.Path and /ServicePackage.Name/ServiceVersion.Version/Document.Path, +// where is the directory containing the first state file. +func (c Content) PopulateDocumentContent(filenames []string) error { + if len(filenames) == 0 { + return fmt.Errorf("cannot populate documents without a location") + } + // TODO decK actually allows you to use _multiple_ state files + // We currently choose the first arbitrarily and assume document content is under its directory + // Future plans are to rework the multiple state file functionality to require all state files + // be in the same directory. + root := filepath.Dir(filenames[0]) + for _, sp := range c.ServicePackages { + if sp.Document != nil { + path := filepath.Join(root, utils.FilenameToName(*sp.Document.Path)) + content, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("error reading document file: %w", err) + } + sp.Document.Content = kong.String(string(content)) + } + for _, sv := range sp.Versions { + if sv.Document != nil { + path := filepath.Join(root, utils.FilenameToName(*sv.Document.Path)) + content, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("error reading document file: %w", err) + } + sv.Document.Content = kong.String(string(content)) + } + } + } + return nil +} + +// StripLocalDocumentPath removes local path information from a target state document, returning the base path with a +// prepended slash. These path values match typical path values for documents created in the Konnect GUI, whereas path +// values in decK state files are local relative paths with service package and service version directories. +func (c Content) StripLocalDocumentPath() { + for _, sp := range c.ServicePackages { + if sp.Document != nil { + trunc := "/" + filepath.Base(utils.FilenameToName(*sp.Document.Path)) + sp.Document.Path = &trunc + } + for _, sv := range sp.Versions { + if sv.Document != nil { + trunc := "/" + filepath.Base(utils.FilenameToName(*sv.Document.Path)) + sv.Document.Path = &trunc + } + } + } +} diff --git a/pkg/file/reader.go b/pkg/file/reader.go new file mode 100644 index 0000000..ad0c2b8 --- /dev/null +++ b/pkg/file/reader.go @@ -0,0 +1,137 @@ +package file + +import ( + "context" + "fmt" + + "github.com/blang/semver/v4" + "github.com/kong/deck/dump" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +var ( + // ErrorTransformFalseNotSupported indicates that no transform mode is not supported + ErrorTransformFalseNotSupported = fmt.Errorf("_transform: false is not supported") + // ErrorFilenameEmpty indicates that you must provide a filename + ErrorFilenameEmpty = fmt.Errorf("filename cannot be empty") +) + +// RenderConfig contains necessary information to render a correct +// KongConfig from a file. +type RenderConfig struct { + CurrentState *state.KongState + KongVersion semver.Version +} + +// GetContentFromFiles reads in a file with a slice of filenames and constructs +// a state. If filename is `-`, then it will read from os.Stdin. +// If filename represents a directory, it will traverse the tree +// rooted at filename, read all the files with .yaml, .yml and .json extensions +// and generate a content after a merge of the content from all the files. +// +// It will return an error if the file representation is invalid +// or if there is any error during processing. +func GetContentFromFiles(filenames []string, mockEnvVars bool) (*Content, error) { + if len(filenames) == 0 { + return nil, ErrorFilenameEmpty + } + + return getContent(filenames, mockEnvVars) +} + +// GetForKonnect processes the fileContent and renders a RawState and KonnectRawState +func GetForKonnect(ctx context.Context, fileContent *Content, + opt RenderConfig, client *kong.Client, +) (*utils.KongRawState, *utils.KonnectRawState, error) { + var builder stateBuilder + // setup + builder.targetContent = fileContent + builder.currentState = opt.CurrentState + builder.kongVersion = opt.KongVersion + builder.client = client + builder.ctx = ctx + builder.disableDynamicDefaults = true + + if fileContent.Transform != nil && !*fileContent.Transform { + return nil, nil, ErrorTransformFalseNotSupported + } + + kongState, konnectState, err := builder.build() + if err != nil { + return nil, nil, fmt.Errorf("building state: %w", err) + } + return kongState, konnectState, nil +} + +// Get process the fileContent and renders a RawState. +// IDs of entities are matches based on currentState. +func Get(ctx context.Context, fileContent *Content, opt RenderConfig, dumpConfig dump.Config, wsClient *kong.Client) ( + *utils.KongRawState, error, +) { + var builder stateBuilder + // setup + builder.targetContent = fileContent + builder.currentState = opt.CurrentState + builder.kongVersion = opt.KongVersion + builder.client = wsClient + builder.ctx = ctx + builder.skipCACerts = dumpConfig.SkipCACerts + builder.isKonnect = dumpConfig.KonnectControlPlane != "" + + if len(dumpConfig.SelectorTags) > 0 { + builder.selectTags = dumpConfig.SelectorTags + } + + if fileContent.Transform != nil && !*fileContent.Transform { + return nil, ErrorTransformFalseNotSupported + } + + state, _, err := builder.build() + if err != nil { + return nil, fmt.Errorf("building state: %w", err) + } + return state, nil +} + +func ensureJSON(m map[string]interface{}) map[string]interface{} { + res := map[string]interface{}{} + for k, v := range m { + switch v2 := v.(type) { + case map[interface{}]interface{}: + res[fmt.Sprint(k)] = yamlToJSON(v2) + case []interface{}: + var array []interface{} + for _, element := range v2 { + switch el := element.(type) { + case map[interface{}]interface{}: + array = append(array, yamlToJSON(el)) + default: + array = append(array, el) + } + } + if array != nil { + res[fmt.Sprint(k)] = array + } else { + res[fmt.Sprint(k)] = v + } + default: + res[fmt.Sprint(k)] = v + } + } + return res +} + +func yamlToJSON(m map[interface{}]interface{}) map[string]interface{} { + res := map[string]interface{}{} + for k, v := range m { + switch v2 := v.(type) { + case map[interface{}]interface{}: + res[fmt.Sprint(k)] = yamlToJSON(v2) + default: + res[fmt.Sprint(k)] = v + } + } + return res +} diff --git a/pkg/file/reader_test.go b/pkg/file/reader_test.go new file mode 100644 index 0000000..aae4861 --- /dev/null +++ b/pkg/file/reader_test.go @@ -0,0 +1,169 @@ +package file + +import ( + "bytes" + "context" + "io/ioutil" + "os" + "reflect" + "testing" + + "github.com/kong/deck/dump" + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func Test_ensureJSON(t *testing.T) { + type args struct { + m map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + }{ + { + "empty array is kept as is", + args{map[string]interface{}{ + "foo": []interface{}{}, + }}, + map[string]interface{}{ + "foo": []interface{}{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ensureJSON(tt.args.m); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ensureJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestReadKongStateFromStdinFailsToParseText(t *testing.T) { + filenames := []string{"-"} + assert := assert.New(t) + assert.Equal("-", filenames[0]) + + var content bytes.Buffer + content.Write([]byte("hunter2\n")) + + tmpfile, err := ioutil.TempFile("", "example") + if err != nil { + panic(err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write(content.Bytes()); err != nil { + panic(err) + } + + if _, err := tmpfile.Seek(0, 0); err != nil { + panic(err) + } + + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() // Restore original Stdin + + os.Stdin = tmpfile + + c, err := GetContentFromFiles(filenames, false) + assert.NotNil(err) + assert.Nil(c) +} + +func TestTransformNotFalse(t *testing.T) { + filenames := []string{"-"} + assert := assert.New(t) + + tmpfile, err := ioutil.TempFile("", "example") + if err != nil { + panic(err) + } + defer os.Remove(tmpfile.Name()) + + _, err = tmpfile.WriteString("_transform: false\nservices:\n- host: test.com\n name: test service\n") + if err != nil { + panic(err) + } + + if _, err := tmpfile.Seek(0, 0); err != nil { + panic(err) + } + + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() // Restore original Stdin + + os.Stdin = tmpfile + + c, err := GetContentFromFiles(filenames, false) + if err != nil { + panic(err) + } + + ctx := context.Background() + parsed, err := Get(ctx, c, RenderConfig{}, dump.Config{}, nil) + assert.Equal(err, ErrorTransformFalseNotSupported) + assert.Nil(parsed) + + parsed, _, err = GetForKonnect(ctx, c, RenderConfig{}, nil) + assert.Equal(err, ErrorTransformFalseNotSupported) + assert.Nil(parsed) +} + +func TestReadKongStateFromStdin(t *testing.T) { + filenames := []string{"-"} + assert := assert.New(t) + assert.Equal("-", filenames[0]) + + var content bytes.Buffer + content.Write([]byte("services:\n- host: test.com\n name: test service\n")) + + tmpfile, err := ioutil.TempFile("", "example") + if err != nil { + panic(err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write(content.Bytes()); err != nil { + panic(err) + } + + if _, err := tmpfile.Seek(0, 0); err != nil { + panic(err) + } + + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() // Restore original Stdin + + os.Stdin = tmpfile + + c, err := GetContentFromFiles(filenames, false) + assert.NotNil(c) + assert.Nil(err) + + assert.Equal(kong.Service{ + Name: kong.String("test service"), + Host: kong.String("test.com"), + }, + c.Services[0].Service) +} + +func TestReadKongStateFromFile(t *testing.T) { + filenames := []string{"testdata/config.yaml"} + assert := assert.New(t) + assert.Equal("testdata/config.yaml", filenames[0]) + + c, err := GetContentFromFiles(filenames, false) + assert.NotNil(c) + assert.Nil(err) + + t.Run("enabled field for service is read", func(t *testing.T) { + assert.Equal(kong.Service{ + Name: kong.String("svc1"), + Host: kong.String("mockbin.org"), + Enabled: kong.Bool(true), + }, c.Services[0].Service) + }) +} diff --git a/pkg/file/readfile.go b/pkg/file/readfile.go new file mode 100644 index 0000000..165f7b6 --- /dev/null +++ b/pkg/file/readfile.go @@ -0,0 +1,236 @@ +package file + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "strconv" + "strings" + "text/template" + + "github.com/imdario/mergo" + "github.com/kong/deck/utils" + "sigs.k8s.io/yaml" +) + +// getContent reads all the YAML and JSON files in the directory or the +// file, depending on the type of each item in filenames, merges the content of +// these files and renders a Content. +func getContent(filenames []string, mockEnvVars bool) (*Content, error) { + var workspaces, runtimeGroups []string + var res Content + var errs []error + for _, fileOrDir := range filenames { + readers, err := getReaders(fileOrDir) + if err != nil { + return nil, err + } + + for filename, r := range readers { + content, err := readContent(r, mockEnvVars) + if err != nil { + errs = append(errs, fmt.Errorf("reading file %s: %w", filename, err)) + continue + } + if content.Workspace != "" { + workspaces = append(workspaces, content.Workspace) + } + if content.Konnect != nil && len(content.Konnect.RuntimeGroupName) > 0 { + runtimeGroups = append(runtimeGroups, content.Konnect.RuntimeGroupName) + } + err = mergo.Merge(&res, content, mergo.WithAppendSlice) + if err != nil { + return nil, fmt.Errorf("merging file contents: %w", err) + } + } + } + if len(errs) > 0 { + return nil, utils.ErrArray{Errors: errs} + } + if err := validateWorkspaces(workspaces); err != nil { + return nil, err + } + if err := validateRuntimeGroups(runtimeGroups); err != nil { + return nil, err + } + return &res, nil +} + +// getReaders returns back a map of filename:io.Reader representing all the +// YAML and JSON files in a directory. If fileOrDir is a single file, then it +// returns back the reader for the file. +// If fileOrDir is equal to "-" string, then it returns back a io.Reader +// for the os.Stdin file descriptor. +func getReaders(fileOrDir string) (map[string]io.Reader, error) { + // special case where `-` means stdin + if fileOrDir == "-" { + return map[string]io.Reader{"STDIN": os.Stdin}, nil + } + + finfo, err := os.Stat(fileOrDir) + if err != nil { + return nil, fmt.Errorf("reading state file: %w", err) + } + + var files []string + if finfo.IsDir() { + files, err = utils.ConfigFilesInDir(fileOrDir) + if err != nil { + return nil, fmt.Errorf("getting files from directory: %w", err) + } + } else { + files = append(files, fileOrDir) + } + + res := make(map[string]io.Reader, len(files)) + for _, file := range files { + f, err := os.Open(file) + if err != nil { + return nil, fmt.Errorf("opening file: %w", err) + } + res[file] = bufio.NewReader(f) + } + return res, nil +} + +func hasLeadingSpace(fileContent string) bool { + if fileContent != "" && string(fileContent[0]) == " " { + return true + } + return false +} + +// readContent reads all the byes until io.EOF and unmarshals the read +// bytes into Content. +func readContent(reader io.Reader, mockEnvVars bool) (*Content, error) { + var err error + contentBytes, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + renderedContent, err := renderTemplate(string(contentBytes), mockEnvVars) + if err != nil { + return nil, fmt.Errorf("parsing file: %w", err) + } + // go-yaml implementation fails at correctly parsing a file whose first + // character is a space, as shown in https://github.com/Kong/deck/issues/578 + // If that is the case here, raise an error. + if hasLeadingSpace(renderedContent) { + return nil, fmt.Errorf("file must not begin with a whitespace") + } + renderedContentBytes := []byte(renderedContent) + err = validate(renderedContentBytes) + if err != nil { + return nil, fmt.Errorf("validating file content: %w", err) + } + var result Content + err = yamlUnmarshal(renderedContentBytes, &result) + if err != nil { + return nil, err + } + return &result, nil +} + +// yamlUnmarshal is a wrapper around yaml.Unmarshal to ensure that the right +// yaml package is in use. Using ghodss/yaml ensures that no +// `map[interface{}]interface{}` is present in go-kong.Plugin.Configuration. +// If it is present, then it leads to a silent error. See Github Issue #144. +// The verification for this is done using a test. +func yamlUnmarshal(bytes []byte, v interface{}) error { + return yaml.Unmarshal(bytes, v) +} + +func getPrefixedEnvVar(key string) (string, error) { + const envVarPrefix = "DECK_" + if !strings.HasPrefix(key, envVarPrefix) { + return "", fmt.Errorf("environment variables in the state file must "+ + "be prefixed with 'DECK_', found: '%s'", key) + } + value, exists := os.LookupEnv(key) + if !exists { + return "", fmt.Errorf("environment variable '%s' present in state file but not set", key) + } + return value, nil +} + +// getPrefixedEnvVarMocked is used when we mock the env variables while rendering a template. +// It will always return the name of the environment variable in this case. +func getPrefixedEnvVarMocked(key string) (string, error) { + const envVarPrefix = "DECK_" + if !strings.HasPrefix(key, envVarPrefix) { + return "", fmt.Errorf("environment variables in the state file must "+ + "be prefixed with 'DECK_', found: '%s'", key) + } + return key, nil +} + +func toBool(key string) (bool, error) { + return strconv.ParseBool(key) +} + +// toBoolMocked is used when we mock the env variables while rendering a template. +// It will always return false in this case. +func toBoolMocked(_ string) (bool, error) { + return false, nil +} + +func toInt(key string) (int, error) { + return strconv.Atoi(key) +} + +// toIntMocked is used when we mock the env variables while rendering a template. +// It will always return 42 in this case. +func toIntMocked(_ string) (int, error) { + return 42, nil +} + +func toFloat(key string) (float64, error) { + return strconv.ParseFloat(key, 64) +} + +// toFloatMocked is used when we mock the env variables while rendering a template. +// It will always return 42 in this case. +func toFloatMocked(_ string) (float64, error) { + return 42, nil +} + +func indent(spaces int, v string) string { + pad := strings.Repeat(" ", spaces) + return strings.Replace(v, "\n", "\n"+pad, -1) +} + +func renderTemplate(content string, mockEnvVars bool) (string, error) { + var templateFuncs template.FuncMap + if mockEnvVars { + templateFuncs = template.FuncMap{ + "env": getPrefixedEnvVarMocked, + "toBool": toBoolMocked, + "toInt": toIntMocked, + "toFloat": toFloatMocked, + "indent": indent, + } + } else { + templateFuncs = template.FuncMap{ + "env": getPrefixedEnvVar, + "toBool": toBool, + "toInt": toInt, + "toFloat": toFloat, + "indent": indent, + } + } + t := template.New("state").Funcs(templateFuncs).Delims("${{", "}}") + + t, err := t.Parse(content) + if err != nil { + return "", err + } + var buffer bytes.Buffer + err = t.Execute(&buffer, nil) + if err != nil { + return "", err + } + return buffer.String(), nil +} diff --git a/pkg/file/readfile_test.go b/pkg/file/readfile_test.go new file mode 100644 index 0000000..573f0aa --- /dev/null +++ b/pkg/file/readfile_test.go @@ -0,0 +1,636 @@ +package file + +import ( + "io" + "os" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +func Test_configFilesInDir(t *testing.T) { + type args struct { + dir string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "empty directory", + args: args{"testdata/emptydir"}, + want: nil, + wantErr: false, + }, + { + name: "directory does not exist", + args: args{"testdata/does-not-exist"}, + want: nil, + wantErr: true, + }, + { + name: "valid directory", + args: args{"testdata/emptyfiles"}, + want: []string{ + "testdata/emptyfiles/Baz.YamL", + "testdata/emptyfiles/bar.yaml", + "testdata/emptyfiles/foo.yml", + "testdata/emptyfiles/foobar.json", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := utils.ConfigFilesInDir(tt.args.dir) + if (err != nil) != tt.wantErr { + t.Errorf("configFilesInDir() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("configFilesInDir() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getReaders(t *testing.T) { + type args struct { + fileOrDir string + } + tests := []struct { + name string + args args + want map[string]io.Reader + // length of returned array + wantLen int + wantErr bool + }{ + { + name: "read from standard input", + args: args{"-"}, + want: map[string]io.Reader{ + "STDIN": os.Stdin, + }, + wantLen: 1, + wantErr: false, + }, + { + name: "directory does not exist", + args: args{"testdata/does-not-exist"}, + want: nil, + wantLen: 0, + wantErr: true, + }, + { + name: "valid directory", + args: args{"testdata/emptyfiles"}, + want: nil, + wantLen: 4, + wantErr: false, + }, + { + name: "valid file", + args: args{"testdata/file.yaml"}, + want: nil, + wantLen: 1, + wantErr: false, + }, + { + name: "valid JSON file", + args: args{"testdata/file.json"}, + want: nil, + wantLen: 1, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getReaders(tt.args.fileOrDir) + if (err != nil) != tt.wantErr { + t.Errorf("getReaders() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantLen != len(got) { + t.Errorf("getReaders() mismatch in returned length: "+ + "want = %v, got = %v", tt.wantLen, len(got)) + return + } + if tt.want != nil && !reflect.DeepEqual(got, tt.want) { + t.Errorf("getReaders() = %v, want %v", got, tt.want) + } + }) + } +} + +func sortSlices(x, y interface{}) bool { + var xName, yName string + switch xEntity := x.(type) { + case FService: + yEntity := y.(FService) + xName = *xEntity.Name + yName = *yEntity.Name + case FRoute: + yEntity := y.(FRoute) + xName = *xEntity.Name + yName = *yEntity.Name + case FConsumer: + yEntity := y.(FConsumer) + xName = *xEntity.Username + yName = *yEntity.Username + case FPlugin: + yEntity := y.(FPlugin) + xName = *xEntity.Name + yName = *yEntity.Name + } + return xName < yName +} + +func Test_getContent(t *testing.T) { + type args struct { + filenames []string + } + tests := []struct { + name string + args args + envVars map[string]string + want *Content + wantErr bool + }{ + { + name: "directory does not exist", + args: args{[]string{"testdata/does-not-exist"}}, + want: nil, + wantErr: true, + }, + { + name: "empty directory", + args: args{[]string{"testdata/emptydir"}}, + want: &Content{}, + wantErr: false, + }, + { + name: "directory with empty files", + args: args{[]string{"testdata/emptyfiles"}}, + want: &Content{}, + wantErr: false, + }, + { + name: "bad yaml", + args: args{[]string{"testdata/badyaml"}}, + want: nil, + wantErr: true, + }, + { + name: "bad JSON", + args: args{[]string{"testdata/badjson"}}, + want: nil, + wantErr: true, + }, + { + name: "single file", + args: args{[]string{"testdata/file.yaml"}}, + envVars: map[string]string{ + "DECK_SVC2_HOST": "2.example.com", + "DECK_FILE_LOG_FUNCTION": ` +function parse_traceid(str)str = string.sub(str,1,8) + local uint = 0 + for i = 1, #str do + uint = uint + str:byte(i) * 0x100^(i-1) + end + return string.format("%.0f", uint) +end + +kong.log.set_serialize_value("trace_id", parse_traceid(ngx.ctx.KONG_SPANS[1].trace_id)) +kong.log.set_serialize_value("span_id", parse_traceid(ngx.ctx.KONG_SPANS[1].span_id))`, + }, + want: &Content{ + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("svc2"), + Host: kong.String("2.example.com"), + Tags: kong.StringSlice("<"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("r2"), + Paths: kong.StringSlice("/r2"), + }, + }, + }, + }, + }, + Plugins: []FPlugin{ + { + Plugin: kong.Plugin{ + Name: kong.String("prometheus"), + }, + }, + { + Plugin: kong.Plugin{ + Name: kong.String("pre-function"), + Config: kong.Configuration{ + "log": ` +function parse_traceid(str)str = string.sub(str,1,8) + local uint = 0 + for i = 1, #str do + uint = uint + str:byte(i) * 0x100^(i-1) + end + return string.format("%.0f", uint) +end + +kong.log.set_serialize_value("trace_id", parse_traceid(ngx.ctx.KONG_SPANS[1].trace_id)) +kong.log.set_serialize_value("span_id", parse_traceid(ngx.ctx.KONG_SPANS[1].span_id)) +`, + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "environment variable present in file but not set", + args: args{[]string{"testdata/file.yaml"}}, + wantErr: true, + }, + { + name: "file with bad environment variable", + args: args{[]string{"testdata/bad-env-var/file.yaml"}}, + wantErr: true, + }, + { + name: "invalid file due to leading space", + args: args{[]string{"testdata/badyamlwithspace/bar.yml"}}, + wantErr: true, + }, + { + name: "multiple files", + args: args{[]string{"testdata/file.yaml", "testdata/file.json"}}, + envVars: map[string]string{ + "DECK_SVC2_HOST": "2.example.com", + "DECK_FILE_LOG_FUNCTION": "kong.log.set_serialize_value('trace_id', 1))", + }, + want: &Content{ + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("svc2"), + Host: kong.String("2.example.com"), + Tags: kong.StringSlice("<"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("r2"), + Paths: kong.StringSlice("/r2"), + }, + }, + }, + }, + }, + Plugins: []FPlugin{ + { + Plugin: kong.Plugin{ + Name: kong.String("prometheus"), + }, + }, + { + Plugin: kong.Plugin{ + Name: kong.String("pre-function"), + Config: kong.Configuration{ + "log": "kong.log.set_serialize_value('trace_id', 1))\n", + }, + }, + }, + }, + Consumers: []FConsumer{ + { + Consumer: kong.Consumer{ + Username: kong.String("foo"), + }, + }, + { + Consumer: kong.Consumer{ + Username: kong.String("bar"), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "valid directory", + args: args{[]string{"testdata/valid"}}, + want: &Content{ + Info: &Info{ + SelectorTags: []string{"tag1"}, + }, + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("svc2"), + Host: kong.String("2.example.com"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("r2"), + Paths: kong.StringSlice("/r2"), + }, + }, + }, + }, + { + Service: kong.Service{ + Name: kong.String("svc1"), + Host: kong.String("1.example.com"), + Tags: kong.StringSlice("team-svc1"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("r1"), + Paths: kong.StringSlice("/r1"), + }, + }, + }, + }, + }, + Consumers: []FConsumer{ + { + Consumer: kong.Consumer{ + Username: kong.String("foo"), + }, + }, + { + Consumer: kong.Consumer{ + Username: kong.String("bar"), + }, + }, + { + Consumer: kong.Consumer{ + Username: kong.String("harry"), + }, + }, + }, + Plugins: []FPlugin{ + { + Plugin: kong.Plugin{ + Name: kong.String("prometheus"), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "different workspaces", + args: args{[]string{"testdata/differentworkspace"}}, + want: nil, + wantErr: true, + }, + { + name: "different runtime groups", + args: args{[]string{"testdata/differentruntimegroup"}}, + want: nil, + wantErr: true, + }, + { + name: "same workspaces", + args: args{[]string{"testdata/sameworkspace"}}, + want: &Content{ + FormatVersion: *kong.String("1.1"), + Workspace: *kong.String("bar"), + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("svc2"), + Host: kong.String("2.example.com"), + Tags: kong.StringSlice("team-svc2"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("r2"), + Paths: kong.StringSlice("/r2"), + }, + }, + }, + }, + { + Service: kong.Service{ + Name: kong.String("svc1"), + Host: kong.String("1.example.com"), + Tags: kong.StringSlice("team-svc1"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("r1"), + Paths: kong.StringSlice("/r1"), + }, + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "defaults", + args: args{[]string{"testdata/defaults"}}, + want: &Content{ + FormatVersion: "1.1", + Upstreams: []FUpstream{ + { + Upstream: kong.Upstream{ + Name: kong.String("upstream1"), + Algorithm: kong.String("round-robin"), + }, + Targets: []*FTarget{ + { + Target: kong.Target{ + Target: kong.String("198.51.100.11:80"), + Weight: kong.Int(0), + }, + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "shared workspace", + args: args{[]string{"testdata/sharedworkspace"}}, + want: &Content{ + FormatVersion: *kong.String("1.1"), + Workspace: *kong.String("bar"), + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("svc1"), + Host: kong.String("1.example.com"), + Tags: kong.StringSlice("team-svc1"), + }, + Routes: []*FRoute{ + { + Route: kong.Route{ + Name: kong.String("r1"), + Paths: kong.StringSlice("/r1"), + }, + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "file with env var and parse bool", + args: args{[]string{"testdata/parsebool/file.yaml"}}, + envVars: map[string]string{ + "DECK_MOCKBIN_ENABLED": "true", + }, + want: &Content{ + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("svc1"), + Host: kong.String("mockbin.org"), + Enabled: kong.Bool(true), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "file with env var and parse bool - err on bad value", + args: args{[]string{"testdata/parsebool/file.yaml"}}, + envVars: map[string]string{ + "DECK_MOCKBIN_ENABLED": "RIP", + }, + wantErr: true, + }, + { + name: "file with env var and parse Int", + args: args{[]string{"testdata/parseint/file.yaml"}}, + envVars: map[string]string{ + "DECK_WRITE_TIMEOUT": "1337", + }, + want: &Content{ + Services: []FService{ + { + Service: kong.Service{ + Name: kong.String("svc1"), + Host: kong.String("mockbin.org"), + WriteTimeout: kong.Int(1337), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "file with env var and parse Int - err on bad value", + args: args{[]string{"testdata/parseint/file.yaml"}}, + envVars: map[string]string{ + "DECK_WRITE_TIMEOUT": "RIP", + }, + wantErr: true, + }, + { + name: "file with env var and parse Float64", + args: args{[]string{"testdata/parsefloat/file.yaml"}}, + envVars: map[string]string{ + "DECK_FOO_FLOAT": "1337", + }, + want: &Content{ + Plugins: []FPlugin{ + { + Plugin: kong.Plugin{ + Name: kong.String("foofloat"), + Config: kong.Configuration{ + "foo": float64(1337), + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "file with env var and parse Int - err on bad value", + args: args{[]string{"testdata/parsefloat/file.yaml"}}, + envVars: map[string]string{ + "DECK_FOO_FLOAT": "RIP", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.envVars { + t.Setenv(k, v) + } + got, err := getContent(tt.args.filenames, false) + if (err != nil) != tt.wantErr { + t.Errorf("getContent() error = %v, wantErr %v", err, tt.wantErr) + return + } + + opt := []cmp.Option{ + cmpopts.SortSlices(sortSlices), + cmpopts.SortSlices(func(a, b *string) bool { return *a < *b }), + cmpopts.EquateEmpty(), + } + if diff := cmp.Diff(got, tt.want, opt...); diff != "" { + t.Errorf(diff) + } + }) + } +} + +func Test_yamlUnmarshal(t *testing.T) { + stringToInterfaceMap := map[string]interface{}{} + bytes1 := ` +versions: + v1: + enabled: false +` + mapOfMap := map[string]interface{}{} + err := yamlUnmarshal([]byte(bytes1), &mapOfMap) + if err != nil { + t.Errorf("yamlUnmarshal() error = %v (should be nil)", err) + } + subMap := mapOfMap["versions"] + if reflect.TypeOf(subMap) != reflect.TypeOf(stringToInterfaceMap) { + t.Errorf("yamlUnmarshal() expected type: %T, got: %T", stringToInterfaceMap, subMap) + } + + bytes2 := ` +versions: +- enabled: false + version: 1 +` + mapOfArrayOfMap := map[string]interface{}{} + err = yamlUnmarshal([]byte(bytes2), &mapOfArrayOfMap) + if err != nil { + t.Errorf("yamlUnmarshal() error = %v (should be nil)", err) + } + array := mapOfArrayOfMap["versions"].([]interface{}) + element := array[0] + if reflect.TypeOf(element) != reflect.TypeOf(stringToInterfaceMap) { + t.Errorf("yamlUnmarshal() expected type: %T, got: %T", stringToInterfaceMap, element) + } +} diff --git a/pkg/file/schema.go b/pkg/file/schema.go new file mode 100644 index 0000000..e099c87 --- /dev/null +++ b/pkg/file/schema.go @@ -0,0 +1,6 @@ +package file + +import _ "embed" // for embedding only + +//go:embed kong_json_schema.json +var kongJSONSchema string diff --git a/pkg/file/testdata/bad-env-var/file.yaml b/pkg/file/testdata/bad-env-var/file.yaml new file mode 100644 index 0000000..9795333 --- /dev/null +++ b/pkg/file/testdata/bad-env-var/file.yaml @@ -0,0 +1,9 @@ +services: +- name: svc2 + host: ${{ env "SVC2_HOST" }} + routes: + - name: r2 + paths: + - /r2 +plugins: +- name: prometheus diff --git a/pkg/file/testdata/badjson/foo.json b/pkg/file/testdata/badjson/foo.json new file mode 100644 index 0000000..e13d4e8 --- /dev/null +++ b/pkg/file/testdata/badjson/foo.json @@ -0,0 +1,13 @@ +{ + "services": { + "foo": "bar" + }, + "consumers": [ + { + "username": "foo" + }, + { + "username": "bar" + } + ] +} diff --git a/pkg/file/testdata/badyaml/bar.yml b/pkg/file/testdata/badyaml/bar.yml new file mode 100644 index 0000000..49478f3 --- /dev/null +++ b/pkg/file/testdata/badyaml/bar.yml @@ -0,0 +1,8 @@ +- name: svc2 + host: 2.example.com + routes: + - name: r2 + paths: + - /r2 +plugins: +- name: prometheus diff --git a/pkg/file/testdata/badyamlwithspace/bar.yml b/pkg/file/testdata/badyamlwithspace/bar.yml new file mode 100644 index 0000000..683011b --- /dev/null +++ b/pkg/file/testdata/badyamlwithspace/bar.yml @@ -0,0 +1,12 @@ + _info: + select_tags: + - test +services: +- name: svc2 + host: 2.example.com + routes: + - name: r2 + paths: + - /r2 +plugins: +- name: prometheus diff --git a/pkg/file/testdata/config.yaml b/pkg/file/testdata/config.yaml new file mode 100644 index 0000000..0b654c6 --- /dev/null +++ b/pkg/file/testdata/config.yaml @@ -0,0 +1,4 @@ +services: +- name: svc1 + host: mockbin.org + enabled: true \ No newline at end of file diff --git a/pkg/file/testdata/defaults/bar.yaml b/pkg/file/testdata/defaults/bar.yaml new file mode 100644 index 0000000..92ae244 --- /dev/null +++ b/pkg/file/testdata/defaults/bar.yaml @@ -0,0 +1,7 @@ +_format_version: "1.1" +upstreams: +- name: upstream1 + algorithm: round-robin + targets: + - target: 198.51.100.11:80 + weight: 0 # default being 100 diff --git a/pkg/file/testdata/differentruntimegroup/bar.yaml b/pkg/file/testdata/differentruntimegroup/bar.yaml new file mode 100644 index 0000000..a613369 --- /dev/null +++ b/pkg/file/testdata/differentruntimegroup/bar.yaml @@ -0,0 +1,6 @@ +_format_version: "3.0" +_konnect: + runtime_group_name: bar +services: +- name: svc2 + host: 2.example.com \ No newline at end of file diff --git a/pkg/file/testdata/differentruntimegroup/foo.yaml b/pkg/file/testdata/differentruntimegroup/foo.yaml new file mode 100644 index 0000000..8a91952 --- /dev/null +++ b/pkg/file/testdata/differentruntimegroup/foo.yaml @@ -0,0 +1,6 @@ +_format_version: "3.0" +_konnect: + runtime_group_name: foo +services: +- name: svc1 + host: 1.example.com \ No newline at end of file diff --git a/pkg/file/testdata/differentworkspace/bar.yaml b/pkg/file/testdata/differentworkspace/bar.yaml new file mode 100644 index 0000000..b5469cb --- /dev/null +++ b/pkg/file/testdata/differentworkspace/bar.yaml @@ -0,0 +1,11 @@ +_format_version: "1.1" +_workspace: bar +services: +- name: svc2 + host: 2.example.com + tags: + - team-svc2 + routes: + - name: r2 + paths: + - /r2 \ No newline at end of file diff --git a/pkg/file/testdata/differentworkspace/foo.yaml b/pkg/file/testdata/differentworkspace/foo.yaml new file mode 100644 index 0000000..966c4da --- /dev/null +++ b/pkg/file/testdata/differentworkspace/foo.yaml @@ -0,0 +1,11 @@ +_format_version: "1.1" +_workspace: foo +services: +- name: svc1 + host: 1.example.com + tags: + - team-svc1 + routes: + - name: r1 + paths: + - /r1 \ No newline at end of file diff --git a/pkg/file/testdata/emptydir/README b/pkg/file/testdata/emptydir/README new file mode 100644 index 0000000..17fdc81 --- /dev/null +++ b/pkg/file/testdata/emptydir/README @@ -0,0 +1 @@ +Keep this directory empty. diff --git a/pkg/file/testdata/emptyfiles/Baz.YamL b/pkg/file/testdata/emptyfiles/Baz.YamL new file mode 100644 index 0000000..e69de29 diff --git a/pkg/file/testdata/emptyfiles/bar.yaml b/pkg/file/testdata/emptyfiles/bar.yaml new file mode 100644 index 0000000..e69de29 diff --git a/pkg/file/testdata/emptyfiles/foo.notyaml b/pkg/file/testdata/emptyfiles/foo.notyaml new file mode 100644 index 0000000..e69de29 diff --git a/pkg/file/testdata/emptyfiles/foo.yaml.pdf b/pkg/file/testdata/emptyfiles/foo.yaml.pdf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pkg/file/testdata/emptyfiles/foo.yml b/pkg/file/testdata/emptyfiles/foo.yml new file mode 100644 index 0000000..e69de29 diff --git a/pkg/file/testdata/emptyfiles/foobar.json b/pkg/file/testdata/emptyfiles/foobar.json new file mode 100644 index 0000000..e69de29 diff --git a/pkg/file/testdata/emptyfiles/not-a-file.yaml/info.txt b/pkg/file/testdata/emptyfiles/not-a-file.yaml/info.txt new file mode 100644 index 0000000..e69de29 diff --git a/pkg/file/testdata/file.json b/pkg/file/testdata/file.json new file mode 100644 index 0000000..247efe2 --- /dev/null +++ b/pkg/file/testdata/file.json @@ -0,0 +1,10 @@ +{ + "consumers": [ + { + "username": "foo" + }, + { + "username": "bar" + } + ] +} diff --git a/pkg/file/testdata/file.yaml b/pkg/file/testdata/file.yaml new file mode 100644 index 0000000..05d6fa8 --- /dev/null +++ b/pkg/file/testdata/file.yaml @@ -0,0 +1,15 @@ +services: +- name: svc2 + host: ${{ env "DECK_SVC2_HOST" }} + routes: + - name: r2 + paths: + - /r2 + tags: + - '<' # verifies that the templating engine does not perform character escaping +plugins: +- name: prometheus +- name: pre-function + config: + log: | + ${{ env "DECK_FILE_LOG_FUNCTION" | indent 8 }} diff --git a/pkg/file/testdata/parsebool/file.yaml b/pkg/file/testdata/parsebool/file.yaml new file mode 100644 index 0000000..209a934 --- /dev/null +++ b/pkg/file/testdata/parsebool/file.yaml @@ -0,0 +1,4 @@ +services: +- name: svc1 + host: mockbin.org + enabled: ${{ env "DECK_MOCKBIN_ENABLED" | toBool }} diff --git a/pkg/file/testdata/parsefloat/file.yaml b/pkg/file/testdata/parsefloat/file.yaml new file mode 100644 index 0000000..dec0fc9 --- /dev/null +++ b/pkg/file/testdata/parsefloat/file.yaml @@ -0,0 +1,4 @@ +plugins: +- config: + foo: ${{ env "DECK_FOO_FLOAT" | toFloat }} + name: foofloat diff --git a/pkg/file/testdata/parseint/file.yaml b/pkg/file/testdata/parseint/file.yaml new file mode 100644 index 0000000..71e69d3 --- /dev/null +++ b/pkg/file/testdata/parseint/file.yaml @@ -0,0 +1,4 @@ +services: +- name: svc1 + host: mockbin.org + write_timeout: ${{ env "DECK_WRITE_TIMEOUT" | toInt }} diff --git a/pkg/file/testdata/sameworkspace/bar.yaml b/pkg/file/testdata/sameworkspace/bar.yaml new file mode 100644 index 0000000..b5469cb --- /dev/null +++ b/pkg/file/testdata/sameworkspace/bar.yaml @@ -0,0 +1,11 @@ +_format_version: "1.1" +_workspace: bar +services: +- name: svc2 + host: 2.example.com + tags: + - team-svc2 + routes: + - name: r2 + paths: + - /r2 \ No newline at end of file diff --git a/pkg/file/testdata/sameworkspace/foo.yaml b/pkg/file/testdata/sameworkspace/foo.yaml new file mode 100644 index 0000000..98db7c7 --- /dev/null +++ b/pkg/file/testdata/sameworkspace/foo.yaml @@ -0,0 +1,11 @@ +_format_version: "1.1" +_workspace: bar +services: +- name: svc1 + host: 1.example.com + tags: + - team-svc1 + routes: + - name: r1 + paths: + - /r1 \ No newline at end of file diff --git a/pkg/file/testdata/sharedworkspace/foo.yaml b/pkg/file/testdata/sharedworkspace/foo.yaml new file mode 100644 index 0000000..135b88a --- /dev/null +++ b/pkg/file/testdata/sharedworkspace/foo.yaml @@ -0,0 +1,9 @@ +services: +- name: svc1 + host: 1.example.com + tags: + - team-svc1 + routes: + - name: r1 + paths: + - /r1 diff --git a/pkg/file/testdata/sharedworkspace/meta.yaml b/pkg/file/testdata/sharedworkspace/meta.yaml new file mode 100644 index 0000000..1b1f7d7 --- /dev/null +++ b/pkg/file/testdata/sharedworkspace/meta.yaml @@ -0,0 +1,2 @@ +_format_version: "1.1" +_workspace: bar diff --git a/pkg/file/testdata/valid/bar.yml b/pkg/file/testdata/valid/bar.yml new file mode 100644 index 0000000..e58ba59 --- /dev/null +++ b/pkg/file/testdata/valid/bar.yml @@ -0,0 +1,9 @@ +services: +- name: svc2 + host: 2.example.com + routes: + - name: r2 + paths: + - /r2 +plugins: +- name: prometheus diff --git a/pkg/file/testdata/valid/consumers.json b/pkg/file/testdata/valid/consumers.json new file mode 100644 index 0000000..247efe2 --- /dev/null +++ b/pkg/file/testdata/valid/consumers.json @@ -0,0 +1,10 @@ +{ + "consumers": [ + { + "username": "foo" + }, + { + "username": "bar" + } + ] +} diff --git a/pkg/file/testdata/valid/foo.yaml b/pkg/file/testdata/valid/foo.yaml new file mode 100644 index 0000000..4d4b68b --- /dev/null +++ b/pkg/file/testdata/valid/foo.yaml @@ -0,0 +1,14 @@ +_info: + select_tags: + - tag1 +services: +- name: svc1 + host: 1.example.com + tags: + - team-svc1 + routes: + - name: r1 + paths: + - /r1 +consumers: +- username: harry diff --git a/pkg/file/types.go b/pkg/file/types.go new file mode 100644 index 0000000..d70b60e --- /dev/null +++ b/pkg/file/types.go @@ -0,0 +1,743 @@ +package file + +import ( + "encoding/json" + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// Format is a file format for Kong's configuration. +type Format string + +type sortable interface { + sortKey() string +} + +const ( + // JSON is JSON file format. + JSON = "JSON" + // YAML if YAML file format. + YAML = "YAML" +) + +const ( + httpPort = 80 + httpsPort = 443 +) + +// FService represents a Kong Service and it's associated routes and plugins. +// +k8s:deepcopy-gen=true +type FService struct { + kong.Service + Routes []*FRoute `json:"routes,omitempty" yaml:",omitempty"` + Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + + // sugar property + URL *string `json:"url,omitempty" yaml:",omitempty"` +} + +// sortKey is used for sorting. +func (s FService) sortKey() string { + if s.Name != nil { + return *s.Name + } + if s.ID != nil { + return *s.ID + } + return "" +} + +type service struct { + ClientCertificate *string `json:"client_certificate,omitempty" yaml:"client_certificate,omitempty"` + ConnectTimeout *int `json:"connect_timeout,omitempty" yaml:"connect_timeout,omitempty"` + CreatedAt *int `json:"created_at,omitempty" yaml:"created_at,omitempty"` + Host *string `json:"host,omitempty" yaml:"host,omitempty"` + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Name *string `json:"name,omitempty" yaml:"name,omitempty"` + Path *string `json:"path,omitempty" yaml:"path,omitempty"` + Port *int `json:"port,omitempty" yaml:"port,omitempty"` + Protocol *string `json:"protocol,omitempty" yaml:"protocol,omitempty"` + ReadTimeout *int `json:"read_timeout,omitempty" yaml:"read_timeout,omitempty"` + Retries *int `json:"retries,omitempty" yaml:"retries,omitempty"` + UpdatedAt *int `json:"updated_at,omitempty" yaml:"updated_at,omitempty"` + WriteTimeout *int `json:"write_timeout,omitempty" yaml:"write_timeout,omitempty"` + Tags []*string `json:"tags,omitempty" yaml:"tags,omitempty"` + TLSVerify *bool `json:"tls_verify,omitempty" yaml:"tls_verify,omitempty"` + TLSVerifyDepth *int `json:"tls_verify_depth,omitempty" yaml:"tls_verify_depth,omitempty"` + CACertificates []*string `json:"ca_certificates,omitempty" yaml:"ca_certificates,omitempty"` + Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` + Routes []*FRoute `json:"routes,omitempty" yaml:",omitempty"` + Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + + // sugar property + URL *string `json:"url,omitempty" yaml:",omitempty"` +} + +func copyToService(fService FService) service { + s := service{} + if fService.ClientCertificate != nil && + !utils.Empty(fService.ClientCertificate.ID) { + s.ClientCertificate = kong.String(*fService.ClientCertificate.ID) + } + s.CACertificates = fService.CACertificates + s.TLSVerify = fService.TLSVerify + s.TLSVerifyDepth = fService.TLSVerifyDepth + s.ConnectTimeout = fService.ConnectTimeout + s.CreatedAt = fService.CreatedAt + s.Host = fService.Host + s.ID = fService.ID + s.Name = fService.Name + s.Path = fService.Path + s.Port = fService.Port + s.Protocol = fService.Protocol + s.ReadTimeout = fService.ReadTimeout + s.Retries = fService.Retries + s.UpdatedAt = fService.UpdatedAt + s.WriteTimeout = fService.WriteTimeout + s.Tags = fService.Tags + s.Routes = fService.Routes + s.Plugins = fService.Plugins + s.Enabled = fService.Enabled + + return s +} + +func unwrapURL(urlString string, fService *FService) error { + parsed, err := url.Parse(urlString) + if err != nil { + return fmt.Errorf("invalid url: " + urlString) + } + if parsed.Scheme == "" { + return fmt.Errorf("invalid url:" + urlString) + } + + fService.Protocol = kong.String(parsed.Scheme) + + fService.Port = kong.Int(httpPort) + if parsed.Scheme == "https" { + fService.Port = kong.Int(httpsPort) + } + + if parsed.Host != "" { + hostPort := strings.Split(parsed.Host, ":") + fService.Host = kong.String(hostPort[0]) + if len(hostPort) > 1 { + port, err := strconv.Atoi(hostPort[1]) + if err == nil { + fService.Port = kong.Int(port) + } + } + } + if parsed.Path != "" { + // make sure that decoded whitespaces are encoded back + encodedParsedPath := strings.ReplaceAll(parsed.Path, " ", "%20") + fService.Path = kong.String(encodedParsedPath) + } + return nil +} + +func copyFromService(service service, fService *FService) error { + if service.ClientCertificate != nil && + !utils.Empty(service.ClientCertificate) { + fService.ClientCertificate = &kong.Certificate{ + ID: kong.String(*service.ClientCertificate), + } + } + if !utils.Empty(service.URL) { + err := unwrapURL(*service.URL, fService) + if err != nil { + return err + } + } + fService.ConnectTimeout = service.ConnectTimeout + fService.CreatedAt = service.CreatedAt + fService.ID = service.ID + fService.Name = service.Name + if service.Protocol != nil { + fService.Protocol = service.Protocol + } + if service.Host != nil { + fService.Host = service.Host + } + if service.Port != nil { + fService.Port = service.Port + } + if service.Path != nil { + fService.Path = service.Path + } + fService.ReadTimeout = service.ReadTimeout + fService.Retries = service.Retries + fService.UpdatedAt = service.UpdatedAt + fService.WriteTimeout = service.WriteTimeout + fService.Tags = service.Tags + fService.CACertificates = service.CACertificates + fService.TLSVerify = service.TLSVerify + fService.TLSVerifyDepth = service.TLSVerifyDepth + fService.Routes = service.Routes + fService.Plugins = service.Plugins + fService.Enabled = service.Enabled + return nil +} + +// MarshalYAML is a custom marshal to handle +// SNI. +func (s FService) MarshalYAML() (interface{}, error) { + return copyToService(s), nil +} + +// UnmarshalYAML is a custom marshal method to handle +// foreign references. +func (s *FService) UnmarshalYAML(unmarshal func(interface{}) error) error { + var service service + if err := unmarshal(&service); err != nil { + return err + } + return copyFromService(service, s) +} + +// MarshalJSON is a custom marshal method to handle +// foreign references. +func (s FService) MarshalJSON() ([]byte, error) { + service := copyToService(s) + return json.Marshal(service) +} + +// UnmarshalJSON is a custom marshal method to handle +// foreign references. +func (s *FService) UnmarshalJSON(b []byte) error { + var service service + err := json.Unmarshal(b, &service) + if err != nil { + return err + } + return copyFromService(service, s) +} + +// FRoute represents a Kong Route and it's associated plugins. +// +k8s:deepcopy-gen=true +type FRoute struct { + kong.Route `yaml:",inline,omitempty"` + Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` +} + +// sortKey is used for sorting. +func (r FRoute) sortKey() string { + if r.Name != nil { + return *r.Name + } + if r.ID != nil { + return *r.ID + } + return "" +} + +// FUpstream represents a Kong Upstream and it's associated targets. +// +k8s:deepcopy-gen=true +type FUpstream struct { + kong.Upstream `yaml:",inline,omitempty"` + Targets []*FTarget `json:"targets,omitempty" yaml:",omitempty"` +} + +// sortKey is used for sorting. +func (u FUpstream) sortKey() string { + if u.Name != nil { + return *u.Name + } + if u.ID != nil { + return *u.ID + } + return "" +} + +// FTarget represents a Kong Target. +// +k8s:deepcopy-gen=true +type FTarget struct { + kong.Target `yaml:",inline,omitempty"` +} + +// sortKey is used for sorting. +func (t FTarget) sortKey() string { + if t.Target.Target != nil { + return *t.Target.Target + } + if t.ID != nil { + return *t.ID + } + return "" +} + +// FCertificate represents a Kong Certificate. +// +k8s:deepcopy-gen=true +type FCertificate struct { + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Cert *string `json:"cert,omitempty" yaml:"cert,omitempty"` + Key *string `json:"key,omitempty" yaml:"key,omitempty"` + CreatedAt *int64 `json:"created_at,omitempty" yaml:"created_at,omitempty"` + Tags []*string `json:"tags,omitempty" yaml:"tags,omitempty"` + SNIs []kong.SNI `json:"snis,omitempty" yaml:"snis,omitempty"` +} + +// sortKey is used for sorting. +func (c FCertificate) sortKey() string { + if c.Cert != nil { + return *c.Cert + } + if c.ID != nil { + return *c.ID + } + return "" +} + +// FCACertificate represents a Kong CACertificate. +// +k8s:deepcopy-gen=true +type FCACertificate struct { + kong.CACertificate `yaml:",inline,omitempty"` +} + +// sortKey is used for sorting. +func (c FCACertificate) sortKey() string { + if c.Cert != nil { + return *c.Cert + } + if c.ID != nil { + return *c.ID + } + return "" +} + +// FPlugin represents a plugin in Kong. +// +k8s:deepcopy-gen=true +type FPlugin struct { + kong.Plugin `yaml:",inline,omitempty"` + + ConfigSource *string `json:"_config,omitempty" yaml:"_config,omitempty"` +} + +// foo is a shadow type of Plugin. +// It is used for custom marshalling of plugin. +type foo struct { + CreatedAt *int `json:"created_at,omitempty" yaml:"created_at,omitempty"` + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Name *string `json:"name,omitempty" yaml:"name,omitempty"` + InstanceName *string `json:"instance_name,omitempty" yaml:"instance_name,omitempty"` + Config kong.Configuration `json:"config,omitempty" yaml:"config,omitempty"` + Service string `json:"service,omitempty" yaml:",omitempty"` + Consumer string `json:"consumer,omitempty" yaml:",omitempty"` + ConsumerGroup string `json:"consumer_group,omitempty" yaml:",omitempty"` + Route string `json:"route,omitempty" yaml:",omitempty"` + Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` + RunOn *string `json:"run_on,omitempty" yaml:"run_on,omitempty"` + Ordering *kong.PluginOrdering `json:"ordering,omitempty" yaml:"ordering,omitempty"` + Protocols []*string `json:"protocols,omitempty" yaml:"protocols,omitempty"` + Tags []*string `json:"tags,omitempty" yaml:"tags,omitempty"` + + ConfigSource *string `json:"_config,omitempty" yaml:"_config,omitempty"` +} + +func copyToFoo(p FPlugin) foo { + f := foo{} + if p.ID != nil { + f.ID = p.ID + } + if p.Name != nil { + f.Name = p.Name + } + if p.InstanceName != nil { + f.InstanceName = p.InstanceName + } + if p.Enabled != nil { + f.Enabled = p.Enabled + } + if p.RunOn != nil { + f.RunOn = p.RunOn + } + if p.Protocols != nil { + f.Protocols = p.Protocols + } + if p.Ordering != nil { + f.Ordering = p.Ordering + } + if p.Tags != nil { + f.Tags = p.Tags + } + if p.Config != nil { + f.Config = p.Config + } + if p.ConfigSource != nil { + f.ConfigSource = p.ConfigSource + } + if p.Plugin.Consumer != nil { + f.Consumer = *p.Plugin.Consumer.ID + } + if p.Plugin.Route != nil { + f.Route = *p.Plugin.Route.ID + } + if p.Plugin.Service != nil { + f.Service = *p.Plugin.Service.ID + } + if p.Plugin.ConsumerGroup != nil { + f.ConsumerGroup = *p.Plugin.ConsumerGroup.ID + } + return f +} + +func copyFromFoo(f foo, p *FPlugin) { + if f.ID != nil { + p.ID = f.ID + } + if f.Name != nil { + p.Name = f.Name + } + if f.InstanceName != nil { + p.InstanceName = f.InstanceName + } + if f.Enabled != nil { + p.Enabled = f.Enabled + } + if f.RunOn != nil { + p.RunOn = f.RunOn + } + if f.Protocols != nil { + p.Protocols = f.Protocols + } + if f.Ordering != nil { + p.Ordering = f.Ordering + } + if f.Tags != nil { + p.Tags = f.Tags + } + if f.Config != nil { + p.Config = f.Config + } + if f.ConfigSource != nil { + p.ConfigSource = f.ConfigSource + } + if f.Consumer != "" { + p.Consumer = &kong.Consumer{ + ID: kong.String(f.Consumer), + } + } + if f.Route != "" { + p.Route = &kong.Route{ + ID: kong.String(f.Route), + } + } + if f.Service != "" { + p.Service = &kong.Service{ + ID: kong.String(f.Service), + } + } + if f.ConsumerGroup != "" { + p.ConsumerGroup = &kong.ConsumerGroup{ + ID: kong.String(f.ConsumerGroup), + } + } +} + +// MarshalYAML is a custom marshal method to handle +// foreign references. +func (p FPlugin) MarshalYAML() (interface{}, error) { + return copyToFoo(p), nil +} + +// UnmarshalYAML is a custom marshal method to handle +// foreign references. +func (p *FPlugin) UnmarshalYAML(unmarshal func(interface{}) error) error { + var f foo + if err := unmarshal(&f); err != nil { + return err + } + copyFromFoo(f, p) + return nil +} + +// MarshalJSON is a custom marshal method to handle +// foreign references. +func (p FPlugin) MarshalJSON() ([]byte, error) { + f := copyToFoo(p) + return json.Marshal(f) +} + +// UnmarshalJSON is a custom marshal method to handle +// foreign references. +func (p *FPlugin) UnmarshalJSON(b []byte) error { + var f foo + err := json.Unmarshal(b, &f) + if err != nil { + return err + } + copyFromFoo(f, p) + return nil +} + +// sortKey is used for sorting. +func (p FPlugin) sortKey() string { + // concat plugin name and foreign relations + if p.Name != nil { + key := *p.Name + if p.Consumer != nil { + key += *p.Consumer.ID + } + if p.Route != nil { + key += *p.Route.ID + } + if p.Service != nil { + key += *p.Service.ID + } + if p.ConsumerGroup != nil { + key += *p.ConsumerGroup.ID + } + return key + } + if p.ID != nil { + return *p.ID + } + return "" +} + +// FConsumer represents a consumer in Kong. +// +k8s:deepcopy-gen=true +type FConsumer struct { + kong.Consumer `yaml:",inline,omitempty"` + Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + KeyAuths []*kong.KeyAuth `json:"keyauth_credentials,omitempty" yaml:"keyauth_credentials,omitempty"` + HMACAuths []*kong.HMACAuth `json:"hmacauth_credentials,omitempty" yaml:"hmacauth_credentials,omitempty"` + JWTAuths []*kong.JWTAuth `json:"jwt_secrets,omitempty" yaml:"jwt_secrets,omitempty"` + BasicAuths []*kong.BasicAuth `json:"basicauth_credentials,omitempty" yaml:"basicauth_credentials,omitempty"` + Oauth2Creds []*kong.Oauth2Credential `json:"oauth2_credentials,omitempty" yaml:"oauth2_credentials,omitempty"` + ACLGroups []*kong.ACLGroup `json:"acls,omitempty" yaml:"acls,omitempty"` + MTLSAuths []*kong.MTLSAuth `json:"mtls_auth_credentials,omitempty" yaml:"mtls_auth_credentials,omitempty"` + Groups []*kong.ConsumerGroup `json:"groups,omitempty" yaml:"groups,omitempty"` +} + +// sortKey is used for sorting. +func (c FConsumer) sortKey() string { + if c.Username != nil { + return *c.Username + } + if c.ID != nil { + return *c.ID + } + return "" +} + +// FConsumerGroupObject represents a Kong ConsumerGroup and its associated consumers and plugins. +// +k8s:deepcopy-gen=true +type FConsumerGroupObject struct { + kong.ConsumerGroup `yaml:",inline,omitempty"` + Consumers []*kong.Consumer `json:"consumers,omitempty" yaml:",omitempty"` + Plugins []*kong.ConsumerGroupPlugin `json:"plugins,omitempty" yaml:",omitempty"` +} + +// sortKey is used for sorting. +func (u FConsumerGroupObject) sortKey() string { + if u.Name != nil { + return *u.Name + } + if u.ID != nil { + return *u.ID + } + return "" +} + +// FRBACRole represents an RBACRole in Kong +// +k8s:deepcopy-gen=true +type FRBACRole struct { + kong.RBACRole `yaml:",inline,omitempty"` + EndpointPermissions []*FRBACEndpointPermission `json:"endpoint_permissions,omitempty" yaml:"endpoint_permissions,omitempty"` //nolint +} + +// FRBACEndpointPermission is a wrapper type for RBACEndpointPermission. +// +k8s:deepcopy-gen=true +type FRBACEndpointPermission struct { + kong.RBACEndpointPermission `yaml:",inline,omitempty"` +} + +func (frbac FRBACEndpointPermission) MarshalJSON() ([]byte, error) { + m := map[string]interface{}{} + if frbac.Workspace != nil { + m["workspace"] = frbac.Workspace + } + if frbac.Actions != nil { + m["actions"] = frbac.Actions + } + if frbac.CreatedAt != nil { + m["created_at"] = frbac.CreatedAt + } + if frbac.Endpoint != nil { + m["endpoint"] = frbac.Endpoint + } + if frbac.Negative != nil { + m["negative"] = frbac.Negative + } + if frbac.Role != nil { + m["role"] = frbac.Role + } + if frbac.Comment != nil { + m["comment"] = frbac.Comment + } + return json.Marshal(m) +} + +// KongDefaults represents default values that are filled in +// for entities with corresponding missing properties. +// +k8s:deepcopy-gen=true +type KongDefaults struct { + Service *kong.Service `json:"service,omitempty" yaml:"service,omitempty"` + Route *kong.Route `json:"route,omitempty" yaml:"route,omitempty"` + Upstream *kong.Upstream `json:"upstream,omitempty" yaml:"upstream,omitempty"` + Target *kong.Target `json:"target,omitempty" yaml:"target,omitempty"` +} + +// Info contains meta-data of the file. +// +k8s:deepcopy-gen=true +type Info struct { + SelectorTags []string `json:"select_tags,omitempty" yaml:"select_tags,omitempty"` + Defaults KongDefaults `json:"defaults,omitempty" yaml:"defaults,omitempty"` +} + +// Konnect contains configuration specific to Konnect. +// +k8s:deepcopy-gen=true +type Konnect struct { + RuntimeGroupName string `json:"runtime_group_name,omitempty" yaml:"runtime_group_name,omitempty"` + ControlPlaneName string `json:"control_plane_name,omitempty" yaml:"control_plane_name,omitempty"` +} + +// Kong represents Kong implementation of a Service in Konnect. +// +k8s:deepcopy-gen=true +type Kong struct { + Service *FService `json:"service,omitempty" yaml:"service,omitempty"` +} + +// Implementation represents an implementation of a Service version in Konnect. +// +k8s:deepcopy-gen=true +type Implementation struct { + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Kong *Kong `json:"kong,omitempty" yaml:"kong,omitempty"` +} + +// FServiceVersion represents a Service version in Konnect. +// The type is duplicated because only a single document is +// exported in file while the API allows for multiple documents. +// +k8s:deepcopy-gen=true +type FServiceVersion struct { + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Version *string `json:"version,omitempty" yaml:"version,omitempty"` + Implementation *Implementation `json:"implementation,omitempty" yaml:"implementation,omitempty"` + Document *FDocument `json:"document,omitempty" yaml:"document,omitempty"` +} + +// FServicePackage represents a Service package and its associated entities. +// +k8s:deepcopy-gen=true +type FServicePackage struct { + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Name *string `json:"name,omitempty" yaml:"name,omitempty"` + Description *string `json:"description,omitempty" yaml:"description,omitempty"` + Versions []FServiceVersion `json:"versions,omitempty" yaml:"versions,omitempty"` + Document *FDocument `json:"document,omitempty" yaml:"document,omitempty"` +} + +// FDocument represents a document in Konnect. +// The type has been duplicated because the documents are altered +// before they are exported to the state file +// for better user experience. +// +k8s:deepcopy-gen=true +type FDocument struct { + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Path *string `json:"path,omitempty" yaml:"path,omitempty"` + Published *bool `json:"published,omitempty" yaml:"published,omitempty"` + Content *string `json:"-" yaml:"-"` +} + +// sortKey is used for sorting. +func (s FServiceVersion) sortKey() string { + if s.Version != nil { + return *s.Version + } + if s.ID != nil { + return *s.ID + } + return "" +} + +// sortKey is used for sorting. +func (s FServicePackage) sortKey() string { + if s.Name != nil { + return *s.Name + } + if s.ID != nil { + return *s.ID + } + return "" +} + +// FVault represents a vault in Kong. +// +k8s:deepcopy-gen=true +type FVault struct { + kong.Vault `yaml:",inline,omitempty"` +} + +// sortKey is used for sorting. +func (c FVault) sortKey() string { + if c.Prefix != nil { + return *c.Prefix + } + if c.ID != nil { + return *c.ID + } + return "" +} + +// FLicense exists as a file type _only_ +// This is a compatibility layer for KIC library usage. deck cannot interact with the license entity. +// Ref https://github.com/Kong/deck/pull/882 if we need to support this entity throughout deck. + +// FLicense represents a Kong License. +// +k8s:deepcopy-gen=true +type FLicense struct { + kong.License `yaml:",inline,omitempty"` +} + +// sortKey is used for sorting. +func (c FLicense) sortKey() string { + if c.ID != nil { + return *c.ID + } + return "" +} + +//go:generate go run ./codegen/main.go + +// Content represents a serialized Kong state. +// +k8s:deepcopy-gen=true +type Content struct { + FormatVersion string `json:"_format_version,omitempty" yaml:"_format_version,omitempty"` + Transform *bool `json:"_transform,omitempty" yaml:"_transform,omitempty"` + Info *Info `json:"_info,omitempty" yaml:"_info,omitempty"` + Workspace string `json:"_workspace,omitempty" yaml:"_workspace,omitempty"` + Konnect *Konnect `json:"_konnect,omitempty" yaml:"_konnect,omitempty"` + + Services []FService `json:"services,omitempty" yaml:",omitempty"` + Routes []FRoute `json:"routes,omitempty" yaml:",omitempty"` + Consumers []FConsumer `json:"consumers,omitempty" yaml:",omitempty"` + ConsumerGroups []FConsumerGroupObject `json:"consumer_groups,omitempty" yaml:",omitempty"` + Plugins []FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + Upstreams []FUpstream `json:"upstreams,omitempty" yaml:",omitempty"` + Certificates []FCertificate `json:"certificates,omitempty" yaml:",omitempty"` + CACertificates []FCACertificate `json:"ca_certificates,omitempty" yaml:"ca_certificates,omitempty"` + + RBACRoles []FRBACRole `json:"rbac_roles,omitempty" yaml:"rbac_roles,omitempty"` + + PluginConfigs map[string]kong.Configuration `json:"_plugin_configs,omitempty" yaml:"_plugin_configs,omitempty"` + + ServicePackages []FServicePackage `json:"service_packages,omitempty" yaml:"service_packages,omitempty"` + + Vaults []FVault `json:"vaults,omitempty" yaml:"vaults,omitempty"` + + Licenses []FLicense `json:"licenses,omitempty" yaml:"licenses,omitempty"` +} diff --git a/pkg/file/types_test.go b/pkg/file/types_test.go new file mode 100644 index 0000000..030d860 --- /dev/null +++ b/pkg/file/types_test.go @@ -0,0 +1,496 @@ +package file + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/yaml" +) + +var ( + jsonString = `{ + "name": "rate-limiting", + "config": { + "minute": 10 + }, + "service": "foo", + "route": "bar", + "consumer": "baz", + "enabled": true, + "run_on": "first", + "protocols": [ + "http" + ] +}` + yamlString = ` +name: rate-limiting +config: + minute: 10 +service: foo +consumer: baz +route: bar +enabled: true +run_on: first +protocols: +- http +` +) + +func Test_sortKey(t *testing.T) { + tests := []struct { + name string + sortable sortable + expectedKey string + }{ + { + sortable: &FService{ + Service: kong.Service{ + Name: kong.String("my-service"), + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-service", + }, + { + sortable: &FService{ + Service: kong.Service{ + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-id", + }, + { + sortable: FService{}, + expectedKey: "", + }, + { + sortable: &FRoute{ + Route: kong.Route{ + Name: kong.String("my-route"), + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-route", + }, + { + sortable: FRoute{ + Route: kong.Route{ + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-id", + }, + { + sortable: FRoute{}, + expectedKey: "", + }, + { + sortable: FUpstream{ + Upstream: kong.Upstream{ + Name: kong.String("my-upstream"), + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-upstream", + }, + { + sortable: FUpstream{ + Upstream: kong.Upstream{ + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-id", + }, + { + sortable: FUpstream{}, + expectedKey: "", + }, + { + sortable: FTarget{ + Target: kong.Target{ + Target: kong.String("my-target"), + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-target", + }, + { + sortable: FTarget{ + Target: kong.Target{ + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-id", + }, + { + sortable: FTarget{}, + expectedKey: "", + }, + { + sortable: FCertificate{ + Cert: kong.String("my-certificate"), + ID: kong.String("my-id"), + }, + expectedKey: "my-certificate", + }, + { + sortable: FCertificate{ + ID: kong.String("my-id"), + }, + expectedKey: "my-id", + }, + { + sortable: FCertificate{}, + expectedKey: "", + }, + { + sortable: FCACertificate{ + CACertificate: kong.CACertificate{ + Cert: kong.String("my-ca-certificate"), + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-ca-certificate", + }, + { + sortable: FCACertificate{ + CACertificate: kong.CACertificate{ + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-id", + }, + { + sortable: FCACertificate{}, + expectedKey: "", + }, + { + sortable: FPlugin{ + Plugin: kong.Plugin{ + Name: kong.String("my-plugin"), + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-plugin", + }, + { + sortable: FPlugin{ + Plugin: kong.Plugin{ + Name: kong.String("my-plugin"), + ID: kong.String("my-id"), + Consumer: &kong.Consumer{ + ID: kong.String("my-consumer-id"), + }, + }, + }, + expectedKey: "my-pluginmy-consumer-id", + }, + { + sortable: FPlugin{ + Plugin: kong.Plugin{ + Name: kong.String("my-plugin"), + ID: kong.String("my-id"), + Route: &kong.Route{ + ID: kong.String("my-route-id"), + }, + }, + }, + expectedKey: "my-pluginmy-route-id", + }, + { + sortable: FPlugin{ + Plugin: kong.Plugin{ + Name: kong.String("my-plugin"), + ID: kong.String("my-id"), + Service: &kong.Service{ + ID: kong.String("my-service-id"), + }, + }, + }, + expectedKey: "my-pluginmy-service-id", + }, + + { + sortable: FPlugin{ + Plugin: kong.Plugin{ + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-id", + }, + { + sortable: FPlugin{}, + expectedKey: "", + }, + { + sortable: &FConsumer{ + Consumer: kong.Consumer{ + Username: kong.String("my-consumer"), + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-consumer", + }, + { + sortable: &FConsumer{ + Consumer: kong.Consumer{ + ID: kong.String("my-id"), + }, + }, + expectedKey: "my-id", + }, + { + sortable: FConsumer{}, + expectedKey: "", + }, + { + sortable: &FServicePackage{ + Name: kong.String("my-service-package"), + ID: kong.String("my-id"), + }, + expectedKey: "my-service-package", + }, + { + sortable: &FServicePackage{ + ID: kong.String("my-id"), + }, + expectedKey: "my-id", + }, + { + sortable: FServicePackage{}, + expectedKey: "", + }, + { + sortable: &FServiceVersion{ + Version: kong.String("my-service-version"), + ID: kong.String("my-id"), + }, + expectedKey: "my-service-version", + }, + { + sortable: &FServiceVersion{ + ID: kong.String("my-id"), + }, + expectedKey: "my-id", + }, + { + sortable: FServiceVersion{}, + expectedKey: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := tt.sortable.sortKey() + if key != tt.expectedKey { + t.Errorf("Expected %v, but is %v", tt.expectedKey, key) + } + }) + } +} + +func TestPluginUnmarshalYAML(t *testing.T) { + var p FPlugin + assert := assert.New(t) + assert.Nil(yaml.Unmarshal([]byte(yamlString), &p)) + assert.Equal(kong.Plugin{ + Name: p.Name, + Config: p.Config, + Enabled: p.Enabled, + RunOn: p.RunOn, + Protocols: p.Protocols, + Service: &kong.Service{ + ID: kong.String("foo"), + }, + Consumer: &kong.Consumer{ + ID: kong.String("baz"), + }, + Route: &kong.Route{ + ID: kong.String("bar"), + }, + }, p.Plugin) +} + +func TestPluginUnmarshalJSON(t *testing.T) { + var p FPlugin + assert := assert.New(t) + assert.Nil(json.Unmarshal([]byte(jsonString), &p)) + assert.Equal(kong.Plugin{ + Name: p.Name, + Config: p.Config, + Enabled: p.Enabled, + RunOn: p.RunOn, + Protocols: p.Protocols, + Service: &kong.Service{ + ID: kong.String("foo"), + }, + Consumer: &kong.Consumer{ + ID: kong.String("baz"), + }, + Route: &kong.Route{ + ID: kong.String("bar"), + }, + }, p.Plugin) +} + +func Test_unwrapURL(t *testing.T) { + type args struct { + urlString string + fService *FService + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + args: args{ + urlString: "https://foo.com:8008/bar", + fService: &FService{ + Service: kong.Service{ + Host: kong.String("foo.com"), + Port: kong.Int(8008), + Protocol: kong.String("https"), + Path: kong.String("/bar"), + }, + }, + }, + wantErr: false, + }, + { + args: args{ + urlString: "https://foo.com/bar", + fService: &FService{ + Service: kong.Service{ + Host: kong.String("foo.com"), + Protocol: kong.String("https"), + Path: kong.String("/bar"), + Port: kong.Int(443), + }, + }, + }, + wantErr: false, + }, + { + args: args{ + urlString: "https://foo.com:4224/", + fService: &FService{ + Service: kong.Service{ + Host: kong.String("foo.com"), + Protocol: kong.String("https"), + Path: kong.String("/"), + Port: kong.Int(4224), + }, + }, + }, + wantErr: false, + }, + { + args: args{ + urlString: "https://foo.com/", + fService: &FService{ + Service: kong.Service{ + Host: kong.String("foo.com"), + Protocol: kong.String("https"), + Path: kong.String("/"), + Port: kong.Int(443), + }, + }, + }, + wantErr: false, + }, + { + args: args{ + urlString: "http://foo.com:4242", + fService: &FService{ + Service: kong.Service{ + Host: kong.String("foo.com"), + Protocol: kong.String("http"), + Port: kong.Int(4242), + }, + }, + }, + wantErr: false, + }, + { + args: args{ + urlString: "http://foo.com", + fService: &FService{ + Service: kong.Service{ + Host: kong.String("foo.com"), + Protocol: kong.String("http"), + Port: kong.Int(80), + }, + }, + }, + wantErr: false, + }, + { + args: args{ + urlString: "grpc://foocom", + fService: &FService{ + Service: kong.Service{ + Host: kong.String("foocom"), + Protocol: kong.String("grpc"), + Port: kong.Int(80), + }, + }, + }, + wantErr: false, + }, + { + args: args{ + urlString: "foo.com/sdf", + fService: &FService{ + Service: kong.Service{}, + }, + }, + wantErr: true, + }, + { + args: args{ + urlString: "foo.com", + fService: &FService{ + Service: kong.Service{}, + }, + }, + wantErr: true, + }, + { + args: args{ + urlString: "42:", + fService: &FService{ + Service: kong.Service{}, + }, + }, + wantErr: true, + }, + { + args: args{ + urlString: "http://foo.com/Spaced%20Test/bar", + fService: &FService{ + Service: kong.Service{ + Host: kong.String("foo.com"), + Protocol: kong.String("http"), + Port: kong.Int(80), + Path: kong.String("/Spaced%20Test/bar"), + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := FService{} + if err := unwrapURL(tt.args.urlString, &in); (err != nil) != tt.wantErr { + t.Errorf("unwrapURL() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(tt.args.fService, &in) { + t.Errorf("unwrapURL() got = %v, want = %v", &in, tt.args.fService) + } + }) + } +} diff --git a/pkg/file/validate.go b/pkg/file/validate.go new file mode 100644 index 0000000..ac2c325 --- /dev/null +++ b/pkg/file/validate.go @@ -0,0 +1,68 @@ +package file + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/kong/deck/utils" + "github.com/xeipuuv/gojsonschema" + "sigs.k8s.io/yaml" +) + +type ValidationError struct { + Object string `json:"object"` + Err error `json:"error"` +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("validation error: object=%s, err=%v", e.Object, e.Err) +} + +func validate(content []byte) error { + var c map[string]interface{} + err := yaml.Unmarshal(content, &c) + if err != nil { + return fmt.Errorf("unmarshaling file content: %w", err) + } + c = ensureJSON(c) + schemaLoader := gojsonschema.NewStringLoader(kongJSONSchema) + documentLoader := gojsonschema.NewGoLoader(c) + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + return err + } + if result.Valid() { + return nil + } + + var errs utils.ErrArray + for _, desc := range result.Errors() { + jsonString, err := json.Marshal(desc.Value()) + if err != nil { + return err + } + errs.Errors = append(errs.Errors, &ValidationError{Object: string(jsonString), Err: errors.New(desc.String())}) + } + return errs +} + +func validateWorkspaces(workspaces []string) error { + utils.RemoveDuplicates(&workspaces) + if len(workspaces) > 1 { + return fmt.Errorf("it seems like you are trying to sync multiple workspaces "+ + "at the same time (%v).\ndecK doesn't support syncing multiple workspaces at the same time, "+ + "please sync one workspace at a time", workspaces) + } + return nil +} + +func validateRuntimeGroups(names []string) error { + utils.RemoveDuplicates(&names) + if len(names) > 1 { + return fmt.Errorf("it seems like you are trying to sync multiple Konnect Runtime Groups "+ + "at the same time (%v).\ndecK doesn't support syncing multiple Runtime Groups at the same time, "+ + "please sync one Runtime Group at a time", names) + } + return nil +} diff --git a/pkg/file/writer.go b/pkg/file/writer.go new file mode 100644 index 0000000..d33632d --- /dev/null +++ b/pkg/file/writer.go @@ -0,0 +1,842 @@ +package file + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" + "sigs.k8s.io/yaml" +) + +// WriteConfig holds settings to use to write the state file. +type WriteConfig struct { + Workspace string + SelectTags []string + Filename string + FileFormat Format + WithID bool + ControlPlaneName string + KongVersion string +} + +func compareOrder(obj1, obj2 sortable) bool { + return strings.Compare(obj1.sortKey(), obj2.sortKey()) < 0 +} + +func getFormatVersion(kongVersion string) (string, error) { + parsedKongVersion, err := utils.ParseKongVersion(kongVersion) + if err != nil { + return "", fmt.Errorf("parsing Kong version: %w", err) + } + formatVersion := "1.1" + if parsedKongVersion.GTE(utils.Kong300Version) { + formatVersion = "3.0" + } + return formatVersion, nil +} + +// KongStateToFile generates a state object to file.Content. +// It will omit timestamps and IDs while writing. +func KongStateToContent(kongState *state.KongState, config WriteConfig) (*Content, error) { + file := &Content{} + var err error + + file.Workspace = config.Workspace + formatVersion, err := getFormatVersion(config.KongVersion) + if err != nil { + return nil, fmt.Errorf("get format version: %w", err) + } + file.FormatVersion = formatVersion + if config.ControlPlaneName != "" { + file.Konnect = &Konnect{ + ControlPlaneName: config.ControlPlaneName, + } + } + + selectTags := config.SelectTags + if len(selectTags) > 0 { + file.Info = &Info{ + SelectorTags: selectTags, + } + } + + err = populateServices(kongState, file, config) + if err != nil { + return nil, err + } + + err = populateServicelessRoutes(kongState, file, config) + if err != nil { + return nil, err + } + + err = populatePlugins(kongState, file, config) + if err != nil { + return nil, err + } + + err = populateUpstreams(kongState, file, config) + if err != nil { + return nil, err + } + + err = populateCertificates(kongState, file, config) + if err != nil { + return nil, err + } + + err = populateCACertificates(kongState, file, config) + if err != nil { + return nil, err + } + + err = populateConsumers(kongState, file, config) + if err != nil { + return nil, err + } + + err = populateVaults(kongState, file, config) + if err != nil { + return nil, err + } + + err = populateConsumerGroups(kongState, file, config) + if err != nil { + return nil, err + } + return file, nil +} + +// KongStateToFile writes a state object to file with filename. +// See KongStateToContent for the State generation +func KongStateToFile(kongState *state.KongState, config WriteConfig) error { + file, err := KongStateToContent(kongState, config) + if err != nil { + return err + } + return WriteContentToFile(file, config.Filename, config.FileFormat) +} + +func KonnectStateToFile(kongState *state.KongState, config WriteConfig) error { + file := &Content{} + file.FormatVersion = "0.1" + var err error + + err = populateServicePackages(kongState, file, config) + if err != nil { + return err + } + + // do not populate service-less routes + // we do not know if konnect supports these or not + + err = populatePlugins(kongState, file, config) + if err != nil { + return err + } + + err = populateUpstreams(kongState, file, config) + if err != nil { + return err + } + + err = populateCertificates(kongState, file, config) + if err != nil { + return err + } + + err = populateCACertificates(kongState, file, config) + if err != nil { + return err + } + + err = populateConsumers(kongState, file, config) + if err != nil { + return err + } + + return WriteContentToFile(file, config.Filename, config.FileFormat) +} + +func populateServicePackages(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + packages, err := kongState.ServicePackages.GetAll() + if err != nil { + return err + } + + for _, sp := range packages { + safePackageName := utils.NameToFilename(*sp.Name) + p := FServicePackage{ + ID: sp.ID, + Name: sp.Name, + Description: sp.Description, + } + versions, err := kongState.ServiceVersions.GetAllByServicePackageID(*p.ID) + if err != nil { + return err + } + documents, err := kongState.Documents.GetAllByParent(sp) + if err != nil { + return err + } + + for _, d := range documents { + safeDocPath := utils.NameToFilename(*d.Path) + fDocument := FDocument{ + ID: d.ID, + Path: kong.String(filepath.Join(safePackageName, safeDocPath)), + Published: d.Published, + Content: d.Content, + } + utils.ZeroOutID(&fDocument, fDocument.Path, config.WithID) + p.Document = &fDocument + // Although the documents API returns a list of documents and does support multiple documents, + // we pretend there's only one because that's all the web UI allows. + break + } + + for _, v := range versions { + safeVersionName := utils.NameToFilename(*v.Version) + fVersion := FServiceVersion{ + ID: v.ID, + Version: v.Version, + } + if v.ControlPlaneServiceRelation != nil && + !utils.Empty(v.ControlPlaneServiceRelation.ControlPlaneEntityID) { + kongServiceID := *v.ControlPlaneServiceRelation.ControlPlaneEntityID + + s, err := fetchService(kongServiceID, kongState, config) + if err != nil { + return err + } + fVersion.Implementation = &Implementation{ + Type: utils.ImplementationTypeKongGateway, + Kong: &Kong{ + Service: s, + }, + } + } + documents, err := kongState.Documents.GetAllByParent(v) + if err != nil { + return err + } + + for _, d := range documents { + safeDocPath := utils.NameToFilename(*d.Path) + fDocument := FDocument{ + ID: d.ID, + Path: kong.String(filepath.Join(safePackageName, safeVersionName, safeDocPath)), + Published: d.Published, + Content: d.Content, + } + utils.ZeroOutID(&fDocument, fDocument.Path, config.WithID) + fVersion.Document = &fDocument + break + } + utils.ZeroOutID(&fVersion, fVersion.Version, config.WithID) + p.Versions = append(p.Versions, fVersion) + } + sort.SliceStable(p.Versions, func(i, j int) bool { + return compareOrder(p.Versions[i], p.Versions[j]) + }) + utils.ZeroOutID(&p, p.Name, config.WithID) + file.ServicePackages = append(file.ServicePackages, p) + } + sort.SliceStable(file.ServicePackages, func(i, j int) bool { + return compareOrder(file.ServicePackages[i], file.ServicePackages[j]) + }) + return nil +} + +func populateServices(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + services, err := kongState.Services.GetAll() + if err != nil { + return err + } + for _, s := range services { + s, err := fetchService(*s.ID, kongState, config) + if err != nil { + return err + } + file.Services = append(file.Services, *s) + } + sort.SliceStable(file.Services, func(i, j int) bool { + return compareOrder(file.Services[i], file.Services[j]) + }) + return nil +} + +func fetchService(id string, kongState *state.KongState, config WriteConfig) (*FService, error) { + kongService, err := kongState.Services.Get(id) + if err != nil { + return nil, err + } + s := FService{Service: kongService.Service} + routes, err := kongState.Routes.GetAllByServiceID(*s.ID) + if err != nil { + return nil, err + } + plugins, err := kongState.Plugins.GetAllByServiceID(*s.ID) + if err != nil { + return nil, err + } + for _, p := range plugins { + p := p + if p.Route != nil || p.Consumer != nil { + continue + } + p.Service = nil + utils.ZeroOutID(p, p.Name, config.WithID) + utils.ZeroOutTimestamps(p) + utils.MustRemoveTags(&p.Plugin, config.SelectTags) + s.Plugins = append(s.Plugins, &FPlugin{Plugin: p.Plugin}) + } + sort.SliceStable(s.Plugins, func(i, j int) bool { + return compareOrder(s.Plugins[i], s.Plugins[j]) + }) + for _, r := range routes { + r := r + plugins, err := kongState.Plugins.GetAllByRouteID(*r.ID) + if err != nil { + return nil, err + } + r.Service = nil + utils.ZeroOutID(r, r.Name, config.WithID) + utils.ZeroOutTimestamps(r) + utils.MustRemoveTags(&r.Route, config.SelectTags) + route := &FRoute{Route: r.Route} + for _, p := range plugins { + p := p + if p.Service != nil || p.Consumer != nil { + continue + } + p.Route = nil + utils.ZeroOutID(p, p.Name, config.WithID) + utils.ZeroOutTimestamps(p) + utils.MustRemoveTags(&p.Plugin, config.SelectTags) + route.Plugins = append(route.Plugins, &FPlugin{Plugin: p.Plugin}) + } + sort.SliceStable(route.Plugins, func(i, j int) bool { + return compareOrder(route.Plugins[i], route.Plugins[j]) + }) + s.Routes = append(s.Routes, route) + } + sort.SliceStable(s.Routes, func(i, j int) bool { + return compareOrder(s.Routes[i], s.Routes[j]) + }) + utils.ZeroOutID(&s, s.Name, config.WithID) + utils.ZeroOutTimestamps(&s) + utils.MustRemoveTags(&s, config.SelectTags) + return &s, nil +} + +func populateServicelessRoutes(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + routes, err := kongState.Routes.GetAll() + if err != nil { + return err + } + for _, r := range routes { + r := r + if r.Service != nil { + continue + } + plugins, err := kongState.Plugins.GetAllByRouteID(*r.ID) + if err != nil { + return err + } + utils.ZeroOutID(r, r.Name, config.WithID) + utils.ZeroOutTimestamps(r) + utils.MustRemoveTags(&r.Route, config.SelectTags) + route := &FRoute{Route: r.Route} + for _, p := range plugins { + p := p + if p.Service != nil || p.Consumer != nil { + continue + } + p.Route = nil + utils.ZeroOutID(p, p.Name, config.WithID) + utils.ZeroOutTimestamps(p) + utils.MustRemoveTags(&p.Plugin, config.SelectTags) + route.Plugins = append(route.Plugins, &FPlugin{Plugin: p.Plugin}) + } + sort.SliceStable(route.Plugins, func(i, j int) bool { + return compareOrder(route.Plugins[i], route.Plugins[j]) + }) + file.Routes = append(file.Routes, *route) + } + sort.SliceStable(file.Routes, func(i, j int) bool { + return compareOrder(file.Routes[i], file.Routes[j]) + }) + return nil +} + +func populatePlugins(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + plugins, err := kongState.Plugins.GetAll() + if err != nil { + return err + } + for _, p := range plugins { + p := p + associations := 0 + if p.Consumer != nil { + associations++ + cID := *p.Consumer.ID + consumer, err := kongState.Consumers.GetByIDOrUsername(cID) + if err != nil { + return fmt.Errorf("unable to get consumer %s for plugin %s [%s]: %w", cID, *p.Name, *p.ID, err) + } + if !utils.Empty(consumer.Username) { + cID = *consumer.Username + } + p.Consumer.ID = &cID + } + if p.Service != nil { + associations++ + sID := *p.Service.ID + service, err := kongState.Services.Get(sID) + if err != nil { + return fmt.Errorf("unable to get service %s for plugin %s [%s]: %w", sID, *p.Name, *p.ID, err) + } + if !utils.Empty(service.Name) { + sID = *service.Name + } + p.Service.ID = &sID + } + if p.Route != nil { + associations++ + rID := *p.Route.ID + route, err := kongState.Routes.Get(rID) + if err != nil { + return fmt.Errorf("unable to get route %s for plugin %s [%s]: %w", rID, *p.Name, *p.ID, err) + } + if !utils.Empty(route.Name) { + rID = *route.Name + } + p.Route.ID = &rID + } + if p.ConsumerGroup != nil { + associations++ + cgID := *p.ConsumerGroup.ID + cg, err := kongState.ConsumerGroups.Get(cgID) + if err != nil { + return fmt.Errorf("unable to get consumer-group %s for plugin %s [%s]: %w", cgID, *p.Name, *p.ID, err) + } + if !utils.Empty(cg.Name) { + cgID = *cg.Name + } + p.ConsumerGroup.ID = &cgID + } + if associations == 0 || associations > 1 { + utils.ZeroOutID(p, p.Name, config.WithID) + utils.ZeroOutTimestamps(p) + utils.MustRemoveTags(&p.Plugin, config.SelectTags) + p := FPlugin{Plugin: p.Plugin} + file.Plugins = append(file.Plugins, p) + } + } + sort.SliceStable(file.Plugins, func(i, j int) bool { + return compareOrder(file.Plugins[i], file.Plugins[j]) + }) + return nil +} + +func populateUpstreams(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + upstreams, err := kongState.Upstreams.GetAll() + if err != nil { + return err + } + for _, u := range upstreams { + u := FUpstream{Upstream: u.Upstream} + targets, err := kongState.Targets.GetAllByUpstreamID(*u.ID) + if err != nil { + return err + } + for _, t := range targets { + t := t + t.Upstream = nil + utils.ZeroOutID(t, t.Target.Target, config.WithID) + utils.ZeroOutTimestamps(t) + utils.MustRemoveTags(&t.Target, config.SelectTags) + u.Targets = append(u.Targets, &FTarget{Target: t.Target}) + } + sort.SliceStable(u.Targets, func(i, j int) bool { + return compareOrder(u.Targets[i], u.Targets[j]) + }) + utils.ZeroOutID(&u, u.Name, config.WithID) + utils.ZeroOutTimestamps(&u) + utils.MustRemoveTags(&u.Upstream, config.SelectTags) + file.Upstreams = append(file.Upstreams, u) + } + sort.SliceStable(file.Upstreams, func(i, j int) bool { + return compareOrder(file.Upstreams[i], file.Upstreams[j]) + }) + return nil +} + +func populateVaults(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + vaults, err := kongState.Vaults.GetAll() + if err != nil { + return err + } + for _, v := range vaults { + v := FVault{Vault: v.Vault} + utils.ZeroOutID(&v, v.Prefix, config.WithID) + utils.ZeroOutTimestamps(&v) + utils.MustRemoveTags(&v.Vault, config.SelectTags) + file.Vaults = append(file.Vaults, v) + } + sort.SliceStable(file.Vaults, func(i, j int) bool { + return compareOrder(file.Vaults[i], file.Vaults[j]) + }) + return nil +} + +func populateCertificates(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + certificates, err := kongState.Certificates.GetAll() + if err != nil { + return err + } + for _, c := range certificates { + c := FCertificate{ + ID: c.ID, + Cert: c.Cert, + Key: c.Key, + Tags: c.Tags, + } + snis, err := kongState.SNIs.GetAllByCertID(*c.ID) + if err != nil { + return err + } + for _, s := range snis { + s := s + s.Certificate = nil + utils.ZeroOutID(s, s.Name, config.WithID) + utils.ZeroOutTimestamps(s) + utils.MustRemoveTags(&s.SNI, config.SelectTags) + c.SNIs = append(c.SNIs, s.SNI) + } + sort.SliceStable(c.SNIs, func(i, j int) bool { + return strings.Compare(*c.SNIs[i].Name, *c.SNIs[j].Name) < 0 + }) + utils.ZeroOutTimestamps(&c) + utils.MustRemoveTags(&c, config.SelectTags) + file.Certificates = append(file.Certificates, c) + } + sort.SliceStable(file.Certificates, func(i, j int) bool { + return compareOrder(file.Certificates[i], file.Certificates[j]) + }) + return nil +} + +func populateCACertificates(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + caCertificates, err := kongState.CACertificates.GetAll() + if err != nil { + return err + } + for _, c := range caCertificates { + c := FCACertificate{CACertificate: c.CACertificate} + utils.ZeroOutTimestamps(&c) + utils.MustRemoveTags(&c.CACertificate, config.SelectTags) + file.CACertificates = append(file.CACertificates, c) + } + sort.SliceStable(file.CACertificates, func(i, j int) bool { + return compareOrder(file.CACertificates[i], file.CACertificates[j]) + }) + return nil +} + +func populateConsumers(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + consumers, err := kongState.Consumers.GetAll() + if err != nil { + return err + } + consumerGroups, err := kongState.ConsumerGroups.GetAll() + if err != nil { + return err + } + for _, c := range consumers { + c := FConsumer{Consumer: c.Consumer} + plugins, err := kongState.Plugins.GetAllByConsumerID(*c.ID) + if err != nil { + return err + } + for _, p := range plugins { + p := p + if p.Service != nil || p.Route != nil { + continue + } + utils.ZeroOutID(p, p.Name, config.WithID) + utils.ZeroOutTimestamps(p) + p.Consumer = nil + utils.MustRemoveTags(&p.Plugin, config.SelectTags) + c.Plugins = append(c.Plugins, &FPlugin{Plugin: p.Plugin}) + } + sort.SliceStable(c.Plugins, func(i, j int) bool { + return compareOrder(c.Plugins[i], c.Plugins[j]) + }) + // custom-entities associated with Consumer + keyAuths, err := kongState.KeyAuths.GetAllByConsumerID(*c.ID) + if err != nil { + return err + } + for _, k := range keyAuths { + k := k + utils.ZeroOutID(k, k.Key, config.WithID) + utils.ZeroOutTimestamps(k) + utils.MustRemoveTags(k, config.SelectTags) + k.Consumer = nil + c.KeyAuths = append(c.KeyAuths, &k.KeyAuth) + } + hmacAuth, err := kongState.HMACAuths.GetAllByConsumerID(*c.ID) + if err != nil { + return err + } + for _, k := range hmacAuth { + k := k + k.Consumer = nil + utils.ZeroOutID(k, k.Username, config.WithID) + utils.ZeroOutTimestamps(k) + utils.MustRemoveTags(k, config.SelectTags) + c.HMACAuths = append(c.HMACAuths, &k.HMACAuth) + } + jwtSecrets, err := kongState.JWTAuths.GetAllByConsumerID(*c.ID) + if err != nil { + return err + } + for _, k := range jwtSecrets { + k := k + k.Consumer = nil + utils.ZeroOutID(k, k.Key, config.WithID) + utils.ZeroOutTimestamps(k) + utils.MustRemoveTags(k, config.SelectTags) + c.JWTAuths = append(c.JWTAuths, &k.JWTAuth) + } + basicAuths, err := kongState.BasicAuths.GetAllByConsumerID(*c.ID) + if err != nil { + return err + } + for _, k := range basicAuths { + k := k + k.Consumer = nil + utils.ZeroOutID(k, k.Username, config.WithID) + utils.ZeroOutTimestamps(k) + utils.MustRemoveTags(k, config.SelectTags) + c.BasicAuths = append(c.BasicAuths, &k.BasicAuth) + } + oauth2Creds, err := kongState.Oauth2Creds.GetAllByConsumerID(*c.ID) + if err != nil { + return err + } + for _, k := range oauth2Creds { + k := k + k.Consumer = nil + utils.ZeroOutID(k, k.ClientID, config.WithID) + utils.ZeroOutTimestamps(k) + utils.MustRemoveTags(k, config.SelectTags) + c.Oauth2Creds = append(c.Oauth2Creds, &k.Oauth2Credential) + } + aclGroups, err := kongState.ACLGroups.GetAllByConsumerID(*c.ID) + if err != nil { + return err + } + for _, k := range aclGroups { + k := k + k.Consumer = nil + utils.ZeroOutID(k, k.Group, config.WithID) + utils.ZeroOutTimestamps(k) + utils.MustRemoveTags(k, config.SelectTags) + c.ACLGroups = append(c.ACLGroups, &k.ACLGroup) + } + mtlsAuths, err := kongState.MTLSAuths.GetAllByConsumerID(*c.ID) + if err != nil { + return err + } + for _, k := range mtlsAuths { + k := k + utils.ZeroOutTimestamps(k) + utils.MustRemoveTags(k, config.SelectTags) + k.Consumer = nil + c.MTLSAuths = append(c.MTLSAuths, &k.MTLSAuth) + } + // populate groups + for _, cg := range consumerGroups { + cg := *cg + _, err := kongState.ConsumerGroupConsumers.Get(*c.ID, *cg.ID) + if err != nil { + if !errors.Is(err, state.ErrNotFound) { + return err + } + continue + } + utils.ZeroOutID(&cg, cg.Name, config.WithID) + utils.ZeroOutTimestamps(&cg) + utils.MustRemoveTags(&cg.ConsumerGroup, config.SelectTags) + c.Groups = append(c.Groups, cg.DeepCopy()) + } + sort.SliceStable(c.Plugins, func(i, j int) bool { + return compareOrder(c.Plugins[i], c.Plugins[j]) + }) + utils.ZeroOutID(&c, c.Username, config.WithID) + utils.ZeroOutTimestamps(&c) + utils.MustRemoveTags(&c.Consumer, config.SelectTags) + file.Consumers = append(file.Consumers, c) + } + rbacRoles, err := kongState.RBACRoles.GetAll() + if err != nil { + return err + } + for _, r := range rbacRoles { + r := FRBACRole{RBACRole: r.RBACRole} + eps, err := kongState.RBACEndpointPermissions.GetAllByRoleID(*r.ID) + if err != nil { + return err + } + for _, ep := range eps { + ep.Role = nil + utils.ZeroOutTimestamps(ep) + r.EndpointPermissions = append( + r.EndpointPermissions, &FRBACEndpointPermission{RBACEndpointPermission: ep.RBACEndpointPermission}) + } + utils.ZeroOutID(&r, r.Name, config.WithID) + utils.ZeroOutTimestamps(&r) + file.RBACRoles = append(file.RBACRoles, r) + } + sort.SliceStable(file.Consumers, func(i, j int) bool { + return compareOrder(file.Consumers[i], file.Consumers[j]) + }) + return nil +} + +func populateConsumerGroups(kongState *state.KongState, file *Content, + config WriteConfig, +) error { + consumerGroups, err := kongState.ConsumerGroups.GetAll() + if err != nil { + return err + } + cgPlugins, err := kongState.ConsumerGroupPlugins.GetAll() + if err != nil { + return err + } + for _, cg := range consumerGroups { + group := FConsumerGroupObject{ConsumerGroup: cg.ConsumerGroup} + for _, plugin := range cgPlugins { + if plugin.ID != nil && cg.ID != nil { + plugin := plugin + if plugin.ConsumerGroup != nil && *plugin.ConsumerGroup.ID == *cg.ID { + utils.ZeroOutID(plugin, plugin.Name, config.WithID) + utils.ZeroOutID(plugin.ConsumerGroup, plugin.ConsumerGroup.Name, config.WithID) + utils.ZeroOutTimestamps(plugin.ConsumerGroupPlugin.ConsumerGroup) + utils.ZeroOutField(&plugin.ConsumerGroupPlugin, "ConsumerGroup") + group.Plugins = append(group.Plugins, &plugin.ConsumerGroupPlugin) + } + } + } + + plugins, err := kongState.Plugins.GetAllByConsumerGroupID(*cg.ID) + if err != nil { + return err + } + for _, plugin := range plugins { + if plugin.ID != nil && cg.ID != nil { + if plugin.ConsumerGroup != nil && *plugin.ConsumerGroup.ID == *cg.ID { + utils.ZeroOutID(plugin, plugin.Name, config.WithID) + utils.ZeroOutID(plugin.ConsumerGroup, plugin.ConsumerGroup.Name, config.WithID) + group.Plugins = append(group.Plugins, &kong.ConsumerGroupPlugin{ + ID: plugin.ID, + Name: plugin.Name, + Config: plugin.Config, + }) + } + } + } + + utils.ZeroOutID(&group, group.Name, config.WithID) + utils.ZeroOutTimestamps(&group) + file.ConsumerGroups = append(file.ConsumerGroups, group) + } + sort.SliceStable(file.ConsumerGroups, func(i, j int) bool { + return compareOrder(file.ConsumerGroups[i], file.ConsumerGroups[j]) + }) + return nil +} + +func WriteContentToFile(content *Content, filename string, format Format) error { + var c []byte + var err error + switch format { + case YAML: + c, err = yaml.Marshal(content) + if err != nil { + return err + } + case JSON: + c, err = json.MarshalIndent(content, "", " ") + if err != nil { + return err + } + default: + return fmt.Errorf("unknown file format: " + string(format)) + } + + if filename == "-" { + if _, err := fmt.Print(string(c)); err != nil { + return fmt.Errorf("writing file: %w", err) + } + } else { + filename = utils.AddExtToFilename(filename, strings.ToLower(string(format))) + prefix, _ := filepath.Split(filename) + if err := ioutil.WriteFile(filename, c, 0o600); err != nil { + return fmt.Errorf("writing file: %w", err) + } + for _, sp := range content.ServicePackages { + if sp.Document != nil { + if err := os.MkdirAll(filepath.Join(prefix, filepath.Dir(*sp.Document.Path)), 0o700); err != nil { + return fmt.Errorf("creating document directory: %w", err) + } + if err := os.WriteFile(filepath.Join(prefix, *sp.Document.Path), + []byte(*sp.Document.Content), 0o600); err != nil { + return fmt.Errorf("writing document file: %w", err) + } + } + for _, v := range sp.Versions { + if v.Document != nil { + if err := os.MkdirAll(filepath.Join(prefix, filepath.Dir(*v.Document.Path)), 0o700); err != nil { + return fmt.Errorf("creating document directory: %w", err) + } + if err := os.WriteFile(filepath.Join(prefix, *v.Document.Path), + []byte(*v.Document.Content), 0o600); err != nil { + return fmt.Errorf("writing document file: %w", err) + } + } + } + } + } + return nil +} diff --git a/pkg/file/writer_test.go b/pkg/file/writer_test.go new file mode 100644 index 0000000..fea1cb0 --- /dev/null +++ b/pkg/file/writer_test.go @@ -0,0 +1,405 @@ +package file + +import ( + "bytes" + "fmt" + "io" + "os" + "sync" + "testing" + + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func captureOutput(f func()) string { + reader, writer, err := os.Pipe() + if err != nil { + panic(err) + } + stdout := os.Stdout + stderr := os.Stderr + defer func() { + os.Stdout = stdout + os.Stderr = stderr + }() + os.Stdout = writer + os.Stderr = writer + + out := make(chan string) + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + var buf bytes.Buffer + wg.Done() + io.Copy(&buf, reader) + out <- buf.String() + }() + wg.Wait() + f() + writer.Close() + return <-out +} + +func Test_compareOrder(t *testing.T) { + tests := []struct { + name string + sortable1 sortable + sortable2 sortable + expected bool + }{ + { + sortable1: &FService{ + Service: kong.Service{ + Name: kong.String("my-service-1"), + ID: kong.String("my-id-1"), + }, + }, + sortable2: &FService{ + Service: kong.Service{ + Name: kong.String("my-service-2"), + ID: kong.String("my-id-2"), + }, + }, + expected: true, + }, + + { + sortable1: &FRoute{ + Route: kong.Route{ + Name: kong.String("my-route-1"), + ID: kong.String("my-id-1"), + }, + }, + sortable2: &FRoute{ + Route: kong.Route{ + Name: kong.String("my-route-2"), + ID: kong.String("my-id-2"), + }, + }, + expected: true, + }, + + { + sortable1: FUpstream{ + Upstream: kong.Upstream{ + Name: kong.String("my-upstream-1"), + ID: kong.String("my-id-1"), + }, + }, + sortable2: FUpstream{ + Upstream: kong.Upstream{ + Name: kong.String("my-upstream-2"), + ID: kong.String("my-id-2"), + }, + }, + expected: true, + }, + + { + sortable1: FTarget{ + Target: kong.Target{ + Target: kong.String("my-target-1"), + ID: kong.String("my-id-1"), + }, + }, + sortable2: FTarget{ + Target: kong.Target{ + Target: kong.String("my-target-2"), + ID: kong.String("my-id-2"), + }, + }, + expected: true, + }, + + { + sortable1: FCertificate{ + Cert: kong.String("my-certificate-1"), + ID: kong.String("my-id-1"), + }, + sortable2: FCertificate{ + Cert: kong.String("my-certificate-2"), + ID: kong.String("my-id-2"), + }, + expected: true, + }, + + { + sortable1: FCACertificate{ + CACertificate: kong.CACertificate{ + Cert: kong.String("my-ca-certificate-1"), + ID: kong.String("my-id-1"), + }, + }, + sortable2: FCACertificate{ + CACertificate: kong.CACertificate{ + Cert: kong.String("my-ca-certificate-2"), + ID: kong.String("my-id-2"), + }, + }, + expected: true, + }, + + { + sortable1: FPlugin{ + Plugin: kong.Plugin{ + Name: kong.String("my-plugin-1"), + ID: kong.String("my-id-1"), + }, + }, + sortable2: FPlugin{ + Plugin: kong.Plugin{ + Name: kong.String("my-plugin-2"), + ID: kong.String("my-id-2"), + }, + }, + expected: true, + }, + + { + sortable1: &FConsumer{ + Consumer: kong.Consumer{ + Username: kong.String("my-consumer-1"), + ID: kong.String("my-id-2"), + }, + }, + sortable2: &FConsumer{ + Consumer: kong.Consumer{ + Username: kong.String("my-consumer-2"), + ID: kong.String("my-id-2"), + }, + }, + expected: true, + }, + + { + sortable1: &FServicePackage{ + Name: kong.String("my-service-package-1"), + ID: kong.String("my-id-1"), + }, + sortable2: &FServicePackage{ + Name: kong.String("my-service-package-2"), + ID: kong.String("my-id-2"), + }, + expected: true, + }, + { + sortable1: &FServiceVersion{ + Version: kong.String("my-service-version-1"), + ID: kong.String("my-id-1"), + }, + sortable2: &FServiceVersion{ + Version: kong.String("my-service-version-2"), + ID: kong.String("my-id-2"), + }, + expected: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if compareOrder(tt.sortable1, tt.sortable2) != tt.expected { + t.Errorf("Expected %v, but isn't", tt.expected) + } + }) + } +} + +func TestWriteKongStateToStdoutEmptyState(t *testing.T) { + ks, _ := state.NewKongState() + filename := "-" + assert := assert.New(t) + assert.Equal("-", filename) + assert.NotEmpty(t, ks) + // YAML + output := captureOutput(func() { + KongStateToFile(ks, WriteConfig{ + Workspace: "foo", + Filename: filename, + FileFormat: YAML, + KongVersion: "2.8.0", + }) + }) + assert.Equal("_format_version: \"1.1\"\n_workspace: foo\n", output) + // JSON + output = captureOutput(func() { + KongStateToFile(ks, WriteConfig{ + Workspace: "foo", + Filename: filename, + FileFormat: JSON, + KongVersion: "2.8.0", + }) + }) + expected := `{ + "_format_version": "1.1", + "_workspace": "foo" +}` + assert.Equal(expected, output) +} + +func TestWriteKongStateToStdoutStateWithOneService(t *testing.T) { + ks, _ := state.NewKongState() + filename := "-" + assert := assert.New(t) + var service state.Service + service.ID = kong.String("first") + service.Host = kong.String("example.com") + service.Name = kong.String("my-service") + ks.Services.Add(service) + // YAML + output := captureOutput(func() { + KongStateToFile(ks, WriteConfig{ + Filename: filename, + FileFormat: YAML, + KongVersion: "3.0.0", + }) + }) + expected := fmt.Sprintf("_format_version: \"3.0\"\nservices:\n- host: %s\n name: %s\n", *service.Host, *service.Name) + assert.Equal(expected, output) + // JSON + output = captureOutput(func() { + KongStateToFile(ks, WriteConfig{ + Workspace: "foo", + Filename: filename, + FileFormat: JSON, + KongVersion: "3.0.0", + }) + }) + expected = `{ + "_format_version": "3.0", + "_workspace": "foo", + "services": [ + { + "host": "example.com", + "name": "my-service" + } + ] +}` + assert.Equal(expected, output) +} + +func TestWriteKongStateToStdoutStateWithOneServiceOneRoute(t *testing.T) { + ks, _ := state.NewKongState() + filename := "-" + assert := assert.New(t) + var service state.Service + service.ID = kong.String("first") + service.Host = kong.String("example.com") + service.Name = kong.String("my-service") + ks.Services.Add(service) + + var route state.Route + route.Name = kong.String("my-route") + route.ID = kong.String("first") + route.Hosts = kong.StringSlice("example.com", "demo.example.com") + route.Service = &kong.Service{ + ID: kong.String(*service.ID), + Name: kong.String(*service.Name), + } + + ks.Routes.Add(route) + // YAML + output := captureOutput(func() { + KongStateToFile(ks, WriteConfig{ + Filename: filename, + FileFormat: YAML, + KongVersion: "2.8.0", + }) + }) + expected := fmt.Sprintf(`_format_version: "1.1" +services: +- host: %s + name: %s + routes: + - hosts: + - %s + - %s + name: %s +`, *service.Host, *service.Name, *route.Hosts[0], *route.Hosts[1], *route.Name) + assert.Equal(expected, output) + // JSON + output = captureOutput(func() { + KongStateToFile(ks, WriteConfig{ + Workspace: "foo", + Filename: filename, + FileFormat: JSON, + KongVersion: "2.8.0", + }) + }) + expected = `{ + "_format_version": "1.1", + "_workspace": "foo", + "services": [ + { + "host": "example.com", + "name": "my-service", + "routes": [ + { + "hosts": [ + "example.com", + "demo.example.com" + ], + "name": "my-route" + } + ] + } + ] +}` + assert.Equal(expected, output) +} + +func Test_getFormatVersion(t *testing.T) { + tests := []struct { + name string + kongVersion string + expected string + expectedErr string + wantErr bool + }{ + { + name: "3.0.0 version", + kongVersion: "3.0.0", + expected: "3.0", + }, + { + name: "3.0.0.0 version", + kongVersion: "3.0.0.0", + expected: "3.0", + }, + { + name: "2.8.0 version", + kongVersion: "2.8.0", + expected: "1.1", + }, + { + name: "2.8.0.0 version", + kongVersion: "2.8.0.0", + expected: "1.1", + }, + { + name: "2.8.0.1-enterprise-edition version", + kongVersion: "2.8.0.1-enterprise-edition", + expected: "1.1", + }, + { + name: "unsupported version", + kongVersion: "test", + wantErr: true, + expectedErr: "parsing Kong version: unknown Kong version", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res, err := getFormatVersion(tt.kongVersion) + if (err != nil) != tt.wantErr { + t.Errorf("got error = %v, expected error = %v", err, tt.wantErr) + } + if tt.expectedErr != "" { + assert.Equal(t, err.Error(), tt.expectedErr) + } + if res != tt.expected { + t.Errorf("Expected %v, but isn't: %v", tt.expected, res) + } + }) + } +} diff --git a/pkg/file/zz_generated.deepcopy.go b/pkg/file/zz_generated.deepcopy.go new file mode 100644 index 0000000..7185893 --- /dev/null +++ b/pkg/file/zz_generated.deepcopy.go @@ -0,0 +1,823 @@ +//go:build !ignore_autogenerated +// +build !ignore_autogenerated + +/* +Copyright 2021 Kong Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by deepcopy-gen. DO NOT EDIT. + +package file + +import ( + kong "github.com/kong/go-kong/kong" +) + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Content) DeepCopyInto(out *Content) { + *out = *in + if in.Transform != nil { + in, out := &in.Transform, &out.Transform + *out = new(bool) + **out = **in + } + if in.Info != nil { + in, out := &in.Info, &out.Info + *out = new(Info) + (*in).DeepCopyInto(*out) + } + if in.Konnect != nil { + in, out := &in.Konnect, &out.Konnect + *out = new(Konnect) + **out = **in + } + if in.Services != nil { + in, out := &in.Services, &out.Services + *out = make([]FService, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Routes != nil { + in, out := &in.Routes, &out.Routes + *out = make([]FRoute, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Consumers != nil { + in, out := &in.Consumers, &out.Consumers + *out = make([]FConsumer, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.ConsumerGroups != nil { + in, out := &in.ConsumerGroups, &out.ConsumerGroups + *out = make([]FConsumerGroupObject, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Plugins != nil { + in, out := &in.Plugins, &out.Plugins + *out = make([]FPlugin, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Upstreams != nil { + in, out := &in.Upstreams, &out.Upstreams + *out = make([]FUpstream, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Certificates != nil { + in, out := &in.Certificates, &out.Certificates + *out = make([]FCertificate, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.CACertificates != nil { + in, out := &in.CACertificates, &out.CACertificates + *out = make([]FCACertificate, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.RBACRoles != nil { + in, out := &in.RBACRoles, &out.RBACRoles + *out = make([]FRBACRole, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.PluginConfigs != nil { + in, out := &in.PluginConfigs, &out.PluginConfigs + *out = make(map[string]kong.Configuration, len(*in)) + for key, val := range *in { + (*out)[key] = val.DeepCopy() + } + } + if in.ServicePackages != nil { + in, out := &in.ServicePackages, &out.ServicePackages + *out = make([]FServicePackage, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Vaults != nil { + in, out := &in.Vaults, &out.Vaults + *out = make([]FVault, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Licenses != nil { + in, out := &in.Licenses, &out.Licenses + *out = make([]FLicense, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Content. +func (in *Content) DeepCopy() *Content { + if in == nil { + return nil + } + out := new(Content) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FCACertificate) DeepCopyInto(out *FCACertificate) { + *out = *in + in.CACertificate.DeepCopyInto(&out.CACertificate) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FCACertificate. +func (in *FCACertificate) DeepCopy() *FCACertificate { + if in == nil { + return nil + } + out := new(FCACertificate) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FCertificate) DeepCopyInto(out *FCertificate) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.Cert != nil { + in, out := &in.Cert, &out.Cert + *out = new(string) + **out = **in + } + if in.Key != nil { + in, out := &in.Key, &out.Key + *out = new(string) + **out = **in + } + if in.CreatedAt != nil { + in, out := &in.CreatedAt, &out.CreatedAt + *out = new(int64) + **out = **in + } + if in.Tags != nil { + in, out := &in.Tags, &out.Tags + *out = make([]*string, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(string) + **out = **in + } + } + } + if in.SNIs != nil { + in, out := &in.SNIs, &out.SNIs + *out = make([]kong.SNI, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FCertificate. +func (in *FCertificate) DeepCopy() *FCertificate { + if in == nil { + return nil + } + out := new(FCertificate) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FConsumer) DeepCopyInto(out *FConsumer) { + *out = *in + in.Consumer.DeepCopyInto(&out.Consumer) + if in.Plugins != nil { + in, out := &in.Plugins, &out.Plugins + *out = make([]*FPlugin, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(FPlugin) + (*in).DeepCopyInto(*out) + } + } + } + if in.KeyAuths != nil { + in, out := &in.KeyAuths, &out.KeyAuths + *out = make([]*kong.KeyAuth, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.KeyAuth) + (*in).DeepCopyInto(*out) + } + } + } + if in.HMACAuths != nil { + in, out := &in.HMACAuths, &out.HMACAuths + *out = make([]*kong.HMACAuth, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.HMACAuth) + (*in).DeepCopyInto(*out) + } + } + } + if in.JWTAuths != nil { + in, out := &in.JWTAuths, &out.JWTAuths + *out = make([]*kong.JWTAuth, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.JWTAuth) + (*in).DeepCopyInto(*out) + } + } + } + if in.BasicAuths != nil { + in, out := &in.BasicAuths, &out.BasicAuths + *out = make([]*kong.BasicAuth, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.BasicAuth) + (*in).DeepCopyInto(*out) + } + } + } + if in.Oauth2Creds != nil { + in, out := &in.Oauth2Creds, &out.Oauth2Creds + *out = make([]*kong.Oauth2Credential, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.Oauth2Credential) + (*in).DeepCopyInto(*out) + } + } + } + if in.ACLGroups != nil { + in, out := &in.ACLGroups, &out.ACLGroups + *out = make([]*kong.ACLGroup, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.ACLGroup) + (*in).DeepCopyInto(*out) + } + } + } + if in.MTLSAuths != nil { + in, out := &in.MTLSAuths, &out.MTLSAuths + *out = make([]*kong.MTLSAuth, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.MTLSAuth) + (*in).DeepCopyInto(*out) + } + } + } + if in.Groups != nil { + in, out := &in.Groups, &out.Groups + *out = make([]*kong.ConsumerGroup, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.ConsumerGroup) + (*in).DeepCopyInto(*out) + } + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FConsumer. +func (in *FConsumer) DeepCopy() *FConsumer { + if in == nil { + return nil + } + out := new(FConsumer) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FConsumerGroupObject) DeepCopyInto(out *FConsumerGroupObject) { + *out = *in + in.ConsumerGroup.DeepCopyInto(&out.ConsumerGroup) + if in.Consumers != nil { + in, out := &in.Consumers, &out.Consumers + *out = make([]*kong.Consumer, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.Consumer) + (*in).DeepCopyInto(*out) + } + } + } + if in.Plugins != nil { + in, out := &in.Plugins, &out.Plugins + *out = make([]*kong.ConsumerGroupPlugin, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(kong.ConsumerGroupPlugin) + (*in).DeepCopyInto(*out) + } + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FConsumerGroupObject. +func (in *FConsumerGroupObject) DeepCopy() *FConsumerGroupObject { + if in == nil { + return nil + } + out := new(FConsumerGroupObject) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FDocument) DeepCopyInto(out *FDocument) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.Path != nil { + in, out := &in.Path, &out.Path + *out = new(string) + **out = **in + } + if in.Published != nil { + in, out := &in.Published, &out.Published + *out = new(bool) + **out = **in + } + if in.Content != nil { + in, out := &in.Content, &out.Content + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FDocument. +func (in *FDocument) DeepCopy() *FDocument { + if in == nil { + return nil + } + out := new(FDocument) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FLicense) DeepCopyInto(out *FLicense) { + *out = *in + in.License.DeepCopyInto(&out.License) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FLicense. +func (in *FLicense) DeepCopy() *FLicense { + if in == nil { + return nil + } + out := new(FLicense) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FPlugin) DeepCopyInto(out *FPlugin) { + *out = *in + in.Plugin.DeepCopyInto(&out.Plugin) + if in.ConfigSource != nil { + in, out := &in.ConfigSource, &out.ConfigSource + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FPlugin. +func (in *FPlugin) DeepCopy() *FPlugin { + if in == nil { + return nil + } + out := new(FPlugin) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FRBACEndpointPermission) DeepCopyInto(out *FRBACEndpointPermission) { + *out = *in + in.RBACEndpointPermission.DeepCopyInto(&out.RBACEndpointPermission) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FRBACEndpointPermission. +func (in *FRBACEndpointPermission) DeepCopy() *FRBACEndpointPermission { + if in == nil { + return nil + } + out := new(FRBACEndpointPermission) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FRBACRole) DeepCopyInto(out *FRBACRole) { + *out = *in + in.RBACRole.DeepCopyInto(&out.RBACRole) + if in.EndpointPermissions != nil { + in, out := &in.EndpointPermissions, &out.EndpointPermissions + *out = make([]*FRBACEndpointPermission, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(FRBACEndpointPermission) + (*in).DeepCopyInto(*out) + } + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FRBACRole. +func (in *FRBACRole) DeepCopy() *FRBACRole { + if in == nil { + return nil + } + out := new(FRBACRole) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FRoute) DeepCopyInto(out *FRoute) { + *out = *in + in.Route.DeepCopyInto(&out.Route) + if in.Plugins != nil { + in, out := &in.Plugins, &out.Plugins + *out = make([]*FPlugin, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(FPlugin) + (*in).DeepCopyInto(*out) + } + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FRoute. +func (in *FRoute) DeepCopy() *FRoute { + if in == nil { + return nil + } + out := new(FRoute) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FService) DeepCopyInto(out *FService) { + *out = *in + in.Service.DeepCopyInto(&out.Service) + if in.Routes != nil { + in, out := &in.Routes, &out.Routes + *out = make([]*FRoute, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(FRoute) + (*in).DeepCopyInto(*out) + } + } + } + if in.Plugins != nil { + in, out := &in.Plugins, &out.Plugins + *out = make([]*FPlugin, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(FPlugin) + (*in).DeepCopyInto(*out) + } + } + } + if in.URL != nil { + in, out := &in.URL, &out.URL + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FService. +func (in *FService) DeepCopy() *FService { + if in == nil { + return nil + } + out := new(FService) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FServicePackage) DeepCopyInto(out *FServicePackage) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.Name != nil { + in, out := &in.Name, &out.Name + *out = new(string) + **out = **in + } + if in.Description != nil { + in, out := &in.Description, &out.Description + *out = new(string) + **out = **in + } + if in.Versions != nil { + in, out := &in.Versions, &out.Versions + *out = make([]FServiceVersion, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.Document != nil { + in, out := &in.Document, &out.Document + *out = new(FDocument) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FServicePackage. +func (in *FServicePackage) DeepCopy() *FServicePackage { + if in == nil { + return nil + } + out := new(FServicePackage) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FServiceVersion) DeepCopyInto(out *FServiceVersion) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.Version != nil { + in, out := &in.Version, &out.Version + *out = new(string) + **out = **in + } + if in.Implementation != nil { + in, out := &in.Implementation, &out.Implementation + *out = new(Implementation) + (*in).DeepCopyInto(*out) + } + if in.Document != nil { + in, out := &in.Document, &out.Document + *out = new(FDocument) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FServiceVersion. +func (in *FServiceVersion) DeepCopy() *FServiceVersion { + if in == nil { + return nil + } + out := new(FServiceVersion) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FTarget) DeepCopyInto(out *FTarget) { + *out = *in + in.Target.DeepCopyInto(&out.Target) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FTarget. +func (in *FTarget) DeepCopy() *FTarget { + if in == nil { + return nil + } + out := new(FTarget) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FUpstream) DeepCopyInto(out *FUpstream) { + *out = *in + in.Upstream.DeepCopyInto(&out.Upstream) + if in.Targets != nil { + in, out := &in.Targets, &out.Targets + *out = make([]*FTarget, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(FTarget) + (*in).DeepCopyInto(*out) + } + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FUpstream. +func (in *FUpstream) DeepCopy() *FUpstream { + if in == nil { + return nil + } + out := new(FUpstream) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FVault) DeepCopyInto(out *FVault) { + *out = *in + in.Vault.DeepCopyInto(&out.Vault) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FVault. +func (in *FVault) DeepCopy() *FVault { + if in == nil { + return nil + } + out := new(FVault) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Implementation) DeepCopyInto(out *Implementation) { + *out = *in + if in.Kong != nil { + in, out := &in.Kong, &out.Kong + *out = new(Kong) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Implementation. +func (in *Implementation) DeepCopy() *Implementation { + if in == nil { + return nil + } + out := new(Implementation) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Info) DeepCopyInto(out *Info) { + *out = *in + if in.SelectorTags != nil { + in, out := &in.SelectorTags, &out.SelectorTags + *out = make([]string, len(*in)) + copy(*out, *in) + } + in.Defaults.DeepCopyInto(&out.Defaults) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Info. +func (in *Info) DeepCopy() *Info { + if in == nil { + return nil + } + out := new(Info) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Kong) DeepCopyInto(out *Kong) { + *out = *in + if in.Service != nil { + in, out := &in.Service, &out.Service + *out = new(FService) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Kong. +func (in *Kong) DeepCopy() *Kong { + if in == nil { + return nil + } + out := new(Kong) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *KongDefaults) DeepCopyInto(out *KongDefaults) { + *out = *in + if in.Service != nil { + in, out := &in.Service, &out.Service + *out = new(kong.Service) + (*in).DeepCopyInto(*out) + } + if in.Route != nil { + in, out := &in.Route, &out.Route + *out = new(kong.Route) + (*in).DeepCopyInto(*out) + } + if in.Upstream != nil { + in, out := &in.Upstream, &out.Upstream + *out = new(kong.Upstream) + (*in).DeepCopyInto(*out) + } + if in.Target != nil { + in, out := &in.Target, &out.Target + *out = new(kong.Target) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new KongDefaults. +func (in *KongDefaults) DeepCopy() *KongDefaults { + if in == nil { + return nil + } + out := new(KongDefaults) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Konnect) DeepCopyInto(out *Konnect) { + *out = *in + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Konnect. +func (in *Konnect) DeepCopy() *Konnect { + if in == nil { + return nil + } + out := new(Konnect) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/konnect/client.go b/pkg/konnect/client.go new file mode 100644 index 0000000..06a2117 --- /dev/null +++ b/pkg/konnect/client.go @@ -0,0 +1,179 @@ +package konnect + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httputil" + "net/url" + "os" +) + +var defaultCtx = context.Background() + +type service struct { + client *Client + controlPlaneID string + + runtimeGroupID string +} + +// Client talks to the Konnect API. +type Client struct { + client *http.Client + baseURL string + common service + Auth *AuthService + ServicePackages *ServicePackageService + ServiceVersions *ServiceVersionService + Documents *DocumentService + ControlPlanes *ControlPlaneService + ControlPlaneRelations *ControlPlaneRelationsService + logger io.Writer + debug bool + token string + + RuntimeGroups *RuntimeGroupService +} + +// ClientOpts contains configuration options for a new Client. +type ClientOpts struct { + BaseURL string +} + +// NewClient returns a Client which talks to Konnect's API. +func NewClient(httpClient *http.Client, opts ClientOpts) (*Client, error) { + if httpClient == nil { + httpClient = http.DefaultClient + } + client := new(Client) + client.client = httpClient + url, err := url.ParseRequestURI(opts.BaseURL) + if err != nil { + return nil, fmt.Errorf("parsing URL: %w", err) + } + client.baseURL = url.String() + + client.common.client = client + client.Auth = (*AuthService)(&client.common) + client.ServicePackages = (*ServicePackageService)(&client.common) + client.ServiceVersions = (*ServiceVersionService)(&client.common) + client.Documents = (*DocumentService)(&client.common) + client.ControlPlanes = (*ControlPlaneService)(&client.common) + client.ControlPlaneRelations = (*ControlPlaneRelationsService)(&client.common) + client.logger = os.Stderr + + client.RuntimeGroups = (*RuntimeGroupService)(&client.common) + return client, nil +} + +// SetControlPlaneID sets the kong control-plane ID in the client. +// This is used to inject the control-plane ID in requests as needed. +func (c *Client) SetControlPlaneID(cpID string) { + c.common.controlPlaneID = cpID +} + +// SetControlPlaneID sets the konnect runtime-group ID in the client. +func (c *Client) SetRuntimeGroupID(rgID string) { + c.common.runtimeGroupID = rgID +} + +// Do executes a HTTP request and returns a response +func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*http.Response, error) { + var err error + if req == nil { + return nil, fmt.Errorf("request cannot be nil") + } + if ctx == nil { + ctx = defaultCtx + } + req = req.WithContext(ctx) + + // log the request + err = c.logRequest(req) + if err != nil { + return nil, err + } + + // Make the request + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("making HTTP request: %w", err) + } + + // log the response + err = c.logResponse(resp) + if err != nil { + return nil, err + } + + // check for API errors + if err = hasError(resp); err != nil { + return resp, err + } + + // Call Close on exit + defer func() { + e := resp.Body.Close() + if e != nil { + err = e + } + }() + + // response + if v != nil { + if writer, ok := v.(io.Writer); ok { + _, err = io.Copy(writer, resp.Body) + if err != nil { + return nil, err + } + } else { + err = json.NewDecoder(resp.Body).Decode(v) + if err != nil { + return nil, err + } + } + } + return resp, err +} + +// SetDebugMode enables or disables logging of +// the request to the logger set by SetLogger(). +// By default, debug logging is disabled. +func (c *Client) SetDebugMode(enableDebug bool) { + c.debug = enableDebug +} + +func (c *Client) logRequest(r *http.Request) error { + if !c.debug { + return nil + } + dump, err := httputil.DumpRequestOut(r, true) + if err != nil { + return err + } + _, err = c.logger.Write(append(dump, '\n')) + return err +} + +func (c *Client) logResponse(r *http.Response) error { + if !c.debug { + return nil + } + dump, err := httputil.DumpResponse(r, true) + if err != nil { + return err + } + _, err = c.logger.Write(append(dump, '\n')) + return err +} + +// SetLogger sets the debug logger, defaults to os.StdErr +func (c *Client) SetLogger(w io.Writer) { + if w == nil { + return + } + c.logger = w +} diff --git a/pkg/konnect/consumer_group.go b/pkg/konnect/consumer_group.go new file mode 100644 index 0000000..f379c9d --- /dev/null +++ b/pkg/konnect/consumer_group.go @@ -0,0 +1,472 @@ +package konnect + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/kong/go-kong/kong" +) + +type PageOpt struct { + // request + Size int `url:"size,omitempty"` + Number *int `url:"number,omitempty"` + + // response + NextPageNum *int `url:"next_page_num,omitempty"` + TotalCount *int `url:"total_count,omitempty"` +} + +type KonnectListOpt struct { //nolint:revive + // Size of the page + Page *PageOpt + + // Tags to use for filtering the list. + Tags []*string `url:"tags,omitempty"` +} + +type RLAOverride struct { + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Value kong.Configuration `json:"value,omitempty" yaml:"value,omitempty"` +} + +type konnectResponseObj struct { + Item kong.ConsumerGroup `json:"item,omitempty" yaml:"item,omitempty"` + Page *PageOpt +} + +type konnectRLAObj struct { + ConsumerGroupID *string `json:"consumer_group_id,omitempty" yaml:"consumer_group_id,omitempty"` + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Limit []*int32 `json:"limit,omitempty" yaml:"limit,omitempty"` + WindowSize []*int32 `json:"window_size,omitempty" yaml:"window_size,omitempty"` + WindowType *string `json:"window_type,omitempty" yaml:"window_type,omitempty"` + RetryAfterJitterMax *int32 `json:"retry_after_jitter_max,omitempty" yaml:"retry_after_jitter_max,omitempty"` +} + +type konnectRLAResponseObj struct { + Item konnectRLAObj `json:"item,omitempty" yaml:"item,omitempty"` +} + +func isEmptyString(s *string) bool { + return s == nil || strings.TrimSpace(*s) == "" +} + +func CreateConsumerGroup(ctx context.Context, client *kong.Client, entity interface{}) (*kong.ConsumerGroup, error) { + endpoint := "/v1/consumer-groups" + req, err := client.NewRequest(http.MethodPost, endpoint, nil, entity) + if err != nil { + return nil, err + } + var cg konnectResponseObj + _, err = client.Do(ctx, req, &cg) + if err != nil { + return nil, err + } + return &cg.Item, nil +} + +func UpdateConsumerGroup(ctx context.Context, client *kong.Client, + cgID *string, entity interface{}, +) (*kong.ConsumerGroup, error) { + if isEmptyString(cgID) { + return nil, fmt.Errorf("update consumer-group: consumer-group ID cannot be nil") + } + endpoint := fmt.Sprintf("/v1/consumer-groups/%v", *cgID) + req, err := client.NewRequest(http.MethodPut, endpoint, nil, entity) + if err != nil { + return nil, err + } + var cg konnectResponseObj + _, err = client.Do(ctx, req, &cg) + if err != nil { + return nil, err + } + return &cg.Item, nil +} + +// GetConsumerGroup fetches a ConsumerGroup from Konnect. +func GetConsumerGroup(ctx context.Context, + client *kong.Client, nameOrID *string, +) (*kong.ConsumerGroup, error) { + if isEmptyString(nameOrID) { + return nil, fmt.Errorf("getting consumer-group: nameOrID cannot be nil") + } + + endpoint := fmt.Sprintf("/v1/consumer-groups/%v", *nameOrID) + req, err := client.NewRequest("GET", endpoint, nil, nil) + if err != nil { + return nil, err + } + + var cg konnectResponseObj + _, err = client.Do(ctx, req, &cg) + if err != nil { + return nil, err + } + return &cg.Item, nil +} + +// ListAllConsumerGroupMembers fetches all ConsumerGroups members from Konnect. +func ListAllConsumerGroupMembers( + ctx context.Context, client *kong.Client, cgID *string, +) ([]*kong.Consumer, error) { + if isEmptyString(cgID) { + return nil, fmt.Errorf("list consumer-group-members: consumer-group ID cannot be nil") + } + var members, data []*kong.Consumer + var err error + opt := &KonnectListOpt{Page: &PageOpt{Size: 100}} + for opt != nil { + endpoint := fmt.Sprintf("/v1/consumer-groups/%v/members", *cgID) + data, opt, err = ListConsumerGroupMembers(ctx, client, endpoint, opt) + if err != nil { + return nil, err + } + members = append(members, data...) + } + return members, nil +} + +func upsertRateLimitingAdvancedPlugin( + ctx context.Context, client *kong.Client, id string, config kong.Configuration, method string, +) (*kong.ConsumerGroupRLA, error) { + endpoint := fmt.Sprintf("/v1/consumer-groups/%v/rate-limiting-advanced-config", id) + req, err := client.NewRequest(method, endpoint, nil, config) + if err != nil { + return nil, err + } + var rla konnectRLAResponseObj + _, err = client.Do(ctx, req, &rla) + if err != nil { + return nil, err + } + rlaConfig := kong.Configuration{} + if rla.Item.Limit != nil { + rlaConfig["limit"] = rla.Item.Limit + } + if rla.Item.WindowSize != nil { + rlaConfig["window_size"] = rla.Item.WindowSize + } + if rla.Item.WindowType != nil { + rlaConfig["window_type"] = rla.Item.WindowType + } + if rla.Item.RetryAfterJitterMax != nil { + rlaConfig["retry_after_jitter_max"] = rla.Item.RetryAfterJitterMax + } + return &kong.ConsumerGroupRLA{ + Plugin: kong.String("rate-limiting-advanced"), + ConsumerGroup: rla.Item.ConsumerGroupID, + Config: rlaConfig, + }, nil +} + +func CreateRateLimitingAdvancedPlugin( + ctx context.Context, client *kong.Client, cgID *string, config kong.Configuration, +) (*kong.ConsumerGroupRLA, error) { + if isEmptyString(cgID) { + return nil, fmt.Errorf("create consumer-group override: consumer-group ID cannot be nil") + } + return upsertRateLimitingAdvancedPlugin( + ctx, client, *cgID, config, http.MethodPost, + ) +} + +func UpdateRateLimitingAdvancedPlugin( + ctx context.Context, client *kong.Client, cgID *string, config kong.Configuration, +) (*kong.ConsumerGroupRLA, error) { + if isEmptyString(cgID) { + return nil, fmt.Errorf("update consumer-group override: consumer-group ID cannot be nil") + } + return upsertRateLimitingAdvancedPlugin( + ctx, client, *cgID, config, http.MethodPut, + ) +} + +// GetConsumerGroupRateLimitingAdvancedPlugin fetches the RLA override for +// a ConsumerGroup from Konnect. +func GetConsumerGroupRateLimitingAdvancedPlugin( + ctx context.Context, client *kong.Client, cgID *string, +) (*kong.ConsumerGroupPlugin, error) { + if isEmptyString(cgID) { + return nil, fmt.Errorf("get consumer-group override: consumer-group ID cannot be nil") + } + endpoint := fmt.Sprintf("/v1/consumer-groups/%v/rate-limiting-advanced-config", *cgID) + req, err := client.NewRequest("GET", endpoint, nil, nil) + if err != nil { + return nil, err + } + + var rla konnectRLAResponseObj + res, err := client.Do(ctx, req, &rla) + if err != nil { + // Konnect returns a 404 if no plugin exists yet. + if res.StatusCode == http.StatusNotFound { + return nil, nil + } + return nil, err + } + + config := kong.Configuration{} + if rla.Item.Limit != nil { + config["limit"] = rla.Item.Limit + } + if rla.Item.WindowSize != nil { + config["window_size"] = rla.Item.WindowSize + } + if rla.Item.WindowType != nil { + config["window_type"] = rla.Item.WindowType + } + if rla.Item.RetryAfterJitterMax != nil { + config["retry_after_jitter_max"] = rla.Item.RetryAfterJitterMax + } + return &kong.ConsumerGroupPlugin{ + ID: rla.Item.ID, + Name: kong.String("rate-limiting-advanced"), + ConsumerGroup: &kong.ConsumerGroup{ + ID: cgID, + }, + Config: config, + }, nil +} + +// DeleteRateLimitingAdvancedPlugin deletes a ConsumerGroup plugin in Kong +func DeleteRateLimitingAdvancedPlugin( + ctx context.Context, client *kong.Client, cgID *string, +) error { + if isEmptyString(cgID) { + return fmt.Errorf("deleting consumer-group plugin: id cannot be nil") + } + + endpoint := fmt.Sprintf("/v1/consumer-groups/%v/rate-limiting-advanced-config", *cgID) + req, err := client.NewRequest("DELETE", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = client.Do(ctx, req, nil) + return err +} + +// ListConsumerGroupMembers fetches a page members for a ConsumerGroup from Konnect. +func ListConsumerGroupMembers(ctx context.Context, + client *kong.Client, endpoint string, opt *KonnectListOpt, +) ([]*kong.Consumer, *KonnectListOpt, error) { + data, next, err := list(ctx, client, endpoint, opt) + if err != nil { + return nil, nil, err + } + + var consumers []*kong.Consumer + + for _, object := range data { + b, err := object.MarshalJSON() + if err != nil { + return nil, nil, err + } + var consumer kong.Consumer + err = json.Unmarshal(b, &consumer) + if err != nil { + return nil, nil, err + } + consumers = append(consumers, &consumer) + } + + return consumers, next, nil +} + +// GetConsumerGroupObject Get fetches a ConsumerGroup from Kong. +func GetConsumerGroupObject(ctx context.Context, + client *kong.Client, cgID *string, +) (*kong.ConsumerGroupObject, error) { + r, err := GetConsumerGroup(ctx, client, cgID) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + + m, err := ListAllConsumerGroupMembers(ctx, client, cgID) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + + var plugins []*kong.ConsumerGroupPlugin + p, err := GetConsumerGroupRateLimitingAdvancedPlugin(ctx, client, cgID) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + if p != nil { + plugins = append(plugins, p) + } + + group := &kong.ConsumerGroupObject{ + ConsumerGroup: r, + Consumers: m, + Plugins: plugins, + } + return group, nil +} + +// DeleteConsumerGroup deletes a ConsumerGroup in Kong +func DeleteConsumerGroup( + ctx context.Context, client *kong.Client, cgID *string, +) error { + if isEmptyString(cgID) { + return fmt.Errorf("delete consumer-group: ID cannot be nil") + } + + endpoint := fmt.Sprintf("/v1/consumer-groups/%v", *cgID) + req, err := client.NewRequest("DELETE", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = client.Do(ctx, req, nil) + return err +} + +func DeleteConsumerGroupMember( + ctx context.Context, client *kong.Client, cgID, consumer *string, +) error { + if isEmptyString(cgID) { + return fmt.Errorf("delete consumer-group-member: ID cannot be nil") + } + + endpoint := fmt.Sprintf("/v1/consumers/%s/groups/%s/members", *consumer, *cgID) + req, err := client.NewRequest("DELETE", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = client.Do(ctx, req, nil) + return err +} + +func CreateConsumerGroupMember( + ctx context.Context, client *kong.Client, cgID, consumer *string, +) error { + if isEmptyString(consumer) { + return fmt.Errorf("create consumer-group-member: consumer cannot be nil") + } else if isEmptyString(cgID) { + return fmt.Errorf("create consumer-group-member: consumer group ID cannot be nil") + } + + endpoint := fmt.Sprintf("/v1/consumers/%s/groups/%s/members", *consumer, *cgID) + req, err := client.NewRequest("POST", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = client.Do(ctx, req, nil) + return err +} + +func UpdateConsumerGroupMember( + ctx context.Context, client *kong.Client, cgID, consumer *string, +) error { + if isEmptyString(consumer) { + return fmt.Errorf("create consumer-group-member: consumer cannot be nil") + } else if isEmptyString(cgID) { + return fmt.Errorf("create consumer-group-member: consumer group ID cannot be nil") + } + + endpoint := fmt.Sprintf("/v1/consumers/%s/groups/%s/members", *consumer, *cgID) + req, err := client.NewRequest("PUT", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = client.Do(ctx, req, nil) + return err +} + +// list fetches a list of an entity in Kong. +// opt can be used to control pagination. +func list(ctx context.Context, + client *kong.Client, endpoint string, opt *KonnectListOpt, +) ([]json.RawMessage, *KonnectListOpt, error) { + req, err := client.NewRequest("GET", endpoint, nil, nil) + if err != nil { + return nil, nil, err + } + var list struct { + Items []json.RawMessage `json:"items"` + KonnectListOpt + } + + _, err = client.Do(ctx, req, &list) + if err != nil { + return nil, nil, err + } + + var next *KonnectListOpt + if list.Page != nil && list.Page.NextPageNum != nil { + next = &KonnectListOpt{ + Page: &PageOpt{ + Size: opt.Page.Size, + Number: list.Page.NextPageNum, + }, + Tags: opt.Tags, + } + } + + return list.Items, next, nil +} + +// List fetches a list of ConsumerGroup in Kong. +// opt can be used to control pagination. +func ListConsumerGroups(ctx context.Context, + client *kong.Client, opt *KonnectListOpt, +) ([]*kong.ConsumerGroup, *KonnectListOpt, error) { + data, next, err := list(ctx, client, "/v1/consumer-groups", opt) + if err != nil { + return nil, nil, err + } + + var consumers []*kong.ConsumerGroup + + for _, object := range data { + b, err := object.MarshalJSON() + if err != nil { + return nil, nil, err + } + var consumer kong.ConsumerGroup + err = json.Unmarshal(b, &consumer) + if err != nil { + return nil, nil, err + } + consumers = append(consumers, &consumer) + } + + return consumers, next, nil +} + +// ListAll fetches all ConsumerGroup in Kong. +func ListAllConsumerGroups(ctx context.Context, client *kong.Client, tags []*string) ([]*kong.ConsumerGroup, error) { + var consumerGroups, data []*kong.ConsumerGroup + var err error + opt := &KonnectListOpt{Page: &PageOpt{Size: 100}} + if tags != nil { + opt.Tags = tags + } + + for opt != nil { + data, opt, err = ListConsumerGroups(ctx, client, opt) + if err != nil { + return nil, err + } + consumerGroups = append(consumerGroups, data...) + } + return consumerGroups, nil +} diff --git a/pkg/konnect/control_plane_relations_service.go b/pkg/konnect/control_plane_relations_service.go new file mode 100644 index 0000000..74197f7 --- /dev/null +++ b/pkg/konnect/control_plane_relations_service.go @@ -0,0 +1,137 @@ +package konnect + +import ( + "context" + "encoding/json" + "fmt" + "net/http" +) + +type ControlPlaneRelationsService service + +type ControlPlaneServiceRelationCreateRequest struct { + ServiceVersionID string `json:"service_version"` + ControlPlaneEntityID string `json:"control_plane_entity_id"` + ControlPlane string `json:"control_plane"` +} + +type ControlPlaneServiceRelationUpdateRequest struct { + ID string + ControlPlaneServiceRelationCreateRequest +} + +// Create creates a ControlPlaneServiceRelation in Konnect. +func (s *ControlPlaneRelationsService) Create( + ctx context.Context, + relation *ControlPlaneServiceRelationCreateRequest, +) (*ControlPlaneServiceRelation, error) { + if relation == nil { + return nil, fmt.Errorf("cannot create a nil ControlPlaneServiceRelation") + } + relation.ControlPlane = s.controlPlaneID + + endpoint := "/api/control_plane_service_relations" + method := http.MethodPost + + req, err := s.client.NewRequest(method, endpoint, nil, relation) + if err != nil { + return nil, err + } + + var createdRelation ControlPlaneServiceRelation + _, err = s.client.Do(ctx, req, &createdRelation) + if err != nil { + return nil, err + } + return &createdRelation, nil +} + +// Delete deletes a ControlPlaneServiceRelation in Konnect. +func (s *ControlPlaneRelationsService) Delete(ctx context.Context, + relationID *string, +) error { + if emptyString(relationID) { + return fmt.Errorf("id cannot be nil for Delete operation") + } + + endpoint := fmt.Sprintf("/api/control_plane_service_relations/%v", + *relationID) + req, err := s.client.NewRequest("DELETE", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = s.client.Do(ctx, req, nil) + return err +} + +// Update updates a ControlPlaneServiceRelation in Konnect. +func (s ControlPlaneRelationsService) Update(ctx context.Context, + relation *ControlPlaneServiceRelationUpdateRequest, +) (*ServiceVersion, error) { + if relation == nil { + return nil, fmt.Errorf("cannot update a nil ControlPlaneServiceRelation") + } + + if relation.ID == "" { + return nil, fmt.Errorf("ID cannot be nil for Update operation") + } + relation.ControlPlane = s.controlPlaneID + + endpoint := fmt.Sprintf("/api/control_plane_service_relations/%v", relation.ID) + req, err := s.client.NewRequest("PATCH", endpoint, nil, relation) + if err != nil { + return nil, err + } + + var updatedSP ServiceVersion + _, err = s.client.Do(ctx, req, &updatedSP) + if err != nil { + return nil, err + } + return &updatedSP, nil +} + +// List fetches a list of control_plane_service_relations. +func (s *ControlPlaneRelationsService) List(ctx context.Context, + opt *ListOpt, +) ([]*ControlPlaneServiceRelation, *ListOpt, error) { + data, next, err := s.client.list(ctx, "/api/control_plane_service_relations", opt) + if err != nil { + return nil, nil, err + } + var relations []*ControlPlaneServiceRelation + + for _, object := range data { + b, err := object.MarshalJSON() + if err != nil { + return nil, nil, err + } + var relation ControlPlaneServiceRelation + err = json.Unmarshal(b, &relation) + if err != nil { + return nil, nil, err + } + relations = append(relations, &relation) + } + + return relations, next, nil +} + +// ListAll fetches all control_plane_service_relations. +func (s *ControlPlaneRelationsService) ListAll(ctx context.Context) ([]*ControlPlaneServiceRelation, + error, +) { + var relations, data []*ControlPlaneServiceRelation + var err error + opt := &ListOpt{Size: pageSize} + + for opt != nil { + data, opt, err = s.List(ctx, opt) + if err != nil { + return nil, err + } + relations = append(relations, data...) + } + return relations, nil +} diff --git a/pkg/konnect/control_plane_service.go b/pkg/konnect/control_plane_service.go new file mode 100644 index 0000000..3508f95 --- /dev/null +++ b/pkg/konnect/control_plane_service.go @@ -0,0 +1,36 @@ +package konnect + +import ( + "context" + "encoding/json" +) + +type ControlPlaneService service + +// List fetches a list of control planes. +// No pagination is being performed because the number of control planes +// is expected to be very small. +func (s *ControlPlaneService) List(ctx context.Context, + opt *ListOpt, +) ([]ControlPlane, *ListOpt, error) { + data, next, err := s.client.list(ctx, "/api/control_planes", opt) + if err != nil { + return nil, nil, err + } + var controlPlanes []ControlPlane + + for _, object := range data { + b, err := object.MarshalJSON() + if err != nil { + return nil, nil, err + } + var controlPlane ControlPlane + err = json.Unmarshal(b, &controlPlane) + if err != nil { + return nil, nil, err + } + controlPlanes = append(controlPlanes, controlPlane) + } + + return controlPlanes, next, nil +} diff --git a/pkg/konnect/document_service.go b/pkg/konnect/document_service.go new file mode 100644 index 0000000..3e31014 --- /dev/null +++ b/pkg/konnect/document_service.go @@ -0,0 +1,150 @@ +package konnect + +import ( + "context" + "encoding/json" + "fmt" + "net/http" +) + +type DocumentService service + +// Create creates a Document in Konnect. +func (d *DocumentService) Create(ctx context.Context, doc *Document) (*Document, error) { + if doc == nil { + return nil, fmt.Errorf("cannot create a nil document") + } + + if doc.Parent == nil { + return nil, fmt.Errorf("document must have a Parent") + } + + endpoint := doc.Parent.URL() + "/documents/" + method := http.MethodPost + if doc.ID != nil { + method = "PUT" + endpoint = endpoint + *doc.ID + } + req, err := d.client.NewRequest(method, endpoint, nil, doc) + if err != nil { + return nil, err + } + + var createdDoc Document + _, err = d.client.Do(ctx, req, &createdDoc) + if err != nil { + return nil, err + } + createdDoc.Parent = doc.Parent + return &createdDoc, nil +} + +// Delete deletes a Document in Konnect. +func (d *DocumentService) Delete(ctx context.Context, doc *Document) error { + if emptyString(doc.ID) { + return fmt.Errorf("id cannot be nil for Delete operation") + } + + if doc.Parent == nil { + return fmt.Errorf("document must have a Parent") + } + + endpoint := fmt.Sprintf("%s/documents/%s", doc.Parent.URL(), *doc.ID) + req, err := d.client.NewRequest("DELETE", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = d.client.Do(ctx, req, nil) + return err +} + +// Update updates a Document in Konnect. +func (d *DocumentService) Update(ctx context.Context, doc *Document) (*Document, error) { + if doc == nil { + return nil, fmt.Errorf("cannot update a nil document") + } + + if emptyString(doc.ID) { + return nil, fmt.Errorf("ID cannot be nil for Update operation") + } + + if doc.Parent == nil { + return nil, fmt.Errorf("document must have a Parent") + } + + // Document PATCHes run through POST validation logic. Attempting to PATCH a Published: true + // document without toggling Published results in a 400, as if you'd tried to POST another + // Published: true document under the same resource. As such, this PUTs instead. + endpoint := fmt.Sprintf("%s/documents/%s", doc.Parent.URL(), *doc.ID) + putReq, err := d.client.NewRequest("PUT", endpoint, nil, doc) + if err != nil { + return nil, err + } + + var updatedDoc Document + _, err = d.client.Do(ctx, putReq, &updatedDoc) + if err != nil { + return nil, err + } + updatedDoc.Parent = doc.Parent + return &updatedDoc, nil +} + +// listByPath fetches a list of Documents in Konnect on a specific path. +// This is a helper method for listing all documents for specific entities. +func (d *DocumentService) listByPath(ctx context.Context, path string, opt *ListOpt) ([]*Document, *ListOpt, error) { + data, next, err := d.client.list(ctx, path, opt) + if err != nil { + return nil, nil, err + } + var docs []*Document + + for _, object := range data { + b, err := object.MarshalJSON() + if err != nil { + return nil, nil, err + } + var doc Document + err = json.Unmarshal(b, &doc) + if err != nil { + return nil, nil, err + } + docs = append(docs, &doc) + } + + return docs, next, nil +} + +// ListAll fetches all Documents in Kong. +func (d *DocumentService) listAllByPath(ctx context.Context, path string) ([]*Document, error) { + var docs, data []*Document + var err error + opt := &ListOpt{Size: pageSize} + + for opt != nil { + data, opt, err = d.listByPath(ctx, path, opt) + if err != nil { + return nil, err + } + docs = append(docs, data...) + } + return docs, nil +} + +// ListAllForParent fetches all Documents in Konnect for a parent entity. +func (d *DocumentService) ListAllForParent(ctx context.Context, parent ParentInfoer) ([]*Document, error) { + if parent == nil { + return nil, fmt.Errorf("parent cannot be nil") + } + var docs []*Document + var err error + docs, err = d.listAllByPath(ctx, parent.URL()+"/documents") + if err != nil { + return nil, err + } + for _, doc := range docs { + doc.Parent = parent + } + return docs, nil +} diff --git a/pkg/konnect/error.go b/pkg/konnect/error.go new file mode 100644 index 0000000..1c12d16 --- /dev/null +++ b/pkg/konnect/error.go @@ -0,0 +1,62 @@ +package konnect + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" +) + +func hasError(res *http.Response) error { + if res.StatusCode >= 200 && res.StatusCode <= 399 { + return nil + } + + body, _ := ioutil.ReadAll(res.Body) // TODO error in error? + return &APIError{ + httpCode: res.StatusCode, + message: messageFromBody(body), + } +} + +func messageFromBody(b []byte) string { + s := struct { + Message string + }{} + + if err := json.Unmarshal(b, &s); err != nil { + return fmt.Sprintf("", err) + } + + return s.Message +} + +// APIError is used for Kong Admin API errors. +type APIError struct { + httpCode int + message string +} + +func (e *APIError) Error() string { + return fmt.Sprintf("HTTP status %d (message: %q)", e.httpCode, e.message) +} + +// Code returns the HTTP status code for the error. +func (e *APIError) Code() int { + return e.httpCode +} + +// IsNotFoundErr returns true if the error or it's cause is +// a 404 response from Kong. +func IsNotFoundErr(e error) bool { + var apiErr *APIError + return errors.As(e, &apiErr) && apiErr.httpCode == http.StatusNotFound +} + +// IsUnauthorizedErr returns true if the error or it's cause is +// a 401 response from Konnect. +func IsUnauthorizedErr(e error) bool { + var apiErr *APIError + return errors.As(e, &apiErr) && apiErr.httpCode == http.StatusUnauthorized +} diff --git a/pkg/konnect/list.go b/pkg/konnect/list.go new file mode 100644 index 0000000..ccdad2f --- /dev/null +++ b/pkg/konnect/list.go @@ -0,0 +1,60 @@ +package konnect + +import ( + "context" + "encoding/json" +) + +// ListOpt aids in paginating through list endpoints. +type ListOpt struct { + // Size of the page + Size int `url:"size,omitempty"` + // Page number to fetch + Page int `url:"page,omitempty"` +} + +const ( + // max page size in Konnect's API is 100 + pageSize = 100 +) + +// list fetches a list of an entity in Kong. +// opt can be used to control pagination. +func (c *Client) list(ctx context.Context, + endpoint string, opt *ListOpt, +) ([]json.RawMessage, *ListOpt, error) { + pageSize := 100 + if opt != nil { + if opt.Size > 100 { + opt.Size = pageSize + } else { + pageSize = opt.Size + } + } + + req, err := c.NewRequest("GET", endpoint, opt, nil) + if err != nil { + return nil, nil, err + } + var list struct { + Data []json.RawMessage `json:"data"` + Page int `json:"page"` + PageCount int `json:"pageCount"` + } + + _, err = c.Do(ctx, req, &list) + if err != nil { + return nil, nil, err + } + + // convenient for end user to use this opt till it's nil + var next *ListOpt + if len(list.Data) > 0 && list.Page != list.PageCount { + next = &ListOpt{ + Page: list.Page + 1, + Size: pageSize, + } + } + + return list.Data, next, nil +} diff --git a/pkg/konnect/login_service.go b/pkg/konnect/login_service.go new file mode 100644 index 0000000..ce40fb5 --- /dev/null +++ b/pkg/konnect/login_service.go @@ -0,0 +1,146 @@ +package konnect + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/cookiejar" + "net/url" + "strings" +) + +type OrgUserInfo struct { + Name string `json:"name,omitempty"` + OrgID string `json:"id,omitempty"` +} + +type AuthService service + +func (s *AuthService) Login(ctx context.Context, email, + password string, +) (AuthResponse, error) { + body := map[string]string{ + "username": email, + "password": password, + } + req, err := s.client.NewRequest(http.MethodPost, authEndpoint, nil, body) + if err != nil { + return AuthResponse{}, err + } + var authResponse AuthResponse + resp, err := s.client.Do(ctx, req, &authResponse) + if err != nil { + return AuthResponse{}, err + } + url, _ := url.Parse(s.client.baseURL) + jar, err := cookiejar.New(nil) + if err != nil { + return AuthResponse{}, err + } + + jar.SetCookies(url, resp.Cookies()) + s.client.client.Jar = jar + return authResponse, nil +} + +// getGlobalAuthEndpoint returns the global auth endpoint +// given a base Konnect URL. +func getGlobalAuthEndpoint(baseURL string) string { + parts := strings.Split(baseURL, "api.konghq") + return baseEndpointUS + parts[len(parts)-1] + authEndpointV2 +} + +func createAuthRequest(baseURL, email, password string) (*http.Request, error) { + var ( + buf []byte + err error + + body = map[string]string{ + "username": email, + "password": password, + } + ) + + buf, err = json.Marshal(body) + if err != nil { + return nil, err + } + + endpoint := getGlobalAuthEndpoint(baseURL) + return http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(buf)) +} + +func (s *AuthService) sessionAuth(ctx context.Context, email, + password string, +) (AuthResponse, error) { + req, err := createAuthRequest(s.client.baseURL, email, password) + if err != nil { + return AuthResponse{}, err + } + + var authResponse AuthResponse + resp, err := s.client.Do(ctx, req, &authResponse) + if err != nil { + return AuthResponse{}, err + } + url, _ := url.Parse(s.client.baseURL) + jar, err := cookiejar.New(nil) + if err != nil { + return AuthResponse{}, err + } + + jar.SetCookies(url, resp.Cookies()) + s.client.client.Jar = jar + return authResponse, nil +} + +func (s *AuthService) LoginV2(ctx context.Context, email, + password, token string, +) (AuthResponse, error) { + var ( + err error + authResponse AuthResponse + ) + + if token != "" { + s.client.token = token + } else if email != "" && password != "" { + authResponse, err = s.sessionAuth(ctx, email, password) + if err != nil { + return AuthResponse{}, err + } + } else { + return AuthResponse{}, errors.New( + "at least one of email/password or personal access token must be provided", + ) + } + + info, err := s.OrgUserInfo(ctx) + if err != nil { + return AuthResponse{}, err + } + authResponse.Name = info.Name + authResponse.OrganizationID = info.OrgID + return authResponse, nil +} + +func (s *AuthService) OrgUserInfo(ctx context.Context) (*OrgUserInfo, error) { + // replace geo-specific endpoint with global one for retrieving org info + client := *s.client + client.baseURL = strings.Replace(s.client.baseURL, "eu.", "global.", 1) + client.baseURL = strings.Replace(client.baseURL, "au.", "global.", 1) + + req, err := client.NewRequest(http.MethodGet, "/v2/organizations/me", nil, nil) + if err != nil { + return nil, err + } + + info := &OrgUserInfo{} + _, err = s.client.Do(ctx, req, info) + if err != nil { + return nil, err + } + return info, nil +} diff --git a/pkg/konnect/login_service_test.go b/pkg/konnect/login_service_test.go new file mode 100644 index 0000000..89f384f --- /dev/null +++ b/pkg/konnect/login_service_test.go @@ -0,0 +1,42 @@ +package konnect + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetGlobalAuthEndpoint(t *testing.T) { + tests := []struct { + baseURL string + expected string + }{ + { + baseURL: "https://us.api.konghq.com", + expected: "https://global.api.konghq.com/kauth/api/v1/authenticate", + }, + { + baseURL: "https://global.api.konghq.com", + expected: "https://global.api.konghq.com/kauth/api/v1/authenticate", + }, + { + baseURL: "https://eu.api.konghq.com", + expected: "https://global.api.konghq.com/kauth/api/v1/authenticate", + }, + { + baseURL: "https://au.api.konghq.com", + expected: "https://global.api.konghq.com/kauth/api/v1/authenticate", + }, + { + baseURL: "https://api.konghq.com", + expected: "https://global.api.konghq.com/kauth/api/v1/authenticate", + }, + { + baseURL: "https://eu.api.konghq.test", + expected: "https://global.api.konghq.test/kauth/api/v1/authenticate", + }, + } + for _, tt := range tests { + assert.Equal(t, tt.expected, getGlobalAuthEndpoint(tt.baseURL)) + } +} diff --git a/pkg/konnect/request.go b/pkg/konnect/request.go new file mode 100644 index 0000000..9e155bb --- /dev/null +++ b/pkg/konnect/request.go @@ -0,0 +1,53 @@ +package konnect + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + + "github.com/google/go-querystring/query" +) + +// NewRequest creates a request based on the inputs. +// endpoint should be relative to the baseURL specified during +// client creation. +// body is always marshaled into JSON. +func (c *Client) NewRequest(method, endpoint string, qs interface{}, + body interface{}, +) (*http.Request, error) { + if endpoint == "" { + return nil, fmt.Errorf("endpoint can't be nil") + } + // body to be sent in JSON + var buf []byte + if body != nil { + var err error + buf, err = json.Marshal(body) + if err != nil { + return nil, err + } + } + + // Create a new request + req, err := http.NewRequest(method, c.baseURL+endpoint, + bytes.NewBuffer(buf)) + if err != nil { + return nil, err + } + + // add body if needed + if body != nil { + req.Header.Add("Content-Type", "application/json") + } + + // add query string if any + if qs != nil { + values, err := query.Values(qs) + if err != nil { + return nil, err + } + req.URL.RawQuery = values.Encode() + } + return req, nil +} diff --git a/pkg/konnect/runtime_group_service.go b/pkg/konnect/runtime_group_service.go new file mode 100644 index 0000000..b1bd46b --- /dev/null +++ b/pkg/konnect/runtime_group_service.go @@ -0,0 +1,38 @@ +package konnect + +import ( + "context" + "encoding/json" +) + +type RuntimeGroupService service + +// List fetches a list of Service packages. +func (s *RuntimeGroupService) List(ctx context.Context, + opt *ListOpt, +) ([]*RuntimeGroup, *ListOpt, error) { + data, next, err := s.client.list(ctx, "/konnect-api/api/runtime_groups", opt) + // TODO: replace the above with the following once the Konnect API is updated. + // + // Note: pagination logic will need to be fixed too. + // data, next, err := s.client.list(ctx, "/v2/control-planes", opt) + if err != nil { + return nil, nil, err + } + var runtimeGroups []*RuntimeGroup + + for _, object := range data { + b, err := object.MarshalJSON() + if err != nil { + return nil, nil, err + } + var runtimeGroup RuntimeGroup + err = json.Unmarshal(b, &runtimeGroup) + if err != nil { + return nil, nil, err + } + runtimeGroups = append(runtimeGroups, &runtimeGroup) + } + + return runtimeGroups, next, nil +} diff --git a/pkg/konnect/service_package_service.go b/pkg/konnect/service_package_service.go new file mode 100644 index 0000000..28fb6c3 --- /dev/null +++ b/pkg/konnect/service_package_service.go @@ -0,0 +1,118 @@ +package konnect + +import ( + "context" + "encoding/json" + "fmt" +) + +type ServicePackageService service + +// Create creates a ServicePackage in Konnect. +func (s *ServicePackageService) Create(ctx context.Context, + sp *ServicePackage, +) (*ServicePackage, error) { + if sp == nil { + return nil, fmt.Errorf("cannot create a nil service-package") + } + + endpoint := "/api/service_packages" + method := "POST" + req, err := s.client.NewRequest(method, endpoint, nil, sp) + if err != nil { + return nil, err + } + + var createdSP ServicePackage + _, err = s.client.Do(ctx, req, &createdSP) + if err != nil { + return nil, err + } + return &createdSP, nil +} + +// Delete deletes a ServicePackage in Konnect. +func (s *ServicePackageService) Delete(ctx context.Context, id *string) error { + if emptyString(id) { + return fmt.Errorf("id cannot be nil for Delete operation") + } + + endpoint := fmt.Sprintf("/api/service_packages/%v", *id) + req, err := s.client.NewRequest("DELETE", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = s.client.Do(ctx, req, nil) + return err +} + +// Update updates a ServicePackage in Konnect. +func (s *ServicePackageService) Update(ctx context.Context, + sp *ServicePackage, +) (*ServicePackage, error) { + if sp == nil { + return nil, fmt.Errorf("cannot update a nil service-package") + } + + if emptyString(sp.ID) { + return nil, fmt.Errorf("ID cannot be nil for Update operation") + } + + endpoint := fmt.Sprintf("/api/service_packages/%v", *sp.ID) + req, err := s.client.NewRequest("PATCH", endpoint, nil, sp) + if err != nil { + return nil, err + } + + var updatedSP ServicePackage + _, err = s.client.Do(ctx, req, &updatedSP) + if err != nil { + return nil, err + } + return &updatedSP, nil +} + +// List fetches a list of Service packages. +func (s *ServicePackageService) List(ctx context.Context, + opt *ListOpt, +) ([]*ServicePackage, *ListOpt, error) { + data, next, err := s.client.list(ctx, "/api/service_packages", opt) + if err != nil { + return nil, nil, err + } + var servicePackages []*ServicePackage + + for _, object := range data { + b, err := object.MarshalJSON() + if err != nil { + return nil, nil, err + } + var servicePackage ServicePackage + err = json.Unmarshal(b, &servicePackage) + if err != nil { + return nil, nil, err + } + servicePackages = append(servicePackages, &servicePackage) + } + + return servicePackages, next, nil +} + +// ListAll fetches all Service packages. +func (s *ServicePackageService) ListAll(ctx context.Context) ([]*ServicePackage, + error, +) { + var servicePackages, data []*ServicePackage + var err error + opt := &ListOpt{Size: pageSize} + + for opt != nil { + data, opt, err = s.List(ctx, opt) + if err != nil { + return nil, err + } + servicePackages = append(servicePackages, data...) + } + return servicePackages, nil +} diff --git a/pkg/konnect/service_version_service.go b/pkg/konnect/service_version_service.go new file mode 100644 index 0000000..08b3101 --- /dev/null +++ b/pkg/konnect/service_version_service.go @@ -0,0 +1,107 @@ +package konnect + +import ( + "context" + "fmt" + "net/http" +) + +type ServiceVersionService service + +// Create creates a ServiceVersion in Konnect. +func (s *ServiceVersionService) Create(ctx context.Context, + sv *ServiceVersion, +) (*ServiceVersion, error) { + if sv == nil { + return nil, fmt.Errorf("cannot create a nil service-package") + } + + endpoint := "/api/service_versions" + method := "POST" + + if !emptyString(sv.ID) { + method = "PUT" + endpoint += "/" + *sv.ID + + } + + req, err := s.client.NewRequest(method, endpoint, nil, map[string]string{ + "version": *sv.Version, + "service_package": *sv.ServicePackage.ID, + "control_plane": s.controlPlaneID, + }) + if err != nil { + return nil, err + } + + var createdSV ServiceVersion + _, err = s.client.Do(ctx, req, &createdSV) + if err != nil { + return nil, err + } + return &createdSV, nil +} + +// Delete deletes a ServiceVersion in Konnect. +func (s *ServiceVersionService) Delete(ctx context.Context, id *string) error { + if emptyString(id) { + return fmt.Errorf("id cannot be nil for Delete operation") + } + + endpoint := fmt.Sprintf("/api/service_versions/%v", *id) + req, err := s.client.NewRequest("DELETE", endpoint, nil, nil) + if err != nil { + return err + } + + _, err = s.client.Do(ctx, req, nil) + if err != nil { + return err + } + return err +} + +// Update updates a ServiceVersion in Konnect. +func (s *ServiceVersionService) Update(ctx context.Context, + sv *ServiceVersion, +) (*ServiceVersion, error) { + if sv == nil { + return nil, fmt.Errorf("cannot update a nil service-package") + } + + if emptyString(sv.ID) { + return nil, fmt.Errorf("ID cannot be nil for Update operation") + } + + endpoint := fmt.Sprintf("/api/service_versions/%v", *sv.ID) + req, err := s.client.NewRequest("PATCH", endpoint, nil, sv) + if err != nil { + return nil, err + } + + var updatedSV ServiceVersion + _, err = s.client.Do(ctx, req, &updatedSV) + if err != nil { + return nil, err + } + return &updatedSV, nil +} + +// ListForPackage fetches a list of Service Versions for a given servicePackageID. +func (s *ServiceVersionService) ListForPackage(ctx context.Context, + servicePackageID *string, +) ([]ServiceVersion, error) { + endpoint := "/api/service_packages/" + *servicePackageID + "/service_versions" + req, err := s.client.NewRequest(http.MethodGet, endpoint, nil, nil) + if err != nil { + return nil, err + } + // Note: This endpoint doesn't follow the structure of paginated endpoints + // and instead returns an array with all service versions. + var response []ServiceVersion + _, err = s.client.Do(ctx, req, &response) + if err != nil { + return nil, err + } + return response, nil +} diff --git a/pkg/konnect/types.go b/pkg/konnect/types.go new file mode 100644 index 0000000..40d9560 --- /dev/null +++ b/pkg/konnect/types.go @@ -0,0 +1,148 @@ +package konnect + +import ( + "fmt" +) + +const ( + baseEndpointUS = "https://global.api.konghq" + authEndpoint = "/api/auth" + authEndpointV2 = "/kauth/api/v1/authenticate" +) + +type ParentInfoer interface { + URL() string + Key() string +} + +func BaseURL() string { + const baseURL = "https://konnect.konghq.com" + return baseURL +} + +// RuntimeGroup represents a Runtime Group in Konnect. +// +k8s:deepcopy-gen=true +type RuntimeGroup struct { + ID *string `json:"id,omitempty"` + Name *string `json:"name,omitempty"` +} + +// ServicePackage represents a Service Package in Konnect. +// +k8s:deepcopy-gen=true +type ServicePackage struct { + ID *string `json:"id,omitempty"` + Name *string `json:"name,omitempty"` + Description *string `json:"description"` + + Versions []ServiceVersion `json:"versions,omitempty"` +} + +func (p *ServicePackage) URL() string { + return fmt.Sprintf("/api/service_packages/%s", *p.ID) +} + +func (p *ServicePackage) Key() string { + return "ServicePackage" + ":" + *p.ID +} + +// ServiceVersion represents a Service Version in Konnect. +// +k8s:deepcopy-gen=true +type ServiceVersion struct { + ID *string `json:"id,omitempty"` + Version *string `json:"version,omitempty"` + + ServicePackage *ServicePackage `json:"service_package,omitempty"` + + ControlPlaneServiceRelation *ControlPlaneServiceRelation `json:"control_plane_service_relation,omitempty"` +} + +func (v *ServiceVersion) URL() string { + return fmt.Sprintf("/api/service_versions/%s", *v.ID) +} + +func (v *ServiceVersion) Key() string { + return "ServiceVersion" + ":" + *v.ID +} + +type Document struct { + ID *string `json:"id,omitempty"` + Path *string `json:"path,omitempty"` + Content *string `json:"content,omitempty"` + Published *bool `json:"published,omitempty"` + Parent ParentInfoer `json:"-"` +} + +func (d *Document) ParentKey() string { + return d.Parent.Key() +} + +// ShallowCopyInto is a shallowcopy function, copying the receiver, writing into out. d must be non-nil. +func (d *Document) ShallowCopyInto(out *Document) { + *out = *d + if d.ID != nil { + d, out := &d.ID, &out.ID + *out = new(string) + **out = **d + } + if d.Path != nil { + d, out := &d.Path, &out.Path + *out = new(string) + **out = **d + } + if d.Content != nil { + d, out := &d.Content, &out.Content + *out = new(string) + **out = **d + } + if d.Published != nil { + d, out := &d.Published, &out.Published + *out = new(bool) + **out = **d + } + if d.Parent != nil { + out.Parent = d.Parent + } +} + +// ShallowCopy is a shallowcopy function, copying the receiver, creating a new Document. +func (d *Document) ShallowCopy() *Document { + if d == nil { + return nil + } + out := new(Document) + d.ShallowCopyInto(out) + return out +} + +// ControlPlaneServiceRelation represents relationship between Control plane implementation and a Service version. +// +k8s:deepcopy-gen=true +type ControlPlaneServiceRelation struct { + ID *string `json:"id,omitempty"` + ControlPlaneEntityID *string `json:"control_plane_entity_id,omitempty"` + ControlPlane *ControlPlane `json:"control_plane,omitempty"` +} + +// ControlPlane identifies a specific control plane in Konnect. +// +k8s:deepcopy-gen=true +type ControlPlane struct { + ID *string `json:"id"` + Type *ControlPlaneType `json:"type"` +} + +// ControlPlaneType represents control plane associated information. +// +k8s:deepcopy-gen=true +type ControlPlaneType struct { + Name *string `json:"name"` +} + +// AuthResponse is authentication response wrapper for login. +type AuthResponse struct { + Name string `json:"name"` + OrganizationID string `json:"org_id"` + + // deprecated fields + Organization string `json:"org_name"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + FullName string `json:"full_name"` +} diff --git a/pkg/konnect/utils.go b/pkg/konnect/utils.go new file mode 100644 index 0000000..a999f52 --- /dev/null +++ b/pkg/konnect/utils.go @@ -0,0 +1,10 @@ +package konnect + +const ( + // KonnectManagedPluginTag is used by Konnect to tag internally-managed plugins + KonnectManagedPluginTag = "konnect-managed-plugin" +) + +func emptyString(p *string) bool { + return p == nil || *p == "" +} diff --git a/pkg/konnect/zz_generated.deepcopy.go b/pkg/konnect/zz_generated.deepcopy.go new file mode 100644 index 0000000..ba7c0b3 --- /dev/null +++ b/pkg/konnect/zz_generated.deepcopy.go @@ -0,0 +1,200 @@ +//go:build !ignore_autogenerated +// +build !ignore_autogenerated + +/* +Copyright 2021 Kong Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by deepcopy-gen. DO NOT EDIT. + +package konnect + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ControlPlane) DeepCopyInto(out *ControlPlane) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.Type != nil { + in, out := &in.Type, &out.Type + *out = new(ControlPlaneType) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ControlPlane. +func (in *ControlPlane) DeepCopy() *ControlPlane { + if in == nil { + return nil + } + out := new(ControlPlane) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ControlPlaneServiceRelation) DeepCopyInto(out *ControlPlaneServiceRelation) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.ControlPlaneEntityID != nil { + in, out := &in.ControlPlaneEntityID, &out.ControlPlaneEntityID + *out = new(string) + **out = **in + } + if in.ControlPlane != nil { + in, out := &in.ControlPlane, &out.ControlPlane + *out = new(ControlPlane) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ControlPlaneServiceRelation. +func (in *ControlPlaneServiceRelation) DeepCopy() *ControlPlaneServiceRelation { + if in == nil { + return nil + } + out := new(ControlPlaneServiceRelation) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ControlPlaneType) DeepCopyInto(out *ControlPlaneType) { + *out = *in + if in.Name != nil { + in, out := &in.Name, &out.Name + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ControlPlaneType. +func (in *ControlPlaneType) DeepCopy() *ControlPlaneType { + if in == nil { + return nil + } + out := new(ControlPlaneType) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RuntimeGroup) DeepCopyInto(out *RuntimeGroup) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.Name != nil { + in, out := &in.Name, &out.Name + *out = new(string) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RuntimeGroup. +func (in *RuntimeGroup) DeepCopy() *RuntimeGroup { + if in == nil { + return nil + } + out := new(RuntimeGroup) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ServicePackage) DeepCopyInto(out *ServicePackage) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.Name != nil { + in, out := &in.Name, &out.Name + *out = new(string) + **out = **in + } + if in.Description != nil { + in, out := &in.Description, &out.Description + *out = new(string) + **out = **in + } + if in.Versions != nil { + in, out := &in.Versions, &out.Versions + *out = make([]ServiceVersion, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServicePackage. +func (in *ServicePackage) DeepCopy() *ServicePackage { + if in == nil { + return nil + } + out := new(ServicePackage) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ServiceVersion) DeepCopyInto(out *ServiceVersion) { + *out = *in + if in.ID != nil { + in, out := &in.ID, &out.ID + *out = new(string) + **out = **in + } + if in.Version != nil { + in, out := &in.Version, &out.Version + *out = new(string) + **out = **in + } + if in.ServicePackage != nil { + in, out := &in.ServicePackage, &out.ServicePackage + *out = new(ServicePackage) + (*in).DeepCopyInto(*out) + } + if in.ControlPlaneServiceRelation != nil { + in, out := &in.ControlPlaneServiceRelation, &out.ControlPlaneServiceRelation + *out = new(ControlPlaneServiceRelation) + (*in).DeepCopyInto(*out) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServiceVersion. +func (in *ServiceVersion) DeepCopy() *ServiceVersion { + if in == nil { + return nil + } + out := new(ServiceVersion) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/scripts/header-template.go.tmpl b/pkg/scripts/header-template.go.tmpl new file mode 100644 index 0000000..d741198 --- /dev/null +++ b/pkg/scripts/header-template.go.tmpl @@ -0,0 +1,15 @@ +/* +Copyright 2021 Kong Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ diff --git a/pkg/scripts/update-deepcopy-gen.sh b/pkg/scripts/update-deepcopy-gen.sh new file mode 100755 index 0000000..a25d0c9 --- /dev/null +++ b/pkg/scripts/update-deepcopy-gen.sh @@ -0,0 +1,23 @@ +#!/bin/bash -e + +go install k8s.io/code-generator/cmd/deepcopy-gen +TMP_DIR=$(mktemp -d) +trap "rm -rf $TMP_DIR" EXIT + +# konnect package +deepcopy-gen --input-dirs github.com/kong/deck/konnect \ + -O zz_generated.deepcopy \ + --go-header-file scripts/header-template.go.tmpl \ + --output-base $TMP_DIR + +cp $TMP_DIR/github.com/kong/deck/konnect/zz_generated.deepcopy.go \ + konnect/zz_generated.deepcopy.go + +# file package +deepcopy-gen --input-dirs github.com/kong/deck/file \ + -O zz_generated.deepcopy \ + --go-header-file scripts/header-template.go.tmpl \ + --output-base $TMP_DIR + +cp $TMP_DIR/github.com/kong/deck/file/zz_generated.deepcopy.go \ + file/zz_generated.deepcopy.go diff --git a/pkg/scripts/verify-codegen.sh b/pkg/scripts/verify-codegen.sh new file mode 100755 index 0000000..4a3ab87 --- /dev/null +++ b/pkg/scripts/verify-codegen.sh @@ -0,0 +1,7 @@ +#!/bin/bash -e + +FILE="kong_json_schema.json" +cp file/${FILE} /tmp/${FILE} +go generate ./... + +diff -u /tmp/${FILE} file/${FILE} diff --git a/pkg/scripts/verify-deepcopy-gen.sh b/pkg/scripts/verify-deepcopy-gen.sh new file mode 100755 index 0000000..ce911ed --- /dev/null +++ b/pkg/scripts/verify-deepcopy-gen.sh @@ -0,0 +1,23 @@ +#!/bin/bash -e + +go install k8s.io/code-generator/cmd/deepcopy-gen +TMP_DIR=$(mktemp -d) +trap "rm -rf $TMP_DIR" EXIT + +# konnect package +deepcopy-gen --input-dirs github.com/kong/deck/konnect \ + -O zz_generated.deepcopy \ + --go-header-file scripts/header-template.go.tmpl \ + --output-base $TMP_DIR + +diff -Naur konnect/zz_generated.deepcopy.go \ + $TMP_DIR/github.com/kong/deck/konnect/zz_generated.deepcopy.go + +# file package +deepcopy-gen --input-dirs github.com/kong/deck/file \ + -O zz_generated.deepcopy \ + --go-header-file scripts/header-template.go.tmpl \ + --output-base $TMP_DIR + +diff -Naur file/zz_generated.deepcopy.go \ + $TMP_DIR/github.com/kong/deck/file/zz_generated.deepcopy.go diff --git a/pkg/state/aclgroup.go b/pkg/state/aclgroup.go new file mode 100644 index 0000000..c3531e5 --- /dev/null +++ b/pkg/state/aclgroup.go @@ -0,0 +1,262 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +var ( + errGroupRequired = fmt.Errorf("name of ACL group required") + errConsumerRequired = fmt.Errorf("consumer required") +) + +const ( + aclGroupTableName = "aclGroup" + aclGroupsByConsumerID = "aclGroupsByConsumerID" +) + +var aclGroupTableSchema = &memdb.TableSchema{ + Name: aclGroupTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "group": { + Name: "group", + Indexer: &memdb.StringFieldIndex{Field: "Group"}, + }, + all: allIndex, + // foreign + aclGroupsByConsumerID: { + Name: aclGroupsByConsumerID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Consumer", + Sub: "ID", + }, + }, + }, + }, + }, +} + +// ACLGroupsCollection stores and indexes acl-group credentials. +type ACLGroupsCollection collection + +// Add adds aclGroup to ACLGroupsCollection +func (k *ACLGroupsCollection) Add(aclGroup ACLGroup) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(aclGroup.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := insertACLGroup(txn, aclGroup) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func insertACLGroup(txn *memdb.Txn, aclGroup ACLGroup) error { + if utils.Empty(aclGroup.ID) { + return errIDRequired + } + + // err out if group with same ID is present + _, err := getACLGroupByID(txn, *aclGroup.ID) + if err == nil { + return fmt.Errorf("inserting acl-group %v: %w", aclGroup.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + // check if the same combination is present + if utils.Empty(aclGroup.Group) { + return errGroupRequired + } + if aclGroup.Consumer == nil || utils.Empty(aclGroup.Consumer.ID) { + return errConsumerRequired + } + _, err = getACLGroup(txn, *aclGroup.Consumer.ID, *aclGroup.Group) + if err == nil { + return fmt.Errorf("inserting acl-group %v: %w", aclGroup.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + // all good + err = txn.Insert(aclGroupTableName, &aclGroup) + if err != nil { + return err + } + return nil +} + +func getACLGroupByID(txn *memdb.Txn, id string) (*ACLGroup, error) { + res, err := multiIndexLookupUsingTxn(txn, aclGroupTableName, + []string{"id"}, id) + if err != nil { + return nil, err + } + aclGroup, ok := res.(*ACLGroup) + if !ok { + panic(unexpectedType) + } + return &ACLGroup{ACLGroup: *aclGroup.DeepCopy()}, nil +} + +// GetByID gets an acl-group with id. +func (k *ACLGroupsCollection) GetByID(id string) (*ACLGroup, error) { + if id == "" { + return nil, errIDRequired + } + txn := k.db.Txn(false) + defer txn.Abort() + return getACLGroupByID(txn, id) +} + +func getACLGroup(txn *memdb.Txn, consumerID, groupOrID string) (*ACLGroup, error) { + groups, err := getAllACLGroupsByConsumerID(txn, consumerID) + if err != nil { + return nil, err + } + for _, group := range groups { + if groupOrID == *group.ID || groupOrID == *group.Group { + return &ACLGroup{ACLGroup: *group.DeepCopy()}, nil + } + } + return nil, ErrNotFound +} + +func getAllACLGroupsByConsumerID(txn *memdb.Txn, consumerID string) ([]*ACLGroup, error) { + iter, err := txn.Get(aclGroupTableName, aclGroupsByConsumerID, consumerID) + if err != nil { + return nil, err + } + var res []*ACLGroup + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*ACLGroup) + if !ok { + panic(unexpectedType) + } + res = append(res, &ACLGroup{ACLGroup: *r.DeepCopy()}) + } + return res, nil +} + +// Get gets a acl-group for a consumer by group or ID. +func (k *ACLGroupsCollection) Get(consumerID, + groupOrID string, +) (*ACLGroup, error) { + if groupOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + return getACLGroup(txn, consumerID, groupOrID) +} + +// GetAllByConsumerID returns all acl-group credentials +// belong to a Consumer with id. +func (k *ACLGroupsCollection) GetAllByConsumerID(id string) ([]*ACLGroup, + error, +) { + if id == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + + return getAllACLGroupsByConsumerID(txn, id) +} + +// Update updates an existing acl-group credential. +func (k *ACLGroupsCollection) Update(aclGroup ACLGroup) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(aclGroup.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteACLGroup(txn, *aclGroup.ID) + if err != nil { + return err + } + + err = insertACLGroup(txn, aclGroup) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteACLGroup(txn *memdb.Txn, id string) error { + group, err := getACLGroupByID(txn, id) + if err != nil { + return err + } + + err = txn.Delete(aclGroupTableName, group) + if err != nil { + return err + } + return nil +} + +// Delete deletes an acl-group by id. +func (k *ACLGroupsCollection) Delete(id string) error { + if id == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteACLGroup(txn, id) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all acl-groups. +func (k *ACLGroupsCollection) GetAll() ([]*ACLGroup, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(aclGroupTableName, all, true) + if err != nil { + return nil, err + } + + var res []*ACLGroup + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*ACLGroup) + if !ok { + panic(unexpectedType) + } + res = append(res, &ACLGroup{ACLGroup: *r.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/aclgroup_test.go b/pkg/state/aclgroup_test.go new file mode 100644 index 0000000..9becc30 --- /dev/null +++ b/pkg/state/aclgroup_test.go @@ -0,0 +1,240 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func aclGroupsCollection() *ACLGroupsCollection { + return state().ACLGroups +} + +func TestACLGroupInsert(t *testing.T) { + assert := assert.New(t) + collection := aclGroupsCollection() + + var aclGroup ACLGroup + assert.NotNil(collection.Add(aclGroup)) + + aclGroup.Group = kong.String("my-group") + aclGroup.ID = kong.String("first") + err := collection.Add(aclGroup) + assert.NotNil(err) + + var aclGroup2 ACLGroup + aclGroup2.Group = kong.String("my-group") + aclGroup2.ID = kong.String("first") + aclGroup2.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + } + err = collection.Add(aclGroup2) + assert.Nil(err) + + // re-insert + err = collection.Add(aclGroup2) + assert.NotNil(err) + + // re-insert with a different ID + aclGroup2.ID = kong.String("second") + err = collection.Add(aclGroup2) + assert.NotNil(err) + + // re-insert for different consumer + aclGroup2.Consumer = &kong.Consumer{ + ID: kong.String("consumer2-id"), + } + err = collection.Add(aclGroup2) + assert.Nil(err) +} + +func TestACLGroupGetByID(t *testing.T) { + assert := assert.New(t) + collection := aclGroupsCollection() + + var aclGroup ACLGroup + aclGroup.Group = kong.String("my-group") + aclGroup.ID = kong.String("first") + aclGroup.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + } + + err := collection.Add(aclGroup) + assert.Nil(err) + + res, err := collection.GetByID("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-group", *res.Group) + + res, err = collection.GetByID("my-group") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.GetByID("does-not-exist") + assert.NotNil(err) + assert.Nil(res) +} + +func TestACLGroupGet(t *testing.T) { + assert := assert.New(t) + collection := aclGroupsCollection() + + populateWithACLGroupFixtures(assert, collection) + + res, err := collection.Get("first", "does-not-exist") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.Get("does-not-exist", "my-group12") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.Get("consumer1-id", "my-group12") + assert.Nil(err) + assert.NotNil(res) +} + +func TestACLGroupUpdate(t *testing.T) { + assert := assert.New(t) + collection := aclGroupsCollection() + + var aclGroup ACLGroup + aclGroup.Group = kong.String("my-group") + aclGroup.ID = kong.String("first") + aclGroup.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + } + + err := collection.Add(aclGroup) + assert.Nil(err) + + res, err := collection.Get("consumer1-id", "first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-group", *res.Group) + + res.Group = kong.String("my-group2") + err = collection.Update(*res) + assert.Nil(err) + + res, err = collection.Get("consumer1-id", "my-group") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.Get("consumer1-id", "my-group2") + assert.Nil(err) + assert.Equal("first", *res.ID) +} + +func TestACLGroupDelete(t *testing.T) { + assert := assert.New(t) + collection := aclGroupsCollection() + + var aclGroup ACLGroup + aclGroup.Group = kong.String("my-group1") + aclGroup.ID = kong.String("first") + aclGroup.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + } + err := collection.Add(aclGroup) + assert.Nil(err) + + res, err := collection.Get("consumer1-id", "my-group1") + assert.Nil(err) + assert.NotNil(res) + + err = collection.Delete(*res.ID) + assert.Nil(err) + + res, err = collection.Get("consumer1-id", "my-group1") + assert.NotNil(err) + assert.Nil(res) + + // delete a non-existing one + err = collection.Delete("first") + assert.NotNil(err) + + err = collection.Delete("my-group1") + assert.NotNil(err) +} + +func TestACLGroupGetAll(t *testing.T) { + assert := assert.New(t) + collection := aclGroupsCollection() + + populateWithACLGroupFixtures(assert, collection) + + aclGroups, err := collection.GetAll() + assert.Nil(err) + assert.Equal(5, len(aclGroups)) +} + +func TestACLGroupGetByConsumer(t *testing.T) { + assert := assert.New(t) + collection := aclGroupsCollection() + + populateWithACLGroupFixtures(assert, collection) + + aclGroups, err := collection.GetAllByConsumerID("consumer1-id") + assert.Nil(err) + assert.Equal(3, len(aclGroups)) +} + +func populateWithACLGroupFixtures(assert *assert.Assertions, + collection *ACLGroupsCollection, +) { + aclGroups := []ACLGroup{ + { + ACLGroup: kong.ACLGroup{ + Group: kong.String("my-group11"), + ID: kong.String("first"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + }, + }, + }, + { + ACLGroup: kong.ACLGroup{ + Group: kong.String("my-group12"), + ID: kong.String("second"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + }, + }, + }, + { + ACLGroup: kong.ACLGroup{ + Group: kong.String("my-group13"), + ID: kong.String("third"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + }, + }, + }, + { + ACLGroup: kong.ACLGroup{ + Group: kong.String("my-group21"), + ID: kong.String("fourth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + }, + }, + }, + { + ACLGroup: kong.ACLGroup{ + Group: kong.String("my-group22"), + ID: kong.String("fifth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + }, + }, + }, + } + + for _, k := range aclGroups { + err := collection.Add(k) + assert.Nil(err) + } +} diff --git a/pkg/state/basicauth.go b/pkg/state/basicauth.go new file mode 100644 index 0000000..d2e7703 --- /dev/null +++ b/pkg/state/basicauth.go @@ -0,0 +1,85 @@ +package state + +// BasicAuthsCollection stores and indexes basic-auth credentials. +type BasicAuthsCollection struct { + credentialsCollection +} + +func newBasicAuthsCollection(common collection) *BasicAuthsCollection { + return &BasicAuthsCollection{ + credentialsCollection: credentialsCollection{ + collection: common, + CredType: "basic-auth", + }, + } +} + +// Add adds a basic-auth credential to BasicAuthsCollection +func (k *BasicAuthsCollection) Add(basicAuth BasicAuth) error { + cred := (entity)(&basicAuth) + return k.credentialsCollection.Add(cred) +} + +// Get gets a basic-auth credential by key or ID. +func (k *BasicAuthsCollection) Get(keyOrID string) (*BasicAuth, error) { + cred, err := k.credentialsCollection.Get(keyOrID) + if err != nil { + return nil, err + } + + basicAuth, ok := cred.(*BasicAuth) + if !ok { + panic(unexpectedType) + } + return &BasicAuth{BasicAuth: *basicAuth.DeepCopy()}, nil +} + +// GetAllByConsumerID returns all basic-auth credentials +// belong to a Consumer with id. +func (k *BasicAuthsCollection) GetAllByConsumerID(id string) ([]*BasicAuth, + error, +) { + creds, err := k.credentialsCollection.GetAllByConsumerID(id) + if err != nil { + return nil, err + } + + var res []*BasicAuth + for _, cred := range creds { + r, ok := cred.(*BasicAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &BasicAuth{BasicAuth: *r.DeepCopy()}) + } + return res, nil +} + +// Update updates an existing basic-auth credential. +func (k *BasicAuthsCollection) Update(basicAuth BasicAuth) error { + cred := (entity)(&basicAuth) + return k.credentialsCollection.Update(cred) +} + +// Delete deletes a basic-auth credential by key or ID. +func (k *BasicAuthsCollection) Delete(keyOrID string) error { + return k.credentialsCollection.Delete(keyOrID) +} + +// GetAll gets all basic-auth credentials. +func (k *BasicAuthsCollection) GetAll() ([]*BasicAuth, error) { + creds, err := k.credentialsCollection.GetAll() + if err != nil { + return nil, err + } + + var res []*BasicAuth + for _, cred := range creds { + r, ok := cred.(*BasicAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &BasicAuth{BasicAuth: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/basicauth_test.go b/pkg/state/basicauth_test.go new file mode 100644 index 0000000..c191e5f --- /dev/null +++ b/pkg/state/basicauth_test.go @@ -0,0 +1,219 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func basicAuthsCollection() *BasicAuthsCollection { + return state().BasicAuths +} + +func TestBasicAuthInsert(t *testing.T) { + assert := assert.New(t) + collection := basicAuthsCollection() + + var basicAuth BasicAuth + basicAuth.ID = kong.String("first") + err := collection.Add(basicAuth) + assert.NotNil(err) + + basicAuth.Username = kong.String("my-username") + err = collection.Add(basicAuth) + assert.NotNil(err) + + var basicAuth2 BasicAuth + basicAuth2.Username = kong.String("my-username") + basicAuth2.ID = kong.String("first") + basicAuth2.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + Username: kong.String("my-username"), + } + err = collection.Add(basicAuth2) + assert.Nil(err) +} + +func TestBasicAuthGet(t *testing.T) { + assert := assert.New(t) + collection := basicAuthsCollection() + + var basicAuth BasicAuth + basicAuth.Username = kong.String("my-username") + basicAuth.ID = kong.String("first") + basicAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(basicAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-username", *res.Username) + + res, err = collection.Get("my-username") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("first", *res.ID) + assert.Equal("consumer1-id", *res.Consumer.ID) + + res, err = collection.Get("does-not-exist") + assert.NotNil(err) + assert.Nil(res) +} + +func TestBasicAuthUpdate(t *testing.T) { + assert := assert.New(t) + collection := basicAuthsCollection() + + var basicAuth BasicAuth + basicAuth.Username = kong.String("my-username") + basicAuth.ID = kong.String("first") + basicAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(basicAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-username", *res.Username) + + res.Username = kong.String("my-username2") + res.Password = kong.String("password") + err = collection.Update(*res) + assert.Nil(err) + + res, err = collection.Get("my-username") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.Get("my-username2") + assert.Nil(err) + assert.Equal("first", *res.ID) + assert.Equal("password", *res.Password) +} + +func TestBasicAuthDelete(t *testing.T) { + assert := assert.New(t) + collection := basicAuthsCollection() + + var basicAuth BasicAuth + basicAuth.Username = kong.String("my-username1") + basicAuth.ID = kong.String("first") + basicAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + err := collection.Add(basicAuth) + assert.Nil(err) + + res, err := collection.Get("my-username1") + assert.Nil(err) + assert.NotNil(res) + + err = collection.Delete(*res.ID) + assert.Nil(err) + + res, err = collection.Get("my-username1") + assert.NotNil(err) + assert.Nil(res) + + // delete a non-existing one + err = collection.Delete("first") + assert.NotNil(err) + + err = collection.Delete("my-username1") + assert.NotNil(err) +} + +func TestBasicAuthGetAll(t *testing.T) { + assert := assert.New(t) + collection := basicAuthsCollection() + + populateWithBasicAuthFixtures(assert, collection) + + basicAuths, err := collection.GetAll() + assert.Nil(err) + assert.Equal(5, len(basicAuths)) +} + +func TestBasicAuthGetByConsumer(t *testing.T) { + assert := assert.New(t) + collection := basicAuthsCollection() + + populateWithBasicAuthFixtures(assert, collection) + + basicAuths, err := collection.GetAllByConsumerID("consumer1-id") + assert.Nil(err) + assert.Equal(3, len(basicAuths)) +} + +func populateWithBasicAuthFixtures(assert *assert.Assertions, + collection *BasicAuthsCollection, +) { + basicAuths := []BasicAuth{ + { + BasicAuth: kong.BasicAuth{ + Username: kong.String("my-username11"), + ID: kong.String("first"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + BasicAuth: kong.BasicAuth{ + Username: kong.String("my-username12"), + ID: kong.String("second"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + BasicAuth: kong.BasicAuth{ + Username: kong.String("my-username13"), + ID: kong.String("third"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + BasicAuth: kong.BasicAuth{ + Username: kong.String("my-username21"), + ID: kong.String("fourth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + { + BasicAuth: kong.BasicAuth{ + Username: kong.String("my-username22"), + ID: kong.String("fifth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + } + + for _, k := range basicAuths { + err := collection.Add(k) + assert.Nil(err) + } +} diff --git a/pkg/state/builder.go b/pkg/state/builder.go new file mode 100644 index 0000000..97526da --- /dev/null +++ b/pkg/state/builder.go @@ -0,0 +1,385 @@ +package state + +import ( + "errors" + "fmt" + + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// Get builds a KongState from a raw representation of Kong. +func Get(raw *utils.KongRawState) (*KongState, error) { + kongState, err := NewKongState() + if err != nil { + return nil, fmt.Errorf("creating new in-memory state of Kong: %w", err) + } + err = buildKong(kongState, raw) + if err != nil { + return nil, err + } + return kongState, nil +} + +func ensureService(kongState *KongState, serviceID string) (bool, *kong.Service, error) { + s, err := kongState.Services.Get(serviceID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return false, nil, nil + } + return false, nil, fmt.Errorf("looking up service %q: %w", serviceID, err) + + } + return true, utils.GetServiceReference(s.Service), nil +} + +func ensureRoute(kongState *KongState, routeID string) (bool, *kong.Route, error) { + r, err := kongState.Routes.Get(routeID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return false, nil, nil + } + return false, nil, fmt.Errorf("looking up route %q: %w", routeID, err) + + } + return true, utils.GetRouteReference(r.Route), nil +} + +func ensureConsumer(kongState *KongState, consumerID string) (bool, *kong.Consumer, error) { + c, err := kongState.Consumers.GetByIDOrUsername(consumerID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return false, nil, nil + } + return false, nil, fmt.Errorf("looking up consumer %q: %w", consumerID, err) + + } + return true, utils.GetConsumerReference(c.Consumer), nil +} + +func ensureConsumerGroup(kongState *KongState, consumerGroupID string) (bool, *kong.ConsumerGroup, error) { + c, err := kongState.ConsumerGroups.Get(consumerGroupID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return false, nil, nil + } + return false, nil, fmt.Errorf("looking up consumer-group %q: %w", consumerGroupID, err) + + } + return true, utils.GetConsumerGroupReference(c.ConsumerGroup), nil +} + +func buildKong(kongState *KongState, raw *utils.KongRawState) error { + for _, s := range raw.Services { + err := kongState.Services.Add(Service{Service: *s}) + if err != nil { + return fmt.Errorf("inserting service into state: %w", err) + } + } + for _, r := range raw.Routes { + if r.Service != nil && !utils.Empty(r.Service.ID) { + ok, s, err := ensureService(kongState, *r.Service.ID) + if err != nil { + return err + } + if ok { + r.Service = s + } + } + err := kongState.Routes.Add(Route{Route: *r}) + if err != nil { + return fmt.Errorf("inserting route into state: %w", err) + } + } + for _, c := range raw.Consumers { + err := kongState.Consumers.Add(Consumer{Consumer: *c}) + if err != nil { + return fmt.Errorf("inserting consumer into state: %w", err) + } + } + for _, cg := range raw.ConsumerGroups { + err := kongState.ConsumerGroups.Add(ConsumerGroup{ConsumerGroup: *cg.ConsumerGroup}) + if err != nil { + return fmt.Errorf("inserting consumer group into state: %w", err) + } + utils.ZeroOutTimestamps(cg.ConsumerGroup) + for _, c := range cg.Consumers { + err := kongState.ConsumerGroupConsumers.Add( + ConsumerGroupConsumer{ + ConsumerGroupConsumer: kong.ConsumerGroupConsumer{ + Consumer: c, ConsumerGroup: cg.ConsumerGroup, + }, + }, + ) + if err != nil { + return fmt.Errorf("inserting consumer group consumer into state: %w", err) + } + } + + for _, p := range cg.Plugins { + err := kongState.ConsumerGroupPlugins.Add( + ConsumerGroupPlugin{ + ConsumerGroupPlugin: kong.ConsumerGroupPlugin{ + ID: p.ID, + Name: p.Name, + Config: p.Config, + ConsumerGroup: cg.ConsumerGroup, + }, + }, + ) + if err != nil { + return fmt.Errorf("inserting consumer group plugin into state: %w", err) + } + } + } + for _, cred := range raw.KeyAuths { + ok, c, err := ensureConsumer(kongState, *cred.Consumer.ID) + if err != nil { + return err + } + if !ok { + continue + } + cred.Consumer = c + err = kongState.KeyAuths.Add(KeyAuth{KeyAuth: *cred}) + if err != nil { + return fmt.Errorf("inserting key-auth into state: %w", err) + } + } + for _, cred := range raw.HMACAuths { + ok, c, err := ensureConsumer(kongState, *cred.Consumer.ID) + if err != nil { + return err + } + if !ok { + continue + } + cred.Consumer = c + err = kongState.HMACAuths.Add(HMACAuth{HMACAuth: *cred}) + if err != nil { + return fmt.Errorf("inserting hmac-auth into state: %w", err) + } + } + for _, cred := range raw.JWTAuths { + ok, c, err := ensureConsumer(kongState, *cred.Consumer.ID) + if err != nil { + return err + } + if !ok { + continue + } + cred.Consumer = c + err = kongState.JWTAuths.Add(JWTAuth{JWTAuth: *cred}) + if err != nil { + return fmt.Errorf("inserting jwt into state: %w", err) + } + } + for _, cred := range raw.BasicAuths { + ok, c, err := ensureConsumer(kongState, *cred.Consumer.ID) + if err != nil { + return err + } + if !ok { + continue + } + cred.Consumer = c + err = kongState.BasicAuths.Add(BasicAuth{BasicAuth: *cred}) + if err != nil { + return fmt.Errorf("inserting basic-auth into state: %w", err) + } + } + for _, cred := range raw.Oauth2Creds { + ok, c, err := ensureConsumer(kongState, *cred.Consumer.ID) + if err != nil { + return err + } + if !ok { + continue + } + cred.Consumer = c + err = kongState.Oauth2Creds.Add(Oauth2Credential{Oauth2Credential: *cred}) + if err != nil { + return fmt.Errorf("inserting oauth2-cred into state: %w", err) + } + } + for _, cred := range raw.ACLGroups { + ok, c, err := ensureConsumer(kongState, *cred.Consumer.ID) + if err != nil { + return err + } + if !ok { + continue + } + cred.Consumer = c + err = kongState.ACLGroups.Add(ACLGroup{ACLGroup: *cred}) + if err != nil { + return fmt.Errorf("inserting basic-auth into state: %w", err) + } + } + for _, cred := range raw.MTLSAuths { + ok, c, err := ensureConsumer(kongState, *cred.Consumer.ID) + if err != nil { + return err + } + if !ok { + continue + } + cred.Consumer = c + err = kongState.MTLSAuths.Add(MTLSAuth{MTLSAuth: *cred}) + if err != nil { + return fmt.Errorf("inserting mtls-auth into state: %w", err) + } + } + for _, u := range raw.Upstreams { + err := kongState.Upstreams.Add(Upstream{Upstream: *u}) + if err != nil { + return fmt.Errorf("inserting upstream into state: %w", err) + } + } + for _, t := range raw.Targets { + err := kongState.Targets.Add(Target{Target: *t}) + if err != nil { + return fmt.Errorf("inserting target into state: %w", err) + } + } + + for _, c := range raw.Certificates { + err := kongState.Certificates.Add(Certificate{Certificate: *c}) + if err != nil { + return fmt.Errorf("inserting certificate into state: %w", err) + } + } + + for _, s := range raw.SNIs { + err := kongState.SNIs.Add(SNI{SNI: *s}) + if err != nil { + return fmt.Errorf("inserting sni into state: %w", err) + } + } + + for _, c := range raw.CACertificates { + err := kongState.CACertificates.Add(CACertificate{ + CACertificate: *c, + }) + if err != nil { + return fmt.Errorf("inserting ca_certificate into state: %w", err) + } + } + + for _, p := range raw.Plugins { + if p.Service != nil && !utils.Empty(p.Service.ID) { + ok, s, err := ensureService(kongState, *p.Service.ID) + if err != nil { + return err + } + if ok { + p.Service = s + } + } + if p.Route != nil && !utils.Empty(p.Route.ID) { + ok, r, err := ensureRoute(kongState, *p.Route.ID) + if err != nil { + return err + } + if ok { + p.Route = r + } + } + if p.Consumer != nil && !utils.Empty(p.Consumer.ID) { + ok, c, err := ensureConsumer(kongState, *p.Consumer.ID) + if err != nil { + return err + } + if ok { + p.Consumer = c + } + } + if p.ConsumerGroup != nil && !utils.Empty(p.ConsumerGroup.ID) { + ok, cg, err := ensureConsumerGroup(kongState, *p.ConsumerGroup.ID) + if err != nil { + return err + } + if ok { + p.ConsumerGroup = cg + } + } + err := kongState.Plugins.Add(Plugin{Plugin: *p}) + if err != nil { + return fmt.Errorf("inserting plugins into state: %w", err) + } + } + + for _, r := range raw.RBACRoles { + err := kongState.RBACRoles.Add(RBACRole{RBACRole: *r}) + if err != nil { + return fmt.Errorf("inserting rbac roles into state: %w", err) + } + } + for _, r := range raw.RBACEndpointPermissions { + err := kongState.RBACEndpointPermissions.Add(RBACEndpointPermission{RBACEndpointPermission: *r}) + if err != nil { + return fmt.Errorf("inserting rbac endpoint permissions into state: %w", err) + } + } + for _, v := range raw.Vaults { + err := kongState.Vaults.Add(Vault{Vault: *v}) + if err != nil { + return fmt.Errorf("inserting vault into state: %w", err) + } + } + return nil +} + +func buildKonnect(kongState *KongState, raw *utils.KonnectRawState) error { + for _, s := range raw.ServicePackages { + servicePackage := s.DeepCopy() + servicePackage.Versions = nil + err := kongState.ServicePackages.Add(ServicePackage{ + ServicePackage: *servicePackage, + }) + if err != nil { + return fmt.Errorf("inserting service-package into state: %w", err) + } + + for _, v := range s.Versions { + v = *v.DeepCopy() + v.ServicePackage = servicePackage.DeepCopy() + err := kongState.ServiceVersions.Add(ServiceVersion{ + ServiceVersion: v, + }) + if err != nil { + return fmt.Errorf("inserting service-version into state: %w", err) + } + } + } + for _, d := range raw.Documents { + document := d.ShallowCopy() + err := kongState.Documents.Add(Document{ + Document: *document, + }) + if err != nil { + return fmt.Errorf("inserting document into state: %w", err) + } + } + return nil +} + +func GetKonnectState(rawKong *utils.KongRawState, + rawKonnect *utils.KonnectRawState, +) (*KongState, error) { + kongState, err := NewKongState() + if err != nil { + return nil, fmt.Errorf("creating new in-memory state of Kong: %w", err) + } + + err = buildKong(kongState, rawKong) + if err != nil { + return nil, err + } + + err = buildKonnect(kongState, rawKonnect) + if err != nil { + return nil, err + } + return kongState, nil +} diff --git a/pkg/state/cacert.go b/pkg/state/cacert.go new file mode 100644 index 0000000..a147864 --- /dev/null +++ b/pkg/state/cacert.go @@ -0,0 +1,173 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + caCertTableName = "caCert" +) + +var caCertTableSchema = &memdb.TableSchema{ + Name: caCertTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "cert": { + Name: "cert", + Unique: true, + Indexer: &indexers.MD5FieldsIndexer{ + Fields: []string{"Cert"}, + }, + }, + all: allIndex, + }, +} + +// CACertificatesCollection stores and indexes Kong CACertificates. +type CACertificatesCollection collection + +// Add adds a caCert to the collection +func (k *CACertificatesCollection) Add(caCert CACertificate) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(caCert.ID) { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *caCert.ID) + if !utils.Empty(caCert.Cert) { + searchBy = append(searchBy, *caCert.Cert) + } + _, err := getCACert(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting ca-cert %v: %w", caCert.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(caCertTableName, &caCert) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getCACert(txn *memdb.Txn, IDs ...string) (*CACertificate, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, caCertTableName, + []string{"cert", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + caCert, ok := res.(*CACertificate) + if !ok { + panic(unexpectedType) + } + return &CACertificate{CACertificate: *caCert.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets a caCertificate by cert or ID. +func (k *CACertificatesCollection) Get(certOrID string) (*CACertificate, error) { + if certOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + return getCACert(txn, certOrID) +} + +// Update udpates an existing caCert. +// It returns an error if the caCert is not already present. +func (k *CACertificatesCollection) Update(caCert CACertificate) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(caCert.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteCACert(txn, *caCert.ID) + if err != nil { + return err + } + + err = txn.Insert(caCertTableName, &caCert) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteCACert(txn *memdb.Txn, certOrID string) error { + caCert, err := getCACert(txn, certOrID) + if err != nil { + return err + } + + err = txn.Delete(caCertTableName, caCert) + if err != nil { + return err + } + return nil +} + +// Delete deletes a caCertificate by looking up it's cert and key. +func (k *CACertificatesCollection) Delete(certOrID string) error { + if certOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteCACert(txn, certOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a caCertificate by name or ID. +func (k *CACertificatesCollection) GetAll() ([]*CACertificate, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(caCertTableName, all, true) + if err != nil { + return nil, err + } + + var res []*CACertificate + for el := iter.Next(); el != nil; el = iter.Next() { + c, ok := el.(*CACertificate) + if !ok { + panic(unexpectedType) + } + res = append(res, &CACertificate{CACertificate: *c.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/cacert_test.go b/pkg/state/cacert_test.go new file mode 100644 index 0000000..0b93fde --- /dev/null +++ b/pkg/state/cacert_test.go @@ -0,0 +1,144 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func caCertsCollection() *CACertificatesCollection { + return state().CACertificates +} + +func TestCACertificateInsert(t *testing.T) { + assert := assert.New(t) + collection := caCertsCollection() + + var caCert CACertificate + assert.NotNil(collection.Add(caCert)) + caCert.ID = kong.String("first") + assert.NotNil(collection.Add(caCert)) + caCert.Cert = kong.String("firstCert") + assert.Nil(collection.Add(caCert)) + + // re-inesrt + assert.NotNil(collection.Add(caCert)) +} + +func TestCACertificateGetUpdate(t *testing.T) { + assert := assert.New(t) + collection := caCertsCollection() + + var caCert CACertificate + + assert.NotNil(collection.Update(caCert)) + + caCert.Cert = kong.String("firstCert") + caCert.ID = kong.String("first") + assert.NotNil(collection.Update(caCert)) + + err := collection.Add(caCert) + assert.Nil(err) + + se, err := collection.Get("") + assert.NotNil(err) + assert.Nil(se) + + se, err = collection.Get("firstCert") + assert.Nil(err) + assert.NotNil(se) + se.Cert = kong.String("firstCert-updated") + err = collection.Update(*se) + assert.Nil(err) + + se, err = collection.Get("firstCert-updated") + assert.Nil(err) + assert.NotNil(se) + assert.Equal("firstCert-updated", *se.Cert) + + se, err = collection.Get("not-present") + assert.Equal(ErrNotFound, err) + assert.Nil(se) +} + +func TestCACertInvalidType(t *testing.T) { + assert := assert.New(t) + collection := caCertsCollection() + + var cert Certificate + cert.Cert = kong.String("my-cert") + cert.ID = kong.String("first") + txn := collection.db.Txn(true) + txn.Insert(caCertTableName, &cert) + txn.Commit() + + assert.Panics(func() { + collection.Get("my-cert") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestCACertificateDelete(t *testing.T) { + assert := assert.New(t) + collection := caCertsCollection() + + assert.NotNil(collection.Delete("")) + + var caCert CACertificate + caCert.ID = kong.String("first") + caCert.Cert = kong.String("firstCert") + err := collection.Add(caCert) + assert.Nil(err) + + se, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(se) + assert.Equal("firstCert", *se.Cert) + + err = collection.Delete(*se.ID) + assert.Nil(err) + + err = collection.Delete(*se.ID) + assert.NotNil(err) + + caCert.ID = kong.String("first") + caCert.Cert = kong.String("firstCert") + err = collection.Add(caCert) + assert.Nil(err) + + se, err = collection.Get("first") + assert.Nil(err) + assert.NotNil(se) + assert.Equal("firstCert", *se.Cert) + + err = collection.Delete(*se.Cert) + assert.Nil(err) + + err = collection.Delete(*se.ID) + assert.NotNil(err) +} + +func TestCACertificateGetAll(t *testing.T) { + assert := assert.New(t) + collection := caCertsCollection() + + var caCert CACertificate + caCert.ID = kong.String("first") + caCert.Cert = kong.String("firstCert") + err := collection.Add(caCert) + assert.Nil(err) + + var certificate2 CACertificate + certificate2.ID = kong.String("second") + certificate2.Cert = kong.String("secondCert") + err = collection.Add(certificate2) + assert.Nil(err) + + certificates, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(certificates)) +} diff --git a/pkg/state/certificate.go b/pkg/state/certificate.go new file mode 100644 index 0000000..a2f472d --- /dev/null +++ b/pkg/state/certificate.go @@ -0,0 +1,235 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + certificateTableName = "certificate" +) + +var certificateTableSchema = &memdb.TableSchema{ + Name: certificateTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "certkey": { + Name: "certkey", + Unique: true, + Indexer: &indexers.MD5FieldsIndexer{ + Fields: []string{"Cert", "Key"}, + }, + }, + all: allIndex, + }, +} + +func validateCert(certificate Certificate) error { + if utils.Empty(certificate.Key) { + return fmt.Errorf("certificate's Key cannot be empty") + } + if utils.Empty(certificate.Cert) { + return fmt.Errorf("certificate's Cert cannot be empty") + } + return nil +} + +// CertificatesCollection stores and indexes Kong Certificates. +type CertificatesCollection collection + +// Add adds a certificate to the collection +func (k *CertificatesCollection) Add(certificate Certificate) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(certificate.ID) { + return errIDRequired + } + if err := validateCert(certificate); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + _, err := getCertificate(txn, *certificate.ID) + if err == nil { + return fmt.Errorf("inserting certificate %v: %w", certificate.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(certificateTableName, &certificate) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getCertificate(txn *memdb.Txn, id string) (*Certificate, error) { + res, err := multiIndexLookupUsingTxn(txn, certificateTableName, + []string{"id"}, id) + if err != nil { + return nil, err + } + + c, ok := res.(*Certificate) + if !ok { + panic(unexpectedType) + } + return &Certificate{Certificate: *c.DeepCopy()}, nil +} + +// Get gets a certificate by ID. +func (k *CertificatesCollection) Get(id string) (*Certificate, error) { + if id == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + certificate, err := getCertificate(txn, id) + if err != nil { + return nil, err + } + return certificate, nil +} + +func getCertificateByCertKey(txn *memdb.Txn, cert, key string) (*Certificate, error) { + res, err := txn.First(certificateTableName, "certkey", cert, key) + if err != nil { + return nil, err + } + if res == nil { + return nil, ErrNotFound + } + c, ok := res.(*Certificate) + if !ok { + panic(unexpectedType) + } + return &Certificate{Certificate: *c.DeepCopy()}, nil +} + +// GetByCertKey gets a certificate with +// the same key and cert from the collection. +func (k *CertificatesCollection) GetByCertKey(cert, + key string, +) (*Certificate, error) { + if cert == "" || key == "" { + return nil, fmt.Errorf("cert/key cannot be empty string") + } + + txn := k.db.Txn(false) + defer txn.Abort() + + return getCertificateByCertKey(txn, cert, key) +} + +// Update udpates an existing certificate. +// It returns an error if the certificate is not already present. +func (k *CertificatesCollection) Update(certificate Certificate) error { + if utils.Empty(certificate.ID) { + return errIDRequired + } + if err := validateCert(certificate); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteCertificate(txn, *certificate.ID) + if err != nil { + return err + } + + err = txn.Insert(certificateTableName, &certificate) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteCertificate(txn *memdb.Txn, id string) error { + cert, err := getCertificate(txn, id) + if err != nil { + return err + } + + err = txn.Delete(certificateTableName, cert) + if err != nil { + return err + } + return nil +} + +// Delete deletes a certificate by ID. +func (k *CertificatesCollection) Delete(id string) error { + if id == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteCertificate(txn, id) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// DeleteByCertKey deletes a certificate by looking up it's cert and key. +func (k *CertificatesCollection) DeleteByCertKey(cert, key string) error { + if cert == "" || key == "" { + return fmt.Errorf("cert/key cannot be empty string") + } + + txn := k.db.Txn(true) + defer txn.Abort() + + certificate, err := getCertificateByCertKey(txn, cert, key) + if err != nil { + return err + } + err = deleteCertificate(txn, *certificate.ID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a certificate by name or ID. +func (k *CertificatesCollection) GetAll() ([]*Certificate, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(certificateTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Certificate + for el := iter.Next(); el != nil; el = iter.Next() { + c, ok := el.(*Certificate) + if !ok { + panic(unexpectedType) + } + res = append(res, &Certificate{Certificate: *c.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/certificate_test.go b/pkg/state/certificate_test.go new file mode 100644 index 0000000..fd4fb50 --- /dev/null +++ b/pkg/state/certificate_test.go @@ -0,0 +1,220 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func certsCollection() *CertificatesCollection { + return state().Certificates +} + +func TestCertificateInsert(t *testing.T) { + assert := assert.New(t) + collection := certsCollection() + + var certificate Certificate + assert.NotNil(collection.Add(certificate)) + + certificate.ID = kong.String("first") + assert.NotNil(collection.Add(certificate)) + + certificate.Key = kong.String("firstKey") + assert.NotNil(collection.Add(certificate)) + + certificate.Cert = kong.String("firstCert") + err := collection.Add(certificate) + assert.Nil(err) + + // re-insert + assert.NotNil(collection.Add(certificate)) +} + +func TestCertificateGetUpdate(t *testing.T) { + assert := assert.New(t) + collection := certsCollection() + + var certificate Certificate + certificate.Cert = kong.String("firstCert") + certificate.Key = kong.String("firstKey") + certificate.ID = kong.String("first") + err := collection.Add(certificate) + assert.Nil(err) + + se, err := collection.GetByCertKey("firstCert", "firstKey") + assert.Nil(err) + assert.NotNil(se) + se.ID = nil + assert.NotNil(collection.Update(*se)) + + se.ID = kong.String("first") + se.Key = nil + se.Cert = kong.String("firstCert-updated") + err = collection.Update(*se) + assert.NotNil(err) + + se.Key = kong.String("firstKey-updated") + err = collection.Update(*se) + assert.Nil(err) + + se, err = collection.Get("") + assert.Nil(se) + assert.NotNil(err) + + se, err = collection.GetByCertKey("firstCert-updated", "firstKey-updated") + assert.Nil(err) + assert.NotNil(se) + assert.Equal("firstCert-updated", *se.Cert) + + se, err = collection.GetByCertKey("", "") + assert.NotNil(err) + assert.Nil(se) + + se, err = collection.GetByCertKey("not-present", "firstsdfsdfKey") + assert.Equal(ErrNotFound, err) + assert.Nil(se) +} + +// Regression test +// to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestCertificateGetMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := certsCollection() + + var cert Certificate + cert.Cert = kong.String("my-cert") + cert.Key = kong.String("my-key") + cert.ID = kong.String("first") + err := collection.Add(cert) + assert.Nil(err) + + c, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(c) + c.Cert = kong.String("my-new-cert") + + c, err = collection.Get("first") + assert.Nil(err) + assert.NotNil(c) + assert.Equal("my-cert", *c.Cert) +} + +func TestCertificatesInvalidType(t *testing.T) { + assert := assert.New(t) + collection := certsCollection() + + var upstream Upstream + upstream.Name = kong.String("my-upstream") + upstream.ID = kong.String("first") + txn := collection.db.Txn(true) + err := txn.Insert(certificateTableName, &upstream) + assert.NotNil(err) + txn.Abort() + + type badCertificate struct { + kong.Certificate + Meta + } + + certificate := badCertificate{ + Certificate: kong.Certificate{ + ID: kong.String("id"), + Cert: kong.String("Cert"), + Key: kong.String("Key"), + }, + } + + txn = collection.db.Txn(true) + err = txn.Insert(certificateTableName, &certificate) + assert.Nil(err) + txn.Commit() + + assert.Panics(func() { + collection.Get("id") + }) + + assert.Panics(func() { + collection.GetByCertKey("Cert", "Key") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestCertificateDelete(t *testing.T) { + assert := assert.New(t) + collection := certsCollection() + + var certificate Certificate + certificate.ID = kong.String("first") + certificate.Cert = kong.String("firstCert") + certificate.Key = kong.String("firstKey") + err := collection.Add(certificate) + assert.Nil(err) + + se, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(se) + assert.Equal("firstCert", *se.Cert) + + err = collection.Delete(*se.ID) + assert.Nil(err) + + err = collection.Delete(*se.ID) + assert.NotNil(err) + + certificate.ID = kong.String("first") + certificate.Cert = kong.String("firstCert") + certificate.Key = kong.String("firstKey") + err = collection.Add(certificate) + assert.Nil(err) + + se, err = collection.Get("first") + assert.Nil(err) + assert.NotNil(se) + assert.Equal("firstCert", *se.Cert) + + assert.NotNil(collection.DeleteByCertKey("", "")) + + assert.NotNil(collection.DeleteByCertKey("foo", "bar")) + + err = collection.DeleteByCertKey(*se.Cert, *se.Key) + assert.Nil(err) + + err = collection.Delete("") + assert.NotNil(err) + + err = collection.Delete(*se.ID) + assert.NotNil(err) + + se, err = collection.Get("first") + assert.NotNil(err) + assert.Nil(se) +} + +func TestCertificateGetAll(t *testing.T) { + assert := assert.New(t) + collection := certsCollection() + + var certificate Certificate + certificate.ID = kong.String("first") + certificate.Cert = kong.String("firstCert") + certificate.Key = kong.String("firstKey") + err := collection.Add(certificate) + assert.Nil(err) + + var certificate2 Certificate + certificate2.ID = kong.String("second") + certificate2.Cert = kong.String("secondCert") + certificate2.Key = kong.String("secondKey") + err = collection.Add(certificate2) + assert.Nil(err) + + certificates, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(certificates)) +} diff --git a/pkg/state/consumer.go b/pkg/state/consumer.go new file mode 100644 index 0000000..ac956cd --- /dev/null +++ b/pkg/state/consumer.go @@ -0,0 +1,204 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/utils" +) + +const ( + consumerTableName = "consumer" +) + +var consumerTableSchema = &memdb.TableSchema{ + Name: consumerTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "Username": { + Name: "Username", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Username"}, + AllowMissing: true, + }, + "CustomID": { + Name: "CustomID", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "CustomID"}, + AllowMissing: true, + }, + all: allIndex, + }, +} + +// ConsumersCollection stores and indexes Kong Consumers. +type ConsumersCollection collection + +// Add adds a consumer to the collection +// An error is thrown if consumer.ID is empty. +func (k *ConsumersCollection) Add(consumer Consumer) error { + if utils.Empty(consumer.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *consumer.ID) + if !utils.Empty(consumer.Username) { + searchBy = append(searchBy, *consumer.Username) + } + + // search separately by id+username and by custom_id. + // + // This is because the custom_id is unique, but it may be equal to + // the username of another consumer. If we search by both id+username and + // custom_id, we may get a false positive. + _, err := getConsumer(txn, []string{"Username", "id"}, searchBy...) + if err == nil { + return fmt.Errorf("inserting consumer by username %v: %w", consumer.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + if !utils.Empty(consumer.CustomID) { + searchBy = []string{*consumer.CustomID} + _, err = getConsumer(txn, []string{"CustomID"}, searchBy...) + if err == nil { + return fmt.Errorf("inserting consumer by custom_id %v: %w", consumer.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + } + + err = txn.Insert(consumerTableName, &consumer) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getConsumer(txn *memdb.Txn, indexes []string, IDs ...string) (*Consumer, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, consumerTableName, indexes, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + consumer, ok := res.(*Consumer) + if !ok { + panic(unexpectedType) + } + return &Consumer{Consumer: *consumer.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// GetByIDOrUsername gets a consumer by name or ID. +func (k *ConsumersCollection) GetByIDOrUsername(userNameOrID string) (*Consumer, error) { + if userNameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + return getConsumer(txn, []string{"Username", "id"}, userNameOrID) +} + +// GetByCustomID gets a consumer by customID. +func (k *ConsumersCollection) GetByCustomID(customID string) (*Consumer, error) { + if customID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + return getConsumer(txn, []string{"CustomID"}, customID) +} + +// Update udpates an existing consumer. +// It returns an error if the consumer is not already present. +func (k *ConsumersCollection) Update(consumer Consumer) error { + // TODO abstract this in the go-memdb library itself + if utils.Empty(consumer.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteConsumer(txn, *consumer.ID) + if err != nil { + return err + } + + err = txn.Insert(consumerTableName, &consumer) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteConsumer(txn *memdb.Txn, userNameOrID string) error { + consumer, err := getConsumer(txn, []string{"Username", "id"}, userNameOrID) + if err != nil { + return err + } + + err = txn.Delete(consumerTableName, consumer) + if err != nil { + return err + } + return nil +} + +// Delete deletes a consumer by name or ID. +func (k *ConsumersCollection) Delete(userNameOrID string) error { + if userNameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteConsumer(txn, userNameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a consumer by name or ID. +func (k *ConsumersCollection) GetAll() ([]*Consumer, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(consumerTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Consumer + for el := iter.Next(); el != nil; el = iter.Next() { + s, ok := el.(*Consumer) + if !ok { + panic(unexpectedType) + } + res = append(res, &Consumer{Consumer: *s.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/consumer_group.go b/pkg/state/consumer_group.go new file mode 100644 index 0000000..7b1c6ff --- /dev/null +++ b/pkg/state/consumer_group.go @@ -0,0 +1,176 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/utils" +) + +const ( + consumerGroupTableName = "consumerGroup" +) + +var consumerGroupTableSchema = &memdb.TableSchema{ + Name: consumerGroupTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Name"}, + }, + all: allIndex, + }, +} + +// consumerGroupsCollection stores and indexes Kong consumerGroups. +type ConsumerGroupsCollection collection + +// Add adds an consumerGroup to the collection. +// consumerGroup.ID should not be nil else an error is thrown. +func (k *ConsumerGroupsCollection) Add(consumerGroup ConsumerGroup) error { + if utils.Empty(consumerGroup.ID) { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *consumerGroup.ID) + if !utils.Empty(consumerGroup.Name) { + searchBy = append(searchBy, *consumerGroup.Name) + } + _, err := getConsumerGroup(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting consumerGroup %v: %w", consumerGroup.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(consumerGroupTableName, &consumerGroup) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getConsumerGroup(txn *memdb.Txn, IDs ...string) (*ConsumerGroup, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, consumerGroupTableName, + []string{"name", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + + consumerGroup, ok := res.(*ConsumerGroup) + if !ok { + panic(unexpectedType) + } + return &ConsumerGroup{ConsumerGroup: *consumerGroup.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets an consumerGroup by name or ID. +func (k *ConsumerGroupsCollection) Get(nameOrID string) (*ConsumerGroup, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + consumerGroup, err := getConsumerGroup(txn, nameOrID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return consumerGroup, nil +} + +// Update updates an existing consumerGroup. +func (k *ConsumerGroupsCollection) Update(consumerGroup ConsumerGroup) error { + if utils.Empty(consumerGroup.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteConsumerGroup(txn, *consumerGroup.ID) + if err != nil { + return err + } + + err = txn.Insert(consumerGroupTableName, &consumerGroup) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteConsumerGroup(txn *memdb.Txn, nameOrID string) error { + consumerGroup, err := getConsumerGroup(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(consumerGroupTableName, consumerGroup) + if err != nil { + return err + } + return nil +} + +// Delete deletes an consumerGroup by its name or ID. +func (k *ConsumerGroupsCollection) Delete(nameOrID string) error { + if nameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteConsumerGroup(txn, nameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all consumerGroups in the state. +func (k *ConsumerGroupsCollection) GetAll() ([]*ConsumerGroup, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(consumerGroupTableName, all, true) + if err != nil { + return nil, err + } + + var res []*ConsumerGroup + for el := iter.Next(); el != nil; el = iter.Next() { + u, ok := el.(*ConsumerGroup) + if !ok { + panic(unexpectedType) + } + res = append(res, &ConsumerGroup{ConsumerGroup: *u.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/consumer_group_consumers.go b/pkg/state/consumer_group_consumers.go new file mode 100644 index 0000000..c05adea --- /dev/null +++ b/pkg/state/consumer_group_consumers.go @@ -0,0 +1,244 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + consumerGroupConsumerTableName = "consumerGroupConsumer" + consumerByGroupID = "consumerByGroupID" +) + +var errInvalidConsumerGroup = fmt.Errorf("consumer_group.ID is required in consumer group consumers") + +var consumerGroupConsumerTableSchema = &memdb.TableSchema{ + Name: consumerGroupConsumerTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Consumer", + Sub: "ID", + }, + { + Struct: "ConsumerGroup", + Sub: "ID", + }, + }, + }, + }, + "username": { + Name: "username", + Unique: true, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Consumer", + Sub: "Username", + }, + }, + }, + }, + all: allIndex, + // foreign + consumerByGroupID: { + Name: consumerByGroupID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "ConsumerGroup", + Sub: "ID", + }, + }, + }, + }, + }, +} + +func validateConsumerGroup(consumer *ConsumerGroupConsumer) error { + if consumer.ConsumerGroup == nil || + utils.Empty(consumer.ConsumerGroup.ID) { + return errInvalidConsumerGroup + } + return nil +} + +// ConsumerGroupConsumersCollection stores and indexes Kong consumerGroupConsumers. +type ConsumerGroupConsumersCollection collection + +// Add adds a consumerGroupConsumer to the collection. +func (k *ConsumerGroupConsumersCollection) Add(consumer ConsumerGroupConsumer) error { + if utils.Empty(consumer.Consumer.ID) { + return errIDRequired + } + + if err := validateConsumerGroup(&consumer); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *consumer.Consumer.ID, *consumer.ConsumerGroup.ID) + if !utils.Empty(consumer.Consumer.Username) { + searchBy = append(searchBy, *consumer.Consumer.Username) + } + _, err := getConsumerGroupConsumer(txn, *consumer.ConsumerGroup.ID, searchBy...) + if err == nil { + return fmt.Errorf("inserting consumerGroupConsumer %v: %w", consumer.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(consumerGroupConsumerTableName, &consumer) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getAllByConsumerGroupID(txn *memdb.Txn, consumerGroupID string) ([]*ConsumerGroupConsumer, error) { + iter, err := txn.Get(consumerGroupConsumerTableName, consumerByGroupID, consumerGroupID) + if err != nil { + return nil, err + } + + var consumers []*ConsumerGroupConsumer + for el := iter.Next(); el != nil; el = iter.Next() { + t, ok := el.(*ConsumerGroupConsumer) + if !ok { + panic(unexpectedType) + } + consumers = append(consumers, &ConsumerGroupConsumer{ConsumerGroupConsumer: *t.DeepCopy()}) + } + return consumers, nil +} + +func getConsumerGroupConsumer(txn *memdb.Txn, consumerGroupID string, IDs ...string) (*ConsumerGroupConsumer, error) { + consumers, err := getAllByConsumerGroupID(txn, consumerGroupID) + if err != nil { + return nil, err + } + + for _, id := range IDs { + for _, consumer := range consumers { + if id == *consumer.Consumer.ID || id == *consumer.Consumer.Username { + return &ConsumerGroupConsumer{ConsumerGroupConsumer: *consumer.DeepCopy()}, nil + } + } + } + return nil, ErrNotFound +} + +// Get gets a consumerGroupConsumer. +func (k *ConsumerGroupConsumersCollection) Get( + nameOrID, consumerGroupID string, +) (*ConsumerGroupConsumer, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + return getConsumerGroupConsumer(txn, consumerGroupID, nameOrID) +} + +// Update udpates an existing consumerGroupConsumer. +func (k *ConsumerGroupConsumersCollection) Update(consumer ConsumerGroupConsumer) error { + if utils.Empty(consumer.Consumer.ID) { + return errIDRequired + } + + if err := validateConsumerGroup(&consumer); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + res, err := txn.First(consumerGroupConsumerTableName, "id", + *consumer.Consumer.ID, *consumer.ConsumerGroup.ID) + if err != nil { + return err + } + + t, ok := res.(*ConsumerGroupConsumer) + if !ok { + panic(unexpectedType) + } + err = txn.Delete(consumerGroupConsumerTableName, *t) + if err != nil { + return err + } + + err = txn.Insert(consumerGroupConsumerTableName, &consumer) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteConsumerGroupConsumer(txn *memdb.Txn, nameOrID, consumerGroupID string) error { + consumer, err := getConsumerGroupConsumer(txn, consumerGroupID, nameOrID) + if err != nil { + return err + } + err = txn.Delete(consumerGroupConsumerTableName, consumer) + if err != nil { + return err + } + return nil +} + +// Delete deletes a consumerGroupConsumer by its username or ID. +func (k *ConsumerGroupConsumersCollection) Delete(nameOrID, consumerGroupID string) error { + if nameOrID == "" { + return errIDRequired + } + + if consumerGroupID == "" { + return errInvalidConsumerGroup + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteConsumerGroupConsumer(txn, nameOrID, consumerGroupID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all consumerGroupConsumers in the state. +func (k *ConsumerGroupConsumersCollection) GetAll() ([]*ConsumerGroupConsumer, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(consumerGroupConsumerTableName, all, true) + if err != nil { + return nil, err + } + + var res []*ConsumerGroupConsumer + for el := iter.Next(); el != nil; el = iter.Next() { + u, ok := el.(*ConsumerGroupConsumer) + if !ok { + panic(unexpectedType) + } + res = append(res, &ConsumerGroupConsumer{ConsumerGroupConsumer: *u.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/consumer_group_plugin.go b/pkg/state/consumer_group_plugin.go new file mode 100644 index 0000000..f313219 --- /dev/null +++ b/pkg/state/consumer_group_plugin.go @@ -0,0 +1,242 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + consumerGroupPluginTableName = "consumerGroupPlugin" + pluginByGroupID = "pluginByGroupID" +) + +var consumerGroupPluginTableSchema = &memdb.TableSchema{ + Name: consumerGroupPluginTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "ConsumerGroupPlugin", + Sub: "ID", + }, + { + Struct: "ConsumerGroup", + Sub: "ID", + }, + }, + }, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "ConsumerGroupPlugin", + Sub: "Name", + }, + }, + }, + }, + all: allIndex, + // foreign + pluginByGroupID: { + Name: pluginByGroupID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "ConsumerGroup", + Sub: "ID", + }, + }, + }, + }, + }, +} + +func validatePluginGroup(plugin *ConsumerGroupPlugin) error { + if plugin.ConsumerGroup == nil || + utils.Empty(plugin.ConsumerGroup.ID) { + return errInvalidConsumerGroup + } + return nil +} + +// ConsumerGroupPluginsCollection stores and indexes Kong consumerGroupPlugins. +type ConsumerGroupPluginsCollection collection + +// Add adds a consumerGroupPlugin to the collection. +func (k *ConsumerGroupPluginsCollection) Add(plugin ConsumerGroupPlugin) error { + var nameOrID string + if plugin.ConsumerGroupPlugin.ID != nil { + nameOrID = *plugin.ConsumerGroupPlugin.ID + } else { + nameOrID = *plugin.ConsumerGroupPlugin.Name + } + + if err := validatePluginGroup(&plugin); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, nameOrID, *plugin.ConsumerGroup.ID) + _, err := getConsumerGroupPlugin(txn, *plugin.ConsumerGroup.ID, searchBy...) + if err == nil { + return fmt.Errorf("inserting consumerGroupPlugin %v: %w", plugin.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(consumerGroupPluginTableName, &plugin) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getAllPluginsByConsumerGroupID(txn *memdb.Txn, consumerGroupID string) ([]*ConsumerGroupPlugin, error) { + iter, err := txn.Get(consumerGroupPluginTableName, pluginByGroupID, consumerGroupID) + if err != nil { + return nil, err + } + + var plugins []*ConsumerGroupPlugin + for el := iter.Next(); el != nil; el = iter.Next() { + t, ok := el.(*ConsumerGroupPlugin) + if !ok { + panic(unexpectedType) + } + plugins = append(plugins, &ConsumerGroupPlugin{ConsumerGroupPlugin: *t.DeepCopy()}) + } + return plugins, nil +} + +func getConsumerGroupPlugin(txn *memdb.Txn, consumerGroupID string, IDs ...string) (*ConsumerGroupPlugin, error) { + plugins, err := getAllPluginsByConsumerGroupID(txn, consumerGroupID) + if err != nil { + return nil, err + } + + for _, id := range IDs { + for _, plugin := range plugins { + if id == *plugin.ID || id == *plugin.Name { + return &ConsumerGroupPlugin{ConsumerGroupPlugin: *plugin.DeepCopy()}, nil + } + } + } + return nil, ErrNotFound +} + +// Get gets a consumerGroupPlugin. +func (k *ConsumerGroupPluginsCollection) Get( + nameOrID, consumerGroupID string, +) (*ConsumerGroupPlugin, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + return getConsumerGroupPlugin(txn, consumerGroupID, nameOrID) +} + +// Update udpates an existing consumerGroupPlugin. +func (k *ConsumerGroupPluginsCollection) Update(plugin ConsumerGroupPlugin) error { + if utils.Empty(plugin.ConsumerGroupPlugin.ID) { + return errIDRequired + } + + if err := validatePluginGroup(&plugin); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + res, err := txn.First(consumerGroupPluginTableName, "id", + *plugin.ConsumerGroupPlugin.ID, *plugin.ConsumerGroup.ID) + if err != nil { + return err + } + + t, ok := res.(*ConsumerGroupPlugin) + if !ok { + panic(unexpectedType) + } + err = txn.Delete(consumerGroupPluginTableName, *t) + if err != nil { + return err + } + + err = txn.Insert(consumerGroupPluginTableName, &plugin) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteConsumerGroupPlugin(txn *memdb.Txn, nameOrID, consumerGroupID string) error { + consumer, err := getConsumerGroupPlugin(txn, consumerGroupID, nameOrID) + if err != nil { + return err + } + err = txn.Delete(consumerGroupPluginTableName, consumer) + if err != nil { + return err + } + return nil +} + +// Delete deletes a consumerGroupPlugin by its username or ID. +func (k *ConsumerGroupPluginsCollection) Delete(nameOrID, consumerGroupID string) error { + if nameOrID == "" { + return errIDRequired + } + + if consumerGroupID == "" { + return errInvalidConsumerGroup + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteConsumerGroupPlugin(txn, nameOrID, consumerGroupID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all consumerGroupPlugins in the state. +func (k *ConsumerGroupPluginsCollection) GetAll() ([]*ConsumerGroupPlugin, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(consumerGroupPluginTableName, all, true) + if err != nil { + return nil, err + } + + var res []*ConsumerGroupPlugin + for el := iter.Next(); el != nil; el = iter.Next() { + u, ok := el.(*ConsumerGroupPlugin) + if !ok { + panic(unexpectedType) + } + res = append(res, &ConsumerGroupPlugin{ConsumerGroupPlugin: *u.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/consumer_test.go b/pkg/state/consumer_test.go new file mode 100644 index 0000000..35ecf53 --- /dev/null +++ b/pkg/state/consumer_test.go @@ -0,0 +1,161 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func consumersCollection() *ConsumersCollection { + return state().Consumers +} + +func TestConsumerInsert(t *testing.T) { + assert := assert.New(t) + collection := consumersCollection() + + var consumer Consumer + + assert.NotNil(collection.Add(consumer)) + + consumer.ID = kong.String("first") + assert.Nil(collection.Add(consumer)) + + // re-insert + consumer.Username = kong.String("my-name") + assert.NotNil(collection.Add(consumer)) +} + +func TestConsumerGetUpdate(t *testing.T) { + assert := assert.New(t) + collection := consumersCollection() + + var consumer Consumer + consumer.ID = kong.String("first") + consumer.Username = kong.String("my-name") + err := collection.Add(consumer) + assert.Nil(err) + + c, err := collection.GetByIDOrUsername("") + assert.NotNil(err) + assert.Nil(c) + + c, err = collection.GetByIDOrUsername("first") + assert.Nil(err) + assert.NotNil(c) + + c.ID = nil + c.Username = kong.String("my-updated-name") + err = collection.Update(*c) + assert.NotNil(err) + + c.ID = kong.String("does-not-exist") + assert.NotNil(collection.Update(*c)) + + c.ID = kong.String("first") + assert.Nil(collection.Update(*c)) + + c, err = collection.GetByIDOrUsername("my-name") + assert.NotNil(err) + assert.Nil(c) + + c, err = collection.GetByIDOrUsername("my-updated-name") + assert.Nil(err) + assert.NotNil(c) +} + +// Test to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestConsumerGetMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := consumersCollection() + + var consumer Consumer + consumer.ID = kong.String("first") + consumer.Username = kong.String("my-name") + err := collection.Add(consumer) + assert.Nil(err) + + c, err := collection.GetByIDOrUsername("first") + assert.Nil(err) + assert.NotNil(c) + c.Username = kong.String("update-should-not-reflect") + + c, err = collection.GetByIDOrUsername("first") + assert.Nil(err) + assert.Equal("my-name", *c.Username) +} + +func TestConsumersInvalidType(t *testing.T) { + assert := assert.New(t) + collection := consumersCollection() + + type c2 Consumer + var c c2 + c.Username = kong.String("my-name") + c.ID = kong.String("first") + txn := collection.db.Txn(true) + assert.Nil(txn.Insert(consumerTableName, &c)) + txn.Commit() + + assert.Panics(func() { + collection.GetByIDOrUsername("my-name") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestConsumerDelete(t *testing.T) { + assert := assert.New(t) + collection := consumersCollection() + + var consumer Consumer + consumer.ID = kong.String("first") + consumer.Username = kong.String("my-consumer") + err := collection.Add(consumer) + assert.Nil(err) + + c, err := collection.GetByIDOrUsername("my-consumer") + assert.Nil(err) + assert.NotNil(c) + assert.Equal("first", *c.ID) + + err = collection.Delete("first") + assert.Nil(err) + + err = collection.Delete("") + assert.NotNil(err) + + err = collection.Delete(*c.ID) + assert.NotNil(err) +} + +func TestConsumerGetAll(t *testing.T) { + assert := assert.New(t) + collection := consumersCollection() + + consumers := []Consumer{ + { + Consumer: kong.Consumer{ + ID: kong.String("first"), + Username: kong.String("my-consumer1"), + }, + }, + { + Consumer: kong.Consumer{ + ID: kong.String("second"), + Username: kong.String("my-consumer2"), + }, + }, + } + for _, s := range consumers { + assert.Nil(collection.Add(s)) + } + + allConsumers, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(len(consumers), len(allConsumers)) +} diff --git a/pkg/state/credentials.go b/pkg/state/credentials.go new file mode 100644 index 0000000..ec5d534 --- /dev/null +++ b/pkg/state/credentials.go @@ -0,0 +1,201 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" +) + +const ( + byConsumerID = "byConsumerID" +) + +// credentialsCollection stores and indexes key-auth credentials. +type credentialsCollection struct { + collection + CredType string +} + +func (k *credentialsCollection) TableName() string { + return k.CredType +} + +func (k *credentialsCollection) Schema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: k.CredType, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + byConsumerID: { + Name: byConsumerID, + Indexer: &indexers.MethodIndexer{ + Method: "GetConsumer", + }, + }, + "id2": { + Name: "id2", + Unique: true, + Indexer: &indexers.MethodIndexer{ + Method: "GetID2", + }, + }, + all: allIndex, + }, + } +} + +func (k *credentialsCollection) getCred(txn *memdb.Txn, IDs ...string) (entity, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, k.CredType, + []string{"id", "id2"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + cred, ok := res.(entity) + if !ok { + panic(unexpectedType) + } + return cred, nil + } + return nil, ErrNotFound +} + +// Add adds a key-auth credential to credentialsCollection. +func (k *credentialsCollection) Add(cred entity) error { + if cred.GetID() == "" { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + // TODO detect unique constraint violation for ID2 + + _, err := k.getCred(txn, cred.GetID(), cred.GetID2()) + if err == nil { + return fmt.Errorf("inserting credential %v: %w", cred.GetID(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(k.CredType, cred) + if err != nil { + return err + } + txn.Commit() + return nil +} + +// Get gets a credential by ID or endpoint key. +func (k *credentialsCollection) Get(id string) (entity, error) { + if id == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + return k.getCred(txn, id) +} + +// Update updates an existing key-auth credential. +func (k *credentialsCollection) Update(cred entity) error { + // TODO abstract this check in the go-memdb library itself + if cred.GetID() == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := k.deleteCred(txn, cred.GetID()) + if err != nil { + return err + } + err = txn.Insert(k.CredType, cred) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func (k *credentialsCollection) deleteCred(txn *memdb.Txn, nameOrID string) error { + cred, err := k.getCred(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(k.CredType, cred) + if err != nil { + return err + } + return nil +} + +// Delete deletes a key-auth credential by key or ID. +func (k *credentialsCollection) Delete(id string) error { + if id == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := k.deleteCred(txn, id) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all key-auth credentials. +func (k *credentialsCollection) GetAll() ([]entity, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(k.CredType, all, true) + if err != nil { + return nil, err + } + + var res []entity + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(entity) + if !ok { + panic(unexpectedType) + } + res = append(res, r) + } + return res, nil +} + +// GetAllByConsumerID returns all key-auth credentials +// belong to a Consumer with id. +func (k *credentialsCollection) GetAllByConsumerID(id string) ([]entity, + error, +) { + txn := k.db.Txn(false) + iter, err := txn.Get(k.CredType, byConsumerID, id) + if err != nil { + return nil, err + } + var res []entity + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(entity) + if !ok { + panic(unexpectedType) + } + res = append(res, r) + } + return res, nil +} diff --git a/pkg/state/document.go b/pkg/state/document.go new file mode 100644 index 0000000..d5eb1c9 --- /dev/null +++ b/pkg/state/document.go @@ -0,0 +1,219 @@ +package state + +import ( + "errors" + "fmt" + + "github.com/hashicorp/go-memdb" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + documentTableName = "document" + documentsByParent = "documentsByParent" +) + +var ( + errDocumentMissingParent = fmt.Errorf("Document has no Parent") + errDocumentPathRequired = fmt.Errorf("Document must have a Path") +) + +// DocumentsCollection stores and indexes key-auth credentials. +type DocumentsCollection collection + +var documentTableSchema = &memdb.TableSchema{ + Name: documentTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + all: allIndex, + // foreign + documentsByParent: { + Name: documentsByParent, + Indexer: &indexers.MethodIndexer{ + Method: "ParentKey", + }, + }, + }, +} + +// Add adds a document into DocumentsCollection +// document.ID should not be nil else an error is thrown. +func (k *DocumentsCollection) Add(document Document) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(document.ID) { + return errIDRequired + } + + if utils.Empty(document.Path) { + return errDocumentPathRequired + } + + if document.Parent == nil { + return errDocumentMissingParent + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *document.ID) + searchBy = append(searchBy, *document.Path) + _, err := getDocument(txn, document.ParentKey(), searchBy...) + if err == nil { + return fmt.Errorf("inserting document %v: %w", document.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(documentTableName, &document) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getDocument(txn *memdb.Txn, parentKey string, IDs ...string) (*Document, error) { + if parentKey == "" { + return nil, fmt.Errorf("parentKey is required") + } + documents, err := getAllDocsByParentKey(txn, parentKey) + if err != nil { + return nil, err + } + + for _, id := range IDs { + for _, document := range documents { + if id == *document.ID || id == *document.Path { + return &Document{Document: *document.ShallowCopy()}, nil + } + } + } + return nil, ErrNotFound +} + +func getAllDocsByParentKey(txn *memdb.Txn, parentKey string) ([]*Document, error) { + iter, err := txn.Get(documentTableName, documentsByParent, parentKey) + if err != nil { + return nil, err + } + + var documents []*Document + for el := iter.Next(); el != nil; el = iter.Next() { + d, ok := el.(*Document) + if !ok { + panic(unexpectedType) + } + documents = append(documents, &Document{Document: *d.ShallowCopy()}) + } + return documents, nil +} + +// Update updates a Document +func (k *DocumentsCollection) Update(document Document) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(document.ID) { + return errIDRequired + } + + if document.Parent == nil { + return errDocumentMissingParent + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteDocument(txn, document.ParentKey(), documentsByParent, *document.ID) + if err != nil { + return err + } + + err = txn.Insert(documentTableName, &document) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteDocument(txn *memdb.Txn, key string, index string, pathOrID string) error { + document, err := getDocument(txn, key, index, pathOrID) + if err != nil { + return err + } + + err = txn.Delete(documentTableName, document) + if err != nil { + return err + } + return nil +} + +// DeleteByParent deletes a Document by parent and path or ID. +func (k *DocumentsCollection) DeleteByParent(parent konnect.ParentInfoer, pathOrID string) error { + if pathOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteDocument(txn, parent.Key(), documentsByParent, pathOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all Documents. +func (k *DocumentsCollection) GetAll() ([]*Document, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(documentTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Document + for el := iter.Next(); el != nil; el = iter.Next() { + d, ok := el.(*Document) + if !ok { + panic(unexpectedType) + } + res = append(res, &Document{Document: *d.ShallowCopy()}) + } + txn.Commit() + return res, nil +} + +// GetAllByParent returns all documents for a Parent +func (k *DocumentsCollection) GetAllByParent(parent konnect.ParentInfoer) ([]*Document, error) { + if parent == nil { + return make([]*Document, 0), errDocumentMissingParent + } + txn := k.db.Txn(false) + return getAllDocsByParentKey(txn, parent.Key()) +} + +// GetByParent returns a document attached to a Parent with a given path or ID +func (k *DocumentsCollection) GetByParent(parent konnect.ParentInfoer, pathOrID string) (*Document, error) { + if parent == nil { + return nil, errDocumentMissingParent + } + txn := k.db.Txn(false) + document, err := getDocument(txn, parent.Key(), documentsByParent, pathOrID) + if err != nil { + return nil, err + } + return document, nil +} diff --git a/pkg/state/document_test.go b/pkg/state/document_test.go new file mode 100644 index 0000000..04ea0e8 --- /dev/null +++ b/pkg/state/document_test.go @@ -0,0 +1,438 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/kong/deck/konnect" + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func documentCollection() *DocumentsCollection { + return state().Documents +} + +func TestDocumentCollection_Add(t *testing.T) { + type args struct { + document Document + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + document: Document{ + Document: konnect.Document{ + Path: kong.String("foo"), + Parent: &konnect.ServiceVersion{ + ID: kong.String("abc"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "errors without a path", + args: args{ + document: Document{ + Document: konnect.Document{ + ID: kong.String("id1"), + Parent: &konnect.ServiceVersion{ + ID: kong.String("abc"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "errors without a parent", + args: args{ + document: Document{ + Document: konnect.Document{ + ID: kong.String("id1"), + Path: kong.String("foo"), + }, + }, + }, + wantErr: true, + }, + { + name: "works with ServiceVersion parent", + args: args{ + document: Document{ + Document: konnect.Document{ + ID: kong.String("id2"), + Path: kong.String("bar"), + Parent: &konnect.ServiceVersion{ + ID: kong.String("whatever"), + Version: kong.String("abc"), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "works with ServicePackage parent", + args: args{ + document: Document{ + Document: konnect.Document{ + ID: kong.String("id3"), + Path: kong.String("bar"), + Parent: &konnect.ServicePackage{ + ID: kong.String("whatever"), + Name: kong.String("abc"), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "errors on re-insert when id is present", + args: args{ + document: Document{ + Document: konnect.Document{ + ID: kong.String("id4"), + Path: kong.String("abc"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + }, + wantErr: true, + }, + } + k := documentCollection() + d1 := Document{ + Document: konnect.Document{ + ID: kong.String("id4"), + Path: kong.String("abc"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + k.Add(d1) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.document); (err != nil) != tt.wantErr { + t.Errorf("DocumentCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDocumentCollection_GetByParent(t *testing.T) { + type args struct { + pathOrID string + parent konnect.ParentInfoer + } + d1 := Document{ + Document: konnect.Document{ + ID: kong.String("foo-id"), + Path: kong.String("path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + d2 := Document{ + Document: konnect.Document{ + ID: kong.String("bar-id"), + Path: kong.String("path"), + Parent: &konnect.ServiceVersion{ + ID: kong.String("id2"), + }, + }, + } + tests := []struct { + name string + args args + want *Document + wantErr bool + }{ + { + name: "gets a document by parent and ID", + args: args{ + pathOrID: "foo-id", + parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + want: &d1, + wantErr: false, + }, + { + name: "gets a document by parent and path", + args: args{ + pathOrID: "path", + parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + want: &d1, + wantErr: false, + }, + { + name: "returns an error when parent missing", + args: args{ + pathOrID: "bar-name", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an ErrNotFound when no document found", + args: args{ + pathOrID: "baz-id", + parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + want: nil, + wantErr: true, + }, + { + name: "returns an error when ID is empty", + args: args{ + pathOrID: "", + }, + want: nil, + wantErr: true, + }, + } + k := documentCollection() + k.Add(d1) + k.Add(d2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := k.GetByParent(tt.args.parent, tt.args.pathOrID) + if (err != nil) != tt.wantErr { + t.Errorf("DocumentCollection.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DocumentCollection.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDocumentCollection_Update(t *testing.T) { + d1 := Document{ + Document: konnect.Document{ + ID: kong.String("foo-id"), + Path: kong.String("foo-path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + d2 := Document{ + Document: konnect.Document{ + ID: kong.String("bar-id"), + Path: kong.String("bar-path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + d3 := Document{ + Document: konnect.Document{ + ID: kong.String("foo-id"), + Path: kong.String("new-foo-path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + type args struct { + document Document + } + tests := []struct { + name string + args args + wantErr bool + updatedDocument *Document + }{ + { + name: "update errors if document.ID is nil", + args: args{ + document: Document{ + Document: konnect.Document{ + Path: kong.String("name"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "update errors if document does not exist", + args: args{ + document: Document{ + Document: konnect.Document{ + ID: kong.String("does-not-exist"), + }, + }, + }, + wantErr: true, + }, + { + name: "update succeeds when ID is supplied", + args: args{ + document: d3, + }, + wantErr: false, + updatedDocument: &d3, + }, + } + k := documentCollection() + k.Add(d1) + k.Add(d2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // t.Parallel() + if err := k.Update(tt.args.document); (err != nil) != tt.wantErr { + t.Errorf("DocumentCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + got, _ := k.GetByParent(tt.updatedDocument.Parent, *tt.updatedDocument.ID) + + if !reflect.DeepEqual(got, tt.updatedDocument) { + t.Errorf("update document, got = %#v, want %#v", got, tt.updatedDocument) + } + } + }) + } +} + +func TestDocumentDeleteByParent(t *testing.T) { + assert := assert.New(t) + collection := documentCollection() + + var document Document + document.Path = kong.String("my-document") + document.ID = kong.String("first") + document.Parent = &konnect.ServicePackage{ + ID: kong.String("package-id1"), + } + err := collection.Add(document) + assert.Nil(err) + + re, err := collection.GetByParent(document.Parent, "my-document") + assert.Nil(err) + assert.NotNil(re) + + err = collection.DeleteByParent(document.Parent, *re.ID) + assert.Nil(err) + + err = collection.DeleteByParent(document.Parent, *re.ID) + assert.NotNil(err) +} + +func TestDocumentGetAll(t *testing.T) { + assert := assert.New(t) + collection := documentCollection() + + var d1 Document + d1.Path = kong.String("my-d1") + d1.ID = kong.String("first") + d1.Parent = &konnect.ServicePackage{ + ID: kong.String("id1"), + } + err := collection.Add(d1) + assert.Nil(err) + + var d2 Document + d2.Path = kong.String("my-d2") + d2.ID = kong.String("second") + d2.Parent = &konnect.ServicePackage{ + ID: kong.String("id1"), + } + err = collection.Add(d2) + assert.Nil(err) + + documents, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(documents)) +} + +func TestDocumentGetAllByParent(t *testing.T) { + assert := assert.New(t) + collection := documentCollection() + + documents := []*Document{ + { + Document: konnect.Document{ + ID: kong.String("d1-id"), + Path: kong.String("d1-path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + { + Document: konnect.Document{ + ID: kong.String("d2-id"), + Path: kong.String("d2-path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + { + Document: konnect.Document{ + ID: kong.String("d3-id"), + Path: kong.String("d3-path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id2"), + }, + }, + }, + { + Document: konnect.Document{ + ID: kong.String("d4-id"), + Path: kong.String("d4-path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id2"), + }, + }, + }, + { + Document: konnect.Document{ + ID: kong.String("d5-id"), + Path: kong.String("d5-path"), + Parent: &konnect.ServicePackage{ + ID: kong.String("id2"), + }, + }, + }, + } + + for _, document := range documents { + err := collection.Add(*document) + assert.Nil(err) + } + + documents, err := collection.GetAllByParent(&konnect.ServicePackage{ID: kong.String("id1")}) + assert.Nil(err) + assert.Equal(2, len(documents)) + + documents, err = collection.GetAllByParent(&konnect.ServicePackage{ID: kong.String("id2")}) + assert.Nil(err) + assert.Equal(3, len(documents)) +} diff --git a/pkg/state/hmacauth.go b/pkg/state/hmacauth.go new file mode 100644 index 0000000..48a01cf --- /dev/null +++ b/pkg/state/hmacauth.go @@ -0,0 +1,85 @@ +package state + +// HMACAuthsCollection stores and indexes hmac-auth credentials. +type HMACAuthsCollection struct { + credentialsCollection +} + +func newHMACAuthsCollection(common collection) *HMACAuthsCollection { + return &HMACAuthsCollection{ + credentialsCollection: credentialsCollection{ + collection: common, + CredType: "hmac-auth", + }, + } +} + +// Add adds a hmac-auth credential to HMACAuthsCollection +func (k *HMACAuthsCollection) Add(hmacAuth HMACAuth) error { + cred := (entity)(&hmacAuth) + return k.credentialsCollection.Add(cred) +} + +// Get gets a hmac-auth credential by key or ID. +func (k *HMACAuthsCollection) Get(keyOrID string) (*HMACAuth, error) { + cred, err := k.credentialsCollection.Get(keyOrID) + if err != nil { + return nil, err + } + + hmacAuth, ok := cred.(*HMACAuth) + if !ok { + panic(unexpectedType) + } + return &HMACAuth{HMACAuth: *hmacAuth.DeepCopy()}, nil +} + +// GetAllByConsumerID returns all hmac-auth credentials +// belong to a Consumer with id. +func (k *HMACAuthsCollection) GetAllByConsumerID(id string) ([]*HMACAuth, + error, +) { + creds, err := k.credentialsCollection.GetAllByConsumerID(id) + if err != nil { + return nil, err + } + + var res []*HMACAuth + for _, cred := range creds { + r, ok := cred.(*HMACAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &HMACAuth{HMACAuth: *r.DeepCopy()}) + } + return res, nil +} + +// Update updates an existing hmac-auth credential. +func (k *HMACAuthsCollection) Update(hmacAuth HMACAuth) error { + cred := (entity)(&hmacAuth) + return k.credentialsCollection.Update(cred) +} + +// Delete deletes a hmac-auth credential by key or ID. +func (k *HMACAuthsCollection) Delete(keyOrID string) error { + return k.credentialsCollection.Delete(keyOrID) +} + +// GetAll gets all hmac-auth credentials. +func (k *HMACAuthsCollection) GetAll() ([]*HMACAuth, error) { + creds, err := k.credentialsCollection.GetAll() + if err != nil { + return nil, err + } + + var res []*HMACAuth + for _, cred := range creds { + r, ok := cred.(*HMACAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &HMACAuth{HMACAuth: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/hmacauth_test.go b/pkg/state/hmacauth_test.go new file mode 100644 index 0000000..9cbaa30 --- /dev/null +++ b/pkg/state/hmacauth_test.go @@ -0,0 +1,218 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func hmacAuthsCollection() *HMACAuthsCollection { + return state().HMACAuths +} + +func TestHMACAuthInsert(t *testing.T) { + assert := assert.New(t) + collection := hmacAuthsCollection() + + var hmacAuth HMACAuth + hmacAuth.ID = kong.String("first") + err := collection.Add(hmacAuth) + assert.NotNil(err) + + hmacAuth.Username = kong.String("my-username") + err = collection.Add(hmacAuth) + assert.NotNil(err) + + var hmacAuth2 HMACAuth + hmacAuth2.Username = kong.String("my-username") + hmacAuth2.ID = kong.String("first") + hmacAuth2.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + Username: kong.String("my-username"), + } + err = collection.Add(hmacAuth2) + assert.Nil(err) +} + +func TestHMACAuthGet(t *testing.T) { + assert := assert.New(t) + collection := hmacAuthsCollection() + + var hmacAuth HMACAuth + hmacAuth.Username = kong.String("my-username") + hmacAuth.ID = kong.String("first") + hmacAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(hmacAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-username", *res.Username) + + res, err = collection.Get("my-username") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("first", *res.ID) + assert.Equal("consumer1-id", *res.Consumer.ID) + + res, err = collection.Get("does-not-exist") + assert.NotNil(err) + assert.Nil(res) +} + +func TestHMACAuthUpdate(t *testing.T) { + assert := assert.New(t) + collection := hmacAuthsCollection() + + var hmacAuth HMACAuth + hmacAuth.Username = kong.String("my-username") + hmacAuth.ID = kong.String("first") + hmacAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + } + + err := collection.Add(hmacAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-username", *res.Username) + + res.Username = kong.String("my-username2") + res.Secret = kong.String("secret") + err = collection.Update(*res) + assert.Nil(err) + + res, err = collection.Get("my-username") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.Get("my-username2") + assert.Nil(err) + assert.Equal("first", *res.ID) + assert.Equal("secret", *res.Secret) +} + +func TestHMACAuthDelete(t *testing.T) { + assert := assert.New(t) + collection := hmacAuthsCollection() + + var hmacAuth HMACAuth + hmacAuth.Username = kong.String("my-username1") + hmacAuth.ID = kong.String("first") + hmacAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + err := collection.Add(hmacAuth) + assert.Nil(err) + + res, err := collection.Get("my-username1") + assert.Nil(err) + assert.NotNil(res) + + err = collection.Delete(*res.ID) + assert.Nil(err) + + res, err = collection.Get("my-username1") + assert.NotNil(err) + assert.Nil(res) + + // delete a non-existing one + err = collection.Delete("first") + assert.NotNil(err) + + err = collection.Delete("my-username1") + assert.NotNil(err) +} + +func TestHMACAuthGetAll(t *testing.T) { + assert := assert.New(t) + collection := hmacAuthsCollection() + + populateWithHMACAuthFixtures(assert, collection) + + hmacAuths, err := collection.GetAll() + assert.Nil(err) + assert.Equal(5, len(hmacAuths)) +} + +func TestHMACAuthGetByConsumer(t *testing.T) { + assert := assert.New(t) + collection := hmacAuthsCollection() + + populateWithHMACAuthFixtures(assert, collection) + + hmacAuths, err := collection.GetAllByConsumerID("consumer1-id") + assert.Nil(err) + assert.Equal(3, len(hmacAuths)) +} + +func populateWithHMACAuthFixtures(assert *assert.Assertions, + collection *HMACAuthsCollection, +) { + hmacAuths := []HMACAuth{ + { + HMACAuth: kong.HMACAuth{ + Username: kong.String("my-username11"), + ID: kong.String("first"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + HMACAuth: kong.HMACAuth{ + Username: kong.String("my-username12"), + ID: kong.String("second"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + HMACAuth: kong.HMACAuth{ + Username: kong.String("my-username13"), + ID: kong.String("third"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + HMACAuth: kong.HMACAuth{ + Username: kong.String("my-username21"), + ID: kong.String("fourth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + { + HMACAuth: kong.HMACAuth{ + Username: kong.String("my-username22"), + ID: kong.String("fifth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + } + + for _, k := range hmacAuths { + err := collection.Add(k) + assert.Nil(err) + } +} diff --git a/pkg/state/indexers/md5Indexer.go b/pkg/state/indexers/md5Indexer.go new file mode 100644 index 0000000..32d4dc3 --- /dev/null +++ b/pkg/state/indexers/md5Indexer.go @@ -0,0 +1,55 @@ +package indexers + +import ( + "crypto/md5" + "fmt" + "reflect" +) + +// MD5FieldsIndexer is used to create an index based on md5sum of +// string or *string fields. +type MD5FieldsIndexer struct { + // Fields to use for md5sum calculation + Fields []string +} + +// FromObject take Obj and returns index key formed using +// the fields. +func (s *MD5FieldsIndexer) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + blob := "" + for _, field := range s.Fields { + fv := v.FieldByName(field) + fv = reflect.Indirect(fv) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", field, obj) + } + blob += fv.String() + } + if blob == "" { + return false, nil, nil + } + md5Sum := md5.Sum([]byte(blob)) + return true, md5Sum[:], nil +} + +// FromArgs takes in a string and returns its byte form. +func (s *MD5FieldsIndexer) FromArgs(args ...interface{}) ([]byte, error) { + blob := "" + for _, arg := range args { + s, ok := arg.(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", arg) + } + blob += s + } + if blob == "" { + return nil, fmt.Errorf("empty args is not a valid value") + } + // Add the null character as a terminator + md5Sum := md5.Sum([]byte(blob)) + return md5Sum[:], nil +} diff --git a/pkg/state/indexers/md5Indexer_test.go b/pkg/state/indexers/md5Indexer_test.go new file mode 100644 index 0000000..7c03b8c --- /dev/null +++ b/pkg/state/indexers/md5Indexer_test.go @@ -0,0 +1,60 @@ +package indexers + +import ( + "crypto/md5" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMD5FieldsIndexer(t *testing.T) { + assert := assert.New(t) + + type Foo struct { + Bar *string + Baz *string + } + + in := &MD5FieldsIndexer{ + Fields: []string{"Bar", "Baz"}, + } + s1 := "yolo" + s2 := "oloy" + b := Foo{ + Bar: &s1, + Baz: &s2, + } + + ok, val, err := in.FromObject(b) + assert.True(ok) + assert.Nil(err) + sum := md5.Sum([]byte(s1 + s2)) + assert.Equal(sum[:], val) + + val, err = in.FromArgs(s1, s2) + assert.Nil(err) + assert.Equal(sum[:], val) + + ok, val, err = in.FromObject(Foo{}) + assert.False(ok) + assert.NotNil(err) + assert.Empty(val) + + s1 = "" + s2 = "" + ok, val, err = in.FromObject(Foo{ + Bar: &s1, + Baz: &s2, + }) + assert.False(ok) + assert.Nil(err) + assert.Empty(val) + + val, err = in.FromArgs("") + assert.NotNil(err) + assert.Nil(val) + + val, err = in.FromArgs(2) + assert.NotNil(err) + assert.Nil(val) +} diff --git a/pkg/state/indexers/methodIndexer.go b/pkg/state/indexers/methodIndexer.go new file mode 100644 index 0000000..e076d27 --- /dev/null +++ b/pkg/state/indexers/methodIndexer.go @@ -0,0 +1,48 @@ +package indexers + +import ( + "fmt" + "reflect" +) + +// MethodIndexer is used to create an index based on a string returned +// as a result of calling a method on the object. +// It is assumed that the method has no arguments. +type MethodIndexer struct { + // Method name to call to get the string to bulid the index on. + Method string +} + +// FromObject take Obj and returns index key formed using +// the fields. +func (s *MethodIndexer) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + + method := v.MethodByName(s.Method) + resp := method.Call(nil) + if len(resp) != 1 { + return false, nil, fmt.Errorf("function call returned unexpected result") + } + key := resp[0].String() + + if key == "" { + return false, nil, nil + } + return true, []byte(key), nil +} + +// FromArgs takes in a string and returns its byte form. +func (s *MethodIndexer) FromArgs(args ...interface{}) ([]byte, error) { + blob := "" + for _, arg := range args { + s, ok := arg.(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", arg) + } + blob += s + } + if blob == "" { + return nil, fmt.Errorf("empty args is not a valid value") + } + return []byte(blob), nil +} diff --git a/pkg/state/indexers/methodIndexer_test.go b/pkg/state/indexers/methodIndexer_test.go new file mode 100644 index 0000000..3267df5 --- /dev/null +++ b/pkg/state/indexers/methodIndexer_test.go @@ -0,0 +1,71 @@ +package indexers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type Foo struct { + id string +} + +func (f Foo) ID() string { + return f.id +} + +func (f Foo) BadID() (string, error) { + return f.id, nil +} + +type ID interface { + ID() string +} + +func TestMethodIndexer(t *testing.T) { + assert := assert.New(t) + + in := &MethodIndexer{ + Method: "ID", + } + b := Foo{ + id: "id1", + } + + ok, val, err := in.FromObject(b) + assert.True(ok) + assert.Nil(err) + assert.Equal([]byte("id1"), val) + + ok, val, err = in.FromObject(Foo{}) + assert.False(ok) + assert.Nil(err) + assert.Empty(val) + + idInterface := (ID)(b) + ok, val, err = in.FromObject(idInterface) + assert.True(ok) + assert.Nil(err) + assert.Equal([]byte("id1"), val) + + val, err = in.FromArgs("id1") + assert.Nil(err) + assert.Equal([]byte("id1"), val) + + val, err = in.FromArgs("") + assert.NotNil(err) + assert.Nil(val) + + val, err = in.FromArgs(42) + assert.NotNil(err) + assert.Nil(val) + + in = &MethodIndexer{ + Method: "BadID", + } + + ok, val, err = in.FromObject(Foo{id: "id1"}) + assert.False(ok) + assert.NotNil(err) + assert.Empty(val) +} diff --git a/pkg/state/indexers/subFieldIndexer.go b/pkg/state/indexers/subFieldIndexer.go new file mode 100644 index 0000000..7c56519 --- /dev/null +++ b/pkg/state/indexers/subFieldIndexer.go @@ -0,0 +1,66 @@ +package indexers + +import ( + "fmt" + "reflect" +) + +// Field represents a field that needs to be used for +// subfield indexing. +type Field struct { + // Struct is the name of the field of the struct + // being indexed. + Struct string + // Sub is the name of the field inside the struct Struct, + // which is being indexed. + Sub string +} + +// SubFieldIndexer is used to extract a field from an object +// using reflection and builds an index on that field. +type SubFieldIndexer struct { + Fields []Field +} + +// FromObject take Obj and returns index key formed using +// the field SubField. +func (s *SubFieldIndexer) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + val := "" + for _, f := range s.Fields { + structV := v.FieldByName(f.Struct) + structV = reflect.Indirect(structV) + if !structV.IsValid() { + continue + } + subField := structV.FieldByName(f.Sub) + subField = reflect.Indirect(subField) + + val += subField.String() + } + + if val == "" { + return false, nil, nil + } + + // Add the null character as a terminator + val += "\x00" + return true, []byte(val), nil +} + +// FromArgs takes in a string and returns its byte form. +func (s *SubFieldIndexer) FromArgs(args ...interface{}) ([]byte, error) { + val := "" + for _, arg := range args { + s, ok := arg.(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + val += s + } + // Add the null character as a terminator + val += "\x00" + return []byte(val), nil +} diff --git a/pkg/state/indexers/subFieldIndexer_test.go b/pkg/state/indexers/subFieldIndexer_test.go new file mode 100644 index 0000000..d869b05 --- /dev/null +++ b/pkg/state/indexers/subFieldIndexer_test.go @@ -0,0 +1,100 @@ +package indexers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSubFieldIndexer(t *testing.T) { + type Foo struct { + Bar *string + } + + type Baz struct { + A *Foo + } + + in := &SubFieldIndexer{ + Fields: []Field{ + { + Struct: "A", + Sub: "Bar", + }, + }, + } + s := "fubar" + b := Baz{ + A: &Foo{ + Bar: &s, + }, + } + + ok, val, err := in.FromObject(b) + assert := assert.New(t) + assert.True(ok) + assert.Nil(err) + assert.Equal("fubar\x00", string(val)) + + ok, val, err = in.FromObject(Baz{}) + assert.False(ok) + assert.Nil(err) + assert.Empty(val) + + s = "" + ok, val, err = in.FromObject(Baz{ + A: &Foo{ + Bar: &s, + }, + }) + assert.False(ok) + assert.Nil(err) + assert.Empty(val) + + val, err = in.FromArgs("fubar") + assert.Nil(err) + assert.Equal("fubar\x00", string(val)) + + val, err = in.FromArgs(2) + assert.Nil(val) + assert.NotNil(err) + + val, err = in.FromArgs("1", "2") + assert.Equal([]byte("12\x00"), val) + assert.Nil(err) +} + +func TestSubFieldIndexerPointer(t *testing.T) { + type Foo struct { + Bar *string + } + + type Baz struct { + A *Foo + } + + in := &SubFieldIndexer{ + Fields: []Field{ + { + Struct: "A", + Sub: "Bar", + }, + }, + } + s := "fubar" + b := Baz{ + A: &Foo{ + Bar: &s, + }, + } + + ok, val, err := in.FromObject(b) + assert := assert.New(t) + assert.True(ok) + assert.Nil(err) + assert.Equal("fubar\x00", string(val)) + + val, err = in.FromArgs("fubar") + assert.Nil(err) + assert.Equal("fubar\x00", string(val)) +} diff --git a/pkg/state/jwtauth.go b/pkg/state/jwtauth.go new file mode 100644 index 0000000..519be03 --- /dev/null +++ b/pkg/state/jwtauth.go @@ -0,0 +1,85 @@ +package state + +// JWTAuthsCollection stores and indexes jwt-auth credentials. +type JWTAuthsCollection struct { + credentialsCollection +} + +func newJWTAuthsCollection(common collection) *JWTAuthsCollection { + return &JWTAuthsCollection{ + credentialsCollection: credentialsCollection{ + collection: common, + CredType: "jwt-auth", + }, + } +} + +// Add adds a jwt-auth credential to JWTAuthsCollection +func (k *JWTAuthsCollection) Add(jwtAuth JWTAuth) error { + cred := (entity)(&jwtAuth) + return k.credentialsCollection.Add(cred) +} + +// Get gets a jwt-auth credential by key or ID. +func (k *JWTAuthsCollection) Get(keyOrID string) (*JWTAuth, error) { + cred, err := k.credentialsCollection.Get(keyOrID) + if err != nil { + return nil, err + } + + jwtAuth, ok := cred.(*JWTAuth) + if !ok { + panic(unexpectedType) + } + return &JWTAuth{JWTAuth: *jwtAuth.DeepCopy()}, nil +} + +// GetAllByConsumerID returns all jwt-auth credentials +// belong to a Consumer with id. +func (k *JWTAuthsCollection) GetAllByConsumerID(id string) ([]*JWTAuth, + error, +) { + creds, err := k.credentialsCollection.GetAllByConsumerID(id) + if err != nil { + return nil, err + } + + var res []*JWTAuth + for _, cred := range creds { + r, ok := cred.(*JWTAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &JWTAuth{JWTAuth: *r.DeepCopy()}) + } + return res, nil +} + +// Update updates an existing jwt-auth credential. +func (k *JWTAuthsCollection) Update(jwtAuth JWTAuth) error { + cred := (entity)(&jwtAuth) + return k.credentialsCollection.Update(cred) +} + +// Delete deletes a jwt-auth credential by key or ID. +func (k *JWTAuthsCollection) Delete(keyOrID string) error { + return k.credentialsCollection.Delete(keyOrID) +} + +// GetAll gets all jwt-auth credentials. +func (k *JWTAuthsCollection) GetAll() ([]*JWTAuth, error) { + creds, err := k.credentialsCollection.GetAll() + if err != nil { + return nil, err + } + + var res []*JWTAuth + for _, cred := range creds { + r, ok := cred.(*JWTAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &JWTAuth{JWTAuth: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/jwtauth_test.go b/pkg/state/jwtauth_test.go new file mode 100644 index 0000000..9f9f89d --- /dev/null +++ b/pkg/state/jwtauth_test.go @@ -0,0 +1,214 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func jwtAuthsCollection() *JWTAuthsCollection { + return state().JWTAuths +} + +func TestJWTAuthInsert(t *testing.T) { + assert := assert.New(t) + collection := jwtAuthsCollection() + + var jwtAuth JWTAuth + jwtAuth.Key = kong.String("my-key") + jwtAuth.ID = kong.String("first") + err := collection.Add(jwtAuth) + assert.NotNil(err) + + var jwtAuth2 JWTAuth + jwtAuth2.Key = kong.String("my-key") + jwtAuth2.ID = kong.String("first") + jwtAuth2.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + Username: kong.String("my-username"), + } + err = collection.Add(jwtAuth2) + assert.Nil(err) +} + +func TestJWTAuthGet(t *testing.T) { + assert := assert.New(t) + collection := jwtAuthsCollection() + + var jwtAuth JWTAuth + jwtAuth.Key = kong.String("my-key") + jwtAuth.ID = kong.String("first") + jwtAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(jwtAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-key", *res.Key) + + res, err = collection.Get("my-key") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("first", *res.ID) + assert.Equal("consumer1-id", *res.Consumer.ID) + + res, err = collection.Get("does-not-exist") + assert.NotNil(err) + assert.Nil(res) +} + +func TestJWTAuthUpdate(t *testing.T) { + assert := assert.New(t) + collection := jwtAuthsCollection() + + var jwtAuth JWTAuth + jwtAuth.Key = kong.String("my-key") + jwtAuth.ID = kong.String("first") + jwtAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(jwtAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-key", *res.Key) + + res.Key = kong.String("my-key2") + err = collection.Update(*res) + assert.Nil(err) + + res, err = collection.Get("my-key") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.Get("my-key2") + assert.Nil(err) + assert.Equal("first", *res.ID) +} + +func TestJWTAuthDelete(t *testing.T) { + assert := assert.New(t) + collection := jwtAuthsCollection() + + var jwtAuth JWTAuth + jwtAuth.Key = kong.String("my-key1") + jwtAuth.ID = kong.String("first") + jwtAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + err := collection.Add(jwtAuth) + assert.Nil(err) + + res, err := collection.Get("my-key1") + assert.Nil(err) + assert.NotNil(res) + + err = collection.Delete(*res.ID) + assert.Nil(err) + + res, err = collection.Get("my-key1") + assert.NotNil(err) + assert.Nil(res) + + // delete a non-existing one + err = collection.Delete("first") + assert.NotNil(err) + + err = collection.Delete("my-key1") + assert.NotNil(err) +} + +func TestJWTAuthGetAll(t *testing.T) { + assert := assert.New(t) + collection := jwtAuthsCollection() + + populateWithJWTAuthFixtures(assert, collection) + + jwtAuths, err := collection.GetAll() + assert.Nil(err) + assert.Equal(5, len(jwtAuths)) +} + +func TestJWTAuthGetByConsumer(t *testing.T) { + assert := assert.New(t) + collection := jwtAuthsCollection() + + populateWithJWTAuthFixtures(assert, collection) + + jwtAuths, err := collection.GetAllByConsumerID("consumer1-id") + assert.Nil(err) + assert.Equal(3, len(jwtAuths)) +} + +func populateWithJWTAuthFixtures(assert *assert.Assertions, + collection *JWTAuthsCollection, +) { + jwtAuths := []JWTAuth{ + { + JWTAuth: kong.JWTAuth{ + Key: kong.String("my-key11"), + ID: kong.String("first"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + JWTAuth: kong.JWTAuth{ + Key: kong.String("my-key12"), + ID: kong.String("second"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + JWTAuth: kong.JWTAuth{ + Key: kong.String("my-key13"), + ID: kong.String("third"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + JWTAuth: kong.JWTAuth{ + Key: kong.String("my-key21"), + ID: kong.String("fourth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + { + JWTAuth: kong.JWTAuth{ + Key: kong.String("my-key22"), + ID: kong.String("fifth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + } + + for _, k := range jwtAuths { + err := collection.Add(k) + assert.Nil(err) + } +} diff --git a/pkg/state/keyauth.go b/pkg/state/keyauth.go new file mode 100644 index 0000000..e474077 --- /dev/null +++ b/pkg/state/keyauth.go @@ -0,0 +1,85 @@ +package state + +// KeyAuthsCollection stores and indexes key-auth credentials. +type KeyAuthsCollection struct { + credentialsCollection +} + +func newKeyAuthsCollection(common collection) *KeyAuthsCollection { + return &KeyAuthsCollection{ + credentialsCollection: credentialsCollection{ + collection: common, + CredType: "key-auth", + }, + } +} + +// Add adds a key-auth credential to KeyAuthsCollection +func (k *KeyAuthsCollection) Add(keyAuth KeyAuth) error { + cred := (entity)(&keyAuth) + return k.credentialsCollection.Add(cred) +} + +// Get gets a key-auth credential by key or ID. +func (k *KeyAuthsCollection) Get(keyOrID string) (*KeyAuth, error) { + cred, err := k.credentialsCollection.Get(keyOrID) + if err != nil { + return nil, err + } + + keyAuth, ok := cred.(*KeyAuth) + if !ok { + panic(unexpectedType) + } + return &KeyAuth{KeyAuth: *keyAuth.DeepCopy()}, nil +} + +// GetAllByConsumerID returns all key-auth credentials +// belong to a Consumer with id. +func (k *KeyAuthsCollection) GetAllByConsumerID(id string) ([]*KeyAuth, + error, +) { + creds, err := k.credentialsCollection.GetAllByConsumerID(id) + if err != nil { + return nil, err + } + + var res []*KeyAuth + for _, cred := range creds { + r, ok := cred.(*KeyAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &KeyAuth{KeyAuth: *r.DeepCopy()}) + } + return res, nil +} + +// Update updates an existing key-auth credential. +func (k *KeyAuthsCollection) Update(keyAuth KeyAuth) error { + cred := (entity)(&keyAuth) + return k.credentialsCollection.Update(cred) +} + +// Delete deletes a key-auth credential by key or ID. +func (k *KeyAuthsCollection) Delete(keyOrID string) error { + return k.credentialsCollection.Delete(keyOrID) +} + +// GetAll gets all key-auth credentials. +func (k *KeyAuthsCollection) GetAll() ([]*KeyAuth, error) { + creds, err := k.credentialsCollection.GetAll() + if err != nil { + return nil, err + } + + var res []*KeyAuth + for _, cred := range creds { + r, ok := cred.(*KeyAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &KeyAuth{KeyAuth: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/keyauth_test.go b/pkg/state/keyauth_test.go new file mode 100644 index 0000000..74bfe7e --- /dev/null +++ b/pkg/state/keyauth_test.go @@ -0,0 +1,256 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func keyAuthsCollection() *KeyAuthsCollection { + return state().KeyAuths +} + +func TestKeyAuthInsert(t *testing.T) { + assert := assert.New(t) + collection := keyAuthsCollection() + + var keyAuth KeyAuth + keyAuth.Key = kong.String("my-secret-apikey") + keyAuth.ID = kong.String("first") + err := collection.Add(keyAuth) + assert.NotNil(err) + + var keyAuth2 KeyAuth + keyAuth2.Key = kong.String("my-secret-apikey") + keyAuth2.ID = kong.String("first") + keyAuth2.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + } + err = collection.Add(keyAuth2) + assert.Nil(err) + + // same API key + keyAuth2.Key = kong.String("my-secret-apikey") + keyAuth2.ID = kong.String("second") + keyAuth2.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + } + err = collection.Add(keyAuth2) + assert.NotNil(err) + + // re-insert + err = collection.Add(keyAuth2) + assert.NotNil(err) +} + +func TestKeyAuthGet(t *testing.T) { + assert := assert.New(t) + collection := keyAuthsCollection() + + var keyAuth KeyAuth + keyAuth.Key = kong.String("my-apikey") + keyAuth.ID = kong.String("first") + keyAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + } + + err := collection.Add(keyAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-apikey", *res.Key) + + res, err = collection.Get("my-apikey") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("first", *res.ID) + assert.Equal("consumer1-id", *res.Consumer.ID) + + res, err = collection.Get("does-not-exist") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.Get("") + assert.NotNil(err) + assert.Nil(res) +} + +func TestKeyAuthUpdate(t *testing.T) { + assert := assert.New(t) + collection := keyAuthsCollection() + + var keyAuth KeyAuth + + assert.NotNil(collection.Add(keyAuth)) + + keyAuth.Key = kong.String("my-apikey") + keyAuth.ID = kong.String("first") + keyAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + } + + err := collection.Add(keyAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-apikey", *res.Key) + + res.Key = kong.String("my-apikey2") + err = collection.Update(*res) + assert.Nil(err) + + res, err = collection.Get("first") + assert.Nil(err) + assert.Equal("my-apikey2", *res.Key) + + res, err = collection.Get("my-apikey") + assert.NotNil(err) + assert.Nil(res) +} + +func TestKeyAuthDelete(t *testing.T) { + assert := assert.New(t) + collection := keyAuthsCollection() + + var keyAuth KeyAuth + keyAuth.Key = kong.String("my-apikey1") + keyAuth.ID = kong.String("first") + keyAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + } + err := collection.Add(keyAuth) + assert.Nil(err) + + res, err := collection.Get("my-apikey1") + assert.Nil(err) + assert.NotNil(res) + + err = collection.Delete(*res.ID) + assert.Nil(err) + + res, err = collection.Get("my-apikey1") + assert.NotNil(err) + assert.Nil(res) + + // delete a non-existing one + err = collection.Delete("first") + assert.NotNil(err) + + err = collection.Delete("my-apikey1") + assert.NotNil(err) + + err = collection.Delete("does-not-exist") + assert.NotNil(err) + + err = collection.Delete("") + assert.NotNil(err) +} + +func TestKeyAuthGetAll(t *testing.T) { + assert := assert.New(t) + collection := keyAuthsCollection() + + populateWithKeyAuthFixtures(assert, collection) + + keyAuths, err := collection.GetAll() + assert.Nil(err) + assert.Equal(5, len(keyAuths)) +} + +func TestKeyAuthGetByConsumer(t *testing.T) { + assert := assert.New(t) + collection := keyAuthsCollection() + + populateWithKeyAuthFixtures(assert, collection) + + keyAuths, err := collection.GetAllByConsumerID("consumer1-id") + assert.Nil(err) + assert.Equal(3, len(keyAuths)) +} + +func populateWithKeyAuthFixtures(assert *assert.Assertions, + collection *KeyAuthsCollection, +) { + keyAuths := []KeyAuth{ + { + KeyAuth: kong.KeyAuth{ + Key: kong.String("my-apikey11"), + ID: kong.String("first"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + }, + }, + }, + { + KeyAuth: kong.KeyAuth{ + Key: kong.String("my-apikey12"), + ID: kong.String("second"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + }, + }, + }, + { + KeyAuth: kong.KeyAuth{ + Key: kong.String("my-apikey13"), + ID: kong.String("third"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + }, + }, + }, + { + KeyAuth: kong.KeyAuth{ + Key: kong.String("my-apikey21"), + ID: kong.String("fourth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + }, + }, + }, + { + KeyAuth: kong.KeyAuth{ + Key: kong.String("my-apikey22"), + ID: kong.String("fifth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + }, + }, + }, + } + + for _, k := range keyAuths { + err := collection.Add(k) + assert.Nil(err) + } +} + +func TestKeyAuthInvalidType(t *testing.T) { + assert := assert.New(t) + collection := keyAuthsCollection() + + var hmacAuth HMACAuth + hmacAuth.Username = kong.String("my-hmacAuth") + hmacAuth.ID = kong.String("first") + hmacAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + } + txn := collection.db.Txn(true) + assert.Nil(txn.Insert("key-auth", &hmacAuth)) + txn.Commit() + + assert.Panics(func() { + collection.Get("first") + }) + assert.Panics(func() { + collection.GetAll() + }) + assert.Panics(func() { + collection.GetAllByConsumerID("consumer-id") + }) +} diff --git a/pkg/state/konnect_types.go b/pkg/state/konnect_types.go new file mode 100644 index 0000000..85bd7e1 --- /dev/null +++ b/pkg/state/konnect_types.go @@ -0,0 +1,143 @@ +package state + +import ( + "reflect" + + "github.com/kong/deck/konnect" +) + +// Document represents a document in Konnect. +// It adds some helper methods along with Meta to the original Document object. +type Document struct { + konnect.Document `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (d1 *Document) Identifier() string { + if d1.Path != nil { + return *d1.Path + } + return *d1.ID +} + +// Console returns an entity's identity in a human-readable string. +func (d1 *Document) Console() string { + return *d1.Path +} + +// Equal returns true if s1 and s2 are equal. +func (d1 *Document) Equal(d2 *Document) bool { + return d1.EqualWithOpts(d2, false, false, false) +} + +// EqualWithOpts returns true if d1 and d2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (d1 *Document) EqualWithOpts(d2 *Document, + ignoreID, _, ignoreForeign bool, +) bool { + d1Copy := d1.Document.ShallowCopy() + d2Copy := d2.Document.ShallowCopy() + + if ignoreID { + d1Copy.ID = nil + d2Copy.ID = nil + } + if ignoreForeign { + d1Copy.Parent = nil + d2Copy.Parent = nil + } + return reflect.DeepEqual(d1Copy, d2Copy) +} + +// ServicePackage represents a service package in Konnect. +// It adds some helper methods along with Meta to the original ServicePackage object. +type ServicePackage struct { + konnect.ServicePackage `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (s1 *ServicePackage) Identifier() string { + if s1.Name != nil { + return *s1.Name + } + return *s1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (s1 *ServicePackage) Console() string { + return s1.Identifier() +} + +// Equal returns true if s1 and s2 are equal. +func (s1 *ServicePackage) Equal(s2 *ServicePackage) bool { + return s1.EqualWithOpts(s2, false, false) +} + +// EqualWithOpts returns true if s1 and s2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (s1 *ServicePackage) EqualWithOpts(s2 *ServicePackage, + ignoreID bool, _ bool, +) bool { + s1Copy := s1.ServicePackage.DeepCopy() + s2Copy := s2.ServicePackage.DeepCopy() + + if ignoreID { + s1Copy.ID = nil + s2Copy.ID = nil + } + return reflect.DeepEqual(s1Copy, s2Copy) +} + +// ServiceVersion represents a service version in Konnect. +// It adds some helper methods along with Meta to the original ServiceVersion +// object. +type ServiceVersion struct { + konnect.ServiceVersion `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (s1 *ServiceVersion) Identifier() string { + if s1.Version != nil { + return *s1.Version + } + return *s1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (s1 *ServiceVersion) Console() string { + return s1.Identifier() +} + +// Equal returns true if s1 and s2 are equal. +func (s1 *ServiceVersion) Equal(s2 *ServiceVersion) bool { + return s1.EqualWithOpts(s2, false, false, false) +} + +// EqualWithOpts returns true if s1 and s2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (s1 *ServiceVersion) EqualWithOpts(s2 *ServiceVersion, + ignoreID, _, ignoreForeign bool, +) bool { + s1Copy := s1.ServiceVersion.DeepCopy() + s2Copy := s2.ServiceVersion.DeepCopy() + + if ignoreID { + s1Copy.ID = nil + s2Copy.ID = nil + } + if ignoreForeign { + s1Copy.ServicePackage = nil + s1Copy.ControlPlaneServiceRelation = nil + s2Copy.ServicePackage = nil + s2Copy.ControlPlaneServiceRelation = nil + } + return reflect.DeepEqual(s1Copy, s2Copy) +} diff --git a/pkg/state/mtlsauth.go b/pkg/state/mtlsauth.go new file mode 100644 index 0000000..25e715a --- /dev/null +++ b/pkg/state/mtlsauth.go @@ -0,0 +1,85 @@ +package state + +// MTLSAuthsCollection stores and indexes mtls-auth credentials. +type MTLSAuthsCollection struct { + credentialsCollection +} + +func newMTLSAuthsCollection(common collection) *MTLSAuthsCollection { + return &MTLSAuthsCollection{ + credentialsCollection: credentialsCollection{ + collection: common, + CredType: "mtls-auth", + }, + } +} + +// Add adds a mtls-auth credential to MTLSAuthsCollection +func (k *MTLSAuthsCollection) Add(mtlsAuth MTLSAuth) error { + cred := (entity)(&mtlsAuth) + return k.credentialsCollection.Add(cred) +} + +// Get gets a mtls-auth credential by ID. +func (k *MTLSAuthsCollection) Get(ID string) (*MTLSAuth, error) { + cred, err := k.credentialsCollection.Get(ID) + if err != nil { + return nil, err + } + + mtlsAuth, ok := cred.(*MTLSAuth) + if !ok { + panic(unexpectedType) + } + return &MTLSAuth{MTLSAuth: *mtlsAuth.DeepCopy()}, nil +} + +// GetAllByConsumerID returns all mtls-auth credentials +// belong to a Consumer with id. +func (k *MTLSAuthsCollection) GetAllByConsumerID(id string) ([]*MTLSAuth, + error, +) { + creds, err := k.credentialsCollection.GetAllByConsumerID(id) + if err != nil { + return nil, err + } + + var res []*MTLSAuth + for _, cred := range creds { + r, ok := cred.(*MTLSAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &MTLSAuth{MTLSAuth: *r.DeepCopy()}) + } + return res, nil +} + +// Update updates an existing mtls-auth credential. +func (k *MTLSAuthsCollection) Update(mtlsAuth MTLSAuth) error { + cred := (entity)(&mtlsAuth) + return k.credentialsCollection.Update(cred) +} + +// Delete deletes a mtls-auth credential by ID. +func (k *MTLSAuthsCollection) Delete(ID string) error { + return k.credentialsCollection.Delete(ID) +} + +// GetAll gets all mtls-auth credentials. +func (k *MTLSAuthsCollection) GetAll() ([]*MTLSAuth, error) { + creds, err := k.credentialsCollection.GetAll() + if err != nil { + return nil, err + } + + var res []*MTLSAuth + for _, cred := range creds { + r, ok := cred.(*MTLSAuth) + if !ok { + panic(unexpectedType) + } + res = append(res, &MTLSAuth{MTLSAuth: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/mtlsauth_test.go b/pkg/state/mtlsauth_test.go new file mode 100644 index 0000000..3a02588 --- /dev/null +++ b/pkg/state/mtlsauth_test.go @@ -0,0 +1,204 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func mtlsAuthsCollection() *MTLSAuthsCollection { + return state().MTLSAuths +} + +func TestMTLSAuthInsert(t *testing.T) { + assert := assert.New(t) + collection := mtlsAuthsCollection() + + var mtlsAuth MTLSAuth + mtlsAuth.ID = kong.String("first") + err := collection.Add(mtlsAuth) + assert.NotNil(err) + + mtlsAuth.SubjectName = kong.String("test@example.com") + err = collection.Add(mtlsAuth) + assert.NotNil(err) + + var mtlsAuth2 MTLSAuth + mtlsAuth2.SubjectName = kong.String("test@example.com") + mtlsAuth2.ID = kong.String("first") + mtlsAuth2.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + Username: kong.String("my-username"), + } + err = collection.Add(mtlsAuth2) + assert.Nil(err) +} + +func TestMTLSAuthGet(t *testing.T) { + assert := assert.New(t) + collection := mtlsAuthsCollection() + + var mtlsAuth MTLSAuth + mtlsAuth.SubjectName = kong.String("test@example.com") + mtlsAuth.ID = kong.String("first") + mtlsAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(mtlsAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("test@example.com", *res.SubjectName) + + res, err = collection.Get("does-not-exist") + assert.NotNil(err) + assert.Nil(res) +} + +func TestMTLSAuthUpdate(t *testing.T) { + assert := assert.New(t) + collection := mtlsAuthsCollection() + + var mtlsAuth MTLSAuth + mtlsAuth.SubjectName = kong.String("test@example.com") + mtlsAuth.ID = kong.String("first") + mtlsAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(mtlsAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("test@example.com", *res.SubjectName) + + res.SubjectName = kong.String("test2@example.com") + err = collection.Update(*res) + assert.Nil(err) + + res, err = collection.Get("first") + assert.Nil(err) + assert.Equal("test2@example.com", *res.SubjectName) +} + +func TestMTLSAuthDelete(t *testing.T) { + assert := assert.New(t) + collection := mtlsAuthsCollection() + + var mtlsAuth MTLSAuth + mtlsAuth.SubjectName = kong.String("test@example.com") + mtlsAuth.ID = kong.String("first") + mtlsAuth.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + err := collection.Add(mtlsAuth) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + + err = collection.Delete(*res.ID) + assert.Nil(err) + + res, err = collection.Get("first") + assert.NotNil(err) + assert.Nil(res) + + // delete a non-existing one + err = collection.Delete("first") + assert.NotNil(err) +} + +func TestMTLSAuthGetAll(t *testing.T) { + assert := assert.New(t) + collection := mtlsAuthsCollection() + + populateWithMTLSAuthFixtures(assert, collection) + + mtlsAuths, err := collection.GetAll() + assert.Nil(err) + assert.Equal(5, len(mtlsAuths)) +} + +func TestMTLSAuthGetByConsumer(t *testing.T) { + assert := assert.New(t) + collection := mtlsAuthsCollection() + + populateWithMTLSAuthFixtures(assert, collection) + + mtlsAuths, err := collection.GetAllByConsumerID("consumer1-id") + assert.Nil(err) + assert.Equal(3, len(mtlsAuths)) +} + +func populateWithMTLSAuthFixtures(assert *assert.Assertions, + collection *MTLSAuthsCollection, +) { + mtlsAuths := []MTLSAuth{ + { + MTLSAuth: kong.MTLSAuth{ + SubjectName: kong.String("test11@example.com"), + ID: kong.String("first"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + MTLSAuth: kong.MTLSAuth{ + SubjectName: kong.String("test12@example.com"), + ID: kong.String("second"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + MTLSAuth: kong.MTLSAuth{ + SubjectName: kong.String("test13@example.com"), + ID: kong.String("third"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + MTLSAuth: kong.MTLSAuth{ + SubjectName: kong.String("test21@example.com"), + ID: kong.String("fourth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + { + MTLSAuth: kong.MTLSAuth{ + SubjectName: kong.String("test22@example.com"), + ID: kong.String("fifth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + } + + for _, k := range mtlsAuths { + err := collection.Add(k) + assert.Nil(err) + } +} diff --git a/pkg/state/oauth2.go b/pkg/state/oauth2.go new file mode 100644 index 0000000..ffe919c --- /dev/null +++ b/pkg/state/oauth2.go @@ -0,0 +1,85 @@ +package state + +// Oauth2CredsCollection stores and indexes oauth2 credentials. +type Oauth2CredsCollection struct { + credentialsCollection +} + +func newOauth2CredsCollection(common collection) *Oauth2CredsCollection { + return &Oauth2CredsCollection{ + credentialsCollection: credentialsCollection{ + collection: common, + CredType: "oauth2", + }, + } +} + +// Add adds a oauth2 credential to Oauth2CredsCollection +func (k *Oauth2CredsCollection) Add(keyAuth Oauth2Credential) error { + cred := (entity)(&keyAuth) + return k.credentialsCollection.Add(cred) +} + +// Get gets a oauth2 credential by key or ID. +func (k *Oauth2CredsCollection) Get(keyOrID string) (*Oauth2Credential, error) { + cred, err := k.credentialsCollection.Get(keyOrID) + if err != nil { + return nil, err + } + + keyAuth, ok := cred.(*Oauth2Credential) + if !ok { + panic(unexpectedType) + } + return &Oauth2Credential{Oauth2Credential: *keyAuth.DeepCopy()}, nil +} + +// GetAllByConsumerID returns all oauth2 credentials +// belong to a Consumer with id. +func (k *Oauth2CredsCollection) GetAllByConsumerID(id string) ([]*Oauth2Credential, + error, +) { + creds, err := k.credentialsCollection.GetAllByConsumerID(id) + if err != nil { + return nil, err + } + + var res []*Oauth2Credential + for _, cred := range creds { + r, ok := cred.(*Oauth2Credential) + if !ok { + panic(unexpectedType) + } + res = append(res, &Oauth2Credential{Oauth2Credential: *r.DeepCopy()}) + } + return res, nil +} + +// Update updates an existing oauth2 credential. +func (k *Oauth2CredsCollection) Update(keyAuth Oauth2Credential) error { + cred := (entity)(&keyAuth) + return k.credentialsCollection.Update(cred) +} + +// Delete deletes a oauth2 credential by key or ID. +func (k *Oauth2CredsCollection) Delete(keyOrID string) error { + return k.credentialsCollection.Delete(keyOrID) +} + +// GetAll gets all oauth2 credentials. +func (k *Oauth2CredsCollection) GetAll() ([]*Oauth2Credential, error) { + creds, err := k.credentialsCollection.GetAll() + if err != nil { + return nil, err + } + + var res []*Oauth2Credential + for _, cred := range creds { + r, ok := cred.(*Oauth2Credential) + if !ok { + panic(unexpectedType) + } + res = append(res, &Oauth2Credential{Oauth2Credential: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/oauth2_test.go b/pkg/state/oauth2_test.go new file mode 100644 index 0000000..cff35b7 --- /dev/null +++ b/pkg/state/oauth2_test.go @@ -0,0 +1,211 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func oauth2CredsCollection() *Oauth2CredsCollection { + return state().Oauth2Creds +} + +func TestOauth2CredInsert(t *testing.T) { + assert := assert.New(t) + collection := oauth2CredsCollection() + + var oauth2Cred Oauth2Credential + oauth2Cred.ClientID = kong.String("client-id") + oauth2Cred.ID = kong.String("first") + err := collection.Add(oauth2Cred) + assert.NotNil(err) + + oauth2Cred.Consumer = &kong.Consumer{ + ID: kong.String("consumer-id"), + Username: kong.String("my-username"), + } + err = collection.Add(oauth2Cred) + assert.Nil(err) +} + +func TestOauth2CredentialGet(t *testing.T) { + assert := assert.New(t) + collection := oauth2CredsCollection() + + var oauth2Cred Oauth2Credential + oauth2Cred.ClientID = kong.String("my-clientid") + oauth2Cred.ID = kong.String("first") + oauth2Cred.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(oauth2Cred) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-clientid", *res.ClientID) + + res, err = collection.Get("my-clientid") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("first", *res.ID) + assert.Equal("consumer1-id", *res.Consumer.ID) + + res, err = collection.Get("does-not-exist") + assert.NotNil(err) + assert.Nil(res) +} + +func TestOauth2CredentialUpdate(t *testing.T) { + assert := assert.New(t) + collection := oauth2CredsCollection() + + var oauth2Cred Oauth2Credential + oauth2Cred.ClientID = kong.String("my-clientid") + oauth2Cred.ID = kong.String("first") + oauth2Cred.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + + err := collection.Add(oauth2Cred) + assert.Nil(err) + + res, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(res) + assert.Equal("my-clientid", *res.ClientID) + + res.ClientID = kong.String("my-clientid2") + err = collection.Update(*res) + assert.Nil(err) + + res, err = collection.Get("my-clientid") + assert.NotNil(err) + assert.Nil(res) + + res, err = collection.Get("my-clientid2") + assert.Nil(err) + assert.Equal("first", *res.ID) +} + +func TestOauth2CredentialDelete(t *testing.T) { + assert := assert.New(t) + collection := oauth2CredsCollection() + + var oauth2Cred Oauth2Credential + oauth2Cred.ClientID = kong.String("my-clientid1") + oauth2Cred.ID = kong.String("first") + oauth2Cred.Consumer = &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + } + err := collection.Add(oauth2Cred) + assert.Nil(err) + + res, err := collection.Get("my-clientid1") + assert.Nil(err) + assert.NotNil(res) + + err = collection.Delete(*res.ID) + assert.Nil(err) + + res, err = collection.Get("my-clientid1") + assert.NotNil(err) + assert.Nil(res) + + // delete a non-existing one + err = collection.Delete("first") + assert.NotNil(err) + + err = collection.Delete("my-clientid1") + assert.NotNil(err) +} + +func TestOauth2CredentialGetAll(t *testing.T) { + assert := assert.New(t) + collection := oauth2CredsCollection() + + populateWithOauth2CredentialFixtures(assert, collection) + + oauth2Creds, err := collection.GetAll() + assert.Nil(err) + assert.Equal(5, len(oauth2Creds)) +} + +func TestOauth2CredentialGetByConsumer(t *testing.T) { + assert := assert.New(t) + collection := oauth2CredsCollection() + + populateWithOauth2CredentialFixtures(assert, collection) + + oauth2Creds, err := collection.GetAllByConsumerID("consumer1-id") + assert.Nil(err) + assert.Equal(3, len(oauth2Creds)) +} + +func populateWithOauth2CredentialFixtures(assert *assert.Assertions, + collection *Oauth2CredsCollection, +) { + oauth2Creds := []Oauth2Credential{ + { + Oauth2Credential: kong.Oauth2Credential{ + ClientID: kong.String("my-clientid11"), + ID: kong.String("first"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + Oauth2Credential: kong.Oauth2Credential{ + ClientID: kong.String("my-clientid12"), + ID: kong.String("second"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + Oauth2Credential: kong.Oauth2Credential{ + ClientID: kong.String("my-clientid13"), + ID: kong.String("third"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1-id"), + Username: kong.String("consumer1-name"), + }, + }, + }, + { + Oauth2Credential: kong.Oauth2Credential{ + ClientID: kong.String("my-clientid21"), + ID: kong.String("fourth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + { + Oauth2Credential: kong.Oauth2Credential{ + ClientID: kong.String("my-clientid22"), + ID: kong.String("fifth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer2-id"), + Username: kong.String("consumer2-name"), + }, + }, + }, + } + + for _, k := range oauth2Creds { + err := collection.Add(k) + assert.Nil(err) + } +} diff --git a/pkg/state/plugin.go b/pkg/state/plugin.go new file mode 100644 index 0000000..0b39519 --- /dev/null +++ b/pkg/state/plugin.go @@ -0,0 +1,380 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +var errPluginNameRequired = fmt.Errorf("name of plugin required") + +const ( + pluginTableName = "plugin" + pluginsByServiceID = "pluginsByServiceID" + pluginsByRouteID = "pluginsByRouteID" + pluginsByConsumerID = "pluginsByConsumerID" + pluginsByConsumerGroupID = "pluginsByConsumerGroupID" +) + +var pluginTableSchema = &memdb.TableSchema{ + Name: pluginTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "name": { + Name: "name", + Indexer: &memdb.StringFieldIndex{Field: "Name"}, + }, + all: allIndex, + // foreign + pluginsByServiceID: { + Name: pluginsByServiceID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Service", + Sub: "ID", + }, + }, + }, + AllowMissing: true, + }, + pluginsByRouteID: { + Name: pluginsByRouteID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Route", + Sub: "ID", + }, + }, + }, + AllowMissing: true, + }, + pluginsByConsumerID: { + Name: pluginsByConsumerID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Consumer", + Sub: "ID", + }, + }, + }, + AllowMissing: true, + }, + pluginsByConsumerGroupID: { + Name: pluginsByConsumerGroupID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "ConsumerGroup", + Sub: "ID", + }, + }, + }, + AllowMissing: true, + }, + // combined foreign fields + // FIXME bug: collision if svc/route/consumer has the same ID + // and same type of plugin is created. Consider the case when only + // of the association is present + "fields": { + Name: "fields", + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Plugin", + Sub: "Name", + }, + { + Struct: "Service", + Sub: "ID", + }, + { + Struct: "Route", + Sub: "ID", + }, + { + Struct: "Consumer", + Sub: "ID", + }, + { + Struct: "ConsumerGroup", + Sub: "ID", + }, + }, + }, + }, + }, +} + +// PluginsCollection stores and indexes Kong Services. +type PluginsCollection collection + +// Add adds a plugin to PluginsCollection +func (k *PluginsCollection) Add(plugin Plugin) error { + txn := k.db.Txn(true) + defer txn.Abort() + + err := insertPlugin(txn, plugin) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func insertPlugin(txn *memdb.Txn, plugin Plugin) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(plugin.ID) { + return errIDRequired + } + if utils.Empty(plugin.Name) { + return errPluginNameRequired + } + + // err out if plugin with same ID is present + _, err := getPluginByID(txn, *plugin.ID) + if err == nil { + return fmt.Errorf("inserting plugin %v: %w", plugin.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + // err out if another plugin with exact same combination is present + sID, rID, cID, cgID := "", "", "", "" + if plugin.Service != nil && !utils.Empty(plugin.Service.ID) { + sID = *plugin.Service.ID + } + if plugin.Route != nil && !utils.Empty(plugin.Route.ID) { + rID = *plugin.Route.ID + } + if plugin.Consumer != nil && !utils.Empty(plugin.Consumer.ID) { + cID = *plugin.Consumer.ID + } + if plugin.ConsumerGroup != nil && !utils.Empty(plugin.ConsumerGroup.ID) { + cgID = *plugin.ConsumerGroup.ID + } + _, err = getPluginBy(txn, *plugin.Name, sID, rID, cID, cgID) + if err == nil { + return fmt.Errorf("inserting plugin %v: %w", plugin.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + // all good + err = txn.Insert(pluginTableName, &plugin) + if err != nil { + return err + } + return nil +} + +func getPluginByID(txn *memdb.Txn, id string) (*Plugin, error) { + res, err := multiIndexLookupUsingTxn(txn, pluginTableName, + []string{"id"}, id) + if err != nil { + return nil, err + } + + plugin, ok := res.(*Plugin) + if !ok { + panic(unexpectedType) + } + return &Plugin{Plugin: *plugin.DeepCopy()}, nil +} + +// Get gets a plugin by id. +func (k *PluginsCollection) Get(id string) (*Plugin, error) { + if id == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + + plugin, err := getPluginByID(txn, id) + if err != nil { + return nil, err + } + return plugin, nil +} + +// GetAllByName returns all plugins of a specific type +// (key-auth, ratelimiting, etc). +func (k *PluginsCollection) GetAllByName(name string) ([]*Plugin, error) { + return k.getAllPluginsBy("name", name) +} + +func getPluginBy(txn *memdb.Txn, name, svcID, routeID, consumerID, consumerGroupID string) ( + *Plugin, error, +) { + if name == "" { + return nil, errPluginNameRequired + } + + res, err := txn.First(pluginTableName, "fields", + name, svcID, routeID, consumerID, consumerGroupID) + if err != nil { + return nil, err + } + if res == nil { + return nil, ErrNotFound + } + p, ok := res.(*Plugin) + if !ok { + panic(unexpectedType) + } + return &Plugin{Plugin: *p.DeepCopy()}, nil +} + +// GetByProp returns a plugin which matches all the properties passed in +// the arguments. If serviceID, routeID, consumerID and consumerGroupID +// are empty strings, then a global plugin is searched. +// Otherwise, a plugin with name and the supplied foreign references is +// searched. +// name is required. +func (k *PluginsCollection) GetByProp( + name, serviceID, routeID, consumerID, consumerGroupID string, +) (*Plugin, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + return getPluginBy(txn, name, serviceID, routeID, consumerID, consumerGroupID) +} + +func (k *PluginsCollection) getAllPluginsBy(index, identifier string) ( + []*Plugin, error, +) { + if identifier == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(pluginTableName, index, identifier) + if err != nil { + return nil, err + } + var res []*Plugin + for el := iter.Next(); el != nil; el = iter.Next() { + p, ok := el.(*Plugin) + if !ok { + panic(unexpectedType) + } + res = append(res, &Plugin{Plugin: *p.DeepCopy()}) + } + return res, nil +} + +// GetAllByServiceID returns all plugins referencing a service +// by its id. +func (k *PluginsCollection) GetAllByServiceID(id string) ([]*Plugin, + error, +) { + return k.getAllPluginsBy(pluginsByServiceID, id) +} + +// GetAllByRouteID returns all plugins referencing a route +// by its id. +func (k *PluginsCollection) GetAllByRouteID(id string) ([]*Plugin, + error, +) { + return k.getAllPluginsBy(pluginsByRouteID, id) +} + +// GetAllByConsumerID returns all plugins referencing a consumer +// by its id. +func (k *PluginsCollection) GetAllByConsumerID(id string) ([]*Plugin, + error, +) { + return k.getAllPluginsBy(pluginsByConsumerID, id) +} + +// GetAllByConsumerGroupID returns all plugins referencing a consumer-group +// by its id. +func (k *PluginsCollection) GetAllByConsumerGroupID(id string) ([]*Plugin, + error, +) { + return k.getAllPluginsBy(pluginsByConsumerGroupID, id) +} + +// Update updates a plugin +func (k *PluginsCollection) Update(plugin Plugin) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(plugin.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deletePlugin(txn, *plugin.ID) + if err != nil { + return err + } + + err = insertPlugin(txn, plugin) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deletePlugin(txn *memdb.Txn, id string) error { + plugin, err := getPluginByID(txn, id) + if err != nil { + return err + } + return txn.Delete(pluginTableName, plugin) +} + +// Delete deletes a plugin by ID. +func (k *PluginsCollection) Delete(id string) error { + if id == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deletePlugin(txn, id) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a plugin by name or ID. +func (k *PluginsCollection) GetAll() ([]*Plugin, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(pluginTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Plugin + for el := iter.Next(); el != nil; el = iter.Next() { + p, ok := el.(*Plugin) + if !ok { + panic(unexpectedType) + } + res = append(res, &Plugin{Plugin: *p.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/plugin_test.go b/pkg/state/plugin_test.go new file mode 100644 index 0000000..f808806 --- /dev/null +++ b/pkg/state/plugin_test.go @@ -0,0 +1,574 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func pluginsCollection() *PluginsCollection { + return state().Plugins +} + +func TestPluginsCollection_Add(t *testing.T) { + type args struct { + plugin Plugin + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + Name: kong.String("foo"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors when Name is nil", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id1"), + }, + }, + }, + wantErr: true, + }, + { + name: "inserts with a name and ID", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id2"), + Name: kong.String("bar-name"), + }, + }, + }, + wantErr: false, + }, + { + name: "errors on re-insert when same ID is present", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id3"), + Name: kong.String("foo-name"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert when id is present", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id3"), + Name: kong.String("foobar-name"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert when when same association is present", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id4-new"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert when when same (multiple) association is present", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id5-new"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + Service: &kong.Service{ + ID: kong.String("svc1"), + }, + }, + }, + }, + wantErr: true, + }, + } + k := pluginsCollection() + plugin1 := Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id3"), + Name: kong.String("foo-name"), + }, + } + plugin2 := Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id4"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + }, + } + plugin3 := Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id5"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + Service: &kong.Service{ + ID: kong.String("svc1"), + }, + }, + } + k.Add(plugin1) + k.Add(plugin2) + k.Add(plugin3) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.plugin); (err != nil) != tt.wantErr { + t.Errorf("PluginsCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestPluginsCollection_Update(t *testing.T) { + type args struct { + plugin Plugin + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + Name: kong.String("foo"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors when Name is nil", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id1"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors when the plugin is not present", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("does-not-exist-yet"), + Name: kong.String("bar-name"), + }, + }, + }, + wantErr: true, + }, + { + name: "updates when ID is present", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id3"), + Name: kong.String("foo-name-new"), + }, + }, + }, + wantErr: false, + }, + { + name: "errors on update when when same association is present", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("new-id"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "errors on update when when same (multiple) association is present", + args: args{ + plugin: Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("new-id"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + Service: &kong.Service{ + ID: kong.String("svc1"), + }, + }, + }, + }, + wantErr: true, + }, + } + k := pluginsCollection() + plugin1 := Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id1"), + Name: kong.String("foo-name"), + }, + } + plugin2 := Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id2"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + }, + } + plugin3 := Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id3"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + Service: &kong.Service{ + ID: kong.String("svc1"), + }, + }, + } + plugin4 := Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id4"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + Service: &kong.Service{ + ID: kong.String("svc1"), + }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("cg1"), + }, + }, + } + k.Add(plugin1) + k.Add(plugin2) + k.Add(plugin3) + k.Add(plugin4) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Update(tt.args.plugin); (err != nil) != tt.wantErr { + t.Errorf("PluginsCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestPluginGet(t *testing.T) { + assert := assert.New(t) + collection := pluginsCollection() + + var plugin Plugin + plugin.Name = kong.String("my-plugin") + plugin.ID = kong.String("first") + plugin.Service = &kong.Service{ + ID: kong.String("service1-id"), + Name: kong.String("service1-name"), + } + assert.NotNil(plugin.Service) + err := collection.Add(plugin) + assert.NotNil(plugin.Service) + assert.Nil(err) + + re, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(re) + assert.Equal("my-plugin", *re.Name) + re.Service = &kong.Service{ + ID: kong.String("service2-id"), + Name: kong.String("service2-name"), + } + + re, err = collection.Get("does-not-exists") + assert.Equal(ErrNotFound, err) + assert.Nil(re) +} + +func TestGetPluginByProp(t *testing.T) { + plugins := []Plugin{ + { + Plugin: kong.Plugin{ + ID: kong.String("1"), + Name: kong.String("key-auth"), + Config: map[string]interface{}{ + "key1": "value1", + }, + }, + }, + { + Plugin: kong.Plugin{ + ID: kong.String("2"), + Name: kong.String("key-auth"), + Service: &kong.Service{ + ID: kong.String("svc1"), + }, + Config: map[string]interface{}{ + "key2": "value2", + }, + }, + }, + { + Plugin: kong.Plugin{ + ID: kong.String("3"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + Config: map[string]interface{}{ + "key3": "value3", + }, + }, + }, + { + Plugin: kong.Plugin{ + ID: kong.String("4"), + Name: kong.String("key-auth"), + Consumer: &kong.Consumer{ + ID: kong.String("consumer1"), + }, + Config: map[string]interface{}{ + "key4": "value4", + }, + }, + }, + { + Plugin: kong.Plugin{ + ID: kong.String("5"), + Name: kong.String("key-auth"), + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("cg1"), + }, + Config: map[string]interface{}{ + "key5": "value5", + }, + }, + }, + } + assert := assert.New(t) + collection := pluginsCollection() + + for _, p := range plugins { + assert.Nil(collection.Add(p)) + } + + plugin, err := collection.GetByProp("", "", "", "", "") + assert.Nil(plugin) + assert.Error(err) + + plugin, err = collection.GetByProp("foo", "", "", "", "") + assert.Nil(plugin) + assert.Equal(ErrNotFound, err) + + plugin, err = collection.GetByProp("key-auth", "", "", "", "") + assert.NoError(err) + assert.NotNil(plugin) + assert.Equal("value1", plugin.Config["key1"]) + + plugin, err = collection.GetByProp("key-auth", "svc1", "", "", "") + assert.NoError(err) + assert.NotNil(plugin) + assert.Equal("value2", plugin.Config["key2"]) + + plugin, err = collection.GetByProp("key-auth", "", "route1", "", "") + assert.NoError(err) + assert.NotNil(plugin) + assert.Equal("value3", plugin.Config["key3"]) + + plugin, err = collection.GetByProp("key-auth", "", "", "consumer1", "") + assert.NoError(err) + assert.NotNil(plugin) + assert.Equal("value4", plugin.Config["key4"]) + + plugin, err = collection.GetByProp("key-auth", "", "", "", "cg1") + assert.NoError(err) + assert.NotNil(plugin) + assert.Equal("value5", plugin.Config["key5"]) +} + +func TestPluginsInvalidType(t *testing.T) { + assert := assert.New(t) + + collection := pluginsCollection() + + var service Service + service.Name = kong.String("my-service") + service.ID = kong.String("first") + txn := collection.db.Txn(true) + txn.Insert(pluginTableName, &service) + txn.Commit() + + assert.Panics(func() { + collection.Get("first") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestPluginDelete(t *testing.T) { + assert := assert.New(t) + collection := pluginsCollection() + + var plugin Plugin + plugin.ID = kong.String("first") + plugin.Name = kong.String("my-plugin") + plugin.Config = map[string]interface{}{ + "foo": "bar", + "baz": "bar", + } + plugin.Service = &kong.Service{ + ID: kong.String("service1-id"), + Name: kong.String("service1-name"), + } + err := collection.Add(plugin) + assert.Nil(err) + + p, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(p) + assert.Equal("bar", p.Config["foo"]) + + err = collection.Delete(*p.ID) + assert.Nil(err) + + err = collection.Delete(*p.ID) + assert.NotNil(err) + + assert.NotNil(collection.Delete("")) +} + +func TestPluginGetAll(t *testing.T) { + assert := assert.New(t) + collection := pluginsCollection() + + plugins := []*Plugin{ + { + Plugin: kong.Plugin{ + ID: kong.String("first-id"), + Name: kong.String("key-auth"), + Service: &kong.Service{ + ID: kong.String("service1-id"), + Name: kong.String("service1-name"), + }, + Config: map[string]interface{}{ + "foo": "bar", + "baz": "bar", + }, + }, + }, + { + Plugin: kong.Plugin{ + ID: kong.String("second-id"), + Name: kong.String("basic-auth"), + Service: &kong.Service{ + ID: kong.String("service1-id"), + Name: kong.String("service1-name"), + }, + }, + }, + { + Plugin: kong.Plugin{ + ID: kong.String("third-id"), + Name: kong.String("rate-limiting"), + Route: &kong.Route{ + ID: kong.String("route1-id"), + Name: kong.String("route1-name"), + }, + }, + }, + { + Plugin: kong.Plugin{ + ID: kong.String("fourth-id"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1-id"), + Name: kong.String("route1-name"), + }, + }, + }, + } + + for _, p := range plugins { + assert.Nil(collection.Add(*p)) + } + + allPlugins, err := collection.GetAll() + assert.Nil(err) + assert.Equal(len(plugins), len(allPlugins)) + + allPlugins, err = collection.GetAllByName("") + assert.NotNil(err) + assert.Nil(allPlugins) + allPlugins, err = collection.GetAllByConsumerID("") + assert.NotNil(err) + assert.Nil(allPlugins) + allPlugins, err = collection.GetAllByRouteID("") + assert.NotNil(err) + assert.Nil(allPlugins) + allPlugins, err = collection.GetAllByServiceID("") + assert.NotNil(err) + assert.Nil(allPlugins) + + allPlugins, err = collection.GetAllByName("key-auth") + assert.Nil(err) + assert.Equal(2, len(allPlugins)) + + allPlugins, err = collection.GetAllByRouteID("route1-id") + assert.Nil(err) + assert.Equal(2, len(allPlugins)) + + allPlugins, err = collection.GetAllByServiceID("service1-id") + assert.Nil(err) + assert.Equal(2, len(allPlugins)) + + allPlugins, err = collection.GetAllByServiceID("service-nope") + assert.Nil(err) + assert.Equal(0, len(allPlugins)) +} diff --git a/pkg/state/rbac_endpoint_permission.go b/pkg/state/rbac_endpoint_permission.go new file mode 100644 index 0000000..87da91e --- /dev/null +++ b/pkg/state/rbac_endpoint_permission.go @@ -0,0 +1,210 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + rbacEndpointPermissionTableName = "rbac-endpointpermission" + rbacEndpointPermissionsByRoleID = "rbacEndpointPermissionsByRoleID" +) + +var ( + errInvalidRole = fmt.Errorf("role.ID is required in rbacEndpointPermission") + rbacEndpointPermissionTableSchema = &memdb.TableSchema{ + Name: rbacEndpointPermissionTableName, + Indexes: map[string]*memdb.IndexSchema{ + // ID in the case of an RBACEndpointPermission is a composite key of role ID, workspace, and endpoint + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + all: allIndex, + // foreign + rbacEndpointPermissionsByRoleID: { + Name: rbacEndpointPermissionsByRoleID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Role", + Sub: "ID", + }, + }, + }, + }, + }, + } +) + +func validateRoleForRBACEndpointPermission(rbacEndpointPermission *RBACEndpointPermission) error { + if rbacEndpointPermission.Role == nil || + utils.Empty(rbacEndpointPermission.Role.ID) { + return errInvalidRole + } + return nil +} + +// RBACEndpointPermissionsCollection stores and indexes Kong RBACEndpointPermissions. +type RBACEndpointPermissionsCollection collection + +// Add adds a rbacEndpointPermission into RBACEndpointPermissionsCollection +// rbacEndpointPermission.Endpoint should not be nil else an error is thrown. +func (k *RBACEndpointPermissionsCollection) Add(rbacEndpointPermission RBACEndpointPermission) error { + if err := validateRoleForRBACEndpointPermission(&rbacEndpointPermission); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, rbacEndpointPermission.FriendlyName()) + + _, err := getRBACEndpointPermission(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting rbacEndpointPermission %v: %w", rbacEndpointPermission.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + rbacEndpointPermission.ID = rbacEndpointPermission.FriendlyName() + err = txn.Insert(rbacEndpointPermissionTableName, &rbacEndpointPermission) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getRBACEndpointPermission(txn *memdb.Txn, IDs ...string) (*RBACEndpointPermission, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, rbacEndpointPermissionTableName, + []string{"id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + rbacEndpointPermission, ok := res.(*RBACEndpointPermission) + if !ok { + panic(unexpectedType) + } + return &RBACEndpointPermission{RBACEndpointPermission: *rbacEndpointPermission.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets a rbacEndpointPermission by name or ID. +func (k *RBACEndpointPermissionsCollection) Get(nameOrID string) (*RBACEndpointPermission, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + rbacEndpointPermission, err := getRBACEndpointPermission(txn, nameOrID) + if err != nil { + return nil, err + } + return rbacEndpointPermission, nil +} + +// Update updates a rbacEndpointPermission +func (k *RBACEndpointPermissionsCollection) Update(rbacEndpointPermission RBACEndpointPermission) error { + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteRBACEndpointPermission(txn, rbacEndpointPermission.FriendlyName()) + if err != nil { + return err + } + + rbacEndpointPermission.ID = rbacEndpointPermission.FriendlyName() + err = txn.Insert(rbacEndpointPermissionTableName, &rbacEndpointPermission) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteRBACEndpointPermission(txn *memdb.Txn, nameOrID string) error { + rbacEndpointPermission, err := getRBACEndpointPermission(txn, nameOrID) + if err != nil { + return err + } + rbacEndpointPermission.ID = rbacEndpointPermission.FriendlyName() + err = txn.Delete(rbacEndpointPermissionTableName, rbacEndpointPermission) + if err != nil { + return err + } + return nil +} + +// Delete deletes a rbacEndpointPermission by name or ID. +func (k *RBACEndpointPermissionsCollection) Delete(endpointIdentifier string) error { + if endpointIdentifier == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteRBACEndpointPermission(txn, endpointIdentifier) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a rbacEndpointPermission by name or ID. +func (k *RBACEndpointPermissionsCollection) GetAll() ([]*RBACEndpointPermission, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(rbacEndpointPermissionTableName, all, true) + if err != nil { + return nil, err + } + + var res []*RBACEndpointPermission + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*RBACEndpointPermission) + if !ok { + panic(unexpectedType) + } + res = append(res, &RBACEndpointPermission{RBACEndpointPermission: *r.DeepCopy()}) + } + txn.Commit() + return res, nil +} + +// GetAllByRoleID returns all endpoint permissions by referencing a role +// by its id. +func (k *RBACEndpointPermissionsCollection) GetAllByRoleID(id string) ([]*RBACEndpointPermission, + error, +) { + txn := k.db.Txn(false) + iter, err := txn.Get(rbacEndpointPermissionTableName, rbacEndpointPermissionsByRoleID, id) + if err != nil { + return nil, err + } + var res []*RBACEndpointPermission + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*RBACEndpointPermission) + if !ok { + panic(unexpectedType) + } + res = append(res, &RBACEndpointPermission{RBACEndpointPermission: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/rbac_endpoint_permission_test.go b/pkg/state/rbac_endpoint_permission_test.go new file mode 100644 index 0000000..05e6931 --- /dev/null +++ b/pkg/state/rbac_endpoint_permission_test.go @@ -0,0 +1,333 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func rbacEndpointPermissionsCollection() *RBACEndpointPermissionsCollection { + return state().RBACEndpointPermissions +} + +func TestRBACEndpointPermissionsCollection_Add(t *testing.T) { + type args struct { + rbacEndpointPermission RBACEndpointPermission + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when role is nil", + args: args{ + rbacEndpointPermission: RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Actions: kong.StringSlice("read"), + Endpoint: kong.String("/foo"), + }, + }, + }, + wantErr: true, + }, + { + name: "inserts without a workspace, endpoint, and role", + args: args{ + rbacEndpointPermission: RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/foo"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }, + }, + }, + wantErr: false, + }, + } + k := rbacEndpointPermissionsCollection() + rbacEndpointPermission1 := RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("*"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }, + } + k.Add(rbacEndpointPermission1) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.rbacEndpointPermission); (err != nil) != tt.wantErr { + t.Errorf("RBACEndpointPermissionsCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestRBACEndpointPermissionsCollection_Get(t *testing.T) { + type args struct { + nameOrID string + } + rbacEndpointPermission1 := RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/foo"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }, + } + rbacEndpointPermission2 := RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/bar"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }, + } + tests := []struct { + name string + args args + want *RBACEndpointPermission + wantErr bool + }{ + { + name: "gets a rbacEndpointPermission by ID", + args: args{ + nameOrID: rbacEndpointPermission1.FriendlyName(), + }, + want: &rbacEndpointPermission1, + wantErr: false, + }, + { + name: "gets a rbacEndpointPermission by Name", + args: args{ + nameOrID: rbacEndpointPermission2.FriendlyName(), + }, + want: &rbacEndpointPermission2, + wantErr: false, + }, + { + name: "returns an ErrNotFound when no rbacEndpointPermission found", + args: args{ + nameOrID: "baz-id", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an error when ID is empty", + args: args{ + nameOrID: "", + }, + want: nil, + wantErr: true, + }, + } + k := rbacEndpointPermissionsCollection() + k.Add(rbacEndpointPermission1) + k.Add(rbacEndpointPermission2) + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := k.Get(tc.args.nameOrID) + if (err != nil) != tc.wantErr { + t.Errorf("RBACEndpointPermissionsCollection.Get() error = %v, wantErr %v", err, tc.wantErr) + return + } + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("RBACEndpointPermissionsCollection.Get() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestRBACEndpointPermissionsCollection_Update(t *testing.T) { + rbacEndpointPermission1 := RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/foo"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }, + } + rbacEndpointPermission2 := RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/bar"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }, + } + rbacEndpointPermission3 := RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/foo"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + Comment: kong.String("updated!"), + }, + } + type args struct { + rbacEndpointPermission RBACEndpointPermission + } + tests := []struct { + name string + args args + wantErr bool + updatedRBACEndpointPermission *RBACEndpointPermission + }{ + { + name: "update errors if rbacEndpointPermission does not exist", + args: args{ + rbacEndpointPermission: RBACEndpointPermission{ + RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("foo"), + Endpoint: kong.String("bad"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }, + }, + }, + wantErr: true, + }, + { + name: "update succeeds when ID is supplied", + args: args{ + rbacEndpointPermission: rbacEndpointPermission3, + }, + wantErr: false, + updatedRBACEndpointPermission: &rbacEndpointPermission3, + }, + } + k := rbacEndpointPermissionsCollection() + k.Add(rbacEndpointPermission1) + k.Add(rbacEndpointPermission2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // t.Parallel() + if err := k.Update(tt.args.rbacEndpointPermission); (err != nil) != tt.wantErr { + t.Errorf("RBACEndpointPermissionsCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + got, _ := k.Get(tt.updatedRBACEndpointPermission.FriendlyName()) + + if !reflect.DeepEqual(got, tt.updatedRBACEndpointPermission) { + t.Errorf("update rbacEndpointPermission, got = %#v, want %#v", got, tt.updatedRBACEndpointPermission) + } + } + }) + } +} + +func TestRBACEndpointPermissionDelete(t *testing.T) { + assert := assert.New(t) + collection := rbacEndpointPermissionsCollection() + + rbacEndpointPermission := RBACEndpointPermission{RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/foo"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }} + + err := collection.Add(rbacEndpointPermission) + assert.Nil(err) + + re, err := collection.Get(rbacEndpointPermission.FriendlyName()) + assert.Nil(err) + assert.NotNil(re) + + err = collection.Delete(re.FriendlyName()) + assert.Nil(err) + + err = collection.Delete(re.FriendlyName()) + assert.NotNil(err) +} + +func TestRBACEndpointPermissionGetAll(t *testing.T) { + assert := assert.New(t) + collection := rbacEndpointPermissionsCollection() + + rbacEndpointPermission := RBACEndpointPermission{RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/first"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }} + + err := collection.Add(rbacEndpointPermission) + assert.Nil(err) + + rbacEndpointPermission2 := RBACEndpointPermission{RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/second"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }} + + err = collection.Add(rbacEndpointPermission2) + assert.Nil(err) + + rbacEndpointPermissions, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(rbacEndpointPermissions)) +} + +func TestRBACEndpointPermissionGetAllByServiceID(t *testing.T) { + assert := assert.New(t) + collection := rbacEndpointPermissionsCollection() + + rbacEndpointPermissions := []*RBACEndpointPermission{ + {RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/first"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }}, + {RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/second"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }}, + {RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/third"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("1234")}, + }}, + {RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/fourth"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("4321")}, + }}, + {RBACEndpointPermission: kong.RBACEndpointPermission{ + Workspace: kong.String("*"), + Endpoint: kong.String("/fifth"), + Actions: kong.StringSlice("read"), + Role: &kong.RBACRole{ID: kong.String("4321")}, + }}, + } + + for _, rbacEndpointPermission := range rbacEndpointPermissions { + err := collection.Add(*rbacEndpointPermission) + assert.Nil(err) + } + + rbacEndpointPermissions, err := collection.GetAllByRoleID("1234") + assert.Nil(err) + assert.Equal(3, len(rbacEndpointPermissions)) + + rbacEndpointPermissions, err = collection.GetAllByRoleID("4321") + assert.Nil(err) + assert.Equal(2, len(rbacEndpointPermissions)) +} diff --git a/pkg/state/rbac_role.go b/pkg/state/rbac_role.go new file mode 100644 index 0000000..bd6f31d --- /dev/null +++ b/pkg/state/rbac_role.go @@ -0,0 +1,176 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/utils" +) + +const ( + rbacRoleTableName = "rbac-role" +) + +var rbacRoleTableSchema = &memdb.TableSchema{ + Name: rbacRoleTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Name"}, + }, + all: allIndex, + }, +} + +// RBACRolesCollection stores and indexes Kong RBACRoles. +type RBACRolesCollection collection + +// Add adds a rbacRole into RBACRolesCollection +// rbacRole.ID should not be nil else an error is thrown. +func (k *RBACRolesCollection) Add(rbacRole RBACRole) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(rbacRole.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *rbacRole.ID) + if !utils.Empty(rbacRole.Name) { + searchBy = append(searchBy, *rbacRole.Name) + } + _, err := getRBACRole(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting rbacRole %v: %w", rbacRole.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(rbacRoleTableName, &rbacRole) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getRBACRole(txn *memdb.Txn, IDs ...string) (*RBACRole, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, rbacRoleTableName, + []string{"name", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + + rbacRole, ok := res.(*RBACRole) + if !ok { + panic(unexpectedType) + } + return &RBACRole{RBACRole: *rbacRole.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets a rbacRole by name or ID. +func (k *RBACRolesCollection) Get(nameOrID string) (*RBACRole, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + rbacRole, err := getRBACRole(txn, nameOrID) + if err != nil { + return nil, err + } + return rbacRole, nil +} + +// Update updates a rbacRole +func (k *RBACRolesCollection) Update(rbacRole RBACRole) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(rbacRole.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteRBACRole(txn, *rbacRole.ID) + if err != nil { + return err + } + + err = txn.Insert(rbacRoleTableName, &rbacRole) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteRBACRole(txn *memdb.Txn, nameOrID string) error { + rbacRole, err := getRBACRole(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(rbacRoleTableName, rbacRole) + if err != nil { + return err + } + return nil +} + +// Delete deletes a rbacRole by name or ID. +func (k *RBACRolesCollection) Delete(nameOrID string) error { + if nameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteRBACRole(txn, nameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a rbacRole by name or ID. +func (k *RBACRolesCollection) GetAll() ([]*RBACRole, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(rbacRoleTableName, all, true) + if err != nil { + return nil, err + } + + var res []*RBACRole + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*RBACRole) + if !ok { + panic(unexpectedType) + } + res = append(res, &RBACRole{RBACRole: *r.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/rbac_role_test.go b/pkg/state/rbac_role_test.go new file mode 100644 index 0000000..4a4998a --- /dev/null +++ b/pkg/state/rbac_role_test.go @@ -0,0 +1,298 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func rbacRolesCollection() *RBACRolesCollection { + return state().RBACRoles +} + +func TestRBACRolesCollection_Add(t *testing.T) { + type args struct { + rbacRole RBACRole + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + rbacRole: RBACRole{ + RBACRole: kong.RBACRole{ + Name: kong.String("foo"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors without a name", + args: args{ + rbacRole: RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("id1"), + }, + }, + }, + wantErr: true, + }, + { + name: "inserts with a name and ID", + args: args{ + rbacRole: RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("id2"), + Name: kong.String("bar-name"), + }, + }, + }, + wantErr: false, + }, + { + name: "errors on re-insert when name is present", + args: args{ + rbacRole: RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("id4"), + Name: kong.String("foo-name"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert when id is present", + args: args{ + rbacRole: RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("id3"), + Name: kong.String("foobar-name"), + }, + }, + }, + wantErr: true, + }, + } + k := rbacRolesCollection() + rbacRole1 := RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("id3"), + Name: kong.String("foo-name"), + }, + } + k.Add(rbacRole1) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.rbacRole); (err != nil) != tt.wantErr { + t.Errorf("RBACRolesCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestRBACRolesCollection_Get(t *testing.T) { + type args struct { + nameOrID string + } + rbacRole1 := RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("foo-id"), + }, + } + rbacRole2 := RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + }, + } + tests := []struct { + name string + args args + want *RBACRole + wantErr bool + }{ + { + name: "gets a rbacRole by ID", + args: args{ + nameOrID: "foo-id", + }, + want: &rbacRole1, + wantErr: false, + }, + { + name: "gets a rbacRole by Name", + args: args{ + nameOrID: "bar-name", + }, + want: &rbacRole2, + wantErr: false, + }, + { + name: "returns an ErrNotFound when no rbacRole found", + args: args{ + nameOrID: "baz-id", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an error when ID is empty", + args: args{ + nameOrID: "", + }, + want: nil, + wantErr: true, + }, + } + k := rbacRolesCollection() + k.Add(rbacRole1) + k.Add(rbacRole2) + for _, tt := range tests { + tc := &tt //nolint:gosec + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := k.Get(tc.args.nameOrID) + if (err != nil) != tc.wantErr { + t.Errorf("RBACRolesCollection.Get() error = %v, wantErr %v", err, tc.wantErr) + return + } + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("RBACRolesCollection.Get() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestRBACRolesCollection_Update(t *testing.T) { + rbacRole1 := RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + } + rbacRole2 := RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + }, + } + rbacRole3 := RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("foo-id"), + Name: kong.String("foo-new-name"), + }, + } + type args struct { + rbacRole RBACRole + } + tests := []struct { + name string + args args + wantErr bool + updatedRBACRole *RBACRole + }{ + { + name: "update errors if rbacRole.ID is nil", + args: args{ + rbacRole: RBACRole{ + RBACRole: kong.RBACRole{ + Name: kong.String("name"), + }, + }, + }, + wantErr: true, + }, + { + name: "update errors if rbacRole does not exist", + args: args{ + rbacRole: RBACRole{ + RBACRole: kong.RBACRole{ + ID: kong.String("does-not-exist"), + }, + }, + }, + wantErr: true, + }, + { + name: "update succeeds when ID is supplied", + args: args{ + rbacRole: rbacRole3, + }, + wantErr: false, + updatedRBACRole: &rbacRole3, + }, + } + k := rbacRolesCollection() + k.Add(rbacRole1) + k.Add(rbacRole2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // t.Parallel() + if err := k.Update(tt.args.rbacRole); (err != nil) != tt.wantErr { + t.Errorf("RBACRolesCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + got, _ := k.Get(*tt.updatedRBACRole.ID) + + if !reflect.DeepEqual(got, tt.updatedRBACRole) { + t.Errorf("update rbacRole, got = %#v, want %#v", got, tt.updatedRBACRole) + } + } + }) + } +} + +func TestRBACRoleDelete(t *testing.T) { + assert := assert.New(t) + collection := rbacRolesCollection() + + var rbacRole RBACRole + rbacRole.Name = kong.String("my-rbacRole") + rbacRole.ID = kong.String("first") + + err := collection.Add(rbacRole) + assert.Nil(err) + + re, err := collection.Get("my-rbacRole") + assert.Nil(err) + assert.NotNil(re) + + err = collection.Delete(*re.ID) + assert.Nil(err) + + err = collection.Delete(*re.ID) + assert.NotNil(err) +} + +func TestRBACRoleGetAll(t *testing.T) { + assert := assert.New(t) + collection := rbacRolesCollection() + + var rbacRole RBACRole + rbacRole.Name = kong.String("my-rbacRole1") + rbacRole.ID = kong.String("first") + + err := collection.Add(rbacRole) + assert.Nil(err) + + var rbacRole2 RBACRole + rbacRole2.Name = kong.String("my-rbacRole2") + rbacRole2.ID = kong.String("second") + + err = collection.Add(rbacRole2) + assert.Nil(err) + + rbacRoles, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(rbacRoles)) +} diff --git a/pkg/state/route.go b/pkg/state/route.go new file mode 100644 index 0000000..dfd3f22 --- /dev/null +++ b/pkg/state/route.go @@ -0,0 +1,213 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + routeTableName = "route" + routesByServiceID = "routesByServiceID" +) + +var routeTableSchema = &memdb.TableSchema{ + Name: routeTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Name"}, + AllowMissing: true, + }, + all: allIndex, + // foreign + routesByServiceID: { + Name: routesByServiceID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Service", + Sub: "ID", + }, + }, + }, + AllowMissing: true, + }, + }, +} + +// RoutesCollection stores and indexes Kong Routes. +type RoutesCollection collection + +// Add adds a route into RoutesCollection +// route.ID should not be nil else an error is thrown. +func (k *RoutesCollection) Add(route Route) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(route.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *route.ID) + if !utils.Empty(route.Name) { + searchBy = append(searchBy, *route.Name) + } + _, err := getRoute(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting route %v: %w", route.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(routeTableName, &route) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getRoute(txn *memdb.Txn, IDs ...string) (*Route, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, routeTableName, + []string{"name", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + + route, ok := res.(*Route) + if !ok { + panic(unexpectedType) + } + return &Route{Route: *route.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets a route by name or ID. +func (k *RoutesCollection) Get(nameOrID string) (*Route, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + route, err := getRoute(txn, nameOrID) + if err != nil { + return nil, err + } + return route, nil +} + +// Update updates a route +func (k *RoutesCollection) Update(route Route) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(route.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteRoute(txn, *route.ID) + if err != nil { + return err + } + + err = txn.Insert(routeTableName, &route) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteRoute(txn *memdb.Txn, nameOrID string) error { + route, err := getRoute(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(routeTableName, route) + if err != nil { + return err + } + return nil +} + +// Delete deletes a route by name or ID. +func (k *RoutesCollection) Delete(nameOrID string) error { + if nameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteRoute(txn, nameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a route by name or ID. +func (k *RoutesCollection) GetAll() ([]*Route, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(routeTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Route + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*Route) + if !ok { + panic(unexpectedType) + } + res = append(res, &Route{Route: *r.DeepCopy()}) + } + txn.Commit() + return res, nil +} + +// GetAllByServiceID returns all routes referencing a service +// by its id. +func (k *RoutesCollection) GetAllByServiceID(id string) ([]*Route, + error, +) { + txn := k.db.Txn(false) + iter, err := txn.Get(routeTableName, routesByServiceID, id) + if err != nil { + return nil, err + } + var res []*Route + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*Route) + if !ok { + panic(unexpectedType) + } + res = append(res, &Route{Route: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/route_test.go b/pkg/state/route_test.go new file mode 100644 index 0000000..f85a672 --- /dev/null +++ b/pkg/state/route_test.go @@ -0,0 +1,439 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func routesCollection() *RoutesCollection { + return state().Routes +} + +func TestRoutesCollection_Add(t *testing.T) { + type args struct { + route Route + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + route: Route{ + Route: kong.Route{ + Name: kong.String("foo"), + Hosts: kong.StringSlice("example.com"), + }, + }, + }, + wantErr: true, + }, + { + name: "inserts without a name", + args: args{ + route: Route{ + Route: kong.Route{ + ID: kong.String("id1"), + Hosts: kong.StringSlice("example.com"), + }, + }, + }, + wantErr: false, + }, + { + name: "inserts with a name and ID", + args: args{ + route: Route{ + Route: kong.Route{ + ID: kong.String("id2"), + Name: kong.String("bar-name"), + Hosts: kong.StringSlice("example.com"), + }, + }, + }, + wantErr: false, + }, + { + name: "errors on re-insert when name is present", + args: args{ + route: Route{ + Route: kong.Route{ + ID: kong.String("id4"), + Name: kong.String("foo-name"), + Hosts: kong.StringSlice("example.com"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert when id is present", + args: args{ + route: Route{ + Route: kong.Route{ + ID: kong.String("id3"), + Name: kong.String("foobar-name"), + Hosts: kong.StringSlice("example.com"), + }, + }, + }, + wantErr: true, + }, + } + k := routesCollection() + route1 := Route{ + Route: kong.Route{ + ID: kong.String("id3"), + Name: kong.String("foo-name"), + Hosts: kong.StringSlice("example.com"), + }, + } + k.Add(route1) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.route); (err != nil) != tt.wantErr { + t.Errorf("RoutesCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestRoutesCollection_Get(t *testing.T) { + type args struct { + nameOrID string + } + route1 := Route{ + Route: kong.Route{ + ID: kong.String("foo-id"), + Hosts: kong.StringSlice("example.com"), + }, + } + route2 := Route{ + Route: kong.Route{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + Hosts: kong.StringSlice("example.com"), + }, + } + tests := []struct { + name string + args args + want *Route + wantErr bool + }{ + { + name: "gets a route by ID", + args: args{ + nameOrID: "foo-id", + }, + want: &route1, + wantErr: false, + }, + { + name: "gets a route by Name", + args: args{ + nameOrID: "bar-name", + }, + want: &route2, + wantErr: false, + }, + { + name: "returns an ErrNotFound when no route found", + args: args{ + nameOrID: "baz-id", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an error when ID is empty", + args: args{ + nameOrID: "", + }, + want: nil, + wantErr: true, + }, + } + k := routesCollection() + k.Add(route1) + k.Add(route2) + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := k.Get(tc.args.nameOrID) + if (err != nil) != tc.wantErr { + t.Errorf("RoutesCollection.Get() error = %v, wantErr %v", err, tc.wantErr) + return + } + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("RoutesCollection.Get() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestRoutesInvalidType(t *testing.T) { + assert := assert.New(t) + + collection := routesCollection() + + var service Service + service.Name = kong.String("my-service") + service.ID = kong.String("first") + txn := collection.db.Txn(true) + txn.Insert(routeTableName, &service) + txn.Commit() + + assert.Panics(func() { + collection.Get("my-service") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestRoutesCollection_Update(t *testing.T) { + route1 := Route{ + Route: kong.Route{ + ID: kong.String("foo-id"), + Hosts: kong.StringSlice("example.com"), + }, + } + route2 := Route{ + Route: kong.Route{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + Hosts: kong.StringSlice("example.com"), + }, + } + route3 := Route{ + Route: kong.Route{ + ID: kong.String("foo-id"), + Name: kong.String("name"), + Hosts: kong.StringSlice("example.com"), + }, + } + type args struct { + route Route + } + tests := []struct { + name string + args args + wantErr bool + updatedRoute *Route + }{ + { + name: "update errors if route.ID is nil", + args: args{ + route: Route{ + Route: kong.Route{ + Name: kong.String("name"), + }, + }, + }, + wantErr: true, + }, + { + name: "update errors if route does not exist", + args: args{ + route: Route{ + Route: kong.Route{ + ID: kong.String("does-not-exist"), + }, + }, + }, + wantErr: true, + }, + { + name: "update succeeds when ID is supplied", + args: args{ + route: route3, + }, + wantErr: false, + updatedRoute: &route3, + }, + } + k := routesCollection() + k.Add(route1) + k.Add(route2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // t.Parallel() + if err := k.Update(tt.args.route); (err != nil) != tt.wantErr { + t.Errorf("RoutesCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + got, _ := k.Get(*tt.updatedRoute.ID) + + if !reflect.DeepEqual(got, tt.updatedRoute) { + t.Errorf("update route, got = %#v, want %#v", got, tt.updatedRoute) + } + } + }) + } +} + +// Regression test +// to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestRouteGetMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := routesCollection() + + var route Route + route.Name = kong.String("my-route") + route.ID = kong.String("first") + route.Hosts = kong.StringSlice("example.com", "demo.example.com") + route.Service = &kong.Service{ + ID: kong.String("service1-id"), + } + assert.NotNil(route.Service) + err := collection.Add(route) + assert.NotNil(route.Service) + assert.Nil(err) + + re, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(re) + assert.Equal("my-route", *re.Name) + + re.SNIs = kong.StringSlice("example.com", "demo.example.com") + + re, err = collection.Get("my-route") + assert.Nil(err) + assert.NotNil(re) + assert.Nil(re.SNIs) +} + +func TestRouteDelete(t *testing.T) { + assert := assert.New(t) + collection := routesCollection() + + var route Route + route.Name = kong.String("my-route") + route.ID = kong.String("first") + route.Hosts = kong.StringSlice("example.com", "demo.example.com") + route.Service = &kong.Service{ + ID: kong.String("service1-id"), + } + err := collection.Add(route) + assert.Nil(err) + + re, err := collection.Get("my-route") + assert.Nil(err) + assert.NotNil(re) + assert.Equal("example.com", *re.Hosts[0]) + + err = collection.Delete(*re.ID) + assert.Nil(err) + + err = collection.Delete(*re.ID) + assert.NotNil(err) +} + +func TestRouteGetAll(t *testing.T) { + assert := assert.New(t) + collection := routesCollection() + + var route Route + route.Name = kong.String("my-route1") + route.ID = kong.String("first") + route.Hosts = kong.StringSlice("example.com", "demo.example.com") + route.Service = &kong.Service{ + ID: kong.String("service1-id"), + } + err := collection.Add(route) + assert.Nil(err) + + var route2 Route + route2.Name = kong.String("my-route2") + route2.ID = kong.String("second") + route2.Hosts = kong.StringSlice("example.com", "demo.example.com") + route2.Service = &kong.Service{ + ID: kong.String("service1-id"), + } + err = collection.Add(route2) + assert.Nil(err) + + routes, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(routes)) +} + +func TestRouteGetAllByServiceID(t *testing.T) { + assert := assert.New(t) + collection := routesCollection() + + routes := []*Route{ + { + Route: kong.Route{ + ID: kong.String("route0-id"), + }, + }, + { + Route: kong.Route{ + ID: kong.String("route1-id"), + Name: kong.String("route1-name"), + Service: &kong.Service{ + ID: kong.String("service1-id"), + }, + }, + }, + { + Route: kong.Route{ + ID: kong.String("route2-id"), + Service: &kong.Service{ + ID: kong.String("service1-id"), + }, + }, + }, + { + Route: kong.Route{ + ID: kong.String("route3-id"), + Name: kong.String("route3-name"), + Service: &kong.Service{ + ID: kong.String("service2-id"), + }, + }, + }, + { + Route: kong.Route{ + ID: kong.String("route4-id"), + Name: kong.String("route4-name"), + Service: &kong.Service{ + ID: kong.String("service2-id"), + }, + }, + }, + { + Route: kong.Route{ + ID: kong.String("route5-id"), + Service: &kong.Service{ + ID: kong.String("service2-id"), + }, + }, + }, + } + + for _, route := range routes { + err := collection.Add(*route) + assert.Nil(err) + } + + routes, err := collection.GetAllByServiceID("service1-id") + assert.Nil(err) + assert.Equal(2, len(routes)) + + routes, err = collection.GetAllByServiceID("service2-id") + assert.Nil(err) + assert.Equal(3, len(routes)) +} diff --git a/pkg/state/service.go b/pkg/state/service.go new file mode 100644 index 0000000..5b6d05a --- /dev/null +++ b/pkg/state/service.go @@ -0,0 +1,172 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/utils" +) + +const ( + serviceTableName = "service" +) + +var serviceTableSchema = &memdb.TableSchema{ + Name: serviceTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Name"}, + AllowMissing: true, + }, + all: allIndex, + }, +} + +// ServicesCollection stores and indexes Kong Services. +type ServicesCollection collection + +// Add adds a service to the collection. +// service.ID should not be nil else an error is thrown. +func (k *ServicesCollection) Add(service Service) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(service.ID) { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *service.ID) + if !utils.Empty(service.Name) { + searchBy = append(searchBy, *service.Name) + } + _, err := getService(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting service %v: %w", service.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(serviceTableName, &service) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getService(txn *memdb.Txn, IDs ...string) (*Service, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, serviceTableName, + []string{"name", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + service, ok := res.(*Service) + if !ok { + panic(unexpectedType) + } + return &Service{Service: *service.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets a service by name or ID. +func (k *ServicesCollection) Get(nameOrID string) (*Service, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + return getService(txn, nameOrID) +} + +// Update udpates an existing service. +// It returns an error if the service is not already present. +func (k *ServicesCollection) Update(service Service) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(service.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteService(txn, *service.ID) + if err != nil { + return err + } + + err = txn.Insert(serviceTableName, &service) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteService(txn *memdb.Txn, nameOrID string) error { + service, err := getService(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(serviceTableName, service) + if err != nil { + return err + } + return nil +} + +// Delete deletes a service by name or ID. +func (k *ServicesCollection) Delete(nameOrID string) error { + if nameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteService(txn, nameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll returns all the services. +func (k *ServicesCollection) GetAll() ([]*Service, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(serviceTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Service + for el := iter.Next(); el != nil; el = iter.Next() { + s, ok := el.(*Service) + if !ok { + panic(unexpectedType) + } + res = append(res, &Service{Service: *s.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/service_package.go b/pkg/state/service_package.go new file mode 100644 index 0000000..c10dcfd --- /dev/null +++ b/pkg/state/service_package.go @@ -0,0 +1,171 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/utils" +) + +const ( + servicePackageTableName = "service-package" +) + +var servicePackageTableSchema = &memdb.TableSchema{ + Name: servicePackageTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Name"}, + }, + all: allIndex, + }, +} + +// ServicePackagesCollection stores and indexes Kong Services. +type ServicePackagesCollection collection + +// Add adds a servicePackage to the collection. +// service.ID should not be nil else an error is thrown. +func (k *ServicePackagesCollection) Add(servicePackage ServicePackage) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(servicePackage.ID) { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *servicePackage.ID) + if !utils.Empty(servicePackage.Name) { + searchBy = append(searchBy, *servicePackage.Name) + } + _, err := getServicePackage(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting servicePackage %v: %w", servicePackage.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(servicePackageTableName, &servicePackage) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getServicePackage(txn *memdb.Txn, IDs ...string) (*ServicePackage, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, servicePackageTableName, + []string{"name", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + servicePackage, ok := res.(*ServicePackage) + if !ok { + panic(unexpectedType) + } + return &ServicePackage{ServicePackage: *servicePackage.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets a servicePackage by name or ID. +func (k *ServicePackagesCollection) Get(nameOrID string) (*ServicePackage, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + return getServicePackage(txn, nameOrID) +} + +// Update udpates an existing service. +// It returns an error if the servicePackage is not already present. +func (k *ServicePackagesCollection) Update(servicePackage ServicePackage) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(servicePackage.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteServicePackage(txn, *servicePackage.ID) + if err != nil { + return err + } + + err = txn.Insert(servicePackageTableName, &servicePackage) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteServicePackage(txn *memdb.Txn, nameOrID string) error { + servicePackage, err := getServicePackage(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(servicePackageTableName, servicePackage) + if err != nil { + return err + } + return nil +} + +// Delete deletes a servicePackage by name or ID. +func (k *ServicePackagesCollection) Delete(nameOrID string) error { + if nameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteServicePackage(txn, nameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll returns all the servicePackages. +func (k *ServicePackagesCollection) GetAll() ([]*ServicePackage, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(servicePackageTableName, all, true) + if err != nil { + return nil, err + } + + var res []*ServicePackage + for el := iter.Next(); el != nil; el = iter.Next() { + s, ok := el.(*ServicePackage) + if !ok { + panic(unexpectedType) + } + res = append(res, &ServicePackage{ServicePackage: *s.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/service_package_test.go b/pkg/state/service_package_test.go new file mode 100644 index 0000000..431f79b --- /dev/null +++ b/pkg/state/service_package_test.go @@ -0,0 +1,390 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/kong/deck/konnect" + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func servicePackagesCollection() *ServicePackagesCollection { + return state().ServicePackages +} + +func TestServicePackagesCollection_Add(t *testing.T) { + type args struct { + servicePackage ServicePackage + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + servicePackage: ServicePackage{ + ServicePackage: konnect.ServicePackage{ + Name: kong.String("foo"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors out without a name", + args: args{ + servicePackage: ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + wantErr: true, + }, + { + name: "inserts with a name and ID", + args: args{ + servicePackage: ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("id2"), + Name: kong.String("foo-name"), + }, + }, + }, + wantErr: false, + }, + { + name: "errors on re-insert by ID", + args: args{ + servicePackage: ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("id3"), + Name: kong.String("foo-name"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert by Name", + args: args{ + servicePackage: ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("new-id"), + Name: kong.String("bar-name"), + }, + }, + }, + wantErr: true, + }, + } + k := servicePackagesCollection() + svc1 := ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("id3"), + Name: kong.String("bar-name"), + }, + } + k.Add(svc1) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.servicePackage); (err != nil) != tt.wantErr { + t.Errorf("ServicePackageCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestServicePackagesCollection_Get(t *testing.T) { + type args struct { + nameOrID string + } + svc1 := ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + } + svc2 := ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + }, + } + tests := []struct { + name string + args args + want *ServicePackage + wantErr bool + }{ + { + name: "gets a servicePackage by ID", + args: args{ + nameOrID: "foo-id", + }, + want: &svc1, + wantErr: false, + }, + { + name: "gets a servicePackage by Name", + args: args{ + nameOrID: "bar-name", + }, + want: &svc2, + wantErr: false, + }, + { + name: "returns an ErrNotFound when no servicePackage found", + args: args{ + nameOrID: "baz-id", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an error when ID is empty", + args: args{ + nameOrID: "", + }, + want: nil, + wantErr: true, + }, + } + k := servicePackagesCollection() + k.Add(svc1) + k.Add(svc2) + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := k.Get(tc.args.nameOrID) + if (err != nil) != tc.wantErr { + t.Errorf("ServicePackageCollection.Get() error = %v, wantErr %v", err, tc.wantErr) + return + } + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("ServicePackageCollection.Get() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestServicePackagesCollection_Update(t *testing.T) { + svc1 := ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + } + svc2 := ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + }, + } + svc3 := ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("foo-id"), + Name: kong.String("name"), + }, + } + type args struct { + servicePackage ServicePackage + } + tests := []struct { + name string + args args + wantErr bool + updatedService *ServicePackage + }{ + { + name: "update errors if servicePackage.ID is nil", + args: args{ + servicePackage: ServicePackage{ + ServicePackage: konnect.ServicePackage{ + Name: kong.String("name"), + }, + }, + }, + wantErr: true, + }, + { + name: "update errors if servicePackage does not exist", + args: args{ + servicePackage: ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("does-not-exist"), + }, + }, + }, + wantErr: true, + }, + { + name: "update succeeds when ID is supplied", + args: args{ + servicePackage: svc3, + }, + wantErr: false, + updatedService: &svc3, + }, + } + k := servicePackagesCollection() + k.Add(svc1) + k.Add(svc2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // t.Parallel() + if err := k.Update(tt.args.servicePackage); (err != nil) != tt.wantErr { + t.Errorf("ServicePackageCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + got, _ := k.Get(*tt.updatedService.ID) + + if !reflect.DeepEqual(got, tt.updatedService) { + t.Errorf("update servicePackage, got = %#v, want %#v", got, tt.updatedService) + } + } + }) + } +} + +func TestServicePackageUpdate(t *testing.T) { + assert := assert.New(t) + k := servicePackagesCollection() + svc1 := ServicePackage{ + ServicePackage: konnect.ServicePackage{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + } + assert.Nil(k.Add(svc1)) + + svc1.Name = kong.String("bar-name") + assert.Nil(k.Update(svc1)) + + r, err := k.Get("foo-id") + assert.Nil(err) + assert.NotNil(r) + + r, err = k.Get("bar-name") + assert.Nil(err) + assert.NotNil(r) + + r, err = k.Get("foo-name") + assert.NotNil(err) + assert.Nil(r) +} + +func TestServicePackagesInvalidType(t *testing.T) { + assert := assert.New(t) + collection := servicePackagesCollection() + + var route Route + route.Name = kong.String("my-route") + route.ID = kong.String("first") + txn := collection.db.Txn(true) + txn.Insert(servicePackageTableName, &route) + txn.Commit() + + assert.Panics(func() { + collection.Get("my-route") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestServicePackageDelete(t *testing.T) { + assert := assert.New(t) + collection := servicePackagesCollection() + + var servicePackage ServicePackage + servicePackage.ID = kong.String("first-id") + servicePackage.Name = kong.String("first-name") + err := collection.Add(servicePackage) + assert.Nil(err) + + err = collection.Delete("does-not-exist") + assert.NotNil(err) + err = collection.Delete("first-id") + assert.Nil(err) + + err = collection.Delete("first-name") + assert.NotNil(err) + + err = collection.Delete("") + assert.NotNil(err) +} + +func TestServicePackageGetAll(t *testing.T) { + assert := assert.New(t) + collection := servicePackagesCollection() + + services := []ServicePackage{ + { + ServicePackage: konnect.ServicePackage{ + ID: kong.String("first"), + Name: kong.String("my-service1"), + }, + }, + { + ServicePackage: konnect.ServicePackage{ + ID: kong.String("second"), + Name: kong.String("my-service2"), + }, + }, + } + for _, s := range services { + assert.Nil(collection.Add(s)) + } + + allServices, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(len(services), len(allServices)) +} + +// Regression test +// to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestServicePackagesGetAllMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := servicePackagesCollection() + + services := []ServicePackage{ + { + ServicePackage: konnect.ServicePackage{ + ID: kong.String("first"), + Name: kong.String("my-service1"), + Description: kong.String("service1-desc"), + }, + }, + { + ServicePackage: konnect.ServicePackage{ + ID: kong.String("second"), + Name: kong.String("my-service2"), + Description: kong.String("service2-desc"), + }, + }, + } + for _, s := range services { + assert.Nil(collection.Add(s)) + } + + allServices, err := collection.GetAll() + assert.Nil(err) + assert.Equal(len(services), len(allServices)) + + allServices[0].Description = kong.String("new-service1-desc") + allServices[1].Description = kong.String("new-service2-desc") + + servicePackage, err := collection.Get("my-service1") + assert.Nil(err) + assert.NotNil(servicePackage) + assert.Equal("service1-desc", *servicePackage.Description) +} diff --git a/pkg/state/service_test.go b/pkg/state/service_test.go new file mode 100644 index 0000000..6624683 --- /dev/null +++ b/pkg/state/service_test.go @@ -0,0 +1,426 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func servicesCollection() *ServicesCollection { + return state().Services +} + +func TestServicesCollection_Add(t *testing.T) { + type args struct { + service Service + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + service: Service{ + Service: kong.Service{ + Name: kong.String("foo"), + Host: kong.String("example.com"), + }, + }, + }, + wantErr: true, + }, + { + name: "inserts without a name", + args: args{ + service: Service{ + Service: kong.Service{ + ID: kong.String("id1"), + Host: kong.String("example.com"), + }, + }, + }, + wantErr: false, + }, + { + name: "inserts with a name and ID", + args: args{ + service: Service{ + Service: kong.Service{ + ID: kong.String("id2"), + Name: kong.String("foo-name"), + Host: kong.String("example.com"), + }, + }, + }, + wantErr: false, + }, + { + name: "errors on re-insert by ID", + args: args{ + service: Service{ + Service: kong.Service{ + ID: kong.String("id3"), + Name: kong.String("foo-name"), + Host: kong.String("example.com"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert by Name", + args: args{ + service: Service{ + Service: kong.Service{ + ID: kong.String("new-id"), + Name: kong.String("bar-name"), + Host: kong.String("example.com"), + }, + }, + }, + wantErr: true, + }, + } + k := servicesCollection() + svc1 := Service{ + Service: kong.Service{ + ID: kong.String("id3"), + Name: kong.String("bar-name"), + Host: kong.String("example.com"), + }, + } + k.Add(svc1) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.service); (err != nil) != tt.wantErr { + t.Errorf("ServicesCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestServicesCollection_Get(t *testing.T) { + type args struct { + nameOrID string + } + svc1 := Service{ + Service: kong.Service{ + ID: kong.String("foo-id"), + Host: kong.String("example.com"), + }, + } + svc2 := Service{ + Service: kong.Service{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + Host: kong.String("example.com"), + }, + } + tests := []struct { + name string + args args + want *Service + wantErr bool + }{ + { + name: "gets a service by ID", + args: args{ + nameOrID: "foo-id", + }, + want: &svc1, + wantErr: false, + }, + { + name: "gets a service by Name", + args: args{ + nameOrID: "bar-name", + }, + want: &svc2, + wantErr: false, + }, + { + name: "returns an ErrNotFound when no service found", + args: args{ + nameOrID: "baz-id", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an error when ID is empty", + args: args{ + nameOrID: "", + }, + want: nil, + wantErr: true, + }, + } + k := servicesCollection() + k.Add(svc1) + k.Add(svc2) + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := k.Get(tc.args.nameOrID) + if (err != nil) != tc.wantErr { + t.Errorf("ServicesCollection.Get() error = %v, wantErr %v", err, tc.wantErr) + return + } + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("ServicesCollection.Get() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestServicesCollection_Update(t *testing.T) { + svc1 := Service{ + Service: kong.Service{ + ID: kong.String("foo-id"), + Host: kong.String("example.com"), + }, + } + svc2 := Service{ + Service: kong.Service{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + Host: kong.String("example.com"), + }, + } + svc3 := Service{ + Service: kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("name"), + Host: kong.String("2.example.com"), + Port: kong.Int(42), + }, + } + type args struct { + service Service + } + tests := []struct { + name string + args args + wantErr bool + updatedService *Service + }{ + { + name: "update errors if service.ID is nil", + args: args{ + service: Service{ + Service: kong.Service{ + Name: kong.String("name"), + }, + }, + }, + wantErr: true, + }, + { + name: "update errors if service does not exist", + args: args{ + service: Service{ + Service: kong.Service{ + ID: kong.String("does-not-exist"), + }, + }, + }, + wantErr: true, + }, + { + name: "update succeeds when ID is supplied", + args: args{ + service: svc3, + }, + wantErr: false, + updatedService: &svc3, + }, + } + k := servicesCollection() + k.Add(svc1) + k.Add(svc2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // t.Parallel() + if err := k.Update(tt.args.service); (err != nil) != tt.wantErr { + t.Errorf("ServicesCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + got, _ := k.Get(*tt.updatedService.ID) + + if !reflect.DeepEqual(got, tt.updatedService) { + t.Errorf("update service, got = %#v, want %#v", got, tt.updatedService) + } + } + }) + } +} + +func TestServiceUpdate(t *testing.T) { + assert := assert.New(t) + k := servicesCollection() + svc1 := Service{ + Service: kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + Host: kong.String("example.com"), + }, + } + assert.Nil(k.Add(svc1)) + + svc1.Name = kong.String("bar-name") + assert.Nil(k.Update(svc1)) + + r, err := k.Get("foo-id") + assert.Nil(err) + assert.NotNil(r) + + r, err = k.Get("bar-name") + assert.Nil(err) + assert.NotNil(r) + + r, err = k.Get("foo-name") + assert.NotNil(err) + assert.Nil(r) +} + +// Regression test +// to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestServiceGetMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := servicesCollection() + + var service Service + service.Name = kong.String("my-service") + service.ID = kong.String("first") + err := collection.Add(service) + assert.Nil(err) + + se, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(se) + se.Host = kong.String("example.com") + + se, err = collection.Get("my-service") + assert.Nil(err) + assert.NotNil(se) + assert.Nil(se.Host) +} + +func TestServicesInvalidType(t *testing.T) { + assert := assert.New(t) + collection := servicesCollection() + + var route Route + route.Name = kong.String("my-route") + route.ID = kong.String("first") + txn := collection.db.Txn(true) + txn.Insert(serviceTableName, &route) + txn.Commit() + + assert.Panics(func() { + collection.Get("my-route") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestServiceDelete(t *testing.T) { + assert := assert.New(t) + collection := servicesCollection() + + var service Service + service.ID = kong.String("first") + service.Host = kong.String("example.com") + err := collection.Add(service) + assert.Nil(err) + + err = collection.Delete("does-not-exist") + assert.NotNil(err) + err = collection.Delete("first") + assert.Nil(err) + + err = collection.Delete("first") + assert.NotNil(err) + + err = collection.Delete("") + assert.NotNil(err) +} + +func TestServiceGetAll(t *testing.T) { + assert := assert.New(t) + collection := servicesCollection() + + services := []Service{ + { + Service: kong.Service{ + ID: kong.String("first"), + Name: kong.String("my-service1"), + Host: kong.String("example.com"), + }, + }, + { + Service: kong.Service{ + ID: kong.String("second"), + Name: kong.String("my-service2"), + Host: kong.String("example.com"), + }, + }, + } + for _, s := range services { + assert.Nil(collection.Add(s)) + } + + allServices, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(len(services), len(allServices)) +} + +// Regression test +// to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestServiceGetAllMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := servicesCollection() + + services := []Service{ + { + Service: kong.Service{ + ID: kong.String("first"), + Name: kong.String("my-service1"), + Host: kong.String("example.com"), + }, + }, + { + Service: kong.Service{ + ID: kong.String("second"), + Name: kong.String("my-service2"), + Host: kong.String("example.com"), + }, + }, + } + for _, s := range services { + assert.Nil(collection.Add(s)) + } + + allServices, err := collection.GetAll() + assert.Nil(err) + assert.Equal(len(services), len(allServices)) + + allServices[0].Host = kong.String("new.example.com") + allServices[1].Host = kong.String("new.example.com") + + service, err := collection.Get("my-service1") + assert.Nil(err) + assert.NotNil(service) + assert.Equal("example.com", *service.Host) +} diff --git a/pkg/state/service_version.go b/pkg/state/service_version.go new file mode 100644 index 0000000..af0b981 --- /dev/null +++ b/pkg/state/service_version.go @@ -0,0 +1,226 @@ +package state + +import ( + "errors" + "fmt" + + "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + serviceVersionTableName = "service-version" + versionsByServicePackageID = "serviceVersionsByServicePackageID" +) + +var errInvalidPackage = fmt.Errorf("servicePackage.ID is required in ServiceVersion") + +var serviceVersionTableSchema = &memdb.TableSchema{ + Name: serviceVersionTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + all: allIndex, + // foreign + versionsByServicePackageID: { + Name: versionsByServicePackageID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "ServicePackage", + Sub: "ID", + }, + }, + }, + }, + }, +} + +func validatePackage(version ServiceVersion) error { + if version.ServicePackage == nil || + utils.Empty(version.ServicePackage.ID) { + return errInvalidPackage + } + return nil +} + +// ServiceVersionsCollection stores and indexes Service Versions. +type ServiceVersionsCollection collection + +// Add adds a serviceVersion into ServiceVersionsCollection +// serviceVersion.ID should not be nil else an error is thrown. +func (k *ServiceVersionsCollection) Add(serviceVersion ServiceVersion) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(serviceVersion.ID) { + return errIDRequired + } + + if err := validatePackage(serviceVersion); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *serviceVersion.ID) + if !utils.Empty(serviceVersion.Version) { + searchBy = append(searchBy, *serviceVersion.Version) + } + _, err := getServiceVersion(txn, *serviceVersion.ServicePackage.ID, searchBy...) + if err == nil { + return fmt.Errorf("inserting serviceVersion %v: %w", serviceVersion.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(serviceVersionTableName, &serviceVersion) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getServiceVersion(txn *memdb.Txn, packageID string, IDs ...string) (*ServiceVersion, error) { + if packageID == "" { + return nil, fmt.Errorf("packageID is required") + } + versions, err := getAllByPackageID(txn, packageID) + if err != nil { + return nil, err + } + + for _, id := range IDs { + for _, version := range versions { + if id == *version.ID || id == *version.Version { + return &ServiceVersion{ServiceVersion: *version.DeepCopy()}, nil + } + } + } + return nil, ErrNotFound +} + +// Get gets a Service Version by name or ID. +func (k *ServiceVersionsCollection) Get(packageID, nameOrID string) (*ServiceVersion, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + serviceVersion, err := getServiceVersion(txn, packageID, nameOrID) + if err != nil { + return nil, err + } + return serviceVersion, nil +} + +// Update updates a Service Version. +func (k *ServiceVersionsCollection) Update(serviceVersion ServiceVersion) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(serviceVersion.ID) { + return errIDRequired + } + if err := validatePackage(serviceVersion); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteServiceVersion(txn, *serviceVersion.ServicePackage.ID, *serviceVersion.ID) + if err != nil { + return err + } + + err = txn.Insert(serviceVersionTableName, &serviceVersion) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteServiceVersion(txn *memdb.Txn, packageID, nameOrID string) error { + serviceVersion, err := getServiceVersion(txn, packageID, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(serviceVersionTableName, serviceVersion) + if err != nil { + return err + } + return nil +} + +// Delete deletes a serviceVersion by name or ID. +func (k *ServiceVersionsCollection) Delete(packageID, nameOrID string) error { + if nameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteServiceVersion(txn, packageID, nameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all serviceVersios. +func (k *ServiceVersionsCollection) GetAll() ([]*ServiceVersion, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(serviceVersionTableName, all, true) + if err != nil { + return nil, err + } + + var res []*ServiceVersion + for el := iter.Next(); el != nil; el = iter.Next() { + s, ok := el.(*ServiceVersion) + if !ok { + panic(unexpectedType) + } + res = append(res, &ServiceVersion{ServiceVersion: *s.DeepCopy()}) + } + txn.Commit() + return res, nil +} + +func getAllByPackageID(txn *memdb.Txn, packageID string) ([]*ServiceVersion, error) { + iter, err := txn.Get(serviceVersionTableName, versionsByServicePackageID, packageID) + if err != nil { + return nil, err + } + + var versions []*ServiceVersion + for el := iter.Next(); el != nil; el = iter.Next() { + v, ok := el.(*ServiceVersion) + if !ok { + panic(unexpectedType) + } + versions = append(versions, &ServiceVersion{ServiceVersion: *v.DeepCopy()}) + } + return versions, nil +} + +// GetAllByServicePackageID returns all serviceVersions for a servicePackage id. +func (k *ServiceVersionsCollection) GetAllByServicePackageID(id string) ([]*ServiceVersion, + error, +) { + txn := k.db.Txn(false) + return getAllByPackageID(txn, id) +} diff --git a/pkg/state/service_version_test.go b/pkg/state/service_version_test.go new file mode 100644 index 0000000..8e93506 --- /dev/null +++ b/pkg/state/service_version_test.go @@ -0,0 +1,415 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/kong/deck/konnect" + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func serviceVersionCollection() *ServiceVersionsCollection { + return state().ServiceVersions +} + +func TestServiceVersionCollection_Add(t *testing.T) { + type args struct { + serviceVersion ServiceVersion + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + serviceVersion: ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + Version: kong.String("foo"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors without a version", + args: args{ + serviceVersion: ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("id1"), + }, + }, + }, + wantErr: true, + }, + { + name: "errors without a ServicePackage", + args: args{ + serviceVersion: ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("id1"), + Version: kong.String("bar-name"), + }, + }, + }, + wantErr: true, + }, + { + name: "inserts with all valid details", + args: args{ + serviceVersion: ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("id2"), + Version: kong.String("bar-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "errors on re-insert when version is already present", + args: args{ + serviceVersion: ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("id4"), + Version: kong.String("foo-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert when id is present", + args: args{ + serviceVersion: ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("id3"), + Version: kong.String("foobar-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + }, + wantErr: true, + }, + } + k := serviceVersionCollection() + sv1 := ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("id3"), + Version: kong.String("foo-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + k.Add(sv1) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.serviceVersion); (err != nil) != tt.wantErr { + t.Errorf("ServiceVersionCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestServiceVersionCollection_Get(t *testing.T) { + type args struct { + nameOrID string + packageID string + } + sv1 := ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("foo-id"), + Version: kong.String("foo-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + sv2 := ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("bar-id"), + Version: kong.String("bar-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + tests := []struct { + name string + args args + want *ServiceVersion + wantErr bool + }{ + { + name: "gets a serviceVersion by package and version ID", + args: args{ + nameOrID: "foo-id", + packageID: "id1", + }, + want: &sv1, + wantErr: false, + }, + { + name: "returns an error when only version is specified", + args: args{ + nameOrID: "bar-name", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an ErrNotFound when no serviceVersion found", + args: args{ + nameOrID: "baz-id", + packageID: "id1", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an error when ID is empty", + args: args{ + nameOrID: "", + }, + want: nil, + wantErr: true, + }, + } + k := serviceVersionCollection() + k.Add(sv1) + k.Add(sv2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := k.Get(tt.args.packageID, tt.args.nameOrID) + if (err != nil) != tt.wantErr { + t.Errorf("ServiceVersionCollection.Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ServiceVersionCollection.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestServiceVersionCollection_Update(t *testing.T) { + sv1 := ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("foo-id"), + Version: kong.String("foo-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + sv2 := ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("bar-id"), + Version: kong.String("bar-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + sv3 := ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("foo-id"), + Version: kong.String("new-foo-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + } + type args struct { + serviceVersion ServiceVersion + } + tests := []struct { + name string + args args + wantErr bool + updatedVersion *ServiceVersion + }{ + { + name: "update errors if serviceVersion.ID is nil", + args: args{ + serviceVersion: ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + Version: kong.String("name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "update errors if serviceVersion does not exist", + args: args{ + serviceVersion: ServiceVersion{ + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("does-not-exist"), + }, + }, + }, + wantErr: true, + }, + { + name: "update succeeds when ID is supplied", + args: args{ + serviceVersion: sv3, + }, + wantErr: false, + updatedVersion: &sv3, + }, + } + k := serviceVersionCollection() + k.Add(sv1) + k.Add(sv2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // t.Parallel() + if err := k.Update(tt.args.serviceVersion); (err != nil) != tt.wantErr { + t.Errorf("ServiceVersionCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + got, _ := k.Get(*tt.updatedVersion.ServicePackage.ID, *tt.updatedVersion.ID) + + if !reflect.DeepEqual(got, tt.updatedVersion) { + t.Errorf("update serviceVersion, got = %#v, want %#v", got, tt.updatedVersion) + } + } + }) + } +} + +func TestServiceVersionDelete(t *testing.T) { + assert := assert.New(t) + collection := serviceVersionCollection() + + var serviceVersion ServiceVersion + serviceVersion.Version = kong.String("my-serviceVersion") + serviceVersion.ID = kong.String("first") + serviceVersion.ServicePackage = &konnect.ServicePackage{ + ID: kong.String("package-id1"), + } + err := collection.Add(serviceVersion) + assert.Nil(err) + + re, err := collection.Get("package-id1", "my-serviceVersion") + assert.Nil(err) + assert.NotNil(re) + + err = collection.Delete("package-id1", *re.ID) + assert.Nil(err) + + err = collection.Delete("package-id1", *re.ID) + assert.NotNil(err) +} + +func TestServiceVersionGetAll(t *testing.T) { + assert := assert.New(t) + collection := serviceVersionCollection() + + var serviceVersion ServiceVersion + serviceVersion.Version = kong.String("my-sv1") + serviceVersion.ID = kong.String("first") + serviceVersion.ServicePackage = &konnect.ServicePackage{ + ID: kong.String("id1"), + } + err := collection.Add(serviceVersion) + assert.Nil(err) + + var sv2 ServiceVersion + sv2.Version = kong.String("my-sv2") + sv2.ID = kong.String("second") + sv2.ServicePackage = &konnect.ServicePackage{ + ID: kong.String("id1"), + } + err = collection.Add(sv2) + assert.Nil(err) + + serviceVersions, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(serviceVersions)) +} + +func TestServiceVersionGetAllByServiceID(t *testing.T) { + assert := assert.New(t) + collection := serviceVersionCollection() + + serviceVersions := []*ServiceVersion{ + { + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("sv1-id"), + Version: kong.String("sv1-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + { + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("sv2-id"), + Version: kong.String("sv2-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id1"), + }, + }, + }, + { + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("sv3-id"), + Version: kong.String("sv3-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id2"), + }, + }, + }, + { + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("sv4-id"), + Version: kong.String("sv4-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id2"), + }, + }, + }, + { + ServiceVersion: konnect.ServiceVersion{ + ID: kong.String("sv5-id"), + Version: kong.String("sv5-name"), + ServicePackage: &konnect.ServicePackage{ + ID: kong.String("id2"), + }, + }, + }, + } + + for _, serviceVersion := range serviceVersions { + err := collection.Add(*serviceVersion) + assert.Nil(err) + } + + serviceVersions, err := collection.GetAllByServicePackageID("id1") + assert.Nil(err) + assert.Equal(2, len(serviceVersions)) + + serviceVersions, err = collection.GetAllByServicePackageID("id2") + assert.Nil(err) + assert.Equal(3, len(serviceVersions)) +} diff --git a/pkg/state/sni.go b/pkg/state/sni.go new file mode 100644 index 0000000..280dae9 --- /dev/null +++ b/pkg/state/sni.go @@ -0,0 +1,226 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + sniTableName = "sni" + snisByCertID = "snisByCertID" +) + +var errInvalidCert = fmt.Errorf("certificate.ID is required in sni") + +var sniTableSchema = &memdb.TableSchema{ + Name: sniTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Name"}, + AllowMissing: true, + }, + all: allIndex, + // foreign + snisByCertID: { + Name: snisByCertID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Certificate", + Sub: "ID", + }, + }, + }, + }, + }, +} + +func validateCertForSNI(sni *SNI) error { + if sni.Certificate == nil || + utils.Empty(sni.Certificate.ID) { + return errInvalidCert + } + return nil +} + +// SNIsCollection stores and indexes Kong SNIs. +type SNIsCollection collection + +// Add adds a sni into SNIsCollection +// sni.ID should not be nil else an error is thrown. +func (k *SNIsCollection) Add(sni SNI) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(sni.ID) { + return errIDRequired + } + + if err := validateCertForSNI(&sni); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *sni.ID) + if !utils.Empty(sni.Name) { + searchBy = append(searchBy, *sni.Name) + } + _, err := getSNI(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting sni %v: %w", sni.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(sniTableName, &sni) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getSNI(txn *memdb.Txn, IDs ...string) (*SNI, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, sniTableName, + []string{"name", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + + sni, ok := res.(*SNI) + if !ok { + panic(unexpectedType) + } + return &SNI{SNI: *sni.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets a sni by name or ID. +func (k *SNIsCollection) Get(nameOrID string) (*SNI, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + sni, err := getSNI(txn, nameOrID) + if err != nil { + return nil, err + } + return sni, nil +} + +// Update updates a sni +func (k *SNIsCollection) Update(sni SNI) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(sni.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteSNI(txn, *sni.ID) + if err != nil { + return err + } + + err = txn.Insert(sniTableName, &sni) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteSNI(txn *memdb.Txn, nameOrID string) error { + sni, err := getSNI(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(sniTableName, sni) + if err != nil { + return err + } + return nil +} + +// Delete deletes a sni by name or ID. +func (k *SNIsCollection) Delete(nameOrID string) error { + if nameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteSNI(txn, nameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a sni by name or ID. +func (k *SNIsCollection) GetAll() ([]*SNI, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(sniTableName, all, true) + if err != nil { + return nil, err + } + + var res []*SNI + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*SNI) + if !ok { + panic(unexpectedType) + } + res = append(res, &SNI{SNI: *r.DeepCopy()}) + } + txn.Commit() + return res, nil +} + +// GetAllByCertID returns all routes referencing a service +// by its id. +func (k *SNIsCollection) GetAllByCertID(id string) ([]*SNI, + error, +) { + txn := k.db.Txn(false) + iter, err := txn.Get(sniTableName, snisByCertID, id) + if err != nil { + return nil, err + } + var res []*SNI + for el := iter.Next(); el != nil; el = iter.Next() { + r, ok := el.(*SNI) + if !ok { + panic(unexpectedType) + } + res = append(res, &SNI{SNI: *r.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/sni_test.go b/pkg/state/sni_test.go new file mode 100644 index 0000000..0ec3803 --- /dev/null +++ b/pkg/state/sni_test.go @@ -0,0 +1,481 @@ +package state + +import ( + "reflect" + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func snisCollection() *SNIsCollection { + return state().SNIs +} + +func TestSNIsCollection_Add(t *testing.T) { + type args struct { + sni SNI + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "errors when ID is nil", + args: args{ + sni: SNI{ + SNI: kong.SNI{ + Name: kong.String("foo"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "inserts without a name", + args: args{ + sni: SNI{ + SNI: kong.SNI{ + ID: kong.String("id1"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "inserts with a name and ID", + args: args{ + sni: SNI{ + SNI: kong.SNI{ + ID: kong.String("id2"), + Name: kong.String("bar-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + }, + wantErr: false, + }, + { + name: "errors on re-insert when name is present", + args: args{ + sni: SNI{ + SNI: kong.SNI{ + ID: kong.String("id4"), + Name: kong.String("foo-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert when id is present", + args: args{ + sni: SNI{ + SNI: kong.SNI{ + ID: kong.String("id3"), + Name: kong.String("foobar-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "errors on re-insert when id is present", + args: args{ + sni: SNI{ + SNI: kong.SNI{ + ID: kong.String("id3"), + Name: kong.String("foobar-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + }, + wantErr: true, + }, + } + k := snisCollection() + sni1 := SNI{ + SNI: kong.SNI{ + ID: kong.String("id3"), + Name: kong.String("foo-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + } + k.Add(sni1) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := k.Add(tt.args.sni); (err != nil) != tt.wantErr { + t.Errorf("SNIsCollection.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSNIsCollection_Get(t *testing.T) { + type args struct { + nameOrID string + } + sni1 := SNI{ + SNI: kong.SNI{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + } + sni2 := SNI{ + SNI: kong.SNI{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + } + tests := []struct { + name string + args args + want *SNI + wantErr bool + }{ + { + name: "gets a sni by ID", + args: args{ + nameOrID: "foo-id", + }, + want: &sni1, + wantErr: false, + }, + { + name: "gets a sni by Name", + args: args{ + nameOrID: "bar-name", + }, + want: &sni2, + wantErr: false, + }, + { + name: "returns an ErrNotFound when no sni found", + args: args{ + nameOrID: "baz-id", + }, + want: nil, + wantErr: true, + }, + { + name: "returns an error when ID is empty", + args: args{ + nameOrID: "", + }, + want: nil, + wantErr: true, + }, + } + k := snisCollection() + k.Add(sni1) + k.Add(sni2) + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := k.Get(tc.args.nameOrID) + if (err != nil) != tc.wantErr { + t.Errorf("SNIsCollection.Get() error = %v, wantErr %v", err, tc.wantErr) + return + } + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("SNIsCollection.Get() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestSNIsInvalidType(t *testing.T) { + assert := assert.New(t) + + collection := snisCollection() + + type derivedSNI struct { + SNI + } + + var sni derivedSNI + sni.SNI = SNI{ + SNI: kong.SNI{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + } + txn := collection.db.Txn(true) + txn.Insert(sniTableName, &sni) + txn.Commit() + + assert.Panics(func() { + collection.Get("foo-id") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestSNIsCollection_Update(t *testing.T) { + sni1 := SNI{ + SNI: kong.SNI{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + } + sni2 := SNI{ + SNI: kong.SNI{ + ID: kong.String("bar-id"), + Name: kong.String("bar-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + } + sni3 := SNI{ + SNI: kong.SNI{ + ID: kong.String("foo-id"), + Name: kong.String("name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + } + type args struct { + sni SNI + } + tests := []struct { + name string + args args + wantErr bool + updatedSNI *SNI + }{ + { + name: "update errors if sni.ID is nil", + args: args{ + sni: SNI{ + SNI: kong.SNI{ + Name: kong.String("name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "update errors if sni does not exist", + args: args{ + sni: SNI{ + SNI: kong.SNI{ + ID: kong.String("does-not-exist"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "update succeeds when ID is supplied", + args: args{ + sni: sni3, + }, + wantErr: false, + updatedSNI: &sni3, + }, + } + k := snisCollection() + k.Add(sni1) + k.Add(sni2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // t.Parallel() + if err := k.Update(tt.args.sni); (err != nil) != tt.wantErr { + t.Errorf("SNIsCollection.Update() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + got, _ := k.Get(*tt.updatedSNI.ID) + + if !reflect.DeepEqual(got, tt.updatedSNI) { + t.Errorf("update sni, got = %#v, want %#v", got, tt.updatedSNI) + } + } + }) + } +} + +// Regression test +// to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestSNIGetMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := snisCollection() + + var sni SNI + sni.Name = kong.String("my-sni") + sni.ID = kong.String("first") + sni.Certificate = &kong.Certificate{ + ID: kong.String("cert1-id"), + } + err := collection.Add(sni) + assert.Nil(err) + + re, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(re) + assert.Equal("my-sni", *re.Name) + + re, err = collection.Get("my-sni") + assert.Nil(err) + assert.NotNil(re) +} + +func TestSNIDelete(t *testing.T) { + assert := assert.New(t) + collection := snisCollection() + + var sni SNI + sni.Name = kong.String("my-sni") + sni.ID = kong.String("first") + sni.Certificate = &kong.Certificate{ + ID: kong.String("cert1-id"), + } + err := collection.Add(sni) + assert.Nil(err) + + re, err := collection.Get("my-sni") + assert.Nil(err) + assert.NotNil(re) + assert.Equal("first", *re.ID) + + err = collection.Delete(*re.ID) + assert.Nil(err) + + err = collection.Delete(*re.ID) + assert.NotNil(err) +} + +func TestSNIGetAll(t *testing.T) { + assert := assert.New(t) + collection := snisCollection() + + var sni SNI + sni.Name = kong.String("my-sni1") + sni.ID = kong.String("first") + sni.Certificate = &kong.Certificate{ + ID: kong.String("cert1-id"), + } + err := collection.Add(sni) + assert.Nil(err) + + var sni2 SNI + sni2.Name = kong.String("my-sni2") + sni2.ID = kong.String("second") + sni2.Certificate = &kong.Certificate{ + ID: kong.String("cert1-id"), + } + err = collection.Add(sni2) + assert.Nil(err) + + snis, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(snis)) +} + +func TestSNIGetAllByServiceID(t *testing.T) { + assert := assert.New(t) + collection := snisCollection() + + snis := []*SNI{ + { + SNI: kong.SNI{ + ID: kong.String("sni1-id"), + Name: kong.String("sni1-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + { + SNI: kong.SNI{ + ID: kong.String("sni2-id"), + Certificate: &kong.Certificate{ + ID: kong.String("cert1-id"), + }, + }, + }, + { + SNI: kong.SNI{ + ID: kong.String("sni3-id"), + Name: kong.String("sni3-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert2-id"), + }, + }, + }, + { + SNI: kong.SNI{ + ID: kong.String("sni4-id"), + Name: kong.String("sni4-name"), + Certificate: &kong.Certificate{ + ID: kong.String("cert2-id"), + }, + }, + }, + { + SNI: kong.SNI{ + ID: kong.String("sni5-id"), + Certificate: &kong.Certificate{ + ID: kong.String("cert2-id"), + }, + }, + }, + } + + for _, sni := range snis { + err := collection.Add(*sni) + assert.Nil(err) + } + + snis, err := collection.GetAllByCertID("cert1-id") + assert.Nil(err) + assert.Equal(2, len(snis)) + + snis, err = collection.GetAllByCertID("cert2-id") + assert.Nil(err) + assert.Equal(3, len(snis)) +} diff --git a/pkg/state/state.go b/pkg/state/state.go new file mode 100644 index 0000000..7972b77 --- /dev/null +++ b/pkg/state/state.go @@ -0,0 +1,131 @@ +package state + +import ( + "fmt" + + memdb "github.com/hashicorp/go-memdb" +) + +type collection struct { + db *memdb.MemDB +} + +// KongState is an in-memory database representation +// of Kong's configuration. +type KongState struct { + common collection + Services *ServicesCollection + Routes *RoutesCollection + Upstreams *UpstreamsCollection + Targets *TargetsCollection + Certificates *CertificatesCollection + SNIs *SNIsCollection + CACertificates *CACertificatesCollection + Plugins *PluginsCollection + Consumers *ConsumersCollection + Vaults *VaultsCollection + ConsumerGroups *ConsumerGroupsCollection + ConsumerGroupConsumers *ConsumerGroupConsumersCollection + ConsumerGroupPlugins *ConsumerGroupPluginsCollection + + KeyAuths *KeyAuthsCollection + HMACAuths *HMACAuthsCollection + JWTAuths *JWTAuthsCollection + BasicAuths *BasicAuthsCollection + ACLGroups *ACLGroupsCollection + Oauth2Creds *Oauth2CredsCollection + MTLSAuths *MTLSAuthsCollection + RBACRoles *RBACRolesCollection + RBACEndpointPermissions *RBACEndpointPermissionsCollection + + // konnect-specific entities + ServicePackages *ServicePackagesCollection + ServiceVersions *ServiceVersionsCollection + Documents *DocumentsCollection +} + +// NewKongState creates a new in-memory KongState. +func NewKongState() (*KongState, error) { + // TODO FIXME clean up the mess + keyAuthTemp := newKeyAuthsCollection(collection{}) + hmacAuthTemp := newHMACAuthsCollection(collection{}) + basicAuthTemp := newBasicAuthsCollection(collection{}) + jwtAuthTemp := newJWTAuthsCollection(collection{}) + oauth2CredsTemp := newOauth2CredsCollection(collection{}) + mtlsAuthTemp := newMTLSAuthsCollection(collection{}) + + schema := &memdb.DBSchema{ + Tables: map[string]*memdb.TableSchema{ + serviceTableName: serviceTableSchema, + routeTableName: routeTableSchema, + upstreamTableName: upstreamTableSchema, + targetTableName: targetTableSchema, + certificateTableName: certificateTableSchema, + sniTableName: sniTableSchema, + caCertTableName: caCertTableSchema, + pluginTableName: pluginTableSchema, + consumerTableName: consumerTableSchema, + consumerGroupTableName: consumerGroupTableSchema, + consumerGroupConsumerTableName: consumerGroupConsumerTableSchema, + consumerGroupPluginTableName: consumerGroupPluginTableSchema, + rbacRoleTableName: rbacRoleTableSchema, + rbacEndpointPermissionTableName: rbacEndpointPermissionTableSchema, + vaultTableName: vaultTableSchema, + + keyAuthTemp.TableName(): keyAuthTemp.Schema(), + hmacAuthTemp.TableName(): hmacAuthTemp.Schema(), + basicAuthTemp.TableName(): basicAuthTemp.Schema(), + jwtAuthTemp.TableName(): jwtAuthTemp.Schema(), + oauth2CredsTemp.TableName(): oauth2CredsTemp.Schema(), + mtlsAuthTemp.TableName(): mtlsAuthTemp.Schema(), + + aclGroupTableName: aclGroupTableSchema, + + // konnect-specific entities + servicePackageTableName: servicePackageTableSchema, + serviceVersionTableName: serviceVersionTableSchema, + documentTableName: documentTableSchema, + }, + } + + memDB, err := memdb.NewMemDB(schema) + if err != nil { + return nil, fmt.Errorf("creating new ServiceCollection: %w", err) + } + var state KongState + state.common = collection{ + db: memDB, + } + + state.Services = (*ServicesCollection)(&state.common) + state.Routes = (*RoutesCollection)(&state.common) + state.Upstreams = (*UpstreamsCollection)(&state.common) + state.Targets = (*TargetsCollection)(&state.common) + state.Certificates = (*CertificatesCollection)(&state.common) + state.SNIs = (*SNIsCollection)(&state.common) + state.CACertificates = (*CACertificatesCollection)(&state.common) + state.Plugins = (*PluginsCollection)(&state.common) + state.Consumers = (*ConsumersCollection)(&state.common) + state.ConsumerGroups = (*ConsumerGroupsCollection)(&state.common) + state.ConsumerGroupConsumers = (*ConsumerGroupConsumersCollection)(&state.common) + state.ConsumerGroupPlugins = (*ConsumerGroupPluginsCollection)(&state.common) + state.RBACRoles = (*RBACRolesCollection)(&state.common) + state.RBACEndpointPermissions = (*RBACEndpointPermissionsCollection)(&state.common) + state.Vaults = (*VaultsCollection)(&state.common) + + state.KeyAuths = newKeyAuthsCollection(state.common) + state.HMACAuths = newHMACAuthsCollection(state.common) + state.BasicAuths = newBasicAuthsCollection(state.common) + state.JWTAuths = newJWTAuthsCollection(state.common) + state.Oauth2Creds = newOauth2CredsCollection(state.common) + state.MTLSAuths = newMTLSAuthsCollection(state.common) + + state.ACLGroups = (*ACLGroupsCollection)(&state.common) + + // konnect-specific entities + state.ServicePackages = (*ServicePackagesCollection)(&state.common) + state.ServiceVersions = (*ServiceVersionsCollection)(&state.common) + state.Documents = (*DocumentsCollection)(&state.common) + + return &state, nil +} diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go new file mode 100644 index 0000000..b6b3e62 --- /dev/null +++ b/pkg/state/state_test.go @@ -0,0 +1,22 @@ +package state + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewState(t *testing.T) { + state, err := NewKongState() + assert := assert.New(t) + assert.Nil(err) + assert.NotNil(state) +} + +func state() *KongState { + s, err := NewKongState() + if err != nil { + panic(err) + } + return s +} diff --git a/pkg/state/target.go b/pkg/state/target.go new file mode 100644 index 0000000..b75f801 --- /dev/null +++ b/pkg/state/target.go @@ -0,0 +1,236 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/state/indexers" + "github.com/kong/deck/utils" +) + +const ( + targetTableName = "target" + targetsByUpstreamID = "targetsByUpstreamID" +) + +var errInvalidUpstream = fmt.Errorf("upstream.ID is required in target") + +var targetTableSchema = &memdb.TableSchema{ + Name: targetTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + all: allIndex, + // foreign + targetsByUpstreamID: { + Name: targetsByUpstreamID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Upstream", + Sub: "ID", + }, + }, + }, + }, + }, +} + +func validateUpstream(target *Target) error { + if target.Upstream == nil || + utils.Empty(target.Upstream.ID) { + return errInvalidUpstream + } + return nil +} + +// TargetsCollection stores and indexes Kong Upstreams. +type TargetsCollection collection + +// Add adds a target to TargetsCollection. +// target should have an ID, Target and it's upstream's ID is set. +func (k *TargetsCollection) Add(target Target) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(target.ID) { + return errIDRequired + } + + if err := validateUpstream(&target); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *target.ID) + if !utils.Empty(target.Target.Target) { + searchBy = append(searchBy, *target.Target.Target) + } + _, err := getTarget(txn, *target.Upstream.ID, searchBy...) + if err == nil { + return fmt.Errorf("inserting target %v: %w", target.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(targetTableName, &target) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getTarget(txn *memdb.Txn, upstreamID string, IDs ...string) (*Target, error) { + targets, err := getAllByUpstreamID(txn, upstreamID) + if err != nil { + return nil, err + } + + for _, id := range IDs { + for _, target := range targets { + if id == *target.ID || id == *target.Target.Target { + return &Target{Target: *target.DeepCopy()}, nil + } + } + } + return nil, ErrNotFound +} + +func getAllByUpstreamID(txn *memdb.Txn, upstreamID string) ([]*Target, error) { + iter, err := txn.Get(targetTableName, targetsByUpstreamID, upstreamID) + if err != nil { + return nil, err + } + + var targets []*Target + for el := iter.Next(); el != nil; el = iter.Next() { + t, ok := el.(*Target) + if !ok { + panic(unexpectedType) + } + targets = append(targets, &Target{Target: *t.DeepCopy()}) + } + return targets, nil +} + +// Get returns a specific target for upstream with upstreamID. +func (k *TargetsCollection) Get(upstreamID, + targetOrID string, +) (*Target, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + return getTarget(txn, upstreamID, targetOrID) +} + +// Update updates a target +func (k *TargetsCollection) Update(target Target) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(target.ID) { + return errIDRequired + } + + if err := validateUpstream(&target); err != nil { + return err + } + + txn := k.db.Txn(true) + defer txn.Abort() + + // This doesn't follow the usual getTarget() because + // the target.Upstream.ID can be different from the one in the DB. + res, err := multiIndexLookupUsingTxn(txn, targetTableName, + []string{"id"}, *target.ID) + if err != nil { + return err + } + + t, ok := res.(*Target) + if !ok { + panic(unexpectedType) + } + err = txn.Delete(targetTableName, *t) + if err != nil { + return err + } + + err = txn.Insert(targetTableName, &target) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteTarget(txn *memdb.Txn, upstreamID, targetOrID string) error { + target, err := getTarget(txn, upstreamID, targetOrID) + if err != nil { + return err + } + + err = txn.Delete(targetTableName, target) + if err != nil { + return err + } + return nil +} + +// Delete deletes a target by its ID. +func (k *TargetsCollection) Delete(upstreamID, targetOrID string) error { + if targetOrID == "" { + return errIDRequired + } + + if upstreamID == "" { + return errInvalidUpstream + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteTarget(txn, upstreamID, targetOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a target by Target or ID. +func (k *TargetsCollection) GetAll() ([]*Target, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(targetTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Target + for el := iter.Next(); el != nil; el = iter.Next() { + t, ok := el.(*Target) + if !ok { + panic(unexpectedType) + } + res = append(res, &Target{Target: *t.DeepCopy()}) + } + txn.Commit() + return res, nil +} + +// GetAllByUpstreamID returns all targets referencing a Upstream +// by its ID. +func (k *TargetsCollection) GetAllByUpstreamID(id string) ([]*Target, + error, +) { + txn := k.db.Txn(false) + return getAllByUpstreamID(txn, id) +} diff --git a/pkg/state/target_test.go b/pkg/state/target_test.go new file mode 100644 index 0000000..a3b3a7d --- /dev/null +++ b/pkg/state/target_test.go @@ -0,0 +1,263 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func targetsCollection() *TargetsCollection { + return state().Targets +} + +func TestTargetInsert(t *testing.T) { + assert := assert.New(t) + collection := targetsCollection() + + var t0 Target + t0.Target.Target = kong.String("my-target") + err := collection.Add(t0) + assert.NotNil(err) + + t0.ID = kong.String("first") + err = collection.Add(t0) + assert.NotNil(err) + + var t1 Target + t1.Target.Target = kong.String("my-target") + t1.ID = kong.String("first") + t1.Upstream = &kong.Upstream{ + ID: kong.String("upstream1-id"), + } + err = collection.Add(t1) + assert.Nil(err) + + var t2 Target + t2.Target.Target = kong.String("my-target") + t2.ID = kong.String("second") + t2.Upstream = &kong.Upstream{ + ID: kong.String("upstream1-id"), + } + err = collection.Add(t2) + assert.NotNil(err) + + var t3 Target + t3.Target.Target = kong.String("my-target") + t3.ID = kong.String("third") + t3.Upstream = &kong.Upstream{ + Name: kong.String("upstream1-id"), + } + err = collection.Add(t3) + assert.NotNil(err) +} + +func TestTargetGetUpdate(t *testing.T) { + assert := assert.New(t) + collection := targetsCollection() + + var target Target + target.Target.Target = kong.String("my-target") + target.ID = kong.String("first") + target.Upstream = &kong.Upstream{ + ID: kong.String("upstream1-id"), + } + assert.NotNil(target.Upstream) + err := collection.Add(target) + assert.Nil(err) + + re, err := collection.Get("upstream1-id", "first") + assert.Nil(err) + assert.NotNil(re) + assert.Equal("my-target", *re.Target.Target) + + re.ID = nil + re.Upstream.ID = nil + assert.NotNil(collection.Update(*re)) + + re.ID = kong.String("does-not-exist") + assert.NotNil(collection.Update(*re)) + + re.ID = kong.String("first") + assert.NotNil(collection.Update(*re)) + + re.Upstream.ID = kong.String("upstream1-id") + assert.Nil(collection.Update(*re)) + + re.Upstream.ID = kong.String("upstream2-id") + assert.Nil(collection.Update(*re)) +} + +// Regression test +// to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestTargetGetMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := targetsCollection() + + var target Target + target.Target.Target = kong.String("my-target") + target.ID = kong.String("first") + target.Upstream = &kong.Upstream{ + ID: kong.String("upstream1-id"), + } + err := collection.Add(target) + assert.Nil(err) + + re, err := collection.Get("upstream1-id", "first") + assert.Nil(err) + assert.NotNil(re) + assert.Equal("my-target", *re.Target.Target) + + re.Weight = kong.Int(1) + + re, err = collection.Get("upstream1-id", "my-target") + assert.Nil(err) + assert.NotNil(re) + assert.Nil(re.Weight) +} + +func TestTargetsInvalidType(t *testing.T) { + assert := assert.New(t) + + collection := targetsCollection() + + type badTarget struct { + kong.Target + Meta + } + + target := badTarget{ + Target: kong.Target{ + ID: kong.String("id"), + Target: kong.String("target"), + Upstream: &kong.Upstream{ + ID: kong.String("upstream-id"), + }, + }, + } + + txn := collection.db.Txn(true) + err := txn.Insert(targetTableName, &target) + assert.Nil(err) + txn.Commit() + + assert.Panics(func() { + collection.Get("upstream-id", "id") + }) + + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestTargetDelete(t *testing.T) { + assert := assert.New(t) + collection := targetsCollection() + + var target Target + target.Target.Target = kong.String("my-target") + target.ID = kong.String("first") + target.Upstream = &kong.Upstream{ + ID: kong.String("upstream1-id"), + } + err := collection.Add(target) + assert.Nil(err) + + re, err := collection.Get("upstream1-id", "my-target") + assert.Nil(err) + assert.NotNil(re) + + err = collection.Delete("upstream1-id", *re.ID) + assert.Nil(err) + + err = collection.Delete("upstream1-id", *re.ID) + assert.NotNil(err) + + err = collection.Delete("", "first") + assert.NotNil(err) + + err = collection.Delete("foo", "") + assert.NotNil(err) +} + +func TestTargetGetAll(t *testing.T) { + assert := assert.New(t) + collection := targetsCollection() + + var target Target + target.Target.Target = kong.String("my-target1") + target.ID = kong.String("first") + target.Upstream = &kong.Upstream{ + ID: kong.String("upstream1-id"), + } + err := collection.Add(target) + assert.Nil(err) + + var target2 Target + target2.Target.Target = kong.String("my-target2") + target2.ID = kong.String("second") + target2.Upstream = &kong.Upstream{ + ID: kong.String("upstream1-id"), + } + err = collection.Add(target2) + assert.Nil(err) + + targets, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(targets)) +} + +func TestTargetGetAllByUpstreamName(t *testing.T) { + assert := assert.New(t) + collection := targetsCollection() + + targets := []*Target{ + { + Target: kong.Target{ + ID: kong.String("target1-id"), + Target: kong.String("target1-name"), + Upstream: &kong.Upstream{ + ID: kong.String("upstream1-id"), + }, + }, + }, + { + Target: kong.Target{ + ID: kong.String("target2-id"), + Target: kong.String("target2-name"), + Upstream: &kong.Upstream{ + ID: kong.String("upstream1-id"), + }, + }, + }, + { + Target: kong.Target{ + ID: kong.String("target3-id"), + Target: kong.String("target3-name"), + Upstream: &kong.Upstream{ + ID: kong.String("upstream2-id"), + }, + }, + }, + { + Target: kong.Target{ + ID: kong.String("target4-id"), + Target: kong.String("target4-name"), + Upstream: &kong.Upstream{ + ID: kong.String("upstream2-id"), + }, + }, + }, + } + + for _, target := range targets { + err := collection.Add(*target) + assert.Nil(err) + } + + targets, err := collection.GetAllByUpstreamID("upstream1-id") + assert.Nil(err) + assert.Equal(2, len(targets)) +} diff --git a/pkg/state/types.go b/pkg/state/types.go new file mode 100644 index 0000000..b9f14fb --- /dev/null +++ b/pkg/state/types.go @@ -0,0 +1,1601 @@ +package state + +import ( + "fmt" + "reflect" + "sort" + + "github.com/kong/go-kong/kong" +) + +// entity abstracts out common fields in a credentials. +// TODO generalize for each and every entity. +type entity interface { + // ID of the cred. + GetID() string + // ID2 is the second endpoint key. + GetID2() string + // Consumer returns consumer ID associated with the cred. + GetConsumer() string +} + +// ConsoleString contains methods to be used to print +// entity to console. +type ConsoleString interface { + // Console returns a string to uniquely identify an + // entity in human-readable form. + // It should have the ID or endpoint key along-with + // foreign references if they exist. + // It will be used to communicate to the human user + // that this entity is undergoing some change. + Console() string +} + +// Meta contains additional information for an entity +// type Meta struct { +// Name *string `json:"name,omitempty" yaml:"name,omitempty"` +// Global *bool `json:"global,omitempty" yaml:"global,omitempty"` +// Kind *string `json:"type,omitempty" yaml:"type,omitempty"` +// } + +// Meta stores metadata for any entity. +type Meta struct { + metaMap map[string]interface{} +} + +func (m *Meta) initMeta() { + if m.metaMap == nil { + m.metaMap = make(map[string]interface{}) + } +} + +// AddMeta adds key->obj metadata. +// It will override the old obj in key is already present. +func (m *Meta) AddMeta(key string, obj interface{}) { + m.initMeta() + m.metaMap[key] = obj +} + +// GetMeta returns the obj previously added using AddMeta(). +// It returns nil if key is not present. +func (m *Meta) GetMeta(key string) interface{} { + m.initMeta() + return m.metaMap[key] +} + +// Service represents a service in Kong. +// It adds some helper methods along with Meta to the original Service object. +type Service struct { + kong.Service `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (s1 *Service) Identifier() string { + if s1.Name != nil { + return *s1.Name + } + return *s1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (s1 *Service) Console() string { + return s1.FriendlyName() +} + +// Equal returns true if s1 and s2 are equal. +func (s1 *Service) Equal(s2 *Service) bool { + return s1.EqualWithOpts(s2, false, false) +} + +// EqualWithOpts returns true if s1 and s2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (s1 *Service) EqualWithOpts(s2 *Service, + ignoreID bool, ignoreTS bool, +) bool { + s1Copy := s1.Service.DeepCopy() + s2Copy := s2.Service.DeepCopy() + + if len(s1Copy.Tags) == 0 { + s1Copy.Tags = nil + } + if len(s2Copy.Tags) == 0 { + s2Copy.Tags = nil + } + + // Cassandra can sometimes mess up tag order, but tag order doesn't actually matter: tags are sets + // even though we represent them with slices. Sort before comparison to avoid spurious diff detection. + sort.Slice(s1Copy.Tags, func(i, j int) bool { return *(s1Copy.Tags[i]) < *(s1Copy.Tags[j]) }) + sort.Slice(s2Copy.Tags, func(i, j int) bool { return *(s2Copy.Tags[i]) < *(s2Copy.Tags[j]) }) + + if ignoreID { + s1Copy.ID = nil + s2Copy.ID = nil + } + if ignoreTS { + s1Copy.CreatedAt = nil + s2Copy.CreatedAt = nil + + s1Copy.UpdatedAt = nil + s2Copy.UpdatedAt = nil + } + return reflect.DeepEqual(s1Copy, s2Copy) +} + +// Route represents a route in Kong. +// It adds some helper methods along with Meta to the original Route object. +type Route struct { + kong.Route `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (r1 *Route) Identifier() string { + if r1.Name != nil { + return *r1.Name + } + return *r1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (r1 *Route) Console() string { + return r1.FriendlyName() +} + +// Equal returns true if r1 and r2 are equal. +// TODO add compare array without position +func (r1 *Route) Equal(r2 *Route) bool { + return r1.EqualWithOpts(r2, false, false, false) +} + +// EqualWithOpts returns true if r1 and r2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (r1 *Route) EqualWithOpts(r2 *Route, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + r1Copy := r1.Route.DeepCopy() + r2Copy := r2.Route.DeepCopy() + + if len(r1Copy.Tags) == 0 { + r1Copy.Tags = nil + } + if len(r2Copy.Tags) == 0 { + r2Copy.Tags = nil + } + + sort.Slice(r1Copy.Tags, func(i, j int) bool { return *(r1Copy.Tags[i]) < *(r1Copy.Tags[j]) }) + sort.Slice(r2Copy.Tags, func(i, j int) bool { return *(r2Copy.Tags[i]) < *(r2Copy.Tags[j]) }) + + if ignoreID { + r1Copy.ID = nil + r2Copy.ID = nil + } + if ignoreTS { + r1Copy.CreatedAt = nil + r2Copy.CreatedAt = nil + + r1Copy.UpdatedAt = nil + r2Copy.UpdatedAt = nil + } + if ignoreForeign { + r1Copy.Service = nil + r2Copy.Service = nil + } + + if r1Copy.Service != nil { + r1Copy.Service.Name = nil + } + if r2Copy.Service != nil { + r2Copy.Service.Name = nil + } + return reflect.DeepEqual(r1Copy, r2Copy) +} + +// Upstream represents a upstream in Kong. +// It adds some helper methods along with Meta to the original Upstream object. +type Upstream struct { + kong.Upstream `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (u1 *Upstream) Identifier() string { + if u1.Name != nil { + return *u1.Name + } + return *u1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (u1 *Upstream) Console() string { + return u1.FriendlyName() +} + +// Equal returns true if u1 and u2 are equal. +func (u1 *Upstream) Equal(u2 *Upstream) bool { + return u1.EqualWithOpts(u2, false, false) +} + +// EqualWithOpts returns true if u1 and u2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (u1 *Upstream) EqualWithOpts(u2 *Upstream, + ignoreID bool, ignoreTS bool, +) bool { + u1Copy := u1.Upstream.DeepCopy() + u2Copy := u2.Upstream.DeepCopy() + + if len(u1Copy.Tags) == 0 { + u1Copy.Tags = nil + } + if len(u2Copy.Tags) == 0 { + u2Copy.Tags = nil + } + + sort.Slice(u1Copy.Tags, func(i, j int) bool { return *(u1Copy.Tags[i]) < *(u1Copy.Tags[j]) }) + sort.Slice(u2Copy.Tags, func(i, j int) bool { return *(u2Copy.Tags[i]) < *(u2Copy.Tags[j]) }) + + if ignoreID { + u1Copy.ID = nil + u2Copy.ID = nil + } + if ignoreTS { + u1Copy.CreatedAt = nil + u2Copy.CreatedAt = nil + } + return reflect.DeepEqual(u1Copy, u2Copy) +} + +// Target represents a Target in Kong. +// It adds some helper methods along with Meta to the original Target object. +type Target struct { + kong.Target `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (t1 *Target) Identifier() string { + if t1.Target.Target != nil { + return *t1.Target.Target + } + return *t1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (t1 *Target) Console() string { + res := t1.FriendlyName() + if t1.Upstream != nil { + res = res + " for upstream " + t1.Upstream.FriendlyName() + } + return res +} + +// Equal returns true if t1 and t2 are equal. +// TODO add compare array without position +func (t1 *Target) Equal(t2 *Target) bool { + return t1.EqualWithOpts(t2, false, false, false) +} + +// EqualWithOpts returns true if t1 and t2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (t1 *Target) EqualWithOpts(t2 *Target, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + t1Copy := t1.Target.DeepCopy() + t2Copy := t2.Target.DeepCopy() + + if len(t1Copy.Tags) == 0 { + t1Copy.Tags = nil + } + if len(t2Copy.Tags) == 0 { + t2Copy.Tags = nil + } + + sort.Slice(t1Copy.Tags, func(i, j int) bool { return *(t1Copy.Tags[i]) < *(t1Copy.Tags[j]) }) + sort.Slice(t2Copy.Tags, func(i, j int) bool { return *(t2Copy.Tags[i]) < *(t2Copy.Tags[j]) }) + + if ignoreID { + t1Copy.ID = nil + t2Copy.ID = nil + } + if ignoreTS { + t1Copy.CreatedAt = nil + t2Copy.CreatedAt = nil + } + if ignoreForeign { + t1Copy.Upstream = nil + t2Copy.Upstream = nil + } + return reflect.DeepEqual(t1Copy, t2Copy) +} + +// Certificate represents a upstream in Kong. +// It adds some helper methods along with Meta to the +// original Certificate object. +type Certificate struct { + kong.Certificate `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (c1 *Certificate) Identifier() string { + if c1.ID != nil { + return *c1.ID + } + return *c1.Cert +} + +// Console returns an entity's identity in a human +// readable string. +func (c1 *Certificate) Console() string { + return c1.FriendlyName() +} + +// Equal returns true if c1 and c2 are equal. +func (c1 *Certificate) Equal(c2 *Certificate) bool { + return c1.EqualWithOpts(c2, false, false) +} + +// EqualWithOpts returns true if c1 and c2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (c1 *Certificate) EqualWithOpts(c2 *Certificate, + ignoreID bool, ignoreTS bool, +) bool { + c1Copy := c1.Certificate.DeepCopy() + c2Copy := c2.Certificate.DeepCopy() + + if len(c1Copy.Tags) == 0 { + c1Copy.Tags = nil + } + if len(c2Copy.Tags) == 0 { + c2Copy.Tags = nil + } + + sort.Slice(c1Copy.Tags, func(i, j int) bool { return *(c1Copy.Tags[i]) < *(c1Copy.Tags[j]) }) + sort.Slice(c2Copy.Tags, func(i, j int) bool { return *(c2Copy.Tags[i]) < *(c2Copy.Tags[j]) }) + + if ignoreID { + c1Copy.ID = nil + c2Copy.ID = nil + } + if ignoreTS { + c1Copy.CreatedAt = nil + c2Copy.CreatedAt = nil + } + return reflect.DeepEqual(c1Copy, c2Copy) +} + +// SNI represents a SNI in Kong. +// It adds some helper methods along with Meta to the original SNI object. +type SNI struct { + kong.SNI `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (s1 *SNI) Identifier() string { + if s1.Name != nil { + return *s1.Name + } + return *s1.ID +} + +// Equal returns true if s1 and s2 are equal. +// TODO add compare array without position +func (s1 *SNI) Equal(s2 *SNI) bool { + return s1.EqualWithOpts(s2, false, false, false) +} + +// Console returns an entity's identity in a human +// readable string. +func (s1 *SNI) Console() string { + return s1.FriendlyName() +} + +// EqualWithOpts returns true if s1 and s2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (s1 *SNI) EqualWithOpts(s2 *SNI, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + s1Copy := s1.SNI.DeepCopy() + s2Copy := s2.SNI.DeepCopy() + + if len(s1Copy.Tags) == 0 { + s1Copy.Tags = nil + } + if len(s2Copy.Tags) == 0 { + s2Copy.Tags = nil + } + + sort.Slice(s1Copy.Tags, func(i, j int) bool { return *(s1Copy.Tags[i]) < *(s1Copy.Tags[j]) }) + sort.Slice(s2Copy.Tags, func(i, j int) bool { return *(s2Copy.Tags[i]) < *(s2Copy.Tags[j]) }) + + if ignoreID { + s1Copy.ID = nil + s2Copy.ID = nil + } + if ignoreTS { + s1Copy.CreatedAt = nil + s2Copy.CreatedAt = nil + } + if ignoreForeign { + s1Copy.Certificate = nil + s2Copy.Certificate = nil + } + return reflect.DeepEqual(s1Copy, s2Copy) +} + +// Plugin represents a route in Kong. +// It adds some helper methods along with Meta to the original Plugin object. +type Plugin struct { + kong.Plugin `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (p1 *Plugin) Identifier() string { + if p1.Name != nil { + return *p1.Name + } + return *p1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (p1 *Plugin) Console() string { + res := *p1.Name + " " + + if p1.Service == nil && p1.Route == nil && p1.Consumer == nil { + return res + "(global)" + } + associations := []string{} + if p1.Service != nil { + associations = append(associations, "service "+p1.Service.FriendlyName()) + } + if p1.Route != nil { + associations = append(associations, "route "+p1.Route.FriendlyName()) + } + if p1.Consumer != nil { + associations = append(associations, "consumer "+p1.Consumer.FriendlyName()) + } + if p1.ConsumerGroup != nil { + associations = append(associations, "consumer-group "+p1.ConsumerGroup.FriendlyName()) + } + if len(associations) > 0 { + res += "for " + } + for i := 0; i < len(associations); i++ { + res += associations[i] + if i < len(associations)-1 { + res += " and " + } + } + return res +} + +// Equal returns true if r1 and r2 are equal. +// TODO add compare array without position +func (p1 *Plugin) Equal(p2 *Plugin) bool { + return p1.EqualWithOpts(p2, false, false, false) +} + +// EqualWithOpts returns true if p1 and p2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (p1 *Plugin) EqualWithOpts(p2 *Plugin, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + p1Copy := p1.Plugin.DeepCopy() + p2Copy := p2.Plugin.DeepCopy() + + if len(p1Copy.Tags) == 0 { + p1Copy.Tags = nil + } + if len(p2Copy.Tags) == 0 { + p2Copy.Tags = nil + } + + sort.Slice(p1Copy.Tags, func(i, j int) bool { return *(p1Copy.Tags[i]) < *(p1Copy.Tags[j]) }) + sort.Slice(p2Copy.Tags, func(i, j int) bool { return *(p2Copy.Tags[i]) < *(p2Copy.Tags[j]) }) + + if ignoreID { + p1Copy.ID = nil + p2Copy.ID = nil + } + if ignoreTS { + p1Copy.CreatedAt = nil + p2Copy.CreatedAt = nil + } + if ignoreForeign { + p1Copy.Service = nil + p1Copy.Route = nil + p1Copy.Consumer = nil + p2Copy.Service = nil + p2Copy.Route = nil + p2Copy.Consumer = nil + p2Copy.ConsumerGroup = nil + } + + if p1Copy.Service != nil { + p1Copy.Service.Name = nil + } + if p2Copy.Service != nil { + p2Copy.Service.Name = nil + } + if p1Copy.Route != nil { + p1Copy.Route.Name = nil + } + if p2Copy.Route != nil { + p2Copy.Route.Name = nil + } + if p1Copy.Consumer != nil { + p1Copy.Consumer.Username = nil + } + if p2Copy.Consumer != nil { + p2Copy.Consumer.Username = nil + } + if p1Copy.ConsumerGroup != nil { + p1Copy.ConsumerGroup.Name = nil + } + if p2Copy.ConsumerGroup != nil { + p2Copy.ConsumerGroup.Name = nil + } + return reflect.DeepEqual(p1Copy, p2Copy) +} + +// Consumer represents a consumer in Kong. +// It adds some helper methods along with Meta to the original Consumer object. +type Consumer struct { + kong.Consumer `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (c1 *Consumer) Identifier() string { + if c1.Username != nil { + return *c1.Username + } + return *c1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (c1 *Consumer) Console() string { + return c1.FriendlyName() +} + +// Equal returns true if c1 and c2 are equal. +func (c1 *Consumer) Equal(c2 *Consumer) bool { + return c1.EqualWithOpts(c2, false, false) +} + +// EqualWithOpts returns true if c1 and c2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (c1 *Consumer) EqualWithOpts(c2 *Consumer, + ignoreID bool, ignoreTS bool, +) bool { + c1Copy := c1.Consumer.DeepCopy() + c2Copy := c2.Consumer.DeepCopy() + + if len(c1Copy.Tags) == 0 { + c1Copy.Tags = nil + } + if len(c2Copy.Tags) == 0 { + c2Copy.Tags = nil + } + + sort.Slice(c1Copy.Tags, func(i, j int) bool { return *(c1Copy.Tags[i]) < *(c1Copy.Tags[j]) }) + sort.Slice(c2Copy.Tags, func(i, j int) bool { return *(c2Copy.Tags[i]) < *(c2Copy.Tags[j]) }) + + if ignoreID { + c1Copy.ID = nil + c2Copy.ID = nil + } + if ignoreTS { + c1Copy.CreatedAt = nil + c2Copy.CreatedAt = nil + } + return reflect.DeepEqual(c1Copy, c2Copy) +} + +func forConsumerString(c *kong.Consumer) string { + if c != nil { + friendlyName := c.FriendlyName() + if friendlyName != "" { + return " for consumer " + friendlyName + } + } + return "" +} + +// ConsumerGroupObject represents a ConsumerGroupObject in Kong. +// It adds some helper methods along with Meta to the original Upstream object. +type ConsumerGroupObject struct { + kong.ConsumerGroupObject `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (c1 *ConsumerGroupObject) Identifier() string { + if c1.ConsumerGroup != nil && c1.ConsumerGroup.Name != nil { + return *c1.ConsumerGroup.Name + } + return *c1.ConsumerGroup.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (c1 *ConsumerGroupObject) Console() string { + return c1.ConsumerGroup.FriendlyName() +} + +// Equal returns true if u1 and u2 are equal. +func (c1 *ConsumerGroupObject) Equal(c2 *ConsumerGroupObject) bool { + return c1.EqualWithOpts(c2, false, false) +} + +// EqualWithOpts returns true if c1 and c2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (c1 *ConsumerGroupObject) EqualWithOpts(c2 *ConsumerGroupObject, + ignoreID bool, ignoreTS bool, +) bool { + c1Copy := c1.ConsumerGroup.DeepCopy() + c2Copy := c2.ConsumerGroup.DeepCopy() + + if ignoreID { + c1Copy.ID = nil + c2Copy.ID = nil + } + + if ignoreTS { + c1Copy.CreatedAt = nil + c2Copy.CreatedAt = nil + } + return reflect.DeepEqual(c1Copy, c2Copy) +} + +// ConsumerGroup represents a ConsumerGroup in Kong. +// It adds some helper methods along with Meta to the original ConsumerGroup object. +type ConsumerGroup struct { + kong.ConsumerGroup `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (c1 *ConsumerGroup) Identifier() string { + if c1.ConsumerGroup.Name != nil { + return *c1.ConsumerGroup.Name + } + return *c1.ConsumerGroup.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (c1 *ConsumerGroup) Console() string { + return c1.ConsumerGroup.FriendlyName() +} + +// Equal returns true if c1 and c2 are equal. +func (c1 *ConsumerGroup) Equal(c2 *ConsumerGroup) bool { + return c1.EqualWithOpts(c2, false, false) +} + +// EqualWithOpts returns true if c1 and c2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (c1 *ConsumerGroup) EqualWithOpts(c2 *ConsumerGroup, + ignoreID bool, ignoreTS bool, +) bool { + u1Copy := c1.ConsumerGroup.DeepCopy() + u2Copy := c2.ConsumerGroup.DeepCopy() + + if ignoreID { + u1Copy.ID = nil + u2Copy.ID = nil + } + if ignoreTS { + u1Copy.CreatedAt = nil + u2Copy.CreatedAt = nil + } + return reflect.DeepEqual(u1Copy, u2Copy) +} + +// ConsumerGroupConsumer represents a ConsumerGroupConsumer in Kong. +// It adds some helper methods along with Meta to the original ConsumerGroupConsumer object. +type ConsumerGroupConsumer struct { + kong.ConsumerGroupConsumer `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key Ursername or ID. +func (c1 *ConsumerGroupConsumer) Identifier() string { + if c1.Consumer.Username != nil { + return *c1.Consumer.Username + } + return *c1.Consumer.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (c1 *ConsumerGroupConsumer) Console() string { + return *c1.ConsumerGroupConsumer.Consumer.Username +} + +// Equal returns true if c1 and c2 are equal. +func (c1 *ConsumerGroupConsumer) Equal(c2 *ConsumerGroupConsumer) bool { + return c1.EqualWithOpts(c2, false, false) +} + +// EqualWithOpts returns true if c1 and c2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (c1 *ConsumerGroupConsumer) EqualWithOpts(c2 *ConsumerGroupConsumer, + ignoreID bool, ignoreTS bool, +) bool { + c1Copy := c1.ConsumerGroupConsumer.DeepCopy() + c2Copy := c2.ConsumerGroupConsumer.DeepCopy() + if ignoreID { + c1Copy.Consumer.ID = nil + c2Copy.Consumer.ID = nil + } + if ignoreTS { + c1Copy.CreatedAt = nil + c2Copy.CreatedAt = nil + c1Copy.Consumer.CreatedAt = nil + c2Copy.Consumer.CreatedAt = nil + c2Copy.ConsumerGroup.CreatedAt = nil + c1Copy.ConsumerGroup.CreatedAt = nil + } + return reflect.DeepEqual(c1Copy, c2Copy) +} + +// ConsumerGroupPlugin represents a ConsumerGroupConsumer in Kong. +// It adds some helper methods along with Meta to the original ConsumerGroupConsumer object. +type ConsumerGroupPlugin struct { + kong.ConsumerGroupPlugin `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (c1 *ConsumerGroupPlugin) Identifier() string { + if c1.Name != nil { + return *c1.Name + } + return *c1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (c1 *ConsumerGroupPlugin) Console() string { + return *c1.Name +} + +// Equal returns true if c1 and c2 are equal. +func (c1 *ConsumerGroupPlugin) Equal(c2 *ConsumerGroupPlugin) bool { + return c1.EqualWithOpts(c2, false, false) +} + +// EqualWithOpts returns true if c1 and c2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (c1 *ConsumerGroupPlugin) EqualWithOpts(c2 *ConsumerGroupPlugin, + ignoreID bool, ignoreTS bool, +) bool { + c1Copy := c1.DeepCopy() + c2Copy := c2.DeepCopy() + if ignoreID { + c1Copy.ID = nil + c2Copy.ID = nil + } + if ignoreTS { + c1Copy.CreatedAt = nil + c2Copy.CreatedAt = nil + c1Copy.ConsumerGroup.CreatedAt = nil + c2Copy.ConsumerGroup.CreatedAt = nil + } + return reflect.DeepEqual(c1Copy, c2Copy) +} + +// KeyAuth represents a key-auth credential in Kong. +// It adds some helper methods along with Meta to the original KeyAuth object. +type KeyAuth struct { + kong.KeyAuth `yaml:",inline"` + Meta +} + +// stripKey returns the last 5 characters of key. +// If key is less than or equal to 5 characters, then the key is returned as is. +func stripKey(key string) string { + const keyIdentifierLength = 5 + if len(key) <= keyIdentifierLength { + return key + } + return key[len(key)-keyIdentifierLength:] +} + +// Console returns an entity's identity in a human +// readable string. +func (k1 *KeyAuth) Console() string { + return stripKey(*k1.Key) + forConsumerString(k1.Consumer) +} + +// Equal returns true if k1 and k2 are equal. +func (k1 *KeyAuth) Equal(k2 *KeyAuth) bool { + return k1.EqualWithOpts(k2, false, false, false) +} + +// EqualWithOpts returns true if k1 and k2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (k1 *KeyAuth) EqualWithOpts(k2 *KeyAuth, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + k1Copy := k1.KeyAuth.DeepCopy() + k2Copy := k2.KeyAuth.DeepCopy() + + if len(k1Copy.Tags) == 0 { + k1Copy.Tags = nil + } + if len(k2Copy.Tags) == 0 { + k2Copy.Tags = nil + } + + sort.Slice(k1Copy.Tags, func(i, j int) bool { return *(k1Copy.Tags[i]) < *(k1Copy.Tags[j]) }) + sort.Slice(k2Copy.Tags, func(i, j int) bool { return *(k2Copy.Tags[i]) < *(k2Copy.Tags[j]) }) + + if ignoreID { + k1Copy.ID = nil + k2Copy.ID = nil + } + if ignoreTS { + k1Copy.CreatedAt = nil + k2Copy.CreatedAt = nil + } + if ignoreForeign { + k1Copy.Consumer = nil + k2Copy.Consumer = nil + } + if k1Copy.Consumer != nil { + k1Copy.Consumer.Username = nil + } + if k2Copy.Consumer != nil { + k2Copy.Consumer.Username = nil + } + return reflect.DeepEqual(k1Copy, k2Copy) +} + +// GetID returns ID. +// If ID is empty, it returns an empty string. +func (k1 *KeyAuth) GetID() string { + if k1.ID == nil { + return "" + } + return *k1.ID +} + +// GetID2 returns the endpoint key of the entity, +// the Key field for KeyAuth. +func (k1 *KeyAuth) GetID2() string { + if k1.Key == nil { + return "" + } + return *k1.Key +} + +// GetConsumer returns the credential's Consumer's ID. +// If Consumer's ID is empty, it returns an empty string. +func (k1 *KeyAuth) GetConsumer() string { + if k1.Consumer == nil || k1.Consumer.ID == nil { + return "" + } + return *k1.Consumer.ID +} + +// HMACAuth represents a key-auth credential in Kong. +// It adds some helper methods along with Meta to the original HMACAuth object. +type HMACAuth struct { + kong.HMACAuth `yaml:",inline"` + Meta +} + +// Console returns an entity's identity in a human +// readable string. +func (h1 *HMACAuth) Console() string { + return *h1.Username + forConsumerString(h1.Consumer) +} + +// Equal returns true if h1 and h2 are equal. +func (h1 *HMACAuth) Equal(h2 *HMACAuth) bool { + return h1.EqualWithOpts(h2, false, false, false) +} + +// EqualWithOpts returns true if h1 and h2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (h1 *HMACAuth) EqualWithOpts(h2 *HMACAuth, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + h1Copy := h1.HMACAuth.DeepCopy() + h2Copy := h2.HMACAuth.DeepCopy() + + if len(h1Copy.Tags) == 0 { + h1Copy.Tags = nil + } + if len(h2Copy.Tags) == 0 { + h2Copy.Tags = nil + } + + sort.Slice(h1Copy.Tags, func(i, j int) bool { return *(h1Copy.Tags[i]) < *(h1Copy.Tags[j]) }) + sort.Slice(h2Copy.Tags, func(i, j int) bool { return *(h2Copy.Tags[i]) < *(h2Copy.Tags[j]) }) + + if ignoreID { + h1Copy.ID = nil + h2Copy.ID = nil + } + if ignoreTS { + h1Copy.CreatedAt = nil + h2Copy.CreatedAt = nil + } + if ignoreForeign { + h1Copy.Consumer = nil + h2Copy.Consumer = nil + } + if h1Copy.Consumer != nil { + h1Copy.Consumer.Username = nil + } + if h2Copy.Consumer != nil { + h2Copy.Consumer.Username = nil + } + return reflect.DeepEqual(h1Copy, h2Copy) +} + +// GetID returns ID. +// If ID is empty, it returns an empty string. +func (h1 *HMACAuth) GetID() string { + if h1.ID == nil { + return "" + } + return *h1.ID +} + +// GetID2 returns the endpoint key of the entity, +// the Username field for HMACAuth. +func (h1 *HMACAuth) GetID2() string { + if h1.Username == nil { + return "" + } + return *h1.Username +} + +// GetConsumer returns the credential's Consumer's ID. +// If Consumer's ID is empty, it returns an empty string. +func (h1 *HMACAuth) GetConsumer() string { + if h1.Consumer == nil || h1.Consumer.ID == nil { + return "" + } + return *h1.Consumer.ID +} + +// JWTAuth represents a jwt credential in Kong. +// It adds some helper methods along with Meta to the original JWTAuth object. +type JWTAuth struct { + kong.JWTAuth `yaml:",inline"` + Meta +} + +// Console returns an entity's identity in a human +// readable string. +func (j1 *JWTAuth) Console() string { + return *j1.Key + forConsumerString(j1.Consumer) +} + +// Equal returns true if j1 and j2 are equal. +func (j1 *JWTAuth) Equal(j2 *JWTAuth) bool { + return j1.EqualWithOpts(j2, false, false, false) +} + +// EqualWithOpts returns true if j1 and j2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (j1 *JWTAuth) EqualWithOpts(j2 *JWTAuth, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + j1Copy := j1.JWTAuth.DeepCopy() + j2Copy := j2.JWTAuth.DeepCopy() + + if len(j1Copy.Tags) == 0 { + j1Copy.Tags = nil + } + if len(j2Copy.Tags) == 0 { + j2Copy.Tags = nil + } + + sort.Slice(j1Copy.Tags, func(i, j int) bool { return *(j1Copy.Tags[i]) < *(j1Copy.Tags[j]) }) + sort.Slice(j2Copy.Tags, func(i, j int) bool { return *(j2Copy.Tags[i]) < *(j2Copy.Tags[j]) }) + + if ignoreID { + j1Copy.ID = nil + j2Copy.ID = nil + } + if ignoreTS { + j1Copy.CreatedAt = nil + j2Copy.CreatedAt = nil + } + if ignoreForeign { + j1Copy.Consumer = nil + j2Copy.Consumer = nil + } + if j1Copy.Consumer != nil { + j1Copy.Consumer.Username = nil + } + if j2Copy.Consumer != nil { + j2Copy.Consumer.Username = nil + } + return reflect.DeepEqual(j1Copy, j2Copy) +} + +// GetID returns ID. +// If ID is empty, it returns an empty string. +func (j1 *JWTAuth) GetID() string { + if j1.ID == nil { + return "" + } + return *j1.ID +} + +// GetID2 returns the endpoint key of the entity, +// the Key field for JWTAuth. +func (j1 *JWTAuth) GetID2() string { + if j1.Key == nil { + return "" + } + return *j1.Key +} + +// GetConsumer returns the credential's Consumer's ID. +// If Consumer's ID is empty, it returns an empty string. +func (j1 *JWTAuth) GetConsumer() string { + if j1.Consumer == nil || j1.Consumer.ID == nil { + return "" + } + return *j1.Consumer.ID +} + +// BasicAuth represents a basic-auth credential in Kong. +// It adds some helper methods along with Meta to the original BasicAuth object. +type BasicAuth struct { + kong.BasicAuth `yaml:",inline"` + Meta +} + +// Console returns an entity's identity in a human +// readable string. +func (b1 *BasicAuth) Console() string { + return *b1.Username + forConsumerString(b1.Consumer) +} + +// Equal returns true if b1 and b2 are equal. +func (b1 *BasicAuth) Equal(b2 *BasicAuth) bool { + return b1.EqualWithOpts(b2, false, false, false, false) +} + +// EqualWithOpts returns true if j1 and j2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (b1 *BasicAuth) EqualWithOpts(b2 *BasicAuth, ignoreID, + ignoreTS, ignorePassword, ignoreForeign bool, +) bool { + b1Copy := b1.BasicAuth.DeepCopy() + b2Copy := b2.BasicAuth.DeepCopy() + + if len(b1Copy.Tags) == 0 { + b1Copy.Tags = nil + } + if len(b2Copy.Tags) == 0 { + b2Copy.Tags = nil + } + + sort.Slice(b1Copy.Tags, func(i, j int) bool { return *(b1Copy.Tags[i]) < *(b1Copy.Tags[j]) }) + sort.Slice(b2Copy.Tags, func(i, j int) bool { return *(b2Copy.Tags[i]) < *(b2Copy.Tags[j]) }) + + if ignoreID { + b1Copy.ID = nil + b2Copy.ID = nil + } + if ignoreTS { + b1Copy.CreatedAt = nil + b2Copy.CreatedAt = nil + } + if ignorePassword { + b1Copy.Password = nil + b2Copy.Password = nil + } + if ignoreForeign { + b1Copy.Consumer = nil + b2Copy.Consumer = nil + } + if b1Copy.Consumer != nil { + b1Copy.Consumer.Username = nil + } + if b2Copy.Consumer != nil { + b2Copy.Consumer.Username = nil + } + return reflect.DeepEqual(b1Copy, b2Copy) +} + +// GetID returns ID. +// If ID is empty, it returns an empty string. +func (b1 *BasicAuth) GetID() string { + if b1.ID == nil { + return "" + } + return *b1.ID +} + +// GetID2 returns the endpoint key of the entity, +// the Username field for BasicAuth. +func (b1 *BasicAuth) GetID2() string { + if b1.Username == nil { + return "" + } + return *b1.Username +} + +// GetConsumer returns the credential's Consumer's ID. +// If Consumer's ID is empty, it returns an empty string. +func (b1 *BasicAuth) GetConsumer() string { + if b1.Consumer == nil || b1.Consumer.ID == nil { + return "" + } + return *b1.Consumer.ID +} + +// ACLGroup represents an ACL group for a consumer in Kong. +// It adds some helper methods along with Meta to the original ACLGroup object. +type ACLGroup struct { + kong.ACLGroup `yaml:",inline"` + Meta +} + +// Console returns an entity's identity in a human +// readable string. +func (b1 *ACLGroup) Console() string { + return *b1.Group + forConsumerString(b1.Consumer) +} + +// Equal returns true if b1 and b2 are equal. +func (b1 *ACLGroup) Equal(b2 *ACLGroup) bool { + return b1.EqualWithOpts(b2, false, false, false) +} + +// EqualWithOpts returns true if j1 and j2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (b1 *ACLGroup) EqualWithOpts(b2 *ACLGroup, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + b1Copy := b1.ACLGroup.DeepCopy() + b2Copy := b2.ACLGroup.DeepCopy() + + if len(b1Copy.Tags) == 0 { + b1Copy.Tags = nil + } + if len(b2Copy.Tags) == 0 { + b2Copy.Tags = nil + } + + sort.Slice(b1Copy.Tags, func(i, j int) bool { return *(b1Copy.Tags[i]) < *(b1Copy.Tags[j]) }) + sort.Slice(b2Copy.Tags, func(i, j int) bool { return *(b2Copy.Tags[i]) < *(b2Copy.Tags[j]) }) + + if ignoreID { + b1Copy.ID = nil + b2Copy.ID = nil + } + if ignoreTS { + b1Copy.CreatedAt = nil + b2Copy.CreatedAt = nil + } + if ignoreForeign { + b1Copy.Consumer = nil + b2Copy.Consumer = nil + } + if b1Copy.Consumer != nil { + b1Copy.Consumer.Username = nil + } + if b2Copy.Consumer != nil { + b2Copy.Consumer.Username = nil + } + return reflect.DeepEqual(b1Copy, b2Copy) +} + +// CACertificate represents a CACertificate in Kong. +// It adds some helper methods along with Meta to the +// original CACertificate object. +type CACertificate struct { + kong.CACertificate `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (c1 *CACertificate) Identifier() string { + if c1.ID != nil { + return *c1.ID + } + return *c1.Cert +} + +// Console returns an entity's identity in a human +// readable string. +func (c1 *CACertificate) Console() string { + return c1.FriendlyName() +} + +// Equal returns true if c1 and c2 are equal. +func (c1 *CACertificate) Equal(c2 *CACertificate) bool { + return c1.EqualWithOpts(c2, false, false) +} + +// EqualWithOpts returns true if c1 and c2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (c1 *CACertificate) EqualWithOpts(c2 *CACertificate, + ignoreID bool, ignoreTS bool, +) bool { + c1Copy := c1.CACertificate.DeepCopy() + c2Copy := c2.CACertificate.DeepCopy() + + if len(c1Copy.Tags) == 0 { + c1Copy.Tags = nil + } + if len(c2Copy.Tags) == 0 { + c2Copy.Tags = nil + } + + sort.Slice(c1Copy.Tags, func(i, j int) bool { return *(c1Copy.Tags[i]) < *(c1Copy.Tags[j]) }) + sort.Slice(c2Copy.Tags, func(i, j int) bool { return *(c2Copy.Tags[i]) < *(c2Copy.Tags[j]) }) + + if ignoreID { + c1Copy.ID = nil + c2Copy.ID = nil + } + if ignoreTS { + c1Copy.CreatedAt = nil + c2Copy.CreatedAt = nil + } + return reflect.DeepEqual(c1Copy, c2Copy) +} + +// Oauth2Credential represents an Oauth2 credential in Kong. +// It adds some helper methods along with Meta to the original Oauth2Credential object. +type Oauth2Credential struct { + kong.Oauth2Credential `yaml:",inline"` + Meta +} + +// Console returns an entity's identity in a human +// readable string. +func (k1 *Oauth2Credential) Console() string { + return *k1.Name + forConsumerString(k1.Consumer) +} + +// Equal returns true if k1 and k2 are equal. +func (k1 *Oauth2Credential) Equal(k2 *Oauth2Credential) bool { + return k1.EqualWithOpts(k2, false, false, false) +} + +// EqualWithOpts returns true if k1 and k2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (k1 *Oauth2Credential) EqualWithOpts(k2 *Oauth2Credential, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + k1Copy := k1.Oauth2Credential.DeepCopy() + k2Copy := k2.Oauth2Credential.DeepCopy() + + if len(k1Copy.Tags) == 0 { + k1Copy.Tags = nil + } + if len(k2Copy.Tags) == 0 { + k2Copy.Tags = nil + } + + sort.Slice(k1Copy.Tags, func(i, j int) bool { return *(k1Copy.Tags[i]) < *(k1Copy.Tags[j]) }) + sort.Slice(k2Copy.Tags, func(i, j int) bool { return *(k2Copy.Tags[i]) < *(k2Copy.Tags[j]) }) + + if ignoreID { + k1Copy.ID = nil + k2Copy.ID = nil + } + if ignoreTS { + k1Copy.CreatedAt = nil + k2Copy.CreatedAt = nil + } + if ignoreForeign { + k1Copy.Consumer = nil + k2Copy.Consumer = nil + } + if k1Copy.Consumer != nil { + k1Copy.Consumer.Username = nil + } + if k2Copy.Consumer != nil { + k2Copy.Consumer.Username = nil + } + return reflect.DeepEqual(k1Copy, k2Copy) +} + +// GetID returns ID. +// If ID is empty, it returns an empty string. +func (k1 *Oauth2Credential) GetID() string { + if k1.ID == nil { + return "" + } + return *k1.ID +} + +// GetID2 returns the endpoint key of the entity, +// the ClientID field for Oauth2Credential. +func (k1 *Oauth2Credential) GetID2() string { + if k1.ClientID == nil { + return "" + } + return *k1.ClientID +} + +// GetConsumer returns the credential's Consumer's ID. +// If Consumer's ID is empty, it returns an empty string. +func (k1 *Oauth2Credential) GetConsumer() string { + if k1.Consumer == nil || k1.Consumer.ID == nil { + return "" + } + return *k1.Consumer.ID +} + +// MTLSAuth represents an mtls-auth credential in Kong. +// It adds some helper methods along with Meta to the original MTLSAuth object. +type MTLSAuth struct { + kong.MTLSAuth `yaml:",inline"` + Meta +} + +// Console returns an entity's identity in a human +// readable string. +func (b1 *MTLSAuth) Console() string { + return *b1.SubjectName + forConsumerString(b1.Consumer) +} + +// Equal returns true if b1 and b2 are equal. +func (b1 *MTLSAuth) Equal(b2 *MTLSAuth) bool { + return b1.EqualWithOpts(b2, false, false, false) +} + +// EqualWithOpts returns true if j1 and j2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (b1 *MTLSAuth) EqualWithOpts(b2 *MTLSAuth, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + b1Copy := b1.MTLSAuth.DeepCopy() + b2Copy := b2.MTLSAuth.DeepCopy() + + if len(b1Copy.Tags) == 0 { + b1Copy.Tags = nil + } + if len(b2Copy.Tags) == 0 { + b2Copy.Tags = nil + } + + sort.Slice(b1Copy.Tags, func(i, j int) bool { return *(b1Copy.Tags[i]) < *(b1Copy.Tags[j]) }) + sort.Slice(b2Copy.Tags, func(i, j int) bool { return *(b2Copy.Tags[i]) < *(b2Copy.Tags[j]) }) + + if ignoreID { + b1Copy.ID = nil + b2Copy.ID = nil + } + if ignoreTS { + b1Copy.CreatedAt = nil + b2Copy.CreatedAt = nil + } + if ignoreForeign { + b1Copy.Consumer = nil + b2Copy.Consumer = nil + } + if b1Copy.Consumer != nil { + b1Copy.Consumer.Username = nil + } + if b2Copy.Consumer != nil { + b2Copy.Consumer.Username = nil + } + return reflect.DeepEqual(b1Copy, b2Copy) +} + +// RBACRole represents an RBAC Role in Kong. +// It adds some helper methods along with Meta to the original RBACRole object. +type RBACRole struct { + kong.RBACRole `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (r1 *RBACRole) Identifier() string { + if r1.Name != nil { + return *r1.Name + } + return *r1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (r1 *RBACRole) Console() string { + return r1.FriendlyName() +} + +// Equal returns true if r1 and r2 are equal. +// TODO add compare array without position +func (r1 *RBACRole) Equal(r2 *RBACRole) bool { + return r1.EqualWithOpts(r2, false, false, false) +} + +// EqualWithOpts returns true if r1 and r2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (r1 *RBACRole) EqualWithOpts(r2 *RBACRole, ignoreID, + ignoreTS, _ bool, +) bool { + r1Copy := r1.RBACRole.DeepCopy() + r2Copy := r2.RBACRole.DeepCopy() + + if ignoreID { + r1Copy.ID = nil + r2Copy.ID = nil + } + if ignoreTS { + r1Copy.CreatedAt = nil + r2Copy.CreatedAt = nil + } + + return reflect.DeepEqual(r1Copy, r2Copy) +} + +// RBACEndpointPermission represents an RBAC Role in Kong. +// It adds some helper methods along with Meta to the original RBACEndpointPermission object. +type RBACEndpointPermission struct { + ID string + kong.RBACEndpointPermission `yaml:",inline"` + Meta +} + +// Identifier returns a composite ID base on Role ID, workspace, and endpoint +func (r1 *RBACEndpointPermission) Identifier() string { + if r1.Endpoint != nil { + return fmt.Sprintf("%s-%s-%s", *r1.Role.ID, *r1.Workspace, *r1.Endpoint) + } + return *r1.Endpoint +} + +// Console returns an entity's identity in a human +// readable string. +func (r1 *RBACEndpointPermission) Console() string { + return r1.FriendlyName() +} + +// Equal returns true if r1 and r2 are equal. +// TODO add compare array without position +func (r1 *RBACEndpointPermission) Equal(r2 *RBACEndpointPermission) bool { + return r1.EqualWithOpts(r2, false, false, false) +} + +// EqualWithOpts returns true if r1 and r2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (r1 *RBACEndpointPermission) EqualWithOpts(r2 *RBACEndpointPermission, ignoreID, + ignoreTS, _ bool, +) bool { + r1Copy := r1.RBACEndpointPermission.DeepCopy() + r2Copy := r2.RBACEndpointPermission.DeepCopy() + + if ignoreID { + r1Copy.Endpoint = nil + r2Copy.Endpoint = nil + } + if ignoreTS { + r1Copy.CreatedAt = nil + r2Copy.CreatedAt = nil + } + + return reflect.DeepEqual(r1Copy, r2Copy) +} + +// GetID returns ID. +// If ID is empty, it returns an empty string. +func (b1 *MTLSAuth) GetID() string { + if b1.ID == nil { + return "" + } + return *b1.ID +} + +// GetID2 returns the endpoint key of the entity, +// BUT NO SUCH THING EXISTS 😱 +// TODO: this is kind of a pointless clone of GetID for MTLSAuth. the mtls-auth +// entity cannot be referenced by anything other than its ID (it has no unique +// fields), but the entity interface requires this function. this duplication +// doesn't appear to be harmful, but it's weird. +func (b1 *MTLSAuth) GetID2() string { + return (*b1).GetID() +} + +func (b1 *MTLSAuth) GetConsumer() string { + if b1.Consumer == nil || b1.Consumer.ID == nil { + return "" + } + return *b1.Consumer.ID +} + +// Vault represents a vault in Kong. +// It adds some helper methods along with Meta to the original Vault object. +type Vault struct { + kong.Vault `yaml:",inline"` + Meta +} + +// Identifier returns the endpoint key name or ID. +func (v1 *Vault) Identifier() string { + if v1.Name != nil { + return *v1.Name + } + return *v1.ID +} + +// Console returns an entity's identity in a human +// readable string. +func (v1 *Vault) Console() string { + return v1.FriendlyName() +} + +// Equal returns true if v1 and v2 are equal. +// TODO add compare array without position +func (v1 *Vault) Equal(v2 *Vault) bool { + return v1.EqualWithOpts(v2, false, false) +} + +// EqualWithOpts returns true if v1 and v2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (v1 *Vault) EqualWithOpts(v2 *Vault, ignoreID, ignoreTS bool) bool { + v1Copy := v1.Vault.DeepCopy() + v2Copy := v2.Vault.DeepCopy() + + if len(v1Copy.Tags) == 0 { + v1Copy.Tags = nil + } + if len(v2Copy.Tags) == 0 { + v2Copy.Tags = nil + } + + sort.Slice(v1Copy.Tags, func(i, j int) bool { return *(v1Copy.Tags[i]) < *(v1Copy.Tags[j]) }) + sort.Slice(v2Copy.Tags, func(i, j int) bool { return *(v2Copy.Tags[i]) < *(v2Copy.Tags[j]) }) + + if ignoreID { + v1Copy.ID = nil + v2Copy.ID = nil + } + if ignoreTS { + v1Copy.CreatedAt = nil + v2Copy.CreatedAt = nil + + v1Copy.UpdatedAt = nil + v2Copy.UpdatedAt = nil + } + return reflect.DeepEqual(v1Copy, v2Copy) +} diff --git a/pkg/state/types_test.go b/pkg/state/types_test.go new file mode 100644 index 0000000..043beba --- /dev/null +++ b/pkg/state/types_test.go @@ -0,0 +1,502 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +// getTags returns a slice of test tags. If reversed is true, the tags are backwards! +// backwards tag slices are useful for confirming that our equality checks ignore tag order +func getTags(reversed bool) []*string { + fooString := "foo" + barString := "bar" + if reversed { + return []*string{&barString, &fooString} + } + return []*string{&fooString, &barString} +} + +func TestMeta(t *testing.T) { + assert := assert.New(t) + + var m Meta + + m.AddMeta("foo", "bar") + r := m.GetMeta("foo") + res, ok := r.(string) + assert.True(ok) + assert.Equal("bar", res) + // assert.Equal(reflect.TypeOf(r).String(), "string") + + s := "string-pointer" + m.AddMeta("baz", &s) + r = m.GetMeta("baz") + res2, ok := r.(*string) + assert.True(ok) + assert.Equal("string-pointer", *res2) + + // can retrieve a previous value + r = m.GetMeta("foo") + res, ok = r.(string) + assert.True(ok) + assert.Equal("bar", res) +} + +func TestServiceEqual(t *testing.T) { + assert := assert.New(t) + + var s1, s2 Service + s1.ID = kong.String("foo") + s1.Name = kong.String("bar") + + s2.ID = kong.String("foo") + s2.Name = kong.String("baz") + + assert.False(s1.Equal(&s2)) + assert.False(s1.EqualWithOpts(&s2, false, false)) + + s2.Name = kong.String("bar") + assert.True(s1.Equal(&s2)) + assert.True(s1.EqualWithOpts(&s2, false, false)) + s1.Tags = getTags(true) + s2.Tags = getTags(false) + assert.True(s1.EqualWithOpts(&s2, false, false)) + + s1.ID = kong.String("fuu") + assert.False(s1.EqualWithOpts(&s2, false, false)) + assert.True(s1.EqualWithOpts(&s2, true, false)) + + s2.CreatedAt = kong.Int(1) + s1.UpdatedAt = kong.Int(2) + assert.False(s1.EqualWithOpts(&s2, false, false)) + assert.False(s1.EqualWithOpts(&s2, false, true)) +} + +func TestRouteEqual(t *testing.T) { + assert := assert.New(t) + + var r1, r2 Route + r1.ID = kong.String("foo") + r1.Name = kong.String("bar") + + r2.ID = kong.String("foo") + r2.Name = kong.String("baz") + + assert.False(r1.Equal(&r2)) + assert.False(r1.EqualWithOpts(&r2, false, false, false)) + + r2.Name = kong.String("bar") + assert.True(r1.Equal(&r2)) + assert.True(r1.EqualWithOpts(&r2, false, false, false)) + r1.Tags = getTags(true) + r2.Tags = getTags(false) + assert.True(r1.EqualWithOpts(&r2, false, false, false)) + + r1.ID = kong.String("fuu") + assert.False(r1.EqualWithOpts(&r2, false, false, false)) + assert.True(r1.EqualWithOpts(&r2, true, false, false)) + + r2.CreatedAt = kong.Int(1) + r1.UpdatedAt = kong.Int(2) + assert.False(r1.EqualWithOpts(&r2, false, false, false)) + assert.False(r1.EqualWithOpts(&r2, false, true, false)) + assert.True(r1.EqualWithOpts(&r2, true, true, false)) + + r1.Hosts = kong.StringSlice("demo1.example.com", "demo2.example.com") + + // order matters + r2.Hosts = kong.StringSlice("demo2.example.com", "demo1.example.com") + assert.False(r1.EqualWithOpts(&r2, true, true, false)) + + r2.Hosts = kong.StringSlice("demo1.example.com", "demo2.example.com") + assert.True(r1.EqualWithOpts(&r2, true, true, false)) + + r1.Service = &kong.Service{ID: kong.String("1")} + r2.Service = &kong.Service{ID: kong.String("2")} + assert.False(r1.EqualWithOpts(&r2, true, true, false)) + assert.True(r1.EqualWithOpts(&r2, true, true, true)) + + r1.Service = &kong.Service{ID: kong.String("2")} + assert.True(r1.EqualWithOpts(&r2, true, true, false)) +} + +func TestUpstreamEqual(t *testing.T) { + assert := assert.New(t) + + var u1, u2 Upstream + u1.ID = kong.String("foo") + u1.Name = kong.String("bar") + + u2.ID = kong.String("foo") + u2.Name = kong.String("baz") + + assert.False(u1.Equal(&u2)) + assert.False(u1.EqualWithOpts(&u2, false, false)) + + u2.Name = kong.String("bar") + assert.True(u1.Equal(&u2)) + assert.True(u1.EqualWithOpts(&u2, false, false)) + u1.Tags = getTags(true) + u2.Tags = getTags(false) + assert.True(u1.EqualWithOpts(&u2, false, false)) + + u1.ID = kong.String("fuu") + assert.False(u1.EqualWithOpts(&u2, false, false)) + assert.True(u1.EqualWithOpts(&u2, true, false)) + + var timestamp int64 = 1 + u2.CreatedAt = ×tamp + assert.False(u1.EqualWithOpts(&u2, false, false)) + assert.False(u1.EqualWithOpts(&u2, false, true)) +} + +func TestTargetEqual(t *testing.T) { + assert := assert.New(t) + + var t1, t2 Target + t1.ID = kong.String("foo") + t1.Target.Target = kong.String("bar") + + t2.ID = kong.String("foo") + t2.Target.Target = kong.String("baz") + + assert.False(t1.Equal(&t2)) + assert.False(t1.EqualWithOpts(&t2, false, false, false)) + + t2.Target.Target = kong.String("bar") + assert.True(t1.Equal(&t2)) + assert.True(t1.EqualWithOpts(&t2, false, false, false)) + t1.Tags = getTags(true) + t2.Tags = getTags(false) + assert.True(t1.EqualWithOpts(&t2, false, false, false)) + + t1.ID = kong.String("fuu") + assert.False(t1.EqualWithOpts(&t2, false, false, false)) + assert.True(t1.EqualWithOpts(&t2, true, false, false)) + + var timestamp float64 = 1 + t2.CreatedAt = ×tamp + assert.False(t1.EqualWithOpts(&t2, false, false, false)) + assert.False(t1.EqualWithOpts(&t2, false, true, false)) + + t1.Upstream = &kong.Upstream{ID: kong.String("1")} + t2.Upstream = &kong.Upstream{ID: kong.String("2")} + assert.False(t1.EqualWithOpts(&t2, true, true, false)) + assert.True(t1.EqualWithOpts(&t2, true, true, true)) + + t1.Upstream = &kong.Upstream{ID: kong.String("2")} + assert.True(t1.EqualWithOpts(&t2, true, true, false)) +} + +func TestCertificateEqual(t *testing.T) { + assert := assert.New(t) + + var c1, c2 Certificate + c1.ID = kong.String("foo") + c1.Cert = kong.String("certfoo") + c1.Key = kong.String("keyfoo") + + c2.ID = kong.String("foo") + c2.Cert = kong.String("certfoo") + c2.Key = kong.String("keyfoo-unequal") + + assert.False(c1.Equal(&c2)) + assert.False(c1.EqualWithOpts(&c2, false, false)) + + c2.Key = kong.String("keyfoo") + assert.True(c1.Equal(&c2)) + assert.True(c1.EqualWithOpts(&c2, false, false)) + c1.Tags = getTags(true) + c2.Tags = getTags(false) + assert.True(c1.EqualWithOpts(&c2, false, false)) + + c1.ID = kong.String("fuu") + assert.False(c1.EqualWithOpts(&c2, false, false)) + assert.True(c1.EqualWithOpts(&c2, true, false)) + + var timestamp int64 = 1 + c2.CreatedAt = ×tamp + assert.False(c1.EqualWithOpts(&c2, false, false)) + assert.False(c1.EqualWithOpts(&c2, false, true)) +} + +func TestSNIEqual(t *testing.T) { + assert := assert.New(t) + + var s1, s2 SNI + s1.ID = kong.String("foo") + s1.Name = kong.String("bar") + + s2.ID = kong.String("foo") + s2.Name = kong.String("baz") + + assert.False(s1.Equal(&s2)) + assert.False(s1.EqualWithOpts(&s2, false, false, false)) + + s2.Name = kong.String("bar") + assert.True(s1.Equal(&s2)) + assert.True(s1.EqualWithOpts(&s2, false, false, false)) + s1.Tags = getTags(true) + s2.Tags = getTags(false) + assert.True(s1.EqualWithOpts(&s2, false, false, false)) + + s1.ID = kong.String("fuu") + assert.False(s1.EqualWithOpts(&s2, false, false, false)) + assert.True(s1.EqualWithOpts(&s2, true, false, false)) + + var timestamp int64 = 1 + s2.CreatedAt = ×tamp + assert.False(s1.EqualWithOpts(&s2, false, false, false)) + assert.False(s1.EqualWithOpts(&s2, false, true, false)) + + s1.Certificate = &kong.Certificate{ID: kong.String("1")} + s2.Certificate = &kong.Certificate{ID: kong.String("2")} + assert.False(s1.EqualWithOpts(&s2, true, true, false)) + assert.True(s1.EqualWithOpts(&s2, true, true, true)) + + s1.Certificate = &kong.Certificate{ID: kong.String("2")} + assert.True(s1.EqualWithOpts(&s2, true, true, false)) +} + +func TestPluginEqual(t *testing.T) { + assert := assert.New(t) + + var p1, p2 Plugin + p1.ID = kong.String("foo") + p1.Name = kong.String("bar") + + p2.ID = kong.String("foo") + p2.Name = kong.String("baz") + + assert.False(p1.Equal(&p2)) + assert.False(p1.EqualWithOpts(&p2, false, false, false)) + + p2.Name = kong.String("bar") + assert.True(p1.Equal(&p2)) + assert.True(p1.EqualWithOpts(&p2, false, false, false)) + p1.Tags = getTags(true) + p2.Tags = getTags(false) + assert.True(p1.EqualWithOpts(&p2, false, false, false)) + + p1.ID = kong.String("fuu") + assert.False(p1.EqualWithOpts(&p2, false, false, false)) + assert.True(p1.EqualWithOpts(&p2, true, false, false)) + + timestamp := 1 + p2.CreatedAt = ×tamp + assert.False(p1.EqualWithOpts(&p2, false, false, false)) + assert.False(p1.EqualWithOpts(&p2, false, true, false)) + + p1.Service = &kong.Service{ID: kong.String("1")} + p2.Service = &kong.Service{ID: kong.String("2")} + assert.False(p1.EqualWithOpts(&p2, true, true, false)) + assert.True(p1.EqualWithOpts(&p2, true, true, true)) + + p1.Service = &kong.Service{ID: kong.String("2")} + assert.True(p1.EqualWithOpts(&p2, true, true, false)) +} + +func TestConsumerEqual(t *testing.T) { + assert := assert.New(t) + + var c1, c2 Consumer + c1.ID = kong.String("foo") + c1.Username = kong.String("bar") + + c2.ID = kong.String("foo") + c2.Username = kong.String("baz") + + assert.False(c1.Equal(&c2)) + assert.False(c1.EqualWithOpts(&c2, false, false)) + + c2.Username = kong.String("bar") + assert.True(c1.Equal(&c2)) + assert.True(c1.EqualWithOpts(&c2, false, false)) + c1.Tags = getTags(true) + c2.Tags = getTags(false) + assert.True(c1.EqualWithOpts(&c2, false, false)) + + c1.ID = kong.String("fuu") + assert.False(c1.EqualWithOpts(&c2, false, false)) + assert.True(c1.EqualWithOpts(&c2, true, false)) + + var a int64 = 1 + c2.CreatedAt = &a + assert.False(c1.EqualWithOpts(&c2, false, false)) + assert.False(c1.EqualWithOpts(&c2, false, true)) +} + +func TestKeyAuthEqual(t *testing.T) { + assert := assert.New(t) + + var k1, k2 KeyAuth + k1.ID = kong.String("foo") + k1.Key = kong.String("bar") + + k2.ID = kong.String("foo") + k2.Key = kong.String("baz") + + assert.False(k1.Equal(&k2)) + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + + k2.Key = kong.String("bar") + assert.True(k1.Equal(&k2)) + assert.True(k1.EqualWithOpts(&k2, false, false, false)) + k1.Tags = getTags(true) + k2.Tags = getTags(false) + assert.True(k1.EqualWithOpts(&k2, false, false, false)) + + k1.ID = kong.String("fuu") + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + assert.True(k1.EqualWithOpts(&k2, true, false, false)) + + k2.CreatedAt = kong.Int(1) + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + assert.False(k1.EqualWithOpts(&k2, false, true, false)) + + k2.Consumer = &kong.Consumer{Username: kong.String("u1")} + assert.False(k1.EqualWithOpts(&k2, false, true, false)) + assert.False(k1.EqualWithOpts(&k2, false, false, true)) +} + +func TestHMACAuthEqual(t *testing.T) { + assert := assert.New(t) + + var k1, k2 HMACAuth + k1.ID = kong.String("foo") + k1.Username = kong.String("bar") + + k2.ID = kong.String("foo") + k2.Username = kong.String("baz") + + assert.False(k1.Equal(&k2)) + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + + k2.Username = kong.String("bar") + assert.True(k1.Equal(&k2)) + assert.True(k1.EqualWithOpts(&k2, false, false, false)) + k1.Tags = getTags(true) + k2.Tags = getTags(false) + assert.True(k1.EqualWithOpts(&k2, false, false, false)) + + k1.ID = kong.String("fuu") + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + assert.True(k1.EqualWithOpts(&k2, true, false, false)) + + k2.CreatedAt = kong.Int(1) + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + assert.False(k1.EqualWithOpts(&k2, false, true, false)) + + k2.Consumer = &kong.Consumer{Username: kong.String("u1")} + assert.False(k1.EqualWithOpts(&k2, false, true, false)) + assert.False(k1.EqualWithOpts(&k2, false, false, true)) +} + +func TestJWTAuthEqual(t *testing.T) { + assert := assert.New(t) + + var k1, k2 JWTAuth + k1.ID = kong.String("foo") + k1.Key = kong.String("bar") + + k2.ID = kong.String("foo") + k2.Key = kong.String("baz") + + assert.False(k1.Equal(&k2)) + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + + k2.Key = kong.String("bar") + assert.True(k1.Equal(&k2)) + assert.True(k1.EqualWithOpts(&k2, false, false, false)) + k1.Tags = getTags(true) + k2.Tags = getTags(false) + assert.True(k1.EqualWithOpts(&k2, false, false, false)) + + k1.ID = kong.String("fuu") + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + assert.True(k1.EqualWithOpts(&k2, true, false, false)) + + k2.CreatedAt = kong.Int(1) + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + assert.False(k1.EqualWithOpts(&k2, false, true, false)) + + k2.Consumer = &kong.Consumer{Username: kong.String("u1")} + assert.False(k1.EqualWithOpts(&k2, false, true, false)) + assert.False(k1.EqualWithOpts(&k2, false, false, true)) +} + +func TestBasicAuthEqual(t *testing.T) { + assert := assert.New(t) + + var k1, k2 BasicAuth + k1.ID = kong.String("foo") + k1.Password = kong.String("bar") + + k2.ID = kong.String("foo") + k2.Password = kong.String("baz") + + assert.False(k1.Equal(&k2)) + assert.False(k1.EqualWithOpts(&k2, false, false, false, false)) + + k2.Password = kong.String("bar") + assert.True(k1.Equal(&k2)) + assert.True(k1.EqualWithOpts(&k2, false, false, false, false)) + assert.True(k1.EqualWithOpts(&k2, false, false, false, true)) + k1.Tags = getTags(true) + k2.Tags = getTags(false) + assert.True(k1.EqualWithOpts(&k2, false, false, false, false)) + + k1.ID = kong.String("fuu") + assert.False(k1.EqualWithOpts(&k2, false, false, false, false)) + assert.True(k1.EqualWithOpts(&k2, true, false, false, false)) + + k2.CreatedAt = kong.Int(1) + assert.False(k1.EqualWithOpts(&k2, false, false, false, false)) + assert.False(k1.EqualWithOpts(&k2, false, true, false, false)) + + k2.Consumer = &kong.Consumer{Username: kong.String("u1")} + assert.False(k1.EqualWithOpts(&k2, false, true, false, false)) + assert.False(k1.EqualWithOpts(&k2, false, false, true, false)) +} + +func TestACLGroupEqual(t *testing.T) { + assert := assert.New(t) + + var k1, k2 ACLGroup + k1.ID = kong.String("foo") + k1.Group = kong.String("bar") + + k2.ID = kong.String("foo") + k2.Group = kong.String("baz") + + assert.False(k1.Equal(&k2)) + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + + k2.Group = kong.String("bar") + assert.True(k1.Equal(&k2)) + assert.True(k1.EqualWithOpts(&k2, false, false, false)) + k1.Tags = getTags(true) + k2.Tags = getTags(false) + assert.True(k1.EqualWithOpts(&k2, false, false, false)) + + k1.ID = kong.String("fuu") + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + assert.True(k1.EqualWithOpts(&k2, true, false, false)) + + k2.CreatedAt = kong.Int(1) + assert.False(k1.EqualWithOpts(&k2, false, false, false)) + assert.False(k1.EqualWithOpts(&k2, false, true, false)) + + k2.Consumer = &kong.Consumer{Username: kong.String("u1")} + assert.False(k1.EqualWithOpts(&k2, false, true, false)) + assert.False(k1.EqualWithOpts(&k2, false, false, true)) +} + +func TestStripKey(t *testing.T) { + assert := assert.New(t) + assert.Equal("hello", stripKey("hello")) + assert.Equal("yolo", stripKey("yolo")) + assert.Equal("world", stripKey("hello world")) +} diff --git a/pkg/state/upstream.go b/pkg/state/upstream.go new file mode 100644 index 0000000..f228841 --- /dev/null +++ b/pkg/state/upstream.go @@ -0,0 +1,178 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/utils" +) + +const ( + upstreamTableName = "upstream" +) + +var upstreamTableSchema = &memdb.TableSchema{ + Name: upstreamTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Name"}, + }, + all: allIndex, + }, +} + +// UpstreamsCollection stores and indexes Kong Upstreams. +type UpstreamsCollection collection + +// Add adds an upstream to the collection. +// upstream.ID should not be nil else an error is thrown. +func (k *UpstreamsCollection) Add(upstream Upstream) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(upstream.ID) { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *upstream.ID) + if !utils.Empty(upstream.Name) { + searchBy = append(searchBy, *upstream.Name) + } + _, err := getUpstream(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting upstream %v: %w", upstream.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(upstreamTableName, &upstream) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getUpstream(txn *memdb.Txn, IDs ...string) (*Upstream, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, upstreamTableName, + []string{"name", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + + upstream, ok := res.(*Upstream) + if !ok { + panic(unexpectedType) + } + return &Upstream{Upstream: *upstream.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets an upstream by name or ID. +func (k *UpstreamsCollection) Get(nameOrID string) (*Upstream, error) { + if nameOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + upstream, err := getUpstream(txn, nameOrID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return upstream, nil +} + +// Update udpates an existing upstream. +func (k *UpstreamsCollection) Update(upstream Upstream) error { + // TODO abstract this in the go-memdb library itself + if utils.Empty(upstream.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteUpstream(txn, *upstream.ID) + if err != nil { + return err + } + + err = txn.Insert(upstreamTableName, &upstream) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteUpstream(txn *memdb.Txn, nameOrID string) error { + upstream, err := getUpstream(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(upstreamTableName, upstream) + if err != nil { + return err + } + return nil +} + +// Delete deletes an upstream by it's name or ID. +func (k *UpstreamsCollection) Delete(nameOrID string) error { + if nameOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteUpstream(txn, nameOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all upstreams in the state. +func (k *UpstreamsCollection) GetAll() ([]*Upstream, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(upstreamTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Upstream + for el := iter.Next(); el != nil; el = iter.Next() { + u, ok := el.(*Upstream) + if !ok { + panic(unexpectedType) + } + res = append(res, &Upstream{Upstream: *u.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/state/upstream_test.go b/pkg/state/upstream_test.go new file mode 100644 index 0000000..c46226b --- /dev/null +++ b/pkg/state/upstream_test.go @@ -0,0 +1,172 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func upstreamsCollection() *UpstreamsCollection { + return state().Upstreams +} + +func TestUpstreamInsert(t *testing.T) { + assert := assert.New(t) + collection := upstreamsCollection() + + // name is required + var upstream Upstream + upstream.ID = kong.String("first") + err := collection.Add(upstream) + assert.NotNil(err) + + // happy path + upstream.Name = kong.String("my-upstream") + assert.Nil(collection.Add(upstream)) + + // ID is required + var upstream2 Upstream + upstream2.Name = kong.String("my-upstream") + err = collection.Add(upstream2) + assert.NotNil(err) + + // re-insert + upstream2.ID = kong.String("first") + assert.NotNil(collection.Add(upstream2)) + + upstream2.ID = kong.String("same-name-but-different-id") + assert.NotNil(collection.Add(upstream2)) +} + +func TestUpstreamGetUpdate(t *testing.T) { + assert := assert.New(t) + collection := upstreamsCollection() + + se, err := collection.Get("does-not-exist") + assert.NotNil(err) + assert.Nil(se) + + se, err = collection.Get("") + assert.NotNil(err) + assert.Nil(se) + + var upstream Upstream + upstream.Name = kong.String("my-upstream") + upstream.ID = kong.String("first") + err = collection.Add(upstream) + assert.Nil(err) + + se, err = collection.Get("first") + assert.Nil(err) + assert.NotNil(se) + + se.Name = kong.String("my-updated-upstream") + err = collection.Update(*se) + assert.Nil(err) + + se, err = collection.Get("my-updated-upstream") + assert.Nil(err) + assert.NotNil(se) + + se.ID = nil + err = collection.Update(*se) + assert.NotNil(err) + + se, err = collection.Get("my-upstream") + assert.Equal(ErrNotFound, err) + assert.Nil(se) +} + +// Regression test +// to ensure that the memory reference of the pointer returned by Get() +// is different from the one stored in MemDB. +func TestUpstreamGetMemoryReference(t *testing.T) { + assert := assert.New(t) + collection := upstreamsCollection() + + var upstream Upstream + upstream.Name = kong.String("my-upstream") + upstream.ID = kong.String("first") + err := collection.Add(upstream) + assert.Nil(err) + + se, err := collection.Get("first") + assert.Nil(err) + assert.NotNil(se) + se.Slots = kong.Int(1) + + se, err = collection.Get("my-upstream") + assert.Nil(err) + assert.NotNil(se) + assert.Nil(se.Slots) +} + +func TestUpstreamsInvalidType(t *testing.T) { + assert := assert.New(t) + + collection := upstreamsCollection() + + var route Route + route.Name = kong.String("my-route") + route.ID = kong.String("first") + txn := collection.db.Txn(true) + txn.Insert(upstreamTableName, &route) + txn.Commit() + + assert.Panics(func() { + collection.Get("my-route") + }) + assert.Panics(func() { + collection.GetAll() + }) +} + +func TestUpstreamDelete(t *testing.T) { + assert := assert.New(t) + collection := upstreamsCollection() + + var upstream Upstream + upstream.Name = kong.String("my-upstream") + upstream.ID = kong.String("first") + err := collection.Add(upstream) + assert.Nil(err) + + se, err := collection.Get("my-upstream") + assert.Nil(err) + assert.NotNil(se) + + err = collection.Delete(*se.ID) + assert.Nil(err) + + err = collection.Delete("") + assert.NotNil(err) + + _, err = collection.Get("my-upstream") + assert.Equal(ErrNotFound, err) + + err = collection.Delete(*se.ID) + assert.NotNil(err) +} + +func TestUpstreamGetAll(t *testing.T) { + assert := assert.New(t) + collection := upstreamsCollection() + + var upstream Upstream + upstream.Name = kong.String("my-upstream1") + upstream.ID = kong.String("first") + err := collection.Add(upstream) + assert.Nil(err) + + var upstream2 Upstream + upstream2.Name = kong.String("my-upstream2") + upstream2.ID = kong.String("second") + err = collection.Add(upstream2) + assert.Nil(err) + + upstreams, err := collection.GetAll() + + assert.Nil(err) + assert.Equal(2, len(upstreams)) +} diff --git a/pkg/state/utils.go b/pkg/state/utils.go new file mode 100644 index 0000000..77aa572 --- /dev/null +++ b/pkg/state/utils.go @@ -0,0 +1,56 @@ +package state + +import ( + "fmt" + + memdb "github.com/hashicorp/go-memdb" +) + +const ( + all = "all" +) + +// ErrNotFound is an error type that is +// returned when an entity is not found in the state. +var ErrNotFound = fmt.Errorf("entity not found") + +// ErrAlreadyExists represents an entity is already present in the state. +var ErrAlreadyExists = fmt.Errorf("entity already exists") + +// internal errors +var errIDRequired = fmt.Errorf("ID is required") + +// error annotation messages +const ( + unexpectedType = "unexpected type found" +) + +var allIndex = &memdb.IndexSchema{ + Name: all, + Indexer: &memdb.ConditionalIndex{ + Conditional: func(v interface{}) (bool, error) { + return true, nil + }, + }, +} + +// multiIndexLookupUsingTxn can be used to search for an entity +// based on search on multiple indexes with same key. +func multiIndexLookupUsingTxn(txn *memdb.Txn, tableName string, + indices []string, + args ...interface{}, +) (interface{}, error) { + for _, indexName := range indices { + res, err := txn.First(tableName, indexName, args...) + if res == nil && err == nil { + continue + } + if err != nil { + return nil, err + } + if res != nil { + return res, nil + } + } + return nil, ErrNotFound +} diff --git a/pkg/state/vault.go b/pkg/state/vault.go new file mode 100644 index 0000000..b0c1fa4 --- /dev/null +++ b/pkg/state/vault.go @@ -0,0 +1,176 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/deck/utils" +) + +const ( + vaultTableName = "vault" +) + +var vaultTableSchema = &memdb.TableSchema{ + Name: vaultTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + "prefix": { + Name: "prefix", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "Prefix"}, + }, + all: allIndex, + }, +} + +// VaultsCollection stores and indexes Kong Vaults. +type VaultsCollection collection + +// Add adds a vault to the collection. +// vault.ID should not be nil else an error is thrown. +func (k *VaultsCollection) Add(vault Vault) error { + if utils.Empty(vault.ID) { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + var searchBy []string + searchBy = append(searchBy, *vault.ID) + if !utils.Empty(vault.Prefix) { + searchBy = append(searchBy, *vault.Prefix) + } + _, err := getVault(txn, searchBy...) + if err == nil { + return fmt.Errorf("inserting vault %v: %w", vault.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + err = txn.Insert(vaultTableName, &vault) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func getVault(txn *memdb.Txn, IDs ...string) (*Vault, error) { + for _, id := range IDs { + res, err := multiIndexLookupUsingTxn(txn, vaultTableName, + []string{"prefix", "id"}, id) + if errors.Is(err, ErrNotFound) { + continue + } + if err != nil { + return nil, err + } + + vault, ok := res.(*Vault) + if !ok { + panic(unexpectedType) + } + return &Vault{Vault: *vault.DeepCopy()}, nil + } + return nil, ErrNotFound +} + +// Get gets a vault by prefix or ID. +func (k *VaultsCollection) Get(prefixOrID string) (*Vault, error) { + if prefixOrID == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + vault, err := getVault(txn, prefixOrID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return vault, nil +} + +// Update udpates an existing vault. +func (k *VaultsCollection) Update(vault Vault) error { + if utils.Empty(vault.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteVault(txn, *vault.ID) + if err != nil { + return err + } + + err = txn.Insert(vaultTableName, &vault) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteVault(txn *memdb.Txn, nameOrID string) error { + vault, err := getVault(txn, nameOrID) + if err != nil { + return err + } + + err = txn.Delete(vaultTableName, vault) + if err != nil { + return err + } + return nil +} + +// Delete deletes a vault by its prefix or ID. +func (k *VaultsCollection) Delete(prefixOrID string) error { + if prefixOrID == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteVault(txn, prefixOrID) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets all vaults in the state. +func (k *VaultsCollection) GetAll() ([]*Vault, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(vaultTableName, all, true) + if err != nil { + return nil, err + } + + var res []*Vault + for el := iter.Next(); el != nil; el = iter.Next() { + v, ok := el.(*Vault) + if !ok { + panic(unexpectedType) + } + res = append(res, &Vault{Vault: *v.DeepCopy()}) + } + txn.Commit() + return res, nil +} diff --git a/pkg/types/aclgroup.go b/pkg/types/aclgroup.go new file mode 100644 index 0000000..2b5c201 --- /dev/null +++ b/pkg/types/aclgroup.go @@ -0,0 +1,185 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// aclGroupCRUD implements crud.Actions interface. +type aclGroupCRUD struct { + client *kong.Client +} + +func aclGroupFromStruct(arg crud.Event) *state.ACLGroup { + aclGroup, ok := arg.Obj.(*state.ACLGroup) + if !ok { + panic("unexpected type, expected *state.ACLGroup") + } + + return aclGroup +} + +// Create creates a Route in Kong. +// The arg should be of type crud.Event, containing the aclGroup to be created, +// else the function will panic. +// It returns a the created *state.Route. +func (s *aclGroupCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + aclGroup := aclGroupFromStruct(event) + cid := "" + if !utils.Empty(aclGroup.Consumer.Username) { + cid = *aclGroup.Consumer.Username + } + if !utils.Empty(aclGroup.Consumer.ID) { + cid = *aclGroup.Consumer.ID + } + createdACLGroup, err := s.client.ACLs.Create(ctx, &cid, + &aclGroup.ACLGroup) + if err != nil { + return nil, err + } + return &state.ACLGroup{ACLGroup: *createdACLGroup}, nil +} + +// Delete deletes a Route in Kong. +// The arg should be of type crud.Event, containing the aclGroup to be deleted, +// else the function will panic. +// It returns a the deleted *state.Route. +func (s *aclGroupCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + aclGroup := aclGroupFromStruct(event) + cid := "" + if !utils.Empty(aclGroup.Consumer.Username) { + cid = *aclGroup.Consumer.Username + } + if !utils.Empty(aclGroup.Consumer.ID) { + cid = *aclGroup.Consumer.ID + } + err := s.client.ACLs.Delete(ctx, &cid, aclGroup.ID) + if err != nil { + return nil, err + } + return aclGroup, nil +} + +// Update updates a Route in Kong. +// The arg should be of type crud.Event, containing the aclGroup to be updated, +// else the function will panic. +// It returns a the updated *state.Route. +func (s *aclGroupCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + aclGroup := aclGroupFromStruct(event) + + cid := "" + if !utils.Empty(aclGroup.Consumer.Username) { + cid = *aclGroup.Consumer.Username + } + if !utils.Empty(aclGroup.Consumer.ID) { + cid = *aclGroup.Consumer.ID + } + updatedACLGroup, err := s.client.ACLs.Create(ctx, &cid, &aclGroup.ACLGroup) + if err != nil { + return nil, err + } + return &state.ACLGroup{ACLGroup: *updatedACLGroup}, nil +} + +type aclGroupDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *aclGroupDiffer) Deletes(handler func(crud.Event) error) error { + currentACLGroups, err := d.currentState.ACLGroups.GetAll() + if err != nil { + return fmt.Errorf("error fetching acls from state: %w", err) + } + + for _, aclGroup := range currentACLGroups { + n, err := d.deleteACLGroup(aclGroup) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *aclGroupDiffer) deleteACLGroup(aclGroup *state.ACLGroup) (*crud.Event, error) { + // lookup by consumerID and Group + _, err := d.targetState.ACLGroups.Get(*aclGroup.Consumer.ID, *aclGroup.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: aclGroup, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up acl %q: %w", *aclGroup.Group, err) + } + return nil, nil +} + +func (d *aclGroupDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetACLGroups, err := d.targetState.ACLGroups.GetAll() + if err != nil { + return fmt.Errorf("error fetching acls from state: %w", err) + } + + for _, aclGroup := range targetACLGroups { + n, err := d.createUpdateACLGroup(aclGroup) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *aclGroupDiffer) createUpdateACLGroup(aclGroup *state.ACLGroup) (*crud.Event, error) { + aclGroup = &state.ACLGroup{ACLGroup: *aclGroup.DeepCopy()} + currentACLGroup, err := d.currentState.ACLGroups.Get( + *aclGroup.Consumer.ID, *aclGroup.ID) + if errors.Is(err, state.ErrNotFound) { + // aclGroup not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: aclGroup, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up acl %q: %w", + *aclGroup.Group, err) + } + // found, check if update needed + + if !currentACLGroup.EqualWithOpts(aclGroup, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: aclGroup, + OldObj: currentACLGroup, + }, nil + } + return nil, nil +} diff --git a/pkg/types/basicauth.go b/pkg/types/basicauth.go new file mode 100644 index 0000000..fbdaea2 --- /dev/null +++ b/pkg/types/basicauth.go @@ -0,0 +1,199 @@ +package types + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/kong/deck/cprint" + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// basicAuthCRUD implements crud.Actions interface. +type basicAuthCRUD struct { + client *kong.Client +} + +func basicAuthFromStruct(arg crud.Event) *state.BasicAuth { + basicAuth, ok := arg.Obj.(*state.BasicAuth) + if !ok { + panic("unexpected type, expected *state.BasicAuth") + } + + return basicAuth +} + +// Create creates a Route in Kong. +// The arg should be of type crud.Event, containing the basicAuth to be created, +// else the function will panic. +// It returns a the created *state.Route. +func (s *basicAuthCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + basicAuth := basicAuthFromStruct(event) + cid := "" + if !utils.Empty(basicAuth.Consumer.Username) { + cid = *basicAuth.Consumer.Username + } + if !utils.Empty(basicAuth.Consumer.ID) { + cid = *basicAuth.Consumer.ID + } + createdBasicAuth, err := s.client.BasicAuths.Create(ctx, &cid, + &basicAuth.BasicAuth) + if err != nil { + return nil, err + } + return &state.BasicAuth{BasicAuth: *createdBasicAuth}, nil +} + +// Delete deletes a Route in Kong. +// The arg should be of type crud.Event, containing the basicAuth to be deleted, +// else the function will panic. +// It returns a the deleted *state.Route. +func (s *basicAuthCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + basicAuth := basicAuthFromStruct(event) + cid := "" + if !utils.Empty(basicAuth.Consumer.Username) { + cid = *basicAuth.Consumer.Username + } + if !utils.Empty(basicAuth.Consumer.ID) { + cid = *basicAuth.Consumer.ID + } + err := s.client.BasicAuths.Delete(ctx, &cid, basicAuth.ID) + if err != nil { + return nil, err + } + return basicAuth, nil +} + +// Update updates a Route in Kong. +// The arg should be of type crud.Event, containing the basicAuth to be updated, +// else the function will panic. +// It returns a the updated *state.Route. +func (s *basicAuthCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + basicAuth := basicAuthFromStruct(event) + + cid := "" + if !utils.Empty(basicAuth.Consumer.Username) { + cid = *basicAuth.Consumer.Username + } + if !utils.Empty(basicAuth.Consumer.ID) { + cid = *basicAuth.Consumer.ID + } + updatedBasicAuth, err := s.client.BasicAuths.Create(ctx, &cid, &basicAuth.BasicAuth) + if err != nil { + return nil, err + } + return &state.BasicAuth{BasicAuth: *updatedBasicAuth}, nil +} + +type basicAuthDiffer struct { + kind crud.Kind + once sync.Once + + currentState, targetState *state.KongState +} + +func (d *basicAuthDiffer) warnBasicAuth() { + const ( + basicAuthPasswordWarning = "Warning: import/export of basic-auth" + + "credentials using decK doesn't work due to hashing of passwords in Kong." + ) + d.once.Do(func() { + cprint.UpdatePrintln(basicAuthPasswordWarning) + }) +} + +func (d *basicAuthDiffer) Deletes(handler func(crud.Event) error) error { + currentBasicAuths, err := d.currentState.BasicAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching basic-auths from state: %w", err) + } + + for _, basicAuth := range currentBasicAuths { + n, err := d.deleteBasicAuth(basicAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *basicAuthDiffer) deleteBasicAuth(basicAuth *state.BasicAuth) (*crud.Event, error) { + d.warnBasicAuth() + _, err := d.targetState.BasicAuths.Get(*basicAuth.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: basicAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up basic-auth %q: %w", + *basicAuth.Username, err) + } + return nil, nil +} + +func (d *basicAuthDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetBasicAuths, err := d.targetState.BasicAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching basic-auths from state: %w", err) + } + + for _, basicAuth := range targetBasicAuths { + n, err := d.createUpdateBasicAuth(basicAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *basicAuthDiffer) createUpdateBasicAuth(basicAuth *state.BasicAuth) (*crud.Event, error) { + d.warnBasicAuth() + basicAuth = &state.BasicAuth{BasicAuth: *basicAuth.DeepCopy()} + currentBasicAuth, err := d.currentState.BasicAuths.Get(*basicAuth.ID) + if errors.Is(err, state.ErrNotFound) { + // basicAuth not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: basicAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up basic-auth %q: %w", + *basicAuth.Username, err) + } + // found, check if update needed + + if !currentBasicAuth.EqualWithOpts(basicAuth, false, true, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: basicAuth, + OldObj: currentBasicAuth, + }, nil + } + return nil, nil +} diff --git a/pkg/types/ca_cert.go b/pkg/types/ca_cert.go new file mode 100644 index 0000000..62f09cd --- /dev/null +++ b/pkg/types/ca_cert.go @@ -0,0 +1,165 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// caCertificateCRUD implements crud.Actions interface. +type caCertificateCRUD struct { + client *kong.Client +} + +func caCertFromStruct(arg crud.Event) *state.CACertificate { + caCert, ok := arg.Obj.(*state.CACertificate) + if !ok { + panic("unexpected type, expected *state.CACertificate") + } + return caCert +} + +// Create creates a CACertificate in Kong. +// The arg should be of type crud.Event, containing the certificate to be created, +// else the function will panic. +// It returns a the created *state.CACertificate. +func (s *caCertificateCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + certificate := caCertFromStruct(event) + createdCertificate, err := s.client.CACertificates.Create(ctx, + &certificate.CACertificate) + if err != nil { + return nil, err + } + return &state.CACertificate{CACertificate: *createdCertificate}, nil +} + +// Delete deletes a CACertificate in Kong. +// The arg should be of type crud.Event, containing the certificate to be deleted, +// else the function will panic. +// It returns a the deleted *state.CACertificate. +func (s *caCertificateCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + certificate := caCertFromStruct(event) + err := s.client.CACertificates.Delete(ctx, certificate.ID) + if err != nil { + return nil, err + } + return certificate, nil +} + +// Update updates a CACertificate in Kong. +// The arg should be of type crud.Event, containing the certificate to be updated, +// else the function will panic. +// It returns a the updated *state.CACertificate. +func (s *caCertificateCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + certificate := caCertFromStruct(event) + updatedCertificate, err := s.client.CACertificates.Create(ctx, + &certificate.CACertificate) + if err != nil { + return nil, err + } + return &state.CACertificate{CACertificate: *updatedCertificate}, nil +} + +type caCertificateDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *caCertificateDiffer) Deletes(handler func(crud.Event) error) error { + currentCACertificates, err := d.currentState.CACertificates.GetAll() + if err != nil { + return fmt.Errorf("error fetching caCertificates from state: %w", err) + } + + for _, certificate := range currentCACertificates { + n, err := d.deleteCACertificate(certificate) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *caCertificateDiffer) deleteCACertificate( + caCert *state.CACertificate, +) (*crud.Event, error) { + _, err := d.targetState.CACertificates.Get(*caCert.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: caCert, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up caCertificate %q: %w", + caCert.FriendlyName(), err) + } + return nil, nil +} + +func (d *caCertificateDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetCACertificates, err := d.targetState.CACertificates.GetAll() + if err != nil { + return fmt.Errorf("error fetching caCertificates from state: %w", err) + } + + for _, caCert := range targetCACertificates { + n, err := d.createUpdateCACertificate(caCert) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *caCertificateDiffer) createUpdateCACertificate( + caCert *state.CACertificate, +) (*crud.Event, error) { + caCertCopy := &state.CACertificate{CACertificate: *caCert.DeepCopy()} + currentCACert, err := d.currentState.CACertificates.Get(*caCert.ID) + + if errors.Is(err, state.ErrNotFound) { + // caCertificate not present, create it + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: caCertCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up caCertificate %q: %w", + caCert.FriendlyName(), err) + } + + // found, check if update needed + if !currentCACert.EqualWithOpts(caCertCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: caCertCopy, + OldObj: currentCACert, + }, nil + } + return nil, nil +} diff --git a/pkg/types/certificate.go b/pkg/types/certificate.go new file mode 100644 index 0000000..8381168 --- /dev/null +++ b/pkg/types/certificate.go @@ -0,0 +1,202 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// certificateCRUD implements crud.Actions interface. +type certificateCRUD struct { + client *kong.Client + isKonnect bool +} + +func certificateFromStruct(arg crud.Event) *state.Certificate { + certificate, ok := arg.Obj.(*state.Certificate) + if !ok { + panic("unexpected type, expected *state.certificate") + } + return certificate +} + +// Create creates a Certificate in Kong. +// The arg should be of type crud.Event, containing the certificate to be created, +// else the function will panic. +// It returns a the created *state.Certificate. +func (s *certificateCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + certificate := certificateFromStruct(event) + if s.isKonnect { + certificate.SNIs = nil + } + createdCertificate, err := s.client.Certificates.Create(ctx, &certificate.Certificate) + if err != nil { + return nil, err + } + return &state.Certificate{Certificate: *createdCertificate}, nil +} + +// Delete deletes a Certificate in Kong. +// The arg should be of type crud.Event, containing the certificate to be deleted, +// else the function will panic. +// It returns a the deleted *state.Certificate. +func (s *certificateCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + certificate := certificateFromStruct(event) + err := s.client.Certificates.Delete(ctx, certificate.ID) + if err != nil { + return nil, err + } + return certificate, nil +} + +// Update updates a Certificate in Kong. +// The arg should be of type crud.Event, containing the certificate to be updated, +// else the function will panic. +// It returns a the updated *state.Certificate. +func (s *certificateCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + certificate := certificateFromStruct(event) + + if s.isKonnect { + certificate.SNIs = nil + } + updatedCertificate, err := s.client.Certificates.Create(ctx, &certificate.Certificate) + if err != nil { + return nil, err + } + return &state.Certificate{Certificate: *updatedCertificate}, nil +} + +type certificateDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState + + isKonnect bool +} + +func (d *certificateDiffer) Deletes(handler func(crud.Event) error) error { + currentCertificates, err := d.currentState.Certificates.GetAll() + if err != nil { + return fmt.Errorf("error fetching certificates from state: %w", err) + } + + for _, certificate := range currentCertificates { + n, err := d.deleteCertificate(certificate) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *certificateDiffer) deleteCertificate( + certificate *state.Certificate, +) (*crud.Event, error) { + _, err := d.targetState.Certificates.Get(*certificate.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: certificate, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up certificate %q': %w", + certificate.FriendlyName(), err) + } + return nil, nil +} + +func (d *certificateDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetCertificates, err := d.targetState.Certificates.GetAll() + if err != nil { + return fmt.Errorf("error fetching certificates from state: %w", err) + } + + for _, certificate := range targetCertificates { + n, err := d.createUpdateCertificate(certificate) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *certificateDiffer) createUpdateCertificate( + certificate *state.Certificate, +) (*crud.Event, error) { + certificateCopy := &state.Certificate{Certificate: *certificate.DeepCopy()} + currentCertificate, err := d.currentState.Certificates.Get(*certificate.ID) + + if d.isKonnect { + certificateCopy.SNIs = nil + if currentCertificate != nil { + currentCertificate.SNIs = nil + } + } + + if errors.Is(err, state.ErrNotFound) { + // certificate not present, create it + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: certificateCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up certificate %q: %w", + certificate.FriendlyName(), err) + } + + // found, check if update needed + if !currentCertificate.EqualWithOpts(certificateCopy, false, true) { + // Certificate and SNI objects have a special relationship. A PUT request + // (which we use for updates) with a certificate that contains no SNI + // children will in fact delete any existing SNI objects associated with + // that certificate, rather than leaving them as-is. + + // To work around this issues, we set SNIs on certificates here using the + // current certificate's SNI list. If there are changes to the SNIs, + // subsequent actions on the SNI objects will handle those. + if !d.isKonnect { + currentSNIs, err := d.currentState.SNIs.GetAllByCertID(*currentCertificate.ID) + if err != nil { + return nil, fmt.Errorf("error looking up current certificate SNIs %q: %w", + certificate.FriendlyName(), err) + } + sniNames := make([]*string, 0) + for _, s := range currentSNIs { + sniNames = append(sniNames, s.Name) + } + + certificateCopy.SNIs = sniNames + currentCertificate.SNIs = sniNames + } + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: certificateCopy, + OldObj: currentCertificate, + }, nil + } + return nil, nil +} diff --git a/pkg/types/consumer.go b/pkg/types/consumer.go new file mode 100644 index 0000000..4d5b52c --- /dev/null +++ b/pkg/types/consumer.go @@ -0,0 +1,228 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// consumerCRUD implements crud.Actions interface. +type consumerCRUD struct { + client *kong.Client +} + +func consumerFromStruct(arg crud.Event) *state.Consumer { + consumer, ok := arg.Obj.(*state.Consumer) + if !ok { + panic("unexpected type, expected *state.consumer") + } + return consumer +} + +// Create creates a Consumer in Kong. +// The arg should be of type crud.Event, containing the consumer to be created, +// else the function will panic. +// It returns a the created *state.Consumer. +func (s *consumerCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumer := consumerFromStruct(event) + createdConsumer, err := s.client.Consumers.Create(ctx, &consumer.Consumer) + if err != nil { + return nil, err + } + return &state.Consumer{Consumer: *createdConsumer}, nil +} + +// Delete deletes a Consumer in Kong. +// The arg should be of type crud.Event, containing the consumer to be deleted, +// else the function will panic. +// It returns a the deleted *state.Consumer. +func (s *consumerCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumer := consumerFromStruct(event) + err := s.client.Consumers.Delete(ctx, consumer.ID) + if err != nil { + return nil, err + } + return consumer, nil +} + +// Update updates a Consumer in Kong. +// The arg should be of type crud.Event, containing the consumer to be updated, +// else the function will panic. +// It returns a the updated *state.Consumer. +func (s *consumerCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumer := consumerFromStruct(event) + + updatedConsumer, err := s.client.Consumers.Create(ctx, &consumer.Consumer) + if err != nil { + return nil, err + } + return &state.Consumer{Consumer: *updatedConsumer}, nil +} + +type consumerDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *consumerDiffer) Deletes(handler func(crud.Event) error) error { + currentConsumers, err := d.currentState.Consumers.GetAll() + if err != nil { + return fmt.Errorf("error fetching consumers from state: %w", err) + } + + for _, consumer := range currentConsumers { + n, err := d.deleteConsumer(consumer) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *consumerDiffer) deleteConsumer(consumer *state.Consumer) (*crud.Event, error) { + _, err := d.targetState.Consumers.GetByIDOrUsername(*consumer.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: consumer, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up consumer %q: %w", + consumer.FriendlyName(), err) + } + return nil, nil +} + +func (d *consumerDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetConsumers, err := d.targetState.Consumers.GetAll() + if err != nil { + return fmt.Errorf("error fetching consumers from state: %w", err) + } + + for _, consumer := range targetConsumers { + n, err := d.createUpdateConsumer(consumer) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *consumerDiffer) createUpdateConsumer(consumer *state.Consumer) (*crud.Event, error) { + consumerCopy := &state.Consumer{Consumer: *consumer.DeepCopy()} + currentConsumer, err := d.currentState.Consumers.GetByIDOrUsername(*consumer.ID) + + if errors.Is(err, state.ErrNotFound) { + // consumer not present, create it + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: consumerCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up consumer %q: %w", + consumer.FriendlyName(), err) + } + + // found, check if update needed + if !currentConsumer.EqualWithOpts(consumerCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: consumerCopy, + OldObj: currentConsumer, + }, nil + } + return nil, nil +} + +func (d *consumerDiffer) DuplicatesDeletes() ([]crud.Event, error) { + targetConsumers, err := d.targetState.Consumers.GetAll() + if err != nil { + return nil, fmt.Errorf("error fetching consumers from state: %w", err) + } + + var events []crud.Event + for _, targetConsumer := range targetConsumers { + event, err := d.deleteDuplicateConsumer(targetConsumer) + if err != nil { + return nil, err + } + if event != nil { + events = append(events, *event) + } + } + + return events, nil +} + +func (d *consumerDiffer) deleteDuplicateConsumer(targetConsumer *state.Consumer) (*crud.Event, error) { + var ( + idOrUsername string + + currentConsumer *state.Consumer + err error + ) + + if targetConsumer.Username != nil { + idOrUsername = *targetConsumer.Username + } else if targetConsumer.ID != nil { + idOrUsername = *targetConsumer.ID + } + + if idOrUsername != "" { + currentConsumer, err = d.currentState.Consumers.GetByIDOrUsername(idOrUsername) + } + if errors.Is(err, state.ErrNotFound) || idOrUsername == "" { + if targetConsumer.CustomID != nil { + currentConsumer, err = d.currentState.Consumers.GetByCustomID(*targetConsumer.CustomID) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up consumer by custom_id %q: %w", + *targetConsumer.Username, err) + } + } else { + return nil, nil + } + } + if err != nil { + return nil, fmt.Errorf("error looking up consumer by username or id %q: %w", + *targetConsumer.Username, err) + } + + if *currentConsumer.ID != *targetConsumer.ID { + return &crud.Event{ + Op: crud.Delete, + Kind: "consumer", + Obj: currentConsumer, + }, nil + } + + return nil, nil +} diff --git a/pkg/types/consumer_group.go b/pkg/types/consumer_group.go new file mode 100644 index 0000000..772a454 --- /dev/null +++ b/pkg/types/consumer_group.go @@ -0,0 +1,183 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// consumerGroupCRUD implements crud.Actions interface. +type consumerGroupCRUD struct { + client *kong.Client + isKonnect bool +} + +func consumerGroupFromStruct(arg crud.Event) *state.ConsumerGroup { + consumerGroup, ok := arg.Obj.(*state.ConsumerGroup) + if !ok { + panic("unexpected type, expected *state.ConsumerGroup") + } + return consumerGroup +} + +// Create creates a consumerGroup in Kong. +// The arg should be of type crud.Event, containing the consumerGroup to be created, +// else the function will panic. +// It returns the created *state.consumerGroup. +func (s *consumerGroupCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumerGroup := consumerGroupFromStruct(event) + + var createdConsumerGroup *kong.ConsumerGroup + var err error + if s.isKonnect { + createdConsumerGroup, err = konnect.CreateConsumerGroup(ctx, s.client, &consumerGroup.ConsumerGroup) + } else { + createdConsumerGroup, err = s.client.ConsumerGroups.Create(ctx, &consumerGroup.ConsumerGroup) + } + if err != nil { + return nil, err + } + return &state.ConsumerGroup{ConsumerGroup: *createdConsumerGroup}, nil +} + +// Delete deletes a consumerGroup in Kong. +// The arg should be of type crud.Event, containing the consumerGroup to be deleted, +// else the function will panic. +// It returns the deleted *state.consumerGroup. +func (s *consumerGroupCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumerGroup := consumerGroupFromStruct(event) + + var err error + if s.isKonnect { + err = konnect.DeleteConsumerGroup(ctx, s.client, consumerGroup.ConsumerGroup.ID) + } else { + err = s.client.ConsumerGroups.Delete(ctx, consumerGroup.ConsumerGroup.ID) + } + if err != nil { + return nil, err + } + return consumerGroup, nil +} + +// Update updates a consumerGroup in Kong. +// The arg should be of type crud.Event, containing the consumerGroup to be updated, +// else the function will panic. +// It returns the updated *state.consumerGroup. +func (s *consumerGroupCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumerGroup := consumerGroupFromStruct(event) + + var err error + var updatedConsumerGroup *kong.ConsumerGroup + if s.isKonnect { + updatedConsumerGroup, err = konnect.UpdateConsumerGroup(ctx, s.client, consumerGroup.ID, &consumerGroup.ConsumerGroup) + } else { + updatedConsumerGroup, err = s.client.ConsumerGroups.Update(ctx, &consumerGroup.ConsumerGroup) + } + if err != nil { + return nil, err + } + return &state.ConsumerGroup{ConsumerGroup: *updatedConsumerGroup}, nil +} + +type consumerGroupDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *consumerGroupDiffer) Deletes(handler func(crud.Event) error) error { + currentconsumerGroups, err := d.currentState.ConsumerGroups.GetAll() + if err != nil { + return fmt.Errorf("error fetching consumerGroups from state: %w", err) + } + + for _, consumerGroup := range currentconsumerGroups { + n, err := d.deleteConsumerGroup(consumerGroup) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *consumerGroupDiffer) deleteConsumerGroup(consumerGroup *state.ConsumerGroup) (*crud.Event, error) { + _, err := d.targetState.ConsumerGroups.Get(*consumerGroup.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: "consumer-group", + Obj: consumerGroup, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up consumerGroup %q: %w", + *consumerGroup.Name, err) + } + return nil, nil +} + +func (d *consumerGroupDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetconsumerGroups, err := d.targetState.ConsumerGroups.GetAll() + if err != nil { + return fmt.Errorf("error fetching consumerGroups from state: %w", err) + } + + for _, consumerGroup := range targetconsumerGroups { + n, err := d.createUpdateConsumerGroup(consumerGroup) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *consumerGroupDiffer) createUpdateConsumerGroup(consumerGroup *state.ConsumerGroup) (*crud.Event, + error, +) { + consumerGroupCopy := &state.ConsumerGroup{ConsumerGroup: *consumerGroup.DeepCopy()} + currentconsumerGroup, err := d.currentState.ConsumerGroups.Get(*consumerGroup.Name) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: "consumer-group", + Obj: consumerGroupCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up consumerGroup %v: %w", + *consumerGroup.Name, err) + } + + // found, check if update needed + if !currentconsumerGroup.EqualWithOpts(consumerGroupCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "consumer-group", + Obj: consumerGroupCopy, + OldObj: currentconsumerGroup, + }, nil + } + return nil, nil +} diff --git a/pkg/types/consumer_group_consumer.go b/pkg/types/consumer_group_consumer.go new file mode 100644 index 0000000..6daee66 --- /dev/null +++ b/pkg/types/consumer_group_consumer.go @@ -0,0 +1,220 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// consumerGroupConsumerCRUD implements crud.Actions interface. +type consumerGroupConsumerCRUD struct { + client *kong.Client + isKonnect bool +} + +func consumerGroupConsumerFromStruct(arg crud.Event) *state.ConsumerGroupConsumer { + consumerGroup, ok := arg.Obj.(*state.ConsumerGroupConsumer) + if !ok { + panic("unexpected type, expected *state.ConsumerGroupConsumer") + } + return consumerGroup +} + +// Create creates a consumerGroupConsumer in Kong. +// The arg should be of type crud.Event, containing the consumerGroupConsumer to be created, +// else the function will panic. +// It returns the created *state.consumerGroupConsumer. +func (s *consumerGroupConsumerCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumer := consumerGroupConsumerFromStruct(event) + + var err error + if s.isKonnect { + err = konnect.CreateConsumerGroupMember( + ctx, s.client, consumer.ConsumerGroup.ID, consumer.Consumer.ID, + ) + } else { + _, err = s.client.ConsumerGroupConsumers.Create(ctx, consumer.ConsumerGroup.ID, consumer.Consumer.Username) + } + if err != nil { + return nil, err + } + + return &state.ConsumerGroupConsumer{ + ConsumerGroupConsumer: kong.ConsumerGroupConsumer{ + Consumer: consumer.Consumer, + ConsumerGroup: consumer.ConsumerGroup, + }, + }, nil +} + +// Delete deletes a consumerGroupConsumer in Kong. +// The arg should be of type crud.Event, containing the consumerGroupConsumer to be deleted, +// else the function will panic. +// It returns the deleted *state.consumerGroupConsumer. +func (s *consumerGroupConsumerCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumer := consumerGroupConsumerFromStruct(event) + + var err error + if s.isKonnect { + err = konnect.DeleteConsumerGroupMember(ctx, s.client, consumer.ConsumerGroup.ID, consumer.Consumer.ID) + } else { + err = s.client.ConsumerGroupConsumers.Delete(ctx, consumer.ConsumerGroup.ID, consumer.Consumer.Username) + } + if err != nil { + return nil, err + } + + return consumer, nil +} + +// Update updates a consumerGroupConsumer in Kong. +// The arg should be of type crud.Event, containing the consumerGroupConsumer to be updated, +// else the function will panic. +// It returns the updated *state.consumerGroupConsumer. +func (s *consumerGroupConsumerCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + consumer := consumerGroupConsumerFromStruct(event) + + var err error + // delete the old member + if s.isKonnect { + err = konnect.DeleteConsumerGroupMember( + ctx, s.client, consumer.ConsumerGroup.ID, consumer.Consumer.ID, + ) + } else { + err = s.client.ConsumerGroupConsumers.Delete( + ctx, consumer.ConsumerGroup.ID, consumer.Consumer.Username, + ) + } + if err != nil { + return nil, err + } + + // recreate it + if s.isKonnect { + err = konnect.CreateConsumerGroupMember( + ctx, s.client, consumer.ConsumerGroup.ID, consumer.Consumer.ID, + ) + } else { + _, err = s.client.ConsumerGroupConsumers.Create( + ctx, consumer.ConsumerGroup.ID, consumer.Consumer.Username, + ) + } + if err != nil { + return nil, err + } + + return &state.ConsumerGroupConsumer{ + ConsumerGroupConsumer: kong.ConsumerGroupConsumer{ + Consumer: consumer.Consumer, + ConsumerGroup: consumer.ConsumerGroup, + }, + }, nil +} + +type consumerGroupConsumerDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *consumerGroupConsumerDiffer) Deletes(handler func(crud.Event) error) error { + currentConsumers, err := d.currentState.ConsumerGroupConsumers.GetAll() + if err != nil { + return fmt.Errorf("error fetching consumerGroupConsumers from state: %w", err) + } + + for _, consumer := range currentConsumers { + n, err := d.deleteConsumerGroupConsumer(consumer) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *consumerGroupConsumerDiffer) deleteConsumerGroupConsumer( + consumer *state.ConsumerGroupConsumer, +) (*crud.Event, error) { + _, err := d.targetState.ConsumerGroupConsumers.Get( + *consumer.Consumer.Username, *consumer.ConsumerGroup.ID, + ) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: "consumer-group-consumer", + Obj: consumer, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up consumerGroupConsumer %q: %w", + *consumer.Consumer.Username, err) + } + return nil, nil +} + +func (d *consumerGroupConsumerDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetConsumers, err := d.targetState.ConsumerGroupConsumers.GetAll() + if err != nil { + return fmt.Errorf("error fetching consumerGroupConsumers from state: %w", err) + } + + for _, consumer := range targetConsumers { + n, err := d.createUpdateConsumerGroupConsumer(consumer) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *consumerGroupConsumerDiffer) createUpdateConsumerGroupConsumer( + consumer *state.ConsumerGroupConsumer, +) (*crud.Event, error) { + consumerCopy := &state.ConsumerGroupConsumer{ConsumerGroupConsumer: *consumer.DeepCopy()} + currentConsumer, err := d.currentState.ConsumerGroupConsumers.Get( + *consumer.Consumer.Username, *consumer.ConsumerGroup.ID, + ) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: "consumer-group-consumer", + Obj: consumerCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up consumerGroupConsumer %v: %w", + *currentConsumer.Consumer.Username, err) + } + + // found, check if update needed + if !currentConsumer.EqualWithOpts(consumerCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "consumer-group-consumer", + Obj: consumerCopy, + OldObj: currentConsumer, + }, nil + } + return nil, nil +} diff --git a/pkg/types/consumer_group_plugin.go b/pkg/types/consumer_group_plugin.go new file mode 100644 index 0000000..4d37c47 --- /dev/null +++ b/pkg/types/consumer_group_plugin.go @@ -0,0 +1,177 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// consumerGroupPluginCRUD implements crud.Actions interface. +type consumerGroupPluginCRUD struct { + client *kong.Client + isKonnect bool +} + +func consumerGroupPluginFromStruct(arg crud.Event) *state.ConsumerGroupPlugin { + plugin, ok := arg.Obj.(*state.ConsumerGroupPlugin) + if !ok { + panic("unexpected type, expected *state.ConsumerGroupPlugin") + } + return plugin +} + +// Create creates a consumerGroupPlugin in Kong. +// The arg should be of type crud.Event, containing the consumerGroupPlugin to be created, +// else the function will panic. +// It returns the created *state.consumerGroupPlugin. +func (s *consumerGroupPluginCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + plugin := consumerGroupPluginFromStruct(event) + config := map[string]kong.Configuration{ + "config": plugin.Config, + } + var ( + res *kong.ConsumerGroupRLA + err error + ) + if s.isKonnect { + res, err = konnect.CreateRateLimitingAdvancedPlugin(ctx, s.client, plugin.ConsumerGroup.ID, plugin.Config) + if err != nil { + return nil, err + } + } else { + res, err = s.client.ConsumerGroups.UpdateRateLimitingAdvancedPlugin(ctx, plugin.ConsumerGroup.ID, config) + if err != nil { + return nil, err + } + } + return &state.ConsumerGroupPlugin{ + ConsumerGroupPlugin: kong.ConsumerGroupPlugin{ + Name: res.Plugin, + Config: res.Config, + ConsumerGroup: &kong.ConsumerGroup{ + ID: res.ConsumerGroup, + }, + }, + }, nil +} + +// Update updates a consumerGroupConsumer in Kong. +// The arg should be of type crud.Event, containing the consumerGroupConsumer to be updated, +// else the function will panic. +// It returns the updated *state.consumerGroupConsumer. +func (s *consumerGroupPluginCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + plugin := consumerGroupPluginFromStruct(event) + config := map[string]kong.Configuration{ + "config": plugin.Config, + } + var ( + res *kong.ConsumerGroupRLA + err error + ) + if s.isKonnect { + res, err = konnect.UpdateRateLimitingAdvancedPlugin(ctx, s.client, plugin.ConsumerGroup.ID, plugin.Config) + if err != nil { + return nil, err + } + } else { + res, err = s.client.ConsumerGroups.UpdateRateLimitingAdvancedPlugin(ctx, plugin.ConsumerGroup.ID, config) + if err != nil { + return nil, err + } + } + return &state.ConsumerGroupPlugin{ + ConsumerGroupPlugin: kong.ConsumerGroupPlugin{ + ID: plugin.ID, + Name: res.Plugin, + Config: res.Config, + ConsumerGroup: &kong.ConsumerGroup{ + ID: plugin.ConsumerGroup.ID, + Name: res.ConsumerGroup, + }, + }, + }, nil +} + +// Delete is just a placeholder because Admin API doesn't support DELETEs +// for consumer groups plugins. +func (s *consumerGroupPluginCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + plugin := consumerGroupPluginFromStruct(event) + if s.isKonnect { + err := konnect.DeleteRateLimitingAdvancedPlugin(ctx, s.client, plugin.ConsumerGroup.ID) + if err != nil { + return nil, err + } + return plugin, nil + } + return nil, nil +} + +type consumerGroupPluginDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *consumerGroupPluginDiffer) Deletes(_ func(crud.Event) error) error { + return nil +} + +func (d *consumerGroupPluginDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetPlugins, err := d.targetState.ConsumerGroupPlugins.GetAll() + if err != nil { + return fmt.Errorf("error fetching consumerGroupPlugins from state: %w", err) + } + + for _, plugin := range targetPlugins { + n, err := d.createUpdateConsumerGroupPlugin(plugin) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *consumerGroupPluginDiffer) createUpdateConsumerGroupPlugin( + plugin *state.ConsumerGroupPlugin, +) (*crud.Event, error) { + pluginCopy := &state.ConsumerGroupPlugin{ConsumerGroupPlugin: *plugin.DeepCopy()} + currentPlugin, err := d.currentState.ConsumerGroupPlugins.Get( + *plugin.Name, *plugin.ConsumerGroup.ID, + ) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: "consumer-group-plugin", + Obj: pluginCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up consumerGroupPlugin %v: %w", + *currentPlugin.ID, err) + } + + // found, check if update needed + if !currentPlugin.EqualWithOpts(pluginCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "consumer-group-plugin", + Obj: pluginCopy, + OldObj: currentPlugin, + }, nil + } + return nil, nil +} diff --git a/pkg/types/core.go b/pkg/types/core.go new file mode 100644 index 0000000..9d27a22 --- /dev/null +++ b/pkg/types/core.go @@ -0,0 +1,535 @@ +package types + +import ( + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +type Differ interface { + Deletes(func(crud.Event) error) error + CreateAndUpdates(func(crud.Event) error) error +} + +type DuplicatesDeleter interface { + // DuplicatesDeletes returns delete events for entities that have duplicates in the current and target state. + // A duplicate is defined as an entity with the same name but different ID. + DuplicatesDeletes() ([]crud.Event, error) +} + +type Entity interface { + Type() EntityType + CRUDActions() crud.Actions + PostProcessActions() crud.Actions + Differ() Differ +} + +type entityImpl struct { + typ EntityType + crudActions crud.Actions // needs to set client + postProcessActions crud.Actions // needs currentstate Set + differ Differ +} + +func (e entityImpl) Type() EntityType { + return e.typ +} + +func (e entityImpl) CRUDActions() crud.Actions { + return e.crudActions +} + +func (e entityImpl) PostProcessActions() crud.Actions { + return e.postProcessActions +} + +func (e entityImpl) Differ() Differ { + return e.differ +} + +type EntityOpts struct { + CurrentState *state.KongState + TargetState *state.KongState + KongClient *kong.Client + KonnectClient *konnect.Client + + IsKonnect bool +} + +// EntityType defines a type of entity that is managed by decK. +type EntityType string + +const ( + // Service identifies a Service in Kong. + Service EntityType = "service" + // Route identifies a Route in Kong. + Route EntityType = "route" + // Plugin identifies a Plugin in Kong. + Plugin EntityType = "plugin" + + // Certificate identifies a Certificate in Kong. + Certificate EntityType = "certificate" + // SNI identifies a SNI in Kong. + SNI EntityType = "sni" + // CACertificate identifies a CACertificate in Kong. + CACertificate EntityType = "ca-certificate" + + // Upstream identifies a Upstream in Kong. + Upstream EntityType = "upstream" + // Target identifies a Target in Kong. + Target EntityType = "target" + + // Consumer identifies a Consumer in Kong. + Consumer EntityType = "consumer" + // ConsumerGroup identifies a ConsumerGroup in Kong. + ConsumerGroup EntityType = "consumer-group" + // ConsumerGroupConsumer identifies a ConsumerGroupConsumer in Kong. + ConsumerGroupConsumer EntityType = "consumer-group-consumer" + // ConsumerGroupPlugin identifies a ConsumerGroupPlugin in Kong. + ConsumerGroupPlugin EntityType = "consumer-group-plugin" + // ACLGroup identifies a ACLGroup in Kong. + ACLGroup EntityType = "acl-group" + // BasicAuth identifies a BasicAuth in Kong. + BasicAuth EntityType = "basic-auth" + // HMACAuth identifies a HMACAuth in Kong. + HMACAuth EntityType = "hmac-auth" + // JWTAuth identifies a JWTAuth in Kong. + JWTAuth EntityType = "jwt-auth" + // MTLSAuth identifies a MTLSAuth in Kong. + MTLSAuth EntityType = "mtls-auth" + // KeyAuth identifies aKeyAuth in Kong. + KeyAuth EntityType = "key-auth" + // OAuth2Cred identifies a OAuth2Cred in Kong. + OAuth2Cred EntityType = "oauth2-cred" //nolint:gosec + + // RBACRole identifies a RBACRole in Kong Enterprise. + RBACRole EntityType = "rbac-role" + // RBACEndpointPermission identifies a RBACEndpointPermission in Kong Enterprise. + RBACEndpointPermission EntityType = "rbac-endpoint-permission" + + // ServicePackage identifies a ServicePackage in Konnect. + ServicePackage EntityType = "service-package" + // ServiceVersion identifies a ServiceVersion in Konnect. + ServiceVersion EntityType = "service-version" + // Document identifies a Document in Konnect. + Document EntityType = "document" + + // Vault identifies a Vault in Kong. + Vault EntityType = "vault" +) + +// AllTypes represents all types defined in the +// package. +var AllTypes = []EntityType{ + Service, Route, Plugin, + + Certificate, SNI, CACertificate, + + Upstream, Target, + + Consumer, + ConsumerGroup, ConsumerGroupConsumer, ConsumerGroupPlugin, + ACLGroup, BasicAuth, KeyAuth, + HMACAuth, JWTAuth, OAuth2Cred, + MTLSAuth, + + RBACRole, RBACEndpointPermission, + + ServicePackage, ServiceVersion, Document, + + Vault, +} + +func entityTypeToKind(t EntityType) crud.Kind { + return crud.Kind(t) +} + +func NewEntity(t EntityType, opts EntityOpts) (Entity, error) { + switch t { + case Service: + return entityImpl{ + typ: Service, + crudActions: &serviceCRUD{ + client: opts.KongClient, + }, + postProcessActions: &servicePostAction{ + currentState: opts.CurrentState, + }, + differ: &serviceDiffer{ + kind: entityTypeToKind(Service), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case Route: + return entityImpl{ + typ: Route, + crudActions: &routeCRUD{ + client: opts.KongClient, + }, + postProcessActions: &routePostAction{ + currentState: opts.CurrentState, + }, + differ: &routeDiffer{ + kind: entityTypeToKind(Route), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case Upstream: + return entityImpl{ + typ: Upstream, + crudActions: &upstreamCRUD{ + client: opts.KongClient, + }, + postProcessActions: &upstreamPostAction{ + currentState: opts.CurrentState, + }, + differ: &upstreamDiffer{ + kind: entityTypeToKind(Upstream), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case Target: + return entityImpl{ + typ: Target, + crudActions: &targetCRUD{ + client: opts.KongClient, + }, + postProcessActions: &targetPostAction{ + currentState: opts.CurrentState, + }, + differ: &targetDiffer{ + kind: entityTypeToKind(Target), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case Plugin: + return entityImpl{ + typ: Plugin, + crudActions: &pluginCRUD{ + client: opts.KongClient, + }, + postProcessActions: &pluginPostAction{ + currentState: opts.CurrentState, + }, + differ: &pluginDiffer{ + kind: entityTypeToKind(Plugin), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case Consumer: + return entityImpl{ + typ: Consumer, + crudActions: &consumerCRUD{ + client: opts.KongClient, + }, + postProcessActions: &consumerPostAction{ + currentState: opts.CurrentState, + }, + differ: &consumerDiffer{ + kind: entityTypeToKind(Consumer), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case ConsumerGroup: + return entityImpl{ + typ: ConsumerGroup, + crudActions: &consumerGroupCRUD{ + client: opts.KongClient, + isKonnect: opts.IsKonnect, + }, + postProcessActions: &consumerGroupPostAction{ + currentState: opts.CurrentState, + }, + differ: &consumerGroupDiffer{ + kind: entityTypeToKind(ConsumerGroup), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case ConsumerGroupConsumer: + return entityImpl{ + typ: ConsumerGroupConsumer, + crudActions: &consumerGroupConsumerCRUD{ + client: opts.KongClient, + isKonnect: opts.IsKonnect, + }, + postProcessActions: &consumerGroupConsumerPostAction{ + currentState: opts.CurrentState, + }, + differ: &consumerGroupConsumerDiffer{ + kind: entityTypeToKind(ConsumerGroupConsumer), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case ConsumerGroupPlugin: + return entityImpl{ + typ: ConsumerGroupPlugin, + crudActions: &consumerGroupPluginCRUD{ + client: opts.KongClient, + isKonnect: opts.IsKonnect, + }, + postProcessActions: &consumerGroupPluginPostAction{ + currentState: opts.CurrentState, + }, + differ: &consumerGroupPluginDiffer{ + kind: entityTypeToKind(ConsumerGroupPlugin), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case ServicePackage: + return entityImpl{ + typ: ServicePackage, + crudActions: &servicePackageCRUD{ + client: opts.KonnectClient, + }, + postProcessActions: &servicePackagePostAction{ + currentState: opts.CurrentState, + }, + differ: &servicePackageDiffer{ + kind: entityTypeToKind(ServicePackage), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case ServiceVersion: + return entityImpl{ + typ: ServiceVersion, + crudActions: &serviceVersionCRUD{ + client: opts.KonnectClient, + }, + postProcessActions: &serviceVersionPostAction{ + currentState: opts.CurrentState, + }, + differ: &serviceVersionDiffer{ + kind: entityTypeToKind(ServiceVersion), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case Document: + return entityImpl{ + typ: Document, + crudActions: &documentCRUD{ + client: opts.KonnectClient, + }, + postProcessActions: &documentPostAction{ + currentState: opts.CurrentState, + }, + differ: &documentDiffer{ + kind: entityTypeToKind(Document), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case Certificate: + return entityImpl{ + typ: Certificate, + crudActions: &certificateCRUD{ + client: opts.KongClient, + isKonnect: opts.IsKonnect, + }, + postProcessActions: &certificatePostAction{ + currentState: opts.CurrentState, + }, + differ: &certificateDiffer{ + kind: entityTypeToKind(Certificate), + currentState: opts.CurrentState, + targetState: opts.TargetState, + isKonnect: opts.IsKonnect, + }, + }, nil + case CACertificate: + return entityImpl{ + typ: CACertificate, + crudActions: &caCertificateCRUD{ + client: opts.KongClient, + }, + postProcessActions: &caCertificatePostAction{ + currentState: opts.CurrentState, + }, + differ: &caCertificateDiffer{ + kind: entityTypeToKind(CACertificate), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case SNI: + return entityImpl{ + typ: SNI, + crudActions: &sniCRUD{ + client: opts.KongClient, + }, + postProcessActions: &sniPostAction{ + currentState: opts.CurrentState, + }, + differ: &sniDiffer{ + kind: entityTypeToKind(SNI), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case RBACEndpointPermission: + return entityImpl{ + typ: RBACEndpointPermission, + crudActions: &rbacEndpointPermissionCRUD{ + client: opts.KongClient, + }, + postProcessActions: &rbacEndpointPermissionPostAction{ + currentState: opts.CurrentState, + }, + differ: &rbacEndpointPermissionDiffer{ + kind: entityTypeToKind(RBACEndpointPermission), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case RBACRole: + return entityImpl{ + typ: RBACRole, + crudActions: &rbacRoleCRUD{ + client: opts.KongClient, + }, + postProcessActions: &rbacRolePostAction{ + currentState: opts.CurrentState, + }, + differ: &rbacRoleDiffer{ + kind: entityTypeToKind(RBACRole), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case ACLGroup: + return entityImpl{ + typ: ACLGroup, + crudActions: &aclGroupCRUD{ + client: opts.KongClient, + }, + postProcessActions: &aclGroupPostAction{ + currentState: opts.CurrentState, + }, + differ: &aclGroupDiffer{ + kind: entityTypeToKind(ACLGroup), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case BasicAuth: + return entityImpl{ + typ: BasicAuth, + crudActions: &basicAuthCRUD{ + client: opts.KongClient, + }, + postProcessActions: &basicAuthPostAction{ + currentState: opts.CurrentState, + }, + differ: &basicAuthDiffer{ + kind: entityTypeToKind(BasicAuth), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case KeyAuth: + return entityImpl{ + typ: KeyAuth, + crudActions: &keyAuthCRUD{ + client: opts.KongClient, + }, + postProcessActions: &keyAuthPostAction{ + currentState: opts.CurrentState, + }, + differ: &keyAuthDiffer{ + kind: entityTypeToKind(KeyAuth), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case HMACAuth: + return entityImpl{ + typ: HMACAuth, + crudActions: &hmacAuthCRUD{ + client: opts.KongClient, + }, + postProcessActions: &hmacAuthPostAction{ + currentState: opts.CurrentState, + }, + differ: &hmacAuthDiffer{ + kind: entityTypeToKind(HMACAuth), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case JWTAuth: + return entityImpl{ + typ: JWTAuth, + crudActions: &jwtAuthCRUD{ + client: opts.KongClient, + }, + postProcessActions: &jwtAuthPostAction{ + currentState: opts.CurrentState, + }, + differ: &jwtAuthDiffer{ + kind: entityTypeToKind(JWTAuth), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case MTLSAuth: + return entityImpl{ + typ: MTLSAuth, + crudActions: &mtlsAuthCRUD{ + client: opts.KongClient, + }, + postProcessActions: &mtlsAuthPostAction{ + currentState: opts.CurrentState, + }, + differ: &mtlsAuthDiffer{ + kind: entityTypeToKind(MTLSAuth), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case OAuth2Cred: + return entityImpl{ + typ: OAuth2Cred, + crudActions: &oauth2CredCRUD{ + client: opts.KongClient, + }, + postProcessActions: &oauth2CredPostAction{ + currentState: opts.CurrentState, + }, + differ: &oauth2CredDiffer{ + kind: entityTypeToKind(OAuth2Cred), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case Vault: + return entityImpl{ + typ: Vault, + crudActions: &vaultCRUD{ + client: opts.KongClient, + }, + postProcessActions: &vaultPostAction{ + currentState: opts.CurrentState, + }, + differ: &vaultDiffer{ + kind: entityTypeToKind(Vault), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + default: + return nil, fmt.Errorf("unknown type: %q", t) + } +} diff --git a/pkg/types/document.go b/pkg/types/document.go new file mode 100644 index 0000000..8e4e5d1 --- /dev/null +++ b/pkg/types/document.go @@ -0,0 +1,183 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" +) + +// documentCRUD implements crud.Actions interface. +type documentCRUD struct { + client *konnect.Client +} + +func documentFromStruct(arg crud.Event) *state.Document { + d, ok := arg.Obj.(*state.Document) + if !ok { + panic("unexpected type, expected *state.Document") + } + return d +} + +func oldDocumentFromStruct(arg crud.Event) *state.Document { + d, ok := arg.OldObj.(*state.Document) + if !ok { + panic("unexpected type, expected *state.Document") + } + return d +} + +// Create creates a document in Konnect. +// The arg should be of type crud.Event, containing the document to be created, +// else the function will panic. +// It returns the created *state.Document. +func (s *documentCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + d := documentFromStruct(event) + createdDoc, err := s.client.Documents.Create(ctx, &d.Document) + if err != nil { + return nil, err + } + return &state.Document{Document: *createdDoc}, nil +} + +// Delete deletes a document in Konnect. +// The arg should be of type crud.Event, containing the document to be deleted, +// else the function will panic. +// It returns a the deleted *state.Document. +func (s *documentCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + d := documentFromStruct(event) + err := s.client.Documents.Delete(ctx, &d.Document) + if err != nil { + return nil, err + } + return d, nil +} + +// Update updates a document in Konnect. +// The arg should be of type crud.Event, containing the document to be updated, +// else the function will panic. +// It returns a the updated *state.Document. +func (s *documentCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + var ( + err error + updatedDoc *konnect.Document + ) + event := crud.EventFromArg(arg[0]) + document := documentFromStruct(event) + oldDocument := oldDocumentFromStruct(event) + + // if there is a change in document entity, make a PATCH + if !document.EqualWithOpts(oldDocument, false, true, true) { + documentCopy := &state.Document{Document: *document.ShallowCopy()} + updatedDoc, err = s.client.Documents.Update(ctx, &documentCopy.Document) + if err != nil { + return nil, err + } + } else { + updatedDoc = &document.Document + } + + return &state.Document{Document: *updatedDoc}, nil +} + +type documentDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *documentDiffer) Deletes(handler func(crud.Event) error) error { + currentDocuments, err := d.currentState.Documents.GetAll() + if err != nil { + return fmt.Errorf("error fetching documents from state: %w", err) + } + + for _, doc := range currentDocuments { + n, err := d.deleteDocument(doc) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *documentDiffer) deleteDocument(doc *state.Document) (*crud.Event, error) { + _, err := d.targetState.Documents.GetByParent(doc.Parent, *doc.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: "document", + Obj: doc, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up document %q: %w", + doc.Identifier(), err) + } + return nil, nil +} + +func (d *documentDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetDocuments, err := d.targetState.Documents.GetAll() + if err != nil { + return fmt.Errorf("error fetching documents from state: %w", err) + } + + for _, doc := range targetDocuments { + n, err := d.createUpdateDocument(doc) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *documentDiffer) createUpdateDocument(doc *state.Document) (*crud.Event, error) { + dCopy := &state.Document{Document: *doc.ShallowCopy()} + currentDoc, err := d.currentState.Documents.GetByParent(doc.Parent, *doc.ID) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: "document", + Obj: dCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up document %q: %w", + doc.Identifier(), err) + } + + // found, check if update needed + // Service Package-attached Documents fail equality checks if ignoreForeign + // is disabled. This appears to be related to an invalid diff detection for + // Service Versions attached to the package. + if !currentDoc.EqualWithOpts(dCopy, false, true, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "document", + Obj: dCopy, + OldObj: currentDoc, + }, nil + } + return nil, nil +} diff --git a/pkg/types/hmacauth.go b/pkg/types/hmacauth.go new file mode 100644 index 0000000..553a550 --- /dev/null +++ b/pkg/types/hmacauth.go @@ -0,0 +1,184 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// hmacAuthCRUD implements crud.Actions interface. +type hmacAuthCRUD struct { + client *kong.Client +} + +func hmacAuthFromStruct(arg crud.Event) *state.HMACAuth { + hmacAuth, ok := arg.Obj.(*state.HMACAuth) + if !ok { + panic("unexpected type, expected *state.HMACAuth") + } + + return hmacAuth +} + +// Create creates a Route in Kong. +// The arg should be of type crud.Event, containing the hmacAuth to be created, +// else the function will panic. +// It returns a the created *state.Route. +func (s *hmacAuthCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + hmacAuth := hmacAuthFromStruct(event) + cid := "" + if !utils.Empty(hmacAuth.Consumer.Username) { + cid = *hmacAuth.Consumer.Username + } + if !utils.Empty(hmacAuth.Consumer.ID) { + cid = *hmacAuth.Consumer.ID + } + createdHMACAuth, err := s.client.HMACAuths.Create(ctx, &cid, + &hmacAuth.HMACAuth) + if err != nil { + return nil, err + } + return &state.HMACAuth{HMACAuth: *createdHMACAuth}, nil +} + +// Delete deletes a Route in Kong. +// The arg should be of type crud.Event, containing the hmacAuth to be deleted, +// else the function will panic. +// It returns a the deleted *state.Route. +func (s *hmacAuthCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + hmacAuth := hmacAuthFromStruct(event) + cid := "" + if !utils.Empty(hmacAuth.Consumer.Username) { + cid = *hmacAuth.Consumer.Username + } + if !utils.Empty(hmacAuth.Consumer.ID) { + cid = *hmacAuth.Consumer.ID + } + err := s.client.HMACAuths.Delete(ctx, &cid, hmacAuth.ID) + if err != nil { + return nil, err + } + return hmacAuth, nil +} + +// Update updates a Route in Kong. +// The arg should be of type crud.Event, containing the hmacAuth to be updated, +// else the function will panic. +// It returns a the updated *state.Route. +func (s *hmacAuthCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + hmacAuth := hmacAuthFromStruct(event) + + cid := "" + if !utils.Empty(hmacAuth.Consumer.Username) { + cid = *hmacAuth.Consumer.Username + } + if !utils.Empty(hmacAuth.Consumer.ID) { + cid = *hmacAuth.Consumer.ID + } + updatedHMACAuth, err := s.client.HMACAuths.Create(ctx, &cid, &hmacAuth.HMACAuth) + if err != nil { + return nil, err + } + return &state.HMACAuth{HMACAuth: *updatedHMACAuth}, nil +} + +type hmacAuthDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *hmacAuthDiffer) Deletes(handler func(crud.Event) error) error { + currentHMACAuths, err := d.currentState.HMACAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching hmac-auths from state: %w", err) + } + + for _, hmacAuth := range currentHMACAuths { + n, err := d.deleteHMACAuth(hmacAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *hmacAuthDiffer) deleteHMACAuth(hmacAuth *state.HMACAuth) (*crud.Event, error) { + _, err := d.targetState.HMACAuths.Get(*hmacAuth.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: hmacAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up hmac-auth %q: %w", + *hmacAuth.Username, err) + } + return nil, nil +} + +func (d *hmacAuthDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetHMACAuths, err := d.targetState.HMACAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching hmac-auths from state: %w", err) + } + + for _, hmacAuth := range targetHMACAuths { + n, err := d.createUpdateHMACAuth(hmacAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *hmacAuthDiffer) createUpdateHMACAuth(hmacAuth *state.HMACAuth) (*crud.Event, error) { + hmacAuth = &state.HMACAuth{HMACAuth: *hmacAuth.DeepCopy()} + currentHMACAuth, err := d.currentState.HMACAuths.Get(*hmacAuth.ID) + if errors.Is(err, state.ErrNotFound) { + // hmacAuth not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: hmacAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up hmac-auth %q: %w", + *hmacAuth.Username, err) + } + // found, check if update needed + + if !currentHMACAuth.EqualWithOpts(hmacAuth, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: hmacAuth, + OldObj: currentHMACAuth, + }, nil + } + return nil, nil +} diff --git a/pkg/types/jwtauth.go b/pkg/types/jwtauth.go new file mode 100644 index 0000000..8ceba6f --- /dev/null +++ b/pkg/types/jwtauth.go @@ -0,0 +1,183 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// jwtAuthCRUD implements crud.Actions interface. +type jwtAuthCRUD struct { + client *kong.Client +} + +func jwtAuthFromStruct(arg crud.Event) *state.JWTAuth { + jwtAuth, ok := arg.Obj.(*state.JWTAuth) + if !ok { + panic("unexpected type, expected *state.JWTAuth") + } + + return jwtAuth +} + +// Create creates a Route in Kong. +// The arg should be of type crud.Event, containing the jwtAuth to be created, +// else the function will panic. +// It returns a the created *state.Route. +func (s *jwtAuthCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + jwtAuth := jwtAuthFromStruct(event) + cid := "" + if !utils.Empty(jwtAuth.Consumer.Username) { + cid = *jwtAuth.Consumer.Username + } + if !utils.Empty(jwtAuth.Consumer.ID) { + cid = *jwtAuth.Consumer.ID + } + createdJWTAuth, err := s.client.JWTAuths.Create(ctx, &cid, + &jwtAuth.JWTAuth) + if err != nil { + return nil, err + } + return &state.JWTAuth{JWTAuth: *createdJWTAuth}, nil +} + +// Delete deletes a Route in Kong. +// The arg should be of type crud.Event, containing the jwtAuth to be deleted, +// else the function will panic. +// It returns a the deleted *state.Route. +func (s *jwtAuthCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + jwtAuth := jwtAuthFromStruct(event) + cid := "" + if !utils.Empty(jwtAuth.Consumer.Username) { + cid = *jwtAuth.Consumer.Username + } + if !utils.Empty(jwtAuth.Consumer.ID) { + cid = *jwtAuth.Consumer.ID + } + err := s.client.JWTAuths.Delete(ctx, &cid, jwtAuth.ID) + if err != nil { + return nil, err + } + return jwtAuth, nil +} + +// Update updates a Route in Kong. +// The arg should be of type crud.Event, containing the jwtAuth to be updated, +// else the function will panic. +// It returns a the updated *state.Route. +func (s *jwtAuthCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + jwtAuth := jwtAuthFromStruct(event) + + cid := "" + if !utils.Empty(jwtAuth.Consumer.Username) { + cid = *jwtAuth.Consumer.Username + } + if !utils.Empty(jwtAuth.Consumer.ID) { + cid = *jwtAuth.Consumer.ID + } + updatedJWTAuth, err := s.client.JWTAuths.Create(ctx, &cid, &jwtAuth.JWTAuth) + if err != nil { + return nil, err + } + return &state.JWTAuth{JWTAuth: *updatedJWTAuth}, nil +} + +type jwtAuthDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *jwtAuthDiffer) Deletes(handler func(crud.Event) error) error { + currentJWTAuths, err := d.currentState.JWTAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching jwt-auths from state: %w", err) + } + + for _, jwtAuth := range currentJWTAuths { + n, err := d.deleteJWTAuth(jwtAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *jwtAuthDiffer) deleteJWTAuth(jwtAuth *state.JWTAuth) (*crud.Event, error) { + _, err := d.targetState.JWTAuths.Get(*jwtAuth.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: jwtAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up jwt-auth %q: %w", *jwtAuth.Key, err) + } + return nil, nil +} + +func (d *jwtAuthDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetJWTAuths, err := d.targetState.JWTAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching jwt-auths from state: %w", err) + } + + for _, jwtAuth := range targetJWTAuths { + n, err := d.createUpdateJWTAuth(jwtAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *jwtAuthDiffer) createUpdateJWTAuth(jwtAuth *state.JWTAuth) (*crud.Event, error) { + jwtAuth = &state.JWTAuth{JWTAuth: *jwtAuth.DeepCopy()} + currentJWTAuth, err := d.currentState.JWTAuths.Get(*jwtAuth.ID) + if errors.Is(err, state.ErrNotFound) { + // jwtAuth not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: jwtAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up jwt-auth %q: %w", + *jwtAuth.Key, err) + } + // found, check if update needed + + if !currentJWTAuth.EqualWithOpts(jwtAuth, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: jwtAuth, + OldObj: currentJWTAuth, + }, nil + } + return nil, nil +} diff --git a/pkg/types/keyauth.go b/pkg/types/keyauth.go new file mode 100644 index 0000000..ad6ff00 --- /dev/null +++ b/pkg/types/keyauth.go @@ -0,0 +1,170 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// keyAuthCRUD implements crud.Actions interface. +type keyAuthCRUD struct { + client *kong.Client +} + +func keyAuthFromStruct(arg crud.Event) *state.KeyAuth { + keyAuth, ok := arg.Obj.(*state.KeyAuth) + if !ok { + panic("unexpected type, expected *state.KeyAuth") + } + + return keyAuth +} + +// Create creates a Route in Kong. +// The arg should be of type crud.Event, containing the keyAuth to be created, +// else the function will panic. +// It returns a the created *state.Route. +func (s *keyAuthCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + keyAuth := keyAuthFromStruct(event) + createdKeyAuth, err := s.client.KeyAuths.Create(ctx, keyAuth.Consumer.ID, + &keyAuth.KeyAuth) + if err != nil { + return nil, err + } + return &state.KeyAuth{KeyAuth: *createdKeyAuth}, nil +} + +// Delete deletes a Route in Kong. +// The arg should be of type crud.Event, containing the keyAuth to be deleted, +// else the function will panic. +// It returns a the deleted *state.Route. +func (s *keyAuthCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + keyAuth := keyAuthFromStruct(event) + cid := "" + if !utils.Empty(keyAuth.Consumer.Username) { + cid = *keyAuth.Consumer.Username + } + if !utils.Empty(keyAuth.Consumer.ID) { + cid = *keyAuth.Consumer.ID + } + err := s.client.KeyAuths.Delete(ctx, &cid, keyAuth.ID) + if err != nil { + return nil, err + } + return keyAuth, nil +} + +// Update updates a Route in Kong. +// The arg should be of type crud.Event, containing the keyAuth to be updated, +// else the function will panic. +// It returns a the updated *state.Route. +func (s *keyAuthCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + keyAuth := keyAuthFromStruct(event) + + updatedKeyAuth, err := s.client.KeyAuths.Create(ctx, keyAuth.Consumer.ID, + &keyAuth.KeyAuth) + if err != nil { + return nil, err + } + return &state.KeyAuth{KeyAuth: *updatedKeyAuth}, nil +} + +type keyAuthDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *keyAuthDiffer) Deletes(handler func(crud.Event) error) error { + currentKeyAuths, err := d.currentState.KeyAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching key-auths from state: %w", err) + } + + for _, keyAuth := range currentKeyAuths { + n, err := d.deleteKeyAuth(keyAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *keyAuthDiffer) deleteKeyAuth(keyAuth *state.KeyAuth) (*crud.Event, error) { + _, err := d.targetState.KeyAuths.Get(*keyAuth.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: keyAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up key-auth %q: %w", *keyAuth.ID, err) + } + return nil, nil +} + +func (d *keyAuthDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetKeyAuths, err := d.targetState.KeyAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching key-auths from state: %w", err) + } + + for _, keyAuth := range targetKeyAuths { + n, err := d.createUpdateKeyAuth(keyAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *keyAuthDiffer) createUpdateKeyAuth(keyAuth *state.KeyAuth) (*crud.Event, error) { + keyAuth = &state.KeyAuth{KeyAuth: *keyAuth.DeepCopy()} + currentKeyAuth, err := d.currentState.KeyAuths.Get(*keyAuth.ID) + if errors.Is(err, state.ErrNotFound) { + // keyAuth not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: keyAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up key-auth %q: %w", + *keyAuth.ID, err) + } + // found, check if update needed + + if !currentKeyAuth.EqualWithOpts(keyAuth, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: keyAuth, + OldObj: currentKeyAuth, + }, nil + } + return nil, nil +} diff --git a/pkg/types/mtlsauth.go b/pkg/types/mtlsauth.go new file mode 100644 index 0000000..4e4f310 --- /dev/null +++ b/pkg/types/mtlsauth.go @@ -0,0 +1,170 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// mtlsAuthCRUD implements crud.Actions interface. +type mtlsAuthCRUD struct { + client *kong.Client +} + +func mtlsAuthFromStruct(arg crud.Event) *state.MTLSAuth { + mtlsAuth, ok := arg.Obj.(*state.MTLSAuth) + if !ok { + panic("unexpected type, expected *state.MTLSAuth") + } + + return mtlsAuth +} + +// Create creates an mtls-auth credential in Kong. +// The arg should be of type crud.Event, containing the mtlsAuth to be created, +// else the function will panic. +// It returns a the created *state.Route. +func (s *mtlsAuthCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + mtlsAuth := mtlsAuthFromStruct(event) + createdMTLSAuth, err := s.client.MTLSAuths.Create(ctx, mtlsAuth.Consumer.ID, + &mtlsAuth.MTLSAuth) + if err != nil { + return nil, err + } + return &state.MTLSAuth{MTLSAuth: *createdMTLSAuth}, nil +} + +// Delete deletes an mtls-auth credential in Kong. +// The arg should be of type crud.Event, containing the mtlsAuth to be deleted, +// else the function will panic. +// It returns a the deleted *state.Route. +func (s *mtlsAuthCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + mtlsAuth := mtlsAuthFromStruct(event) + cid := "" + if !utils.Empty(mtlsAuth.Consumer.Username) { + cid = *mtlsAuth.Consumer.Username + } + if !utils.Empty(mtlsAuth.Consumer.ID) { + cid = *mtlsAuth.Consumer.ID + } + err := s.client.MTLSAuths.Delete(ctx, &cid, mtlsAuth.ID) + if err != nil { + return nil, err + } + return mtlsAuth, nil +} + +// Update updates an mtls-auth credential in Kong. +// The arg should be of type crud.Event, containing the mtlsAuth to be updated, +// else the function will panic. +// It returns a the updated *state.Route. +func (s *mtlsAuthCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + mtlsAuth := mtlsAuthFromStruct(event) + + updatedMTLSAuth, err := s.client.MTLSAuths.Create(ctx, mtlsAuth.Consumer.ID, + &mtlsAuth.MTLSAuth) + if err != nil { + return nil, err + } + return &state.MTLSAuth{MTLSAuth: *updatedMTLSAuth}, nil +} + +type mtlsAuthDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *mtlsAuthDiffer) Deletes(handler func(crud.Event) error) error { + currentMTLSAuths, err := d.currentState.MTLSAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching mtls-auths from state: %w", err) + } + + for _, mtlsAuth := range currentMTLSAuths { + n, err := d.deleteMTLSAuth(mtlsAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *mtlsAuthDiffer) deleteMTLSAuth(mtlsAuth *state.MTLSAuth) (*crud.Event, error) { + _, err := d.targetState.MTLSAuths.Get(*mtlsAuth.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: mtlsAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up mtls-auth %q: %w", *mtlsAuth.ID, err) + } + return nil, nil +} + +func (d *mtlsAuthDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetMTLSAuths, err := d.targetState.MTLSAuths.GetAll() + if err != nil { + return fmt.Errorf("error fetching mtls-auths from state: %w", err) + } + + for _, mtlsAuth := range targetMTLSAuths { + n, err := d.createUpdateMTLSAuth(mtlsAuth) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *mtlsAuthDiffer) createUpdateMTLSAuth(mtlsAuth *state.MTLSAuth) (*crud.Event, error) { + mtlsAuth = &state.MTLSAuth{MTLSAuth: *mtlsAuth.DeepCopy()} + currentMTLSAuth, err := d.currentState.MTLSAuths.Get(*mtlsAuth.ID) + if errors.Is(err, state.ErrNotFound) { + // mtlsAuth not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: mtlsAuth, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up mtls-auth %q: %w", + *mtlsAuth.ID, err) + } + // found, check if update needed + + if !currentMTLSAuth.EqualWithOpts(mtlsAuth, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: mtlsAuth, + OldObj: currentMTLSAuth, + }, nil + } + return nil, nil +} diff --git a/pkg/types/oauth2.go b/pkg/types/oauth2.go new file mode 100644 index 0000000..9f447b8 --- /dev/null +++ b/pkg/types/oauth2.go @@ -0,0 +1,187 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/deck/utils" + "github.com/kong/go-kong/kong" +) + +// oauth2CredCRUD implements crud.Actions interface. +type oauth2CredCRUD struct { + client *kong.Client +} + +func oauth2CredFromStruct(arg crud.Event) *state.Oauth2Credential { + oauth2Cred, ok := arg.Obj.(*state.Oauth2Credential) + if !ok { + panic("unexpected type, expected *state.OAuth2") + } + + return oauth2Cred +} + +// Create creates a Route in Kong. +// The arg should be of type crud.Event, containing the oauth2Cred to be created, +// else the function will panic. +// It returns a the created *state.Route. +func (s *oauth2CredCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + oauth2Cred := oauth2CredFromStruct(event) + cid := "" + if !utils.Empty(oauth2Cred.Consumer.Username) { + cid = *oauth2Cred.Consumer.Username + } + if !utils.Empty(oauth2Cred.Consumer.ID) { + cid = *oauth2Cred.Consumer.ID + } + createdOauth2Cred, err := s.client.Oauth2Credentials.Create(ctx, &cid, + &oauth2Cred.Oauth2Credential) + if err != nil { + return nil, err + } + return &state.Oauth2Credential{Oauth2Credential: *createdOauth2Cred}, nil +} + +// Delete deletes a Route in Kong. +// The arg should be of type crud.Event, containing the oauth2Cred to be deleted, +// else the function will panic. +// It returns a the deleted *state.Route. +func (s *oauth2CredCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + oauth2Cred := oauth2CredFromStruct(event) + cid := "" + if !utils.Empty(oauth2Cred.Consumer.Username) { + cid = *oauth2Cred.Consumer.Username + } + if !utils.Empty(oauth2Cred.Consumer.ID) { + cid = *oauth2Cred.Consumer.ID + } + err := s.client.Oauth2Credentials.Delete(ctx, &cid, oauth2Cred.ID) + if err != nil { + return nil, err + } + return oauth2Cred, nil +} + +// Update updates a Route in Kong. +// The arg should be of type crud.Event, containing the oauth2Cred to be updated, +// else the function will panic. +// It returns a the updated *state.Route. +func (s *oauth2CredCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + oauth2Cred := oauth2CredFromStruct(event) + + cid := "" + if !utils.Empty(oauth2Cred.Consumer.Username) { + cid = *oauth2Cred.Consumer.Username + } + if !utils.Empty(oauth2Cred.Consumer.ID) { + cid = *oauth2Cred.Consumer.ID + } + updatedOauth2Cred, err := s.client.Oauth2Credentials.Create(ctx, &cid, + &oauth2Cred.Oauth2Credential) + if err != nil { + return nil, err + } + return &state.Oauth2Credential{Oauth2Credential: *updatedOauth2Cred}, nil +} + +type oauth2CredDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *oauth2CredDiffer) Deletes(handler func(crud.Event) error) error { + currentOauth2Creds, err := d.currentState.Oauth2Creds.GetAll() + if err != nil { + return fmt.Errorf("error fetching oauth2-cred from state: %w", err) + } + + for _, oauth2Cred := range currentOauth2Creds { + n, err := d.deleteOauth2Cred(oauth2Cred) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *oauth2CredDiffer) deleteOauth2Cred(oauth2Cred *state.Oauth2Credential) ( + *crud.Event, error, +) { + _, err := d.targetState.Oauth2Creds.Get(*oauth2Cred.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: oauth2Cred, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up oauth2-cred %q: %w", *oauth2Cred.Name, err) + } + return nil, nil +} + +func (d *oauth2CredDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetOauth2Creds, err := d.targetState.Oauth2Creds.GetAll() + if err != nil { + return fmt.Errorf("error fetching oauth2-creds from state: %w", err) + } + + for _, oauth2Cred := range targetOauth2Creds { + n, err := d.createUpdateOauth2Cred(oauth2Cred) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *oauth2CredDiffer) createUpdateOauth2Cred(oauth2Cred *state.Oauth2Credential) (*crud.Event, error) { + oauth2Cred = &state.Oauth2Credential{Oauth2Credential: *oauth2Cred.DeepCopy()} + currentOauth2Cred, err := d.currentState.Oauth2Creds.Get(*oauth2Cred.ID) + if errors.Is(err, state.ErrNotFound) { + // oauth2Cred not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: oauth2Cred, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up oauth2-cred %q: %w", + *oauth2Cred.Name, err) + } + currentOauth2Cred = &state.Oauth2Credential{Oauth2Credential: *currentOauth2Cred.DeepCopy()} + // found, check if update needed + + if !currentOauth2Cred.EqualWithOpts(oauth2Cred, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: oauth2Cred, + OldObj: currentOauth2Cred, + }, nil + } + return nil, nil +} diff --git a/pkg/types/plugin.go b/pkg/types/plugin.go new file mode 100644 index 0000000..fc3a595 --- /dev/null +++ b/pkg/types/plugin.go @@ -0,0 +1,206 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// pluginCRUD implements crud.Actions interface. +type pluginCRUD struct { + client *kong.Client +} + +// kong and konnect APIs only require IDs for referenced entities. +func stripPluginReferencesName(plugin *state.Plugin) { + if plugin.Plugin.Service != nil && plugin.Plugin.Service.Name != nil { + plugin.Plugin.Service.Name = nil + } + if plugin.Plugin.Route != nil && plugin.Plugin.Route.Name != nil { + plugin.Plugin.Route.Name = nil + } + if plugin.Plugin.Consumer != nil && plugin.Plugin.Consumer.Username != nil { + plugin.Plugin.Consumer.Username = nil + } + if plugin.Plugin.ConsumerGroup != nil && plugin.Plugin.ConsumerGroup.Name != nil { + plugin.Plugin.ConsumerGroup.Name = nil + } +} + +func pluginFromStruct(arg crud.Event) *state.Plugin { + plugin, ok := arg.Obj.(*state.Plugin) + if !ok { + panic("unexpected type, expected *state.Plugin") + } + stripPluginReferencesName(plugin) + return plugin +} + +// Create creates a Plugin in Kong. +// The arg should be of type crud.Event, containing the plugin to be created, +// else the function will panic. +// It returns a the created *state.Plugin. +func (s *pluginCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + plugin := pluginFromStruct(event) + + createdPlugin, err := s.client.Plugins.Create(ctx, &plugin.Plugin) + if err != nil { + return nil, err + } + return &state.Plugin{Plugin: *createdPlugin}, nil +} + +// Delete deletes a Plugin in Kong. +// The arg should be of type crud.Event, containing the plugin to be deleted, +// else the function will panic. +// It returns a the deleted *state.Plugin. +func (s *pluginCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + plugin := pluginFromStruct(event) + err := s.client.Plugins.Delete(ctx, plugin.ID) + if err != nil { + return nil, err + } + return plugin, nil +} + +// Update updates a Plugin in Kong. +// The arg should be of type crud.Event, containing the plugin to be updated, +// else the function will panic. +// It returns a the updated *state.Plugin. +func (s *pluginCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + plugin := pluginFromStruct(event) + + updatedPlugin, err := s.client.Plugins.Create(ctx, &plugin.Plugin) + if err != nil { + return nil, err + } + return &state.Plugin{Plugin: *updatedPlugin}, nil +} + +type pluginDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *pluginDiffer) Deletes(handler func(crud.Event) error) error { + currentPlugins, err := d.currentState.Plugins.GetAll() + if err != nil { + return fmt.Errorf("error fetching plugins from state: %w", err) + } + + for _, plugin := range currentPlugins { + n, err := d.deletePlugin(plugin) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *pluginDiffer) deletePlugin(plugin *state.Plugin) (*crud.Event, error) { + plugin = &state.Plugin{Plugin: *plugin.DeepCopy()} + name := *plugin.Name + serviceID, routeID, consumerID, consumerGroupID := foreignNames(plugin) + _, err := d.targetState.Plugins.GetByProp( + name, serviceID, routeID, consumerID, consumerGroupID, + ) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: plugin, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up plugin %q: %w", *plugin.ID, err) + } + return nil, nil +} + +func (d *pluginDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetPlugins, err := d.targetState.Plugins.GetAll() + if err != nil { + return fmt.Errorf("error fetching plugins from state: %w", err) + } + + for _, plugin := range targetPlugins { + n, err := d.createUpdatePlugin(plugin) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *pluginDiffer) createUpdatePlugin(plugin *state.Plugin) (*crud.Event, error) { + plugin = &state.Plugin{Plugin: *plugin.DeepCopy()} + name := *plugin.Name + serviceID, routeID, consumerID, consumerGroupID := foreignNames(plugin) + currentPlugin, err := d.currentState.Plugins.GetByProp( + name, serviceID, routeID, consumerID, consumerGroupID, + ) + if errors.Is(err, state.ErrNotFound) { + // plugin not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: plugin, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up plugin %q: %w", + *plugin.Name, err) + } + currentPlugin = &state.Plugin{Plugin: *currentPlugin.DeepCopy()} + // found, check if update needed + + if !currentPlugin.EqualWithOpts(plugin, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: plugin, + OldObj: currentPlugin, + }, nil + } + return nil, nil +} + +func foreignNames(p *state.Plugin) (serviceID, routeID, consumerID, consumerGroupID string) { + if p == nil { + return + } + if p.Service != nil && p.Service.ID != nil { + serviceID = *p.Service.ID + } + if p.Route != nil && p.Route.ID != nil { + routeID = *p.Route.ID + } + if p.Consumer != nil && p.Consumer.ID != nil { + consumerID = *p.Consumer.ID + } + if p.ConsumerGroup != nil && p.ConsumerGroup.ID != nil { + consumerGroupID = *p.ConsumerGroup.ID + } + return +} diff --git a/pkg/types/postProcess.go b/pkg/types/postProcess.go new file mode 100644 index 0000000..9b9b9c2 --- /dev/null +++ b/pkg/types/postProcess.go @@ -0,0 +1,457 @@ +package types + +import ( + "context" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" +) + +type servicePostAction struct { + currentState *state.KongState +} + +func (crud *servicePostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Services.Add(*args[0].(*state.Service)) +} + +func (crud *servicePostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + serviceID := *args[0].(*state.Service).ID + + // Delete all plugins associated with this service as that's the implicit behavior of Kong (cascade delete). + plugins, err := crud.currentState.Plugins.GetAllByServiceID(serviceID) + if err != nil { + return nil, fmt.Errorf("error looking up plugins for service '%v': %w", serviceID, err) + } + for _, plugin := range plugins { + err = crud.currentState.Plugins.Delete(*plugin.ID) + if err != nil { + return nil, fmt.Errorf("error deleting plugin '%v' for service '%v': %w", *plugin.ID, serviceID, err) + } + } + return nil, crud.currentState.Services.Delete(serviceID) +} + +func (crud *servicePostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Services.Update(*args[0].(*state.Service)) +} + +type routePostAction struct { + currentState *state.KongState +} + +func (crud *routePostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Routes.Add(*args[0].(*state.Route)) +} + +func (crud *routePostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + routeID := *args[0].(*state.Route).ID + + // Delete all plugins associated with this route as that's the implicit behavior of Kong (cascade delete). + plugins, err := crud.currentState.Plugins.GetAllByRouteID(routeID) + if err != nil { + return nil, fmt.Errorf("error looking up plugins for route '%v': %w", routeID, err) + } + for _, plugin := range plugins { + err = crud.currentState.Plugins.Delete(*plugin.ID) + if err != nil { + return nil, fmt.Errorf("error deleting plugin '%v' for route '%v': %w", *plugin.ID, routeID, err) + } + } + return nil, crud.currentState.Routes.Delete(routeID) +} + +func (crud *routePostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Routes.Update(*args[0].(*state.Route)) +} + +type upstreamPostAction struct { + currentState *state.KongState +} + +func (crud *upstreamPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Upstreams.Add(*args[0].(*state.Upstream)) +} + +func (crud *upstreamPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Upstreams.Delete(*((args[0].(*state.Upstream)).ID)) +} + +func (crud *upstreamPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Upstreams.Update(*args[0].(*state.Upstream)) +} + +type targetPostAction struct { + currentState *state.KongState +} + +func (crud *targetPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Targets.Add(*args[0].(*state.Target)) +} + +func (crud *targetPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + target := args[0].(*state.Target) + return nil, crud.currentState.Targets.Delete(*target.Upstream.ID, *target.ID) +} + +func (crud *targetPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Targets.Update(*args[0].(*state.Target)) +} + +type certificatePostAction struct { + currentState *state.KongState +} + +func (crud *certificatePostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Certificates.Add(*args[0].(*state.Certificate)) +} + +func (crud *certificatePostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Certificates.Delete(*((args[0].(*state.Certificate)).ID)) +} + +func (crud *certificatePostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Certificates.Update(*args[0].(*state.Certificate)) +} + +type sniPostAction struct { + currentState *state.KongState +} + +func (crud *sniPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.SNIs.Add(*args[0].(*state.SNI)) +} + +func (crud *sniPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + sni := args[0].(*state.SNI) + return nil, crud.currentState.SNIs.Delete(*sni.ID) +} + +func (crud *sniPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.SNIs.Update(*args[0].(*state.SNI)) +} + +type caCertificatePostAction struct { + currentState *state.KongState +} + +func (crud *caCertificatePostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.CACertificates.Add(*args[0].(*state.CACertificate)) +} + +func (crud *caCertificatePostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.CACertificates.Delete(*((args[0].(*state.CACertificate)).ID)) +} + +func (crud *caCertificatePostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.CACertificates.Update(*args[0].(*state.CACertificate)) +} + +type pluginPostAction struct { + currentState *state.KongState +} + +func (crud *pluginPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Plugins.Add(*args[0].(*state.Plugin)) +} + +func (crud *pluginPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Plugins.Delete(*((args[0].(*state.Plugin)).ID)) +} + +func (crud *pluginPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Plugins.Update(*args[0].(*state.Plugin)) +} + +type consumerPostAction struct { + currentState *state.KongState +} + +func (crud *consumerPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Consumers.Add(*args[0].(*state.Consumer)) +} + +func (crud *consumerPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + consumerID := *args[0].(*state.Consumer).ID + + // Delete all plugins associated with this consumer as that's the implicit behavior of Kong (cascade delete). + plugins, err := crud.currentState.Plugins.GetAllByConsumerID(consumerID) + if err != nil { + return nil, fmt.Errorf("error looking up plugins for consumer '%v': %w", consumerID, err) + } + for _, plugin := range plugins { + if err := crud.currentState.Plugins.Delete(*plugin.ID); err != nil { + return nil, fmt.Errorf("error deleting plugin '%v' for consumer '%v': %w", *plugin.ID, consumerID, err) + } + } + return nil, crud.currentState.Consumers.Delete(consumerID) +} + +func (crud *consumerPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Consumers.Update(*args[0].(*state.Consumer)) +} + +type consumerGroupPostAction struct { + currentState *state.KongState +} + +func (crud *consumerGroupPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroups.Add(*args[0].(*state.ConsumerGroup)) +} + +func (crud *consumerGroupPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroups.Delete(*((args[0].(*state.ConsumerGroup)).ID)) +} + +func (crud *consumerGroupPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroups.Update(*args[0].(*state.ConsumerGroup)) +} + +type consumerGroupConsumerPostAction struct { + currentState *state.KongState +} + +func (crud *consumerGroupConsumerPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroupConsumers.Add(*args[0].(*state.ConsumerGroupConsumer)) +} + +func (crud *consumerGroupConsumerPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroupConsumers.Delete( + *((args[0].(*state.ConsumerGroupConsumer)).Consumer.ID), + *((args[0].(*state.ConsumerGroupConsumer)).ConsumerGroup.ID), + ) +} + +func (crud *consumerGroupConsumerPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroupConsumers.Update(*args[0].(*state.ConsumerGroupConsumer)) +} + +type consumerGroupPluginPostAction struct { + currentState *state.KongState +} + +func (crud *consumerGroupPluginPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroupPlugins.Add(*args[0].(*state.ConsumerGroupPlugin)) +} + +func (crud *consumerGroupPluginPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroupPlugins.Delete( + *((args[0].(*state.ConsumerGroupPlugin)).ID), + *((args[0].(*state.ConsumerGroupConsumer)).ConsumerGroup.ID), + ) +} + +func (crud *consumerGroupPluginPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ConsumerGroupPlugins.Update(*args[0].(*state.ConsumerGroupPlugin)) +} + +type keyAuthPostAction struct { + currentState *state.KongState +} + +func (crud *keyAuthPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.KeyAuths.Add(*args[0].(*state.KeyAuth)) +} + +func (crud *keyAuthPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.KeyAuths.Delete(*((args[0].(*state.KeyAuth)).ID)) +} + +func (crud *keyAuthPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.KeyAuths.Update(*args[0].(*state.KeyAuth)) +} + +type hmacAuthPostAction struct { + currentState *state.KongState +} + +func (crud hmacAuthPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.HMACAuths.Add(*args[0].(*state.HMACAuth)) +} + +func (crud hmacAuthPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.HMACAuths.Delete(*((args[0].(*state.HMACAuth)).ID)) +} + +func (crud hmacAuthPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.HMACAuths.Update(*args[0].(*state.HMACAuth)) +} + +type jwtAuthPostAction struct { + currentState *state.KongState +} + +func (crud jwtAuthPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.JWTAuths.Add(*args[0].(*state.JWTAuth)) +} + +func (crud jwtAuthPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.JWTAuths.Delete(*((args[0].(*state.JWTAuth)).ID)) +} + +func (crud jwtAuthPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.JWTAuths.Update(*args[0].(*state.JWTAuth)) +} + +type basicAuthPostAction struct { + currentState *state.KongState +} + +func (crud basicAuthPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.BasicAuths.Add(*args[0].(*state.BasicAuth)) +} + +func (crud basicAuthPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.BasicAuths.Delete(*((args[0].(*state.BasicAuth)).ID)) +} + +func (crud basicAuthPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.BasicAuths.Update(*args[0].(*state.BasicAuth)) +} + +type aclGroupPostAction struct { + currentState *state.KongState +} + +func (crud aclGroupPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ACLGroups.Add(*args[0].(*state.ACLGroup)) +} + +func (crud aclGroupPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ACLGroups.Delete(*((args[0].(*state.ACLGroup)).ID)) +} + +func (crud aclGroupPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ACLGroups.Update(*args[0].(*state.ACLGroup)) +} + +type oauth2CredPostAction struct { + currentState *state.KongState +} + +func (crud oauth2CredPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Oauth2Creds.Add(*args[0].(*state.Oauth2Credential)) +} + +func (crud oauth2CredPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Oauth2Creds.Delete(*((args[0].(*state.Oauth2Credential)).ID)) +} + +func (crud oauth2CredPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Oauth2Creds.Update(*args[0].(*state.Oauth2Credential)) +} + +type mtlsAuthPostAction struct { + currentState *state.KongState +} + +func (crud *mtlsAuthPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.MTLSAuths.Add(*args[0].(*state.MTLSAuth)) +} + +func (crud *mtlsAuthPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.MTLSAuths.Delete(*((args[0].(*state.MTLSAuth)).ID)) +} + +func (crud *mtlsAuthPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.MTLSAuths.Update(*args[0].(*state.MTLSAuth)) +} + +type rbacRolePostAction struct { + currentState *state.KongState +} + +func (crud *rbacRolePostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.RBACRoles.Add(*args[0].(*state.RBACRole)) +} + +func (crud *rbacRolePostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.RBACRoles.Delete(*((args[0].(*state.RBACRole)).ID)) +} + +func (crud *rbacRolePostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.RBACRoles.Update(*args[0].(*state.RBACRole)) +} + +type rbacEndpointPermissionPostAction struct { + currentState *state.KongState +} + +func (crud *rbacEndpointPermissionPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.RBACEndpointPermissions.Add(*args[0].(*state.RBACEndpointPermission)) +} + +func (crud *rbacEndpointPermissionPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.RBACEndpointPermissions.Delete(args[0].(*state.RBACEndpointPermission).FriendlyName()) +} + +func (crud *rbacEndpointPermissionPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.RBACEndpointPermissions.Update(*args[0].(*state.RBACEndpointPermission)) +} + +type servicePackagePostAction struct { + currentState *state.KongState +} + +func (crud servicePackagePostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ServicePackages.Add(*args[0].(*state.ServicePackage)) +} + +func (crud servicePackagePostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ServicePackages.Delete(*((args[0].(*state.ServicePackage)).ID)) +} + +func (crud servicePackagePostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ServicePackages.Update(*args[0].(*state.ServicePackage)) +} + +type serviceVersionPostAction struct { + currentState *state.KongState +} + +func (crud serviceVersionPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ServiceVersions.Add(*args[0].(*state.ServiceVersion)) +} + +func (crud serviceVersionPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + sv := args[0].(*state.ServiceVersion) + return nil, crud.currentState.ServiceVersions.Delete(*sv.ServicePackage.ID, *sv.ID) +} + +func (crud serviceVersionPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.ServiceVersions.Update(*args[0].(*state.ServiceVersion)) +} + +type documentPostAction struct { + currentState *state.KongState +} + +func (crud documentPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Documents.Add(*args[0].(*state.Document)) +} + +func (crud documentPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + d := args[0].(*state.Document) + return nil, crud.currentState.Documents.DeleteByParent(d.Parent, *d.ID) +} + +func (crud documentPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Documents.Update(*args[0].(*state.Document)) +} + +type vaultPostAction struct { + currentState *state.KongState +} + +func (crud vaultPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Vaults.Add(*args[0].(*state.Vault)) +} + +func (crud vaultPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Vaults.Delete(*((args[0].(*state.Vault)).ID)) +} + +func (crud vaultPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Vaults.Update(*args[0].(*state.Vault)) +} diff --git a/pkg/types/rbac_endpoint_permission.go b/pkg/types/rbac_endpoint_permission.go new file mode 100644 index 0000000..e72aff2 --- /dev/null +++ b/pkg/types/rbac_endpoint_permission.go @@ -0,0 +1,174 @@ +package types + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// rbacEndpointPermissionCRUD implements crud.Actions interface. +type rbacEndpointPermissionCRUD struct { + client *kong.Client +} + +func rbacEndpointPermissionFromStruct(arg crud.Event) *state.RBACEndpointPermission { + ep, ok := arg.Obj.(*state.RBACEndpointPermission) + if !ok { + panic("unexpected type, expected *state.RBACEndpointPermission") + } + + return ep +} + +// Create creates a RBACEndpointPermission in Kong. +// The arg should be of type crud.Event, containing the ep to be created, +// else the function will panic. +// It returns a the created *state.RBACEndpointPermission. +func (s *rbacEndpointPermissionCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + ep := rbacEndpointPermissionFromStruct(event) + createdRBACEndpointPermission, err := s.client.RBACEndpointPermissions.Create(ctx, &ep.RBACEndpointPermission) + if err != nil { + return nil, err + } + return &state.RBACEndpointPermission{RBACEndpointPermission: *createdRBACEndpointPermission}, nil +} + +// Delete deletes a RBACEndpointPermission in Kong. +// The arg should be of type crud.Event, containing the ep to be deleted, +// else the function will panic. +// It returns a the deleted *state.RBACEndpointPermission. +func (s *rbacEndpointPermissionCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + ep := rbacEndpointPermissionFromStruct(event) + + // for DELETE calls, the endpoint is passed in the URL only + // including the leading slash results in a URL like + // /rbac/roles/ROLEID/endpoints/workspace//foo/ + // Kong expects a URL like + // /rbac/roles/ROLEID/endpoints/workspace/foo/ + // so we strip this before passing it to go-kong + trimmed := strings.TrimLeft(*ep.Endpoint, "/") + err := s.client.RBACEndpointPermissions.Delete(ctx, ep.Role.ID, ep.Workspace, &trimmed) + if err != nil { + return nil, err + } + return ep, nil +} + +// Update updates a RBACEndpointPermission in Kong. +// The arg should be of type crud.Event, containing the ep to be updated, +// else the function will panic. +// It returns a the updated *state.RBACEndpointPermission. +func (s *rbacEndpointPermissionCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + ep := rbacEndpointPermissionFromStruct(event) + + updatedRBACEndpointPermission, err := s.client.RBACEndpointPermissions.Update(ctx, &ep.RBACEndpointPermission) + if err != nil { + return nil, err + } + return &state.RBACEndpointPermission{RBACEndpointPermission: *updatedRBACEndpointPermission}, nil +} + +func (d *rbacEndpointPermissionDiffer) Deletes(handler func(crud.Event) error) error { + currentRBACEndpointPermissions, err := d.currentState.RBACEndpointPermissions.GetAll() + if err != nil { + return fmt.Errorf("error fetching eps from state: %w", err) + } + + for _, ep := range currentRBACEndpointPermissions { + n, err := d.deleteRBACEndpointPermission(ep) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +type rbacEndpointPermissionDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *rbacEndpointPermissionDiffer) deleteRBACEndpointPermission(ep *state.RBACEndpointPermission) ( + *crud.Event, error, +) { + _, err := d.targetState.RBACEndpointPermissions.Get(ep.FriendlyName()) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: ep, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up rbac ep %q: %w", + ep.ID, err) + } + return nil, nil +} + +func (d *rbacEndpointPermissionDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetRBACEndpointPermissions, err := d.targetState.RBACEndpointPermissions.GetAll() + if err != nil { + return fmt.Errorf("error fetching rbac eps from state: %w", err) + } + + for _, ep := range targetRBACEndpointPermissions { + n, err := d.createUpdateRBACEndpointPermission(ep) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *rbacEndpointPermissionDiffer) createUpdateRBACEndpointPermission(ep *state.RBACEndpointPermission) ( + *crud.Event, error, +) { + epCopy := &state.RBACEndpointPermission{RBACEndpointPermission: *ep.DeepCopy()} + currentEp, err := d.currentState.RBACEndpointPermissions.Get(ep.FriendlyName()) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: epCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up rbac endpoint permission %q: %w", + ep.FriendlyName(), err) + } + + // found, check if update needed + if !currentEp.EqualWithOpts(epCopy, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: epCopy, + OldObj: currentEp, + }, nil + } + return nil, nil +} diff --git a/pkg/types/rbac_role.go b/pkg/types/rbac_role.go new file mode 100644 index 0000000..dfede9a --- /dev/null +++ b/pkg/types/rbac_role.go @@ -0,0 +1,161 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// rbacRoleCRUD implements crud.Actions interface. +type rbacRoleCRUD struct { + client *kong.Client +} + +func rbacRoleFromStruct(arg crud.Event) *state.RBACRole { + role, ok := arg.Obj.(*state.RBACRole) + if !ok { + panic("unexpected type, expected *state.RBACRole") + } + + return role +} + +// Create creates a RBACRole in Kong. +// The arg should be of type crud.Event, containing the role to be created, +// else the function will panic. +// It returns a the created *state.RBACRole. +func (s *rbacRoleCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + role := rbacRoleFromStruct(event) + createdRBACRole, err := s.client.RBACRoles.Create(ctx, &role.RBACRole) + if err != nil { + return nil, err + } + return &state.RBACRole{RBACRole: *createdRBACRole}, nil +} + +// Delete deletes a RBACRole in Kong. +// The arg should be of type crud.Event, containing the role to be deleted, +// else the function will panic. +// It returns a the deleted *state.RBACRole. +func (s *rbacRoleCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + role := rbacRoleFromStruct(event) + err := s.client.RBACRoles.Delete(ctx, role.ID) + if err != nil { + return nil, err + } + return role, nil +} + +// Update updates a RBACRole in Kong. +// The arg should be of type crud.Event, containing the role to be updated, +// else the function will panic. +// It returns a the updated *state.RBACRole. +func (s *rbacRoleCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + role := rbacRoleFromStruct(event) + + updatedRBACRole, err := s.client.RBACRoles.Create(ctx, &role.RBACRole) + if err != nil { + return nil, err + } + return &state.RBACRole{RBACRole: *updatedRBACRole}, nil +} + +type rbacRoleDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *rbacRoleDiffer) Deletes(handler func(crud.Event) error) error { + currentRBACRoles, err := d.currentState.RBACRoles.GetAll() + if err != nil { + return fmt.Errorf("error fetching rbac roles from state: %w", err) + } + + for _, role := range currentRBACRoles { + n, err := d.deleteRBACRole(role) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *rbacRoleDiffer) deleteRBACRole(role *state.RBACRole) (*crud.Event, error) { + _, err := d.targetState.RBACRoles.Get(*role.Name) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: role, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up rbac role %q: %w", + role.FriendlyName(), err) + } + return nil, nil +} + +func (d *rbacRoleDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetRBACRoles, err := d.targetState.RBACRoles.GetAll() + if err != nil { + return fmt.Errorf("error fetching rbac roles from state: %w", err) + } + + for _, role := range targetRBACRoles { + n, err := d.createUpdateRBACRole(role) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *rbacRoleDiffer) createUpdateRBACRole(role *state.RBACRole) (*crud.Event, error) { + roleCopy := &state.RBACRole{RBACRole: *role.DeepCopy()} + currentRole, err := d.currentState.RBACRoles.Get(*role.Name) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: roleCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up rbac role %q: %w", + role.FriendlyName(), err) + } + + // found, check if update needed + if !currentRole.EqualWithOpts(roleCopy, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: roleCopy, + OldObj: currentRole, + }, nil + } + return nil, nil +} diff --git a/pkg/types/route.go b/pkg/types/route.go new file mode 100644 index 0000000..e7ce4ad --- /dev/null +++ b/pkg/types/route.go @@ -0,0 +1,213 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// routeCRUD implements crud.Actions interface. +type routeCRUD struct { + client *kong.Client +} + +// kong and konnect APIs only require IDs for referenced entities. +func stripRouteReferencesName(route *state.Route) { + if route.Route.Service != nil && route.Route.Service.Name != nil { + route.Route.Service.Name = nil + } +} + +func routeFromStruct(arg crud.Event) *state.Route { + route, ok := arg.Obj.(*state.Route) + if !ok { + panic("unexpected type, expected *state.Route") + } + stripRouteReferencesName(route) + return route +} + +// Create creates a Route in Kong. +// The arg should be of type crud.Event, containing the route to be created, +// else the function will panic. +// It returns a the created *state.Route. +func (s *routeCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + route := routeFromStruct(event) + createdRoute, err := s.client.Routes.Create(ctx, &route.Route) + if err != nil { + return nil, err + } + return &state.Route{Route: *createdRoute}, nil +} + +// Delete deletes a Route in Kong. +// The arg should be of type crud.Event, containing the route to be deleted, +// else the function will panic. +// It returns a the deleted *state.Route. +func (s *routeCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + route := routeFromStruct(event) + err := s.client.Routes.Delete(ctx, route.ID) + if err != nil { + return nil, err + } + return route, nil +} + +// Update updates a Route in Kong. +// The arg should be of type crud.Event, containing the route to be updated, +// else the function will panic. +// It returns a the updated *state.Route. +func (s *routeCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + route := routeFromStruct(event) + + updatedRoute, err := s.client.Routes.Create(ctx, &route.Route) + if err != nil { + return nil, err + } + return &state.Route{Route: *updatedRoute}, nil +} + +type routeDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *routeDiffer) Deletes(handler func(crud.Event) error) error { + currentRoutes, err := d.currentState.Routes.GetAll() + if err != nil { + return fmt.Errorf("error fetching routes from state: %w", err) + } + + for _, route := range currentRoutes { + n, err := d.deleteRoute(route) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *routeDiffer) deleteRoute(route *state.Route) (*crud.Event, error) { + _, err := d.targetState.Routes.Get(*route.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: route, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up route %q: %w", + route.FriendlyName(), err) + } + return nil, nil +} + +func (d *routeDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetRoutes, err := d.targetState.Routes.GetAll() + if err != nil { + return fmt.Errorf("error fetching routes from state: %w", err) + } + + for _, route := range targetRoutes { + n, err := d.createUpdateRoute(route) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *routeDiffer) createUpdateRoute(route *state.Route) (*crud.Event, error) { + route = &state.Route{Route: *route.DeepCopy()} + currentRoute, err := d.currentState.Routes.Get(*route.ID) + if errors.Is(err, state.ErrNotFound) { + // route not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: route, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up route %q: %w", + route.FriendlyName(), err) + } + // found, check if update needed + + if !currentRoute.EqualWithOpts(route, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: route, + OldObj: currentRoute, + }, nil + } + return nil, nil +} + +func (d *routeDiffer) DuplicatesDeletes() ([]crud.Event, error) { + targetRoutes, err := d.targetState.Routes.GetAll() + if err != nil { + return nil, fmt.Errorf("error fetching routes from state: %w", err) + } + + var events []crud.Event + for _, route := range targetRoutes { + event, err := d.deleteDuplicateRoute(route) + if err != nil { + return nil, err + } + if event != nil { + events = append(events, *event) + } + } + + return events, nil +} + +func (d *routeDiffer) deleteDuplicateRoute(targetRoute *state.Route) (*crud.Event, error) { + if targetRoute == nil || targetRoute.Name == nil { + // Nothing to do, cannot be a duplicate with no name. + return nil, nil + } + + currentRoute, err := d.currentState.Routes.Get(*targetRoute.Name) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up route %q: %w", *targetRoute.Name, err) + } + + if *currentRoute.ID != *targetRoute.ID { + return &crud.Event{ + Op: crud.Delete, + Kind: "route", + Obj: currentRoute, + }, nil + } + + return nil, nil +} diff --git a/pkg/types/service.go b/pkg/types/service.go new file mode 100644 index 0000000..a715152 --- /dev/null +++ b/pkg/types/service.go @@ -0,0 +1,222 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// serviceCRUD implements crud.Actions interface. +type serviceCRUD struct { + client *kong.Client +} + +func serviceFromStruct(arg crud.Event) *state.Service { + service, ok := arg.Obj.(*state.Service) + if !ok { + panic("unexpected type, expected *state.service") + } + return service +} + +// Create creates a Service in Kong. +// The arg should be of type crud.Event, containing the service to be created, +// else the function will panic. +// It returns a the created *state.Service. +func (s *serviceCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + service := serviceFromStruct(event) + createdService, err := s.client.Services.Create(ctx, &service.Service) + if err != nil { + return nil, err + } + return &state.Service{Service: *createdService}, nil +} + +// Delete deletes a Service in Kong. +// The arg should be of type crud.Event, containing the service to be deleted, +// else the function will panic. +// It returns a the deleted *state.Service. +func (s *serviceCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + service := serviceFromStruct(event) + err := s.client.Services.Delete(ctx, service.ID) + if err != nil { + return nil, err + } + return service, nil +} + +// Update updates a Service in Kong. +// The arg should be of type crud.Event, containing the service to be updated, +// else the function will panic. +// It returns a the updated *state.Service. +func (s *serviceCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + service := serviceFromStruct(event) + + updatedService, err := s.client.Services.Create(ctx, &service.Service) + if err != nil { + return nil, err + } + return &state.Service{Service: *updatedService}, nil +} + +type serviceDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *serviceDiffer) Deletes(handler func(crud.Event) error) error { + currentServices, err := d.currentState.Services.GetAll() + if err != nil { + return err + } + + for _, service := range currentServices { + n, err := d.deleteService(service) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *serviceDiffer) deleteService(service *state.Service) (*crud.Event, error) { + _, err := d.targetState.Services.Get(*service.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: service, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up service %q: %w", + service.FriendlyName(), err) + } + return nil, nil +} + +func (d *serviceDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetServices, err := d.targetState.Services.GetAll() + if err != nil { + return fmt.Errorf("error fetching services from state: %w", err) + } + + for _, service := range targetServices { + n, err := d.createUpdateService(service) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *serviceDiffer) createUpdateService(service *state.Service) (*crud.Event, error) { + serviceCopy := &state.Service{Service: *service.DeepCopy()} + currentService, err := d.currentState.Services.Get(*service.ID) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: "service", + Obj: serviceCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up service %q: %w", + *service.Name, err) + } + + // found, check if update needed + if !currentService.EqualWithOpts(serviceCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "service", + Obj: serviceCopy, + OldObj: currentService, + }, nil + } + return nil, nil +} + +func (d *serviceDiffer) DuplicatesDeletes() ([]crud.Event, error) { + targetServices, err := d.targetState.Services.GetAll() + if err != nil { + return nil, fmt.Errorf("error fetching services from state: %w", err) + } + var events []crud.Event + for _, service := range targetServices { + serviceEvents, err := d.deleteDuplicateService(service) + if err != nil { + return nil, err + } + events = append(events, serviceEvents...) + } + + return events, nil +} + +func (d *serviceDiffer) deleteDuplicateService(targetService *state.Service) ([]crud.Event, error) { + if targetService == nil || targetService.Name == nil { + // Nothing to do, cannot be a duplicate with no name. + return nil, nil + } + + currentService, err := d.currentState.Services.Get(*targetService.Name) + if errors.Is(err, state.ErrNotFound) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up service %q: %w", + *targetService.Name, err) + } + + if *currentService.ID != *targetService.ID { + var events []crud.Event + + // We have to delete all routes beforehand as otherwise we will get a foreign key error when deleting the service + // as routes are not deleted by the cascading delete of the service. + // See https://github.com/Kong/kong/discussions/7314 for more details. + routesToDelete, err := d.currentState.Routes.GetAllByServiceID(*currentService.ID) + if err != nil { + return nil, fmt.Errorf("error looking up routes for service %q: %w", + *currentService.Name, err) + } + + for _, route := range routesToDelete { + events = append(events, crud.Event{ + Op: crud.Delete, + Kind: "route", + Obj: route, + }) + } + + return append(events, crud.Event{ + Op: crud.Delete, + Kind: "service", + Obj: currentService, + }), nil + } + + return nil, nil +} diff --git a/pkg/types/service_package.go b/pkg/types/service_package.go new file mode 100644 index 0000000..c768d7f --- /dev/null +++ b/pkg/types/service_package.go @@ -0,0 +1,160 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" +) + +// servicePackageCRUD implements crud.Actions interface. +type servicePackageCRUD struct { + client *konnect.Client +} + +func servicePackageFromStruct(arg crud.Event) *state.ServicePackage { + sp, ok := arg.Obj.(*state.ServicePackage) + if !ok { + panic("unexpected type, expected *state.ServicePackage") + } + return sp +} + +// Create creates a Service package in Konnect. +// The arg should be of type crud.Event, containing the service to be created, +// else the function will panic. +// It returns a the created *state.ServicePackage. +func (s *servicePackageCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + sp := servicePackageFromStruct(event) + createdSP, err := s.client.ServicePackages.Create(ctx, &sp.ServicePackage) + if err != nil { + return nil, err + } + return &state.ServicePackage{ServicePackage: *createdSP}, nil +} + +// Delete deletes a Service package in Konnect. +// The arg should be of type crud.Event, containing the service to be deleted, +// else the function will panic. +// It returns a the deleted *state.ServicePackage. +func (s *servicePackageCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + sp := servicePackageFromStruct(event) + err := s.client.ServicePackages.Delete(ctx, sp.ID) + if err != nil { + return nil, err + } + return sp, nil +} + +// Update updates a Service package in Konnect. +// The arg should be of type crud.Event, containing the service to be updated, +// else the function will panic. +// It returns a the updated *state.ServicePackage. +func (s *servicePackageCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + sp := servicePackageFromStruct(event) + + updatedSP, err := s.client.ServicePackages.Update(ctx, &sp.ServicePackage) + if err != nil { + return nil, err + } + return &state.ServicePackage{ServicePackage: *updatedSP}, nil +} + +type servicePackageDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *servicePackageDiffer) Deletes(handler func(crud.Event) error) error { + currentServicePackages, err := d.currentState.ServicePackages.GetAll() + if err != nil { + return fmt.Errorf("error fetching services-packages from state: %w", err) + } + + for _, sp := range currentServicePackages { + n, err := d.deleteServicePackage(sp) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *servicePackageDiffer) deleteServicePackage(sp *state.ServicePackage) (*crud.Event, error) { + _, err := d.targetState.ServicePackages.Get(*sp.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: "service-package", + Obj: sp, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up service-package %q: %w", + sp.Identifier(), err) + } + return nil, nil +} + +func (d *servicePackageDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetServicePackages, err := d.targetState.ServicePackages.GetAll() + if err != nil { + return fmt.Errorf("error fetching services-packages from state: %w", err) + } + + for _, sp := range targetServicePackages { + n, err := d.createUpdateServicePackage(sp) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *servicePackageDiffer) createUpdateServicePackage(sp *state.ServicePackage) (*crud.Event, error) { + spCopy := &state.ServicePackage{ServicePackage: *sp.DeepCopy()} + currentSP, err := d.currentState.ServicePackages.Get(*sp.ID) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: "service-package", + Obj: spCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up service-package %q: %w", + sp.Identifier(), err) + } + + // found, check if update needed + if !currentSP.EqualWithOpts(spCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "service-package", + Obj: spCopy, + OldObj: currentSP, + }, nil + } + return nil, nil +} diff --git a/pkg/types/service_version.go b/pkg/types/service_version.go new file mode 100644 index 0000000..b659013 --- /dev/null +++ b/pkg/types/service_version.go @@ -0,0 +1,232 @@ +package types + +import ( + "context" + "errors" + "fmt" + "reflect" + + "github.com/kong/deck/crud" + "github.com/kong/deck/konnect" + "github.com/kong/deck/state" +) + +// serviceVersionCRUD implements crud.Actions interface. +type serviceVersionCRUD struct { + client *konnect.Client +} + +func serviceVersionFromStruct(arg crud.Event) *state.ServiceVersion { + sv, ok := arg.Obj.(*state.ServiceVersion) + if !ok { + panic("unexpected type, expected *state.ServiceVersion") + } + return sv +} + +func oldServiceVersionFromStruct(arg crud.Event) *state.ServiceVersion { + sv, ok := arg.OldObj.(*state.ServiceVersion) + if !ok { + panic("unexpected type, expected *state.ServiceVersion") + } + return sv +} + +// Create creates a Service version in Konnect. +// The arg should be of type crud.Event, containing the service to be created, +// else the function will panic. +// It returns a the created *state.ServiceVersion. +func (s *serviceVersionCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + sv := serviceVersionFromStruct(event) + createdSV, err := s.client.ServiceVersions.Create(ctx, &sv.ServiceVersion) + if err != nil { + return nil, err + } + if sv.ControlPlaneServiceRelation != nil { + _, err := s.client.ControlPlaneRelations.Create(ctx, &konnect.ControlPlaneServiceRelationCreateRequest{ + ServiceVersionID: *createdSV.ID, + ControlPlaneEntityID: *sv.ControlPlaneServiceRelation.ControlPlaneEntityID, + }) + if err != nil { + return nil, err + } + } + return &state.ServiceVersion{ServiceVersion: *createdSV}, nil +} + +// Delete deletes a Service version in Konnect. +// The arg should be of type crud.Event, containing the service to be deleted, +// else the function will panic. +// It returns a the deleted *state.ServiceVersion. +func (s *serviceVersionCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + sv := serviceVersionFromStruct(event) + err := s.client.ServiceVersions.Delete(ctx, sv.ID) + if err != nil { + return nil, err + } + return sv, nil +} + +// Update updates a Service version in Konnect. +// The arg should be of type crud.Event, containing the service to be updated, +// else the function will panic. +// It returns a the updated *state.ServiceVersion. +func (s *serviceVersionCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + var ( + err error + updatedSV *konnect.ServiceVersion + ) + event := crud.EventFromArg(arg[0]) + version := serviceVersionFromStruct(event) + oldVersion := oldServiceVersionFromStruct(event) + + // if there is a change in service version entity, make a PATCH + if !version.EqualWithOpts(oldVersion, false, true, true) { + versionCopy := &state.ServiceVersion{ServiceVersion: *version.DeepCopy()} + versionCopy.ControlPlaneServiceRelation = nil + versionCopy.ServicePackage = nil + updatedSV, err = s.client.ServiceVersions.Update(ctx, &versionCopy.ServiceVersion) + if err != nil { + return nil, err + } + } else { + updatedSV = &version.ServiceVersion + } + + // When a service versions update is detected, it could be due to changes in + // control-plane-entity and service-version relations + // This is possible only during update events + err = s.relationCRUD(ctx, &version.ServiceVersion, &oldVersion.ServiceVersion) + if err != nil { + return nil, err + } + + return &state.ServiceVersion{ServiceVersion: *updatedSV}, nil +} + +func (s *serviceVersionCRUD) relationCRUD(ctx context.Context, version, + oldVersion *konnect.ServiceVersion, +) error { + var err error + + if version.ControlPlaneServiceRelation != nil && + oldVersion.ControlPlaneServiceRelation == nil { + // no version existed before, create a new relation + _, err = s.client.ControlPlaneRelations.Create(ctx, &konnect.ControlPlaneServiceRelationCreateRequest{ + ServiceVersionID: *version.ID, + ControlPlaneEntityID: *version.ControlPlaneServiceRelation.ControlPlaneEntityID, + }) + } else if version.ControlPlaneServiceRelation == nil && oldVersion. + ControlPlaneServiceRelation != nil { + // version doesn't need to exist anymore, delete it + err = s.client.ControlPlaneRelations.Delete(ctx, + oldVersion.ControlPlaneServiceRelation.ID) + } else if !reflect.DeepEqual(version.ControlPlaneServiceRelation, + oldVersion.ControlPlaneServiceRelation) { + // relations are different, update it + _, err = s.client.ControlPlaneRelations.Update(ctx, + &konnect.ControlPlaneServiceRelationUpdateRequest{ + ID: *oldVersion.ControlPlaneServiceRelation.ID, + ControlPlaneServiceRelationCreateRequest: konnect.ControlPlaneServiceRelationCreateRequest{ + ServiceVersionID: *version.ID, + ControlPlaneEntityID: *version.ControlPlaneServiceRelation.ControlPlaneEntityID, + }, + }) + } + return err +} + +type serviceVersionDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *serviceVersionDiffer) Deletes(handler func(crud.Event) error) error { + currentServiceVersions, err := d.currentState.ServiceVersions.GetAll() + if err != nil { + return fmt.Errorf("error fetching service-versions from state: %w", err) + } + + for _, sv := range currentServiceVersions { + n, err := d.deleteServiceVersion(sv) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *serviceVersionDiffer) deleteServiceVersion(sv *state.ServiceVersion) (*crud.Event, error) { + _, err := d.targetState.ServiceVersions.Get(*sv.ServicePackage.ID, *sv.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: sv, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up service-version %q': %w", + sv.Identifier(), err) + } + return nil, nil +} + +func (d *serviceVersionDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetServiceVersions, err := d.targetState.ServiceVersions.GetAll() + if err != nil { + return fmt.Errorf("error fetching services from state: %w", err) + } + + for _, sv := range targetServiceVersions { + n, err := d.createUpdateServiceVersion(sv) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *serviceVersionDiffer) createUpdateServiceVersion(sv *state.ServiceVersion) (*crud.Event, error) { + svCopy := &state.ServiceVersion{ServiceVersion: *sv.DeepCopy()} + currentSV, err := d.currentState.ServiceVersions.Get(*sv.ServicePackage.ID, *sv.ID) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: svCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up service-version %q: %w", + sv.Identifier(), err) + } + + // found, check if update needed + if !currentSV.EqualWithOpts(svCopy, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: svCopy, + OldObj: currentSV, + }, nil + } + return nil, nil +} diff --git a/pkg/types/sni.go b/pkg/types/sni.go new file mode 100644 index 0000000..f5eb59a --- /dev/null +++ b/pkg/types/sni.go @@ -0,0 +1,159 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// sniCRUD implements crud.Actions interface. +type sniCRUD struct { + client *kong.Client +} + +func sniFromStruct(arg crud.Event) *state.SNI { + sni, ok := arg.Obj.(*state.SNI) + if !ok { + panic("unexpected type, expected *state.SNI") + } + + return sni +} + +// Create creates a SNI in Kong. +// The arg should be of type crud.Event, containing the sni to be created, +// else the function will panic. +// It returns a the created *state.SNI. +func (s *sniCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + sni := sniFromStruct(event) + createdSNI, err := s.client.SNIs.Create(ctx, &sni.SNI) + if err != nil { + return nil, err + } + return &state.SNI{SNI: *createdSNI}, nil +} + +// Delete deletes a SNI in Kong. +// The arg should be of type crud.Event, containing the sni to be deleted, +// else the function will panic. +// It returns a the deleted *state.SNI. +func (s *sniCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + sni := sniFromStruct(event) + err := s.client.SNIs.Delete(ctx, sni.ID) + if err != nil { + return nil, err + } + return sni, nil +} + +// Update updates a SNI in Kong. +// The arg should be of type crud.Event, containing the sni to be updated, +// else the function will panic. +// It returns a the updated *state.SNI. +func (s *sniCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + sni := sniFromStruct(event) + + updatedSNI, err := s.client.SNIs.Create(ctx, &sni.SNI) + if err != nil { + return nil, err + } + return &state.SNI{SNI: *updatedSNI}, nil +} + +type sniDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *sniDiffer) Deletes(handler func(crud.Event) error) error { + currentSNIs, err := d.currentState.SNIs.GetAll() + if err != nil { + return fmt.Errorf("error fetching snis from state: %w", err) + } + + for _, sni := range currentSNIs { + n, err := d.deleteSNI(sni) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *sniDiffer) deleteSNI(sni *state.SNI) (*crud.Event, error) { + _, err := d.targetState.SNIs.Get(*sni.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: sni, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up sni %q: %w", *sni.Name, err) + } + return nil, nil +} + +func (d *sniDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + sniSNIs, err := d.targetState.SNIs.GetAll() + if err != nil { + return fmt.Errorf("error fetching snis from state: %w", err) + } + + for _, sni := range sniSNIs { + n, err := d.createUpdateSNI(sni) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *sniDiffer) createUpdateSNI(sni *state.SNI) (*crud.Event, error) { + sni = &state.SNI{SNI: *sni.DeepCopy()} + currentSNI, err := d.currentState.SNIs.Get(*sni.ID) + if errors.Is(err, state.ErrNotFound) { + // sni not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: sni, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up sni %q: %w", *sni.Name, err) + } + // found, check if update needed + + if !currentSNI.EqualWithOpts(sni, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: sni, + OldObj: currentSNI, + }, nil + } + return nil, nil +} diff --git a/pkg/types/target.go b/pkg/types/target.go new file mode 100644 index 0000000..28b7416 --- /dev/null +++ b/pkg/types/target.go @@ -0,0 +1,169 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// targetCRUD implements crud.Actions interface. +type targetCRUD struct { + client *kong.Client +} + +func targetFromStruct(arg crud.Event) *state.Target { + target, ok := arg.Obj.(*state.Target) + if !ok { + panic("unexpected type, expected *state.Target") + } + + return target +} + +// Create creates a Target in Kong. +// The arg should be of type crud.Event, containing the target to be created, +// else the function will panic. +// It returns a the created *state.Target. +func (s *targetCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + target := targetFromStruct(event) + createdTarget, err := s.client.Targets.Create(ctx, + target.Upstream.ID, &target.Target) + if err != nil { + return nil, err + } + return &state.Target{Target: *createdTarget}, nil +} + +// Delete deletes a Target in Kong. +// The arg should be of type crud.Event, containing the target to be deleted, +// else the function will panic. +// It returns a the deleted *state.Target. +func (s *targetCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + target := targetFromStruct(event) + err := s.client.Targets.Delete(ctx, target.Upstream.ID, target.ID) + if err != nil { + return nil, err + } + return target, nil +} + +// Update updates a Target in Kong. +// The arg should be of type crud.Event, containing the target to be updated, +// else the function will panic. +// It returns a the updated *state.Target. +func (s *targetCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + target := targetFromStruct(event) + // Targets in Kong cannot be updated + err := s.client.Targets.Delete(ctx, target.Upstream.ID, target.ID) + if err != nil { + return nil, err + } + createdTarget, err := s.client.Targets.Create(ctx, + target.Upstream.ID, &target.Target) + if err != nil { + return nil, err + } + return &state.Target{Target: *createdTarget}, nil +} + +type targetDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *targetDiffer) Deletes(handler func(crud.Event) error) error { + currentTargets, err := d.currentState.Targets.GetAll() + if err != nil { + return fmt.Errorf("error fetching targets from state: %w", err) + } + + for _, target := range currentTargets { + n, err := d.deleteTarget(target) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *targetDiffer) deleteTarget(target *state.Target) (*crud.Event, error) { + _, err := d.targetState.Targets.Get(*target.Upstream.ID, + *target.Target.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: target, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up target %q: %w", + *target.Target.Target, err) + } + return nil, nil +} + +func (d *targetDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetTargets, err := d.targetState.Targets.GetAll() + if err != nil { + return fmt.Errorf("error fetching targets from state: %w", err) + } + + for _, target := range targetTargets { + n, err := d.createUpdateTarget(target) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *targetDiffer) createUpdateTarget(target *state.Target) (*crud.Event, error) { + target = &state.Target{Target: *target.DeepCopy()} + currentTarget, err := d.currentState.Targets.Get(*target.Upstream.ID, + *target.Target.ID) + if errors.Is(err, state.ErrNotFound) { + // target not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: target, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up target %q: %w", + *target.Target.Target, err) + } + // found, check if update needed + + if !currentTarget.EqualWithOpts(target, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: target, + OldObj: currentTarget, + }, nil + } + return nil, nil +} diff --git a/pkg/types/upstream.go b/pkg/types/upstream.go new file mode 100644 index 0000000..2e888dd --- /dev/null +++ b/pkg/types/upstream.go @@ -0,0 +1,162 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// upstreamCRUD implements crud.Actions interface. +type upstreamCRUD struct { + client *kong.Client +} + +func upstreamFromStruct(arg crud.Event) *state.Upstream { + upstream, ok := arg.Obj.(*state.Upstream) + if !ok { + panic("unexpected type, expected *state.upstream") + } + return upstream +} + +// Create creates a Upstream in Kong. +// The arg should be of type crud.Event, containing the upstream to be created, +// else the function will panic. +// It returns a the created *state.Upstream. +func (s *upstreamCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + upstream := upstreamFromStruct(event) + createdUpstream, err := s.client.Upstreams.Create(ctx, &upstream.Upstream) + if err != nil { + return nil, err + } + return &state.Upstream{Upstream: *createdUpstream}, nil +} + +// Delete deletes a Upstream in Kong. +// The arg should be of type crud.Event, containing the upstream to be deleted, +// else the function will panic. +// It returns a the deleted *state.Upstream. +func (s *upstreamCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + upstream := upstreamFromStruct(event) + err := s.client.Upstreams.Delete(ctx, upstream.ID) + if err != nil { + return nil, err + } + return upstream, nil +} + +// Update updates a Upstream in Kong. +// The arg should be of type crud.Event, containing the upstream to be updated, +// else the function will panic. +// It returns a the updated *state.Upstream. +func (s *upstreamCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + upstream := upstreamFromStruct(event) + + updatedUpstream, err := s.client.Upstreams.Create(ctx, &upstream.Upstream) + if err != nil { + return nil, err + } + return &state.Upstream{Upstream: *updatedUpstream}, nil +} + +type upstreamDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *upstreamDiffer) Deletes(handler func(crud.Event) error) error { + currentUpstreams, err := d.currentState.Upstreams.GetAll() + if err != nil { + return fmt.Errorf("error fetching upstreams from state: %w", err) + } + + for _, upstream := range currentUpstreams { + n, err := d.deleteUpstream(upstream) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *upstreamDiffer) deleteUpstream(upstream *state.Upstream) (*crud.Event, error) { + _, err := d.targetState.Upstreams.Get(*upstream.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: "upstream", + Obj: upstream, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up upstream %q: %w", + *upstream.Name, err) + } + return nil, nil +} + +func (d *upstreamDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetUpstreams, err := d.targetState.Upstreams.GetAll() + if err != nil { + return fmt.Errorf("error fetching upstreams from state: %w", err) + } + + for _, upstream := range targetUpstreams { + n, err := d.createUpdateUpstream(upstream) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *upstreamDiffer) createUpdateUpstream(upstream *state.Upstream) (*crud.Event, + error, +) { + upstreamCopy := &state.Upstream{Upstream: *upstream.DeepCopy()} + currentUpstream, err := d.currentState.Upstreams.Get(*upstream.Name) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: "upstream", + Obj: upstreamCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up upstream %v: %w", + *upstream.Name, err) + } + + // found, check if update needed + if !currentUpstream.EqualWithOpts(upstreamCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "upstream", + Obj: upstreamCopy, + OldObj: currentUpstream, + }, nil + } + return nil, nil +} diff --git a/pkg/types/vault.go b/pkg/types/vault.go new file mode 100644 index 0000000..e742934 --- /dev/null +++ b/pkg/types/vault.go @@ -0,0 +1,166 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/deck/crud" + "github.com/kong/deck/state" + "github.com/kong/go-kong/kong" +) + +// vaultCRUD implements crud.Actions interface. +type vaultCRUD struct { + client *kong.Client +} + +func vaultFromStruct(arg crud.Event) *state.Vault { + vault, ok := arg.Obj.(*state.Vault) + if !ok { + panic("unexpected type, expected *state.Vault") + } + return vault +} + +// Create creates a Vault in Kong. +// The arg should be of type crud.Event, containing the vault to be created, +// else the function will panic. +// It returns a the created *state.Vault. +func (s *vaultCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + vault := vaultFromStruct(event) + createdVault, err := s.client.Vaults.Create(ctx, &vault.Vault) + if err != nil { + return nil, err + } + return &state.Vault{Vault: *createdVault}, nil +} + +// Delete deletes a Vault in Kong. +// The arg should be of type crud.Event, containing the vault to be deleted, +// else the function will panic. +// It returns a the deleted *state.Vault. +func (s *vaultCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + vault := vaultFromStruct(event) + err := s.client.Vaults.Delete(ctx, vault.ID) + if err != nil { + return nil, err + } + return vault, nil +} + +// Update updates a Vault in Kong. +// The arg should be of type crud.Event, containing the vault to be updated, +// else the function will panic. +// It returns a the updated *state.Vault. +func (s *vaultCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + vault := vaultFromStruct(event) + + updatedVault, err := s.client.Vaults.Create(ctx, &vault.Vault) + if err != nil { + return nil, err + } + return &state.Vault{Vault: *updatedVault}, nil +} + +type vaultDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +// Deletes generates a memdb CRUD DELETE event for Vaults +// which is then consumed by the differ and used to gate Kong client calls. +func (d *vaultDiffer) Deletes(handler func(crud.Event) error) error { + currentVaults, err := d.currentState.Vaults.GetAll() + if err != nil { + return fmt.Errorf("error fetching vaults from state: %w", err) + } + + for _, vault := range currentVaults { + n, err := d.deleteVault(vault) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + + } + return nil +} + +func (d *vaultDiffer) deleteVault(vault *state.Vault) (*crud.Event, error) { + _, err := d.targetState.Vaults.Get(*vault.ID) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: "vault", + Obj: vault, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up vault %q: %w", + *vault.Prefix, err) + } + return nil, nil +} + +// CreateAndUpdates generates a memdb CRUD CREATE/UPDATE event for Vaults +// which is then consumed by the differ and used to gate Kong client calls. +func (d *vaultDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetVaults, err := d.targetState.Vaults.GetAll() + if err != nil { + return fmt.Errorf("error fetching vaults from state: %w", err) + } + + for _, vault := range targetVaults { + n, err := d.createUpdateVault(vault) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *vaultDiffer) createUpdateVault(vault *state.Vault) (*crud.Event, + error, +) { + vaultCopy := &state.Vault{Vault: *vault.DeepCopy()} + currentVault, err := d.currentState.Vaults.Get(*vault.Prefix) + + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + Kind: "vault", + Obj: vaultCopy, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up vault %v: %w", + *vault.Prefix, err) + } + + // found, check if update needed + if !currentVault.EqualWithOpts(vaultCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "vault", + Obj: vaultCopy, + OldObj: currentVault, + }, nil + } + return nil, nil +} diff --git a/pkg/utils/analytics.go b/pkg/utils/analytics.go new file mode 100644 index 0000000..1b321e7 --- /dev/null +++ b/pkg/utils/analytics.go @@ -0,0 +1,76 @@ +package utils + +import ( + "bytes" + "fmt" + "net" + "os" + "runtime" + "strings" + + "github.com/shirou/gopsutil/v3/host" +) + +const ( + reportsHost = "kong-hf.konghq.com" + reportsPort = 61829 + konnectMode = "konnect" +) + +func SendAnalytics(cmd, deckVersion, kongVersion, mode string) error { + if strings.ToLower(os.Getenv("DECK_ANALYTICS")) == "off" { + return nil + } + if cmd == "" { + return fmt.Errorf("invalid argument, 'cmd' cannot be empty") + } + + stats := collectStats(cmd, deckVersion, kongVersion, mode) + body := formatStats(stats) + + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", reportsHost, reportsPort)) + if err != nil { + return err + } + conn, err := net.DialUDP("udp", nil, addr) + if err != nil { + return err + } + defer conn.Close() + + _, err = conn.Write([]byte(body)) + if err != nil { + return err + } + return nil +} + +func formatStats(stats map[string]string) string { + var buffer bytes.Buffer + buffer.WriteString("<14>") + for k, v := range stats { + buffer.WriteString(fmt.Sprintf("%s=%s;", k, v)) + } + return buffer.String() +} + +func collectStats(cmd, deckVersion, kongVersion, mode string) map[string]string { + result := map[string]string{ + "signal": "decK", + "v": deckVersion, + "cmd": cmd, + "os": runtime.GOOS, + "arch": runtime.GOARCH, + } + if mode == konnectMode { + result["mode"] = mode + } + if kongVersion != "" && mode != konnectMode { + result["kv"] = kongVersion + } + info, err := host.Info() + if err == nil { + result["osv"] = info.Platform + " " + info.PlatformVersion + } + return result +} diff --git a/pkg/utils/constants.go b/pkg/utils/constants.go new file mode 100644 index 0000000..0176b6a --- /dev/null +++ b/pkg/utils/constants.go @@ -0,0 +1,88 @@ +package utils + +import ( + "github.com/kong/go-kong/kong" +) + +const ( + defaultTimeout = 60000 + defaultSlots = 10000 + defaultWeight = 100 + defaultConcurrency = 10 +) + +var ( + serviceDefaults = kong.Service{ + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(defaultTimeout), + WriteTimeout: kong.Int(defaultTimeout), + ReadTimeout: kong.Int(defaultTimeout), + } + routeDefaults = kong.Route{ + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(true), + Protocols: kong.StringSlice("http", "https"), + } + targetDefaults = kong.Target{ + Weight: kong.Int(defaultWeight), + } + upstreamDefaults = kong.Upstream{ + Slots: kong.Int(defaultSlots), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(defaultConcurrency), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + Interval: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + } + consumerGroupPluginDefault = kong.ConsumerGroupPlugin{ + Config: kong.Configuration{ + "window_type": "sliding", + }, + } + defaultsRestrictedFields = map[string][]string{ + "Service": {"ID", "Name"}, + "Route": {"ID", "Name"}, + "Target": {"ID", "Target"}, + "Upstream": {"ID", "Name"}, + } +) + +const ( + // ImplementationTypeKongGateway indicates an implementation backed by Kong Gateway. + ImplementationTypeKongGateway = "kong-gateway" +) diff --git a/pkg/utils/counter.go b/pkg/utils/counter.go new file mode 100644 index 0000000..14b2cc3 --- /dev/null +++ b/pkg/utils/counter.go @@ -0,0 +1,20 @@ +package utils + +import "sync" + +type AtomicInt32Counter struct { + counter int32 + lock sync.RWMutex +} + +func (a *AtomicInt32Counter) Increment(delta int32) { + a.lock.Lock() + defer a.lock.Unlock() + a.counter += delta +} + +func (a *AtomicInt32Counter) Count() int32 { + a.lock.RLock() + defer a.lock.RUnlock() + return a.counter +} diff --git a/pkg/utils/counter_test.go b/pkg/utils/counter_test.go new file mode 100644 index 0000000..0d45b5a --- /dev/null +++ b/pkg/utils/counter_test.go @@ -0,0 +1,23 @@ +package utils + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicInt32Counter(t *testing.T) { + var a AtomicInt32Counter + var wg sync.WaitGroup + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + a.Increment(int32(1)) + }() + } + wg.Wait() + assert.Equal(t, int32(10), a.Count()) +} diff --git a/pkg/utils/defaulter.go b/pkg/utils/defaulter.go new file mode 100644 index 0000000..d506014 --- /dev/null +++ b/pkg/utils/defaulter.go @@ -0,0 +1,410 @@ +package utils + +import ( + "context" + "fmt" + "net/http" + "reflect" + "strings" + + "github.com/imdario/mergo" + "github.com/kong/go-kong/kong" +) + +var kongToKonnectEntitiesMap = map[string]string{ + "services": "service", + "routes": "route", + "upstreams": "upstream", + "targets": "target", +} + +// Defaulter registers types and fills in struct fields with +// default values. +type Defaulter struct { + r map[string]interface{} + + ctx context.Context + client *kong.Client + isKonnect bool + + service *kong.Service + route *kong.Route + upstream *kong.Upstream + target *kong.Target + consumerGroupPlugin *kong.ConsumerGroupPlugin +} + +type DefaulterOpts struct { + KongDefaults interface{} + DisableDynamicDefaults bool + IsKonnect bool + Client *kong.Client +} + +// NewDefaulter initializes a Defaulter with empty entities. +func NewDefaulter() *Defaulter { + return &Defaulter{ + service: &kong.Service{}, + route: &kong.Route{}, + upstream: &kong.Upstream{}, + target: &kong.Target{}, + consumerGroupPlugin: &kong.ConsumerGroupPlugin{}, + } +} + +func getKongDefaulter(opts DefaulterOpts) (*Defaulter, error) { + d := NewDefaulter() + if err := d.populateDefaultsFromInput(opts.KongDefaults); err != nil { + return nil, err + } + + if opts.DisableDynamicDefaults { + if err := d.populateStaticDefaultsForKonnect(); err != nil { + return nil, err + } + } + + err := d.Register(d.service) + if err != nil { + return nil, fmt.Errorf("registering service with defaulter: %w", err) + } + err = d.Register(d.route) + if err != nil { + return nil, fmt.Errorf("registering route with defaulter: %w", err) + } + err = d.Register(d.upstream) + if err != nil { + return nil, fmt.Errorf("registering upstream with defaulter: %w", err) + } + err = d.Register(d.target) + if err != nil { + return nil, fmt.Errorf("registering target with defaulter: %w", err) + } + err = d.Register(d.consumerGroupPlugin) + if err != nil { + return nil, fmt.Errorf("registering consumer-group-plugin with defaulter: %w", err) + } + return d, nil +} + +// Check if `entity` has restricted fields set +func checkEntityDefaults(entity interface{}, restrictedFields []string) error { + var invalidFields []string + r := reflect.ValueOf(entity) + for _, fieldName := range restrictedFields { + field := reflect.Indirect(r).FieldByName(fieldName) + if field.IsValid() && !field.IsNil() { + invalidFields = append(invalidFields, strings.ToLower(fieldName)) + } + } + if len(invalidFields) > 0 { + return fmt.Errorf("cannot have these restricted fields set: %s", + strings.Join(invalidFields, ", ")) + } + return nil +} + +func (d *Defaulter) once() { + if d.r == nil { + d.r = make(map[string]interface{}) + } +} + +// Register registers a type and it's default value. +// The default value is passed in and the type is inferred from the +// default value. +func (d *Defaulter) Register(def interface{}) error { + d.once() + v := reflect.ValueOf(def) + if !v.IsValid() { + return fmt.Errorf("invalid value") + } + v = reflect.Indirect(v) + d.r[v.Type().String()] = def + return nil +} + +type kongTransformer struct{} + +func (t kongTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { + var a *int + var ar []int + var b *bool + switch typ { + case reflect.TypeOf(ar): + + return func(dst, src reflect.Value) error { + if dst.CanSet() { + if reflect.DeepEqual(reflect.Zero(dst.Type()).Interface(), dst.Interface()) { + return nil + } + } + return nil + } + case reflect.TypeOf(a): + + return func(dst, src reflect.Value) error { + if dst.CanSet() { + if reflect.DeepEqual(reflect.Zero(dst.Type()).Interface(), dst.Interface()) { + return nil + } + } + return nil + } + case reflect.TypeOf(b): + + return func(dst, src reflect.Value) error { + if dst.CanSet() { + if reflect.DeepEqual(reflect.Zero(dst.Type()).Interface(), dst.Interface()) { + return nil + } + } + return nil + } + default: + + return nil + } +} + +// Set fills in default values in a struct of a registered type. +func (d *Defaulter) Set(arg interface{}) error { + d.once() + v := reflect.ValueOf(arg) + if !v.IsValid() { + return fmt.Errorf("invalid value") + } + v = reflect.Indirect(v) + defValue, ok := d.r[v.Type().String()] + if !ok { + return fmt.Errorf("type not registered: %v", reflect.TypeOf(arg)) + } + err := mergo.Merge(arg, defValue, mergo.WithTransformers(kongTransformer{})) + if err != nil { + err = fmt.Errorf("merging: %w", err) + } + return err + // return defaulter.Set(arg, defValue) +} + +// MustSet is like Set but panics if there is an error. +func (d *Defaulter) MustSet(arg interface{}) { + err := d.Set(arg) + if err != nil { + panic(err) + } +} + +func (d *Defaulter) getEntitySchema(entityType string) (map[string]interface{}, error) { + var ( + schema map[string]interface{} + ok bool + ) + endpoint := fmt.Sprintf("/schemas/%s", entityType) + if d.isKonnect { + entityType, ok = kongToKonnectEntitiesMap[entityType] + // if no mapping is found, then the schema cannot be fetched + // from Konnet and we should proceed without defaults. + if !ok { + return schema, nil + } + endpoint = fmt.Sprintf("/v1/schemas/json/%s", entityType) + } + req, err := d.client.NewRequest(http.MethodGet, endpoint, nil, nil) + if err != nil { + return schema, err + } + resp, err := d.client.Do(d.ctx, req, &schema) + if resp == nil { + return schema, fmt.Errorf("invalid HTTP response: %w", err) + } + // in case the schema is not found - like in case of EE features, + // no error should be returned. + if resp.StatusCode == http.StatusNotFound { + return schema, nil + } + return schema, err +} + +func (d *Defaulter) addEntityDefaults(entityType string, entity interface{}) error { + schema, err := d.getEntitySchema(entityType) + if schema == nil && err == nil { + return nil + } + if err != nil { + return fmt.Errorf("retrieve schema for %v from Kong: %w", entityType, err) + } + return kong.FillEntityDefaults(entity, schema) +} + +func getKongDefaulterWithClient(ctx context.Context, opts DefaulterOpts) (*Defaulter, error) { + // fills defaults from input + d, err := getKongDefaulter(opts) + if err != nil { + return nil, err + } + d.ctx = ctx + d.client = opts.Client + d.isKonnect = opts.IsKonnect + + // fills defaults from Kong API + if err := d.addEntityDefaults("services", d.service); err != nil { + return nil, fmt.Errorf("get defaults for services: %w", err) + } + if err := d.Register(d.service); err != nil { + return nil, fmt.Errorf("registering service with defaulter: %w", err) + } + + if err := d.addEntityDefaults("routes", d.route); err != nil { + return nil, fmt.Errorf("get defaults for routes: %w", err) + } + if err := d.Register(d.route); err != nil { + return nil, fmt.Errorf("registering route with defaulter: %w", err) + } + + if err := d.addEntityDefaults("upstreams", d.upstream); err != nil { + return nil, fmt.Errorf("get defaults for upstreams: %w", err) + } + if err := d.Register(d.upstream); err != nil { + return nil, fmt.Errorf("registering upstream with defaulter: %w", err) + } + + if err := d.addEntityDefaults("targets", d.target); err != nil { + return nil, fmt.Errorf("get defaults for targets: %w", err) + } + if err := d.Register(d.target); err != nil { + return nil, fmt.Errorf("registering target with defaulter: %w", err) + } + + // since Konnect implements a different consumer-group API than the one from the + // Kong Gateway, it's not straight-forward to handle defaults injection the same + // way due to schema differences. In order to overcome this limitation, we are + // statically loading defaults for the consumer-group plugin override when running + // against Konnect, while still relying on the Admin API for Kong Gateway. + if d.isKonnect { + if err := mergo.Merge( + d.consumerGroupPlugin, &consumerGroupPluginDefault, mergo.WithTransformers(kongTransformer{}), + ); err != nil { + return nil, fmt.Errorf("merging consumer-group-plugin static defaults: %w", err) + } + } else { + if err := d.addEntityDefaults("consumer_group_plugins", d.consumerGroupPlugin); err != nil { + return nil, fmt.Errorf("get defaults for consumer-group-plugin: %w", err) + } + if err := d.Register(d.consumerGroupPlugin); err != nil { + return nil, fmt.Errorf("registering consumer-group-plugin with defaulter: %w", err) + } + } + return d, nil +} + +// GetDefaulter returns a Defaulter object to be used to set defaults +// on Kong entities. The order of precedence is as follow, from higher to lower: +// +// 1. values set in the state file +// 2. values set in the {_info: defaults:} object in the state file +// 3. hardcoded defaults under utils/constants.go (Konnect-only) +func GetDefaulter(ctx context.Context, opts DefaulterOpts) (*Defaulter, error) { + exists, err := WorkspaceExists(ctx, opts.Client) + if err != nil { + return nil, fmt.Errorf("ensure workspace exists: %w", err) + } + if opts.Client != nil && !opts.DisableDynamicDefaults && exists { + return getKongDefaulterWithClient(ctx, opts) + } + opts.DisableDynamicDefaults = true + return getKongDefaulter(opts) +} + +func (d *Defaulter) populateDefaultsFromInput(defaults interface{}) error { + err := validateKongDefaults(defaults) + if err != nil { + return fmt.Errorf("validating defaults: %w", err) + } + + r := reflect.ValueOf(defaults) + + service := reflect.Indirect(r).FieldByName("Service") + serviceObj := service.Interface().(*kong.Service) + if serviceObj != nil { + err := mergo.Merge(d.service, serviceObj, mergo.WithTransformers(kongTransformer{})) + if err != nil { + return fmt.Errorf("merging: %w", err) + } + } + + route := reflect.Indirect(r).FieldByName("Route") + routeObj := route.Interface().(*kong.Route) + if routeObj != nil { + err := mergo.Merge(d.route, routeObj, mergo.WithTransformers(kongTransformer{})) + if err != nil { + return fmt.Errorf("merging: %w", err) + } + } + + upstream := reflect.Indirect(r).FieldByName("Upstream") + upstreamObj := upstream.Interface().(*kong.Upstream) + if upstreamObj != nil { + err := mergo.Merge(d.upstream, upstreamObj, mergo.WithTransformers(kongTransformer{})) + if err != nil { + return fmt.Errorf("merging: %w", err) + } + } + + target := reflect.Indirect(r).FieldByName("Target") + targetObj := target.Interface().(*kong.Target) + if targetObj != nil { + err := mergo.Merge(d.target, targetObj, mergo.WithTransformers(kongTransformer{})) + if err != nil { + return fmt.Errorf("merging: %w", err) + } + } + return nil +} + +func validateKongDefaults(defaults interface{}) error { + var errs ErrArray + r := reflect.ValueOf(defaults) + for objectName, restrictedFields := range defaultsRestrictedFields { + objectValue := reflect.Indirect(r).FieldByName(objectName) + if objectValue.IsNil() || !objectValue.IsValid() { + continue + } + object := objectValue.Interface() + err := checkEntityDefaults(object, restrictedFields) + if err != nil { + entityErr := fmt.Errorf( + "%s defaults %w", strings.ToLower(objectName), err) + errs.Errors = append(errs.Errors, entityErr) + } + } + if errs.Errors != nil { + return errs + } + return nil +} + +func (d *Defaulter) populateStaticDefaultsForKonnect() error { + if err := mergo.Merge( + d.service, &serviceDefaults, mergo.WithTransformers(kongTransformer{}), + ); err != nil { + return fmt.Errorf("merging service static defaults: %w", err) + } + if err := mergo.Merge( + d.route, &routeDefaults, mergo.WithTransformers(kongTransformer{}), + ); err != nil { + return fmt.Errorf("merging route static defaults: %w", err) + } + if err := mergo.Merge( + d.upstream, &upstreamDefaults, mergo.WithTransformers(kongTransformer{}), + ); err != nil { + return fmt.Errorf("merging upstream static defaults: %w", err) + } + if err := mergo.Merge( + d.target, &targetDefaults, mergo.WithTransformers(kongTransformer{}), + ); err != nil { + return fmt.Errorf("merging target static defaults: %w", err) + } + + return nil +} diff --git a/pkg/utils/defaulter_test.go b/pkg/utils/defaulter_test.go new file mode 100644 index 0000000..9e25f9e --- /dev/null +++ b/pkg/utils/defaulter_test.go @@ -0,0 +1,737 @@ +package utils + +import ( + "context" + "reflect" + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +type kongDefaultForTesting struct { + Service *kong.Service + Route *kong.Route + Upstream *kong.Upstream + Target *kong.Target +} + +var kongDefaults = kongDefaultForTesting{ + Service: &serviceDefaults, + Route: &routeDefaults, + Upstream: &upstreamDefaults, + Target: &targetDefaults, +} + +var defaulterTestOpts = DefaulterOpts{ + KongDefaults: kongDefaults, + DisableDynamicDefaults: false, +} + +func TestDefaulter(t *testing.T) { + assert := assert.New(t) + + var d Defaulter + + assert.NotNil(d.Register(nil)) + assert.NotNil(d.Set(nil)) + + assert.Panics(func() { + d.MustSet(d) + }) + + type Foo struct { + A string + B []string + } + defaultFoo := &Foo{ + A: "defaultA", + B: []string{"default1"}, + } + assert.Nil(d.Register(defaultFoo)) + + // sets a default + var arg Foo + assert.Nil(d.Set(&arg)) + assert.Equal("defaultA", arg.A) + assert.Equal([]string{"default1"}, arg.B) + + // doesn't set a default + arg1 := Foo{ + A: "non-default-value", + } + assert.Nil(d.Set(&arg1)) + assert.Equal("non-default-value", arg1.A) + + // errors on an unregistered type + type Bar struct { + A string + } + assert.NotNil(d.Set(&Bar{})) + assert.Panics(func() { + d.MustSet(&Bar{}) + }) +} + +func TestServiceSetTest(t *testing.T) { + assert := assert.New(t) + ctx := context.Background() + d, err := GetDefaulter(ctx, defaulterTestOpts) + assert.NotNil(d) + assert.Nil(err) + + testCases := []struct { + desc string + arg *kong.Service + want *kong.Service + }{ + { + desc: "empty service", + arg: &kong.Service{}, + want: &serviceDefaults, + }, + { + desc: "retries can be set to 0", + arg: &kong.Service{ + Retries: kong.Int(0), + }, + want: &kong.Service{ + Retries: kong.Int(0), + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(60000), + ReadTimeout: kong.Int(60000), + }, + }, + { + desc: "timeout value value is not overridden", + arg: &kong.Service{ + WriteTimeout: kong.Int(42), + }, + want: &kong.Service{ + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(42), + ReadTimeout: kong.Int(60000), + }, + }, + { + desc: "path value is not overridden", + arg: &kong.Service{ + Path: kong.String("/foo"), + }, + want: &kong.Service{ + Protocol: kong.String("http"), + Path: kong.String("/foo"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(60000), + ReadTimeout: kong.Int(60000), + }, + }, + { + desc: "Name is not reset", + arg: &kong.Service{ + Name: kong.String("foo"), + Host: kong.String("example.com"), + Path: kong.String("/bar"), + }, + want: &kong.Service{ + Name: kong.String("foo"), + Host: kong.String("example.com"), + Protocol: kong.String("http"), + Path: kong.String("/bar"), + ConnectTimeout: kong.Int(60000), + WriteTimeout: kong.Int(60000), + ReadTimeout: kong.Int(60000), + }, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + err := d.Set(tC.arg) + assert.Nil(err) + assert.Equal(tC.want, tC.arg) + }) + } +} + +func TestRouteSetTest(t *testing.T) { + assert := assert.New(t) + ctx := context.Background() + d, err := GetDefaulter(ctx, defaulterTestOpts) + assert.NotNil(d) + assert.Nil(err) + + testCases := []struct { + desc string + arg *kong.Route + want *kong.Route + }{ + { + desc: "empty route", + arg: &kong.Route{}, + want: &routeDefaults, + }, + { + desc: "preserve host is not overridden", + arg: &kong.Route{ + PreserveHost: kong.Bool(true), + }, + want: &kong.Route{ + PreserveHost: kong.Bool(true), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(true), + Protocols: kong.StringSlice("http", "https"), + }, + }, + { + desc: "Protocols is not reset", + arg: &kong.Route{ + Protocols: kong.StringSlice("http", "tls"), + }, + want: &kong.Route{ + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(true), + Protocols: kong.StringSlice("http", "tls"), + }, + }, + { + desc: "non-default feilds is not reset", + arg: &kong.Route{ + Name: kong.String("foo"), + Hosts: kong.StringSlice("1.example.com", "2.example.com"), + Methods: kong.StringSlice("GET", "POST"), + StripPath: kong.Bool(true), + }, + want: &kong.Route{ + Name: kong.String("foo"), + Hosts: kong.StringSlice("1.example.com", "2.example.com"), + Methods: kong.StringSlice("GET", "POST"), + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(true), + Protocols: kong.StringSlice("http", "https"), + }, + }, + { + desc: "strip-path can be set to false", + arg: &kong.Route{ + StripPath: kong.Bool(false), + }, + want: &kong.Route{ + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(false), + Protocols: kong.StringSlice("http", "https"), + }, + }, + { + desc: "strip-path can be set to true", + arg: &kong.Route{ + StripPath: kong.Bool(true), + }, + want: &kong.Route{ + PreserveHost: kong.Bool(false), + RegexPriority: kong.Int(0), + StripPath: kong.Bool(true), + Protocols: kong.StringSlice("http", "https"), + }, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + err := d.Set(tC.arg) + assert.Nil(err) + assert.Equal(tC.want, tC.arg) + }) + } +} + +func TestUpstreamSetTest(t *testing.T) { + assert := assert.New(t) + ctx := context.Background() + d, err := GetDefaulter(ctx, defaulterTestOpts) + assert.NotNil(d) + assert.Nil(err) + + testCases := []struct { + desc string + arg *kong.Upstream + want *kong.Upstream + }{ + { + desc: "empty upstream", + arg: &kong.Upstream{}, + want: &upstreamDefaults, + }, + { + desc: "Healthchecks.Active.Healthy.HTTPStatuses is not overridden", + arg: &kong.Upstream{ + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200}, + }, + }, + }, + }, + want: &kong.Upstream{ + Slots: kong.Int(10000), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200}, + Interval: kong.Int(0), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + Interval: kong.Int(0), + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + { + desc: "Healthchecks.Active.Healthy.Timeout is not overridden", + arg: &kong.Upstream{ + Name: kong.String("foo"), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Healthy: &kong.Healthy{ + Interval: kong.Int(1), + }, + }, + }, + }, + want: &kong.Upstream{ + Name: kong.String("foo"), + Slots: kong.Int(10000), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(1), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + Interval: kong.Int(0), + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + { + desc: "Healthchecks.Active.HTTPSVerifyCertificate can be set to false", + arg: &kong.Upstream{ + Name: kong.String("foo"), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Healthy: &kong.Healthy{ + Interval: kong.Int(1), + }, + HTTPSVerifyCertificate: kong.Bool(false), + }, + }, + }, + want: &kong.Upstream{ + Name: kong.String("foo"), + Slots: kong.Int(10000), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(1), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + HTTPSVerifyCertificate: kong.Bool(false), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + Interval: kong.Int(0), + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + { + desc: "Healthchecks.Active.HTTPSVerifyCertificate can be set to true", + arg: &kong.Upstream{ + Name: kong.String("foo"), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Healthy: &kong.Healthy{ + Interval: kong.Int(1), + }, + HTTPSVerifyCertificate: kong.Bool(true), + }, + }, + }, + want: &kong.Upstream{ + Name: kong.String("foo"), + Slots: kong.Int(10000), + Healthchecks: &kong.Healthcheck{ + Active: &kong.ActiveHealthcheck{ + Concurrency: kong.Int(10), + Healthy: &kong.Healthy{ + HTTPStatuses: []int{200, 302}, + Interval: kong.Int(1), + Successes: kong.Int(0), + }, + HTTPPath: kong.String("/"), + HTTPSVerifyCertificate: kong.Bool(true), + Type: kong.String("http"), + Timeout: kong.Int(1), + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 404, 500, 501, 502, 503, 504, 505}, + Interval: kong.Int(0), + }, + }, + Passive: &kong.PassiveHealthcheck{ + Healthy: &kong.Healthy{ + HTTPStatuses: []int{ + 200, 201, 202, 203, 204, 205, + 206, 207, 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308, + }, + Successes: kong.Int(0), + }, + Unhealthy: &kong.Unhealthy{ + HTTPFailures: kong.Int(0), + TCPFailures: kong.Int(0), + Timeouts: kong.Int(0), + HTTPStatuses: []int{429, 500, 503}, + }, + }, + }, + HashOn: kong.String("none"), + HashFallback: kong.String("none"), + HashOnCookiePath: kong.String("/"), + }, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + err := d.Set(tC.arg) + assert.Nil(err) + assert.Equal(tC.want, tC.arg) + }) + } +} + +func TestGetDefaulter_Konnect(t *testing.T) { + assert := assert.New(t) + + testCases := []struct { + desc string + opts DefaulterOpts + want *Defaulter + }{ + { + desc: "empty user defaults", + opts: DefaulterOpts{ + KongDefaults: &kongDefaultForTesting{}, + DisableDynamicDefaults: true, + }, + want: &Defaulter{ + service: &serviceDefaults, + route: &routeDefaults, + upstream: &upstreamDefaults, + target: &targetDefaults, + }, + }, + { + desc: "user defaults take precedence", + opts: DefaulterOpts{ + KongDefaults: &kongDefaultForTesting{ + Service: &kong.Service{ + Port: kong.Int(8080), + Path: kong.String("/v1"), + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(defaultTimeout), + WriteTimeout: kong.Int(defaultTimeout), + ReadTimeout: kong.Int(defaultTimeout), + }, + }, + DisableDynamicDefaults: true, + }, + want: &Defaulter{ + service: &kong.Service{ + Port: kong.Int(8080), + Path: kong.String("/v1"), + Protocol: kong.String("http"), + ConnectTimeout: kong.Int(defaultTimeout), + WriteTimeout: kong.Int(defaultTimeout), + ReadTimeout: kong.Int(defaultTimeout), + }, + route: &routeDefaults, + upstream: &upstreamDefaults, + target: &targetDefaults, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.Background() + d, err := GetDefaulter(ctx, tc.opts) + assert.NotNil(d) + assert.Nil(err) + + if !reflect.DeepEqual(d.service, tc.want.service) { + assert.Equal(t, tc.want.service, d.service) + } + if !reflect.DeepEqual(d.route, tc.want.route) { + assert.Equal(t, tc.want.route, d.route) + } + if !reflect.DeepEqual(d.upstream, tc.want.upstream) { + assert.Equal(t, tc.want.upstream, d.upstream) + } + if !reflect.DeepEqual(d.target, tc.want.target) { + assert.Equal(t, tc.want.target, d.target) + } + }) + } +} + +func TestCheckRestrictedFields(t *testing.T) { + assert := assert.New(t) + + testCases := []struct { + desc string + entity *kong.Service + restrictedFields []string + wantErr bool + expectedErr string + }{ + { + desc: "no restricted fields", + entity: &kong.Service{ + ID: kong.String("testID"), + Name: kong.String("testName"), + }, + restrictedFields: []string{}, + }, + { + desc: "one restricted fields", + entity: &kong.Service{ + ID: kong.String("testID"), + Name: kong.String("testName"), + }, + restrictedFields: []string{"ID"}, + wantErr: true, + expectedErr: "cannot have these restricted fields set: id", + }, + { + desc: "multiple restricted fields", + entity: &kong.Service{ + ID: kong.String("testID"), + Name: kong.String("testName"), + Port: kong.Int(80), + }, + restrictedFields: []string{"ID", "Name", "Port"}, + wantErr: true, + expectedErr: "cannot have these restricted fields set: id, name, port", + }, + } + + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + err := checkEntityDefaults(tC.entity, tC.restrictedFields) + if (err != nil) != tC.wantErr { + t.Errorf("got error = %v, expected error = %v", err, tC.wantErr) + } + if tC.expectedErr != "" { + assert.Equal(err.Error(), tC.expectedErr) + } + }) + } +} + +func TestKongDefaultsRestrictedFields(t *testing.T) { + assert := assert.New(t) + ctx := context.Background() + + testCases := []struct { + desc string + kongDefaults *kongDefaultForTesting + wantErr bool + expectedErr string + }{ + { + desc: "service no restricted fields", + kongDefaults: &kongDefaultForTesting{ + Service: &kong.Service{ + Path: kong.String("/v1"), + }, + }, + }, + { + desc: "route no restricted fields", + kongDefaults: &kongDefaultForTesting{ + Route: &kong.Route{ + StripPath: kong.Bool(false), + }, + }, + }, + { + desc: "target no restricted fields", + kongDefaults: &kongDefaultForTesting{ + Target: &kong.Target{ + Weight: kong.Int(42), + }, + }, + }, + { + desc: "upstream no restricted fields", + kongDefaults: &kongDefaultForTesting{ + Upstream: &kong.Upstream{ + HostHeader: kong.String("testHostHeader"), + }, + }, + }, + { + desc: "service restricted fields", + kongDefaults: &kongDefaultForTesting{ + Service: &kong.Service{ + ID: kong.String("testID"), + Name: kong.String("testName"), + Path: kong.String("/v1"), + }, + }, + wantErr: true, + expectedErr: "service defaults cannot have these restricted fields set: id, name", + }, + { + desc: "route restricted fields", + kongDefaults: &kongDefaultForTesting{ + Route: &kong.Route{ + ID: kong.String("testID"), + Name: kong.String("testName"), + StripPath: kong.Bool(false), + }, + }, + wantErr: true, + expectedErr: "route defaults cannot have these restricted fields set: id, name", + }, + { + desc: "target restricted fields", + kongDefaults: &kongDefaultForTesting{ + Target: &kong.Target{ + ID: kong.String("testID"), + Target: kong.String("testTarget"), + }, + }, + wantErr: true, + expectedErr: "target defaults cannot have these restricted fields set: id, target", + }, + { + desc: "upstream restricted fields", + kongDefaults: &kongDefaultForTesting{ + Upstream: &kong.Upstream{ + ID: kong.String("testID"), + Name: kong.String("testName"), + HostHeader: kong.String("testHostHeader"), + }, + }, + wantErr: true, + expectedErr: "upstream defaults cannot have these restricted fields set: id, name", + }, + } + + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + opts := DefaulterOpts{ + KongDefaults: tC.kongDefaults, + } + _, err := GetDefaulter(ctx, opts) + if (err != nil) != tC.wantErr { + t.Errorf("got error = %v, expected error = %v", err, tC.wantErr) + } + if tC.expectedErr != "" { + assert.Contains(err.Error(), tC.expectedErr) + } + }) + } +} diff --git a/pkg/utils/prompt.go b/pkg/utils/prompt.go new file mode 100644 index 0000000..afcb1aa --- /dev/null +++ b/pkg/utils/prompt.go @@ -0,0 +1,49 @@ +package utils + +import ( + "errors" + "fmt" + "os" + "strings" +) + +// Confirm prompts a user for a confirmation with message +// and returns true with no error if input is "yes" or "y" (case-insensitive), +// otherwise false. +func Confirm(message string) (bool, error) { + fmt.Print(message) + validOptions := []string{"yes", "y"} + var input string + _, err := fmt.Scanln(&input) + if err != nil { + return false, err + } + input = strings.ToLower(input) + for _, validOption := range validOptions { + if input == validOption { + return true, nil + } + } + return false, nil +} + +// ConfirmFileOverwrite is a helper function to determine whether or not the program should +// truncate and overwrite a file given its name and extension. If the file doesn't already exist +// in the filesystem, then this will return true, otherwise it will prompt the user for confirmation. +func ConfirmFileOverwrite(filename string, ext string, assumeYes bool) (bool, error) { + if assumeYes { + return true, nil + } + + filename = AddExtToFilename(filename, ext) + _, err := os.Stat(filename) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return true, nil + } + return false, err + } + + // file exists, prompt user + return Confirm("File '" + filename + "' already exists. Do you want to overwrite it? ") +} diff --git a/pkg/utils/tags.go b/pkg/utils/tags.go new file mode 100644 index 0000000..e09f3a6 --- /dev/null +++ b/pkg/utils/tags.go @@ -0,0 +1,84 @@ +package utils + +import ( + "fmt" + "reflect" +) + +// MustMergeTags is same as MergeTags but panics if there is an error. +func MustMergeTags(obj interface{}, tags []string) { + err := MergeTags(obj, tags) + if err != nil { + panic(err) + } +} + +// MergeTags merges Tags in the object with tags. +func MergeTags(obj interface{}, tags []string) error { + if len(tags) == 0 { + return nil + } + ptr := reflect.ValueOf(obj) + if ptr.Kind() != reflect.Ptr { + return fmt.Errorf("obj is not a pointer") + } + v := reflect.Indirect(ptr) + structTags := v.FieldByName("Tags") + var zero reflect.Value + if structTags == zero { + return nil + } + m := make(map[string]bool) + for i := 0; i < structTags.Len(); i++ { + tag := reflect.Indirect(structTags.Index(i)).String() + m[tag] = true + } + for _, tag := range tags { + if _, ok := m[tag]; !ok { + t := tag + structTags.Set(reflect.Append(structTags, reflect.ValueOf(&t))) + } + } + return nil +} + +// MustRemoveTags is same as RemoveTags but panics if there is an error. +func MustRemoveTags(obj interface{}, tags []string) { + err := RemoveTags(obj, tags) + if err != nil { + panic(err) + } +} + +// RemoveTags removes tags from the Tags in obj. +func RemoveTags(obj interface{}, tags []string) error { + if len(tags) == 0 { + return nil + } + + m := make(map[string]bool) + for _, tag := range tags { + m[tag] = true + } + + ptr := reflect.ValueOf(obj) + if ptr.Kind() != reflect.Ptr { + return fmt.Errorf("obj is not a pointer") + } + v := reflect.Indirect(ptr) + structTags := v.FieldByName("Tags") + var zero reflect.Value + if structTags == zero { + return nil + } + + res := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(reflect.TypeOf(""))), 0, 0) + for i := 0; i < structTags.Len(); i++ { + tag := reflect.Indirect(structTags.Index(i)).String() + if !m[tag] { + res = reflect.Append(res, structTags.Index(i)) + } + } + structTags.Set(res) + return nil +} diff --git a/pkg/utils/tags_test.go b/pkg/utils/tags_test.go new file mode 100644 index 0000000..c74f114 --- /dev/null +++ b/pkg/utils/tags_test.go @@ -0,0 +1,92 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMergeTags(t *testing.T) { + type Foo struct { + Tags []*string + } + type Bar struct{} + + assert := assert.New(t) + + a := "tag1" + b := "tag2" + c := "tag3" + + var f Foo + err := MergeTags(f, []string{"tag1"}) + assert.NotNil(err) + + assert.Panics(func() { + MustMergeTags(f, []string{"tag1"}) + }) + + var bar Bar + err = MergeTags(&bar, []string{"tag1"}) + assert.Nil(err) + + f = Foo{Tags: []*string{&a, &b}} + assert.Nil(MergeTags(&f, []string{"tag1", "tag2", "tag3"})) + assert.True(equalArray([]*string{&a, &b, &c}, f.Tags)) + + f = Foo{Tags: []*string{}} + assert.Nil(MergeTags(&f, []string{"tag1", "tag2", "tag3"})) + assert.True(equalArray([]*string{&a, &b, &c}, f.Tags)) + + f = Foo{Tags: []*string{&a, &b}} + assert.Nil(MergeTags(&f, nil)) + assert.True(equalArray([]*string{&a, &b}, f.Tags)) +} + +func equalArray(want, have []*string) bool { + if len(want) != len(have) { + return false + } + for i := 0; i < len(want); i++ { + if *want[i] != *have[i] { + return false + } + } + return true +} + +func TestRemoveTags(t *testing.T) { + type Foo struct { + Tags []*string + } + type Bar struct{} + + assert := assert.New(t) + + a := "tag1" + b := "tag2" + + var f Foo + err := RemoveTags(f, []string{"tag1"}) + assert.NotNil(err) + + assert.Panics(func() { + MustRemoveTags(f, []string{"tag1"}) + }) + + var bar Bar + err = RemoveTags(&bar, []string{"tag1"}) + assert.Nil(err) + + f = Foo{Tags: []*string{&a, &b}} + RemoveTags(&f, []string{"tag2", "tag3"}) + assert.True(equalArray([]*string{&a}, f.Tags)) + + f = Foo{Tags: []*string{}} + RemoveTags(&f, []string{"tag1", "tag2", "tag3"}) + assert.True(equalArray([]*string{}, f.Tags)) + + f = Foo{Tags: []*string{&a, &b}} + RemoveTags(&f, nil) + assert.True(equalArray([]*string{&a, &b}, f.Tags)) +} diff --git a/pkg/utils/types.go b/pkg/utils/types.go new file mode 100644 index 0000000..3d4986e --- /dev/null +++ b/pkg/utils/types.go @@ -0,0 +1,347 @@ +package utils + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "math" + "net" + "net/http" + "net/url" + "os" + "regexp" + "strconv" + "strings" + "time" + + "github.com/hashicorp/go-retryablehttp" + "github.com/kong/deck/konnect" + "github.com/kong/go-kong/kong" + "github.com/kong/go-kong/kong/custom" + "github.com/ssgelm/cookiejarparser" +) + +var clientTimeout time.Duration + +// KongRawState contains all of Kong Data +type KongRawState struct { + Services []*kong.Service + Routes []*kong.Route + + Plugins []*kong.Plugin + + Upstreams []*kong.Upstream + Targets []*kong.Target + + Certificates []*kong.Certificate + SNIs []*kong.SNI + CACertificates []*kong.CACertificate + + Consumers []*kong.Consumer + ConsumerGroups []*kong.ConsumerGroupObject + CustomEntities []*custom.Entity + + Vaults []*kong.Vault + + KeyAuths []*kong.KeyAuth + HMACAuths []*kong.HMACAuth + JWTAuths []*kong.JWTAuth + BasicAuths []*kong.BasicAuth + ACLGroups []*kong.ACLGroup + Oauth2Creds []*kong.Oauth2Credential + MTLSAuths []*kong.MTLSAuth + + RBACRoles []*kong.RBACRole + RBACEndpointPermissions []*kong.RBACEndpointPermission +} + +// KonnectRawState contains all of Konnect resources. +type KonnectRawState struct { + ServicePackages []*konnect.ServicePackage + Documents []*konnect.Document +} + +// ErrArray holds an array of errors. +type ErrArray struct { + Errors []error +} + +// Error returns a pretty string of errors present. +func (e ErrArray) Error() string { + if len(e.Errors) == 0 { + return "nil" + } + var res string + + res = strconv.Itoa(len(e.Errors)) + " errors occurred:\n" + for _, err := range e.Errors { + res += fmt.Sprintf("\t%v\n", err) + } + return res +} + +func (e ErrArray) ErrorList() []string { + errList := []string{} + + for _, err := range e.Errors { + errList = append(errList, err.Error()) + } + return errList +} + +// KongClientConfig holds config details to use to talk to a Kong server. +type KongClientConfig struct { + Address string + Workspace string + + TLSServerName string + + TLSCACert string + + TLSSkipVerify bool + Debug bool + + SkipWorkspaceCrud bool + + Headers []string + + HTTPClient *http.Client + + Timeout int + + CookieJarPath string + + TLSClientCert string + + TLSClientKey string + + // whether or not the client should retry on 429s + Retryable bool +} + +type KonnectConfig struct { + Email string + Password string + Token string + Debug bool + + Address string + + Headers []string + + ControlPlaneName string +} + +// ForWorkspace returns a copy of KongClientConfig that produces a KongClient for the workspace specified by argument. +func (kc *KongClientConfig) ForWorkspace(name string) KongClientConfig { + result := *kc + result.Workspace = name + return result +} + +// backoffStrategy provides a callback for Client.Backoff which +// will perform exponential backoff based on the attempt number and limited +// by the provided minimum and maximum durations. +// +// It also tries to parse Retry-After response header when a http.StatusTooManyRequests +// (HTTP Code 429) is found in the resp parameter. Hence it will return the number of +// seconds the server states it may be ready to process more requests from this client. +// +// This is the same as DefaultBackoff (https://github.com/hashicorp/go-retryablehttp/blob/v0.7.1/client.go#L503) +// except that here we are only retrying on 429s. +func backoffStrategy(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { + const ( + base = 10 + bitSize = 64 + baseExponential = 2 + ) + if resp != nil && resp.StatusCode == http.StatusTooManyRequests { + if s, ok := resp.Header["Retry-After"]; ok { + if sleep, err := strconv.ParseInt(s[0], base, bitSize); err == nil { + return time.Second * time.Duration(sleep) + } + } + } + + mult := math.Pow(baseExponential, float64(attemptNum)) * float64(min) + sleep := time.Duration(mult) + if float64(sleep) != mult || sleep > max { + sleep = max + } + return sleep +} + +// retryPolicy provides a callback for Client.CheckRetry, which +// will retry on 429s errors. +func retryPolicy(ctx context.Context, resp *http.Response, _ error) (bool, error) { + // do not retry on context.Canceled or context.DeadlineExceeded + if ctx.Err() != nil { + return false, ctx.Err() + } + + // 429 Too Many Requests is recoverable. Sometimes the server puts + // a Retry-After response header to indicate when the server is + // available to start processing request from client. + if resp != nil && resp.StatusCode == http.StatusTooManyRequests { + return true, nil + } + return false, nil +} + +func getRetryableClient(client *http.Client) *http.Client { + const ( + minRetryWait = 10 * time.Second + maxRetryWait = 60 * time.Second + retryMax = 10 + ) + retryClient := retryablehttp.NewClient() + retryClient.HTTPClient = client + retryClient.Backoff = backoffStrategy + retryClient.CheckRetry = retryPolicy + retryClient.RetryMax = retryMax + retryClient.RetryWaitMax = maxRetryWait + retryClient.RetryWaitMin = minRetryWait + // logging is handled by deck. + retryClient.Logger = nil + return retryClient.StandardClient() +} + +// GetKongClient returns a Kong client +func GetKongClient(opt KongClientConfig) (*kong.Client, error) { + var tlsConfig tls.Config + if opt.TLSSkipVerify { + tlsConfig.InsecureSkipVerify = true //nolint:gosec + } + if opt.TLSServerName != "" { + tlsConfig.ServerName = opt.TLSServerName + } + + if opt.TLSCACert != "" { + certPool := x509.NewCertPool() + ok := certPool.AppendCertsFromPEM([]byte(opt.TLSCACert)) + if !ok { + return nil, fmt.Errorf("failed to load TLSCACert") + } + tlsConfig.RootCAs = certPool + } + + if opt.TLSClientCert != "" && opt.TLSClientKey != "" { + // Read the key pair to create certificate + cert, err := tls.X509KeyPair([]byte(opt.TLSClientCert), []byte(opt.TLSClientKey)) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + clientTimeout = time.Duration(opt.Timeout) * time.Second + c := opt.HTTPClient + if c == nil { + c = HTTPClient() + } + defaultTransport := http.DefaultTransport.(*http.Transport) + defaultTransport.TLSClientConfig = &tlsConfig + c.Transport = defaultTransport + address := CleanAddress(opt.Address) + + headers, err := parseHeaders(opt.Headers) + if err != nil { + return nil, fmt.Errorf("parsing headers: %w", err) + } + c = kong.HTTPClientWithHeaders(c, headers) + + if opt.Retryable { + c = getRetryableClient(c) + } + + url, err := url.ParseRequestURI(address) + if err != nil { + return nil, fmt.Errorf("failed to parse kong address: %w", err) + } + // Add Session Cookie support if required + if opt.CookieJarPath != "" { + jar, err := cookiejarparser.LoadCookieJarFile(opt.CookieJarPath) + if err != nil { + return nil, fmt.Errorf("failed to initialize cookie-jar: %w", err) + } + c.Jar = jar + } + + kongClient, err := kong.NewClient(kong.String(url.String()), c) + if err != nil { + return nil, fmt.Errorf("creating client for Kong's Admin API: %w", err) + } + if opt.Debug { + kongClient.SetDebugMode(true) + kongClient.SetLogger(os.Stderr) + } + if opt.Workspace != "" { + kongClient.SetWorkspace(opt.Workspace) + } + return kongClient, nil +} + +func parseHeaders(headers []string) (http.Header, error) { + res := http.Header{} + const splitLen = 2 + for _, keyValue := range headers { + split := strings.SplitN(keyValue, ":", 2) + if len(split) >= splitLen { + res.Add(split[0], split[1]) + } else { + return nil, fmt.Errorf("splitting header key-value '%s'", keyValue) + } + } + return res, nil +} + +func GetKonnectClient(httpClient *http.Client, config KonnectConfig) (*konnect.Client, + error, +) { + address := CleanAddress(config.Address) + + if httpClient == nil { + defaultTransport := http.DefaultTransport.(*http.Transport) + defaultTransport.Proxy = http.ProxyFromEnvironment + httpClient = http.DefaultClient + httpClient.Transport = defaultTransport + } + headers, err := parseHeaders(config.Headers) + if err != nil { + return nil, fmt.Errorf("parsing headers: %w", err) + } + httpClient = kong.HTTPClientWithHeaders(httpClient, headers) + client, err := konnect.NewClient(httpClient, konnect.ClientOpts{ + BaseURL: address, + }) + if err != nil { + return nil, err + } + if config.Debug { + client.SetDebugMode(true) + client.SetLogger(os.Stderr) + } + return client, nil +} + +// CleanAddress removes trailling / from a URL. +func CleanAddress(address string) string { + re := regexp.MustCompile("[/]+$") + return re.ReplaceAllString(address, "") +} + +// HTTPClient returns a new Go stdlib's net/http.Client with +// sane default timeouts. +func HTTPClient() *http.Client { + return &http.Client{ + Timeout: clientTimeout, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: clientTimeout, + }).DialContext, + TLSHandshakeTimeout: clientTimeout, + Proxy: http.ProxyFromEnvironment, + }, + } +} diff --git a/pkg/utils/types_test.go b/pkg/utils/types_test.go new file mode 100644 index 0000000..75273e5 --- /dev/null +++ b/pkg/utils/types_test.go @@ -0,0 +1,138 @@ +package utils + +import ( + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrArrayString(t *testing.T) { + assert := assert.New(t) + var err ErrArray + assert.Equal("nil", err.Error()) + + err.Errors = append(err.Errors, fmt.Errorf("foo failed")) + + assert.Equal(err.Error(), "1 errors occurred:\n\tfoo failed\n") + + err.Errors = append(err.Errors, fmt.Errorf("bar failed")) + + assert.Equal(err.Error(), "2 errors occurred:\n\tfoo failed\n\tbar failed\n") +} + +func Test_cleanAddress(t *testing.T) { + type args struct { + address string + } + tests := []struct { + name string + args args + want string + }{ + { + args: args{ + address: "foo", + }, + want: "foo", + }, + { + args: args{ + address: "http://localhost:8001", + }, + want: "http://localhost:8001", + }, + { + args: args{ + address: "http://localhost:8001/", + }, + want: "http://localhost:8001", + }, + { + args: args{ + address: "http://localhost:8001//", + }, + want: "http://localhost:8001", + }, + { + args: args{ + address: "https://subdomain.example.com///", + }, + want: "https://subdomain.example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CleanAddress(tt.args.address); got != tt.want { + t.Errorf("cleanAddress() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseHeaders(t *testing.T) { + type args struct { + headers []string + } + tests := []struct { + name string + args args + want http.Header + wantErr bool + }{ + { + name: "nil headers returns without an error", + args: args{ + headers: nil, + }, + want: http.Header{}, + wantErr: false, + }, + { + name: "empty headers returns without an error", + args: args{ + headers: []string{}, + }, + want: http.Header{}, + wantErr: false, + }, + { + name: "headers returns without an error", + args: args{ + headers: []string{ + "foo:bar", + "baz:fubar", + }, + }, + want: http.Header{ + "Foo": []string{"bar"}, + "Baz": []string{"fubar"}, + }, + wantErr: false, + }, + { + name: "invalid headers value returns an error", + args: args{ + headers: []string{ + "fubar", + }, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseHeaders(tt.args.headers) + if (err != nil) != tt.wantErr { + t.Errorf("parseHeaders() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseHeaders() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 0000000..f87d034 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,252 @@ +package utils + +import ( + "context" + "errors" + "fmt" + "net/url" + "os" + "path/filepath" + "reflect" + "regexp" + "strings" + + "github.com/blang/semver/v4" + "github.com/kong/deck/cprint" + "github.com/kong/go-kong/kong" +) + +var ( + kongVersionRegex = regexp.MustCompile(`^\d+\.\d+`) + pathRegexPattern = regexp.MustCompile(`[^a-zA-Z0-9._~/%-]`) + + Kong140Version = semver.MustParse("1.4.0") + Kong300Version = semver.MustParse("3.0.0") + Kong340Version = semver.MustParse("3.4.0") +) + +var ErrorConsumerGroupUpgrade = errors.New( + "a rate-limiting-advanced plugin with config.consumer_groups\n" + + "and/or config.enforce_consumer_groups was found. Please use Consumer Groups scoped\n" + + "Plugins when running against Kong Enterprise 3.4.0 and above.\n\n" + + "Check https://docs.konghq.com/gateway/latest/kong-enterprise/consumer-groups/ for more information", +) + +var UpgradeMessage = "Please upgrade your configuration to account for 3.0\n" + + "breaking changes using the following command:\n\n" + + "deck convert --from kong-gateway-2.x --to kong-gateway-3.x\n\n" + + "This command performs the following changes:\n" + + " - upgrade the `_format_version` value to `3.0`\n" + + " - add the `~` prefix to all routes' paths containing a regex-pattern\n\n" + + "These changes may not be correct or exhaustive enough.\n" + + "It is strongly recommended to perform a manual audit\n" + + "of the updated configuration file before applying\n" + + "the configuration in production. Incorrect changes will result in\n" + + "unintended traffic routing by Kong Gateway.\n\n" + + + "For more information about this and related changes,\n" + + "please visit: https://docs.konghq.com/deck/latest/3.0-upgrade\n\n" + +// IsPathRegexLike checks if a path string contains a regex pattern. +func IsPathRegexLike(path string) bool { + return pathRegexPattern.MatchString(path) +} + +// Empty checks if a string referenced by s or s itself is empty. +func Empty(s *string) bool { + return s == nil || *s == "" +} + +// CleanKongVersion takes a version of Kong and returns back a string in +// the form of `/major.minor` version. There are various dashes and dots +// and other descriptors in Kong version strings, which has often created +// confusion in code and incorrect parsing, and hence this function does +// not return the patch version (on which shouldn't rely on anyways). +func CleanKongVersion(version string) (string, error) { + matches := kongVersionRegex.FindStringSubmatch(version) + if len(matches) < 1 { + return "", fmt.Errorf("unknown Kong version") + } + return matches[0], nil +} + +func AddExtToFilename(filename, ext string) string { + if filepath.Ext(filename) == "" { + filename = filename + "." + ext + } + return filename +} + +// NameToFilename clears path separators from strings. Some entity names in Kong and Konnect +// allow path directory separators. Some decK operations write files using entity names, +// which is not compatible with names that contain path separators. NameToFilename strips leading +// separator characters and replaces other instances of the separator with its URL-encoded representation. +func NameToFilename(name string) string { + s := strings.TrimPrefix(name, string(os.PathSeparator)) + s = strings.ReplaceAll(s, string(os.PathSeparator), url.PathEscape(string(os.PathSeparator))) + return s +} + +// FilenameToName (partially) reverses NameToFilename, replacing all URL-encoded path separator characters +// with the path separator character. It does not re-add a leading separator, because there is no way to know +// if that separator was included originally, and only some names (document paths) typically include one. +func FilenameToName(filename string) string { + return strings.ReplaceAll(filename, url.PathEscape(string(os.PathSeparator)), string(os.PathSeparator)) +} + +func CallGetAll(obj interface{}) (reflect.Value, error) { + // call GetAll method on entity + var result reflect.Value + method := reflect.ValueOf(obj).MethodByName("GetAll") + if !method.IsValid() { + return result, fmt.Errorf("GetAll() method not found for type '%v'. "+ + "Please file a bug with Kong Inc", reflect.ValueOf(obj).Type()) + } + entities := method.Call([]reflect.Value{})[0].Interface() + result = reflect.ValueOf(entities) + return result, nil +} + +func alreadyInSlice(elem string, slice []string) bool { + for _, s := range slice { + if s == elem { + return true + } + } + return false +} + +// RemoveDuplicates removes duplicated elements from a slice. +func RemoveDuplicates(slice *[]string) { + newSlice := []string{} + for _, s := range *slice { + if alreadyInSlice(s, newSlice) { + continue + } + newSlice = append(newSlice, s) + } + *slice = newSlice +} + +func WorkspaceExists(ctx context.Context, client *kong.Client) (bool, error) { + if client == nil { + return false, nil + } + workspace := client.Workspace() + if workspace == "" { + return true, nil + } + return client.Workspaces.Exists(ctx, &workspace) +} + +// These GetFooReference functions return stripped copies (ID and Name only) of Kong resource +// structs. We use these within KongRawState structs to indicate entity relationships. +// While state files indicate relationships by nesting (A collection of services is +// [{name: "foo", id: "1234", connect_timeout: 600000, routes: [{name: "fooRoute"}]}]), +// KongRawState is flattened, with all entities listed independently at the top level. +// To preserve the relationships, these flattened entities include references (the route from +// earlier becomes {name: "fooRoute", service: {name: "foo", id: "1234"}}). + +// GetConsumerReference returns a username+ID only copy of the input consumer, +// for use in references from other objects +func GetConsumerReference(c kong.Consumer) *kong.Consumer { + consumer := &kong.Consumer{ID: kong.String(*c.ID)} + if c.Username != nil { + consumer.Username = kong.String(*c.Username) + } + return consumer +} + +// GetConsumerGroupReference returns a name+ID only copy of the input consumer-group, +// for use in references from other objects +func GetConsumerGroupReference(c kong.ConsumerGroup) *kong.ConsumerGroup { + consumerGroup := &kong.ConsumerGroup{ID: kong.String(*c.ID)} + if c.Name != nil { + consumerGroup.Name = kong.String(*c.Name) + } + return consumerGroup +} + +// GetServiceReference returns a name+ID only copy of the input service, +// for use in references from other objects +func GetServiceReference(s kong.Service) *kong.Service { + service := &kong.Service{ID: kong.String(*s.ID)} + if s.Name != nil { + service.Name = kong.String(*s.Name) + } + return service +} + +// GetRouteReference returns a name+ID only copy of the input route, +// for use in references from other objects +func GetRouteReference(r kong.Route) *kong.Route { + route := &kong.Route{ID: kong.String(*r.ID)} + if r.Name != nil { + route.Name = kong.String(*r.Name) + } + return route +} + +// ParseKongVersion takes a version string from the Gateway and +// turns it into a semver-compliant version to be used for +// comparison across the code. +func ParseKongVersion(version string) (semver.Version, error) { + v, err := CleanKongVersion(version) + if err != nil { + return semver.Version{}, err + } + return semver.ParseTolerant(v) +} + +// ConfigFilesInDir traverses the directory rooted at dir and +// returns all the files with a case-insensitive extension of `yml` or `yaml`. +func ConfigFilesInDir(dir string) ([]string, error) { + var res []string + err := filepath.Walk( + dir, + func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + switch strings.ToLower(filepath.Ext(path)) { + case ".yaml", ".yml", ".json": + res = append(res, path) + } + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("reading state directory: %w", err) + } + return res, nil +} + +// HasPathsWithRegex300AndAbove checks routes' paths format and returns true +// if these math a regex-pattern without a '~' prefix. +func HasPathsWithRegex300AndAbove(route kong.Route) bool { + for _, p := range route.Paths { + if strings.HasPrefix(*p, "~/") || !IsPathRegexLike(*p) { + continue + } + return true + } + return false +} + +// PrintRouteRegexWarning prints out a warning about 3.x routes' path usage. +func PrintRouteRegexWarning(unsupportedRoutes []string) { + unsupportedRoutesLen := len(unsupportedRoutes) + // do not consider more than 10 sample routes to print out. + if unsupportedRoutesLen > 10 { + unsupportedRoutes = unsupportedRoutes[:10] + } + cprint.UpdatePrintf( + "%d unsupported routes' paths format with Kong version 3.0\n"+ + "or above were detected. Some of these routes are (not an exhaustive list):\n\n"+ + "%s\n\n"+UpgradeMessage, + unsupportedRoutesLen, strings.Join(unsupportedRoutes, "\n"), + ) +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 0000000..c1629e0 --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,328 @@ +package utils + +import ( + "net/url" + "os" + "testing" + + "github.com/blang/semver/v4" + "github.com/stretchr/testify/assert" +) + +func TestEmpty(t *testing.T) { + assert := assert.New(t) + notEmpty := "not-empty" + emptyString := "" + var nilPointer *string + assert.False(Empty(¬Empty)) + assert.True(Empty(nilPointer)) + assert.True(Empty(&emptyString)) +} + +func Test_cleanKongVersion(t *testing.T) { + type args struct { + version string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + args: args{ + version: "1.0.1", + }, + want: "1.0", + wantErr: false, + }, + { + args: args{ + version: "1.3.0.1", + }, + want: "1.3", + wantErr: false, + }, + { + args: args{ + version: "0.14.1", + }, + want: "0.14", + wantErr: false, + }, + { + args: args{ + version: "0.14.2rc", + }, + want: "0.14", + wantErr: false, + }, + { + args: args{ + version: "0.14.2rc1", + }, + want: "0.14", + wantErr: false, + }, + { + args: args{ + version: "0.33-enterprise-edition", + }, + want: "0.33", + wantErr: false, + }, + { + args: args{ + version: "1.3.0-0-enterprise-edition", + }, + want: "1.3", + wantErr: false, + }, + { + args: args{ + version: "", + }, + want: "", + wantErr: true, + }, + { + args: args{ + version: "0-1.1", + }, + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CleanKongVersion(tt.args.version) + if (err != nil) != tt.wantErr { + t.Errorf("cleanKongVersion() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("cleanKongVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_AddExtToFilename(t *testing.T) { + type args struct { + filename string + ext string + } + tests := []struct { + name string + args args + want string + }{ + { + args: args{ + filename: "foo", + ext: "yolo", + }, + want: "foo.yolo", + }, + { + args: args{ + filename: "foo.json", + ext: "yolo", + }, + want: "foo.json", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := AddExtToFilename(tt.args.filename, tt.args.ext); got != tt.want { + t.Errorf("AddExtToFilename() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_NameToFilename(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "leading separator", + args: args{ + name: string(os.PathSeparator) + "foo.md", + }, + want: "foo.md", + }, + { + name: "inner separator", + args: args{ + name: "bar" + string(os.PathSeparator) + "foo.md", + }, + want: "bar" + url.PathEscape(string(os.PathSeparator)) + "foo.md", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NameToFilename(tt.args.name); got != tt.want { + t.Errorf("NameToFilename() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_FilenameToName(t *testing.T) { + type args struct { + filename string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "inner separator", + args: args{ + filename: "bar" + url.PathEscape(string(os.PathSeparator)) + "foo.md", + }, + want: "bar" + string(os.PathSeparator) + "foo.md", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FilenameToName(tt.args.filename); got != tt.want { + t.Errorf("FilenameToName() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_ParseKongVersion(t *testing.T) { + type args struct { + version string + } + tests := []struct { + name string + args args + want semver.Version + wantErr bool + }{ + { + args: args{ + version: "1.0.1", + }, + want: semver.Version{Major: 1, Minor: 0}, + wantErr: false, + }, + { + args: args{ + version: "1.3.0.1", + }, + want: semver.Version{Major: 1, Minor: 3}, + wantErr: false, + }, + { + args: args{ + version: "0.14.1", + }, + want: semver.Version{Major: 0, Minor: 14}, + wantErr: false, + }, + { + args: args{ + version: "0.14.2rc", + }, + want: semver.Version{Major: 0, Minor: 14}, + wantErr: false, + }, + { + args: args{ + version: "0.14.2rc1", + }, + want: semver.Version{Major: 0, Minor: 14}, + wantErr: false, + }, + { + args: args{ + version: "0.33-enterprise-edition", + }, + want: semver.Version{Major: 0, Minor: 33}, + wantErr: false, + }, + { + args: args{ + version: "1.3.0-0-enterprise-edition", + }, + want: semver.Version{Major: 1, Minor: 3}, + wantErr: false, + }, + { + args: args{ + version: "", + }, + want: semver.Version{}, + wantErr: true, + }, + { + args: args{ + version: "0-1.1", + }, + want: semver.Version{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseKongVersion(tt.args.version) + if (err != nil) != tt.wantErr { + t.Errorf("ParseKongVersion() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !got.Equals(tt.want) { + t.Errorf("ParseKongVersion() = %v, want %v", got.String(), tt.want.String()) + } + }) + } +} + +func Test_IsPathRegexLike(t *testing.T) { + tests := []struct { + name string + paths []string + expected bool + }{ + { + name: "regex-like paths", + paths: []string{ + `/.*`, + `/[fF][oO]{2}`, + `/foo|bar`, + `/blog-\\d+`, + `/bl[ao]go(sphere|web)`, + }, + expected: true, + }, + { + name: "no regex-like paths", + paths: []string{ + "/", + "/foo", + "/foo/", + "/abcd~user~2", + "/abcd%aa%10%ff%AA%FF", + }, + expected: false, + }, + } + + for _, test := range tests { + for _, path := range test.paths { + assert.Equal( + t, test.expected, IsPathRegexLike(path), "test: '%v', path: '%v'", test.name, path, + ) + } + } +} diff --git a/pkg/utils/uuid.go b/pkg/utils/uuid.go new file mode 100644 index 0000000..5be7f48 --- /dev/null +++ b/pkg/utils/uuid.go @@ -0,0 +1,10 @@ +package utils + +import ( + "github.com/google/uuid" +) + +// UUID will generate a random v4 unique identifier +func UUID() string { + return uuid.NewString() +} diff --git a/pkg/utils/uuid_test.go b/pkg/utils/uuid_test.go new file mode 100644 index 0000000..e43abf1 --- /dev/null +++ b/pkg/utils/uuid_test.go @@ -0,0 +1,17 @@ +package utils + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUUID(t *testing.T) { + assert := assert.New(t) + uuid := UUID() + assert.NotEmpty(uuid) + assert.Regexp(regexp.MustCompile( + "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"), + uuid) +} diff --git a/pkg/utils/zero.go b/pkg/utils/zero.go new file mode 100644 index 0000000..a7cf762 --- /dev/null +++ b/pkg/utils/zero.go @@ -0,0 +1,38 @@ +package utils + +import ( + "reflect" +) + +var zero reflect.Value + +func ZeroOutField(obj interface{}, field string) { + ptr := reflect.ValueOf(obj) + if ptr.Kind() != reflect.Ptr { + return + } + v := reflect.Indirect(ptr) + ts := v.FieldByName(field) + if ts == zero { + return + } + ts.Set(reflect.Zero(ts.Type())) +} + +func ZeroOutID(obj interface{}, altName *string, withID bool) { + // withID is set, export the ID + if withID { + return + } + // altName is not set, export the ID + if Empty(altName) { + return + } + // zero the ID field + ZeroOutField(obj, "ID") +} + +func ZeroOutTimestamps(obj interface{}) { + ZeroOutField(obj, "CreatedAt") + ZeroOutField(obj, "UpdatedAt") +} diff --git a/pkg/utils/zero_test.go b/pkg/utils/zero_test.go new file mode 100644 index 0000000..3015fe0 --- /dev/null +++ b/pkg/utils/zero_test.go @@ -0,0 +1,138 @@ +package utils + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +func TestZeroOutID(t *testing.T) { + type args struct { + obj interface{} + altName *string + withID bool + } + tests := []struct { + name string + args args + expectedObj interface{} + }{ + { + name: "zeros out ID when name is set", + args: args{ + obj: &kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + altName: kong.String("Name"), + withID: false, + }, + expectedObj: &kong.Service{ + Name: kong.String("foo-name"), + }, + }, + { + name: "does not error out if ID is already zero value", + args: args{ + obj: &kong.Service{ + Name: kong.String("foo-name"), + }, + altName: kong.String("Name"), + withID: false, + }, + expectedObj: &kong.Service{ + Name: kong.String("foo-name"), + }, + }, + { + name: "does not error out if provided value is not a pointer", + args: args{ + obj: kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + altName: kong.String("Name"), + withID: false, + }, + expectedObj: kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + }, + { + name: "does not zero out ID when withID is set to true", + args: args{ + obj: &kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + altName: kong.String("Name"), + withID: true, + }, + expectedObj: &kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + }, + { + name: "does not zero out ID when altName is not provided", + args: args{ + obj: &kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + withID: false, + }, + expectedObj: &kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + }, + } + t.Parallel() + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ZeroOutID(tc.args.obj, tc.args.altName, tc.args.withID) + assert.Equal(t, tc.expectedObj, tc.args.obj) + }) + } +} + +func TestZeroOutTimestamps(t *testing.T) { + type args struct { + obj interface{} + } + tests := []struct { + name string + args args + expectedObj interface{} + }{ + { + name: "clears timestamps when set", + args: args{ + obj: &kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + CreatedAt: kong.Int(42), + UpdatedAt: kong.Int(42), + }, + }, + expectedObj: &kong.Service{ + ID: kong.String("foo-id"), + Name: kong.String("foo-name"), + }, + }, + } + t.Parallel() + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ZeroOutTimestamps(tc.args.obj) + assert.Equal(t, tc.expectedObj, tc.args.obj) + }) + } +}