Skip to content

Commit

Permalink
Fix bug in plot_ice, and clean docstring of plot_ice and plot_pdp (#135)
Browse files Browse the repository at this point in the history
* fix plot_pdp/ice

* fix test

* fix type hints
  • Loading branch information
aloctavodia authored Dec 23, 2023
1 parent 1d2287e commit a1adedf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 29 deletions.
45 changes: 18 additions & 27 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,12 @@ def plot_ice(
bartrv: Variable,
X: npt.NDArray[np.float_],
Y: Optional[npt.NDArray[np.float_]] = None,
xs_interval: str = "quantiles",
xs_values: Optional[Union[int, List[float]]] = None,
var_idx: Optional[List[int]] = None,
var_discrete: Optional[List[int]] = None,
func: Optional[Callable] = None,
centered: Optional[bool] = True,
samples: int = 50,
instances: int = 10,
samples: int = 100,
instances: int = 30,
random_seed: Optional[int] = None,
sharey: bool = True,
smooth: bool = True,
Expand All @@ -185,16 +183,6 @@ def plot_ice(
The covariate matrix.
Y : Optional[npt.NDArray[np.float_]], by default None.
The response vector.
xs_interval : str
Method used to compute the values X used to evaluate the predicted function. "linear",
evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified
quantiles of X. "insample", the evaluation is done at the values of X.
For discrete variables these options are ommited.
xs_values : Optional[Union[int, List[float]]], by default None.
Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
quantiles to compute, which must be between 0 and 1 inclusive.
Ignored when ``xs_interval="insample"``.
var_idx : Optional[List[int]], by default None.
List of the indices of the covariate for which to compute the pdp or ice.
var_discrete : Optional[List[int]], by default None.
Expand All @@ -205,22 +193,20 @@ def plot_ice(
If True the result is centered around the partial response evaluated at the lowest value in
``xs_interval``. Defaults to True.
samples : int
Number of posterior samples used in the predictions. Defaults to 50
Number of posterior samples used in the predictions. Defaults to 100
instances : int
Number of instances of X to plot. Defaults to 10.
Number of instances of X to plot. Defaults to 30.
random_seed : Optional[int], by default None.
Seed used to sample from the posterior. Defaults to None.
sharey : bool
Controls sharing of properties among y-axes. Defaults to True.
rug : bool
Whether to include a rugplot. Defaults to True.
smooth : bool
If True the result will be smoothed by first computing a linear interpolation of the data
over a regular grid and then applying the Savitzky-Golay filter to the interpolated data.
Defaults to True.
grid : str or tuple
How to arrange the subplots. Defaults to "long", one subplot below the other.
Other options are "wide", one subplot next to eachother or a tuple indicating the number of
Other options are "wide", one subplot next to each other or a tuple indicating the number of
rows and columns.
color : matplotlib valid color
Color used to plot the pdp or ice. Defaults to "C0"
Expand Down Expand Up @@ -257,17 +243,17 @@ def identity(x):
indices,
var_idx,
var_discrete,
xs_interval,
xs_values,
) = _prepare_plot_data(X, Y, xs_interval, xs_values, var_idx, var_discrete)
_,
_,
) = _prepare_plot_data(X, Y, "linear", None, var_idx, var_discrete)

fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax)

instances_ary = rng.choice(range(X.shape[0]), replace=False, size=instances)
idx_s = list(range(X.shape[0]))

count = 0
for var in range(len(var_idx)):
for i_var, var in enumerate(var_idx):
indices_mi = indices[:]
indices_mi.remove(var)
y_pred = []
Expand All @@ -283,6 +269,7 @@ def identity(x):

new_x = fake_X[:, var]
p_d = np.array(y_pred)
print(p_d.shape)

for s_i in range(shape):
if centered:
Expand All @@ -301,7 +288,7 @@ def identity(x):
idx = np.argsort(new_x)
axes[count].plot(new_x[idx], p_di.mean(0)[idx], color=color_mean)
axes[count].plot(new_x[idx], p_di.T[idx], color=color, alpha=alpha)
axes[count].set_xlabel(x_labels[var])
axes[count].set_xlabel(x_labels[i_var])

count += 1

Expand Down Expand Up @@ -349,7 +336,7 @@ def plot_pdp(
For discrete variables these options are ommited.
xs_values : Optional[Union[int, List[float]]], by default None.
Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
points in the evenly spaced grid. If ``xs_interval="quantiles"`` quantile or sequence of
quantiles to compute, which must be between 0 and 1 inclusive.
Ignored when ``xs_interval="insample"``.
var_idx : Optional[List[int]], by default None.
Expand Down Expand Up @@ -717,7 +704,8 @@ def plot_variable_importance(
xlabel_angle: float = 0,
samples: int = 100,
random_seed: Optional[int] = None,
) -> Tuple[List[int], List[plt.Axes]]:
ax: Optional[plt.Axes] = None,
) -> Tuple[List[int], Union[List[plt.Axes], Any]]:
"""
Estimates variable importance from the BART-posterior.
Expand Down Expand Up @@ -747,6 +735,8 @@ def plot_variable_importance(
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
random_seed : Optional[int]
random_seed used to sample from the posterior. Defaults to None.
ax : axes
Matplotlib axes.
Returns
-------
Expand All @@ -771,7 +761,8 @@ def plot_variable_importance(
if figsize is None:
figsize = (8, 3)

_, ax = plt.subplots(1, 1, figsize=figsize)
if ax is None:
_, ax = plt.subplots(1, 1, figsize=figsize)

if labels is None:
labels_ary = np.arange(n_vars).astype(str)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def test_sample_posterior(self):
{},
{
"samples": 2,
"xs_interval": "quantiles",
"xs_values": [0.25, 0.5, 0.75],
"var_discrete": [3],
},
{"instances": 2},
Expand Down

0 comments on commit a1adedf

Please sign in to comment.