Skip to content

Commit

Permalink
add metadata and versioning with sha
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov committed Feb 14, 2024
1 parent 3966534 commit 08c0d73
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 3 deletions.
11 changes: 9 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ jobs:
ZENML_API_KEY: ${{ secrets.ZENML_API_KEY }}
ZENML_STAGING_STACK: ${{ secrets.ZENML_STAGING_STACK }}
ZENML_PRODUCTION_STACK: ${{ secrets.ZENML_PRODUCTION_STACK }}
ZENML_GITHUB_SHA: ${{ github.event.pull_request.head.sha }}
ZENML_GITHUB_URL_PR: ${{ github.event.pull_request._links.html.href }}
ZENML_DEBUG: true
ZENML_ANALYTICS_OPT_IN: false
ZENML_LOGGING_VERBOSITY: INFO
Expand Down Expand Up @@ -60,14 +62,19 @@ jobs:
run: |
python run.py \
--pipeline train \
--dataset staging
--dataset staging \
--version ${{ env.ZENML_GITHUB_SHA }} \
--github-pr-url ${{ env.ZENML_GITHUB_URL_PR }}
- name: Run pipeline (Production)
if: ${{ github.base_ref == 'main' }}
run: |
python run.py \
--pipeline end-to-end \
--dataset production
--dataset production \
--version ${{ env.ZENML_GITHUB_SHA }} \
--github-pr-url ${{ env.ZENML_GITHUB_URL_PR }}
- name: Read training report
id: report
Expand Down
3 changes: 3 additions & 0 deletions pipelines/end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
data_loader,
data_splitter,
decision_tree_trainer,
metadata_logger,
model_evaluator,
model_scorer,
model_train_reference_appraiser,
Expand Down Expand Up @@ -51,9 +52,11 @@ def gitflow_end_to_end_pipeline(
ignore_reference_model: bool = False,
max_train_accuracy_diff: float = 0.1,
max_test_accuracy_diff: float = 0.05,
github_pr_url: Optional[str] = None,
):
"""Train and serve a new model if it performs better than the model
currently served."""
metadata_logger(github_pr_url=github_pr_url)

data = data_loader(version=dataset_version)
served_model = served_model_loader(
Expand Down
3 changes: 3 additions & 0 deletions pipelines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
data_loader,
data_splitter,
decision_tree_trainer,
metadata_logger,
model_evaluator,
model_scorer,
model_train_appraiser,
Expand All @@ -47,8 +48,10 @@ def gitflow_training_pipeline(
ignore_reference_model: bool = False,
max_train_accuracy_diff: float = 0.1,
max_test_accuracy_diff: float = 0.05,
github_pr_url: Optional[str] = None,
):
"""Pipeline that trains and evaluates a new model."""
metadata_logger(github_pr_url=github_pr_url)
data = data_loader(version=dataset_version)
data_integrity_report = data_integrity_checker(dataset=data)
train_dataset, test_dataset = data_splitter(
Expand Down
23 changes: 22 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def main(
ignore_checks: bool = False,
model_name: str = "model",
dataset_version: Optional[str] = None,
version=None,
github_pr_url=None,
):
"""Main runner for all pipelines.
Expand All @@ -70,7 +72,7 @@ def main(
pipeline_args = {}
if disable_caching:
pipeline_args["enable_cache"] = False
pipeline_args["model"] = Model(name=MODEL_NAME)
pipeline_args["model"] = Model(name=MODEL_NAME, version=version)

docker_settings = DockerSettings(
install_stack_requirements=False,
Expand Down Expand Up @@ -98,6 +100,7 @@ def main(
ignore_model_evaluation_failures=ignore_checks,
ignore_reference_model=ignore_checks,
max_depth=5,
github_pr_url=github_pr_url,
)

if pipeline_name == Pipeline.TRAIN:
Expand Down Expand Up @@ -193,6 +196,22 @@ def main(
action="store_true",
required=False,
)
parser.add_argument(
"-gp",
"--github-pr-url",
default=None,
help="GitHub PR URL",
type=str,
required=False,
)
parser.add_argument(
"-v",
"--version",
default=None,
help="Model Version to create.",
type=str,
required=False,
)
args = parser.parse_args()

assert args.pipeline in [
Expand All @@ -207,4 +226,6 @@ def main(
ignore_checks=args.ignore_checks,
model_name=args.model,
dataset_version=args.dataset,
version=args.version,
github_pr_url=args.github_pr_url,
)
1 change: 1 addition & 0 deletions steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .data_loaders import data_loader, data_splitter
from .data_validators import data_drift_detector, data_integrity_checker
from .metadata_logger import metadata_logger
from .model_appraisers import (
model_train_appraiser,
model_train_reference_appraiser,
Expand Down
17 changes: 17 additions & 0 deletions steps/metadata_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Optional

from zenml import get_step_context, log_model_metadata, step


@step(enable_cache=False)
def metadata_logger(github_pr_url: Optional[str] = None):
model = get_step_context().model
if not model.version.isnumeric():
log_model_metadata(
{
"GitHub commit": (
f"https://github.com/zenml-io/zenml-gitflow/commit/{model.version}"
),
"GitHub PullRequest": github_pr_url,
}
)

0 comments on commit 08c0d73

Please sign in to comment.