Skip to content

Commit

Permalink
fix(lightning_train_remote): accept dict / PartialBuilder spec
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Sep 6, 2023
1 parent 1b4ac20 commit e118ae3
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions zetta_utils/training/lightning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from kubernetes import client as k8s_client # type: ignore
from zetta_utils import builder, load_all_modules, log, mazepa, parsing
from zetta_utils.builder.build import BuilderPartial
from zetta_utils.cloud_management import execution_tracker, resource_allocation

logger = log.get_logger("zetta_utils")
Expand Down Expand Up @@ -275,7 +276,7 @@ def _create_ddp_master_job(
def lightning_train_remote(
worker_image: str,
worker_resources: dict,
spec_path: str,
spec_path: str | dict | BuilderPartial,
num_nodes: int = 1,
env_vars: Optional[Dict[str, str]] = None,
worker_cluster_name: Optional[str] = None,
Expand All @@ -292,7 +293,12 @@ def lightning_train_remote(
execution_id = mazepa.id_generation.get_unique_id(
prefix="exec", slug_len=4, add_uuid=False, max_len=50
)
spec = parsing.cue.load(spec_path)
if isinstance(spec_path, str):
spec = parsing.cue.load(spec_path)
elif isinstance(spec_path, dict):
spec = spec_path
elif isinstance(spec_path, BuilderPartial):
spec = spec_path.spec

_create_ddp_master_job(
execution_id,
Expand Down

0 comments on commit e118ae3

Please sign in to comment.