From 592c6b271be6514a35c0a349d220deaf62aabe9f Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Tue, 11 Jul 2023 18:03:50 +0200 Subject: [PATCH] small improvements (#108) --- pymc_bart/bart.py | 8 ++++---- pymc_bart/tree.py | 5 +++++ pymc_bart/utils.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 86d6314..8138d73 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from multiprocessing import Manager from typing import List, Optional, Tuple -import warnings import numpy as np import numpy.typing as npt @@ -26,9 +26,9 @@ from pymc.logprob.abstract import _logprob from pytensor.tensor.random.op import RandomVariable +from .split_rules import SplitRule from .tree import Tree from .utils import TensorLike, _sample_posterior -from .split_rules import SplitRule __all__ = ["BART"] @@ -93,7 +93,7 @@ class BART(Distribution): Each element of split_prior should be in the [0, 1] interval and the elements should sum to 1. Otherwise they will be normalized. Defaults to 0, i.e. all covariates have the same prior probability to be selected. - split_rules : Optional[SplitRule], default None + split_rules : Optional[List[SplitRule]], default None List of SplitRule objects, one per column in input data. Allows using different split rules for different columns. Default is ContinuousSplitRule. Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. @@ -127,7 +127,7 @@ def __new__( beta: float = 2.0, response: str = "constant", split_prior: Optional[List[float]] = None, - split_rules: Optional[SplitRule] = None, + split_rules: Optional[List[SplitRule]] = None, separate_trees: Optional[bool] = False, **kwargs, ): diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index aa644ff..6c6c297 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -18,6 +18,7 @@ import numpy as np import numpy.typing as npt from pytensor import config + from .split_rules import SplitRule @@ -101,6 +102,10 @@ class Tree: of the tree itself. output: Optional[npt.NDArray[np.float_]] Array of shape number of observations, shape + split_rules : List[SplitRule] + List of SplitRule objects, one per column in input data. + Allows using different split rules for different columns. Default is ContinuousSplitRule. + Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. idx_leaf_nodes : Optional[List[int]], by default None. Array with the index of the leaf nodes of the tree. diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 0840e0e..55855a4 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -357,7 +357,7 @@ def plot_pdp( func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. samples : int - Number of posterior samples used in the predictions. Defaults to 400 + Number of posterior samples used in the predictions. Defaults to 200 random_seed : Optional[int], by default None. Seed used to sample from the posterior. Defaults to None. sharey : bool