diff --git a/boilerplate/flyte/golang_support_tools/go.mod b/boilerplate/flyte/golang_support_tools/go.mod index 13941936c..307398c89 100644 --- a/boilerplate/flyte/golang_support_tools/go.mod +++ b/boilerplate/flyte/golang_support_tools/go.mod @@ -1,13 +1,191 @@ module github.com/flyteorg/boilerplate -go 1.16 +go 1.17 require ( github.com/alvaroloes/enumer v1.1.2 github.com/flyteorg/flytestdlib v0.4.7 github.com/golangci/golangci-lint v1.38.0 + github.com/pseudomuto/protoc-gen-doc v1.4.1 github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5 - github.com/pseudomuto/protoc-gen-doc v0.1.1 // indirect +) + +require ( + 4d63.com/gochecknoglobals v0.0.0-20201008074935-acfc0b28355a // indirect + cloud.google.com/go v0.75.0 // indirect + cloud.google.com/go/storage v1.12.0 // indirect + github.com/Azure/azure-sdk-for-go v51.0.0+incompatible // indirect + github.com/Azure/go-autorest v14.2.0+incompatible // indirect + github.com/Azure/go-autorest/autorest v0.11.17 // indirect + github.com/Azure/go-autorest/autorest/adal v0.9.10 // indirect + github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect + github.com/Azure/go-autorest/logger v0.2.0 // indirect + github.com/Azure/go-autorest/tracing v0.6.0 // indirect + github.com/BurntSushi/toml v0.3.1 // indirect + github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect + github.com/Masterminds/semver v1.5.0 // indirect + github.com/Masterminds/sprig v2.15.0+incompatible // indirect + github.com/OpenPeeDeeP/depguard v1.0.1 // indirect + github.com/alexkohler/prealloc v1.0.0 // indirect + github.com/aokoli/goutils v1.0.1 // indirect + github.com/ashanbrown/forbidigo v1.1.0 // indirect + github.com/ashanbrown/makezero v0.0.0-20201205152432-7b7cdbb3025a // indirect + github.com/aws/aws-sdk-go v1.37.1 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/bkielbasa/cyclop v1.2.0 // indirect + github.com/bombsimon/wsl/v3 v3.2.0 // indirect + github.com/cespare/xxhash v1.1.0 // indirect + github.com/cespare/xxhash/v2 v2.1.1 // indirect + github.com/charithe/durationcheck v0.0.6 // indirect + github.com/coocood/freecache v1.1.1 // indirect + github.com/daixiang0/gci v0.2.8 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/denis-tingajkin/go-header v0.4.2 // indirect + github.com/envoyproxy/protoc-gen-validate v0.3.0-java // indirect + github.com/ernesto-jimenez/gogen v0.0.0-20180125220232-d7d4131e6607 // indirect + github.com/esimonov/ifshort v1.0.1 // indirect + github.com/fatih/color v1.10.0 // indirect + github.com/fatih/structtag v1.2.0 // indirect + github.com/form3tech-oss/jwt-go v3.2.2+incompatible // indirect + github.com/fsnotify/fsnotify v1.4.9 // indirect + github.com/fzipp/gocyclo v0.3.1 // indirect + github.com/ghodss/yaml v1.0.0 // indirect + github.com/go-critic/go-critic v0.5.4 // indirect + github.com/go-logr/logr v0.4.0 // indirect + github.com/go-toolsmith/astcast v1.0.0 // indirect + github.com/go-toolsmith/astcopy v1.0.0 // indirect + github.com/go-toolsmith/astequal v1.0.0 // indirect + github.com/go-toolsmith/astfmt v1.0.0 // indirect + github.com/go-toolsmith/astp v1.0.0 // indirect + github.com/go-toolsmith/strparse v1.0.0 // indirect + github.com/go-toolsmith/typep v1.0.2 // indirect + github.com/go-xmlfmt/xmlfmt v0.0.0-20191208150333-d5b6f63a941b // indirect + github.com/gobwas/glob v0.2.3 // indirect + github.com/gofrs/flock v0.8.0 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect + github.com/golang/protobuf v1.4.3 // indirect + github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2 // indirect + github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect + github.com/golangci/go-misc v0.0.0-20180628070357-927a3d87b613 // indirect + github.com/golangci/gofmt v0.0.0-20190930125516-244bba706f1a // indirect + github.com/golangci/lint-1 v0.0.0-20191013205115-297bf364a8e0 // indirect + github.com/golangci/maligned v0.0.0-20180506175553-b1d89398deca // indirect + github.com/golangci/misspell v0.3.5 // indirect + github.com/golangci/revgrep v0.0.0-20210208091834-cd28932614b5 // indirect + github.com/golangci/unconvert v0.0.0-20180507085042-28b1c447d1f4 // indirect + github.com/google/go-cmp v0.5.4 // indirect + github.com/google/uuid v1.1.2 // indirect + github.com/googleapis/gax-go/v2 v2.0.5 // indirect + github.com/gordonklaus/ineffassign v0.0.0-20210225214923-2e10b2664254 // indirect + github.com/gostaticanalysis/analysisutil v0.4.1 // indirect + github.com/gostaticanalysis/comment v1.4.1 // indirect + github.com/gostaticanalysis/forcetypeassert v0.0.0-20200621232751-01d4955beaa5 // indirect + github.com/gostaticanalysis/nilerr v0.1.1 // indirect + github.com/graymeta/stow v0.2.7 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/huandu/xstrings v1.0.0 // indirect + github.com/imdario/mergo v0.3.5 // indirect + github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/jgautheron/goconst v1.4.0 // indirect + github.com/jingyugao/rowserrcheck v0.0.0-20210130005344-c6a0c12dd98d // indirect + github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/jstemmer/go-junit-report v0.9.1 // indirect + github.com/julz/importas v0.0.0-20210226073942-60b4fa260dd0 // indirect + github.com/kisielk/errcheck v1.6.0 // indirect + github.com/kisielk/gotool v1.0.0 // indirect + github.com/kulti/thelper v0.4.0 // indirect + github.com/kunwardeep/paralleltest v1.0.2 // indirect + github.com/kyoh86/exportloopref v0.1.8 // indirect + github.com/magefile/mage v1.10.0 // indirect + github.com/magiconair/properties v1.8.4 // indirect + github.com/maratori/testpackage v1.0.1 // indirect + github.com/matoous/godox v0.0.0-20210227103229-6504466cf951 // indirect + github.com/mattn/go-colorable v0.1.8 // indirect + github.com/mattn/go-isatty v0.0.12 // indirect + github.com/mattn/go-runewidth v0.0.7 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect + github.com/mbilski/exhaustivestruct v1.2.0 // indirect + github.com/mgechev/dots v0.0.0-20190921121421-c36f7dcfbb81 // indirect + github.com/mgechev/revive v1.0.3 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect + github.com/mitchellh/mapstructure v1.4.1 // indirect + github.com/moricho/tparallel v0.2.1 // indirect + github.com/mwitkow/go-proto-validators v0.0.0-20180403085117-0950a7990007 // indirect + github.com/nakabonne/nestif v0.3.0 // indirect + github.com/nbutton23/zxcvbn-go v0.0.0-20201221231540-e56b841a3c88 // indirect + github.com/ncw/swift v1.0.53 // indirect + github.com/nishanths/exhaustive v0.1.0 // indirect + github.com/nishanths/predeclared v0.2.1 // indirect + github.com/olekukonko/tablewriter v0.0.4 // indirect + github.com/pascaldekloe/name v0.0.0-20180628100202-0fd16699aae1 // indirect + github.com/pelletier/go-toml v1.8.1 // indirect + github.com/phayes/checkstyle v0.0.0-20170904204023-bfd46e6a821d // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/polyfloyd/go-errorlint v0.0.0-20201127212506-19bd8db6546f // indirect + github.com/prometheus/client_golang v1.9.0 // indirect + github.com/prometheus/client_model v0.2.0 // indirect + github.com/prometheus/common v0.15.0 // indirect + github.com/prometheus/procfs v0.3.0 // indirect + github.com/pseudomuto/protokit v0.2.0 // indirect + github.com/quasilyte/go-ruleguard v0.3.0 // indirect + github.com/quasilyte/regex/syntax v0.0.0-20200407221936-30656e2c4a95 // indirect + github.com/ryancurrah/gomodguard v1.2.0 // indirect + github.com/ryanrolds/sqlclosecheck v0.3.0 // indirect + github.com/sanposhiho/wastedassign v0.1.3 // indirect + github.com/satori/go.uuid v1.2.0 // indirect + github.com/securego/gosec/v2 v2.6.1 // indirect + github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c // indirect + github.com/sirupsen/logrus v1.8.0 // indirect + github.com/sonatard/noctx v0.0.1 // indirect + github.com/sourcegraph/go-diff v0.6.1 // indirect + github.com/spf13/afero v1.5.1 // indirect + github.com/spf13/cast v1.3.1 // indirect + github.com/spf13/cobra v1.1.3 // indirect + github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/viper v1.7.1 // indirect + github.com/ssgreg/nlreturn/v2 v2.1.0 // indirect + github.com/stretchr/objx v0.3.0 // indirect + github.com/stretchr/testify v1.7.0 // indirect + github.com/subosito/gotenv v1.2.0 // indirect + github.com/tdakkota/asciicheck v0.0.0-20200416200610-e657995f937b // indirect + github.com/tetafro/godot v1.4.4 // indirect + github.com/timakin/bodyclose v0.0.0-20200424151742-cb6215831a94 // indirect + github.com/tomarrell/wrapcheck v0.0.0-20201130113247-1683564d9756 // indirect + github.com/tommy-muehle/go-mnd/v2 v2.3.1 // indirect + github.com/ultraware/funlen v0.0.3 // indirect + github.com/ultraware/whitespace v0.0.4 // indirect + github.com/uudashr/gocognit v1.0.1 // indirect + go.opencensus.io v0.22.6 // indirect + golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad // indirect + golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 // indirect + golang.org/x/mod v0.4.1 // indirect + golang.org/x/net v0.0.0-20210119194325-5f4716e94777 // indirect + golang.org/x/oauth2 v0.0.0-20210126194326-f9ce19ea3013 // indirect + golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c // indirect + golang.org/x/text v0.3.5 // indirect + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 // indirect + golang.org/x/tools v0.1.0 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + google.golang.org/api v0.38.0 // indirect + google.golang.org/appengine v1.6.7 // indirect + google.golang.org/genproto v0.0.0-20210126160654-44e461bb6506 // indirect + google.golang.org/grpc v1.35.0 // indirect + google.golang.org/protobuf v1.25.0 // indirect + gopkg.in/ini.v1 v1.62.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + honnef.co/go/tools v0.1.2 // indirect + k8s.io/apimachinery v0.20.2 // indirect + k8s.io/client-go v0.0.0-20210217172142-7279fc64d847 // indirect + k8s.io/klog/v2 v2.5.0 // indirect + mvdan.cc/gofumpt v0.1.0 // indirect + mvdan.cc/interfacer v0.0.0-20180901003855-c20040233aed // indirect + mvdan.cc/lint v0.0.0-20170908181259-adc824a0674b // indirect + mvdan.cc/unparam v0.0.0-20210104141923-aac4ce9116a7 // indirect ) replace github.com/vektra/mockery => github.com/enghabu/mockery v0.0.0-20191009061720-9d0c8670c2f0 diff --git a/boilerplate/flyte/golang_support_tools/tools.go b/boilerplate/flyte/golang_support_tools/tools.go index eee691d8c..d970d2106 100644 --- a/boilerplate/flyte/golang_support_tools/tools.go +++ b/boilerplate/flyte/golang_support_tools/tools.go @@ -7,5 +7,5 @@ import ( _ "github.com/flyteorg/flytestdlib/cli/pflags" _ "github.com/golangci/golangci-lint/cmd/golangci-lint" _ "github.com/vektra/mockery/cmd/mockery" - - "github.com/pseudomuto/protoc-gen-doc/cmd/protoc-gen-doc" + _ "github.com/pseudomuto/protoc-gen-doc/cmd/protoc-gen-doc" ) diff --git a/boilerplate/flyte/golang_test_targets/go-gen.sh b/boilerplate/flyte/golang_test_targets/go-gen.sh new file mode 100755 index 000000000..54bd6af61 --- /dev/null +++ b/boilerplate/flyte/golang_test_targets/go-gen.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -ex + +echo "Running go generate" +go generate ./... + +# This section is used by GitHub workflow to ensure that the generation step was run +if [ -n "$DELTA_CHECK" ]; then + DIRTY=$(git status --porcelain) + if [ -n "$DIRTY" ]; then + echo "FAILED: Go code updated without commiting generated code." + echo "Ensure make generate has run and all changes are committed." + DIFF=$(git diff) + echo "diff detected: $DIFF" + DIFF=$(git diff --name-only) + echo "files different: $DIFF" + exit 1 + else + echo "SUCCESS: Generated code is up to date." + fi +fi diff --git a/go.mod b/go.mod index d979e2ea0..8e88e45ec 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.0.0 github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v0.21.11 + github.com/flyteorg/flyteidl v0.21.23 github.com/flyteorg/flytestdlib v0.4.7 github.com/go-logr/zapr v0.4.0 // indirect github.com/go-test/deep v1.0.7 diff --git a/go.sum b/go.sum index d4b5eb510..e264258a2 100644 --- a/go.sum +++ b/go.sum @@ -325,8 +325,8 @@ github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= -github.com/flyteorg/flyteidl v0.21.11 h1:oH9YPoR7scO9GFF/I8D0gCTOB+JP5HRK7b7cLUBRz90= -github.com/flyteorg/flyteidl v0.21.11/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/flyteorg/flyteidl v0.21.23 h1:hzGIFNOt3VooW/NdnaicXijn3EKjNKTz1kY+tlHkED4= +github.com/flyteorg/flyteidl v0.21.23/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= github.com/flyteorg/flytestdlib v0.4.7 h1:SMPPXI3j/MjP7D2fqaR+lPQkTrqYS7xZbwsgJI2F8SU= github.com/flyteorg/flytestdlib v0.4.7/go.mod h1:fv1ar34LJLMTaf0tbfetisLykUlARi7rP+NQTUn6QQs= diff --git a/go/tasks/config_load_test.go b/go/tasks/config_load_test.go index b827cda1c..5cbad7122 100755 --- a/go/tasks/config_load_test.go +++ b/go/tasks/config_load_test.go @@ -87,6 +87,8 @@ func TestLoadConfig(t *testing.T) { assert.NotNil(t, k8sConfig.DefaultSecurityContext) assert.NotNil(t, k8sConfig.DefaultSecurityContext.AllowPrivilegeEscalation) assert.False(t, *k8sConfig.DefaultSecurityContext.AllowPrivilegeEscalation) + assert.NotNil(t, k8sConfig.EnableHostNetworkingPod) + assert.True(t, *k8sConfig.EnableHostNetworkingPod) }) t.Run("logs-config-test", func(t *testing.T) { diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index fe4aed060..4e0c15dc5 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -4,8 +4,6 @@ import ( "fmt" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" structpb "github.com/golang/protobuf/ptypes/struct" ) @@ -69,6 +67,18 @@ func (p Phase) IsWaitingForResources() bool { return p == PhaseWaitingForResources } +type ExternalResource struct { + // A unique identifier for the external resource + ExternalID string + // A unique index for the external resource. Although the ID may change, this will remain the same + // throughout task event reports and retries. + Index uint32 + // The nubmer of times this external resource has been attempted + RetryAttempt uint32 + // Phase (if exists) associated with the external resource + Phase Phase +} + type TaskInfo struct { // log information for the task execution Logs []*core.TaskLog @@ -77,8 +87,8 @@ type TaskInfo struct { OccurredAt *time.Time // Custom Event information that the plugin would like to expose to the front-end CustomInfo *structpb.Struct - // Metadata around how a task was executed - Metadata *event.TaskExecutionMetadata + // A collection of information about external resources launched by this task + ExternalResources []*ExternalResource } func (t *TaskInfo) String() string { diff --git a/go/tasks/pluginmachinery/flytek8s/config/config.go b/go/tasks/pluginmachinery/flytek8s/config/config.go index 827db81cb..e53fad473 100755 --- a/go/tasks/pluginmachinery/flytek8s/config/config.go +++ b/go/tasks/pluginmachinery/flytek8s/config/config.go @@ -139,6 +139,11 @@ type K8sPluginConfig struct { // DefaultSecurityContext provides a default container security context that should be applied for the primary container launched and created by FlytePropeller. This may not be applicable to all plugins. For // // downstream plugins - i.e. TensorflowOperators may not support setting this, but Spark does. DefaultSecurityContext *v1.SecurityContext `json:"default-security-context" pflag:"-,Optionally specify a default security context that should be applied to every container launched/created by FlytePropeller. This will not be applied to plugins that do not support it or to user supplied containers in pod tasks."` + + // EnableHostNetworkingPod is a binary switch to enable `hostNetwork: true` for all pods launched by Flyte. + // Refer to - https://kubernetes.io/docs/concepts/policy/pod-security-policy/#host-namespaces. + // As a follow up, the default pod configurations will now be adjusted using podTemplates per namespace + EnableHostNetworkingPod *bool `json:"enable-host-networking-pod" pflag:"-,If true, will schedule all pods with hostNetwork: true."` } // FlyteCoPilotConfig specifies configuration for the Flyte CoPilot system. FlyteCoPilot, allows running flytekit-less containers diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 8841cdfe2..039e75fba 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -91,6 +91,9 @@ func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecut if podSpec.SecurityContext == nil && config.GetK8sPluginConfig().DefaultPodSecurityContext != nil { podSpec.SecurityContext = config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy() } + if config.GetK8sPluginConfig().EnableHostNetworkingPod != nil { + podSpec.HostNetwork = *config.GetK8sPluginConfig().EnableHostNetworkingPod + } ApplyInterruptibleNodeAffinity(isInterruptible, podSpec) } diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index cca1f3a6a..ee4b8e142 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -494,6 +494,36 @@ func TestToK8sPod(t *testing.T) { assert.NotNil(t, p.SecurityContext) assert.Equal(t, *p.SecurityContext.RunAsGroup, v) }) + + t.Run("enableHostNetwork", func(t *testing.T) { + enabled := true + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + EnableHostNetworkingPod: &enabled, + })) + x := dummyExecContext(&v1.ResourceRequirements{}) + p, err := ToK8sPodSpec(ctx, x) + assert.NoError(t, err) + assert.True(t, p.HostNetwork) + }) + + t.Run("explicitDisableHostNetwork", func(t *testing.T) { + enabled := false + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + EnableHostNetworkingPod: &enabled, + })) + x := dummyExecContext(&v1.ResourceRequirements{}) + p, err := ToK8sPodSpec(ctx, x) + assert.NoError(t, err) + assert.False(t, p.HostNetwork) + }) + + t.Run("skipSettingHostNetwork", func(t *testing.T) { + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) + x := dummyExecContext(&v1.ResourceRequirements{}) + p, err := ToK8sPodSpec(ctx, x) + assert.NoError(t, err) + assert.False(t, p.HostNetwork) + }) } func TestDemystifyPending(t *testing.T) { diff --git a/go/tasks/pluginmachinery/webapi/example/plugin.go b/go/tasks/pluginmachinery/webapi/example/plugin.go index 401cb9064..2c300e92f 100644 --- a/go/tasks/pluginmachinery/webapi/example/plugin.go +++ b/go/tasks/pluginmachinery/webapi/example/plugin.go @@ -4,8 +4,6 @@ import ( "context" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/errors" @@ -96,11 +94,9 @@ func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase co }, }, OccurredAt: &tNow, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "abc", - }, + ExternalResources: []*core.ExternalResource{ + { + ExternalID: "abc", }, }, }), nil diff --git a/go/tasks/plugins/array/awsbatch/launcher.go b/go/tasks/plugins/array/awsbatch/launcher.go index 32f2c5d1d..8959af267 100644 --- a/go/tasks/plugins/array/awsbatch/launcher.go +++ b/go/tasks/plugins/array/awsbatch/launcher.go @@ -6,6 +6,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/logger" arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" @@ -53,6 +54,13 @@ func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchCl } metrics.SubTasksSubmitted.Add(ctx, float64(size)) + + retryAttemptsArray, err := bitarray.NewCompactArray(uint(size), bitarray.Item(pluginConfig.MaxRetries)) + if err != nil { + logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", size, pluginConfig.MaxRetries) + return nil, err + } + parentState := currentState. SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, 0). SetArrayStatus(arraystatus.ArrayStatus{ @@ -61,7 +69,8 @@ func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchCl }, Detailed: arrayCore.NewPhasesCompactArray(uint(size)), }). - SetReason("Successfully launched subtasks.") + SetReason("Successfully launched subtasks."). + SetRetryAttempts(retryAttemptsArray) nextState = currentState.SetExternalJobID(j) nextState.State = parentState diff --git a/go/tasks/plugins/array/awsbatch/launcher_test.go b/go/tasks/plugins/array/awsbatch/launcher_test.go index 5520a9b3d..d135500b1 100644 --- a/go/tasks/plugins/array/awsbatch/launcher_test.go +++ b/go/tasks/plugins/array/awsbatch/launcher_test.go @@ -3,6 +3,7 @@ package awsbatch import ( "testing" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/mock" @@ -110,6 +111,9 @@ func TestLaunchSubTasks(t *testing.T) { JobDefinitionArn: "arn", } + retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) + assert.NoError(t, err) + expectedState := &State{ State: &core2.State{ CurrentPhase: core2.PhaseCheckingSubTaskExecutions, @@ -123,6 +127,7 @@ func TestLaunchSubTasks(t *testing.T) { }, Detailed: arrayCore.NewPhasesCompactArray(5), }, + RetryAttempts: retryAttemptsArray, }, ExternalJobID: refStr("qpxyarq"), diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index 87ef4ade7..a7d033aa1 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -112,6 +112,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(actualPhase)) newArrayStatus.Summary.Inc(actualPhase) + parentState.RetryAttempts.SetItem(childIdx, bitarray.Item(len(subJob.Attempts))) } if queued > 0 { diff --git a/go/tasks/plugins/array/awsbatch/monitor_test.go b/go/tasks/plugins/array/awsbatch/monitor_test.go index b54af4465..621512b4a 100644 --- a/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -136,6 +136,9 @@ func TestCheckSubTasksState(t *testing.T) { inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) + retryAttemptsArray, err := bitarray.NewCompactArray(1, bitarray.Item(1)) + assert.NoError(t, err) + newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, @@ -146,6 +149,7 @@ func TestCheckSubTasksState(t *testing.T) { Detailed: arrayCore.NewPhasesCompactArray(1), }, IndexesToCache: bitarray.NewBitSet(1), + RetryAttempts: retryAttemptsArray, }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", @@ -180,6 +184,9 @@ func TestCheckSubTasksState(t *testing.T) { inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) + retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) + assert.NoError(t, err) + newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, @@ -190,6 +197,7 @@ func TestCheckSubTasksState(t *testing.T) { Detailed: arrayCore.NewPhasesCompactArray(2), }, IndexesToCache: bitarray.NewBitSet(2), + RetryAttempts: retryAttemptsArray, }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", diff --git a/go/tasks/plugins/array/core/state.go b/go/tasks/plugins/array/core/state.go index 3703d6480..a8908fec0 100644 --- a/go/tasks/plugins/array/core/state.go +++ b/go/tasks/plugins/array/core/state.go @@ -5,8 +5,6 @@ import ( "fmt" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" @@ -52,6 +50,9 @@ type State struct { // Which sub-tasks to cache, (using the original index, that is, the length is ArrayJob.size) IndexesToCache *bitarray.BitSet `json:"indexesToCache"` + + // Tracks the number of subtask retries using the execution index + RetryAttempts bitarray.CompactArray `json:"retryAttempts"` } func (s State) GetReason() string { @@ -111,6 +112,11 @@ func (s *State) SetReason(reason string) *State { return s } +func (s *State) SetRetryAttempts(retryAttempts bitarray.CompactArray) *State { + s.RetryAttempts = retryAttempts + return s +} + func (s *State) SetExecutionArraySize(size int) *State { s.ExecutionArraySize = size return s @@ -171,20 +177,24 @@ func GetPhaseVersionOffset(currentPhase Phase, length int64) uint32 { // handling as we don't have to keep an ever growing list of log links (our batch jobs can be 5000 sub-tasks, keeping // all the log links takes up a lot of space). func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idlCore.TaskLog, subTaskIDs []*string) (core.PhaseInfo, error) { - phaseInfo := core.PhaseInfoUndefined t := time.Now() + nowTaskInfo := &core.TaskInfo{ - OccurredAt: &t, - Logs: logLinks, - } - if nowTaskInfo.Metadata == nil { - nowTaskInfo.Metadata = &event.TaskExecutionMetadata{} + OccurredAt: &t, + Logs: logLinks, + ExternalResources: make([]*core.ExternalResource, len(subTaskIDs)), } - for _, subTaskID := range subTaskIDs { - nowTaskInfo.Metadata.ExternalResources = append(nowTaskInfo.Metadata.ExternalResources, &event.ExternalResourceInfo{ - ExternalId: *subTaskID, - }) + + for childIndex, subTaskID := range subTaskIDs { + originalIndex := CalculateOriginalIndex(childIndex, state.GetIndexesToCache()) + + nowTaskInfo.ExternalResources[childIndex] = &core.ExternalResource{ + ExternalID: *subTaskID, + Index: uint32(originalIndex), + RetryAttempt: uint32(state.RetryAttempts.GetItem(childIndex)), + Phase: core.Phases[state.ArrayStatus.Detailed.GetItem(childIndex)], + } } switch p, version := state.GetPhase(); p { diff --git a/go/tasks/plugins/array/core/state_test.go b/go/tasks/plugins/array/core/state_test.go index 876612623..1f477a874 100644 --- a/go/tasks/plugins/array/core/state_test.go +++ b/go/tasks/plugins/array/core/state_test.go @@ -5,14 +5,13 @@ import ( "fmt" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/golang/protobuf/proto" "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" "github.com/stretchr/testify/assert" ) @@ -51,30 +50,43 @@ func assertBitSetsEqual(t testing.TB, b1, b2 *bitarray.BitSet, len int) { } } -func assertTaskExecutionMetadata(t *testing.T, subTaskIDs []*string, metadata *event.TaskExecutionMetadata) { - assert.NotNil(t, metadata) - var externalResources = make([]*event.ExternalResourceInfo, len(subTaskIDs)) +func assertTaskExternalResources(t *testing.T, subTaskIDs []*string, retryAttemptsArray *bitarray.CompactArray, detailedArray *bitarray.CompactArray, externalResources []*core.ExternalResource) { + assert.NotNil(t, externalResources) for i, subTaskID := range subTaskIDs { - externalResources[i] = &event.ExternalResourceInfo{ - ExternalId: *subTaskID, - } + externalResource := externalResources[i] + assert.Equal(t, *subTaskID, externalResource.ExternalID) + assert.Equal(t, retryAttemptsArray.GetItem(i), bitarray.Item(externalResource.RetryAttempt)) + assert.Equal(t, core.Phases[detailedArray.GetItem(i)], externalResource.Phase) } - assert.True(t, proto.Equal(&event.TaskExecutionMetadata{ - ExternalResources: externalResources, - }, metadata)) } func TestMapArrayStateToPluginPhase(t *testing.T) { ctx := context.Background() - var subTaskIDs = make([]*string, 3) - for i := 0; i < 3; i++ { + + subTaskCount := 3 + + var subTaskIDs = make([]*string, subTaskCount) + detailedArray := NewPhasesCompactArray(uint(subTaskCount)) + indexesToCache := InvertBitSet(bitarray.NewBitSet(uint(subTaskCount)), uint(subTaskCount)) + retryAttemptsArray, err := bitarray.NewCompactArray(uint(subTaskCount), bitarray.Item(1)) + assert.NoError(t, err) + + for i := 0; i < subTaskCount; i++ { subTaskID := fmt.Sprintf("sub_task_%d", i) subTaskIDs[i] = &subTaskID + + detailedArray.SetItem(i, bitarray.Item(core.PhaseRunning)) + retryAttemptsArray.SetItem(i, bitarray.Item(1)) } t.Run("start", func(t *testing.T) { s := State{ CurrentPhase: PhaseStart, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) @@ -85,6 +97,11 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { s := State{ CurrentPhase: PhaseLaunch, PhaseVersion: 0, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) @@ -98,13 +115,18 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { PhaseVersion: 8, OriginalArraySize: 10, ExecutionArraySize: 5, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) assert.Equal(t, uint32(368), phaseInfo.Version()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("write to discovery", func(t *testing.T) { @@ -113,55 +135,80 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { PhaseVersion: 8, OriginalArraySize: 10, ExecutionArraySize: 5, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) assert.Equal(t, uint32(548), phaseInfo.Version()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("success", func(t *testing.T) { s := State{ CurrentPhase: PhaseSuccess, PhaseVersion: 0, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseSuccess, phaseInfo.Phase()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("retryable failure", func(t *testing.T) { s := State{ CurrentPhase: PhaseRetryableFailure, PhaseVersion: 0, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRetryableFailure, phaseInfo.Phase()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("permanent failure", func(t *testing.T) { s := State{ CurrentPhase: PhasePermanentFailure, PhaseVersion: 0, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhasePermanentFailure, phaseInfo.Phase()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("All phases", func(t *testing.T) { for _, p := range PhaseValues() { s := State{ CurrentPhase: p, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) diff --git a/go/tasks/plugins/array/k8s/monitor.go b/go/tasks/plugins/array/k8s/monitor.go index cea135181..f7d1bfcfa 100644 --- a/go/tasks/plugins/array/k8s/monitor.go +++ b/go/tasks/plugins/array/k8s/monitor.go @@ -61,6 +61,30 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon currentState.ArrayStatus = *newArrayStatus } + // If the current State is newly minted then we must initialize RetryAttempts to track how many + // times each subtask is executed. + if len(currentState.RetryAttempts.GetItems()) == 0 { + count := uint(currentState.GetExecutionArraySize()) + maxValue := bitarray.Item(tCtx.TaskExecutionMetadata().GetMaxAttempts()) + + retryAttemptsArray, err := bitarray.NewCompactArray(count, maxValue) + if err != nil { + logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue) + return currentState, logLinks, subTaskIDs, nil + } + + // Currently if any subtask fails then all subtasks are retried up to MaxAttempts. Therefore, all + // subtasks have an identical RetryAttempt, namely that of the map task execution metadata. Once + // retries over individual subtasks are implemented we should revisit this logic and instead + // increment the RetryAttempt for each subtask everytime a new pod is created. + retryAttempt := bitarray.Item(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt) + for i := 0; i < currentState.GetExecutionArraySize(); i++ { + retryAttemptsArray.SetItem(i, retryAttempt) + } + + currentState.RetryAttempts = retryAttemptsArray + } + logPlugin, err := logs.InitializeLogPlugins(&config.LogConfig.Config) if err != nil { logger.Errorf(ctx, "Error initializing LogPlugins: [%s]", err) diff --git a/go/tasks/plugins/array/k8s/monitor_test.go b/go/tasks/plugins/array/k8s/monitor_test.go index 8f7c3414c..fd7dc7d0e 100644 --- a/go/tasks/plugins/array/k8s/monitor_test.go +++ b/go/tasks/plugins/array/k8s/monitor_test.go @@ -75,6 +75,7 @@ func getMockTaskExecutionContext(ctx context.Context) *mocks.TaskExecutionContex tMeta.OnIsInterruptible().Return(false) tMeta.OnGetK8sServiceAccount().Return("s") + tMeta.OnGetMaxAttempts().Return(2) tMeta.OnGetNamespace().Return("n") tMeta.OnGetLabels().Return(nil) tMeta.OnGetAnnotations().Return(nil) @@ -194,6 +195,9 @@ func TestCheckSubTasksState(t *testing.T) { }, } + retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) + assert.NoError(t, err) + newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -203,6 +207,7 @@ func TestCheckSubTasksState(t *testing.T) { Detailed: arrayCore.NewPhasesCompactArray(uint(5)), }, IndexesToCache: bitarray.NewBitSet(5), + RetryAttempts: retryAttemptsArray, }) assert.Nil(t, err) @@ -236,6 +241,9 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { }, } + retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) + assert.NoError(t, err) + cacheIndexes := bitarray.NewBitSet(5) newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, @@ -246,6 +254,7 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { ArrayStatus: arraystatus.ArrayStatus{ Detailed: arrayCore.NewPhasesCompactArray(uint(5)), }, + RetryAttempts: retryAttemptsArray, }) assert.Nil(t, err) @@ -273,6 +282,9 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { } cacheIndexes := bitarray.NewBitSet(5) + retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) + assert.NoError(t, err) + newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -280,6 +292,7 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { OriginalMinSuccesses: 5, ArrayStatus: *arrayStatus, IndexesToCache: cacheIndexes, + RetryAttempts: retryAttemptsArray, }) assert.Nil(t, err) diff --git a/go/tasks/plugins/hive/execution_state.go b/go/tasks/plugins/hive/execution_state.go index cbc45cc06..ed6f17cfb 100644 --- a/go/tasks/plugins/hive/execution_state.go +++ b/go/tasks/plugins/hive/execution_state.go @@ -6,8 +6,6 @@ import ( "strconv" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -149,20 +147,18 @@ func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { logs := make([]*idlCore.TaskLog, 0, 1) t := time.Now() - metadata := &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: e.CommandID, - }, + externalResources := []*core.ExternalResource{ + { + ExternalID: e.CommandID, }, } if e.CommandID != "" { logs = append(logs, ConstructTaskLog(e)) return &core.TaskInfo{ - Logs: logs, - OccurredAt: &t, - Metadata: metadata, + Logs: logs, + OccurredAt: &t, + ExternalResources: externalResources, } } diff --git a/go/tasks/plugins/hive/execution_state_test.go b/go/tasks/plugins/hive/execution_state_test.go index cd5cd868e..749e23b46 100644 --- a/go/tasks/plugins/hive/execution_state_test.go +++ b/go/tasks/plugins/hive/execution_state_test.go @@ -7,9 +7,6 @@ import ( "testing" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ioMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" @@ -128,13 +125,8 @@ func TestConstructTaskInfo(t *testing.T) { taskInfo := ConstructTaskInfo(e) assert.Equal(t, "https://wellness.qubole.com/v2/analyze?command_id=123", taskInfo.Logs[0].Uri) - assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "123", - }, - }, - })) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "123") } func TestMapExecutionStateToPhaseInfo(t *testing.T) { diff --git a/go/tasks/plugins/k8s/container/container.go b/go/tasks/plugins/k8s/container/container.go deleted file mode 100755 index 74ab6353b..000000000 --- a/go/tasks/plugins/k8s/container/container.go +++ /dev/null @@ -1,90 +0,0 @@ -package container - -import ( - "context" - - "sigs.k8s.io/controller-runtime/pkg/client" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - - v1 "k8s.io/api/core/v1" - - "github.com/flyteorg/flyteplugins/go/tasks/logs" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" -) - -const ( - containerTaskType = "container" -) - -type Plugin struct { -} - -func (Plugin) GetProperties() k8s.PluginProperties { - return k8s.PluginProperties{} -} - -func (Plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { - - pod := r.(*v1.Pod) - - t := flytek8s.GetLastTransitionOccurredAt(pod).Time - info := pluginsCore.TaskInfo{ - OccurredAt: &t, - } - if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { - taskLogs, err := logs.GetLogsForContainerInPod(ctx, pod, 0, " (User)") - if err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - info.Logs = taskLogs - } - switch pod.Status.Phase { - case v1.PodSucceeded: - return flytek8s.DemystifySuccess(pod.Status, info) - case v1.PodFailed: - code, message := flytek8s.ConvertPodFailureToError(pod.Status) - return pluginsCore.PhaseInfoRetryableFailure(code, message, &info), nil - case v1.PodPending: - return flytek8s.DemystifyPending(pod.Status) - case v1.PodUnknown: - return pluginsCore.PhaseInfoUndefined, nil - } - if len(info.Logs) > 0 { - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion+1, &info), nil - } - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &info), nil -} - -// BuildResource creates a new Pod that will Exit on completion. The pods have no retries by design -func (Plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { - - podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { - return nil, err - } - - pod := flytek8s.BuildPodWithSpec(podSpec) - - pod.Spec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - - return pod, nil -} - -func (Plugin) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) (client.Object, error) { - return flytek8s.BuildIdentityPod(), nil -} - -func init() { - pluginmachinery.PluginRegistry().RegisterK8sPlugin( - k8s.PluginEntry{ - ID: containerTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{containerTaskType}, - ResourceToWatch: &v1.Pod{}, - Plugin: Plugin{}, - IsDefault: true, - DefaultForTaskTypes: []pluginsCore.TaskType{containerTaskType}, - }) -} diff --git a/go/tasks/plugins/k8s/pod/container.go b/go/tasks/plugins/k8s/pod/container.go new file mode 100644 index 000000000..197ff502f --- /dev/null +++ b/go/tasks/plugins/k8s/pod/container.go @@ -0,0 +1,32 @@ +package pod + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + + v1 "k8s.io/api/core/v1" +) + +const ( + containerTaskType = "container" +) + +type containerPodBuilder struct { +} + +func (containerPodBuilder) buildPodSpec(ctx context.Context, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) { + podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + if err != nil { + return nil, err + } + + return podSpec, nil +} + +func (containerPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error { + return nil +} diff --git a/go/tasks/plugins/k8s/container/container_test.go b/go/tasks/plugins/k8s/pod/container_test.go old mode 100755 new mode 100644 similarity index 88% rename from go/tasks/plugins/k8s/container/container_test.go rename to go/tasks/plugins/k8s/pod/container_test.go index 0d1ca451c..3832b00f3 --- a/go/tasks/plugins/k8s/container/container_test.go +++ b/go/tasks/plugins/k8s/pod/container_test.go @@ -1,29 +1,29 @@ -package container +package pod import ( "context" "fmt" "testing" - "github.com/stretchr/testify/mock" - - "k8s.io/apimachinery/pkg/types" - - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/stretchr/testify/assert" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + v1 "k8s.io/api/core/v1" + + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" ) -var resourceRequirements = &v1.ResourceRequirements{ +var containerResourceRequirements = &v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceStorage: resource.MustParse("100M"), @@ -106,9 +106,9 @@ func dummyContainerTaskContext(resources *v1.ResourceRequirements, command []str } func TestContainerTaskExecutor_BuildIdentityResource(t *testing.T) { - c := Plugin{} + p := plugin{defaultPodBuilder, podBuilders} taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{} - r, err := c.BuildIdentityResource(context.TODO(), taskMetadata) + r, err := p.BuildIdentityResource(context.TODO(), taskMetadata) assert.NoError(t, err) assert.NotNil(t, r) _, ok := r.(*v1.Pod) @@ -117,19 +117,19 @@ func TestContainerTaskExecutor_BuildIdentityResource(t *testing.T) { } func TestContainerTaskExecutor_BuildResource(t *testing.T) { - c := Plugin{} + p := plugin{defaultPodBuilder, podBuilders} command := []string{"command"} args := []string{"{{.Input}}"} - taskCtx := dummyContainerTaskContext(resourceRequirements, command, args) + taskCtx := dummyContainerTaskContext(containerResourceRequirements, command, args) - r, err := c.BuildResource(context.TODO(), taskCtx) + r, err := p.BuildResource(context.TODO(), taskCtx) assert.NoError(t, err) assert.NotNil(t, r) j, ok := r.(*v1.Pod) assert.True(t, ok) assert.NotEmpty(t, j.Spec.Containers) - assert.Equal(t, resourceRequirements.Limits[v1.ResourceCPU], j.Spec.Containers[0].Resources.Limits[v1.ResourceCPU]) + assert.Equal(t, containerResourceRequirements.Limits[v1.ResourceCPU], j.Spec.Containers[0].Resources.Limits[v1.ResourceCPU]) // TODO: Once configurable, test when setting storage is supported on the cluster vs not. storageRes := j.Spec.Containers[0].Resources.Limits[v1.ResourceStorage] @@ -142,7 +142,7 @@ func TestContainerTaskExecutor_BuildResource(t *testing.T) { } func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { - c := Plugin{} + p := plugin{defaultPodBuilder, podBuilders} j := &v1.Pod{ Status: v1.PodStatus{}, } @@ -150,21 +150,21 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { ctx := context.TODO() t.Run("running", func(t *testing.T) { j.Status.Phase = v1.PodRunning - phaseInfo, err := c.GetTaskPhase(ctx, nil, j) + phaseInfo, err := p.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase()) }) t.Run("queued", func(t *testing.T) { j.Status.Phase = v1.PodPending - phaseInfo, err := c.GetTaskPhase(ctx, nil, j) + phaseInfo, err := p.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, phaseInfo.Phase()) }) t.Run("failNoCondition", func(t *testing.T) { j.Status.Phase = v1.PodFailed - phaseInfo, err := c.GetTaskPhase(ctx, nil, j) + phaseInfo, err := p.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) ec := phaseInfo.Err().GetCode() @@ -180,7 +180,7 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { Type: v1.PodReasonUnschedulable, }, } - phaseInfo, err := c.GetTaskPhase(ctx, nil, j) + phaseInfo, err := p.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) ec := phaseInfo.Err().GetCode() @@ -189,7 +189,7 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { t.Run("success", func(t *testing.T) { j.Status.Phase = v1.PodSucceeded - phaseInfo, err := c.GetTaskPhase(ctx, nil, j) + phaseInfo, err := p.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.NotNil(t, phaseInfo) assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase()) @@ -197,14 +197,14 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { } func TestContainerTaskExecutor_GetProperties(t *testing.T) { - plugin := Plugin{} + p := plugin{defaultPodBuilder, podBuilders} expected := k8s.PluginProperties{} - assert.Equal(t, expected, plugin.GetProperties()) + assert.Equal(t, expected, p.GetProperties()) } func TestContainerTaskExecutor_GetTaskStatus_InvalidImageName(t *testing.T) { ctx := context.TODO() - c := Plugin{} + p := plugin{defaultPodBuilder, podBuilders} reason := "InvalidImageName" message := "Failed to apply default image tag \"TEST/flyteorg/myapp:latest\": couldn't parse image reference" + " \"TEST/flyteorg/myapp:latest\": invalid reference format: repository name must be lowercase" @@ -235,7 +235,7 @@ func TestContainerTaskExecutor_GetTaskStatus_InvalidImageName(t *testing.T) { t.Run("failInvalidImageName", func(t *testing.T) { pendingPod.Status.Phase = v1.PodPending - phaseInfo, err := c.GetTaskPhase(ctx, nil, pendingPod) + phaseInfo, err := p.GetTaskPhase(ctx, nil, pendingPod) finalReason := fmt.Sprintf("|%s", reason) finalMessage := fmt.Sprintf("|%s", message) assert.NoError(t, err) diff --git a/go/tasks/plugins/k8s/pod/plugin.go b/go/tasks/plugins/k8s/pod/plugin.go new file mode 100644 index 000000000..8bc480042 --- /dev/null +++ b/go/tasks/plugins/k8s/pod/plugin.go @@ -0,0 +1,171 @@ +package pod + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + + v1 "k8s.io/api/core/v1" + + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + podTaskType = "pod" + primaryContainerKey = "primary_container_name" +) + +var ( + defaultPodBuilder = containerPodBuilder{} + podBuilders = map[string]podBuilder{ + sidecarTaskType: sidecarPodBuilder{}, + } +) + +type podBuilder interface { + buildPodSpec(ctx context.Context, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) + updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error +} + +type plugin struct { + defaultPodBuilder podBuilder + podBuilders map[string]podBuilder +} + +func (plugin) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) ( + client.Object, error) { + return flytek8s.BuildIdentityPod(), nil +} + +func (p plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { + // read TaskTemplate + task, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "TaskSpecification cannot be read, Err: [%v]", err.Error()) + } + + // initialize PodBuilder + builder, exists := p.podBuilders[task.Type] + if !exists { + builder = p.defaultPodBuilder + } + + // build pod + podSpec, err := builder.buildPodSpec(ctx, task, taskCtx) + if err != nil { + return nil, err + } + + podSpec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + pod := flytek8s.BuildPodWithSpec(podSpec) + + // update pod metadata + if err = builder.updatePodMetadata(ctx, pod, task, taskCtx); err != nil { + return nil, err + } + + return pod, nil +} + +func (plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { + pod := r.(*v1.Pod) + + transitionOccurredAt := flytek8s.GetLastTransitionOccurredAt(pod).Time + info := pluginsCore.TaskInfo{ + OccurredAt: &transitionOccurredAt, + } + + if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { + taskLogs, err := logs.GetLogsForContainerInPod(ctx, pod, 0, " (User)") + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + info.Logs = taskLogs + } + + switch pod.Status.Phase { + case v1.PodSucceeded: + return flytek8s.DemystifySuccess(pod.Status, info) + case v1.PodFailed: + code, message := flytek8s.ConvertPodFailureToError(pod.Status) + return pluginsCore.PhaseInfoRetryableFailure(code, message, &info), nil + case v1.PodPending: + return flytek8s.DemystifyPending(pod.Status) + case v1.PodReasonUnschedulable: + return pluginsCore.PhaseInfoQueued(transitionOccurredAt, pluginsCore.DefaultPhaseVersion, "pod unschedulable"), nil + case v1.PodUnknown: + return pluginsCore.PhaseInfoUndefined, nil + } + + primaryContainerName, exists := r.GetAnnotations()[primaryContainerKey] + if !exists { + // if the primary container annotation dos not exist, then the task requires all containers + // to succeed to declare success. therefore, if the pod is not in one of the above states we + // fallback to declaring the task as 'running'. + if len(info.Logs) > 0 { + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion+1, &info), nil + } + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &info), nil + } + + // if the primary container annotation exists, we use the status of the specified container + primaryContainerPhase := flytek8s.DeterminePrimaryContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &info) + if primaryContainerPhase.Phase() == pluginsCore.PhaseRunning && len(info.Logs) > 0 { + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion+1, primaryContainerPhase.Info()), nil + } + return primaryContainerPhase, nil +} + +func (plugin) GetProperties() k8s.PluginProperties { + return k8s.PluginProperties{} +} + +func init() { + podPlugin := plugin{ + defaultPodBuilder: defaultPodBuilder, + podBuilders: podBuilders, + } + + // Register containerTaskType and sidecarTaskType plugin entries. These separate task types + // still exist within the system, only now both are evaluated using the same internal pod plugin + // instance. This simplifies migration as users may keep the same configuration but are + // seamlessly transitioned from separate container and sidecar plugins to a single pod plugin. + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: containerTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{containerTaskType}, + ResourceToWatch: &v1.Pod{}, + Plugin: podPlugin, + IsDefault: true, + DefaultForTaskTypes: []pluginsCore.TaskType{containerTaskType}, + }) + + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: sidecarTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{sidecarTaskType}, + ResourceToWatch: &v1.Pod{}, + Plugin: podPlugin, + IsDefault: false, + DefaultForTaskTypes: []pluginsCore.TaskType{sidecarTaskType}, + }) + + // register podTaskType plugin entry + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: podTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{containerTaskType, sidecarTaskType}, + ResourceToWatch: &v1.Pod{}, + Plugin: podPlugin, + IsDefault: true, + DefaultForTaskTypes: []pluginsCore.TaskType{containerTaskType, sidecarTaskType}, + }) +} diff --git a/go/tasks/plugins/k8s/pod/sidecar.go b/go/tasks/plugins/k8s/pod/sidecar.go new file mode 100644 index 000000000..ba208a8a8 --- /dev/null +++ b/go/tasks/plugins/k8s/pod/sidecar.go @@ -0,0 +1,185 @@ +package pod + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + + v1 "k8s.io/api/core/v1" +) + +const ( + sidecarTaskType = "sidecar" +) + +// Why, you might wonder do we recreate the generated go struct generated from the plugins.SidecarJob proto? Because +// although we unmarshal the task custom json, the PodSpec itself is not generated from a proto definition, +// but a proper go struct defined in k8s libraries. Therefore we only unmarshal the sidecar as a json, rather than jsonpb. +type sidecarJob struct { + PodSpec *v1.PodSpec + PrimaryContainerName string + Annotations map[string]string + Labels map[string]string +} + +type sidecarPodBuilder struct { +} + +func (sidecarPodBuilder) buildPodSpec(ctx context.Context, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) { + var podSpec v1.PodSpec + switch task.TaskTypeVersion { + case 0: + // Handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. + sidecarJob := sidecarJob{} + err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + } + + if sidecarJob.PodSpec == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification, nil PodSpec [%v]", task.GetCustom()) + } + + podSpec = *sidecarJob.PodSpec + case 1: + // Handles pod tasks that marshal the pod spec to the task custom. + err := utils.UnmarshalStructToObj(task.GetCustom(), &podSpec) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error()) + } + default: + // Handles pod tasks that marshal the pod spec to the k8s_pod task target. + if task.GetK8SPod() == nil || task.GetK8SPod().PodSpec == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Pod tasks with task type version > 1 should specify their target as a K8sPod with a defined pod spec") + } + + err := utils.UnmarshalStructToObj(task.GetK8SPod().PodSpec, &podSpec) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error()) + } + } + + // Set the restart policy to *not* inherit from the default so that a completed pod doesn't get caught in a + // CrashLoopBackoff after the initial job completion. + podSpec.RestartPolicy = v1.RestartPolicyNever + + return &podSpec, nil +} + +func getPrimaryContainerNameFromConfig(task *core.TaskTemplate) (string, error) { + if len(task.GetConfig()) == 0 { + return "", errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", primaryContainerKey) + } + + primaryContainerName, ok := task.GetConfig()[primaryContainerKey] + if !ok { + return "", errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification, config missing [%s] key in [%v]", primaryContainerKey, task.GetConfig()) + } + + return primaryContainerName, nil +} + +func mergeMapInto(src map[string]string, dst map[string]string) { + for key, value := range src { + dst[key] = value + } +} + +func (sidecarPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error { + pod.Annotations = make(map[string]string) + pod.Labels = make(map[string]string) + + var primaryContainerName string + switch task.TaskTypeVersion { + case 0: + // Handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. + sidecarJob := sidecarJob{} + err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob) + if err != nil { + return errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + } + + mergeMapInto(sidecarJob.Annotations, pod.Annotations) + mergeMapInto(sidecarJob.Labels, pod.Labels) + + primaryContainerName = sidecarJob.PrimaryContainerName + case 1: + // Handles pod tasks that marshal the pod spec to the task custom. + containerName, err := getPrimaryContainerNameFromConfig(task) + if err != nil { + return err + } + + primaryContainerName = containerName + default: + // Handles pod tasks that marshal the pod spec to the k8s_pod task target. + if task.GetK8SPod() == nil || task.GetK8SPod().Metadata != nil { + mergeMapInto(task.GetK8SPod().Metadata.Annotations, pod.Annotations) + mergeMapInto(task.GetK8SPod().Metadata.Labels, pod.Labels) + } + + containerName, err := getPrimaryContainerNameFromConfig(task) + if err != nil { + return err + } + + primaryContainerName = containerName + } + + // validate pod and update resource requirements + if err := validateAndFinalizePodSpec(ctx, taskCtx, primaryContainerName, &pod.Spec); err != nil { + return err + } + + pod.Annotations[primaryContainerKey] = primaryContainerName + return nil +} + +// This method handles templatizing primary container input args, env variables and adds a GPU toleration to the pod +// spec if necessary. +func validateAndFinalizePodSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string, podSpec *v1.PodSpec) error { + var hasPrimaryContainer bool + + resReqs := make([]v1.ResourceRequirements, 0, len(podSpec.Containers)) + for index, container := range podSpec.Containers { + var resourceMode = flytek8s.ResourceCustomizationModeEnsureExistingResourcesInRange + if container.Name == primaryContainerName { + hasPrimaryContainer = true + resourceMode = flytek8s.ResourceCustomizationModeMergeExistingResources + } + + templateParameters := template.Parameters{ + TaskExecMetadata: taskCtx.TaskExecutionMetadata(), + Inputs: taskCtx.InputReader(), + OutputPath: taskCtx.OutputWriter(), + Task: taskCtx.TaskReader(), + } + + err := flytek8s.AddFlyteCustomizationsToContainer(ctx, templateParameters, resourceMode, &podSpec.Containers[index]) + if err != nil { + return err + } + + resReqs = append(resReqs, container.Resources) + } + + if !hasPrimaryContainer { + return errors.Errorf(errors.BadTaskSpecification, "invalid Sidecar task, primary container [%s] not defined", primaryContainerName) + } + + flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), resReqs, podSpec) + return nil +} diff --git a/go/tasks/plugins/k8s/sidecar/sidecar_test.go b/go/tasks/plugins/k8s/pod/sidecar_test.go old mode 100755 new mode 100644 similarity index 87% rename from go/tasks/plugins/k8s/sidecar/sidecar_test.go rename to go/tasks/plugins/k8s/pod/sidecar_test.go index c8d3057b0..77cb40afa --- a/go/tasks/plugins/k8s/sidecar/sidecar_test.go +++ b/go/tasks/plugins/k8s/pod/sidecar_test.go @@ -1,4 +1,4 @@ -package sidecar +package pod import ( "context" @@ -32,7 +32,7 @@ import ( const ResourceNvidiaGPU = "nvidia.com/gpu" -var resourceRequirements = &v1.ResourceRequirements{ +var sidecarResourceRequirements = &v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("2048m"), v1.ResourceEphemeralStorage: resource.MustParse("100M"), @@ -51,11 +51,12 @@ func getSidecarTaskTemplateForTest(sideCarJob sidecarJob) *core.TaskTemplate { panic(err) } return &core.TaskTemplate{ + Type: sidecarTaskType, Custom: &structObj, } } -func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.TaskExecutionMetadata { +func dummySidecarTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.TaskExecutionMetadata { taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{} taskMetadata.On("GetNamespace").Return("test-namespace") taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) @@ -96,7 +97,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore. func getDummySidecarTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements) pluginsCore.TaskExecutionContext { taskCtx := &pluginsCoreMock.TaskExecutionContext{} - dummyTaskMetadata := dummyContainerTaskMetadata(resources) + dummyTaskMetadata := dummySidecarTaskMetadata(resources) inputReader := &pluginsIOMock.InputReader{} inputReader.OnGetInputPrefixPath().Return("test-data-prefix") inputReader.OnGetInputPath().Return("test-data-reference") @@ -198,6 +199,7 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { } task := core.TaskTemplate{ + Type: sidecarTaskType, TaskTypeVersion: 2, Config: map[string]string{ primaryContainerKey: "primary container", @@ -239,9 +241,9 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { DefaultMemoryRequest: resource.MustParse("1024Mi"), GpuResourceName: ResourceNvidiaGPU, })) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) - res, err := handler.BuildResource(context.TODO(), taskCtx) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) + res, err := p.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ primaryContainerKey: "primary container", @@ -282,6 +284,7 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { func TestBuildSidecarResource_TaskType2_Invalid_Spec(t *testing.T) { task := core.TaskTemplate{ + Type: sidecarTaskType, TaskTypeVersion: 2, Config: map[string]string{ primaryContainerKey: "primary container", @@ -300,9 +303,9 @@ func TestBuildSidecarResource_TaskType2_Invalid_Spec(t *testing.T) { }, } - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) - _, err := handler.BuildResource(context.TODO(), taskCtx) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) + _, err := p.BuildResource(context.TODO(), taskCtx) assert.EqualError(t, err, "[BadTaskSpecification] Pod tasks with task type version > 1 should specify their target as a K8sPod with a defined pod spec") } @@ -320,6 +323,7 @@ func TestBuildSidecarResource_TaskType1(t *testing.T) { } task := core.TaskTemplate{ + Type: sidecarTaskType, Custom: structObj, TaskTypeVersion: 1, Config: map[string]string{ @@ -348,9 +352,9 @@ func TestBuildSidecarResource_TaskType1(t *testing.T) { DefaultCPURequest: resource.MustParse("1024m"), DefaultMemoryRequest: resource.MustParse("1024Mi"), })) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) - res, err := handler.BuildResource(context.TODO(), taskCtx) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) + res, err := p.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ primaryContainerKey: "primary container", @@ -401,6 +405,7 @@ func TestBuildSideResource_TaskType1_InvalidSpec(t *testing.T) { } task := core.TaskTemplate{ + Type: sidecarTaskType, Custom: structObj, TaskTypeVersion: 1, } @@ -413,16 +418,16 @@ func TestBuildSideResource_TaskType1_InvalidSpec(t *testing.T) { DefaultCPURequest: resource.MustParse("1024m"), DefaultMemoryRequest: resource.MustParse("1024Mi"), })) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) - _, err = handler.BuildResource(context.TODO(), taskCtx) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) + _, err = p.BuildResource(context.TODO(), taskCtx) assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config needs to be non-empty and include missing [primary_container_name] key") task.Config = map[string]string{ "foo": "bar", } - taskCtx = getDummySidecarTaskContext(&task, resourceRequirements) - _, err = handler.BuildResource(context.TODO(), taskCtx) + taskCtx = getDummySidecarTaskContext(&task, sidecarResourceRequirements) + _, err = p.BuildResource(context.TODO(), taskCtx) assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config missing [primary_container_name] key in [map[foo:bar]]") } @@ -441,6 +446,7 @@ func TestBuildSidecarResource(t *testing.T) { t.Fatal(err) } task := core.TaskTemplate{ + Type: sidecarTaskType, Custom: &sidecarCustom, } @@ -465,9 +471,9 @@ func TestBuildSidecarResource(t *testing.T) { DefaultCPURequest: resource.MustParse("1024m"), DefaultMemoryRequest: resource.MustParse("1024Mi"), })) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) - res, err := handler.BuildResource(context.TODO(), taskCtx) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) + res, err := p.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ primaryContainerKey: "a container", @@ -522,9 +528,9 @@ func TestBuildSidecarReosurceMissingAnnotationsAndLabels(t *testing.T) { task := getSidecarTaskTemplateForTest(sideCarJob) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(task, resourceRequirements) - resp, err := handler.BuildResource(context.TODO(), taskCtx) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(task, sidecarResourceRequirements) + resp, err := p.BuildResource(context.TODO(), taskCtx) assert.NoError(t, err) assert.EqualValues(t, map[string]string{}, resp.GetLabels()) assert.EqualValues(t, map[string]string{"primary_container_name": "PrimaryContainer"}, resp.GetAnnotations()) @@ -544,9 +550,9 @@ func TestBuildSidecarResourceMissingPrimary(t *testing.T) { task := getSidecarTaskTemplateForTest(sideCarJob) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(task, resourceRequirements) - _, err := handler.BuildResource(context.TODO(), taskCtx) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(task, sidecarResourceRequirements) + _, err := p.BuildResource(context.TODO(), taskCtx) assert.True(t, errors.Is(err, errors2.Errorf("BadTaskSpecification", ""))) } @@ -580,9 +586,9 @@ func TestGetTaskSidecarStatus(t *testing.T) { res.SetAnnotations(map[string]string{ primaryContainerKey: "PrimaryContainer", }) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(task, resourceRequirements) - phaseInfo, err := handler.GetTaskPhase(context.TODO(), taskCtx, res) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(task, sidecarResourceRequirements) + phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, expectedTaskPhase, phaseInfo.Phase(), "Expected [%v] got [%v] instead, for podPhase [%v]", expectedTaskPhase, phaseInfo.Phase(), podPhase) @@ -608,9 +614,9 @@ func TestDemystifiedSidecarStatus_PrimaryFailed(t *testing.T) { res.SetAnnotations(map[string]string{ primaryContainerKey: "Primary", }) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, resourceRequirements) - phaseInfo, err := handler.GetTaskPhase(context.TODO(), taskCtx, res) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) + phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) } @@ -634,9 +640,9 @@ func TestDemystifiedSidecarStatus_PrimarySucceeded(t *testing.T) { res.SetAnnotations(map[string]string{ primaryContainerKey: "Primary", }) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, resourceRequirements) - phaseInfo, err := handler.GetTaskPhase(context.TODO(), taskCtx, res) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) + phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase()) } @@ -660,9 +666,9 @@ func TestDemystifiedSidecarStatus_PrimaryRunning(t *testing.T) { res.SetAnnotations(map[string]string{ primaryContainerKey: "Primary", }) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, resourceRequirements) - phaseInfo, err := handler.GetTaskPhase(context.TODO(), taskCtx, res) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) + phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase()) } @@ -681,15 +687,15 @@ func TestDemystifiedSidecarStatus_PrimaryMissing(t *testing.T) { res.SetAnnotations(map[string]string{ primaryContainerKey: "Primary", }) - handler := &sidecarResourceHandler{} - taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, resourceRequirements) - phaseInfo, err := handler.GetTaskPhase(context.TODO(), taskCtx, res) + p := &plugin{defaultPodBuilder, podBuilders} + taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) + phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, pluginsCore.PhasePermanentFailure, phaseInfo.Phase()) } func TestGetProperties(t *testing.T) { - handler := &sidecarResourceHandler{} + p := &plugin{defaultPodBuilder, podBuilders} expected := k8s.PluginProperties{} - assert.Equal(t, expected, handler.GetProperties()) + assert.Equal(t, expected, p.GetProperties()) } diff --git a/go/tasks/plugins/k8s/sidecar/testdata/sidecar_custom b/go/tasks/plugins/k8s/pod/testdata/sidecar_custom similarity index 100% rename from go/tasks/plugins/k8s/sidecar/testdata/sidecar_custom rename to go/tasks/plugins/k8s/pod/testdata/sidecar_custom diff --git a/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go b/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go index df02c97a3..1634ff7b7 100644 --- a/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go +++ b/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go @@ -5,9 +5,6 @@ import ( "fmt" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" - "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -273,12 +270,7 @@ func Test_awsSagemakerPlugin_getEventInfoForTrainingJob(t *testing.T) { if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } - assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "some-acceptable-name", - }, - }, - })) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "some-acceptable-name") }) } diff --git a/go/tasks/plugins/k8s/sagemaker/custom_training_test.go b/go/tasks/plugins/k8s/sagemaker/custom_training_test.go index 91709b507..748fa7d35 100644 --- a/go/tasks/plugins/k8s/sagemaker/custom_training_test.go +++ b/go/tasks/plugins/k8s/sagemaker/custom_training_test.go @@ -6,9 +6,6 @@ import ( "strconv" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" - "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -297,12 +294,7 @@ func Test_awsSagemakerPlugin_getEventInfoForCustomTrainingJob(t *testing.T) { if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } - assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "some-acceptable-name", - }, - }, - })) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "some-acceptable-name") }) } diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go index 796949d86..e710dc564 100644 --- a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go @@ -5,9 +5,6 @@ import ( "fmt" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" - "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -129,12 +126,7 @@ func Test_awsSagemakerPlugin_getEventInfoForHyperparameterTuningJob(t *testing.T if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } - assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "some-acceptable-name", - }, - }, - })) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "some-acceptable-name") }) } diff --git a/go/tasks/plugins/k8s/sagemaker/utils.go b/go/tasks/plugins/k8s/sagemaker/utils.go index 7dc26ac8b..de1354ad9 100644 --- a/go/tasks/plugins/k8s/sagemaker/utils.go +++ b/go/tasks/plugins/k8s/sagemaker/utils.go @@ -6,7 +6,6 @@ import ( "sort" "strings" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" @@ -400,11 +399,9 @@ func createTaskInfo(_ context.Context, jobRegion string, jobName string, jobType return &pluginsCore.TaskInfo{ Logs: taskLogs, CustomInfo: customInfo, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: jobName, - }, + ExternalResources: []*pluginsCore.ExternalResource{ + { + ExternalID: jobName, }, }, }, nil diff --git a/go/tasks/plugins/k8s/sidecar/sidecar.go b/go/tasks/plugins/k8s/sidecar/sidecar.go deleted file mode 100755 index fda1bd502..000000000 --- a/go/tasks/plugins/k8s/sidecar/sidecar.go +++ /dev/null @@ -1,273 +0,0 @@ -package sidecar - -import ( - "context" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" - - "sigs.k8s.io/controller-runtime/pkg/client" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" - - "github.com/flyteorg/flyteplugins/go/tasks/errors" - "github.com/flyteorg/flyteplugins/go/tasks/logs" - k8sv1 "k8s.io/api/core/v1" -) - -const ( - sidecarTaskType = "sidecar" - primaryContainerKey = "primary_container_name" -) - -type sidecarResourceHandler struct{} - -// This method handles templatizing primary container input args, env variables and adds a GPU toleration to the pod -// spec if necessary. -func validateAndFinalizePod( - ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string, pod k8sv1.Pod) (*k8sv1.Pod, error) { - var hasPrimaryContainer bool - - resReqs := make([]k8sv1.ResourceRequirements, 0, len(pod.Spec.Containers)) - for index, container := range pod.Spec.Containers { - var resourceMode = flytek8s.ResourceCustomizationModeEnsureExistingResourcesInRange - if container.Name == primaryContainerName { - hasPrimaryContainer = true - resourceMode = flytek8s.ResourceCustomizationModeMergeExistingResources - } - templateParameters := template.Parameters{ - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - Task: taskCtx.TaskReader(), - } - err := flytek8s.AddFlyteCustomizationsToContainer(ctx, templateParameters, resourceMode, &pod.Spec.Containers[index]) - if err != nil { - return nil, err - } - resReqs = append(resReqs, container.Resources) - } - if !hasPrimaryContainer { - return nil, errors.Errorf(errors.BadTaskSpecification, - "invalid Sidecar task, primary container [%s] not defined", primaryContainerName) - - } - flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), resReqs, &pod.Spec) - return &pod, nil -} - -// Why, you might wonder do we recreate the generated go struct generated from the plugins.SidecarJob proto? Because -// although we unmarshal the task custom json, the PodSpec itself is not generated from a proto definition, -// but a proper go struct defined in k8s libraries. Therefore we only unmarshal the sidecar as a json, rather than jsonpb. -type sidecarJob struct { - PodSpec *k8sv1.PodSpec - PrimaryContainerName string - Annotations map[string]string - Labels map[string]string -} - -func (sidecarResourceHandler) GetProperties() k8s.PluginProperties { - return k8s.PluginProperties{} -} - -func getPrimaryContainerNameFromConfig(task *core.TaskTemplate) (string, error) { - if len(task.GetConfig()) == 0 { - return "", errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", primaryContainerKey) - } - primaryContainerName, ok := task.GetConfig()[primaryContainerKey] - if !ok { - return "", errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config missing [%s] key in [%v]", primaryContainerKey, task.GetConfig()) - } - return primaryContainerName, nil -} - -type podSpecResource struct { - podSpec k8sv1.PodSpec - primaryContainerName string - annotations map[string]string - labels map[string]string -} - -func newPodSpecResource() podSpecResource { - return podSpecResource{ - annotations: make(map[string]string), - labels: make(map[string]string), - } -} - -// Handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. -func buildResourceV0(task *core.TaskTemplate) (podSpecResource, error) { - res := newPodSpecResource() - sidecarJob := sidecarJob{} - err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob) - if err != nil { - return podSpecResource{}, errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - if sidecarJob.PodSpec == nil { - return podSpecResource{}, errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, nil PodSpec [%v]", task.GetCustom()) - } - res.podSpec = *sidecarJob.PodSpec - res.primaryContainerName = sidecarJob.PrimaryContainerName - if sidecarJob.Annotations != nil { - res.annotations = sidecarJob.Annotations - } - - if sidecarJob.Labels != nil { - res.labels = sidecarJob.Labels - } - - return res, nil -} - -// Handles pod tasks that marshal the pod spec to the task custom. -func buildResourceV1(task *core.TaskTemplate) (podSpecResource, error) { - res := newPodSpecResource() - err := utils.UnmarshalStructToObj(task.GetCustom(), &res.podSpec) - if err != nil { - return podSpecResource{}, errors.Errorf(errors.BadTaskSpecification, - "Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - res.primaryContainerName, err = getPrimaryContainerNameFromConfig(task) - if err != nil { - return podSpecResource{}, err - } - return res, nil -} - -// Handles pod tasks that marshal the pod spec to the k8s_pod task target. -func buildResourceV2(task *core.TaskTemplate) (podSpecResource, error) { - res := newPodSpecResource() - if task.GetK8SPod() == nil || task.GetK8SPod().PodSpec == nil { - return podSpecResource{}, errors.Errorf(errors.BadTaskSpecification, - "Pod tasks with task type version > 1 should specify their target as a K8sPod with a defined pod spec") - } - err := utils.UnmarshalStructToObj(task.GetK8SPod().PodSpec, &res.podSpec) - if err != nil { - return podSpecResource{}, errors.Errorf(errors.BadTaskSpecification, - "Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - res.primaryContainerName, err = getPrimaryContainerNameFromConfig(task) - if err != nil { - return podSpecResource{}, err - } - if task.GetK8SPod().Metadata != nil { - if task.GetK8SPod().Metadata.Annotations != nil { - res.annotations = task.GetK8SPod().Metadata.Annotations - } - if task.GetK8SPod().Metadata.Labels != nil { - res.labels = task.GetK8SPod().Metadata.Labels - } - } - return res, nil -} - -func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { - task, err := taskCtx.TaskReader().Read(ctx) - if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "TaskSpecification cannot be read, Err: [%v]", err.Error()) - } - var podSpecResource podSpecResource - switch task.TaskTypeVersion { - case 0: - podSpecResource, err = buildResourceV0(task) - if err != nil { - return nil, err - } - case 1: - podSpecResource, err = buildResourceV1(task) - if err != nil { - return nil, err - } - default: - podSpecResource, err = buildResourceV2(task) - if err != nil { - return nil, err - } - } - - pod := flytek8s.BuildPodWithSpec(&podSpecResource.podSpec) - // Set the restart policy to *not* inherit from the default so that a completed pod doesn't get caught in a - // CrashLoopBackoff after the initial job completion. - pod.Spec.RestartPolicy = k8sv1.RestartPolicyNever - - pod.Spec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - - pod, err = validateAndFinalizePod(ctx, taskCtx, podSpecResource.primaryContainerName, *pod) - if err != nil { - return nil, err - } - - pod.Annotations = podSpecResource.annotations - pod.Annotations[primaryContainerKey] = podSpecResource.primaryContainerName - pod.Labels = podSpecResource.labels - return pod, nil -} - -func (sidecarResourceHandler) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) ( - client.Object, error) { - return flytek8s.BuildIdentityPod(), nil -} - -func (sidecarResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { - pod := r.(*k8sv1.Pod) - - transitionOccurredAt := flytek8s.GetLastTransitionOccurredAt(pod).Time - info := pluginsCore.TaskInfo{ - OccurredAt: &transitionOccurredAt, - } - if pod.Status.Phase != k8sv1.PodPending && pod.Status.Phase != k8sv1.PodUnknown { - taskLogs, err := logs.GetLogsForContainerInPod(ctx, pod, 0, " (User)") - if err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - info.Logs = taskLogs - } - switch pod.Status.Phase { - case k8sv1.PodSucceeded: - return flytek8s.DemystifySuccess(pod.Status, info) - case k8sv1.PodFailed: - code, message := flytek8s.ConvertPodFailureToError(pod.Status) - return pluginsCore.PhaseInfoRetryableFailure(code, message, &info), nil - case k8sv1.PodPending: - return flytek8s.DemystifyPending(pod.Status) - case k8sv1.PodReasonUnschedulable: - return pluginsCore.PhaseInfoQueued(transitionOccurredAt, pluginsCore.DefaultPhaseVersion, "pod unschedulable"), nil - case k8sv1.PodUnknown: - return pluginsCore.PhaseInfoUndefined, nil - } - - // Otherwise, assume the pod is running. - primaryContainerName, ok := r.GetAnnotations()[primaryContainerKey] - if !ok { - return pluginsCore.PhaseInfoUndefined, errors.Errorf(errors.BadTaskSpecification, - "missing primary container annotation for pod") - } - primaryContainerPhase := flytek8s.DeterminePrimaryContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &info) - - if primaryContainerPhase.Phase() == pluginsCore.PhaseRunning && len(info.Logs) > 0 { - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion+1, primaryContainerPhase.Info()), nil - } - return primaryContainerPhase, nil -} - -func init() { - pluginmachinery.PluginRegistry().RegisterK8sPlugin( - k8s.PluginEntry{ - ID: sidecarTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{sidecarTaskType}, - ResourceToWatch: &k8sv1.Pod{}, - Plugin: sidecarResourceHandler{}, - IsDefault: false, - DefaultForTaskTypes: []pluginsCore.TaskType{sidecarTaskType}, - }) -} diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 3370b0b95..7d241cb09 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -3,8 +3,6 @@ package presto import ( "context" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -506,11 +504,9 @@ func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { return &core.TaskInfo{ Logs: logs, OccurredAt: &t, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: e.CommandID, - }, + ExternalResources: []*core.ExternalResource{ + { + ExternalID: e.CommandID, }, }, } diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index d6caad0b8..9a2c4c15b 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -106,8 +106,8 @@ func TestConstructTaskInfo(t *testing.T) { taskInfo := ConstructTaskInfo(e) assert.Equal(t, "https://prestoproxy-internal.flyteorg.net:443", taskInfo.Logs[0].Uri) - assert.Len(t, taskInfo.Metadata.ExternalResources, 1) - assert.Equal(t, taskInfo.Metadata.ExternalResources[0].ExternalId, "123") + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "123") } func TestMapExecutionStateToPhaseInfo(t *testing.T) { diff --git a/go/tasks/plugins/webapi/athena/plugin.go b/go/tasks/plugins/webapi/athena/plugin.go index f2cbba436..c39d9104a 100644 --- a/go/tasks/plugins/webapi/athena/plugin.go +++ b/go/tasks/plugins/webapi/athena/plugin.go @@ -5,8 +5,6 @@ import ( "fmt" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - errors2 "github.com/flyteorg/flyteplugins/go/tasks/errors" awsSdk "github.com/aws/aws-sdk-go-v2/aws" @@ -186,11 +184,9 @@ func createTaskInfo(queryID string, cfg awsSdk.Config) *core.TaskInfo { Name: "Athena Query Console", }, }, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: queryID, - }, + ExternalResources: []*core.ExternalResource{ + { + ExternalID: queryID, }, }, } diff --git a/go/tasks/plugins/webapi/athena/plugin_test.go b/go/tasks/plugins/webapi/athena/plugin_test.go index d85b42573..5f821bb67 100644 --- a/go/tasks/plugins/webapi/athena/plugin_test.go +++ b/go/tasks/plugins/webapi/athena/plugin_test.go @@ -5,8 +5,6 @@ import ( awsSdk "github.com/aws/aws-sdk-go-v2/aws" idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" ) @@ -20,11 +18,6 @@ func TestCreateTaskInfo(t *testing.T) { Name: "Athena Query Console", }, }, taskInfo.Logs) - assert.True(t, proto.Equal(&event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "query_id", - }, - }, - }, taskInfo.Metadata)) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "query_id") } diff --git a/go/tasks/plugins/webapi/bigquery/integration_test.go b/go/tasks/plugins/webapi/bigquery/integration_test.go index 187150633..bec6f3c2f 100644 --- a/go/tasks/plugins/webapi/bigquery/integration_test.go +++ b/go/tasks/plugins/webapi/bigquery/integration_test.go @@ -47,7 +47,6 @@ func TestEndToEnd(t *testing.T) { t.Run("SELECT 1", func(t *testing.T) { queryJobConfig := QueryJobConfig{ ProjectID: "flyte", - Query: "SELECT 1", } inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) @@ -55,6 +54,7 @@ func TestEndToEnd(t *testing.T) { template := flyteIdlCore.TaskTemplate{ Type: bigqueryQueryJobTask, Custom: custom, + Target: &flyteIdlCore.TaskTemplate_Sql{Sql: &flyteIdlCore.Sql{Statement: "SELECT 1", Dialect: flyteIdlCore.Sql_ANSI}}, } phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) @@ -75,7 +75,11 @@ func newFakeBigQueryServer() *httptest.Server { if strings.HasPrefix(request.URL.Path, "/projects/flyte/jobs/") && request.Method == "GET" { writer.WriteHeader(200) - job := bigquery.Job{Status: &bigquery.JobStatus{State: "DONE"}} + job := bigquery.Job{Status: &bigquery.JobStatus{State: "DONE"}, + Configuration: &bigquery.JobConfiguration{ + Query: &bigquery.JobConfigurationQuery{ + DestinationTable: &bigquery.TableReference{ + ProjectId: "project", DatasetId: "dataset", TableId: "table"}}}} bytes, _ := json.Marshal(job) _, _ = writer.Write(bytes) return diff --git a/go/tasks/plugins/webapi/bigquery/plugin.go b/go/tasks/plugins/webapi/bigquery/plugin.go index 7a8b21d28..8dd45b650 100644 --- a/go/tasks/plugins/webapi/bigquery/plugin.go +++ b/go/tasks/plugins/webapi/bigquery/plugin.go @@ -7,10 +7,12 @@ import ( "net/http" "time" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "golang.org/x/oauth2" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/google" structpb "github.com/golang/protobuf/ptypes/struct" "google.golang.org/api/bigquery/v2" @@ -42,8 +44,9 @@ type Plugin struct { } type ResourceWrapper struct { - Status *bigquery.JobStatus - CreateError *googleapi.Error + Status *bigquery.JobStatus + CreateError *googleapi.Error + OutputLocation string } type ResourceMetaWrapper struct { @@ -105,6 +108,7 @@ func (p Plugin) createImpl(ctx context.Context, taskCtx webapi.TaskExecutionCont return nil, nil, err } + job.Configuration.Query.Query = taskTemplate.GetSql().Statement job.Configuration.Labels = taskCtx.TaskExecutionMetadata().GetLabels() resp, err := client.Jobs.Insert(job.JobReference.ProjectId, job).Do() @@ -210,8 +214,12 @@ func (p Plugin) getImpl(ctx context.Context, taskCtx webapi.GetContext) (wrapper return nil, err } + dst := job.Configuration.Query.DestinationTable + outputLocation := fmt.Sprintf("bq://%v:%v.%v", dst.ProjectId, dst.DatasetId, dst.TableId) + return &ResourceWrapper{ - Status: job.Status, + Status: job.Status, + OutputLocation: outputLocation, }, nil } @@ -243,7 +251,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } -func (p Plugin) Status(_ context.Context, tCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { +func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { resourceMeta := tCtx.ResourceMeta().(*ResourceMetaWrapper) resource := tCtx.Resource().(*ResourceWrapper) version := pluginsCore.DefaultPhaseVersion @@ -272,13 +280,54 @@ func (p Plugin) Status(_ context.Context, tCtx webapi.StatusContext) (phase core resource.Status.ErrorResult.Message, taskInfo), nil } - + err = writeOutput(ctx, tCtx, resource.OutputLocation) + if err != nil { + logger.Warnf(ctx, "Failed to write output, uri [%s], err %s", resource.OutputLocation, err.Error()) + return core.PhaseInfoUndefined, err + } return pluginsCore.PhaseInfoSuccess(taskInfo), nil } return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.Status.State) } +func writeOutput(ctx context.Context, tCtx webapi.StatusContext, OutputLocation string) error { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return err + } + + if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil { + logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.") + return nil + } + + resultsStructuredDatasetType, exists := taskTemplate.Interface.Outputs.Variables["results"] + if !exists { + logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.") + return nil + } + return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + &flyteIdlCore.LiteralMap{ + Literals: map[string]*flyteIdlCore.Literal{ + "results": { + Value: &flyteIdlCore.Literal_Scalar{ + Scalar: &flyteIdlCore.Scalar{ + Value: &flyteIdlCore.Scalar_StructuredDataset{ + StructuredDataset: &flyteIdlCore.StructuredDataset{ + Uri: OutputLocation, + Metadata: &flyteIdlCore.StructuredDatasetMetadata{ + StructuredDatasetType: resultsStructuredDatasetType.GetType().GetStructuredDatasetType(), + }, + }, + }, + }, + }, + }, + }, + }, nil)) +} + func handleCreateError(createError *googleapi.Error, taskInfo *core.TaskInfo) core.PhaseInfo { code := fmt.Sprintf("http%d", createError.Code) @@ -456,7 +505,8 @@ func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity) options = append(options, option.WithEndpoint(p.cfg.bigQueryEndpoint), option.WithTokenSource(oauth2.StaticTokenSource(&oauth2.Token{}))) - } else { + } else if p.cfg.GoogleTokenSource.Type != "default" { + tokenSource, err := p.googleTokenSource.GetTokenSource(ctx, identity) if err != nil { @@ -464,6 +514,8 @@ func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity) } options = append(options, option.WithTokenSource(tokenSource)) + } else { + logger.Infof(ctx, "BigQuery client read $GOOGLE_APPLICATION_CREDENTIALS by default") } return bigquery.NewService(ctx, options...) diff --git a/go/tasks/plugins/webapi/bigquery/plugin_test.go b/go/tasks/plugins/webapi/bigquery/plugin_test.go index 0de1d30c5..d39d84418 100644 --- a/go/tasks/plugins/webapi/bigquery/plugin_test.go +++ b/go/tasks/plugins/webapi/bigquery/plugin_test.go @@ -1,9 +1,21 @@ package bigquery import ( + "context" "testing" "time" + coreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + "github.com/stretchr/testify/mock" + "k8s.io/apimachinery/pkg/util/rand" + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -12,6 +24,10 @@ import ( "google.golang.org/api/googleapi" ) +func init() { + labeled.SetMetricKeys(contextutils.NamespaceKey) +} + func TestFormatJobReference(t *testing.T) { t.Run("format job reference", func(t *testing.T) { jobReference := bigquery.JobReference{ @@ -46,6 +62,79 @@ func TestCreateTaskInfo(t *testing.T) { }) } +func TestOutputWriter(t *testing.T) { + ctx := context.Background() + statusContext := &mocks.StatusContext{} + + template := flyteIdlCore.TaskTemplate{} + tr := &coreMocks.TaskReader{} + tr.OnRead(ctx).Return(&template, nil) + statusContext.OnTaskReader().Return(tr) + + outputLocation := "bq://project:flyte.table" + err := writeOutput(ctx, statusContext, outputLocation) + assert.NoError(t, err) + + ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + outputWriter := &ioMocks.OutputWriter{} + outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + or := args.Get(1).(io.OutputReader) + literals, ee, err := or.Read(ctx) + assert.NoError(t, err) + + sd := literals.GetLiterals()["results"].GetScalar().GetStructuredDataset() + assert.Equal(t, sd.Uri, outputLocation) + assert.Equal(t, sd.Metadata.GetStructuredDatasetType().Columns[0].Name, "col1") + assert.Equal(t, sd.Metadata.GetStructuredDatasetType().Columns[0].LiteralType.GetSimple(), flyteIdlCore.SimpleType_INTEGER) + + if ee != nil { + assert.NoError(t, ds.WriteProtobuf(ctx, outputWriter.GetErrorPath(), storage.Options{}, ee)) + } + + if literals != nil { + assert.NoError(t, ds.WriteProtobuf(ctx, outputWriter.GetOutputPath(), storage.Options{}, literals)) + } + }) + + execID := rand.String(3) + basePrefix := storage.DataReference("fake://bucket/prefix/" + execID) + outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb") + statusContext.OnOutputWriter().Return(outputWriter) + + template = flyteIdlCore.TaskTemplate{ + Interface: &flyteIdlCore.TypedInterface{ + Outputs: &flyteIdlCore.VariableMap{ + Variables: map[string]*flyteIdlCore.Variable{ + "results": { + Type: &flyteIdlCore.LiteralType{ + Type: &flyteIdlCore.LiteralType_StructuredDatasetType{ + StructuredDatasetType: &flyteIdlCore.StructuredDatasetType{ + Columns: []*flyteIdlCore.StructuredDatasetType_DatasetColumn{ + { + Name: "col1", + LiteralType: &flyteIdlCore.LiteralType{ + Type: &flyteIdlCore.LiteralType_Simple{ + Simple: flyteIdlCore.SimpleType_INTEGER, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + tr.OnRead(ctx).Return(&template, nil) + statusContext.OnTaskReader().Return(tr) + err = writeOutput(ctx, statusContext, outputLocation) + assert.NoError(t, err) +} + func TestHandleCreateError(t *testing.T) { occurredAt := time.Now() taskInfo := core.TaskInfo{OccurredAt: &occurredAt} diff --git a/go/tasks/plugins/webapi/bigquery/query_job.go b/go/tasks/plugins/webapi/bigquery/query_job.go index 9c1fec1cd..ccd3610f0 100644 --- a/go/tasks/plugins/webapi/bigquery/query_job.go +++ b/go/tasks/plugins/webapi/bigquery/query_job.go @@ -161,6 +161,9 @@ func getJobConfigurationQuery(custom *QueryJobConfig, inputs *flyteIdlCore.Liter return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "unable build query parameters [%v]", err.Error()) } + // BigQuery supports query parameters to help prevent SQL injection when queries are constructed using user input. + // This feature is only available with standard SQL syntax. For more detail: https://cloud.google.com/bigquery/docs/parameterized-queries + useLegacySQL := false return &bigquery.JobConfigurationQuery{ AllowLargeResults: custom.AllowLargeResults, Clustering: custom.Clustering, @@ -178,7 +181,7 @@ func getJobConfigurationQuery(custom *QueryJobConfig, inputs *flyteIdlCore.Liter SchemaUpdateOptions: custom.SchemaUpdateOptions, TableDefinitions: custom.TableDefinitions, TimePartitioning: custom.TimePartitioning, - UseLegacySql: custom.UseLegacySQL, + UseLegacySql: &useLegacySQL, UseQueryCache: custom.UseQueryCache, UserDefinedFunctionResources: custom.UserDefinedFunctionResources, WriteDisposition: custom.WriteDisposition, diff --git a/go/tasks/plugins/webapi/bigquery/query_job_test.go b/go/tasks/plugins/webapi/bigquery/query_job_test.go index 8df93268a..b53c83840 100644 --- a/go/tasks/plugins/webapi/bigquery/query_job_test.go +++ b/go/tasks/plugins/webapi/bigquery/query_job_test.go @@ -69,10 +69,11 @@ func TestGetJobConfigurationQuery(t *testing.T) { }) jobConfigurationQuery, err := getJobConfigurationQuery(&config, inputs) + useLegacySQL := false assert.NoError(t, err) assert.Equal(t, "NAMED", jobConfigurationQuery.ParameterMode) - + assert.Equal(t, &useLegacySQL, jobConfigurationQuery.UseLegacySql) assert.Equal(t, 1, len(jobConfigurationQuery.QueryParameters)) assert.Equal(t, bigquery.QueryParameter{ Name: "integer", diff --git a/go/tasks/testdata/config.yaml b/go/tasks/testdata/config.yaml index fdc326641..208e4df49 100755 --- a/go/tasks/testdata/config.yaml +++ b/go/tasks/testdata/config.yaml @@ -56,6 +56,7 @@ plugins: fsGroup: 2000 default-security-context: allowPrivilegeEscalation: false + enable-host-networking-pod: true # Spark Plugin configuration spark: spark-config-default: diff --git a/tests/end_to_end.go b/tests/end_to_end.go index b4ff1a19b..df8d9e653 100644 --- a/tests/end_to_end.go +++ b/tests/end_to_end.go @@ -161,6 +161,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i tMeta.OnGetOverrides().Return(overrides) tMeta.OnGetK8sServiceAccount().Return("s") tMeta.OnGetNamespace().Return("fake-development") + tMeta.OnGetMaxAttempts().Return(2) tMeta.OnGetSecurityContext().Return(idlCore.SecurityContext{ RunAs: &idlCore.Identity{ K8SServiceAccount: "s",