diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..7701a53 --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +exclude = + .git, + __pycache__, + build, + dist, + docs/source/conf.py +max-line-length = 127 +show-source = True +statistics = True +count = True +verbose = 1 +ignore = E203, W503 diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 779cd9b..3183a4f 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -15,11 +15,14 @@ on: jobs: build-linux: - runs-on: ubuntu-latest + strategy: max-parallel: 5 matrix: - python-version: ['3.9', '3.10'] + python-version: ['3.9', '3.10', '3.11'] + os: [ubuntu-latest] + + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/notebook_smoke.yml b/.github/workflows/notebook_smoke.yml new file mode 100644 index 0000000..5c496d0 --- /dev/null +++ b/.github/workflows/notebook_smoke.yml @@ -0,0 +1,54 @@ +name: notebooks + +on: + pull_request: + branches: + - '*' + push: + branches: + - '*' + tags: + - '*' + +jobs: + build-linux: + + + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] + os: [ubuntu-latest] + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + sudo apt-get update -qq + python -m pip install --upgrade pip + python -m pip install flake8 pytest + python -m pip install jaxlib + python -m pip install jax + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: install package + run: | + pip install . + pip list + + - name: Notebook smoke tests + run: | + pip install ipython + pip install nbformat + pip install seaborn + cd examples + ipython -c "%run simpleGP.ipynb" + diff --git a/gpax/models/gp.py b/gpax/models/gp.py index 2c4a068..19b8b08 100644 --- a/gpax/models/gp.py +++ b/gpax/models/gp.py @@ -23,7 +23,7 @@ from ..kernels import get_kernel from ..utils import split_in_batches -kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] +kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] clear_cache = jax._src.dispatch.xla_primitive_callable.cache_clear @@ -62,7 +62,7 @@ class ExactGP: >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True) GP with custom noise prior - + >>> gp_model = gpax.ExactGP( >>> input_dim=1, kernel='RBF', >>> noise_prior_dist = numpyro.distributions.HalfNormal(.1) @@ -73,7 +73,7 @@ class ExactGP: >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True) GP with custom probabilistic model as its mean function - + >>> # Define a deterministic mean function >>> mean_fn = lambda x, param: param["a"]*x + param["b"] >>> @@ -93,25 +93,34 @@ class ExactGP: >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True) """ - def __init__(self, input_dim: int, kernel: Union[str, kernel_fn_type], - mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, - kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, - mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, - noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, - noise_prior_dist: Optional[dist.Distribution] = None, - lengthscale_prior_dist: Optional[dist.Distribution] = None - ) -> None: + def __init__( + self, + input_dim: int, + kernel: Union[str, kernel_fn_type], + mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, + kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior_dist: Optional[dist.Distribution] = None, + lengthscale_prior_dist: Optional[dist.Distribution] = None, + ) -> None: clear_cache() if noise_prior is not None: - warnings.warn("`noise_prior` is deprecated and will be removed in a future version. " - "Please use `noise_prior_dist` instead, which accepts an instance of a " - "numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, " - "rather than a function that calls `numpyro.sample`.", FutureWarning) + warnings.warn( + "`noise_prior` is deprecated and will be removed in a future version. " + "Please use `noise_prior_dist` instead, which accepts an instance of a " + "numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, " + "rather than a function that calls `numpyro.sample`.", + FutureWarning, + ) if kernel_prior is not None: - warnings.warn("`kernel_prior` will remain available for complex priors. However, for " - "modifying only the lengthscales, it is recommended to use `lengthscale_prior_dist` instead. " - "`lengthscale_prior_dist` accepts an instance of a numpyro.distributions Distribution object, " - "e.g., `dist.Gamma(2, 5)`, rather than a function that calls `numpyro.sample`.", UserWarning) + warnings.warn( + "`kernel_prior` will remain available for complex priors. However, for " + "modifying only the lengthscales, it is recommended to use `lengthscale_prior_dist` instead. " + "`lengthscale_prior_dist` accepts an instance of a numpyro.distributions Distribution object, " + "e.g., `dist.Gamma(2, 5)`, rather than a function that calls `numpyro.sample`.", + UserWarning, + ) self.kernel_dim = input_dim self.kernel = get_kernel(kernel) self.kernel_name = kernel if isinstance(kernel, str) else None @@ -125,11 +134,7 @@ def __init__(self, input_dim: int, kernel: Union[str, kernel_fn_type], self.y_train = None self.mcmc = None - def model(self, - X: jnp.ndarray, - y: jnp.ndarray = None, - **kwargs: float - ) -> None: + def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: """GP probabilistic model with inputs X and targets y""" # Initialize mean function at zeros f_loc = jnp.zeros(X.shape[0]) @@ -150,12 +155,7 @@ def model(self, args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() # compute kernel - k = self.kernel( - X, X, - kernel_params, - noise, - **kwargs - ) + k = self.kernel(X, X, kernel_params, noise, **kwargs) # sample y according to the standard Gaussian process formula numpyro.sample( "y", @@ -163,13 +163,20 @@ def model(self, obs=y, ) - def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, - num_warmup: int = 2000, num_samples: int = 2000, - num_chains: int = 1, chain_method: str = 'sequential', - progress_bar: bool = True, print_summary: bool = True, - device: Type[jaxlib.xla_extension.Device] = None, - **kwargs: float - ) -> None: + def fit( + self, + rng_key: jnp.array, + X: jnp.ndarray, + y: jnp.ndarray, + num_warmup: int = 2000, + num_samples: int = 2000, + num_chains: int = 1, + chain_method: str = "sequential", + progress_bar: bool = True, + print_summary: bool = True, + device: Type[jaxlib.xla_extension.Device] = None, + **kwargs: float + ) -> None: """ Run Hamiltonian Monter Carlo to infer the GP parameters @@ -185,7 +192,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, print_summary: print summary at the end of sampling device: optionally specify a cpu or gpu device on which to run the inference; - e.g., ``device=jax.devices("cpu")[0]`` + e.g., ``device=jax.devices("cpu")[0]`` **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) @@ -206,7 +213,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, num_chains=num_chains, chain_method=chain_method, progress_bar=progress_bar, - jit_model_args=False + jit_model_args=False, ) self.mcmc.run(rng_key, X, y, **kwargs) if print_summary: @@ -228,28 +235,25 @@ def _sample_kernel_params(self, output_scale=True) -> Dict[str, jnp.ndarray]: length_dist = self.lengthscale_prior_dist else: length_dist = dist.LogNormal(0.0, 1.0) - with numpyro.plate('ard', self.kernel_dim): # allows using ARD kernel for kernel_dim > 1 + with numpyro.plate("ard", self.kernel_dim): # allows using ARD kernel for kernel_dim > 1 length = numpyro.sample("k_length", length_dist) if output_scale: scale = numpyro.sample("k_scale", dist.LogNormal(0.0, 1.0)) else: scale = numpyro.deterministic("k_scale", jnp.array(1.0)) - if self.kernel_name == 'Periodic': + if self.kernel_name == "Periodic": period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) - kernel_params = { - "k_length": length, "k_scale": scale, - "period": period if self.kernel_name == "Periodic" else None} + kernel_params = {"k_length": length, "k_scale": scale, "period": period if self.kernel_name == "Periodic" else None} return kernel_params def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: """Get posterior samples (after running the MCMC chains)""" return self.mcmc.get_samples(group_by_chain=chain_dim) - #@partial(jit, static_argnames='self') - def get_mvn_posterior(self, - X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], - noiseless: bool = False, **kwargs: float - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + # @partial(jit, static_argnames='self') + def get_mvn_posterior( + self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior for a single sample of GP parameters @@ -273,9 +277,15 @@ def get_mvn_posterior(self, mean += self.mean_fn(*args).squeeze() return mean, cov - def _predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, - params: Dict[str, jnp.ndarray], n: int, noiseless: bool = False, - **kwargs: float) -> Tuple[jnp.ndarray, jnp.ndarray]: + def _predict( + self, + rng_key: jnp.ndarray, + X_new: jnp.ndarray, + params: Dict[str, jnp.ndarray], + n: int, + noiseless: bool = False, + **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Prediction with a single sample of GP parameters""" # Get the predictive mean and covariance y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, **kwargs) @@ -283,20 +293,22 @@ def _predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled - def _predict_in_batches(self, rng_key: jnp.ndarray, - X_new: jnp.ndarray, batch_size: int = 100, - batch_dim: int = 0, - samples: Optional[Dict[str, jnp.ndarray]] = None, - n: int = 1, filter_nans: bool = False, - predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, - noiseless: bool = False, - device: Type[jaxlib.xla_extension.Device] = None, - **kwargs: float - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - + def _predict_in_batches( + self, + rng_key: jnp.ndarray, + X_new: jnp.ndarray, + batch_size: int = 100, + batch_dim: int = 0, + samples: Optional[Dict[str, jnp.ndarray]] = None, + n: int = 1, + filter_nans: bool = False, + predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, + noiseless: bool = False, + device: Type[jaxlib.xla_extension.Device] = None, + **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: if predict_fn is None: - predict_fn = lambda xi: self.predict( - rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs) + predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs) def predict_batch(Xi): out1, out2 = predict_fn(Xi) @@ -311,15 +323,19 @@ def predict_batch(Xi): y_out2.append(out2) return y_out1, y_out2 - def predict_in_batches(self, rng_key: jnp.ndarray, - X_new: jnp.ndarray, batch_size: int = 100, - samples: Optional[Dict[str, jnp.ndarray]] = None, - n: int = 1, filter_nans: bool = False, - predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, - noiseless: bool = False, - device: Type[jaxlib.xla_extension.Device] = None, - **kwargs: float - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + def predict_in_batches( + self, + rng_key: jnp.ndarray, + X_new: jnp.ndarray, + batch_size: int = 100, + samples: Optional[Dict[str, jnp.ndarray]] = None, + n: int = 1, + filter_nans: bool = False, + predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, + noiseless: bool = False, + device: Type[jaxlib.xla_extension.Device] = None, + **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Make prediction at X_new with sampled GP parameters by spitting the input array into chunks ("batches") and running @@ -327,17 +343,23 @@ def predict_in_batches(self, rng_key: jnp.ndarray, to avoid a memory overflow """ y_pred, y_sampled = self._predict_in_batches( - rng_key, X_new, batch_size, 0, samples, n, - filter_nans, predict_fn, noiseless, device, **kwargs) + rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, **kwargs + ) y_pred = jnp.concatenate(y_pred, 0) y_sampled = jnp.concatenate(y_sampled, -1) return y_pred, y_sampled - def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, - samples: Optional[Dict[str, jnp.ndarray]] = None, - n: int = 1, filter_nans: bool = False, noiseless: bool = False, - device: Type[jaxlib.xla_extension.Device] = None, **kwargs: float - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + def predict( + self, + rng_key: jnp.ndarray, + X_new: jnp.ndarray, + samples: Optional[Dict[str, jnp.ndarray]] = None, + n: int = 1, + filter_nans: bool = False, + noiseless: bool = False, + device: Type[jaxlib.xla_extension.Device] = None, + **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Make prediction at X_new points using posterior samples for GP parameters @@ -370,38 +392,34 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, samples = jax.device_put(samples, device) num_samples = samples["noise"].shape[0] vmap_args = (jra.split(rng_key, num_samples), samples) - predictive = jax.vmap( - lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs)) + predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs)) y_means, y_sampled = predictive(vmap_args) if filter_nans: y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()] y_sampled = jnp.array(y_sampled_) return y_means.mean(0), y_sampled - def sample_from_prior(self, rng_key: jnp.ndarray, - X: jnp.ndarray, num_samples: int = 10): + def sample_from_prior(self, rng_key: jnp.ndarray, X: jnp.ndarray, num_samples: int = 10): """ Samples from prior predictive distribution at X """ X = self._set_data(X) prior_predictive = Predictive(self.model, num_samples=num_samples) samples = prior_predictive(rng_key, X) - return samples['y'] + return samples["y"] - def _set_data(self, - X: jnp.ndarray, - y: Optional[jnp.ndarray] = None - ) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: + def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: X = X if X.ndim > 1 else X[:, None] if y is not None: return X, y.squeeze() return X - def _set_training_data(self, - X_train_new: jnp.ndarray = None, - y_train_new: jnp.ndarray = None, - device: Type[jaxlib.xla_extension.Device] = None - ) -> None: + def _set_training_data( + self, + X_train_new: jnp.ndarray = None, + y_train_new: jnp.ndarray = None, + device: Type[jaxlib.xla_extension.Device] = None, + ) -> None: X_train = self.X_train if X_train_new is None else X_train_new y_train = self.y_train if y_train_new is None else y_train_new if device: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1eeed4c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[tool.black] +line-length = 127 +include = '\.pyi?$' +exclude = ''' +/( + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | docs/source/conf.py +)/ +'''