Skip to content

Commit

Permalink
Merge pull request #238 from alan-turing-institute/newplot
Browse files Browse the repository at this point in the history
adds a new plot type and plotting tests
  • Loading branch information
mastoffel authored Sep 11, 2024
2 parents 6b05aea + decbfd0 commit 70cc05c
Show file tree
Hide file tree
Showing 4 changed files with 581 additions and 79 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }} # Required for private repos
file: ./coverage.xml # Specify the coverage report file
fail_ci_if_error: true

verbose: true


- name: Store coverage file
uses: actions/upload-artifact@v3
with:
name: coverage
path: .coverage.${{ matrix.python-version }}
include-hidden-files: true



# Coverage job to comment on PRs and update README badge
coverage:
Expand All @@ -90,7 +95,7 @@ jobs:
- name: Download coverage artifact
uses: actions/download-artifact@v3
with:
name: 'coverage'
name: coverage

# Comment coverage details on PR
- name: Coverage comment
Expand Down
32 changes: 27 additions & 5 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,11 @@ def print_results(self, model=None, sort_by="r2"):
def plot_results(
self,
model=None,
plot="standard",
plot="Xy",
n_cols=3,
figsize=None,
output_index=0,
input_index=0,
):
"""Plots the results of the cross-validation.
Expand All @@ -428,19 +429,22 @@ def plot_results(
If a model name is specified, plots all folds of that model.
plot_type : str, optional
The type of plot to draw:
“standard” draws the observed values (y-axis) vs. the predicted values (x-axis) (default).
"Xy" observed and predicted values vs. features, including 2σ error bands where available (default).
“standard” draws the observed values (y-axis) vs. the predicted values (x-axis).
“residual” draws the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
n_cols : int
Number of columns in the plot grid.
figsize : tuple, optional
Overrides the default figure size.
output_index : int
Index of the output to plot. Default is 0.
input_index : int
Index of the input to plot. Default is 0.
"""
model_name = (
_get_full_model_name(model, self.model_names) if model is not None else None
)
_plot_results(
figure = _plot_results(
self.cv_results,
self.X,
self.y,
Expand All @@ -449,7 +453,9 @@ def plot_results(
plot=plot,
figsize=figsize,
output_index=output_index,
input_index=input_index,
)
return figure

def evaluate_model(self, model=None):
"""
Expand Down Expand Up @@ -484,7 +490,15 @@ def evaluate_model(self, model=None):

return scores_df

def plot_model(self, model, plot="standard", n_cols=2, figsize=None):
def plot_model(
self,
model,
plot="Xy",
n_cols=3,
figsize=None,
output_index=0,
input_index=0,
):
"""Plots the model predictions vs. the true values.
Parameters
Expand All @@ -497,12 +511,20 @@ def plot_model(self, model, plot="standard", n_cols=2, figsize=None):
“residual” draws the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
n_cols : int, optional
Number of columns in the plot grid for multi-output. Default is 2.
output_index : int
Index of the output to plot. Default is 0..
input_index : int
Index of the input to plot. Default is 0. Only used if plot_type="Xy".
"""
_plot_model(
fig = _plot_model(
model,
self.X[self.test_idxs],
self.y[self.test_idxs],
plot,
n_cols,
figsize,
input_index=input_index,
output_index=output_index,
)

return fig
Loading

0 comments on commit 70cc05c

Please sign in to comment.