Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support Prefect Workers #41

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion block_cascade/executors/vertex/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def create_job(self) -> VertexJob:
raise RuntimeError(
f"Unable to parse bucket from storage block: {storage}"
)
deployment_path = deployment.path.rstrip("/")
deployment_path = storage.data.get("bucket_folder").rstrip("/") or deployment.path.rstrip("/")

package_path = f"{bucket}/{deployment_path}/{module_name}"
self._logger.info(
Expand Down
15 changes: 13 additions & 2 deletions block_cascade/prefect/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ async def _fetch_block(block_id: str) -> Optional[BlockDocument]:
async with get_client() as client:
return await client.read_block_document(block_id)

async def _fetch_block_by_name(block_name: str, block_type_slug: str = "gcs-bucket") -> Optional[BlockDocument]:
async with get_client() as client:
return await client.read_block_document_by_name(
name=block_name,
block_type_slug=block_type_slug,
)

def get_from_prefect_context(attr: str, default: str = "") -> str:
flow_context = FlowRunContext.get()
Expand Down Expand Up @@ -80,8 +86,13 @@ def get_storage_block() -> Optional[BlockDocument]:

global _CACHED_STORAGE # noqa: PLW0603
if not _CACHED_STORAGE:
_CACHED_STORAGE = run_async(
_fetch_block(current_deployment.storage_document_id)
if current_deployment.pull_steps:
_CACHED_STORAGE = run_async(
_fetch_block_by_name(block_name=current_deployment.pull_steps[0]["prefect.deployments.steps.pull_with_block"]["block_document_name"])
)
else:
_CACHED_STORAGE = run_async(
_fetch_block(block_id=current_deployment.storage_document_id)
)
return _CACHED_STORAGE

Expand Down
68 changes: 46 additions & 22 deletions block_cascade/prefect/v2/environment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, Optional

from prefect.context import FlowRunContext

Expand All @@ -24,47 +24,71 @@ class PrefectEnvironmentClient(VertexAIEnvironmentInfoProvider):

def __init__(self):
self._current_deployment = None
self._current_job_variables = None
self._current_infrastructure = None

def get_container_image(self) -> Optional[str]:
infra = self._get_infrastructure_block()
if not infra:
return
job_variables = self._get_job_variables()
if job_variables:
return job_variables.get("image")

deployment_details = infra.data
return deployment_details.get("image")
infra = self._get_infrastructure_block()
if infra:
return infra.data.get("image")
return None

def get_network(self) -> Optional[str]:
job_variables = self._get_job_variables()
if job_variables:
return job_variables.get("network")

infra = self._get_infrastructure_block()
if not infra:
return
if infra:
return infra.data.get("network")

deployment_details = infra.data
return deployment_details.get("network")
return None

def get_project(self) -> Optional[str]:
job_variables = self._get_job_variables()
if job_variables:
return job_variables.get("credentials", {}).get("project")

infra = self._get_infrastructure_block()
if not infra:
return
if infra:
return infra.data.get("gcp_credentials", {}).get("project")

deployment_details = infra.data
return deployment_details.get("gcp_credentials", {}).get("project")
return None

def get_region(self) -> Optional[str]:
job_variables = self._get_job_variables()
if job_variables:
return job_variables.get("region")

infra = self._get_infrastructure_block()
if not infra:
return
if infra:
return infra.data.get("region")

deployment_details = infra.data
return deployment_details.get("region")
return None

def get_service_account(self) -> Optional[str]:
job_variables = self._get_job_variables()
if job_variables:
return job_variables.get("service_account_name")

infra = self._get_infrastructure_block()
if not infra:
return
if infra:
return infra.data.get("service_account")

return None

def _get_job_variables(self) -> Optional[Dict]:
current_deployment = self._get_current_deployment()
if not current_deployment:
return None

deployment_details = infra.data
return deployment_details.get("service_account")
if not self._current_job_variables:
self._current_job_variables = current_deployment.job_variables
return self._current_job_variables

def _get_infrastructure_block(self) -> Optional[BlockDocument]:
current_deployment = self._get_current_deployment()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "block-cascade"
packages = [
{include = "block_cascade"}
]
version = "2.6.0"
version = "2.6.1"
description = "Library for model training in multi-cloud environment."
readme = "README.md"
authors = ["Block"]
Expand Down
89 changes: 89 additions & 0 deletions tests/test_prefect_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
from unittest.mock import Mock, patch

from prefect.client.schemas.responses import DeploymentResponse
from prefect.context import FlowRunContext

from block_cascade.prefect.v2.environment import PrefectEnvironmentClient


@pytest.fixture(autouse=True)
def mock_infrastructure_block():
infra_block = Mock()
infra_block.data = {
"image": "infra_image",
"network": "infra_network",
"gcp_credentials": {"project": "infra_project"},
"region": "infra_region",
"service_account": "infra_service_account"
}
with patch("block_cascade.prefect.v2.environment._fetch_block", return_value=infra_block):
yield infra_block

@pytest.fixture
def mock_job_variables():
return {
"image": "job_image",
"network": "job_network",
"credentials": {"project": "job_project"},
"region": "job_region",
"service_account_name": "job_service_account"
}

@pytest.fixture
def mock_deployment_response(mock_job_variables):
mock_deployment = Mock(spec=DeploymentResponse)
mock_deployment.job_variables = mock_job_variables
mock_deployment.infrastructure_document_id = "mock_infrastructure_id"
return mock_deployment

@pytest.fixture(autouse=True)
def mock__fetch_deployment(mock_deployment_response):
with patch("block_cascade.prefect.v2.environment._fetch_deployment", return_value=mock_deployment_response):
yield

@pytest.fixture(autouse=True)
def mock_flow_run_context():
mock_flow_run = Mock()
mock_flow_run.deployment_id = "mock_deployment_id"

mock_context = Mock(spec=FlowRunContext)
mock_context.flow_run = mock_flow_run

with patch("block_cascade.prefect.v2.environment.FlowRunContext.get", return_value=mock_context):
yield mock_context

def test_get_container_image():
client = PrefectEnvironmentClient()

assert client.get_container_image() == "job_image"

def test_get_network():
client = PrefectEnvironmentClient()

assert client.get_network() == "job_network"

def test_get_project():
client = PrefectEnvironmentClient()

assert client.get_project() == "job_project"

def test_get_region():
client = PrefectEnvironmentClient()

assert client.get_region() == "job_region"

def test_get_service_account():
client = PrefectEnvironmentClient()

assert client.get_service_account() == "job_service_account"

def test_fallback_to_infrastructure(mock_deployment_response):
client = PrefectEnvironmentClient()
mock_deployment_response.job_variables = None

assert client.get_container_image() == "infra_image"
assert client.get_network() == "infra_network"
assert client.get_project() == "infra_project"
assert client.get_region() == "infra_region"
assert client.get_service_account() == "infra_service_account"
Loading