Skip to content

Commit

Permalink
Merge pull request #17 from higgsfield-ai/feat/no-python_wa_master-host
Browse files Browse the repository at this point in the history
add args
  • Loading branch information
arpanetus authored Mar 23, 2024
2 parents bd31536 + e5b8d93 commit 7a192c0
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 81 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/go-envparse v0.1.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/klauspost/compress v1.17.6 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/go-envparse v0.1.0 h1:bE++6bhIsNCPLvgDZkYqo3nA+/PFI51pkrHdmPSDFPY=
github.com/hashicorp/go-envparse v0.1.0/go.mod h1:OHheN1GoygLlAkTlXLXvAdnXdZxy8JUweQ1rAXx1xnc=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
Expand Down
139 changes: 93 additions & 46 deletions internal/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/docker/docker/pkg/archive"
units "github.com/docker/go-units"
"github.com/pkg/errors"
"path/filepath"
)

type DockerRun struct {
Expand Down Expand Up @@ -167,18 +168,80 @@ func createDeviceMapping(devices []string) []container.DeviceMapping {
return mappings
}

func (d *DockerRun) Run(
containerName string,
runCommand string,
runCommandArgs []string,
exposePort int,
) error {
var ldMap = map[string]string{
"/var/lib/nvidia/lib64": "/usr/local/nvidia/lib64",
"/var/lib/tcpx": "/usr/local/tcpx",
"/run/tcpx": "/run/tcpx",
}

fmt.Printf("killing container %s\n", containerName)
if err := d.Kill(containerName); err != nil {
return errors.WithMessagef(err, "failed to kill container %s", containerName)
func ldBinds() []string {
binds := make([]string, 0, len(ldMap))
for host, guest := range ldMap {
// check if host path exists
if _, err := os.Stat(host); err != nil {
continue
}

fmt.Printf("adding bind: %s:%s\n", host, guest)

binds = append(binds, fmt.Sprintf("%s:%s", host, guest))
}

return binds
}

func capAdd() []string {
return []string{
"NET_ADMIN",
"SYS_ADMIN",
"SYS_PTRACE",
"IPC_LOCK",
}
}

func (d *DockerRun) volbinds() []string {
binds := []string{
fmt.Sprintf("%s:%s", d.hostRootPath, d.guestRootPath),
fmt.Sprintf("%s:%s", d.hostCachePath, d.guestCachePath),
fmt.Sprintf("%s:%s", d.hostCachePath, guestRootCachePath),
}

binds = append(binds, ldBinds()...)

return binds
}

func (d *DockerRun) deviceMapsAndRequests() ([]container.DeviceMapping, []container.DeviceRequest) {
// You can't run invoker on cos that natively, but there's still a workaround :D
cos, _ := isCos()

// check if host has gpu
// if yes, add gpu to device requests
// else, don't add gpu to device requests
// this is a hacky way to get around the fact that docker doesn't support
// gpu passthrough on macos
dr := make([]container.DeviceRequest, 0, 1)
dm := make([]container.DeviceMapping, 0, 1)
if _, err := os.Stat("/dev/nvidia0"); err == nil {
fmt.Printf("host has gpu, adding gpu to device requests\n")
if !cos {
dr = append(dr, container.DeviceRequest{
Count: -1,
Capabilities: [][]string{{"gpu"}},
})
}
// usually there's no need to add additional devices on bare-metal
// but with tcpx setup we need to add other nvidia-ish devices
dm = append(dm, createDeviceMapping(listNvidiaGPUs())...)
dm = append(dm, createDeviceMapping(listOtherNvidiaDevices())...)
} else {
fmt.Printf("host does not have gpu, not adding gpu to device requests\n")
}

return dm, dr
}

func (d *DockerRun) build() error {
buildCtx, err := archive.TarWithOptions(d.hostRootPath, &archive.TarOptions{})
if err != nil {
panic(err)
Expand Down Expand Up @@ -208,41 +271,28 @@ func (d *DockerRun) Run(
return errors.WithMessagef(err, "failed to build image %s", d.imageTag)
}

// check if host has gpu
// if yes, add gpu to device requests
// else, don't add gpu to device requests
// this is a hacky way to get around the fact that docker doesn't support
// gpu passthrough on macos
dr := make([]container.DeviceRequest, 0, 1)
cos, _ := isCos()
dm := make([]container.DeviceMapping, 0, 1)
if _, err := os.Stat("/dev/nvidia0"); err == nil {
fmt.Printf("host has gpu, adding gpu to device requests\n")
if cos {
fmt.Printf("host is cos, not adding gpu to device requests\n")
} else {
dr = append(dr, container.DeviceRequest{
Count: -1,
Capabilities: [][]string{{"gpu"}},
})
}
// usually there's no need to add additional devices on bare-metal
// but with tcpx setup we need to add other nvidia-ish devices
dm = append(dm, createDeviceMapping(listNvidiaGPUs())...)
dm = append(dm, createDeviceMapping(listOtherNvidiaDevices())...)
} else {
fmt.Printf("host does not have gpu, not adding gpu to device requests\n")
return nil
}

func (d *DockerRun) Run(
containerName string,
runCommand string,
runCommandArgs []string,
exposePort int,
) error {
fmt.Printf("killing container %s\n", containerName)
if err := d.Kill(containerName); err != nil {
return errors.WithMessagef(err, "failed to kill container %s", containerName)
}

binds := []string{
fmt.Sprintf("%s:%s", d.hostRootPath, d.guestRootPath),
fmt.Sprintf("%s:%s", d.hostCachePath, d.guestCachePath),
fmt.Sprintf("%s:%s", d.hostCachePath, guestRootCachePath),
if err := d.build(); err != nil {
return errors.WithMessagef(err, "failed to build image %s", d.imageTag)
}

if _, err := os.Stat("/run/tcpx"); cos && err == nil {
fmt.Printf("host is cos, adding /run/tcpx to binds\n")
binds = append(binds, "/run/tcpx:/run/tcpx")
dm, dr := d.deviceMapsAndRequests()
envVars, err := loadEnvFile(filepath.Join(d.hostRootPath, "nccl_config_env"))
if err != nil {
return errors.WithMessagef(err, "failed to load nccl_config_env file")
}

fmt.Printf("creating container %s\n", containerName)
Expand All @@ -251,13 +301,14 @@ func (d *DockerRun) Run(
Config: &container.Config{
Image: d.imageTag,
Entrypoint: append([]string{runCommand}, runCommandArgs...),
Env: envVars,
},
HostConfig: &container.HostConfig{
Binds: binds,
Binds: d.volbinds(),
IpcMode: container.IPCModeHost,
PidMode: container.PidMode("host"),
NetworkMode: container.NetworkMode("host"),
CapAdd: []string{"NET_ADMIN"},
CapAdd: capAdd(),
Resources: container.Resources{
DeviceRequests: dr,
Ulimits: []*units.Ulimit{
Expand Down Expand Up @@ -292,7 +343,3 @@ func (d *DockerRun) Run(

return nil
}

func PtrTo[T any](e T) *T {
return &e
}
37 changes: 35 additions & 2 deletions internal/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"os"

envparse "github.com/hashicorp/go-envparse"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"slices"
Expand All @@ -17,6 +18,38 @@ import (

const url = "https://api.ipify.org"

func PtrTo[T any](e T) *T {
return &e
}

func loadEnvFile(path string) ([]string, error) {
// check if the file exists
// if it does not exist, return an empty slice
if _, err := os.Stat(path); os.IsNotExist(err) {
return []string{}, nil
}

// open file
file, err := os.Open(path)
if err != nil {
return nil, errors.WithMessage(err, "failed to open env file")
}

defer file.Close()

envs, err := envparse.Parse(file)
if err != nil {
return nil, errors.WithMessage(err, "failed to parse env file")
}

var lines []string
for key, value := range envs {
lines = append(lines, key+"="+value)
}

return lines, nil
}

func myPublicIP() (string, error) {
resp, err := http.Get(url)
if err != nil {
Expand Down Expand Up @@ -152,8 +185,8 @@ func exitIfError(flag string, err error) {
func nothingIfError(flag string, err error) {}

func ParseOrNil[T ~string | ~int | ~[]string](cmd *cobra.Command, flag string) *T {
// TODO: buddy, need to fix this
got, ok := parseOrExitInternal[T](cmd, flag, false)
// TODO: buddy, need to fix this
got, ok := parseOrExitInternal[T](cmd, flag, false)
if !ok {
return nil
}
Expand Down
Loading

0 comments on commit 7a192c0

Please sign in to comment.