diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index e7590a3..aa644ff 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -296,11 +296,11 @@ def _traverse_tree( params[0][nd_dims] + params[1][nd_dims] * X[..., idx_split_variable] ) else: + idx_split_variable = node.idx_split_variable left_node_index, right_node_index = get_idx_left_child( node_index ), get_idx_right_child(node_index) - idx_split_variable = node.idx_split_variable - if excluded is not None and node.idx_split_variable in excluded: + if excluded is not None and idx_split_variable in excluded: prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable)) stack.append( diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 250fc51..0840e0e 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -408,14 +408,14 @@ def plot_pdp( fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax) count = 0 + fake_X = _create_pdp_data(X, xs_interval, xs_values) for var in range(len(var_idx)): excluded = indices[:] excluded.remove(var) - fake_X, new_x = _create_pdp_data(X, xs_interval, var, xs_values, var_discrete) p_d = _sample_posterior( all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape ) - + new_x = fake_X[:, var] for s_i in range(shape): p_di = func(p_d[:, :, s_i]) if var in var_discrete: @@ -621,10 +621,8 @@ def _prepare_plot_data( def _create_pdp_data( X: npt.NDArray[np.float_], xs_interval: str, - var: int, xs_values: Optional[Union[int, List[float]]] = None, - var_discrete: Optional[List[int]] = None, -) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]: +) -> npt.NDArray[np.float_]: """ Create data for partial dependence plot. @@ -636,28 +634,23 @@ def _create_pdp_data( Interval for x-axis. Available options are 'insample', 'linear' or 'quantiles'. xs_values : int or list Number of points for 'linear' or list of quantiles for 'quantiles'. - var : int - Index of variable of interest - var_discrete : None or list - Indices of discrete variables. Returns ------- - Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]] - A tuple containing a 2D array for the fake_X data and 1D array for new_x data. + npt.NDArray[np.float_] + A 2D array for the fake_X data. """ if xs_interval == "insample": - return X, X[:, var] + return X else: - if var_discrete is not None and var in var_discrete: - new_x = np.unique(X[:, var]) - else: - if xs_interval == "linear" and isinstance(xs_values, int): - new_x = np.linspace(np.nanmin(X[:, var]), np.nanmax(X[:, var]), xs_values) - elif xs_interval == "quantiles" and isinstance(xs_values, list): - new_x = np.quantile(X[:, var], q=xs_values) - - return np.tile(new_x[:, None], X.shape[1]), new_x + if xs_interval == "linear" and isinstance(xs_values, int): + min_vals = np.min(X, axis=0) + max_vals = np.max(X, axis=0) + fake_X = np.linspace(min_vals, max_vals, num=xs_values, axis=0) + elif xs_interval == "quantiles" and isinstance(xs_values, list): + fake_X = np.quantile(X, q=xs_values, axis=0) + + return fake_X def _smooth_mean(