Skip to content

Commit

Permalink
refactor(diffusers): add cache and fix copy code
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemarsden authored and philwinder committed Nov 27, 2024
1 parent 8b2c3c4 commit d2210b7
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 33 deletions.
40 changes: 25 additions & 15 deletions api/cmd/helix/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,27 +306,37 @@ func runnerCLI(cmd *cobra.Command, options *RunnerOptions) error {
// inbuiltModelsDirectory directory inside the Docker image that can have
// a cache of models that are already downloaded during the build process.
// These files need to be copied into runner cache dir
const inbuiltModelsDirectory = "/workspace/ollama"
var bakedModelDirectories = []string{"/workspace/ollama", "/workspace/diffusers"}

func initializeModelsCache(cfg *config.RunnerConfig) error {
log.Info().Msgf("Copying baked models from %s into cache dir %s - this may take a while...", inbuiltModelsDirectory, cfg.CacheDir)
_, err := os.Stat(inbuiltModelsDirectory)
if err != nil {
if os.IsNotExist(err) {
// If the directory doesn't exist, nothing to do
return nil
log.Info().Msgf("Copying baked models from %v into container cache dir %s - this may take a while the first time...", bakedModelDirectories, cfg.CacheDir)

for _, dir := range bakedModelDirectories {
// If the directory doesn't exist, nothing to do
_, err := os.Stat(dir)
if err != nil {
if os.IsNotExist(err) {
log.Debug().Msgf("Baked models directory %s does not exist", dir)
continue
}
return fmt.Errorf("error checking inbuilt models directory: %w", err)
}
return fmt.Errorf("error checking inbuilt models directory: %w", err)
}

// Check if the cache dir exists, if not create it
if _, err := os.Stat(cfg.CacheDir); os.IsNotExist(err) {
err = os.MkdirAll(cfg.CacheDir, 0755)
// Check if the cache dir exists, if not create it
if _, err := os.Stat(cfg.CacheDir); os.IsNotExist(err) {
err = os.MkdirAll(cfg.CacheDir, 0755)
if err != nil {
return fmt.Errorf("error creating cache dir: %w", err)
}
}

// Copy the directory from the Docker image into the cache dir
log.Debug().Msgf("Copying %s into container dir %s", dir, cfg.CacheDir)
err = copydir.CopyDir(cfg.CacheDir, dir)
if err != nil {
return err
return fmt.Errorf("error copying inbuilt models directory: %w", err)
}
}

return copydir.CopyDir(cfg.CacheDir, inbuiltModelsDirectory)

return nil
}
3 changes: 2 additions & 1 deletion api/pkg/runner/diffusers_model_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,10 @@ func (i *DiffusersModelInstance) Start(ctx context.Context) error {
cmd.Dir = "/workspace/helix/runner/helix-diffusers"

cmd.Env = append(cmd.Env,
fmt.Sprintf("CACHE_DIR=%s", path.Join(i.runnerOptions.Config.CacheDir, "hub")), // Mimic the diffusers library's default cache dir
fmt.Sprintf("MODEL_ID=%s", i.initialSession.ModelName),
// Add the HF_TOKEN environment variable which is required by the diffusers library
fmt.Sprintf("HF_TOKEN=hf_ISxQhTIkdWkfZgUFPNUwVtHrCpMiwOYPIEKEN=%s", os.Getenv("HF_TOKEN")),
fmt.Sprintf("HF_TOKEN=%s", os.Getenv("HF_TOKEN")),
// Set python to be unbuffered so we get logs in real time
"PYTHONUNBUFFERED=1",
)
Expand Down
31 changes: 16 additions & 15 deletions api/pkg/util/copydir/copy_dir.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,6 @@ func CopyDir(dst, src string) error {
// destination with the path without the src on it.
dstPath := filepath.Join(dst, path[len(src):])

// If dstPath exists and has the same size as path, don't copy it again.
// We're mainly copying content addressed blobs here, so this is
// probably fine.
dstInfo, err := os.Stat(dstPath)
if err == nil && dstInfo.Size() == info.Size() {
return nil
}

// we don't want to try and copy the same file over itself.
if eq, err := SameFile(path, dstPath); eq {
return nil
} else if err != nil {
return err
}

// If we have a directory, make that subdirectory, then continue
// the walk.
if info.IsDir() {
Expand All @@ -65,6 +50,22 @@ func CopyDir(dst, src string) error {
return nil
}

// If dstPath exists and has the same size as path, don't copy it again.
// We're mainly copying content addressed blobs here, so this is
// probably fine.
// Must use Lstat to get the file status here in case the file is a symlink
dstInfo, err := os.Lstat(dstPath)
if err == nil && dstInfo.Size() == info.Size() {
return nil
}

// we don't want to try and copy the same file over itself.
if eq, err := SameFile(path, dstPath); eq {
return nil
} else if err != nil {
return err
}

// If the current path is a symlink, recreate the symlink relative to
// the dst directory
if info.Mode()&os.ModeSymlink == os.ModeSymlink {
Expand Down
1 change: 0 additions & 1 deletion docker-compose.runner.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
version: '3'
services:
runner:
restart: always
Expand Down
9 changes: 8 additions & 1 deletion runner/helix-diffusers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
server_port = int(os.getenv("SERVER_PORT", 8000))
server_url = f"http://{server_host}:{server_port}"
model_id = os.getenv("MODEL_ID", "stabilityai/sd-turbo")
cache_dir = os.getenv("CACHE_DIR", "/root/.cache/huggingface/hub")

# Check that the cache dir exists
if not os.path.exists(cache_dir):
raise RuntimeError(f"Cache directory {cache_dir} does not exist")


class TextToImageInput(BaseModel):
Expand All @@ -39,7 +44,7 @@ def __init__(self):
logging.info("Pipeline instance created")

def start(self, model_id: str):
logging.info(f"Starting pipeline with model: {model_id}")
logging.info(f"Starting pipeline for model {model_id}, cache dir: {cache_dir}")
try:
if torch.cuda.is_available():
logger.info("Loading CUDA")
Expand All @@ -48,6 +53,7 @@ def start(self, model_id: str):
model_id,
torch_dtype=torch.bfloat16,
local_files_only=True,
cache_dir=cache_dir,
).to(device=self.device)
elif torch.backends.mps.is_available():
logger.info("Loading MPS for Mac M Series")
Expand All @@ -56,6 +62,7 @@ def start(self, model_id: str):
model_id,
torch_dtype=torch.bfloat16,
local_files_only=True,
cache_dir=cache_dir,
).to(device=self.device)
else:
raise Exception("No CUDA or MPS device available")
Expand Down

0 comments on commit d2210b7

Please sign in to comment.