diff --git a/zetta_utils/cloud/execution_tracker.py b/zetta_utils/cloud/execution_tracker.py index b87409153..b09136b2e 100644 --- a/zetta_utils/cloud/execution_tracker.py +++ b/zetta_utils/cloud/execution_tracker.py @@ -1,6 +1,7 @@ import json import os import time +from contextlib import contextmanager from datetime import datetime from enum import Enum from typing import Mapping @@ -9,6 +10,7 @@ import fsspec from cloudfiles import paths +from zetta_utils.common import RepeatTimer from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer from zetta_utils.layer.db_layer.datastore import DatastoreBackend from zetta_utils.log import get_logger @@ -110,3 +112,18 @@ def record_execution_run(execution_id: str) -> None: # pragma: no cover with fsspec.open(info_path, "w") as f: json.dump(execution_run, f, indent=2) + + +@contextmanager +def heartbeat_tracking_ctx_mngr(execution_id, heartbeat_interval=30): + def _send_heartbeat(): + update_execution_heartbeat(execution_id) + + heart = RepeatTimer(heartbeat_interval, _send_heartbeat) + heart.start() + try: + yield + except Exception as e: + raise e from None + finally: + heart.cancel() diff --git a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py index b7fd35283..5c303950c 100644 --- a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py +++ b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py @@ -2,12 +2,11 @@ import copy import os -from contextlib import AbstractContextManager, ExitStack, contextmanager +from contextlib import AbstractContextManager, ExitStack from typing import Dict, Final, Iterable, Optional, Union from zetta_utils import builder, log, mazepa from zetta_utils.cloud import execution_tracker, resource_allocation -from zetta_utils.common import RepeatTimer logger = log.get_logger("zetta_utils") @@ -96,21 +95,6 @@ def get_gcp_with_sqs_config( return exec_queue, ctx_managers -@contextmanager -def heartbeat_tracking_ctx_mngr(execution_id, heartbeat_interval=30): - def _send_heartbeat(): - execution_tracker.update_execution_heartbeat(execution_id) - - heart = RepeatTimer(heartbeat_interval, _send_heartbeat) - heart.start() - try: - yield - except Exception as e: - raise e from None - finally: - heart.cancel() - - @builder.register("mazepa.execute_on_gcp_with_sqs") def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals target: Union[mazepa.Flow, mazepa.ExecutionState], @@ -174,7 +158,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals ) with ExitStack() as stack: - stack.enter_context(heartbeat_tracking_ctx_mngr(execution_id)) + stack.enter_context(execution_tracker.heartbeat_tracking_ctx_mngr(execution_id)) for mngr in ctx_managers: stack.enter_context(mngr) diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index ab97df5a7..bcbc86590 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -12,7 +12,7 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem from kubernetes import client as k8s_client # type: ignore -from zetta_utils import builder, log, parsing +from zetta_utils import builder, log, mazepa, parsing from zetta_utils.cloud import resource_allocation DEFAULT_GCP_CLUSTER_NAME: Final = "zutils-x3" @@ -88,7 +88,11 @@ def lightning_train( @builder.register("lightning_train_remote") @typeguard.typechecked def lightning_train_remote(image: str, resources: dict, spec_path: str) -> None: - execution_id = "test" + + execution_id = mazepa.id_generation.get_unique_id( + prefix="exec", slug_len=4, add_uuid=False, max_len=50 + ) + spec = parsing.cue.load(spec_path) configmap = resource_allocation.k8s.get_configmap( name=execution_id, data={"spec.cue": json.dumps(spec)}