Skip to content

Commit

Permalink
fix bug (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Jul 10, 2023
1 parent aedee25 commit 0917ab4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 23 deletions.
4 changes: 2 additions & 2 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ def _traverse_tree(
params[0][nd_dims] + params[1][nd_dims] * X[..., idx_split_variable]
)
else:
idx_split_variable = node.idx_split_variable
left_node_index, right_node_index = get_idx_left_child(
node_index
), get_idx_right_child(node_index)
idx_split_variable = node.idx_split_variable
if excluded is not None and node.idx_split_variable in excluded:
if excluded is not None and idx_split_variable in excluded:
prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue
stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable))
stack.append(
Expand Down
35 changes: 14 additions & 21 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,14 @@ def plot_pdp(
fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax)

count = 0
fake_X = _create_pdp_data(X, xs_interval, xs_values)
for var in range(len(var_idx)):
excluded = indices[:]
excluded.remove(var)
fake_X, new_x = _create_pdp_data(X, xs_interval, var, xs_values, var_discrete)
p_d = _sample_posterior(
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
)

new_x = fake_X[:, var]
for s_i in range(shape):
p_di = func(p_d[:, :, s_i])
if var in var_discrete:
Expand Down Expand Up @@ -621,10 +621,8 @@ def _prepare_plot_data(
def _create_pdp_data(
X: npt.NDArray[np.float_],
xs_interval: str,
var: int,
xs_values: Optional[Union[int, List[float]]] = None,
var_discrete: Optional[List[int]] = None,
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]:
) -> npt.NDArray[np.float_]:
"""
Create data for partial dependence plot.
Expand All @@ -636,28 +634,23 @@ def _create_pdp_data(
Interval for x-axis. Available options are 'insample', 'linear' or 'quantiles'.
xs_values : int or list
Number of points for 'linear' or list of quantiles for 'quantiles'.
var : int
Index of variable of interest
var_discrete : None or list
Indices of discrete variables.
Returns
-------
Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]
A tuple containing a 2D array for the fake_X data and 1D array for new_x data.
npt.NDArray[np.float_]
A 2D array for the fake_X data.
"""
if xs_interval == "insample":
return X, X[:, var]
return X
else:
if var_discrete is not None and var in var_discrete:
new_x = np.unique(X[:, var])
else:
if xs_interval == "linear" and isinstance(xs_values, int):
new_x = np.linspace(np.nanmin(X[:, var]), np.nanmax(X[:, var]), xs_values)
elif xs_interval == "quantiles" and isinstance(xs_values, list):
new_x = np.quantile(X[:, var], q=xs_values)

return np.tile(new_x[:, None], X.shape[1]), new_x
if xs_interval == "linear" and isinstance(xs_values, int):
min_vals = np.min(X, axis=0)
max_vals = np.max(X, axis=0)
fake_X = np.linspace(min_vals, max_vals, num=xs_values, axis=0)
elif xs_interval == "quantiles" and isinstance(xs_values, list):
fake_X = np.quantile(X, q=xs_values, axis=0)

return fake_X


def _smooth_mean(
Expand Down

0 comments on commit 0917ab4

Please sign in to comment.