Skip to content

Commit

Permalink
fix: heartbeat and id gen
Browse files Browse the repository at this point in the history
  • Loading branch information
akhileshh committed Jul 28, 2023
1 parent 495c183 commit dcef011
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
17 changes: 17 additions & 0 deletions zetta_utils/cloud/execution_tracker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import time
from contextlib import contextmanager
from datetime import datetime
from enum import Enum
from typing import Mapping
Expand All @@ -9,6 +10,7 @@
import fsspec
from cloudfiles import paths

from zetta_utils.common import RepeatTimer
from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer
from zetta_utils.layer.db_layer.datastore import DatastoreBackend
from zetta_utils.log import get_logger
Expand Down Expand Up @@ -110,3 +112,18 @@ def record_execution_run(execution_id: str) -> None: # pragma: no cover

with fsspec.open(info_path, "w") as f:
json.dump(execution_run, f, indent=2)


@contextmanager
def heartbeat_tracking_ctx_mngr(execution_id, heartbeat_interval=30):
def _send_heartbeat():
update_execution_heartbeat(execution_id)

Check warning on line 120 in zetta_utils/cloud/execution_tracker.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/cloud/execution_tracker.py#L119-L120

Added lines #L119 - L120 were not covered by tests

heart = RepeatTimer(heartbeat_interval, _send_heartbeat)
heart.start()
try:
yield
except Exception as e:
raise e from None

Check warning on line 127 in zetta_utils/cloud/execution_tracker.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/cloud/execution_tracker.py#L122-L127

Added lines #L122 - L127 were not covered by tests
finally:
heart.cancel()

Check warning on line 129 in zetta_utils/cloud/execution_tracker.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/cloud/execution_tracker.py#L129

Added line #L129 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import copy
import os
from contextlib import AbstractContextManager, ExitStack, contextmanager
from contextlib import AbstractContextManager, ExitStack
from typing import Dict, Final, Iterable, Optional, Union

from zetta_utils import builder, log, mazepa
from zetta_utils.cloud import execution_tracker, resource_allocation
from zetta_utils.common import RepeatTimer

logger = log.get_logger("zetta_utils")

Expand Down Expand Up @@ -96,21 +95,6 @@ def get_gcp_with_sqs_config(
return exec_queue, ctx_managers


@contextmanager
def heartbeat_tracking_ctx_mngr(execution_id, heartbeat_interval=30):
def _send_heartbeat():
execution_tracker.update_execution_heartbeat(execution_id)

heart = RepeatTimer(heartbeat_interval, _send_heartbeat)
heart.start()
try:
yield
except Exception as e:
raise e from None
finally:
heart.cancel()


@builder.register("mazepa.execute_on_gcp_with_sqs")
def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals
target: Union[mazepa.Flow, mazepa.ExecutionState],
Expand Down Expand Up @@ -174,7 +158,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals
)

with ExitStack() as stack:
stack.enter_context(heartbeat_tracking_ctx_mngr(execution_id))
stack.enter_context(execution_tracker.heartbeat_tracking_ctx_mngr(execution_id))
for mngr in ctx_managers:
stack.enter_context(mngr)

Expand Down
8 changes: 6 additions & 2 deletions zetta_utils/training/lightning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem

from kubernetes import client as k8s_client # type: ignore
from zetta_utils import builder, log, parsing
from zetta_utils import builder, log, mazepa, parsing
from zetta_utils.cloud import resource_allocation

DEFAULT_GCP_CLUSTER_NAME: Final = "zutils-x3"
Expand Down Expand Up @@ -88,7 +88,11 @@ def lightning_train(
@builder.register("lightning_train_remote")
@typeguard.typechecked
def lightning_train_remote(image: str, resources: dict, spec_path: str) -> None:
execution_id = "test"

execution_id = mazepa.id_generation.get_unique_id(
prefix="exec", slug_len=4, add_uuid=False, max_len=50
)

spec = parsing.cue.load(spec_path)
configmap = resource_allocation.k8s.get_configmap(
name=execution_id, data={"spec.cue": json.dumps(spec)}
Expand Down

0 comments on commit dcef011

Please sign in to comment.