Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov committed Aug 16, 2024
1 parent 2b64250 commit 947f8ff
Showing 1 changed file with 47 additions and 43 deletions.
90 changes: 47 additions & 43 deletions template/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
" # Pull required modules from this example\n",
" !git clone -b main https://github.com/zenml-io/zenml\n",
" !cp -r zenml/examples/quickstart/* .\n",
" !rm -rf zenml\n"
" !rm -rf zenml"
]
},
{
Expand All @@ -84,6 +84,7 @@
"!zenml integration install sklearn -y\n",
"\n",
"import IPython\n",
"\n",
"IPython.Application.instance().kernel.do_shutdown(restart=True)"
]
},
Expand Down Expand Up @@ -145,28 +146,22 @@
"outputs": [],
"source": [
"# Do the imports at the top\n",
"from typing_extensions import Annotated\n",
"from sklearn.datasets import load_breast_cancer\n",
"\n",
"import random\n",
"import pandas as pd\n",
"from zenml import step, pipeline, Model, get_step_context\n",
"from zenml.client import Client\n",
"from zenml.logger import get_logger\n",
"from typing import List, Optional\n",
"from uuid import UUID\n",
"\n",
"from typing import Optional, List\n",
"\n",
"from zenml import pipeline\n",
"\n",
"import pandas as pd\n",
"from sklearn.datasets import load_breast_cancer\n",
"from steps import (\n",
" data_loader,\n",
" data_preprocessor,\n",
" data_splitter,\n",
" inference_preprocessor,\n",
" model_evaluator,\n",
" inference_preprocessor\n",
")\n",
"\n",
"from typing_extensions import Annotated\n",
"from zenml import Model, get_step_context, pipeline, step\n",
"from zenml.client import Client\n",
"from zenml.logger import get_logger\n",
"\n",
"logger = get_logger(__name__)\n",
Expand Down Expand Up @@ -205,7 +200,7 @@
"@step\n",
"def data_loader_simplified(\n",
" random_state: int, is_inference: bool = False, target: str = \"target\"\n",
") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset \n",
") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset\n",
" \"\"\"Dataset reader step.\"\"\"\n",
" dataset = load_breast_cancer(as_frame=True)\n",
" inference_size = int(len(dataset.target) * 0.05)\n",
Expand All @@ -218,7 +213,7 @@
" dataset.drop(inference_subset.index, inplace=True)\n",
" dataset.reset_index(drop=True, inplace=True)\n",
" logger.info(f\"Dataset with {len(dataset)} records loaded!\")\n",
" return dataset\n"
" return dataset"
]
},
{
Expand Down Expand Up @@ -291,7 +286,7 @@
" normalize: Optional[bool] = None,\n",
" drop_columns: Optional[List[str]] = None,\n",
" target: Optional[str] = \"target\",\n",
" random_state: int = 17\n",
" random_state: int = 17,\n",
"):\n",
" \"\"\"Feature engineering pipeline.\"\"\"\n",
" # Link all the steps together by calling them and passing the output\n",
Expand Down Expand Up @@ -402,7 +397,6 @@
"from zenml.environment import Environment\n",
"from zenml.zen_stores.rest_zen_store import RestZenStore\n",
"\n",
"\n",
"if not isinstance(client.zen_store, RestZenStore):\n",
" # Only spin up a local Dashboard in case you aren't already connected to a remote server\n",
" if Environment.in_google_colab():\n",
Expand Down Expand Up @@ -479,7 +473,9 @@
"outputs": [],
"source": [
"# Get artifact version from our run\n",
"dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\"dataset_trn\"] \n",
"dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\n",
" \"dataset_trn\"\n",
"]\n",
"\n",
"# Get latest version from client directly\n",
"dataset_trn_artifact_version = client.get_artifact_version(\"dataset_trn\")\n",
Expand All @@ -498,7 +494,9 @@
"source": [
"# Fetch the rest of the artifacts\n",
"dataset_tst_artifact_version = client.get_artifact_version(\"dataset_tst\")\n",
"preprocessing_pipeline_artifact_version = client.get_artifact_version(\"preprocess_pipeline\")"
"preprocessing_pipeline_artifact_version = client.get_artifact_version(\n",
" \"preprocess_pipeline\"\n",
")"
]
},
{
Expand Down Expand Up @@ -576,23 +574,25 @@
"def model_trainer(\n",
" dataset_trn: pd.DataFrame,\n",
" model_type: str = \"sgd\",\n",
") -> Annotated[ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)]:\n",
") -> Annotated[\n",
" ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)\n",
"]:\n",
" \"\"\"Configure and train a model on the training dataset.\"\"\"\n",
" target = \"target\"\n",
" if model_type == \"sgd\":\n",
" model = SGDClassifier()\n",
" elif model_type == \"rf\":\n",
" model = RandomForestClassifier()\n",
" else:\n",
" raise ValueError(f\"Unknown model type {model_type}\") \n",
" raise ValueError(f\"Unknown model type {model_type}\")\n",
"\n",
" logger.info(f\"Training model {model}...\")\n",
"\n",
" model.fit(\n",
" dataset_trn.drop(columns=[target]),\n",
" dataset_trn[target],\n",
" )\n",
" return model\n"
" return model"
]
},
{
Expand Down Expand Up @@ -630,14 +630,14 @@
" min_train_accuracy: float = 0.0,\n",
" min_test_accuracy: float = 0.0,\n",
"):\n",
" \"\"\"Model training pipeline.\"\"\" \n",
" \"\"\"Model training pipeline.\"\"\"\n",
" if train_dataset_id is None or test_dataset_id is None:\n",
" # If we dont pass the IDs, this will run the feature engineering pipeline \n",
" # If we dont pass the IDs, this will run the feature engineering pipeline\n",
" dataset_trn, dataset_tst = feature_engineering()\n",
" else:\n",
" # Load the datasets from an older pipeline\n",
" dataset_trn = client.get_artifact_version(name_id_or_prefix=train_dataset_id)\n",
" dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id) \n",
" dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id)\n",
"\n",
" trained_model = model_trainer(\n",
" dataset_trn=dataset_trn,\n",
Expand Down Expand Up @@ -676,7 +676,7 @@
"training(\n",
" model_type=\"rf\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")\n",
"\n",
"rf_run = client.get_pipeline(\"training\").last_run"
Expand All @@ -693,7 +693,7 @@
"sgd_run = training(\n",
" model_type=\"sgd\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")\n",
"\n",
"sgd_run = client.get_pipeline(\"training\").last_run"
Expand All @@ -717,7 +717,9 @@
"outputs": [],
"source": [
"# The evaluator returns a float value with the accuracy\n",
"rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\"model_evaluator\"].output.load()"
"rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\n",
" \"model_evaluator\"\n",
"].output.load()"
]
},
{
Expand Down Expand Up @@ -776,7 +778,7 @@
"training_configured(\n",
" model_type=\"sgd\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")"
]
},
Expand All @@ -798,7 +800,7 @@
"training_configured(\n",
" model_type=\"rf\",\n",
" train_dataset_id=dataset_trn_artifact_version.id,\n",
" test_dataset_id=dataset_tst_artifact_version.id\n",
" test_dataset_id=dataset_tst_artifact_version.id,\n",
")"
]
},
Expand Down Expand Up @@ -848,7 +850,9 @@
"rf_zenml_model_version = client.get_model_version(\"breast_cancer_classifier\", \"rf\")\n",
"\n",
"# We can now load our classifier directly as well\n",
"random_forest_classifier = rf_zenml_model_version.get_artifact(\"sklearn_classifier\").load()\n",
"random_forest_classifier = rf_zenml_model_version.get_artifact(\n",
" \"sklearn_classifier\"\n",
").load()\n",
"\n",
"random_forest_classifier"
]
Expand Down Expand Up @@ -956,7 +960,7 @@
"\n",
" predictions = pd.Series(predictions, name=\"predicted\")\n",
"\n",
" return predictions\n"
" return predictions"
]
},
{
Expand All @@ -983,18 +987,18 @@
" random_state = 42\n",
" target = \"target\"\n",
"\n",
" df_inference = data_loader(\n",
" random_state=random_state, is_inference=True\n",
" )\n",
" df_inference = data_loader(random_state=random_state, is_inference=True)\n",
" df_inference = inference_preprocessor(\n",
" dataset_inf=df_inference,\n",
" # We use the preprocess pipeline from the feature engineering pipeline\n",
" preprocess_pipeline=client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id),\n",
" preprocess_pipeline=client.get_artifact_version(\n",
" name_id_or_prefix=preprocess_pipeline_id\n",
" ),\n",
" target=target,\n",
" )\n",
" inference_predict(\n",
" dataset_inf=df_inference,\n",
" )\n"
" )"
]
},
{
Expand All @@ -1018,7 +1022,7 @@
"# Lets add some metadata to the model to make it identifiable\n",
"pipeline_settings[\"model\"] = Model(\n",
" name=\"breast_cancer_classifier\",\n",
" version=\"production\", # We can pass in the stage name here!\n",
" version=\"production\", # We can pass in the stage name here!\n",
" license=\"Apache 2.0\",\n",
" description=\"A breast cancer classifier\",\n",
" tags=[\"breast_cancer\", \"classifier\"],\n",
Expand All @@ -1039,9 +1043,7 @@
"# Let's run it again to make sure we have two versions\n",
"# We need to pass in the ID of the preprocessing done in the feature engineering pipeline\n",
"# in order to avoid training-serving skew\n",
"inference_configured(\n",
" preprocess_pipeline_id=preprocessing_pipeline_artifact_version.id\n",
")"
"inference_configured(preprocess_pipeline_id=preprocessing_pipeline_artifact_version.id)"
]
},
{
Expand All @@ -1061,7 +1063,9 @@
"outputs": [],
"source": [
"# Fetch production model\n",
"production_model_version = client.get_model_version(\"breast_cancer_classifier\", \"production\")\n",
"production_model_version = client.get_model_version(\n",
" \"breast_cancer_classifier\", \"production\"\n",
")\n",
"\n",
"# Get the predictions artifact\n",
"production_model_version.get_artifact(\"predictions\").load()"
Expand Down

0 comments on commit 947f8ff

Please sign in to comment.