From 6ea1fe90cd6ad94671d663fc453b2acaad9d775b Mon Sep 17 00:00:00 2001 From: DanielaBreitman Date: Wed, 10 Apr 2024 15:33:57 +0200 Subject: [PATCH 1/2] Fix non-vectorized case w 21cmEMU --- src/py21cmmc/core.py | 15 +++++++++++---- src/py21cmmc/likelihood.py | 4 ++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/py21cmmc/core.py b/src/py21cmmc/core.py index 3a7c7934..3e8b7c80 100644 --- a/src/py21cmmc/core.py +++ b/src/py21cmmc/core.py @@ -1317,9 +1317,9 @@ def build_model_data(self, ctx): astro_params = self._update_params(astro_params).defining_dict astro_params = {k: astro_params[k] for k in self.astro_param_keys} if ( - all(isinstance(v, (np.ndarray, list, int, float)) for v in values) + all(isinstance(v, (np.ndarray, list)) for v in values) and len(values) > 0 - ): + ): lengths = [len(v) for v in values] if lengths.count(lengths[0]) != len(lengths): raise ValueError( @@ -1329,9 +1329,16 @@ def build_model_data(self, ctx): for t in zip(*values): ap.append(dict(zip(keys, t))) astro_params = np.array(ap, dtype=object) + if ( + all(isinstance(v, (float, int)) for v in values) + and len(values) > 0 + ): + astro_params = dict(zip(keys, values)) + astro_params = np.array([astro_params], dtype=object) logger.debug(f"AstroParams: {astro_params}") - + n = len(astro_params) theta, outputs, errors = self.emulator.predict(astro_params=astro_params) + if self.io_options["cache_dir"] is not None: if len(astro_params.shape) == 2: pars = astro_params[0] @@ -1347,7 +1354,7 @@ def build_model_data(self, ctx): logger.debug(f"Adding {self.ctx_variables} to context data") for key in self.ctx_variables: try: - ctx.add(key + self.name, getattr(outputs, key)) + ctx.add(key + self.name, getattr(outputs, key) if n > 1 else getattr(outputs, key)[np.newaxis,...]) except AttributeError: try: ctx.add(key + self.name, errors[key]) diff --git a/src/py21cmmc/likelihood.py b/src/py21cmmc/likelihood.py index fc1603d2..3f001718 100644 --- a/src/py21cmmc/likelihood.py +++ b/src/py21cmmc/likelihood.py @@ -1436,8 +1436,8 @@ def reduce_data(self, ctx): def computeLikelihood(self, model): """Compute the likelihood.""" - n = model["xHI"].shape[0] xHI = np.atleast_2d(model["xHI"]) + n = xHI.shape[0] lnprob = np.zeros(n) for i in range(n): if self._require_spline: @@ -1462,7 +1462,7 @@ def computeLikelihood(self, model): lnprob[i] += self.lnprob(model_spline(z), data, sigma_t) logger.debug(f"Neutral fraction Likelihood computed: {lnprob}") - return lnprob + return lnprob.squeeze() def lnprob(self, model, data, sigma): """Compute the log prob given a model, data and error.""" From 1d8d5a618a13e6b9a64cf4560afbfca28f2fcc36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:36:01 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/py21cmmc/core.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/py21cmmc/core.py b/src/py21cmmc/core.py index 3e8b7c80..a01dc0da 100644 --- a/src/py21cmmc/core.py +++ b/src/py21cmmc/core.py @@ -1316,10 +1316,7 @@ def build_model_data(self, ctx): if len(values) == 0: astro_params = self._update_params(astro_params).defining_dict astro_params = {k: astro_params[k] for k in self.astro_param_keys} - if ( - all(isinstance(v, (np.ndarray, list)) for v in values) - and len(values) > 0 - ): + if all(isinstance(v, (np.ndarray, list)) for v in values) and len(values) > 0: lengths = [len(v) for v in values] if lengths.count(lengths[0]) != len(lengths): raise ValueError( @@ -1329,10 +1326,7 @@ def build_model_data(self, ctx): for t in zip(*values): ap.append(dict(zip(keys, t))) astro_params = np.array(ap, dtype=object) - if ( - all(isinstance(v, (float, int)) for v in values) - and len(values) > 0 - ): + if all(isinstance(v, (float, int)) for v in values) and len(values) > 0: astro_params = dict(zip(keys, values)) astro_params = np.array([astro_params], dtype=object) logger.debug(f"AstroParams: {astro_params}") @@ -1354,7 +1348,12 @@ def build_model_data(self, ctx): logger.debug(f"Adding {self.ctx_variables} to context data") for key in self.ctx_variables: try: - ctx.add(key + self.name, getattr(outputs, key) if n > 1 else getattr(outputs, key)[np.newaxis,...]) + ctx.add( + key + self.name, + getattr(outputs, key) + if n > 1 + else getattr(outputs, key)[np.newaxis, ...], + ) except AttributeError: try: ctx.add(key + self.name, errors[key])