-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #559 from mindsdb/staging
Release 1.3.0
- Loading branch information
Showing
70 changed files
with
2,002 additions
and
1,219 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
name: Question | ||
about: Ask a question | ||
labels: question | ||
--- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
--- | ||
name: Suggestion | ||
about: Suggest a feature, improvement, doc change, etc. | ||
labels: enhancement | ||
--- | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
__title__ = 'lightwood' | ||
__package_name__ = 'lightwood' | ||
__version__ = '1.2.0' | ||
__version__ = '1.3.0' | ||
__description__ = "Lightwood is a toolkit for automatic machine learning model building" | ||
__email__ = "[email protected]" | ||
__author__ = 'MindsDB Inc' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,12 @@ | ||
from lightwood.analysis.model_analyzer import model_analyzer | ||
# Base | ||
from lightwood.analysis.analyze import model_analyzer | ||
from lightwood.analysis.explain import explain | ||
|
||
__all__ = ['model_analyzer', 'explain'] | ||
# Blocks | ||
from lightwood.analysis.base import BaseAnalysisBlock | ||
from lightwood.analysis.nc.calibrate import ICP | ||
from lightwood.analysis.helpers.acc_stats import AccStats | ||
from lightwood.analysis.helpers.feature_importance import GlobalFeatureImportance | ||
|
||
|
||
__all__ = ['model_analyzer', 'explain', 'ICP', 'AccStats', 'GlobalFeatureImportance', 'BaseAnalysisBlock'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Dict, List, Tuple, Optional | ||
|
||
from lightwood.api import dtype | ||
from lightwood.ensemble import BaseEnsemble | ||
from lightwood.analysis.base import BaseAnalysisBlock | ||
from lightwood.data.encoded_ds import EncodedDs | ||
from lightwood.encoder.text.pretrained import PretrainedLangEncoder | ||
from lightwood.api.types import ModelAnalysis, StatisticalAnalysis, TimeseriesSettings | ||
|
||
|
||
def model_analyzer( | ||
predictor: BaseEnsemble, | ||
data: EncodedDs, | ||
train_data: EncodedDs, | ||
stats_info: StatisticalAnalysis, | ||
target: str, | ||
ts_cfg: TimeseriesSettings, | ||
dtype_dict: Dict[str, str], | ||
accuracy_functions, | ||
analysis_blocks: Optional[List[BaseAnalysisBlock]] = [] | ||
) -> Tuple[ModelAnalysis, Dict[str, object]]: | ||
""" | ||
Analyses model on a validation subset to evaluate accuracy, estimate feature importance and generate a | ||
calibration model to estimating confidence in future predictions. | ||
Additionally, any user-specified analysis blocks (see class `BaseAnalysisBlock`) are also called here. | ||
:return: | ||
runtime_analyzer: This dictionary object gets populated in a sequential fashion with data generated from | ||
any `.analyze()` block call. This dictionary object is stored in the predictor itself, and used when | ||
calling the `.explain()` method of all analysis blocks when generating predictions. | ||
model_analysis: `ModelAnalysis` object that contains core analysis metrics, not necessarily needed when predicting. | ||
""" | ||
|
||
runtime_analyzer = {} | ||
data_type = dtype_dict[target] | ||
|
||
# retrieve encoded data representations | ||
encoded_train_data = train_data | ||
encoded_val_data = data | ||
data = encoded_val_data.data_frame | ||
input_cols = list([col for col in data.columns if col != target]) | ||
|
||
# predictive task | ||
is_numerical = data_type in (dtype.integer, dtype.float, dtype.array, dtype.tsarray, dtype.quantity) | ||
is_classification = data_type in (dtype.categorical, dtype.binary) | ||
is_multi_ts = ts_cfg.is_timeseries and ts_cfg.nr_predictions > 1 | ||
has_pretrained_text_enc = any([isinstance(enc, PretrainedLangEncoder) | ||
for enc in encoded_train_data.encoders.values()]) | ||
|
||
# raw predictions for validation dataset | ||
normal_predictions = predictor(encoded_val_data) if not is_classification else predictor(encoded_val_data, | ||
predict_proba=True) | ||
normal_predictions = normal_predictions.set_index(data.index) | ||
|
||
# ------------------------- # | ||
# Run analysis blocks, both core and user-defined | ||
# ------------------------- # | ||
kwargs = { | ||
'predictor': predictor, | ||
'target': target, | ||
'input_cols': input_cols, | ||
'dtype_dict': dtype_dict, | ||
'normal_predictions': normal_predictions, | ||
'data': data, | ||
'train_data': train_data, | ||
'encoded_val_data': encoded_val_data, | ||
'is_classification': is_classification, | ||
'is_numerical': is_numerical, | ||
'is_multi_ts': is_multi_ts, | ||
'stats_info': stats_info, | ||
'ts_cfg': ts_cfg, | ||
'accuracy_functions': accuracy_functions, | ||
'has_pretrained_text_enc': has_pretrained_text_enc | ||
} | ||
|
||
for block in analysis_blocks: | ||
runtime_analyzer = block.analyze(runtime_analyzer, **kwargs) | ||
|
||
# ------------------------- # | ||
# Populate ModelAnalysis object | ||
# ------------------------- # | ||
model_analysis = ModelAnalysis( | ||
accuracies=runtime_analyzer['score_dict'], | ||
accuracy_histogram=runtime_analyzer['acc_histogram'], | ||
accuracy_samples=runtime_analyzer['acc_samples'], | ||
train_sample_size=len(encoded_train_data), | ||
test_sample_size=len(encoded_val_data), | ||
confusion_matrix=runtime_analyzer['cm'], | ||
column_importances=runtime_analyzer['column_importances'], | ||
histograms=stats_info.histograms, | ||
dtypes=dtype_dict | ||
) | ||
|
||
return model_analysis, runtime_analyzer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Tuple, Dict, Optional | ||
|
||
import pandas as pd | ||
from lightwood.helpers.log import log | ||
|
||
|
||
class BaseAnalysisBlock: | ||
"""Class to be inherited by any analysis/explainer block.""" | ||
def __init__(self, | ||
deps: Optional[Tuple] = () | ||
): | ||
|
||
self.dependencies = deps # can be parallelized when there are no dependencies @TODO enforce | ||
|
||
def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]: | ||
""" | ||
This method should be called once during the analysis phase, or not called at all. | ||
It computes any information that the block may either output to the model analysis object, | ||
or use at inference time when `.explain()` is called (in this case, make sure all needed | ||
objects are added to the runtime analyzer so that `.explain()` can access them). | ||
:param info: Dictionary where any new information or objects are added. The next analysis block will use | ||
the output of the previous block as a starting point. | ||
:param kwargs: Dictionary with named variables from either the core analysis or the rest of the prediction | ||
pipeline. | ||
""" | ||
log.info(f"{self.__class__.__name__}.analyze() has not been implemented, no modifications will be done to the model analysis.") # noqa | ||
return info | ||
|
||
def explain(self, | ||
row_insights: pd.DataFrame, | ||
global_insights: Dict[str, object], **kwargs) -> Tuple[pd.DataFrame, Dict[str, object]]: | ||
""" | ||
This method should be called once during the explaining phase at inference time, or not called at all. | ||
Additional explanations can be at an instance level (row-wise) or global. | ||
For the former, return a data frame with any new insights. For the latter, a dictionary is required. | ||
:param row_insights: dataframe with previously computed row-level explanations. | ||
:param global_insights: dict() with any explanations that concern all predicted instances or the model itself. | ||
:returns: | ||
- row_insights: modified input dataframe with any new row insights added here. | ||
- global_insights: dict() with any explanations that concern all predicted instances or the model itself. | ||
""" | ||
log.info(f"{self.__class__.__name__}.explain() has not been implemented, no modifications will be done to the data insights.") # noqa | ||
return row_insights, global_insights |
Oops, something went wrong.