diff --git a/controllers/nnf_workflow_controller_container_helpers.go b/controllers/nnf_workflow_controller_container_helpers.go index b23ddace7..db0d499ee 100644 --- a/controllers/nnf_workflow_controller_container_helpers.go +++ b/controllers/nnf_workflow_controller_container_helpers.go @@ -30,6 +30,7 @@ import ( "github.com/go-logr/logr" mpicommonv1 "github.com/kubeflow/common/pkg/apis/common/v1" mpiv2beta1 "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" + "go.openly.dev/pointy" batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -105,11 +106,6 @@ func (c *nnfUserContainer) createMPIJob() error { // Run the launcher on the first NNF node launcherSpec.NodeSelector = map[string]string{"kubernetes.io/hostname": c.nnfNodes[0]} - // Use initContainers to ensure the workers are up and discoverable before running the launcher command - for i := range c.nnfNodes { - c.addInitContainerWorkerWait(launcherSpec, i) - } - // Target all the NNF nodes for the workers replicas := int32(len(c.nnfNodes)) worker.Replicas = &replicas @@ -152,6 +148,11 @@ func (c *nnfUserContainer) createMPIJob() error { c.applyPermissions(launcherSpec, &mpiJob.Spec, false) c.applyPermissions(workerSpec, &mpiJob.Spec, true) + // Use an Init Container to test the waters for mpi - ensure it can contact the workers before + // the launcher tries it. Since this runs as the UID/GID, this needs to happen after the + // passwd Init Container. + c.addInitContainerWorkerWait(launcherSpec, len(c.nnfNodes)) + // Get the ports from the port manager ports, err := c.getHostPorts() if err != nil { @@ -303,37 +304,55 @@ exit 0 }) } -func (c *nnfUserContainer) addInitContainerWorkerWait(spec *corev1.PodSpec, worker int) { - // Add an initContainer to ensure that a worker pod is up and discoverable via dns. This - // assumes nslookup is available in the container. The nnf-mfu image provides this. - script := `# use nslookup to contact workers -echo "contacting $HOST..." +func (c *nnfUserContainer) addInitContainerWorkerWait(spec *corev1.PodSpec, numWorkers int) { + // Add an initContainer to ensure that the worker pods are up and discoverable via mpirun. + script := `# use mpirun to contact workers +echo "contacting $HOSTS..." for i in $(seq 1 100); do sleep 1 echo "attempt $i of 100..." - nslookup $HOST + echo "mpirun -H $HOSTS hostname" + mpirun -H $HOSTS hostname if [ $? -eq 0 ]; then - echo "successfully contacted $HOST; done" + echo "successfully contacted $HOSTS; done" exit 0 fi done -echo "failed to contact $HOST" +echo "failed to contact $HOSTS" exit 1 ` - // Build the worker's hostname.domain (e.g. nnf-container-example-worker-0.nnf-container-example-worker.default.svc) - // This name comes from mpi-operator. - host := strings.ToLower(fmt.Sprintf( - "%s-worker-%d.%s-worker.%s.svc", c.workflow.Name, worker, c.workflow.Name, c.workflow.Namespace)) - script = strings.ReplaceAll(script, "$HOST", host) + // Build a slice of the workers' hostname.domain (e.g. nnf-container-example-worker-0.nnf-container-example-worker.default.svc) + // This hostname comes from mpi-operator. + workers := []string{} + for i := 0; i < numWorkers; i++ { + host := strings.ToLower(fmt.Sprintf( + "%s-worker-%d.%s-worker.%s.svc", c.workflow.Name, i, c.workflow.Name, c.workflow.Namespace)) + workers = append(workers, host) + } + // mpirun takes a comma separated list of hosts (-H) + script = strings.ReplaceAll(script, "$HOSTS", strings.Join(workers, ",")) spec.InitContainers = append(spec.InitContainers, corev1.Container{ - Name: fmt.Sprintf("mpi-wait-for-worker-%d", worker), + Name: fmt.Sprintf("mpi-wait-for-worker-%d", numWorkers), Image: spec.Containers[0].Image, Command: []string{ "/bin/sh", "-c", script, }, + // mpirun needs this environment variable to use DNS hostnames + Env: []corev1.EnvVar{{Name: "OMPI_MCA_orte_keep_fqdn_hostnames", Value: "true"}}, + // Run this initContainer as the same UID/GID as the launcher + SecurityContext: &corev1.SecurityContext{ + RunAsUser: &c.uid, + RunAsGroup: &c.gid, + RunAsNonRoot: pointy.Bool(true), + }, + // And use the necessary volumes to support the UID/GID + VolumeMounts: []corev1.VolumeMount{ + {MountPath: "/etc/passwd", Name: "passwd", SubPath: "passwd"}, + {MountPath: "/home/mpiuser/.ssh", Name: "ssh-auth"}, + }, }) } @@ -389,16 +408,13 @@ func (c *nnfUserContainer) applyPermissions(spec *corev1.PodSpec, mpiJobSpec *mp if !worker { container.SecurityContext.RunAsUser = &c.uid container.SecurityContext.RunAsGroup = &c.gid - nonRoot := true - container.SecurityContext.RunAsNonRoot = &nonRoot - su := false - container.SecurityContext.AllowPrivilegeEscalation = &su + container.SecurityContext.RunAsNonRoot = pointy.Bool(true) + container.SecurityContext.AllowPrivilegeEscalation = pointy.Bool(false) } else { // For the worker nodes, we need to ensure we have the appropriate linux capabilities to // allow for ssh access for mpirun. Drop all capabilities and only add what is // necessary. Only do this if the Capabilities have not been set by the user. - su := true - container.SecurityContext.AllowPrivilegeEscalation = &su + container.SecurityContext.AllowPrivilegeEscalation = pointy.Bool(true) if container.SecurityContext.Capabilities == nil { container.SecurityContext.Capabilities = &corev1.Capabilities{ Drop: []corev1.Capability{"ALL"}, diff --git a/go.mod b/go.mod index ea158d25d..714c3d127 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/onsi/ginkgo/v2 v2.9.1 github.com/onsi/gomega v1.27.3 github.com/prometheus/client_golang v1.14.0 + go.openly.dev/pointy v1.3.0 go.uber.org/zap v1.24.0 golang.org/x/sync v0.1.0 k8s.io/api v0.26.1 diff --git a/go.sum b/go.sum index 0d7b65fa6..80886c8a1 100644 --- a/go.sum +++ b/go.sum @@ -233,8 +233,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= @@ -247,6 +247,8 @@ go.chromium.org/luci v0.0.0-20230227223707-c4460eb434d8/go.mod h1:vTpW7gzqLQ9mhM go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.openly.dev/pointy v1.3.0 h1:keht3ObkbDNdY8PWPwB7Kcqk+MAlNStk5kXZTxukE68= +go.openly.dev/pointy v1.3.0/go.mod h1:rccSKiQDQ2QkNfSVT2KG8Budnfhf3At8IWxy/3ElYes= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= diff --git a/vendor/go.openly.dev/pointy/.gitignore b/vendor/go.openly.dev/pointy/.gitignore new file mode 100644 index 000000000..f1c181ec9 --- /dev/null +++ b/vendor/go.openly.dev/pointy/.gitignore @@ -0,0 +1,12 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out diff --git a/vendor/go.openly.dev/pointy/LICENSE b/vendor/go.openly.dev/pointy/LICENSE new file mode 100644 index 000000000..4f639d4b8 --- /dev/null +++ b/vendor/go.openly.dev/pointy/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Mateusz Wielbut + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/go.openly.dev/pointy/README.md b/vendor/go.openly.dev/pointy/README.md new file mode 100644 index 000000000..1426a5a70 --- /dev/null +++ b/vendor/go.openly.dev/pointy/README.md @@ -0,0 +1,154 @@ +# pointy + +Simple helper functions to provide a shorthand to get a pointer to a variable holding a constant...because it's annoying when you have to do it hundreds of times in unit tests: + +```golang + +val := 42 +pointerToVal := &val +// vs. +pointerToVal := pointy.Int(42) // if using Go 1.17 or earlier w/o generics +pointerToVal := pointy.Pointer(42) // if using Go 1.18+ w/ generics +``` + +### New in release 2.0.0 + +🚨 Breaking change + +Package has changed to `go.openly.dev`. Please use +``` +import "go.openly.dev/pointy" +``` + +### New in release 1.2.0 + +Generic implementation of the pointer-to-value and value-to-pointer functions. *Requires Go 1.18+.* +The type-specific functions are still available for backwards-compatibility. + +```golang +pointerToInt := pointy.Pointer(42) +pointerToString := pointy.Pointer("foo") +// then later in your code.. +intValue := pointy.PointerValue(pointerToInt, 99) +stringValue := pointy.PointerValue(pointerToString, "bar") +``` + +Convenience functions to safely compare pointers by their dereferenced values: + +```golang +// when both values are pointers +a := pointy.Int(1) +b := pointy.Int(1) +if pointy.PointersValueEqual(a, b) { + fmt.Println("a and b contain equal dereferenced values") +} + +// or if just one is a pointer +a := pointy.Int(1) +b := 1 +if pointy.PointerValueEqual(a, b) { + fmt.Println("a and b contain equal dereferenced values") +} +``` + +### New in release 1.1.0 + +Additional helper functions have been added to safely dereference pointers +or return a fallback value: + +```golang +val := 42 +pointerToVal := &val +// then later in your code.. +myVal := pointy.IntValue(pointerToVal, 99) // returns 42 (or 99 if pointerToVal was nil) +``` + +## GoDoc + +[https://godoc.org/github.com/openly-engineering/pointy](https://pkg.go.dev/github.com/openly-engineering/pointy) + +## Installation + +`go get go.openly.dev/pointy` + +## Example + +```golang +package main + +import ( + "fmt" + + "go.openly.dev/pointy" +) + +func main() { + foo := pointy.Pointer(2018) + fmt.Println("foo is a pointer to:", *foo) + + bar := pointy.Pointer("point to me") + fmt.Println("bar is a pointer to:", *bar) + + // get the value back out (new in v1.1.0) + barVal := pointy.PointerValue(bar, "empty!") + fmt.Println("bar's value is:", barVal) +} +``` + +## Available Functions + +`Pointer[T any](x T) *T` +`PointerValue[T any](p *T, fallback T) T` +`Bool(x bool) *bool` +`BoolValue(p *bool, fallback bool) bool` +`Byte(x byte) *byte` +`ByteValue(p *byte, fallback byte) byte` +`Complex128(x complex128) *complex128` +`Complex128Value(p *complex128, fallback complex128) complex128` +`Complex64(x complex64) *complex64` +`Complex64Value(p *complex64, fallback complex64) complex64` +`Float32(x float32) *float32` +`Float32Value(p *float32, fallback float32) float32` +`Float64(x float64) *float64` +`Float64Value(p *float64, fallback float64) float64` +`Int(x int) *int` +`IntValue(p *int, fallback int) int` +`Int8(x int8) *int8` +`Int8Value(p *int8, fallback int8) int8` +`Int16(x int16) *int16` +`Int16Value(p *int16, fallback int16) int16` +`Int32(x int32) *int32` +`Int32Value(p *int32, fallback int32) int32` +`Int64(x int64) *int64` +`Int64Value(p *int64, fallback int64) int64` +`Uint(x uint) *uint` +`UintValue(p *uint, fallback uint) uint` +`Uint8(x uint8) *uint8` +`Uint8Value(p *uint8, fallback uint8) uint8` +`Uint16(x uint16) *uint16` +`Uint16Value(p *uint16, fallback uint16) uint16` +`Uint32(x uint32) *uint32` +`Uint32Value(p *uint32, fallback uint32) uint32` +`Uint64(x uint64) *uint64` +`Uint64Value(p *uint64, fallback uint64) uint64` +`String(x string) *string` +`StringValue(p *string, fallback string) string` +`Rune(x rune) *rune` +`RuneValue(p *rune, fallback rune) rune` +`PointersValueEqual[T comparable](a *T, b *T) bool` +`PointerValueEqual[T comparable](a *T, b T) bool` +## Motivation + +Creating pointers to literal constant values is useful, especially in unit tests. Go doesn't support simply using the address operator (&) to reference the location of e.g. `value := &int64(42)` so we're forced to [create](https://stackoverflow.com/questions/35146286/find-address-of-constant-in-go/35146856#35146856) [little](https://stackoverflow.com/questions/34197248/how-can-i-store-reference-to-the-result-of-an-operation-in-go/34197367#34197367) [workarounds](https://stackoverflow.com/questions/30716354/how-do-i-do-a-literal-int64-in-go/30716481#30716481). A common solution is to create a helper function: + +```golang +func createInt64Pointer(x int64) *int64 { + return &x +} +// now you can create a pointer to 42 inline +value := createInt64Pointer(42) +``` + +This package provides a library of these simple little helper functions for every native Go primitive. + +Made @ Openly. [Join us](https://careers.openly.com/) and use Go to build cool stuff. diff --git a/vendor/go.openly.dev/pointy/comparison.go b/vendor/go.openly.dev/pointy/comparison.go new file mode 100644 index 000000000..4541ab1ff --- /dev/null +++ b/vendor/go.openly.dev/pointy/comparison.go @@ -0,0 +1,25 @@ +package pointy + +// PointersValueEqual returns true if both pointer parameters are nil or contain the same dereferenced value. +func PointersValueEqual[T comparable](a *T, b *T) bool { + if a == nil && b == nil { + return true + } + if a != nil && b != nil && *a == *b { + return true + } + + return false +} + +// PointerValueEqual returns true if the pointer parameter is not nil and contains the same dereferenced value as the value parameter. +func PointerValueEqual[T comparable](a *T, b T) bool { + if a == nil { + return false + } + if *a == b { + return true + } + + return false +} diff --git a/vendor/go.openly.dev/pointy/pointy.go b/vendor/go.openly.dev/pointy/pointy.go new file mode 100644 index 000000000..0bbe4988c --- /dev/null +++ b/vendor/go.openly.dev/pointy/pointy.go @@ -0,0 +1,250 @@ +// Package pointy is a set of simple helper functions to provide a shorthand to +// get a pointer to a variable holding a constant. +package pointy + +// Bool returns a pointer to a variable holding the supplied bool constant +func Bool(x bool) *bool { + return &x +} + +// BoolValue returns the bool value pointed to by p or fallback if p is nil +func BoolValue(p *bool, fallback bool) bool { + if p == nil { + return fallback + } + return *p +} + +// Byte returns a pointer to a variable holding the supplied byte constant +func Byte(x byte) *byte { + return &x +} + +// ByteValue returns the byte value pointed to by p or fallback if p is nil +func ByteValue(p *byte, fallback byte) byte { + if p == nil { + return fallback + } + return *p +} + +// Complex128 returns a pointer to a variable holding the supplied complex128 constant +func Complex128(x complex128) *complex128 { + return &x +} + +// Complex128Value returns the complex128 value pointed to by p or fallback if p is nil +func Complex128Value(p *complex128, fallback complex128) complex128 { + if p == nil { + return fallback + } + return *p +} + +// Complex64 returns a pointer to a variable holding the supplied complex64 constant +func Complex64(x complex64) *complex64 { + return &x +} + +// Complex64Value returns the complex64 value pointed to by p or fallback if p is nil +func Complex64Value(p *complex64, fallback complex64) complex64 { + if p == nil { + return fallback + } + return *p +} + +// Float32 returns a pointer to a variable holding the supplied float32 constant +func Float32(x float32) *float32 { + return &x +} + +// Float32Value returns the float32 value pointed to by p or fallback if p is nil +func Float32Value(p *float32, fallback float32) float32 { + if p == nil { + return fallback + } + return *p +} + +// Float64 returns a pointer to a variable holding the supplied float64 constant +func Float64(x float64) *float64 { + return &x +} + +// Float64Value returns the float64 value pointed to by p or fallback if p is nil +func Float64Value(p *float64, fallback float64) float64 { + if p == nil { + return fallback + } + return *p +} + +// Int returns a pointer to a variable holding the supplied int constant +func Int(x int) *int { + return &x +} + +// IntValue returns the int value pointed to by p or fallback if p is nil +func IntValue(p *int, fallback int) int { + if p == nil { + return fallback + } + return *p +} + +// Int8 returns a pointer to a variable holding the supplied int8 constant +func Int8(x int8) *int8 { + return &x +} + +// Int8Value returns the int8 value pointed to by p or fallback if p is nil +func Int8Value(p *int8, fallback int8) int8 { + if p == nil { + return fallback + } + return *p +} + +// Int16 returns a pointer to a variable holding the supplied int16 constant +func Int16(x int16) *int16 { + return &x +} + +// Int16Value returns the int16 value pointed to by p or fallback if p is nil +func Int16Value(p *int16, fallback int16) int16 { + if p == nil { + return fallback + } + return *p +} + +// Int32 returns a pointer to a variable holding the supplied int32 constant +func Int32(x int32) *int32 { + return &x +} + +// Int32Value returns the int32 value pointed to by p or fallback if p is nil +func Int32Value(p *int32, fallback int32) int32 { + if p == nil { + return fallback + } + return *p +} + +// Int64 returns a pointer to a variable holding the supplied int64 constant +func Int64(x int64) *int64 { + return &x +} + +// Int64Value returns the int64 value pointed to by p or fallback if p is nil +func Int64Value(p *int64, fallback int64) int64 { + if p == nil { + return fallback + } + return *p +} + +// Uint returns a pointer to a variable holding the supplied uint constant +func Uint(x uint) *uint { + return &x +} + +// UintValue returns the uint value pointed to by p or fallback if p is nil +func UintValue(p *uint, fallback uint) uint { + if p == nil { + return fallback + } + return *p +} + +// Uint8 returns a pointer to a variable holding the supplied uint8 constant +func Uint8(x uint8) *uint8 { + return &x +} + +// Uint8Value returns the uint8 value pointed to by p or fallback if p is nil +func Uint8Value(p *uint8, fallback uint8) uint8 { + if p == nil { + return fallback + } + return *p +} + +// Uint16 returns a pointer to a variable holding the supplied uint16 constant +func Uint16(x uint16) *uint16 { + return &x +} + +// Uint16Value returns the uint16 value pointed to by p or fallback if p is nil +func Uint16Value(p *uint16, fallback uint16) uint16 { + if p == nil { + return fallback + } + return *p +} + +// Uint32 returns a pointer to a variable holding the supplied uint32 constant +func Uint32(x uint32) *uint32 { + return &x +} + +// Uint32Value returns the uint32 value pointed to by p or fallback if p is nil +func Uint32Value(p *uint32, fallback uint32) uint32 { + if p == nil { + return fallback + } + return *p +} + +// Uint64 returns a pointer to a variable holding the supplied uint64 constant +func Uint64(x uint64) *uint64 { + return &x +} + +// Uint64Value returns the uint64 value pointed to by p or fallback if p is nil +func Uint64Value(p *uint64, fallback uint64) uint64 { + if p == nil { + return fallback + } + return *p +} + +// String returns a pointer to a variable holding the supplied string constant +func String(x string) *string { + return &x +} + +// StringValue returns the string value pointed to by p or fallback if p is nil +func StringValue(p *string, fallback string) string { + if p == nil { + return fallback + } + return *p +} + +// Rune returns a pointer to a variable holding the supplied rune constant +func Rune(x rune) *rune { + return &x +} + +// RuneValue returns the rune value pointed to by p or fallback if p is nil +func RuneValue(p *rune, fallback rune) rune { + if p == nil { + return fallback + } + return *p +} + +// Pointer returns a pointer to a variable holding the supplied T constant +func Pointer[T any](x T) *T { + return &x +} + +// PointerValue returns the T value pointed to by p or fallback if p is nil +func PointerValue[T any](p *T, fallback T) T { + if p == nil { + return fallback + } + return *p +} diff --git a/vendor/modules.txt b/vendor/modules.txt index b5f76cbaa..5c84c5aa7 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -287,6 +287,9 @@ go.opencensus.io/internal go.opencensus.io/trace go.opencensus.io/trace/internal go.opencensus.io/trace/tracestate +# go.openly.dev/pointy v1.3.0 +## explicit; go 1.18 +go.openly.dev/pointy # go.uber.org/atomic v1.11.0 ## explicit; go 1.18 go.uber.org/atomic