diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 53db71f..79b4a64 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: run: | conda install pip pip install -r requirements-dev.txt - pip install . + pip install ".[lasso]" python --version conda list pip freeze diff --git a/README.md b/README.md index a5213fd..bb88930 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,10 @@ Assuming a standard Python environment is installed on your machine (including p pip install kulprit +By default Kulprit performs a forward search, if you want to use Lasso (L1 search) you need to install `scikit-learn` package. You can install it using pip: + + pip install kulprit[lasso] + Alternatively, if you want the bleeding edge version of the package you can install it from GitHub: pip install git+https://github.com/bambinos/kulprit.git diff --git a/docs/index.rst b/docs/index.rst index 8749604..d5036d6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,6 +30,13 @@ Assuming a standard Python environment is installed on your machine (including p pip install kulprit + +By default Kulprit performs a forward search, if you want to use Lasso (L1 search) you need to install `scikit-learn` package. You can install it using pip: + +.. code-block:: bash + + pip install kulprit[lasso] + Alternatively, if you want the bleeding edge version of the package you can install it from GitHub: .. code-block:: bash diff --git a/kulprit/projection/search_strategies.py b/kulprit/projection/search_strategies.py index 7bc8081..8d97741 100644 --- a/kulprit/projection/search_strategies.py +++ b/kulprit/projection/search_strategies.py @@ -1,7 +1,11 @@ """This module contains the search strategies""" import numpy as np -from sklearn.linear_model import lasso_path + +try: + from sklearn.linear_model import lasso_path +except ImportError: + pass from kulprit.projection.arviz_io import compute_loo diff --git a/pyproject.toml b/pyproject.toml index d6b60cc..03d965f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,11 @@ description = "Kullback-Leibler projections for Bayesian model selection." dependencies = [ "arviz>=0.17.1", "bambi>=0.12.0", - "scikit-learn>=1.0.2", - "numba>=0.56.0",] + "numba>=0.56.0", + ] + +[project.optional-dependencies] +lasso = ["scikit-learn>=1.0.2"] [tool.flit.module] name = "kulprit"