Skip to content

Commit

Permalink
auto-generate problem
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Oct 30, 2024
1 parent f12ff0b commit 7f0fdc2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 6 deletions.
63 changes: 63 additions & 0 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from autoemulate.plotting import _plot_model
from autoemulate.printing import _print_setup
from autoemulate.save import ModelSerialiser
from autoemulate.sensitivity_analysis import plot_sensitivity_analysis
from autoemulate.sensitivity_analysis import sensitivity_analysis
from autoemulate.utils import _ensure_2d
from autoemulate.utils import _get_full_model_name
from autoemulate.utils import _redirect_warnings
Expand Down Expand Up @@ -523,3 +525,64 @@ def plot_eval(
)

return fig

def sensitivity_analysis(
self, model=None, problem=None, N=1024, conf_level=0.95, as_df=True
):
"""Perform Sobol sensitivity analysis on a fitted emulator.
Parameters
----------
model : object, optional
Fitted model. If None, uses the best model from cross-validation.
problem : dict, optional
The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'.
If None, the problem is generated from X using minimum and maximum values of the features as bounds.
Example:
```python
problem = {
"num_vars": 2,
"names": ["x1", "x2"],
"bounds": [[0, 1], [0, 1]],
}
```
N : int, optional
Number of samples to generate. Default is 1024.
conf_level : float, optional
Confidence level for the confidence intervals. Default is 0.95.
as_df : bool, optional
If True, return a long-format pandas DataFrame (default is True).
"""
if model is None:
if not hasattr(self, "best_model"):
raise RuntimeError("Must run compare() before sensitivity_analysis()")
model = self.best_model
self.logger.info(
f"No model provided, using best model {get_model_name(model)} from cross-validation for sensitivity analysis"
)

Si = sensitivity_analysis(model, problem, self.X, N, conf_level, as_df)
return Si

def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=None):
"""
Plot the sensitivity analysis results.
Parameters:
-----------
results : pd.DataFrame
The results from sobol_results_to_df.
index : str, default "S1"
The type of sensitivity index to plot.
- "S1": first-order indices
- "S2": second-order/interaction indices
- "ST": total-order indices
n_cols : int, optional
The number of columns in the plot. Defaults to 3 if there are 3 or more outputs,
otherwise the number of outputs.
figsize : tuple, optional
Figure size as (width, height) in inches.If None, automatically calculated.
"""
return plot_sensitivity_analysis(results, index, n_cols, figsize)
33 changes: 27 additions & 6 deletions autoemulate/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from autoemulate.utils import _ensure_2d


def sensitivity_analysis(model, problem, N=1024, conf_level=0.95, as_df=True):
def sensitivity_analysis(
model, problem=None, X=None, N=1024, conf_level=0.95, as_df=True
):
"""Perform Sobol sensitivity analysis on a fitted emulator.
Parameters:
Expand Down Expand Up @@ -39,7 +41,7 @@ def sensitivity_analysis(model, problem, N=1024, conf_level=0.95, as_df=True):
containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry
is a list of length corresponding to the number of parameters.
"""
Si = sobol_analysis(model, problem, N, conf_level)
Si = sobol_analysis(model, problem, X, N, conf_level)

if as_df:
return sobol_results_to_df(Si)
Expand Down Expand Up @@ -85,7 +87,21 @@ def _get_output_names(problem, num_outputs):
return output_names


def sobol_analysis(model, problem, N=1024, conf_level=0.95):
def _generate_problem(X):
"""
Generate a problem definition from a design matrix.
"""
if X.ndim == 1:
raise ValueError("X must be a 2D array.")

return {
"num_vars": X.shape[1],
"names": [f"x{i+1}" for i in range(X.shape[1])],
"bounds": [[X[:, i].min(), X[:, i].max()] for i in range(X.shape[1])],
}


def sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
"""
Perform Sobol sensitivity analysis on a fitted emulator.
Expand All @@ -105,8 +121,13 @@ def sobol_analysis(model, problem, N=1024, conf_level=0.95):
containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry
is a list of length corresponding to the number of parameters.
"""
# correctly defined?
problem = _check_problem(problem)
# get problem
if problem is not None:
problem = _check_problem(problem)
elif X is not None:
problem = _generate_problem(X)
else:
raise ValueError("Either problem or X must be provided.")

# saltelli sampling
param_values = sample(problem, N)
Expand Down Expand Up @@ -240,7 +261,7 @@ def plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
Figure size as (width, height) in inches.If None, automatically calculated.
"""
with plt.style.context("seaborn-v0_8-whitegrid"):
with plt.style.context("fast"):
# prepare data
results = _validate_input(results, index)
unique_outputs = results["output"].unique()
Expand Down
12 changes: 12 additions & 0 deletions tests/test_sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from autoemulate.experimental_design import LatinHypercube
from autoemulate.sensitivity_analysis import _calculate_layout
from autoemulate.sensitivity_analysis import _check_problem
from autoemulate.sensitivity_analysis import _generate_problem
from autoemulate.sensitivity_analysis import _get_output_names
from autoemulate.sensitivity_analysis import _validate_input
from autoemulate.sensitivity_analysis import sobol_analysis
Expand Down Expand Up @@ -192,3 +193,14 @@ def test_calculate_layout_custom():
n_rows, n_cols = _calculate_layout(3, 2)
assert n_rows == 2
assert n_cols == 2


# test _generate_problem -----------------------------------------------------


def test_generate_problem():
X = np.array([[0, 0], [1, 1], [2, 2]])
problem = _generate_problem(X)
assert problem["num_vars"] == 2
assert problem["names"] == ["x1", "x2"]
assert problem["bounds"] == [[0, 2], [0, 2]]

0 comments on commit 7f0fdc2

Please sign in to comment.