From 3babcd958e9e71ad4ca427da33f502e03bfbdcc2 Mon Sep 17 00:00:00 2001 From: Patricio Cerda Mardini Date: Mon, 12 Jun 2023 20:27:29 -0400 Subject: [PATCH] fix numerical encoder sign none handling --- lightwood/encoder/numeric/numeric.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lightwood/encoder/numeric/numeric.py b/lightwood/encoder/numeric/numeric.py index c62a4ba31..1a1cf8b25 100644 --- a/lightwood/encoder/numeric/numeric.py +++ b/lightwood/encoder/numeric/numeric.py @@ -57,15 +57,16 @@ def encode(self, data: Union[np.ndarray, pd.Series]): if isinstance(data, pd.Series): data = data.values - data = np.nan_to_num(data.astype(float), nan=0, posinf=20, neginf=-20) - if not self.positive_domain: - sign = np.vectorize(self._sign_fn, otypes=[float])(data) + sign_data = np.nan_to_num(data, nan=0, posinf=0, neginf=0) + sign = np.vectorize(self._sign_fn, otypes=[float])(sign_data) else: sign = np.zeros(len(data)) - log_value = np.vectorize(self._log_fn, otypes=[float])(data) + log_value = np.nan_to_num(log_value, nan=0, posinf=20, neginf=-20) + norm = np.vectorize(self._norm_fn, otypes=[float])(data) + norm = np.nan_to_num(norm, nan=0, posinf=20, neginf=-20) if self.is_target: components = [sign, log_value, norm]