Skip to content

Commit

Permalink
fig bug with nans (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Dec 29, 2023
1 parent 5de28ff commit 7fb9c39
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

__all__ = ["BART", "PGBART"]
__version__ = "0.5.6"
__version__ = "0.5.7"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
7 changes: 3 additions & 4 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(

for idx, rule in enumerate(self.split_rules):
if rule is ContinuousSplitRule:
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.std(self.X[:, idx]))
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx]))

init_mean = self.bart.Y.mean()
self.num_observations = self.X.shape[0]
Expand Down Expand Up @@ -700,7 +700,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
if are_whole_number(array):
seen = []
for idx, num in enumerate(array):
if num in seen:
if num in seen and not np.isnan(num):
array[idx] = num + np.random.normal(0, std / 12)
else:
seen.append(num)
Expand All @@ -711,8 +711,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
@njit
def are_whole_number(array: npt.NDArray[np.float_]) -> np.bool_:
"""Check if all values in array are whole numbers"""
new_array = np.mod(array, 1)
return np.all(new_array == 0)
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)


def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
Expand Down

0 comments on commit 7fb9c39

Please sign in to comment.