Skip to content

Commit

Permalink
feat: pod_failure_policy to ignore disruption
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Sep 6, 2023
1 parent e118ae3 commit 541552e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
6 changes: 6 additions & 0 deletions zetta_utils/cloud_management/resource_allocation/k8s/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ def _get_job_spec(
meta: k8s_client.V1ObjectMeta,
active_deadline_seconds: Optional[int] = None,
backoff_limit: Optional[int] = 3,
pod_failure_policy: Optional[k8s_client.V1PodFailurePolicy] = None,
selector: Optional[k8s_client.V1LabelSelector] = None,
suspend: Optional[bool] = False,
):
pod_template = k8s_client.V1PodTemplateSpec(metadata=meta, spec=pod_spec)
return k8s_client.V1JobSpec(
active_deadline_seconds=active_deadline_seconds,
backoff_limit=backoff_limit,
pod_failure_policy=pod_failure_policy,
selector=selector,
suspend=suspend,
template=pod_template,
Expand All @@ -44,6 +46,7 @@ def get_job_template(
pod_spec: k8s_client.V1PodSpec,
active_deadline_seconds: Optional[int] = None,
backoff_limit: Optional[int] = 3,
pod_failure_policy: Optional[k8s_client.V1PodFailurePolicy] = None,
labels: Optional[Dict[str, str]] = None,
selector: Optional[k8s_client.V1LabelSelector] = None,
suspend: Optional[bool] = False,
Expand All @@ -54,6 +57,7 @@ def get_job_template(
meta=meta,
active_deadline_seconds=active_deadline_seconds,
backoff_limit=backoff_limit,
pod_failure_policy=pod_failure_policy,
selector=selector,
suspend=suspend,
)
Expand All @@ -65,6 +69,7 @@ def get_job(
pod_spec: k8s_client.V1PodSpec,
active_deadline_seconds: Optional[int] = None,
backoff_limit: Optional[int] = 3,
pod_failure_policy: Optional[k8s_client.V1PodFailurePolicy] = None,
labels: Optional[Dict[str, str]] = None,
selector: Optional[k8s_client.V1LabelSelector] = None,
suspend: Optional[bool] = False,
Expand All @@ -75,6 +80,7 @@ def get_job(
meta=meta,
active_deadline_seconds=active_deadline_seconds,
backoff_limit=backoff_limit,
pod_failure_policy=pod_failure_policy,
selector=selector,
suspend=suspend,
)
Expand Down
14 changes: 13 additions & 1 deletion zetta_utils/training/lightning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,19 @@ def _create_ddp_master_job(
volume_mounts=mounts,
)

train_job = resource_allocation.k8s.get_job(execution_id, pod_spec=train_pod_spec)
train_job_failure_policy = k8s_client.V1PodFailurePolicy(
rules=[
k8s_client.V1PodFailurePolicyRule(
action="Ignore",
on_pod_conditions=[
k8s_client.V1PodFailurePolicyOnPodConditionsPattern(status="True", type="DisruptionTarget")
],
)
]
)
train_job = resource_allocation.k8s.get_job(
execution_id, pod_spec=train_pod_spec, pod_failure_policy=train_job_failure_policy
)
train_job_ctx = resource_allocation.k8s.job_ctx_manager(
execution_id=execution_id,
cluster_info=cluster_info,
Expand Down

0 comments on commit 541552e

Please sign in to comment.