Skip to content

Commit

Permalink
Merge pull request #1079 from mindsdb/staging
Browse files Browse the repository at this point in the history
Release 22.12.2.0
  • Loading branch information
paxcema authored Dec 15, 2022
2 parents cf32de4 + ac326a1 commit f1ebf7b
Show file tree
Hide file tree
Showing 10 changed files with 355 additions and 64 deletions.
2 changes: 1 addition & 1 deletion lightwood/__about__.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = 'lightwood'
__package_name__ = 'lightwood'
__version__ = '22.12.1.1'
__version__ = '22.12.2.0'
__description__ = "Lightwood is a toolkit for automatic machine learning model building"
__email__ = "[email protected]"
__author__ = 'MindsDB Inc'
Expand Down
68 changes: 31 additions & 37 deletions lightwood/api/json_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def generate_json_ai(
submodels.extend(
[
{
"module": "LightGBM",
"module": "XGBoostMixer",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"fit_on_dev": True,
Expand All @@ -301,48 +301,35 @@ def generate_json_ai(
},
]
)
elif tss.is_timeseries and tss.horizon > 1:
elif tss.is_timeseries and tss.horizon > 1 and tss.use_previous_target and \
dtype_dict[target] in (dtype.integer, dtype.float, dtype.quantity):

submodels.extend(
[
{
"module": "LightGBMArray",
"module": "SkTime",
"args": {
"fit_on_dev": True,
"stop_after": "$problem_definition.seconds_per_mixer",
"ts_analysis": "$ts_analysis",
"tss": "$problem_definition.timeseries_settings",
"horizon": "$problem_definition.timeseries_settings.horizon",
},
},
{
"module": "ETSMixer",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"horizon": "$problem_definition.timeseries_settings.horizon",
},
},
{
"module": "ARIMAMixer",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"horizon": "$problem_definition.timeseries_settings.horizon",
},
}
]
)

if tss.use_previous_target and dtype_dict[target] in (dtype.integer, dtype.float, dtype.quantity):
submodels.extend(
[
{
"module": "SkTime",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"horizon": "$problem_definition.timeseries_settings.horizon",
},
},
{
"module": "ETSMixer",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"horizon": "$problem_definition.timeseries_settings.horizon",
},
},
{
"module": "ARIMAMixer",
"args": {
"stop_after": "$problem_definition.seconds_per_mixer",
"horizon": "$problem_definition.timeseries_settings.horizon",
},
}
]
)

model = {
"module": "BestOf",
"args": {
Expand Down Expand Up @@ -600,13 +587,16 @@ def _add_implicit_values(json_ai: JsonAI) -> JsonAI:
)
mixers[i]["args"]["ts_analysis"] = mixers[i]["args"].get("ts_analysis", "$ts_analysis")

elif mixers[i]["module"] == "LightGBM":
elif mixers[i]["module"] in ("LightGBM", "XGBoostMixer"):
mixers[i]["args"]["input_cols"] = mixers[i]["args"].get(
"input_cols", "$input_cols"
)
mixers[i]["args"]["target_encoder"] = mixers[i]["args"].get(
"target_encoder", "$encoders[self.target]"
)
mixers[i]["args"]["fit_on_dev"] = mixers[i]["args"].get(
"fit_on_dev", True
)
mixers[i]["args"]["use_optuna"] = True

elif mixers[i]["module"] == "Regression":
Expand Down Expand Up @@ -743,7 +733,8 @@ def _add_implicit_values(json_ai: JsonAI) -> JsonAI:
"dtype_dict": "$dtype_dict",
"target": "$target",
"mode": "$mode",
"ts_analysis": "$ts_analysis"
"ts_analysis": "$ts_analysis",
"pred_args": "$pred_args",
},
},
"timeseries_analyzer": {
Expand Down Expand Up @@ -1040,7 +1031,6 @@ def code_from_json_ai(json_ai: JsonAI) -> str:
# --------------- #
log.info('Ensembling the mixer')
# Create an ensemble of mixers to identify best performing model
self.pred_args = PredictionArguments()
# Dirty hack
self.ensemble = {call(json_ai.model)}
self.supports_proba = self.ensemble.supports_proba
Expand Down Expand Up @@ -1173,6 +1163,8 @@ def code_from_json_ai(json_ai: JsonAI) -> str:
if len(data) == 0:
raise Exception("Empty input, aborting prediction. Please try again with some input data.")
self.pred_args = PredictionArguments.from_dict(args)
log.info(f'[Predict phase 1/{{n_phases}}] - Data preprocessing')
if self.problem_definition.ignore_features:
log.info(f'Dropping features: {{self.problem_definition.ignore_features}}')
Expand All @@ -1190,14 +1182,14 @@ def code_from_json_ai(json_ai: JsonAI) -> str:
encoded_data = encoded_ds.get_encoded_data(include_target=False)
log.info(f'[Predict phase 3/{{n_phases}}] - Calling ensemble')
self.pred_args = PredictionArguments.from_dict(args)
df = self.ensemble(encoded_ds, args=self.pred_args)
if self.pred_args.all_mixers:
return df
else:
log.info(f'[Predict phase 4/{{n_phases}}] - Analyzing output')
insights, global_insights = {call(json_ai.explainer)}
self.global_insights = {{**self.global_insights, **global_insights}}
return insights
"""

Expand All @@ -1223,6 +1215,7 @@ def __init__(self):
self.identifiers = {json_ai.identifiers}
self.dtype_dict = {inline_dict(dtype_dict)}
self.lightwood_version = '{lightwood_version}'
self.pred_args = PredictionArguments()
# Any feature-column dependencies
self.dependencies = {inline_dict(json_ai.dependency_dict)}
Expand All @@ -1233,6 +1226,7 @@ def __init__(self):
self.statistical_analysis = None
self.ts_analysis = None
self.runtime_log = dict()
self.global_insights = dict()
@timed
def analyze_data(self, data: pd.DataFrame) -> None:
Expand Down
4 changes: 4 additions & 0 deletions lightwood/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ class PredictionArguments:
:param anomaly_cooldown: Sets the minimum amount of timesteps between consecutive firings of the the anomaly \
detector.
:param time_format: For time series predictors. If set to `infer`, predicted `order_by` timestamps will be formatted back to the original dataset's `order_by` format. Any other string value will be used as a formatting string, unless empty (''), which disables the feature (this is the default behavior).
:param force_ts_infer: For time series predictors. If set to `true`, an additional row will be produced per each group in the input DF, corresponding to an out-of-sample forecast w.r.t. to the input timestamps.
""" # noqa

predict_proba: bool = True
Expand All @@ -449,6 +450,7 @@ class PredictionArguments:
forecast_offset: int = 0
simple_ts_bounds: bool = False
time_format: str = ''
force_ts_infer: bool = False

@staticmethod
def from_dict(obj: Dict):
Expand All @@ -468,6 +470,7 @@ def from_dict(obj: Dict):
forecast_offset = obj.get('forecast_offset', PredictionArguments.forecast_offset)
simple_ts_bounds = obj.get('simple_ts_bounds', PredictionArguments.simple_ts_bounds)
time_format = obj.get('time_format', PredictionArguments.time_format)
force_ts_infer = obj.get('force_ts_infer', PredictionArguments.force_ts_infer)

pred_args = PredictionArguments(
predict_proba=predict_proba,
Expand All @@ -477,6 +480,7 @@ def from_dict(obj: Dict):
forecast_offset=forecast_offset,
simple_ts_bounds=simple_ts_bounds,
time_format=time_format,
force_ts_infer=force_ts_infer,
)

return pred_args
Expand Down
26 changes: 16 additions & 10 deletions lightwood/data/timeseries_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Optional
from functools import partial
import multiprocessing as mp

Expand All @@ -8,13 +8,15 @@
from lightwood.helpers.ts import get_ts_groups, get_delta, get_group_matches

from type_infer.dtype import dtype
from lightwood.api.types import TimeseriesSettings
from lightwood.api.types import TimeseriesSettings, PredictionArguments
from lightwood.helpers.log import log


def transform_timeseries(
data: pd.DataFrame, dtype_dict: Dict[str, str], ts_analysis: dict,
timeseries_settings: TimeseriesSettings, target: str, mode: str) -> pd.DataFrame:
timeseries_settings: TimeseriesSettings, target: str, mode: str,
pred_args: Optional[PredictionArguments] = None
) -> pd.DataFrame:
"""
Block that transforms the dataframe of a time series task to a convenient format for use in posterior phases like model training.
Expand All @@ -31,10 +33,12 @@ def transform_timeseries(
:param timeseries_settings: A `TimeseriesSettings` object.
:param target: The name of the target column to forecast.
:param mode: Either "train" or "predict", depending on what phase is calling this procedure.
:param pred_args: Optional prediction arguments to control the transformation process.
:return: A dataframe with all the transformations applied.
""" # noqa

pred_args = PredictionArguments() if pred_args is None else pred_args
tss = timeseries_settings
gb_arr = tss.group_by if tss.group_by is not None else []
oby = tss.order_by
Expand Down Expand Up @@ -93,16 +97,16 @@ def transform_timeseries(
offset = min(int(original_df['__mdb_forecast_offset'].unique()[0]), 1)
else:
offset = 0
infer_mode = offset_available and offset == 1
cutoff_mode = offset_available and offset == 1
else:
offset_available = False
offset = 0
infer_mode = False
cutoff_mode = False

original_index_list = []
idx = 0
for row in original_df.itertuples():
if _make_pred(row) or infer_mode:
if _make_pred(row) or cutoff_mode:
original_df.at[row.Index, '__make_predictions'] = True
original_index_list.append(idx)
idx += 1
Expand Down Expand Up @@ -130,7 +134,7 @@ def transform_timeseries(
n_groups = len(df_arr)
for i, subdf in enumerate(df_arr):
if '__mdb_forecast_offset' in subdf.columns and mode == 'predict':
if infer_mode:
if cutoff_mode:
df_arr[i] = _ts_infer_next_row(subdf, oby)
make_preds = [False for _ in range(max(0, len(df_arr[i]) - 1))] + [True]
elif offset_available:
Expand All @@ -139,6 +143,8 @@ def transform_timeseries(
make_preds = [False for _ in range(max(0, len(new_index) - 1))] + [True]
df_arr[i] = df_arr[i].loc[new_index]
else:
if pred_args.force_ts_infer:
df_arr[i] = _ts_infer_next_row(subdf, oby) # force-infer out-of-sample forecast in default mode
make_preds = [True for _ in range(len(df_arr[i]))]
df_arr[i]['__make_predictions'] = make_preds

Expand Down Expand Up @@ -179,7 +185,7 @@ def transform_timeseries(
combined_df = pd.DataFrame(combined_df[combined_df['__make_predictions']]) # filters by True only
del combined_df['__make_predictions']

if not infer_mode and any([i < tss.window for i in group_lengths]):
if not cutoff_mode and any([i < tss.window for i in group_lengths]):
if tss.allow_incomplete_history:
log.warning("Forecasting with incomplete historical context, predictions might be subpar")
else:
Expand All @@ -198,7 +204,7 @@ def transform_timeseries(
if df_gb_map is None:
for i in range(len(combined_df)):
row = combined_df.iloc[i]
if not infer_mode:
if not cutoff_mode:
timeseries_row_mapping[idx] = int(
row['original_index']) if row['original_index'] is not None and not np.isnan(
row['original_index']) else None
Expand All @@ -209,7 +215,7 @@ def transform_timeseries(
for gb in df_gb_map:
for i in range(len(df_gb_map[gb])):
row = df_gb_map[gb].iloc[i]
if not infer_mode:
if not cutoff_mode:
timeseries_row_mapping[idx] = int(
row['original_index']) if row['original_index'] is not None and not np.isnan(
row['original_index']) else None
Expand Down
12 changes: 9 additions & 3 deletions lightwood/mixer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from lightwood.mixer.unit import Unit
from lightwood.mixer.neural import Neural
from lightwood.mixer.neural_ts import NeuralTs
from lightwood.mixer.lightgbm import LightGBM
from lightwood.mixer.xgboost import XGBoostMixer
from lightwood.mixer.random_forest import RandomForest
from lightwood.mixer.lightgbm_array import LightGBMArray
from lightwood.mixer.sktime import SkTime
from lightwood.mixer.arima import ARIMAMixer
from lightwood.mixer.ets import ETSMixer
Expand All @@ -26,5 +25,12 @@
except Exception:
ProphetMixer = None

try:
from lightwood.mixer.lightgbm import LightGBM
from lightwood.mixer.lightgbm_array import LightGBMArray
except Exception:
LightGBM = None
LightGBMArray = None

__all__ = ['BaseMixer', 'Neural', 'NeuralTs', 'LightGBM', 'RandomForest', 'LightGBMArray', 'Unit', 'Regression',
'SkTime', 'QClassic', 'ProphetMixer', 'ETSMixer', 'ARIMAMixer', 'NHitsMixer', 'GluonTSMixer']
'SkTime', 'QClassic', 'ProphetMixer', 'ETSMixer', 'ARIMAMixer', 'NHitsMixer', 'GluonTSMixer', 'XGBoostMixer']
Loading

0 comments on commit f1ebf7b

Please sign in to comment.