-
Notifications
You must be signed in to change notification settings - Fork 6
/
experiment_tracking.py
117 lines (94 loc) · 3.55 KB
/
experiment_tracking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
"""
@author: Jithin Sasikumar
Tracks model training and log the model artifacts along with resulting metrics
and parameters. For that purpose, `MLFlow` is used. It has the flexibility to
extend its functionality to support other tracking mechanism like tensorboard etc.
It is facilitated via `ExperimentTracker protocol` which is similar to interface.
"""
import mlflow
import pandas as pd
from typing import Protocol
from dataclasses import dataclass, field
from src.exception_handler import MLFlowError
class ExperimentTracker(Protocol):
"""
Interface to track experiments by inherting from Protocol class.
"""
def __start__(self):
...
def log(self):
...
def find_best_model(self):
...
@dataclass
class ModelSelection:
"""
Dataclass that contains the dataframe with sorted list of models based on the
given metric.
Instance variables
------------------
model_selection_dataframe: DataFrame
"""
model_selection_dataframe: pd.DataFrame = field(default_factory = lambda: pd.DataFrame())
@dataclass
class MLFlowTracker:
"""
Dataclass to track experiment via MLFlow.
Instance variables
------------------
experiment_name: str
Name of the experiment to be activated.
tracking_uri: str
An HTTP URI or local file path, prefixed with `file:/`
Returns
-------
None.
"""
experiment_name: str
tracking_uri: str = "file:/./artifacts"
def __start__(self) -> None:
"""
Dunder method that sets tracking URI and experiment name
to MLFlow engine.
"""
mlflow.set_tracking_uri(self.tracking_uri)
mlflow.set_experiment(self.experiment_name)
def log(self) -> None:
"""
Initialize auto-logging for tracking. This will log model
artifacts, parameters and metrics in the ./artifacts directory.
"""
self.__start__()
mlflow.keras.autolog()
def find_best_model(self, metric: str) -> ModelSelection(pd.DataFrame):
"""
Method for model selection. Provides functionalities to find and sort
the best model based on the given metric in descending order from all
models within the given experiment directory which makes it easier to
select best performing model.
Note: This can also be done with mlflow using `mlflow ui` command. But,
this is a code implementation of the same.
Parameters
----------
metric: str
Metric name to sort the models.
Returns
-------
instanceof: ModelSelection(pd.DataFrame)
Resulting dataframe.
Raises
------
MLFlowError: Exception
If the experiment id or experiment name is none/invalid.
"""
experiment = dict(mlflow.get_experiment_by_name(self.experiment_name))
experiment_id = experiment['experiment_id']
if experiment is None or experiment_id is None:
raise MLFlowError(
f"Invalid experiment details. Please re-check them and try again !!!")
result_df = mlflow.search_runs([experiment_id],
order_by=[f"metrics.{metric} DESC"])
return ModelSelection(model_selection_dataframe = result_df[
["experiment_id", "run_id", f"metrics.{metric}"]
])