diff --git a/internal/docker.go b/internal/docker.go index d77e261..4af67ef 100644 --- a/internal/docker.go +++ b/internal/docker.go @@ -102,6 +102,7 @@ func (d *DockerRun) Run( runCommand string, runCommandArgs []string, exposePort int, + gpuIDs []string, ) error { fmt.Printf("killing container %s\n", containerName) @@ -148,8 +149,9 @@ func (d *DockerRun) Run( if _, err := os.Stat("/dev/nvidia0"); err == nil { fmt.Printf("host has gpu, adding gpu to device requests\n") dr = append(dr, container.DeviceRequest{ - Count: -1, + Count: len(gpuIDs), Capabilities: [][]string{{"gpu"}}, + DeviceIDs: gpuIDs, }) } else { fmt.Printf("host does not have gpu, not adding gpu to device requests\n") diff --git a/internal/misc.go b/internal/misc.go index 346e7d4..d0d9d50 100644 --- a/internal/misc.go +++ b/internal/misc.go @@ -193,3 +193,11 @@ func parseOrExitInternal[T ~string | ~int | ~[]string](cmd *cobra.Command, flag return nil, false } + +func toStringSlice[T any](slice []T) []string { + var stringSlice []string + for _, v := range slice { + stringSlice = append(stringSlice, fmt.Sprintf("%v", v)) + } + return stringSlice +} diff --git a/internal/run.go b/internal/run.go index 0e716ff..bd0d42b 100644 --- a/internal/run.go +++ b/internal/run.go @@ -2,6 +2,7 @@ package internal import ( "context" + "encoding/json" "fmt" "os" "strings" @@ -10,62 +11,90 @@ import ( type RunArgs struct { ProjectName string `validate:"required,varname"` Hosts []string `validate:"required"` - NProcPerNode int `validate:"required,min=1"` + NProcPerNode string `validate:"required"` ExperimentName string `validate:"required,varname"` Port int `validate:"required,min=1"` RunName string `validate:"required,varname"` MaxRepeats int `validate:"required,min=-1"` Rest []string - ContainerName *string + ContainerName *string } const runScript = `#!/usr/bin/env python from higgsfield.internal.main import cli; cli() ` + func nameFromRunArgs(args RunArgs) string { - if args.ContainerName != nil && *args.ContainerName != "" { + if args.ContainerName != nil && *args.ContainerName != "" { return *args.ContainerName - } + } return DefaultProjExpContainerName(args.ProjectName, args.ExperimentName) } func trimPathForLength(path string, length int) string { - // check if path is less than length - if len(path) < length { - return path - } - - // get rid of home directory and replace is with ~ - // e.g. /home/user/... -> ~/... - if path[0] == '/' { - path = path[1:] - } - - branches := strings.Split(path, "/") - slashes := len(branches) - 1 - if slashes == 0 { - return path[:length] - } - - if branches[0] == "home" { - path = "~/" + strings.Join(branches[2:], "/") - } - - if len(path) < length { - return path - } - - return path[:length] + "..." + // check if path is less than length + if len(path) < length { + return path + } + + // get rid of home directory and replace is with ~ + // e.g. /home/user/... -> ~/... + if path[0] == '/' { + path = path[1:] + } + + branches := strings.Split(path, "/") + slashes := len(branches) - 1 + if slashes == 0 { + return path[:length] + } + + if branches[0] == "home" { + path = "~/" + strings.Join(branches[2:], "/") + } + + if len(path) < length { + return path + } + + return path[:length] + "..." +} + +type nProcPerNode map[string][]int + +// parseNProcPerNode converts +func parseNProcPerNode(host, nppn string) []int { + var procMap nProcPerNode + if err := json.Unmarshal([]byte(nppn), &procMap); err != nil { + fmt.Printf("failed to parse nProcPerNode: %v\n", err) + os.Exit(1) + } + + hostNProc, ok := procMap[host] + if !ok { + fmt.Printf("failed to find host %s in nProcPerNode map\n", host) + os.Exit(1) + } + + return hostNProc } func Run(args RunArgs) { if err := Validator().Struct(args); err != nil { panic(err) } - - master := args.Hosts[0] + + myIP, err := myPublicIP() + if err != nil { + fmt.Printf("failed to get my public IP: %v\n", err) + os.Exit(1) + } + + gpuIDs := parseNProcPerNode(myIP, args.NProcPerNode) + + master := args.Hosts[0] rank := 0 if len(args.Hosts) > 1 { @@ -88,7 +117,7 @@ func Run(args RunArgs) { os.Exit(1) } - containerName := nameFromRunArgs(args) + containerName := nameFromRunArgs(args) fmt.Printf(` ╔══════════════════════════════════════════════════════════════════════════════════════════════════════ @@ -100,9 +129,15 @@ func Run(args RunArgs) { ║ > RUN NAME = %s ║ > CONTAINER NAME = %s ║ > MODEL CHKPT PATH = %s -║ +║ > GPU IDs = %v ╚══════════════════════════════════════════════════════════════════════════════════════════════════════ -`, args.ExperimentName, args.RunName, containerName, trimPathForLength(checkpointDir, 70)) +`, + args.ExperimentName, + args.RunName, + containerName, + trimPathForLength(checkpointDir, 70), + gpuIDs, + ) cmd, cmdArgs := buildArgs( nodeNum, @@ -110,7 +145,7 @@ func Run(args RunArgs) { master, args.Port, []string{"hf.py", "run"}, - args.NProcPerNode, + len(gpuIDs), args.ExperimentName, args.RunName, args.MaxRepeats, @@ -132,8 +167,18 @@ func Run(args RunArgs) { f.Write([]byte(runScript)) - dr := NewDockerRun(context.Background(), args.ProjectName, cwd, hostCachePath) - if err := dr.Run(containerName, cmd, cmdArgs, args.Port); err != nil { + dr := NewDockerRun( + context.Background(), + args.ProjectName, + cwd, + hostCachePath) + if err := dr.Run( + containerName, + cmd, + cmdArgs, + args.Port, + toStringSlice(gpuIDs), + ); err != nil { fmt.Printf("error occured while running experiment: %+v\n", err) os.Exit(1) }