Skip to content

Commit

Permalink
small improvements (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz authored Jul 11, 2023
1 parent 05467b3 commit 592c6b2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
8 changes: 4 additions & 4 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from multiprocessing import Manager
from typing import List, Optional, Tuple
import warnings

import numpy as np
import numpy.typing as npt
Expand All @@ -26,9 +26,9 @@
from pymc.logprob.abstract import _logprob
from pytensor.tensor.random.op import RandomVariable

from .split_rules import SplitRule
from .tree import Tree
from .utils import TensorLike, _sample_posterior
from .split_rules import SplitRule

__all__ = ["BART"]

Expand Down Expand Up @@ -93,7 +93,7 @@ class BART(Distribution):
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
split_rules : Optional[SplitRule], default None
split_rules : Optional[List[SplitRule]], default None
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
Expand Down Expand Up @@ -127,7 +127,7 @@ def __new__(
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[List[float]] = None,
split_rules: Optional[SplitRule] = None,
split_rules: Optional[List[SplitRule]] = None,
separate_trees: Optional[bool] = False,
**kwargs,
):
Expand Down
5 changes: 5 additions & 0 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import numpy.typing as npt
from pytensor import config

from .split_rules import SplitRule


Expand Down Expand Up @@ -101,6 +102,10 @@ class Tree:
of the tree itself.
output: Optional[npt.NDArray[np.float_]]
Array of shape number of observations, shape
split_rules : List[SplitRule]
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
idx_leaf_nodes : Optional[List[int]], by default None.
Array with the index of the leaf nodes of the tree.
Expand Down
2 changes: 1 addition & 1 deletion pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def plot_pdp(
func : Optional[Callable], by default None.
Arbitrary function to apply to the predictions. Defaults to the identity function.
samples : int
Number of posterior samples used in the predictions. Defaults to 400
Number of posterior samples used in the predictions. Defaults to 200
random_seed : Optional[int], by default None.
Seed used to sample from the posterior. Defaults to None.
sharey : bool
Expand Down

0 comments on commit 592c6b2

Please sign in to comment.