Skip to content

Commit

Permalink
Merge pull request #1123 from mindsdb/staging
Browse files Browse the repository at this point in the history
Release 23.3.2.0
  • Loading branch information
paxcema authored Mar 14, 2023
2 parents e77a859 + 019df03 commit 68a71f8
Show file tree
Hide file tree
Showing 19 changed files with 316 additions and 124 deletions.
52 changes: 26 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,40 +65,40 @@ from lightwood.api.high_level import (
predictor_from_code,
)

# Load a pandas dataset
df = pd.read_csv(
"https://raw.githubusercontent.com/mindsdb/benchmarks/main/benchmarks/datasets/hdi/data.csv"
)
if __name__ == '__main__':
# Load a pandas dataset
df = pd.read_csv("https://raw.githubusercontent.com/mindsdb/benchmarks/main/benchmarks/datasets/hdi/data.csv"
)

# Define the prediction task by naming the target column
pdef = ProblemDefinition.from_dict(
{
"target": "Development Index", # column you want to predict
}
)
# Define the prediction task by naming the target column
pdef = ProblemDefinition.from_dict(
{
"target": "Development Index", # column you want to predict
}
)

# Generate JSON-AI code to model the problem
json_ai = json_ai_from_problem(df, problem_definition=pdef)
# Generate JSON-AI code to model the problem
json_ai = json_ai_from_problem(df, problem_definition=pdef)

# OPTIONAL - see the JSON-AI syntax
#print(json_ai.to_json())
# OPTIONAL - see the JSON-AI syntax
# print(json_ai.to_json())

# Generate python code
code = code_from_json_ai(json_ai)
# Generate python code
code = code_from_json_ai(json_ai)

# OPTIONAL - see generated code
#print(code)
# OPTIONAL - see generated code
# print(code)

# Create a predictor from python code
predictor = predictor_from_code(code)
# Create a predictor from python code
predictor = predictor_from_code(code)

# Train a model end-to-end from raw data to a finalized predictor
predictor.learn(df)
# Train a model end-to-end from raw data to a finalized predictor
predictor.learn(df)

# Make the train/test splits and show predictions for a few examples
test_df = predictor.split(predictor.preprocess(df))["test"]
preds = predictor.predict(test_df).iloc[:10]
print(preds)
# Make the train/test splits and show predictions for a few examples
test_df = predictor.split(predictor.preprocess(df))["test"]
preds = predictor.predict(test_df).iloc[:10]
print(preds)
```

### BYOM: Bring your own models
Expand Down
2 changes: 1 addition & 1 deletion lightwood/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = 'lightwood'
__package_name__ = 'lightwood'
__version__ = '23.2.1.0'
__version__ = '23.3.2.0'
__description__ = "Lightwood is a toolkit for automatic machine learning model building"
__email__ = "[email protected]"
__author__ = 'MindsDB Inc'
Expand Down
36 changes: 23 additions & 13 deletions lightwood/analysis/nc/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,20 @@ def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]:
icp_df = deepcopy(ns.data)

# setup prediction cache to avoid additional .predict() calls
pred_is_list = isinstance(ns.normal_predictions['prediction'], list) and \
isinstance(ns.normal_predictions['prediction'][0], list)
try:
pred_is_list = isinstance(ns.normal_predictions['prediction'][0], list)
except KeyError:
pred_is_list = False

if ns.is_classification:
if ns.predictor.supports_proba:
icp.nc_function.model.prediction_cache = ns.normal_predictions[all_cat_cols].values
else:
if ns.is_multi_ts:
icp.nc_function.model.prediction_cache = np.array(
[p[0] for p in ns.normal_predictions['prediction']])
preds = icp.nc_function.model.prediction_cache
preds = np.array([p[0] for p in ns.normal_predictions['prediction']])
else:
preds = ns.normal_predictions['prediction']
predicted_classes = pd.get_dummies(preds).values # inflate to one-hot enc
preds = ns.normal_predictions['prediction'].values
predicted_classes = output['label_encoders'].transform(preds.reshape(-1, 1)) # inflate OHE
icp.nc_function.model.prediction_cache = predicted_classes

elif ns.is_multi_ts or pred_is_list:
Expand Down Expand Up @@ -198,8 +199,12 @@ def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]:

# save relevant predictions in the caches, then calibrate the ICP
pred_cache = icp_df.pop(f'__predicted_{ns.target}').values
if ns.is_multi_ts:
if ns.is_multi_ts and ns.is_classification:
# output['label_encoders'].transform(preds.reshape(-1, 1))
pred_cache = output['label_encoders'].transform([[p[0] for p in pred_cache]])
elif ns.is_multi_ts:
pred_cache = np.array([np.array(p) for p in pred_cache])

icps[tuple(group)].nc_function.model.prediction_cache = pred_cache
icp_df, y = clean_df(icp_df, ns, output.get('label_encoders', None))
if icps[tuple(group)].nc_function.normalizer is not None:
Expand Down Expand Up @@ -334,7 +339,8 @@ def explain(self, row_insights: pd.DataFrame, global_insights: Dict[str, object]
for icol, cat_col in enumerate(all_cat_cols):
row_insights.loc[X.index, cat_col] = class_dists[:, icol]
else:
class_dists = pd.get_dummies(preds).values
ohe_enc = ns.analysis['label_encoders']
class_dists = ohe_enc.transform(np.array([p[0] for p in preds]).reshape(-1, 1))

base_icp.nc_function.model.prediction_cache = class_dists

Expand All @@ -360,8 +366,8 @@ def explain(self, row_insights: pd.DataFrame, global_insights: Dict[str, object]
result.loc[X.index, 'significance'] = significances

else:
significances = get_categorical_conf(all_confs.squeeze())
result.loc[X.index, 'significance'] = significances
significances = get_categorical_conf(all_confs)
result.loc[X.index, 'significance'] = significances.flatten()

# grouped time series, we replace bounds in rows that have a trained ICP
if ns.analysis['icp'].get('__mdb_groups', False):
Expand All @@ -386,6 +392,11 @@ def explain(self, row_insights: pd.DataFrame, global_insights: Dict[str, object]
for i in range(1, ns.tss.horizon)]
icp.nc_function.model.prediction_cache = X[target_cols].values
[X.pop(col) for col in target_cols]
elif is_multi_ts and is_categorical:
ohe_enc = ns.analysis['label_encoders']
preds = X.pop(ns.target_name).values
pred_cache = ohe_enc.transform(np.array([p[0] for p in preds]).reshape(-1, 1))
icp.nc_function.model.prediction_cache = pred_cache
else:
icp.nc_function.model.prediction_cache = X.pop(ns.target_name).values
if icp.nc_function.normalizer:
Expand Down Expand Up @@ -431,7 +442,7 @@ def explain(self, row_insights: pd.DataFrame, global_insights: Dict[str, object]
all_ranges = np.array([icp.predict(X.values)])
all_confs = np.swapaxes(np.swapaxes(all_ranges, 0, 2), 0, 1)
significances = get_categorical_conf(all_confs)
result.loc[X.index, 'significance'] = significances
result.loc[X.index, 'significance'] = significances.flatten()

row_insights['confidence'] = result['significance']

Expand Down Expand Up @@ -519,6 +530,5 @@ def _ts_assign_confs(result, df, confs, significances, tss) -> pd.DataFrame:
added_cols = [f'{base_col}_timestep_{t}' for t in range(1, tss.horizon)]
cols = [base_col] + added_cols
result.loc[df.index, base_col] = result.loc[df.index, cols].values.tolist()
# result[base_col] = result[cols].values.tolist()

return result
35 changes: 24 additions & 11 deletions lightwood/api/json_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,22 +571,31 @@ def _add_implicit_values(json_ai: JsonAI) -> JsonAI:
mixers[i]["args"]["stop_after"] = mixers[i]["args"].get("stop_after", "$problem_definition.seconds_per_mixer")

# specific
if mixers[i]["module"] in ("Neural", "NeuralTs"):
if mixers[i]["module"] in ("Neural", "NeuralTs", "TabTransformerMixer"):
mixers[i]["args"]["target_encoder"] = mixers[i]["args"].get(
"target_encoder", "$encoders[self.target]"
)
mixers[i]["args"]["net"] = mixers[i]["args"].get(
"net",
'"DefaultNet"'
if not tss.is_timeseries or not tss.use_previous_target
else '"ArNet"',
)

if mixers[i]["module"] in ("Neural", "NeuralTs"):
mixers[i]["args"]["net"] = mixers[i]["args"].get(
"net",
'"DefaultNet"'
if not tss.is_timeseries or not tss.use_previous_target
else '"ArNet"',
)
mixers[i]["args"]["search_hyperparameters"] = mixers[i]["args"].get("search_hyperparameters", True)
mixers[i]["args"]["fit_on_dev"] = mixers[i]["args"].get("fit_on_dev", True)

if mixers[i]["module"] == "NeuralTs":
mixers[i]["args"]["timeseries_settings"] = mixers[i]["args"].get(
"timeseries_settings", "$problem_definition.timeseries_settings"
)
mixers[i]["args"]["ts_analysis"] = mixers[i]["args"].get("ts_analysis", "$ts_analysis")

if mixers[i]["module"] == "TabTransformerMixer":
mixers[i]["args"]["search_hyperparameters"] = mixers[i]["args"].get("search_hyperparameters", False)
mixers[i]["args"]["fit_on_dev"] = mixers[i]["args"].get("fit_on_dev", False)

elif mixers[i]["module"] in ("LightGBM", "XGBoostMixer"):
mixers[i]["args"]["input_cols"] = mixers[i]["args"].get(
"input_cols", "$input_cols"
Expand Down Expand Up @@ -1025,7 +1034,12 @@ def code_from_json_ai(json_ai: JsonAI) -> str:
trained_mixers = []
for mixer in self.mixers:
try:
self.fit_mixer(mixer, encoded_train_data, encoded_dev_data)
if mixer.trains_once:
self.fit_mixer(mixer,
ConcatedEncodedDs([encoded_train_data, encoded_dev_data]),
encoded_test_data)
else:
self.fit_mixer(mixer, encoded_train_data, encoded_dev_data)
trained_mixers.append(mixer)
except Exception as e:
log.warning(f'Exception: {{e}} when training mixer: {{mixer}}')
Expand Down Expand Up @@ -1107,7 +1121,7 @@ def code_from_json_ai(json_ai: JsonAI) -> str:
log.info('Updating the mixers')
for mixer in self.mixers:
mixer.partial_fit(train_data, dev_data, adjust_args)
mixer.partial_fit(train_data, dev_data, adjust_args)
""" # noqa

adjust_body = align(adjust_body, 2)
Expand Down Expand Up @@ -1154,8 +1168,7 @@ def code_from_json_ai(json_ai: JsonAI) -> str:
# SET `json_ai.problem_definition.fit_on_all=False` TO TURN THIS BLOCK OFF.
# Update the mixers with partial fit
if self.problem_definition.fit_on_all:
if self.problem_definition.fit_on_all and all([not m.trains_once for m in self.mixers]):
log.info(f'[Learn phase 8/{n_phases}] - Adjustment on validation requested')
self.adjust(enc_train_test["test"].data_frame, ConcatedEncodedDs([enc_train_test["train"],
enc_train_test["dev"]]).data_frame,
Expand Down
6 changes: 3 additions & 3 deletions lightwood/data/encoded_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class ConcatedEncodedDs(EncodedDs):
def __init__(self, encoded_ds_arr: List[EncodedDs]) -> None:
# @TODO: missing super() call here?
self.encoded_ds_arr = encoded_ds_arr
self.encoded_ds_lenghts = [len(x) for x in self.encoded_ds_arr]
self.encoded_ds_lengths = [len(x) for x in self.encoded_ds_arr]
self.encoders = self.encoded_ds_arr[0].encoders
self.encoder_spans = self.encoded_ds_arr[0].encoder_spans
self.target = self.encoded_ds_arr[0].target
Expand All @@ -155,13 +155,13 @@ def __len__(self):
See `lightwood.data.encoded_ds.EncodedDs.__len__()`.
"""
# @TODO: behavior here is not intuitive
return max(0, np.sum(self.encoded_ds_lenghts) - 2)
return max(0, np.sum(self.encoded_ds_lengths) - 2)

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
See `lightwood.data.encoded_ds.EncodedDs.__getitem__()`.
"""
for ds_idx, length in enumerate(self.encoded_ds_lenghts):
for ds_idx, length in enumerate(self.encoded_ds_lengths):
if idx - length < 0:
return self.encoded_ds_arr[ds_idx][idx]
else:
Expand Down
2 changes: 1 addition & 1 deletion lightwood/data/timeseries_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def transform_timeseries(
subsets = []
for group in groups:
if (tss.group_by and group != '__default') or not tss.group_by:
idxs, subset = get_group_matches(data, group, tss.group_by)
idxs, subset = get_group_matches(data, group, tss.group_by, copy=True)
if subset.shape[0] > 0:
if periods.get(group, periods['__default']) == 0 and subset.shape[0] > 1:
raise Exception(
Expand Down
5 changes: 4 additions & 1 deletion lightwood/helpers/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def get_ts_groups(df: pd.DataFrame, tss) -> list:
def get_group_matches(
data: Union[pd.Series, pd.DataFrame],
combination: tuple,
group_columns: List[str]
group_columns: List[str],
copy: bool = False
) -> Tuple[list, pd.DataFrame]:
"""Given a particular group combination, return the data subset that belongs to it."""

Expand All @@ -34,6 +35,8 @@ def get_group_matches(
for val, col in zip(combination, group_columns):
subset = subset[subset[col] == val]
if len(subset) > 0:
if copy:
subset = subset.copy()
return list(subset.index), subset
else:
return [], pd.DataFrame()
Expand Down
8 changes: 7 additions & 1 deletion lightwood/mixer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,11 @@
LightGBM = None
LightGBMArray = None

try:
from lightwood.mixer.tabtransformer import TabTransformerMixer
except Exception:
TabTransformerMixer = None

__all__ = ['BaseMixer', 'Neural', 'NeuralTs', 'LightGBM', 'RandomForest', 'LightGBMArray', 'Unit', 'Regression',
'SkTime', 'QClassic', 'ProphetMixer', 'ETSMixer', 'ARIMAMixer', 'NHitsMixer', 'GluonTSMixer', 'XGBoostMixer']
'SkTime', 'QClassic', 'ProphetMixer', 'ETSMixer', 'ARIMAMixer', 'NHitsMixer', 'GluonTSMixer', 'XGBoostMixer',
'TabTransformerMixer']
4 changes: 3 additions & 1 deletion lightwood/mixer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ class BaseMixer:
- stable: If set to `True`, this mixer should always work. Any mixer with `stable=False` can be expected to fail under some circumstances.
- fit_data_len: Length of the training data.
- supports_proba: For classification tasks, whether the mixer supports yielding per-class scores rather than only returning the predicted label.
- trains_once: If True, the mixer is trained once during learn, using all available input data (`train` and `dev` splits for training, `test` for validation). Otherwise, it trains once with the `train`` split & `dev` for validation, and optionally (depending on the problem definition `fit_on_all` and mixer-wise `fit_on_dev` arguments) a second time after post-training analysis via partial_fit, with `train` and `dev` splits as training subset, and `test` split as validation. Should only be set to True for mixers that don't require post-training analysis, as otherwise actual validation data would be treated as a held-out portion, which is a mistake.
""" # noqa
stable: bool
fit_data_len: int # @TODO (Patricio): should this really be in `BaseMixer`?
supports_proba: bool
trains_once: bool

def __init__(self, stop_after: float):
"""
:param stop_after: Time budget to train this mixer.
"""
self.stop_after = stop_after
self.supports_proba = False
self.trains_once = False

def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
"""
Expand Down
Loading

0 comments on commit 68a71f8

Please sign in to comment.