Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
julienfoenet committed Oct 8, 2024
1 parent 711a5d6 commit 87c399e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions dags/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) # So that airflow can find config files

from dags.config import GENERATED_DATA_PATH, DATA_FOLDER, MODEL_PATH, PREDICTIONS_FOLDER
from dags.config import GENERATED_DATA_PATH, DATA_FOLDER, PREDICTIONS_FOLDER, MODEL_REGISTRY_FOLDER
from formation_indus_ds_avancee.feature_engineering import prepare_features_with_io
from formation_indus_ds_avancee.train_and_predict import predict_with_io

Expand All @@ -27,7 +27,7 @@ def prepare_features_with_io_task():
@task
def predict_with_io_task(feature_path: str) -> None:
predict_with_io(features_path=feature_path,
model_path=MODEL_PATH,
model_registry_folder=MODEL_REGISTRY_FOLDER,
predictions_folder=PREDICTIONS_FOLDER)

feature_path = prepare_features_with_io_task()
Expand Down
6 changes: 4 additions & 2 deletions dags/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import sys
from datetime import timedelta
from datetime import timedelta, datetime

from airflow.decorators import dag, task
from airflow.utils.dates import days_ago
Expand All @@ -22,7 +22,9 @@ def prepare_features_task() -> str:

@task
def train_model_task(feature_path: str) -> None:
train_model_with_io(features_path=feature_path, model_registry_folder=MODEL_REGISTRY_FOLDER)
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
model_path = f'{timestamp}.joblib'
train_model_with_io(features_path=feature_path, model_registry_folder=MODEL_REGISTRY_FOLDER, model_path=model_path)

feature_path = prepare_features_task()
train_model_task(feature_path)
Expand Down
20 changes: 15 additions & 5 deletions formation_indus_ds_avancee/train_and_predict.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
import os
import time
import json

import joblib
import pandas as pd
from sklearn.ensemble import RandomForestRegressor


def train_model_with_io(features_path: str, model_registry_folder: str) -> None:
def train_model_with_io(features_path: str, model_registry_folder: str, model_path: str) -> None:
features = pd.read_parquet(features_path)

train_model(features, model_registry_folder)
train_model(features, model_registry_folder, model_path)


def train_model(features: pd.DataFrame, model_registry_folder: str) -> None:
def train_model(features: pd.DataFrame, model_registry_folder: str, model_path: str) -> None:
target = 'Ba_avg'
X = features.drop(columns=[target])
y = features[target]
model = RandomForestRegressor(n_estimators=1, max_depth=10, n_jobs=1)
model.fit(X, y)
joblib.dump(model, os.path.join(model_registry_folder, 'model.joblib'))

with open(os.path.join(model_registry_folder, 'version'), 'w') as f:
json.dump({'latest': model_path}, f)

def predict_with_io(features_path: str, model_path: str, predictions_folder: str) -> None:
joblib.dump(model, os.path.join(model_registry_folder, model_path))


def predict_with_io(features_path: str, model_registry_folder: str, predictions_folder: str) -> None:
features = pd.read_parquet(features_path)

with open(os.path.join(model_registry_folder, 'version'), 'r') as f:
model_version = json.load(f)
model_path = model_version['latest']

features = predict(features, model_path)
time_str = time.strftime('%Y%m%d-%H%M%S')
features['predictions_time'] = time_str
Expand Down

0 comments on commit 87c399e

Please sign in to comment.