From 52e96ea1c9f9db636b4fd88dad4fb006b43dcfca Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Mon, 15 May 2023 19:07:17 +0200 Subject: [PATCH 1/9] edit installation instructions in readme --- README.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 754dc795..ffd201b7 100644 --- a/README.rst +++ b/README.rst @@ -150,9 +150,12 @@ If you choose to pursue this way, first install Poetry and add it to your PATH poetry install -All the dependencies will be installed at their required versions. -If you also want to install the optional Sphinx dependencies to build the documentation, -add the flag :code:`-E docs` to the command above. +All the dependencies will be installed at their required versions. Consider adding the following flags to the command above: + +- :code:`-E transformers` if you want to use models and datasets from `Hugging Face `_. +- :code:`-E docs` if you want to install Sphinx dependencies to build the documentation. +- :code:`-E notebooks` if you want to work with Jupyter notebooks. + Finally, you can either access the virtualenv that Poetry created by typing :code:`poetry shell`, or execute commands within the virtualenv using the :code:`run` command, e.g. :code:`poetry run python`. From 6cb6581b9574306dc4acfdabc4fe1b92cfcb1241 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Mon, 15 May 2023 21:25:25 +0200 Subject: [PATCH 2/9] bump up version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9adeaa0b..beedbe3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws-fortuna" -version = "0.1.15" +version = "0.1.16" description = "A Library for Uncertainty Quantification." authors = ["Gianluca Detommaso ", "Alberto Gasparin "] license = "Apache-2.0" From b2540c18ef0cfaaf23b49b87f8ba157187b537f0 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Tue, 18 Jul 2023 10:30:59 +0200 Subject: [PATCH 3/9] make small change in readme because of publish to pypi error --- README.rst | 6 +- poetry.lock | 274 +++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 266 insertions(+), 14 deletions(-) diff --git a/README.rst b/README.rst index 54018894..b1687322 100644 --- a/README.rst +++ b/README.rst @@ -183,12 +183,12 @@ We offer a simple pipeline that allows you to run Fortuna on Amazon SageMaker wi 3. Create an `S3 bucket `_. You will need this to dump the results from your training jobs on Amazon Sagemaker. -3. Write a configuration `yaml` file. This will include your AWS details, the path to the entrypoint script that you want +4. Write a configuration `yaml` file. This will include your AWS details, the path to the entrypoint script that you want to run on Amazon SageMaker, the arguments to pass to the script, the path to the S3 bucket where you want to dump the results, the metrics to monitor, and more. - See `here `_ for an example. + Check `this file `_ for an example. -4. Finally, given :code:`config_dir`, that is the absolute path to the main configuration directory, +5. Finally, given :code:`config_dir`, that is the absolute path to the main configuration directory, and :code:`config_filename`, that is the name of the main configuration file (without .yaml extension), enter Python and run the following: diff --git a/poetry.lock b/poetry.lock index 5450bd50..4f6aa650 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "absl-py" version = "1.4.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -15,6 +16,7 @@ files = [ name = "absolufy-imports" version = "0.3.1" description = "A tool to automatically replace relative imports with absolute ones." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -26,6 +28,7 @@ files = [ name = "aiohttp" version = "3.8.4" description = "Async http client/server framework (asyncio)" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -134,6 +137,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -148,6 +152,7 @@ frozenlist = ">=1.1.0" name = "alabaster" version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -159,6 +164,7 @@ files = [ name = "antlr4-python3-runtime" version = "4.9.3" description = "ANTLR 4.9.3 runtime for Python 3.7" +category = "main" optional = true python-versions = "*" files = [ @@ -169,6 +175,7 @@ files = [ name = "anyio" version = "3.6.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" optional = true python-versions = ">=3.6.2" files = [ @@ -189,6 +196,7 @@ trio = ["trio (>=0.16,<0.22)"] name = "appdirs" version = "1.4.4" description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = "*" files = [ @@ -200,6 +208,7 @@ files = [ name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "main" optional = false python-versions = "*" files = [ @@ -211,6 +220,7 @@ files = [ name = "argon2-cffi" version = "21.3.0" description = "The secure Argon2 password hashing algorithm." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -230,6 +240,7 @@ tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -267,6 +278,7 @@ tests = ["pytest"] name = "array-record" version = "0.2.0" description = "A file format that achieves a new frontier of IO efficiency" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -283,6 +295,7 @@ etils = {version = "*", extras = ["epath"]} name = "arrow" version = "1.2.3" description = "Better dates & times for Python" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -297,6 +310,7 @@ python-dateutil = ">=2.7.0" name = "asttokens" version = "2.2.1" description = "Annotate AST trees with source code positions" +category = "main" optional = false python-versions = "*" files = [ @@ -314,6 +328,7 @@ test = ["astroid", "pytest"] name = "astunparse" version = "1.6.3" description = "An AST unparser for Python" +category = "main" optional = false python-versions = "*" files = [ @@ -329,6 +344,7 @@ wheel = ">=0.23.0,<1.0" name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -340,6 +356,7 @@ files = [ name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -358,6 +375,7 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "babel" version = "2.12.1" description = "Internationalization utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -372,6 +390,7 @@ pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" +category = "main" optional = false python-versions = "*" files = [ @@ -383,6 +402,7 @@ files = [ name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" +category = "main" optional = false python-versions = ">=3.6.0" files = [ @@ -401,6 +421,7 @@ lxml = ["lxml"] name = "bleach" version = "6.0.0" description = "An easy safelist-based HTML-sanitizing tool." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -419,6 +440,7 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] name = "boto3" version = "1.26.145" description = "The AWS SDK for Python" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -438,6 +460,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.29.145" description = "Low-level, data-driven core of boto 3." +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -457,6 +480,7 @@ crt = ["awscrt (==0.16.9)"] name = "cached-property" version = "1.5.2" description = "A decorator for caching properties in classes." +category = "main" optional = false python-versions = "*" files = [ @@ -468,6 +492,7 @@ files = [ name = "cachetools" version = "5.3.0" description = "Extensible memoizing collections and decorators" +category = "main" optional = false python-versions = "~=3.7" files = [ @@ -479,6 +504,7 @@ files = [ name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -490,6 +516,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." +category = "main" optional = false python-versions = "*" files = [ @@ -566,6 +593,7 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -577,6 +605,7 @@ files = [ name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -661,6 +690,7 @@ files = [ name = "chex" version = "0.1.7" description = "Chex: Testing made fun, in JAX!" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -681,6 +711,7 @@ typing-extensions = {version = ">=4.2.0", markers = "python_version < \"3.11\""} name = "click" version = "8.1.3" description = "Composable command line interface toolkit" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -695,6 +726,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "cloudpickle" version = "2.2.1" description = "Extended pickling support for Python objects" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -706,6 +738,7 @@ files = [ name = "codespell" version = "2.2.4" description = "Codespell" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -723,6 +756,7 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -734,6 +768,7 @@ files = [ name = "comm" version = "0.1.3" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -753,6 +788,7 @@ typing = ["mypy (>=0.990)"] name = "contextlib2" version = "21.6.0" description = "Backports and enhancements for the contextlib module" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -764,6 +800,7 @@ files = [ name = "contourpy" version = "1.0.7" description = "Python library for calculating contours of 2D quadrilateral grids" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -838,6 +875,7 @@ test-no-images = ["pytest"] name = "coverage" version = "7.2.5" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -904,6 +942,7 @@ toml = ["tomli"] name = "cryptography" version = "41.0.0" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -945,6 +984,7 @@ test-randomorder = ["pytest-randomly"] name = "cycler" version = "0.11.0" description = "Composable style cycles" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -956,6 +996,7 @@ files = [ name = "datasets" version = "2.12.0" description = "HuggingFace community-driven open-source library of datasets" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -999,6 +1040,7 @@ vision = ["Pillow (>=6.2.1)"] name = "debugpy" version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1026,6 +1068,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1037,6 +1080,7 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -1048,6 +1092,7 @@ files = [ name = "dill" version = "0.3.6" description = "serialize all of python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1062,6 +1107,7 @@ graph = ["objgraph (>=1.7.2)"] name = "distlib" version = "0.3.6" description = "Distribution utilities" +category = "dev" optional = false python-versions = "*" files = [ @@ -1073,6 +1119,7 @@ files = [ name = "dm-tree" version = "0.1.8" description = "Tree is a library for working with nested data structures." +category = "main" optional = false python-versions = "*" files = [ @@ -1121,6 +1168,7 @@ files = [ name = "docutils" version = "0.19" description = "Docutils -- Python Documentation Utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1132,6 +1180,7 @@ files = [ name = "et-xmlfile" version = "1.1.0" description = "An implementation of lxml.xmlfile for the standard library" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1143,6 +1192,7 @@ files = [ name = "etils" version = "1.2.0" description = "Collection of common python utils" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1177,6 +1227,7 @@ lazy-imports = ["etils[ecolab]"] name = "exceptiongroup" version = "1.1.1" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1191,6 +1242,7 @@ test = ["pytest (>=6)"] name = "executing" version = "1.2.0" description = "Get the currently executing AST node of a frame, and other information" +category = "main" optional = false python-versions = "*" files = [ @@ -1205,6 +1257,7 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] name = "fastjsonschema" version = "2.16.3" description = "Fastest Python implementation of JSON schema" +category = "main" optional = false python-versions = "*" files = [ @@ -1219,6 +1272,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.12.0" description = "A platform independent file lock." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1234,6 +1288,7 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "p name = "flatbuffers" version = "23.5.9" description = "The FlatBuffers serialization format for Python" +category = "main" optional = false python-versions = "*" files = [ @@ -1245,6 +1300,7 @@ files = [ name = "flax" version = "0.6.10" description = "Flax: A neural network library for JAX designed for flexibility" +category = "main" optional = false python-versions = "*" files = [ @@ -1271,6 +1327,7 @@ testing = ["atari-py (==0.2.5)", "clu", "einops", "gym (==0.18.3)", "jaxlib", "j name = "fonttools" version = "4.39.4" description = "Tools to manipulate font files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1296,6 +1353,7 @@ woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] name = "fqdn" version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" +category = "main" optional = true python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" files = [ @@ -1307,6 +1365,7 @@ files = [ name = "frozendict" version = "2.3.8" description = "A simple immutable dictionary" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1353,6 +1412,7 @@ files = [ name = "frozenlist" version = "1.3.3" description = "A list-like structure which implements collections.abc.MutableSequence" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1436,6 +1496,7 @@ files = [ name = "fsspec" version = "2023.5.0" description = "File-system specification" +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -1475,6 +1536,7 @@ tqdm = ["tqdm"] name = "gast" version = "0.4.0" description = "Python AST that abstracts the underlying Python version" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1486,6 +1548,7 @@ files = [ name = "google-auth" version = "2.18.0" description = "Google Authentication Library" +category = "main" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" files = [ @@ -1511,6 +1574,7 @@ requests = ["requests (>=2.20.0,<3.0.0dev)"] name = "google-auth-oauthlib" version = "1.0.0" description = "Google Authentication Library" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -1529,6 +1593,7 @@ tool = ["click (>=6.0.0)"] name = "google-pasta" version = "0.2.0" description = "pasta is an AST-based Python refactoring library" +category = "main" optional = false python-versions = "*" files = [ @@ -1544,6 +1609,7 @@ six = "*" name = "googleapis-common-protos" version = "1.59.0" description = "Common protobufs used in Google APIs" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1561,6 +1627,7 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"] name = "greenlet" version = "2.0.2" description = "Lightweight in-process concurrent programming" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*" files = [ @@ -1634,6 +1701,7 @@ test = ["objgraph", "psutil"] name = "grpcio" version = "1.54.0" description = "HTTP/2-based RPC framework" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1691,6 +1759,7 @@ protobuf = ["grpcio-tools (>=1.54.0)"] name = "h5py" version = "3.8.0" description = "Read and write HDF5 files from Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1728,6 +1797,7 @@ numpy = ">=1.14.5" name = "html5lib" version = "1.1" description = "HTML parser based on the WHATWG HTML specification" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -1749,6 +1819,7 @@ lxml = ["lxml"] name = "huggingface-hub" version = "0.14.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -1780,6 +1851,7 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t name = "hydra-core" version = "1.3.2" description = "A framework for elegantly configuring complex applications" +category = "main" optional = true python-versions = "*" files = [ @@ -1788,7 +1860,7 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = "==4.9.*" +antlr4-python3-runtime = ">=4.9.0,<4.10.0" importlib-resources = {version = "*", markers = "python_version < \"3.9\""} omegaconf = ">=2.2,<2.4" packaging = "*" @@ -1797,6 +1869,7 @@ packaging = "*" name = "identify" version = "2.5.24" description = "File identification library for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1811,6 +1884,7 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1822,6 +1896,7 @@ files = [ name = "imagesize" version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1833,6 +1908,7 @@ files = [ name = "importlib-metadata" version = "4.13.0" description = "Read metadata from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1852,6 +1928,7 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.12.0" description = "Read resources from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1870,6 +1947,7 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1881,6 +1959,7 @@ files = [ name = "ipykernel" version = "6.23.0" description = "IPython Kernel for Jupyter" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1894,7 +1973,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -1914,6 +1993,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipython" version = "8.12.2" description = "IPython: Productive Interactive Computing" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1953,6 +2033,7 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa name = "ipython-genutils" version = "0.2.0" description = "Vestigial utilities from IPython" +category = "main" optional = true python-versions = "*" files = [ @@ -1964,6 +2045,7 @@ files = [ name = "ipywidgets" version = "8.0.6" description = "Jupyter interactive widgets" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1985,6 +2067,7 @@ test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] name = "isoduration" version = "20.11.0" description = "Operations with ISO 8601 durations" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1999,6 +2082,7 @@ arrow = ">=0.15.0" name = "jax" version = "0.4.10" description = "Differentiate, compile, and transform Numpy code." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2029,6 +2113,7 @@ tpu = ["jaxlib (==0.4.10)", "libtpu-nightly (==0.1.dev20230511)", "requests"] name = "jaxlib" version = "0.4.10" description = "XLA library for JAX" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2055,6 +2140,7 @@ scipy = ">=1.7" name = "jedi" version = "0.18.2" description = "An autocompletion tool for Python that can be used for text editors." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2074,6 +2160,7 @@ testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2091,6 +2178,7 @@ i18n = ["Babel (>=2.7)"] name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2102,6 +2190,7 @@ files = [ name = "joblib" version = "1.2.0" description = "Lightweight pipelining with Python functions" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2113,6 +2202,7 @@ files = [ name = "jsonpointer" version = "2.3" description = "Identify specific nodes in a JSON document (RFC 6901)" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2124,6 +2214,7 @@ files = [ name = "jsonschema" version = "4.17.3" description = "An implementation of JSON Schema validation for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2153,6 +2244,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jupyter" version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." +category = "main" optional = true python-versions = "*" files = [ @@ -2173,6 +2265,7 @@ qtconsole = "*" name = "jupyter-cache" version = "0.6.1" description = "A defined interface for working with a cache of jupyter notebooks." +category = "dev" optional = false python-versions = "~=3.8" files = [ @@ -2200,6 +2293,7 @@ testing = ["coverage", "ipykernel", "jupytext", "matplotlib", "nbdime", "nbforma name = "jupyter-client" version = "8.2.0" description = "Jupyter protocol implementation and client libraries" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2209,7 +2303,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -2223,6 +2317,7 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-console" version = "6.6.3" description = "Jupyter terminal console" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2234,7 +2329,7 @@ files = [ ipykernel = ">=6.14" ipython = "*" jupyter-client = ">=7.0.0" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" prompt-toolkit = ">=3.0.30" pygments = "*" pyzmq = ">=17" @@ -2247,6 +2342,7 @@ test = ["flaky", "pexpect", "pytest"] name = "jupyter-core" version = "5.3.0" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2267,6 +2363,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-events" version = "0.6.3" description = "Jupyter Event System library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2291,6 +2388,7 @@ test = ["click", "coverage", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>= name = "jupyter-server" version = "2.5.0" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -2303,7 +2401,7 @@ anyio = ">=3.1.0" argon2-cffi = "*" jinja2 = "*" jupyter-client = ">=7.4.4" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" jupyter-events = ">=0.4.0" jupyter-server-terminals = "*" nbconvert = ">=6.4.4" @@ -2326,6 +2424,7 @@ test = ["ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", " name = "jupyter-server-terminals" version = "0.4.4" description = "A Jupyter Server Extension Providing Terminals." +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -2345,6 +2444,7 @@ test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2356,6 +2456,7 @@ files = [ name = "jupyterlab-widgets" version = "3.0.7" description = "Jupyter interactive widgets for JupyterLab" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2367,6 +2468,7 @@ files = [ name = "jupytext" version = "1.14.5" description = "Jupyter notebooks as Markdown documents, Julia, Python or R scripts" +category = "dev" optional = false python-versions = "~=3.6" files = [ @@ -2389,6 +2491,7 @@ toml = ["toml"] name = "keras" version = "2.12.0" description = "Deep learning for humans." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2399,6 +2502,7 @@ files = [ name = "kiwisolver" version = "1.4.4" description = "A fast implementation of the Cassowary constraint solver" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2476,6 +2580,7 @@ files = [ name = "libclang" version = "16.0.0" description = "Clang Python Bindings, mirrored from the official LLVM repo: https://github.com/llvm/llvm-project/tree/main/clang/bindings/python, to make the installation process easier." +category = "main" optional = false python-versions = "*" files = [ @@ -2493,6 +2598,7 @@ files = [ name = "lxml" version = "4.9.2" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" files = [ @@ -2585,6 +2691,7 @@ source = ["Cython (>=0.29.7)"] name = "markdown" version = "3.4.3" description = "Python implementation of John Gruber's Markdown." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2602,6 +2709,7 @@ testing = ["coverage", "pyyaml"] name = "markdown-it-py" version = "2.2.0" description = "Python port of markdown-it. Markdown parsing, done right!" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2626,6 +2734,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.2" description = "Safely add untrusted strings to HTML/XML markup." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2685,6 +2794,7 @@ files = [ name = "matplotlib" version = "3.7.1" description = "Python plotting package" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2747,6 +2857,7 @@ python-dateutil = ">=2.7" name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -2761,6 +2872,7 @@ traitlets = "*" name = "mdit-py-plugins" version = "0.3.5" description = "Collection of plugins for markdown-it-py" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2780,6 +2892,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2791,6 +2904,7 @@ files = [ name = "mistune" version = "2.0.5" description = "A sane Markdown parser with useful plugins and renderers" +category = "main" optional = true python-versions = "*" files = [ @@ -2802,6 +2916,7 @@ files = [ name = "ml-dtypes" version = "0.1.0" description = "" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2837,6 +2952,7 @@ dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] name = "msgpack" version = "1.0.5" description = "MessagePack serializer" +category = "main" optional = false python-versions = "*" files = [ @@ -2909,6 +3025,7 @@ files = [ name = "multidict" version = "6.0.4" description = "multidict implementation" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2992,6 +3109,7 @@ files = [ name = "multiprocess" version = "0.70.14" description = "better multiprocessing and multithreading in python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3018,6 +3136,7 @@ dill = ">=0.3.6" name = "multitasking" version = "0.0.11" description = "Non-blocking Python methods using decorators" +category = "dev" optional = false python-versions = "*" files = [ @@ -3029,6 +3148,7 @@ files = [ name = "myst-nb" version = "0.17.2" description = "A Jupyter Notebook Sphinx reader built on top of the MyST markdown parser." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3057,6 +3177,7 @@ testing = ["beautifulsoup4", "coverage (>=6.4,<8.0)", "ipykernel (>=5.5,<6.0)", name = "myst-parser" version = "0.18.1" description = "An extended commonmark compliant parser, with bridges to docutils & sphinx." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3083,6 +3204,7 @@ testing = ["beautifulsoup4", "coverage[toml]", "pytest (>=6,<7)", "pytest-cov", name = "nbclassic" version = "1.0.0" description = "Jupyter Notebook as a Jupyter Server extension." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3118,6 +3240,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-jupyter", "pytest-p name = "nbclient" version = "0.7.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3127,7 +3250,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" nbformat = ">=5.1" traitlets = ">=5.3" @@ -3140,6 +3263,7 @@ test = ["flaky", "ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "p name = "nbconvert" version = "7.4.0" description = "Converting Jupyter Notebooks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3178,6 +3302,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.8.0" description = "The Jupyter Notebook format" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3199,6 +3324,7 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] name = "nbsphinx" version = "0.8.12" description = "Jupyter Notebook Tools for Sphinx" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3218,6 +3344,7 @@ traitlets = ">=5" name = "nbsphinx-link" version = "1.3.0" description = "A sphinx extension for including notebook files outside sphinx source root" +category = "main" optional = true python-versions = "*" files = [ @@ -3233,6 +3360,7 @@ sphinx = ">=1.8" name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -3244,6 +3372,7 @@ files = [ name = "nodeenv" version = "1.7.0" description = "Node.js virtual environment builder" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -3258,6 +3387,7 @@ setuptools = "*" name = "notebook" version = "6.5.4" description = "A web-based notebook environment for interactive computing" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3292,6 +3422,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "requests", "requests-unixs name = "notebook-shim" version = "0.2.3" description = "A shim layer for notebook traits and config" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3309,6 +3440,7 @@ test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync" name = "numpy" version = "1.23.5" description = "NumPy is the fundamental package for array computing with Python." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3346,6 +3478,7 @@ files = [ name = "oauthlib" version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3362,6 +3495,7 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] name = "omegaconf" version = "2.3.0" description = "A flexible configuration library" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3370,13 +3504,14 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = "==4.9.*" +antlr4-python3-runtime = ">=4.9.0,<4.10.0" PyYAML = ">=5.1.0" [[package]] name = "openpyxl" version = "3.1.2" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3391,6 +3526,7 @@ et-xmlfile = "*" name = "opt-einsum" version = "3.3.0" description = "Optimizing numpys einsum function" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -3409,6 +3545,7 @@ tests = ["pytest", "pytest-cov", "pytest-pep8"] name = "optax" version = "0.1.5" description = "A gradient processing and optimisation library in JAX." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3427,6 +3564,7 @@ numpy = ">=1.18.0" name = "orbax-checkpoint" version = "0.2.2" description = "Orbax Checkpoint" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3455,6 +3593,7 @@ dev = ["flax", "pytest", "pytest-xdist"] name = "packaging" version = "23.1" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3466,6 +3605,7 @@ files = [ name = "pandas" version = "2.0.1" description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3532,6 +3672,7 @@ xml = ["lxml (>=4.6.3)"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3543,6 +3684,7 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3558,6 +3700,7 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathos" version = "0.3.0" description = "parallel graph management and execution in heterogeneous computing" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3575,6 +3718,7 @@ ppft = ">=1.7.6.6" name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." +category = "main" optional = false python-versions = "*" files = [ @@ -3589,6 +3733,7 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" +category = "main" optional = false python-versions = "*" files = [ @@ -3600,6 +3745,7 @@ files = [ name = "pillow" version = "9.5.0" description = "Python Imaging Library (Fork)" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3679,6 +3825,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3690,6 +3837,7 @@ files = [ name = "platformdirs" version = "3.5.1" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3705,6 +3853,7 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3720,6 +3869,7 @@ testing = ["pytest", "pytest-benchmark"] name = "pox" version = "0.3.2" description = "utilities for filesystem exploration and automated builds" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3731,6 +3881,7 @@ files = [ name = "ppft" version = "1.7.6.6" description = "distributed and parallel python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3745,6 +3896,7 @@ dill = ["dill (>=0.3.6)"] name = "pre-commit" version = "3.3.1" description = "A framework for managing and maintaining multi-language pre-commit hooks." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3763,6 +3915,7 @@ virtualenv = ">=20.10.0" name = "prometheus-client" version = "0.16.0" description = "Python client for the Prometheus monitoring system." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3777,6 +3930,7 @@ twisted = ["twisted"] name = "promise" version = "2.3" description = "Promises/A+ implementation for Python" +category = "dev" optional = false python-versions = "*" files = [ @@ -3793,6 +3947,7 @@ test = ["coveralls", "futures", "mock", "pytest (>=2.7.3)", "pytest-benchmark", name = "prompt-toolkit" version = "3.0.38" description = "Library for building powerful interactive command lines in Python" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3807,6 +3962,7 @@ wcwidth = "*" name = "protobuf" version = "3.20.3" description = "Protocol Buffers" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3838,6 +3994,7 @@ files = [ name = "protobuf3-to-dict" version = "0.1.5" description = "Ben Hodgson: A teeny Python library for creating Python dicts from protocol buffers and the reverse. Useful as an intermediate step before serialisation (e.g. to JSON). Kapor: upgrade it to PB3 and PY3, rename it to protobuf3-to-dict" +category = "main" optional = true python-versions = "*" files = [ @@ -3852,6 +4009,7 @@ six = "*" name = "psutil" version = "5.9.5" description = "Cross-platform lib for process and system monitoring in Python." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3878,6 +4036,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "main" optional = false python-versions = "*" files = [ @@ -3889,6 +4048,7 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" +category = "main" optional = false python-versions = "*" files = [ @@ -3903,6 +4063,7 @@ tests = ["pytest"] name = "pyarrow" version = "12.0.0" description = "Python library for Apache Arrow" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3940,6 +4101,7 @@ numpy = ">=1.16.6" name = "pyasn1" version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3951,6 +4113,7 @@ files = [ name = "pyasn1-modules" version = "0.3.0" description = "A collection of ASN.1-based protocols modules" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3965,6 +4128,7 @@ pyasn1 = ">=0.4.6,<0.6.0" name = "pycparser" version = "2.21" description = "C parser in Python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3976,6 +4140,7 @@ files = [ name = "pydata-sphinx-theme" version = "0.12.0" description = "Bootstrap-based Sphinx theme from the PyData community" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4000,6 +4165,7 @@ test = ["pydata-sphinx-theme[doc]", "pytest"] name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4014,6 +4180,7 @@ plugins = ["importlib-metadata"] name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -4028,6 +4195,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyrsistent" version = "0.19.3" description = "Persistent/Functional/Immutable data structures" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4064,6 +4232,7 @@ files = [ name = "pytest" version = "7.3.1" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4086,6 +4255,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "4.0.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4104,6 +4274,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -4118,6 +4289,7 @@ six = ">=1.5" name = "python-json-logger" version = "2.0.7" description = "A python library adding a json log formatter" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4129,6 +4301,7 @@ files = [ name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" files = [ @@ -4140,6 +4313,7 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" +category = "main" optional = false python-versions = "*" files = [ @@ -4163,6 +4337,7 @@ files = [ name = "pywinpty" version = "2.0.10" description = "Pseudo terminal support for Windows from Python." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4178,6 +4353,7 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4227,6 +4403,7 @@ files = [ name = "pyzmq" version = "25.0.2" description = "Python bindings for 0MQ" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4316,6 +4493,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "qtconsole" version = "5.4.3" description = "Jupyter Qt console" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -4342,6 +4520,7 @@ test = ["flaky", "pytest", "pytest-qt"] name = "qtpy" version = "2.3.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4359,6 +4538,7 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] name = "regex" version = "2023.5.5" description = "Alternative regular expression module, to replace re." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4456,6 +4636,7 @@ files = [ name = "requests" version = "2.31.0" description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4477,6 +4658,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "requests-oauthlib" version = "1.3.1" description = "OAuthlib authentication support for Requests." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -4495,6 +4677,7 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"] name = "responses" version = "0.18.0" description = "A utility library for mocking out the `requests` Python library." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4513,6 +4696,7 @@ tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=4.6)", "pytest-cov", name = "rfc3339-validator" version = "0.1.4" description = "A pure python RFC3339 validator" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -4527,6 +4711,7 @@ six = "*" name = "rfc3986-validator" version = "0.1.1" description = "Pure python rfc3986 validator" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -4538,6 +4723,7 @@ files = [ name = "rich" version = "13.3.5" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -4557,6 +4743,7 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" +category = "main" optional = false python-versions = ">=3.6,<4" files = [ @@ -4571,6 +4758,7 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.6.1" description = "An Amazon S3 Transfer Manager" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -4588,6 +4776,7 @@ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] name = "safetensors" version = "0.3.1" description = "Fast and Safe Tensor serialization" +category = "main" optional = true python-versions = "*" files = [ @@ -4648,6 +4837,7 @@ torch = ["torch (>=1.10)"] name = "sagemaker" version = "2.161.0" description = "Open source library for training and deploying models on Amazon SageMaker." +category = "main" optional = true python-versions = ">= 3.6" files = [ @@ -4683,6 +4873,7 @@ test = ["Jinja2 (==3.0.3)", "PyYAML (==6.0)", "apache-airflow (==2.6.0)", "apach name = "sagemaker-utils" version = "0.3.6" description = "Helper functions to work with SageMaker" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4700,6 +4891,7 @@ yaspin = "*" name = "schema" version = "0.7.5" description = "Simple data validation library" +category = "main" optional = true python-versions = "*" files = [ @@ -4714,6 +4906,7 @@ contextlib2 = ">=0.5.5" name = "scikit-learn" version = "1.2.2" description = "A set of python modules for machine learning and data mining" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4756,6 +4949,7 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy ( name = "scipy" version = "1.10.1" description = "Fundamental algorithms for scientific computing in Python" +category = "main" optional = false python-versions = "<3.12,>=3.8" files = [ @@ -4794,6 +4988,7 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo name = "send2trash" version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" +category = "main" optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -4810,6 +5005,7 @@ win32 = ["pywin32"] name = "setuptools" version = "67.7.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4826,6 +5022,7 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -4837,6 +5034,7 @@ files = [ name = "smdebug-rulesconfig" version = "1.0.1" description = "SMDebug RulesConfig" +category = "main" optional = true python-versions = ">=2.7" files = [ @@ -4848,6 +5046,7 @@ files = [ name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4859,6 +5058,7 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." +category = "main" optional = false python-versions = "*" files = [ @@ -4870,6 +5070,7 @@ files = [ name = "soupsieve" version = "2.4.1" description = "A modern CSS selector implementation for Beautiful Soup." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4881,6 +5082,7 @@ files = [ name = "sphinx" version = "5.3.0" description = "Python documentation generator" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4916,6 +5118,7 @@ test = ["cython", "html5lib", "pytest (>=4.6)", "typed_ast"] name = "sphinx-autodoc-typehints" version = "1.23.0" description = "Type hints (PEP 484) support for the Sphinx autodoc extension" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4935,6 +5138,7 @@ type-comment = ["typed-ast (>=1.5.4)"] name = "sphinx-gallery" version = "0.11.1" description = "A Sphinx extension that builds an HTML version of any Python script and puts it into an examples gallery." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4949,6 +5153,7 @@ sphinx = ">=3" name = "sphinxcontrib-applehelp" version = "1.0.4" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4964,6 +5169,7 @@ test = ["pytest"] name = "sphinxcontrib-devhelp" version = "1.0.2" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -4979,6 +5185,7 @@ test = ["pytest"] name = "sphinxcontrib-htmlhelp" version = "2.0.1" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4994,6 +5201,7 @@ test = ["html5lib", "pytest"] name = "sphinxcontrib-jsmath" version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5008,6 +5216,7 @@ test = ["flake8", "mypy", "pytest"] name = "sphinxcontrib-qthelp" version = "1.0.3" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5023,6 +5232,7 @@ test = ["pytest"] name = "sphinxcontrib-serializinghtml" version = "1.1.5" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5038,6 +5248,7 @@ test = ["pytest"] name = "sqlalchemy" version = "2.0.13" description = "Database Abstraction Library" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5085,7 +5296,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} typing-extensions = ">=4.2.0" [package.extras] @@ -5115,6 +5326,7 @@ sqlcipher = ["sqlcipher3-binary"] name = "stack-data" version = "0.6.2" description = "Extract data from python stack frames and tracebacks for informative displays" +category = "main" optional = false python-versions = "*" files = [ @@ -5134,6 +5346,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5148,6 +5361,7 @@ widechars = ["wcwidth"] name = "tblib" version = "1.7.0" description = "Traceback serialization library." +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -5159,6 +5373,7 @@ files = [ name = "tensorboard" version = "2.12.3" description = "TensorBoard lets you watch Tensors Flow" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5183,6 +5398,7 @@ wheel = ">=0.26" name = "tensorboard-data-server" version = "0.7.0" description = "Fast data loading for TensorBoard" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5195,6 +5411,7 @@ files = [ name = "tensorflow-cpu" version = "2.12.0" description = "TensorFlow is an open source machine learning framework for everyone." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5240,6 +5457,7 @@ wrapt = ">=1.11.0,<1.15" name = "tensorflow-datasets" version = "4.9.2" description = "tensorflow/datasets is a library of datasets ready to use with TensorFlow." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -5306,6 +5524,7 @@ youtube-vis = ["pycocotools"] name = "tensorflow-estimator" version = "2.12.0" description = "TensorFlow Estimator." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5316,6 +5535,7 @@ files = [ name = "tensorflow-io-gcs-filesystem" version = "0.32.0" description = "TensorFlow IO" +category = "main" optional = false python-versions = ">=3.7, <3.12" files = [ @@ -5346,13 +5566,12 @@ tensorflow-rocm = ["tensorflow-rocm (>=2.12.0,<2.13.0)"] name = "tensorflow-macos" version = "2.12.0" description = "TensorFlow is an open source machine learning framework for everyone." +category = "main" optional = false python-versions = ">=3.8" files = [ {file = "tensorflow_macos-2.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:db464c88e10e927725997f9b872a21c9d057789d3b7e9a26e4ef1af41d0bcc8c"}, {file = "tensorflow_macos-2.12.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:172277c33cb1ae0da19f98c5bcd4946149cfa73c8ea05c6ba18365d58dd3c6f2"}, - {file = "tensorflow_macos-2.12.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:9c9b14fbb73ec4cb0f209722a1489020fd8614c92ae22589f2309c48cefdf21f"}, - {file = "tensorflow_macos-2.12.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:6a54539bd076746f69ae8bef7282f981674fe4dbf59c3a84c4af86ae6bae9d5c"}, {file = "tensorflow_macos-2.12.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e3fa53e63672fd71998bbd71cc5478c74dbe5a2d9291d1801c575358c28403c2"}, {file = "tensorflow_macos-2.12.0-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:5499312c21ed3ed47cc6b4cf861896e9564c2c32d8d3c2ef1437c5ca31adfc73"}, {file = "tensorflow_macos-2.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:84cb873c90be63efabfecca53fdc48b734a037d0750532b55cb7ce7c343b5cac"}, @@ -5387,6 +5606,7 @@ wrapt = ">=1.11.0,<1.15" name = "tensorflow-metadata" version = "1.13.1" description = "Library and standards for schema and statistics." +category = "dev" optional = false python-versions = ">=3.8,<4" files = [ @@ -5402,6 +5622,7 @@ protobuf = ">=3.20.3,<5" name = "tensorstore" version = "0.1.36" description = "Read and write large, multi-dimensional arrays" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5431,6 +5652,7 @@ numpy = ">=1.16.0" name = "termcolor" version = "2.3.0" description = "ANSI color formatting for output in terminal" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5445,6 +5667,7 @@ tests = ["pytest", "pytest-cov"] name = "terminado" version = "0.17.1" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5465,6 +5688,7 @@ test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -5476,6 +5700,7 @@ files = [ name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5494,6 +5719,7 @@ test = ["flake8", "isort", "pytest"] name = "tokenizers" version = "0.13.3" description = "Fast and Customizable Tokenizers" +category = "main" optional = true python-versions = "*" files = [ @@ -5548,6 +5774,7 @@ testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -5559,6 +5786,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5570,6 +5798,7 @@ files = [ name = "toolz" version = "0.12.0" description = "List processing tools and functional utilities" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5581,6 +5810,7 @@ files = [ name = "tornado" version = "6.3.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "main" optional = false python-versions = ">= 3.8" files = [ @@ -5601,6 +5831,7 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5621,6 +5852,7 @@ telegram = ["requests"] name = "traitlets" version = "5.9.0" description = "Traitlets Python configuration system" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5636,6 +5868,7 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] name = "transformers" version = "4.30.0" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -5705,6 +5938,7 @@ vision = ["Pillow"] name = "typing-extensions" version = "4.5.0" description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5716,6 +5950,7 @@ files = [ name = "tzdata" version = "2023.3" description = "Provider of IANA time zone data" +category = "main" optional = false python-versions = ">=2" files = [ @@ -5727,6 +5962,7 @@ files = [ name = "uri-template" version = "1.2.0" description = "RFC 6570 URI Template Processor" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -5741,6 +5977,7 @@ dev = ["flake8 (<4.0.0)", "flake8-annotations", "flake8-bugbear", "flake8-commas name = "urllib3" version = "1.26.15" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -5757,6 +5994,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "virtualenv" version = "20.23.0" description = "Virtual Python Environment builder" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5777,6 +6015,7 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "coverage-enable-subprocess name = "wcwidth" version = "0.2.6" description = "Measures the displayed width of unicode strings in a terminal" +category = "main" optional = false python-versions = "*" files = [ @@ -5788,6 +6027,7 @@ files = [ name = "webcolors" version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5803,6 +6043,7 @@ tests = ["pytest", "pytest-cov"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "main" optional = false python-versions = "*" files = [ @@ -5814,6 +6055,7 @@ files = [ name = "websocket-client" version = "1.5.1" description = "WebSocket client for Python with low level API options" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5830,6 +6072,7 @@ test = ["websockets"] name = "werkzeug" version = "2.3.4" description = "The comprehensive WSGI web application library." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5847,6 +6090,7 @@ watchdog = ["watchdog (>=2.3)"] name = "wheel" version = "0.40.0" description = "A built-package format for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5861,6 +6105,7 @@ test = ["pytest (>=6.0.0)"] name = "widgetsnbextension" version = "4.0.7" description = "Jupyter interactive widgets for Jupyter Notebook" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5872,6 +6117,7 @@ files = [ name = "wrapt" version = "1.14.1" description = "Module for decorators, wrappers and monkey patching." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -5945,6 +6191,7 @@ files = [ name = "xlrd" version = "2.0.1" description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -5961,6 +6208,7 @@ test = ["pytest", "pytest-cov"] name = "xxhash" version = "3.2.0" description = "Python binding for xxHash" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -6068,6 +6316,7 @@ files = [ name = "yarl" version = "1.9.2" description = "Yet another URL library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -6155,6 +6404,7 @@ multidict = ">=4.0" name = "yaspin" version = "2.3.0" description = "Yet Another Terminal Spinner" +category = "main" optional = true python-versions = ">=3.7.2,<4.0.0" files = [ @@ -6169,6 +6419,7 @@ termcolor = ">=2.2,<3.0" name = "yfinance" version = "0.2.18" description = "Download market data from Yahoo! Finance API" +category = "dev" optional = false python-versions = "*" files = [ @@ -6193,6 +6444,7 @@ requests = ">=2.26" name = "zipp" version = "3.15.0" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" optional = false python-versions = ">=3.7" files = [ From bc64a0179383a26a08ab7fa2a75ef6e966b474ef Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Sun, 30 Jul 2023 18:24:25 +0200 Subject: [PATCH 4/9] bump up version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1687732c..4915c415 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws-fortuna" -version = "0.1.22" +version = "0.1.23" description = "A Library for Uncertainty Quantification." authors = ["Gianluca Detommaso ", "Alberto Gasparin "] license = "Apache-2.0" From 591d8425ebbca4c039ac08f73637027f8e126431 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Wed, 13 Sep 2023 15:39:17 +0200 Subject: [PATCH 5/9] refactor tabular analysis of benchmarks --- benchmarks/tabular/analysis.py | 757 ++++++--------------------------- 1 file changed, 137 insertions(+), 620 deletions(-) diff --git a/benchmarks/tabular/analysis.py b/benchmarks/tabular/analysis.py index 2608da52..464f0652 100644 --- a/benchmarks/tabular/analysis.py +++ b/benchmarks/tabular/analysis.py @@ -6,21 +6,17 @@ with open("tabular_results.json", "r") as j: metrics = json.loads(j.read()) +TOL = 1e-4 + # ~~~REGRESSION~~~ # MAP map_nlls = [ metrics["regression"][k]["map"]["nll"] for k in metrics["regression"].keys() ] -map_quantiles_nlls = np.percentile(map_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_picp_errors = [ - np.abs(0.95 - metrics["regression"][k]["map"]["picp"]) - for k in metrics["regression"].keys() + np.abs(metrics["regression"][k]["map"]["picp"] - 0.95) for k in metrics["regression"].keys() ] -map_quantiles_picp_errors = np.percentile( - map_picp_errors, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) map_times = [ metrics["regression"][k]["map"]["time"] for k in metrics["regression"].keys() @@ -31,185 +27,83 @@ metrics["regression"][k]["temp_scaling"]["nll"] for k in metrics["regression"].keys() ] -temp_scaling_quantiles_nlls = np.percentile( - temp_scaling_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_nlls = np.sum(np.array(temp_scaling_nlls) / np.array(map_nlls) <= 1) -winlose_temp_scaling_nlls = ( - f"{win_temp_scaling_nlls} / {len(map_nlls) - win_temp_scaling_nlls}" -) -rel_improve_temp_scaling_nlls = ( - np.array(map_nlls) - np.array(temp_scaling_nlls) -) / np.array(map_nlls) -max_loss_temp_scaling_nlls = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_temp_scaling_nlls[rel_improve_temp_scaling_nlls < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_nlls = ( - f"{np.round(np.median(rel_improve_temp_scaling_nlls), 2)}" -) - +win_temp_scaling_nlls = np.array(temp_scaling_nlls) < np.array(map_nlls) - TOL +lose_temp_scaling_nlls = np.array(temp_scaling_nlls) > np.array(map_nlls) + TOL temp_scaling_picp_errors = [ - np.abs(0.95 - metrics["regression"][k]["temp_scaling"]["picp"]) + np.abs(metrics["regression"][k]["temp_scaling"]["picp"] - 0.95) for k in metrics["regression"].keys() ] -temp_scaling_quantiles_picp_errors = np.percentile( - temp_scaling_picp_errors, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_picp_errors = np.sum( - np.array(temp_scaling_picp_errors) / np.array(map_picp_errors) <= 1 -) -winlose_temp_scaling_picp_errors = f"{win_temp_scaling_picp_errors} / {len(map_picp_errors) - win_temp_scaling_picp_errors}" -rel_improve_temp_scaling_picp_errors = ( - np.array(map_picp_errors) - np.array(temp_scaling_picp_errors) -) / np.array(map_picp_errors) -max_loss_temp_scaling_picp_errors = ( - str( - np.round( - 100 - * np.abs( - np.max( - rel_improve_temp_scaling_picp_errors[ - rel_improve_temp_scaling_picp_errors < 0 - ] - ) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_picp_errors = ( - f"{np.round(np.median(rel_improve_temp_scaling_picp_errors), 2)}" -) +win_temp_scaling_picp_errors = np.array(temp_scaling_picp_errors) < np.array(map_picp_errors) - TOL +lose_temp_scaling_picp_errors = np.array(temp_scaling_picp_errors) > np.array(map_picp_errors) + TOL temp_scaling_times = [ metrics["regression"][k]["temp_scaling"]["time"] for k in metrics["regression"].keys() ] +temp_scaling_best_win = np.max(np.array(map_picp_errors) - np.array(temp_scaling_picp_errors)) +temp_scaling_worst_loss = np.min(np.array(map_picp_errors) - np.array(temp_scaling_picp_errors)) + # CQR cqr_picp_errors = [ - np.abs(0.95 - metrics["regression"][k]["cqr"]["picp"]) + np.abs(metrics["regression"][k]["cqr"]["picp"] - 0.95) for k in metrics["regression"].keys() ] -cqr_quantiles_picp_errors = np.percentile( - cqr_picp_errors, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_cqr_picp_errors = np.sum(np.array(cqr_picp_errors) / np.array(map_picp_errors) <= 1) -winlose_cqr_picp_errors = ( - f"{win_cqr_picp_errors} / {len(map_picp_errors) - win_cqr_picp_errors}" -) -rel_improve_cqr_picp_errors = ( - np.array(map_picp_errors) - np.array(cqr_picp_errors) -) / np.array(map_picp_errors) -max_loss_cqr_picp_errors = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_cqr_picp_errors[rel_improve_cqr_picp_errors < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_cqr_picp_errors = f"{np.round(np.median(rel_improve_cqr_picp_errors), 2)}" +win_cqr_picp_errors = np.array(cqr_picp_errors) < np.array(map_picp_errors) - TOL +lose_cqr_picp_errors = np.array(cqr_picp_errors) > np.array(map_picp_errors) + TOL cqr_times = [ - metrics["regression"][k]["cqr"]["time"] for k in metrics["regression"].keys() -] - -# # TEMPERED CQR -temp_cqr_picp_errors = [ - np.abs(0.95 - metrics["regression"][k]["temp_cqr"]["picp"]) + metrics["regression"][k]["cqr"]["time"] for k in metrics["regression"].keys() ] -temp_cqr_quantiles_picp_errors = np.percentile( - temp_cqr_picp_errors, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_cqr_picp_errors = f"{np.sum(np.array(temp_cqr_picp_errors) / np.array(map_picp_errors) <= 1)} / {len(map_picp_errors)}" -med_improv_temp_cqr_picp_errors = f"{np.round(np.median((np.array(map_picp_errors) - np.array(temp_cqr_picp_errors)) / np.array(map_picp_errors)), 2)}" - -temp_cqr_times = [ - metrics["regression"][k]["temp_cqr"]["time"] for k in metrics["regression"].keys() -] -plt.figure(figsize=(8, 6)) -plt.suptitle("Quantile-quantile plots of metrics on regression datasets") +cqr_best_win = np.max(np.array(map_picp_errors) - np.array(cqr_picp_errors)) +cqr_worst_loss = np.min(np.array(map_picp_errors) - np.array(cqr_picp_errors)) -plt.subplot(2, 2, 1) -plt.title("NLL") -plt.scatter(map_quantiles_nlls, temp_scaling_quantiles_nlls, s=3) -_min, _max = min(map_quantiles_nlls.min(), temp_scaling_quantiles_nlls.min()), max( - map_quantiles_nlls.max(), temp_scaling_quantiles_nlls.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") -plt.grid() +### Regression plots -plt.subplot(2, 2, 2) -plt.title("PICP absolute error") -plt.scatter(map_quantiles_picp_errors, temp_scaling_quantiles_picp_errors, s=3) -_min, _max = min( - map_quantiles_picp_errors.min(), temp_scaling_quantiles_picp_errors.min() -), max(map_quantiles_picp_errors.max(), temp_scaling_quantiles_picp_errors.max()) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") +plt.figure(figsize=(10, 3)) +plt.suptitle("Scatter plots for regression datasets") +plt.subplot(1, 2, 1) +plt.title("PICP errors") plt.grid() - -plt.subplot(2, 2, 4) -plt.title("PICP absolute error") -plt.scatter(map_quantiles_picp_errors, cqr_quantiles_picp_errors, s=3) -_min, _max = min(map_quantiles_picp_errors.min(), cqr_quantiles_picp_errors.min()), max( - map_quantiles_picp_errors.max(), cqr_quantiles_picp_errors.max() -) +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_picp_errors).min(), np.array(temp_scaling_picp_errors).min()), max(np.array(map_picp_errors).max(), np.array(temp_scaling_picp_errors).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("CQR quantiles") +plt.scatter(map_picp_errors, temp_scaling_picp_errors, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_picp_errors, lose_temp_scaling_picp_errors)]) +plt.xscale("log") +plt.yscale("log") + +plt.subplot(1, 2, 2) +plt.title("PICP errors") plt.grid() +plt.xlabel("MAP") +plt.ylabel("CQR") +_min, _max = min(np.array(map_picp_errors).min(), np.array(cqr_picp_errors).min()), max(np.array(map_picp_errors).max(), np.array(cqr_picp_errors).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_picp_errors, cqr_picp_errors, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_cqr_picp_errors, lose_cqr_picp_errors)]) +plt.xscale("log") +plt.yscale("log") plt.tight_layout() - plt.show() -print("~~~REGRESSION~~~\n") -print("## TEMPERATURE SCALING ##") -print( - f"Fraction of times temp_scaling is at least on a par w.r.t. the NLL: {winlose_temp_scaling_nlls}" -) -print( - f"Fraction of times temp_scaling is at least on a par w.r.t. the PICP error: {winlose_temp_scaling_picp_errors}" -) -print() -print( - f"Median of relative NLL improvement given by temp_scaling: {med_improv_temp_scaling_nlls}" -) -print( - f"Median of relative PICP error improvement given by temp_scaling: {med_improv_temp_scaling_picp_errors}" -) -print() -print() -print("## CQR ##") -print( - f"Fraction of times CQR is at least on a par w.r.t. the PICP error: {winlose_cqr_picp_errors}" -) -print() -print( - f"Median of relative PICP error improvement given by temp_scaling: {med_improv_cqr_picp_errors}" -) +plt.figure(figsize=(5, 3)) +plt.suptitle("Scatter plots for regression datasets on other metrics") +plt.title("NLL") +plt.grid() +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_nlls).min(), np.array(temp_scaling_nlls).min()), max(np.array(map_nlls).max(), np.array(temp_scaling_nlls).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_nlls, temp_scaling_nlls, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_nlls, lose_temp_scaling_nlls)]) +plt.xscale("log") +plt.yscale("log") + +plt.tight_layout() +plt.show() # ~~~CLASSIFICATION~~~ @@ -218,32 +112,23 @@ map_nlls = [ metrics["classification"][k]["map"]["nll"] for k in metrics["classification"].keys() ] -map_quantiles_nlls = np.percentile(map_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90]) map_mse = [ metrics["classification"][k]["map"]["mse"] for k in metrics["classification"].keys() ] -map_quantiles_mse = np.percentile(map_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_ece = [ metrics["classification"][k]["map"]["ece"] for k in metrics["classification"].keys() ] -map_quantiles_ece = np.percentile(map_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_rocauc = [ metrics["classification"][k]["map"]["rocauc"] for k in metrics["classification"].keys() if "rocauc" in metrics["classification"][k]["map"] ] -map_quantiles_rocauc = np.percentile(map_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_prauc = [ metrics["classification"][k]["map"]["prauc"] for k in metrics["classification"].keys() if "prauc" in metrics["classification"][k]["map"] ] -map_quantiles_prauc = np.percentile(map_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_acc = [ metrics["classification"][k]["map"]["accuracy"] for k in metrics["classification"].keys() @@ -258,155 +143,31 @@ metrics["classification"][k]["temp_scaling"]["nll"] for k in metrics["classification"].keys() ] -temp_scaling_quantiles_nlls = np.percentile( - temp_scaling_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_nlls = np.sum(np.array(temp_scaling_nlls) / np.array(map_nlls) <= 1) -winlose_temp_scaling_nlls = ( - f"{win_temp_scaling_nlls} / {len(map_nlls) - win_temp_scaling_nlls}" -) -rel_improve_temp_scaling_nlls = ( - np.array(map_nlls) - np.array(temp_scaling_nlls) -) / np.array(map_nlls) -max_loss_temp_scaling_nlls = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_temp_scaling_nlls[rel_improve_temp_scaling_nlls < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_nlls = ( - f"{np.round(np.median(rel_improve_temp_scaling_nlls), 2)}" -) +win_temp_scaling_nlls = np.array(temp_scaling_nlls) < np.array(map_nlls) - TOL +lose_temp_scaling_nlls = np.array(temp_scaling_nlls) > np.array(map_nlls) + TOL temp_scaling_mse = [ metrics["classification"][k]["temp_scaling"]["mse"] for k in metrics["classification"].keys() ] -temp_scaling_quantiles_mse = np.percentile( - temp_scaling_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_mse = np.sum(np.array(temp_scaling_mse) / np.array(map_mse) <= 1) -winlose_temp_scaling_mse = ( - f"{win_temp_scaling_mse} / {len(map_mse) - win_temp_scaling_mse}" -) -rel_improve_temp_scaling_mse = ( - np.array(map_mse) - np.array(temp_scaling_mse) -) / np.array(map_mse) -max_loss_temp_scaling_mse = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_temp_scaling_mse[rel_improve_temp_scaling_mse < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_mse = f"{np.round(np.median(rel_improve_temp_scaling_mse), 2)}" - -temp_scaling_ece = [ - metrics["classification"][k]["temp_scaling"]["ece"] - for k in metrics["classification"].keys() -] -temp_scaling_quantiles_ece = np.percentile( - temp_scaling_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_ece = np.sum(np.array(temp_scaling_ece) / np.array(map_ece) <= 1) -winlose_temp_scaling_ece = ( - f"{win_temp_scaling_ece} / {len(map_ece) - win_temp_scaling_ece}" -) -rel_improve_temp_scaling_ece = ( - np.array(map_ece) - np.array(temp_scaling_ece) -) / np.array(map_ece) -max_loss_temp_scaling_ece = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_temp_scaling_ece[rel_improve_temp_scaling_ece < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_ece = f"{np.round(np.median(rel_improve_temp_scaling_ece), 2)}" +win_temp_scaling_mse = np.array(temp_scaling_mse) < np.array(map_mse) - TOL +lose_temp_scaling_mse = np.array(temp_scaling_mse) > np.array(map_mse) + TOL temp_scaling_rocauc = [ metrics["classification"][k]["temp_scaling"]["rocauc"] for k in metrics["classification"].keys() if "rocauc" in metrics["classification"][k]["temp_scaling"] ] -temp_scaling_quantiles_rocauc = np.percentile( - temp_scaling_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_rocauc = np.sum( - np.array(temp_scaling_rocauc) / np.array(map_rocauc) <= 1 -) -winlose_temp_scaling_rocauc = ( - f"{win_temp_scaling_rocauc} / {len(map_rocauc) - win_temp_scaling_rocauc}" -) -rel_improve_temp_scaling_rocauc = ( - np.array(map_rocauc) - np.array(temp_scaling_rocauc) -) / np.array(map_rocauc) -max_loss_temp_scaling_rocauc = ( - str( - np.round( - 100 - * np.abs( - np.max( - rel_improve_temp_scaling_rocauc[rel_improve_temp_scaling_rocauc < 0] - ) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_rocauc = ( - f"{np.round(np.median(rel_improve_temp_scaling_rocauc), 2)}" -) +win_temp_scaling_rocauc = np.array(temp_scaling_rocauc) < np.array(map_rocauc) - TOL +lose_temp_scaling_rocauc = np.array(temp_scaling_rocauc) > np.array(map_rocauc) + TOL temp_scaling_prauc = [ metrics["classification"][k]["temp_scaling"]["prauc"] for k in metrics["classification"].keys() if "prauc" in metrics["classification"][k]["temp_scaling"] ] -temp_scaling_quantiles_prauc = np.percentile( - temp_scaling_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_prauc = np.sum(np.array(temp_scaling_prauc) / np.array(map_prauc) <= 1) -winlose_temp_scaling_prauc = ( - f"{win_temp_scaling_prauc} / {len(map_prauc) - win_temp_scaling_prauc}" -) -rel_improve_temp_scaling_prauc = ( - np.array(map_prauc) - np.array(temp_scaling_prauc) -) / np.array(map_prauc) -max_loss_temp_scaling_prauc = ( - str( - np.round( - 100 - * np.abs( - np.max( - rel_improve_temp_scaling_prauc[rel_improve_temp_scaling_prauc < 0] - ) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_prauc = ( - f"{np.round(np.median(rel_improve_temp_scaling_prauc), 2)}" -) +win_temp_scaling_prauc = np.array(temp_scaling_prauc) < np.array(map_prauc) - TOL +lose_temp_scaling_prauc = np.array(temp_scaling_prauc) > np.array(map_prauc) + TOL temp_scaling_acc = [ metrics["classification"][k]["temp_scaling"]["accuracy"] @@ -418,132 +179,39 @@ for k in metrics["classification"].keys() ] +temp_scaling_best_win = np.max(np.array(map_mse) - np.array(temp_scaling_mse)) +temp_scaling_worst_loss = np.min(np.array(map_mse) - np.array(temp_scaling_mse)) + # MULTICALIBRATE CONF mc_conf_nlls = [ metrics["classification"][k]["mc_conf"]["nll"] for k in metrics["classification"].keys() ] -mc_conf_nlls = np.array(mc_conf_nlls) -mc_conf_quantiles_nlls = np.percentile( - mc_conf_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_mc_conf_nlls = np.sum(np.array(mc_conf_nlls) / np.array(np.array(map_nlls)) <= 1) -winlose_mc_conf_nlls = f"{win_mc_conf_nlls} / {len(mc_conf_nlls) - win_mc_conf_nlls}" -rel_improve_mc_conf_nlls = (np.array(map_nlls) - np.array(mc_conf_nlls)) / np.array( - map_nlls -) -max_loss_mc_conf_nlls = ( - str( - np.round( - 100 - * np.abs(np.max(rel_improve_mc_conf_nlls[rel_improve_mc_conf_nlls < 0])), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_nlls = f"{np.round(np.median(rel_improve_mc_conf_nlls), 2)}" +win_mc_conf_nlls = np.array(mc_conf_nlls) < np.array(map_nlls) - TOL +lose_mc_conf_nlls = np.array(mc_conf_nlls) > np.array(map_nlls) + TOL mc_conf_mse = [ metrics["classification"][k]["mc_conf"]["mse"] for k in metrics["classification"].keys() ] -mc_conf_mse = np.array(mc_conf_mse) -mc_conf_quantiles_mse = np.percentile(mc_conf_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90]) -win_mc_conf_mse = np.sum(np.array(mc_conf_mse) / np.array(np.array(map_mse)) <= 1) -winlose_mc_conf_mse = f"{win_mc_conf_mse} / {len(mc_conf_mse) - win_mc_conf_mse}" -rel_improve_mc_conf_mse = (np.array(map_mse) - np.array(mc_conf_mse)) / np.array( - map_mse -) -max_loss_mc_conf_mse = ( - str( - np.round( - 100 * np.abs(np.max(rel_improve_mc_conf_mse[rel_improve_mc_conf_mse < 0])), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_mse = f"{np.round(np.median(rel_improve_mc_conf_mse), 2)}" - -mc_conf_ece = [ - metrics["classification"][k]["mc_conf"]["ece"] - for k in metrics["classification"].keys() -] -mc_conf_quantiles_ece = np.percentile(mc_conf_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90]) -win_mc_conf_ece = np.sum(np.array(mc_conf_ece) / np.array(map_ece) <= 1) -winlose_mc_conf_ece = f"{win_mc_conf_ece} / {len(map_ece) - win_mc_conf_ece}" -rel_improve_mc_conf_ece = (np.array(map_ece) - np.array(mc_conf_ece)) / np.array( - map_ece -) -max_loss_mc_conf_ece = ( - str( - np.round( - 100 * np.abs(np.max(rel_improve_mc_conf_ece[rel_improve_mc_conf_ece < 0])), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_ece = f"{np.round(np.median(rel_improve_mc_conf_ece), 2)}" +win_mc_conf_mse = np.array(mc_conf_mse) < np.array(map_mse) - TOL +lose_mc_conf_mse = np.array(mc_conf_mse) > np.array(map_mse) + TOL mc_conf_rocauc = [ metrics["classification"][k]["mc_conf"]["rocauc"] for k in metrics["classification"].keys() + if "rocauc" in metrics["classification"][k]["mc_conf"] ] -mc_conf_rocauc = np.array(mc_conf_rocauc) -mc_conf_quantiles_rocauc = np.percentile( - mc_conf_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_mc_conf_rocauc = np.sum( - np.array(mc_conf_rocauc) / np.array(np.array(map_rocauc)) <= 1 -) -winlose_mc_conf_rocauc = ( - f"{win_mc_conf_rocauc} / {len(mc_conf_rocauc) - win_mc_conf_rocauc}" -) -rel_improve_mc_conf_rocauc = ( - np.array(map_rocauc) - np.array(mc_conf_rocauc) -) / np.array(map_rocauc) -max_loss_mc_conf_rocauc = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_mc_conf_rocauc[rel_improve_mc_conf_rocauc < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_rocauc = f"{np.round(np.median(rel_improve_mc_conf_rocauc), 2)}" +win_mc_conf_rocauc = np.array(mc_conf_rocauc) < np.array(map_rocauc) - TOL +lose_mc_conf_rocauc = np.array(mc_conf_rocauc) > np.array(map_rocauc) + TOL mc_conf_prauc = [ metrics["classification"][k]["mc_conf"]["prauc"] for k in metrics["classification"].keys() + if "prauc" in metrics["classification"][k]["mc_conf"] ] -mc_conf_prauc = np.array(mc_conf_prauc) -mc_conf_quantiles_prauc = np.percentile( - mc_conf_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_mc_conf_prauc = np.sum(np.array(mc_conf_prauc) / np.array(np.array(map_prauc)) <= 1) -winlose_mc_conf_prauc = ( - f"{win_mc_conf_prauc} / {len(mc_conf_prauc) - win_mc_conf_prauc}" -) -rel_improve_mc_conf_prauc = (np.array(map_prauc) - np.array(mc_conf_prauc)) / np.array( - map_prauc -) -max_loss_mc_conf_prauc = ( - str( - np.round( - 100 - * np.abs(np.max(rel_improve_mc_conf_prauc[rel_improve_mc_conf_prauc < 0])), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_prauc = f"{np.round(np.median(rel_improve_mc_conf_prauc), 2)}" +win_mc_conf_prauc = np.array(mc_conf_prauc) < np.array(map_prauc) - TOL +lose_mc_conf_prauc = np.array(mc_conf_prauc) > np.array(map_prauc) + TOL mc_conf_acc = [ metrics["classification"][k]["mc_conf"]["accuracy"] @@ -555,244 +223,93 @@ for k in metrics["classification"].keys() ] -# TEMPERED MULTICALIBRATE CONF -temp_mc_conf_nlls = [ - metrics["classification"][k]["temp_mc_conf"]["nll"] - for k in metrics["classification"].keys() -] -temp_mc_conf_nlls = np.array(temp_mc_conf_nlls) -temp_mc_conf_quantiles_nlls = np.percentile( - temp_mc_conf_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_nlls = f"{np.sum(np.array(temp_mc_conf_nlls) / np.array(np.array(map_nlls)) <= 1)} / {len(temp_mc_conf_nlls)}" -med_improv_temp_mc_conf_nlls = f"{np.round(np.median((np.array(map_nlls) - np.array(temp_mc_conf_nlls)) / np.array(map_nlls)), 2)}" - -temp_mc_conf_mse = [ - metrics["classification"][k]["temp_mc_conf"]["mse"] - for k in metrics["classification"].keys() -] -temp_mc_conf_mse = np.array(temp_mc_conf_mse) -temp_mc_conf_quantiles_mse = np.percentile( - temp_mc_conf_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_mse = f"{np.sum(np.array(temp_mc_conf_mse) / np.array(np.array(map_mse)) <= 1)} / {len(temp_mc_conf_mse)}" -med_improv_temp_mc_conf_mse = f"{np.round(np.median((np.array(map_mse) - np.array(temp_mc_conf_mse)) / np.array(map_mse)), 2)}" - -temp_mc_conf_ece = [ - metrics["classification"][k]["temp_mc_conf"]["ece"] - for k in metrics["classification"].keys() -] -temp_mc_conf_quantiles_ece = np.percentile( - temp_mc_conf_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_ece = ( - f"{np.sum(np.array(temp_mc_conf_ece) / np.array(map_ece) <= 1)} / {len(map_ece)}" -) -med_improv_temp_mc_conf_ece = f"{np.round(np.median((np.array(map_ece) - np.array(temp_mc_conf_ece)) / np.array(map_ece)), 2)}" - -temp_mc_conf_rocauc = [ - metrics["classification"][k]["temp_mc_conf"]["rocauc"] - for k in metrics["classification"].keys() -] -temp_mc_conf_rocauc = np.array(temp_mc_conf_rocauc) -temp_mc_conf_quantiles_rocauc = np.percentile( - temp_mc_conf_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_rocauc = f"{np.sum(np.array(temp_mc_conf_rocauc) / np.array(np.array(map_rocauc)) <= 1)} / {len(temp_mc_conf_rocauc)}" -med_improv_temp_mc_conf_rocauc = f"{np.round(np.median((np.array(map_rocauc) - np.array(temp_mc_conf_rocauc)) / np.array(map_rocauc)), 2)}" - -temp_mc_conf_prauc = [ - metrics["classification"][k]["temp_mc_conf"]["prauc"] - for k in metrics["classification"].keys() -] -temp_mc_conf_prauc = np.array(temp_mc_conf_prauc) -temp_mc_conf_quantiles_prauc = np.percentile( - temp_mc_conf_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_prauc = f"{np.sum(np.array(temp_mc_conf_prauc) / np.array(np.array(map_prauc)) <= 1)} / {len(temp_mc_conf_prauc)}" -med_improv_temp_mc_conf_prauc = f"{np.round(np.median((np.array(map_prauc) - np.array(temp_mc_conf_prauc)) / np.array(map_prauc)), 2)}" - -temp_mc_conf_acc = [ - metrics["classification"][k]["temp_mc_conf"]["accuracy"] - for k in metrics["classification"].keys() -] - -temp_mc_conf_times = [ - metrics["classification"][k]["temp_mc_conf"]["time"] - for k in metrics["classification"].keys() -] - -# MULTICALIBRATE PROB -idx_overlap = [ - i - for i, k in enumerate(metrics["classification"]) - if len(metrics["classification"][k]["mc_prob"]) -] +mc_conf_best_win = np.max(np.array(map_mse) - np.array(mc_conf_mse)) +mc_conf_worst_loss = np.min(np.array(map_mse) - np.array(mc_conf_mse)) -mc_prob_nlls = [ - metrics["classification"][k]["mc_prob"]["nll"] - for k in metrics["classification"].keys() - if "nll" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_nlls = np.percentile( - mc_prob_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_mc_prob_nlls = f"{np.sum(np.array(mc_prob_nlls) / np.array(map_nlls)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_nlls = f"{np.round(np.median((np.array(map_nlls)[idx_overlap] - np.array(mc_prob_nlls)) / np.array(map_nlls)[idx_overlap]), 2)}" - -mc_prob_mse = [ - metrics["classification"][k]["mc_prob"]["mse"] - for k in metrics["classification"].keys() - if "mse" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_mse = np.percentile(mc_prob_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90]) -winlose_mc_prob_mse = f"{np.sum(np.array(mc_prob_mse) / np.array(map_mse)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_mse = f"{np.round(np.median((np.array(map_mse)[idx_overlap] - np.array(mc_prob_mse)) / np.array(map_mse)[idx_overlap]), 2)}" +### Classification plots -mc_prob_ece = [ - metrics["classification"][k]["mc_prob"]["ece"] - for k in metrics["classification"].keys() - if "ece" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_ece = np.percentile(mc_prob_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90]) -winlose_mc_prob_ece = f"{np.sum(np.array(mc_prob_ece) / np.array(map_ece)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_ece = f"{np.round(np.median((np.array(map_ece)[idx_overlap] - np.array(mc_prob_ece)) / np.array(map_ece)[idx_overlap]), 2)}" +plt.figure(figsize=(10, 3)) +plt.suptitle("Scatter plots for classification datasets") +plt.subplot(1, 2, 1) +plt.title("MSE") +plt.grid() +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_mse).min(), np.array(temp_scaling_mse).min()), max(np.array(map_mse).max(), np.array(temp_scaling_mse).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_mse, temp_scaling_mse, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_mse, lose_temp_scaling_mse)]) -mc_prob_rocauc = [ - metrics["classification"][k]["mc_prob"]["rocauc"] - for k in metrics["classification"].keys() - if "rocauc" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_rocauc = np.percentile( - mc_prob_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_mc_prob_rocauc = f"{np.sum(np.array(mc_prob_rocauc) / np.array(map_rocauc)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_rocauc = f"{np.round(np.median((np.array(map_rocauc)[idx_overlap] - np.array(mc_prob_rocauc)) / np.array(map_rocauc)[idx_overlap]), 2)}" - -mc_prob_prauc = [ - metrics["classification"][k]["mc_prob"]["prauc"] - for k in metrics["classification"].keys() - if "prauc" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_prauc = np.percentile( - mc_prob_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_mc_prob_prauc = f"{np.sum(np.array(mc_prob_prauc) / np.array(map_prauc)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_prauc = f"{np.round(np.median((np.array(map_prauc)[idx_overlap] - np.array(mc_prob_prauc)) / np.array(map_prauc)[idx_overlap]), 2)}" - -mc_prob_acc = [ - metrics["classification"][k]["mc_prob"]["accuracy"] - for k in metrics["classification"].keys() - if "accuracy" in metrics["classification"][k]["mc_prob"] -] +plt.subplot(1, 2, 2) +plt.title("MSE") +plt.grid() +plt.xlabel("MAP") +plt.ylabel("TLMC") +_min, _max = min(np.array(map_mse).min(), np.array(mc_conf_mse).min()), max(np.array(map_mse).max(), np.array(mc_conf_mse).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_mse, mc_conf_mse, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_mc_conf_mse, lose_mc_conf_mse)]) -mc_prob_times = [ - metrics["classification"][k]["mc_prob"]["time"] - for k in metrics["classification"].keys() - if "time" in metrics["classification"][k]["mc_prob"] -] +plt.tight_layout() +plt.show() plt.figure(figsize=(10, 6)) -plt.suptitle("Quantile-quantile plots of metrics on classification datasets") - -plt.subplot(2, 4, 1) +plt.suptitle("Scatter plots for classification datasets on other metrics") +plt.subplot(3, 2, 1) plt.title("NLL") -plt.scatter(map_quantiles_nlls, temp_scaling_quantiles_nlls, s=3) -_min, _max = min(map_quantiles_nlls.min(), temp_scaling_quantiles_nlls.min()), max( - map_quantiles_nlls.max(), temp_scaling_quantiles_nlls.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") plt.grid() - -plt.subplot(2, 4, 2) -plt.title("MSE") -plt.scatter(map_quantiles_mse, temp_scaling_quantiles_mse, s=3) -_min, _max = min(map_quantiles_mse.min(), temp_scaling_quantiles_mse.min()), max( - map_quantiles_mse.max(), temp_scaling_quantiles_mse.max() -) +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_nlls).min(), np.array(temp_scaling_nlls).min()), max(np.array(map_nlls).max(), np.array(temp_scaling_nlls).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") +plt.scatter(map_nlls, temp_scaling_nlls, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_nlls, lose_temp_scaling_nlls)]) +plt.xscale("log") +plt.yscale("log") + +plt.subplot(3, 2, 2) +plt.title("NLL") plt.grid() +plt.xlabel("MAP") +plt.ylabel("TLMC") +_min, _max = min(np.array(map_nlls).min(), np.array(mc_conf_nlls).min()), max(np.array(map_nlls).max(), np.array(mc_conf_nlls).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_nlls, mc_conf_nlls, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_mc_conf_nlls, lose_mc_conf_nlls)]) +plt.xscale("log") +plt.yscale("log") -plt.subplot(2, 4, 3) +plt.subplot(3, 2, 3) plt.title("ROCAUC") -plt.scatter(map_quantiles_rocauc, temp_scaling_quantiles_rocauc, s=3) -_min, _max = min(map_quantiles_rocauc.min(), temp_scaling_quantiles_rocauc.min()), max( - map_quantiles_rocauc.max(), temp_scaling_quantiles_rocauc.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") plt.grid() - -plt.subplot(2, 4, 4) -plt.title("PRAUC") -plt.scatter(map_quantiles_prauc, temp_scaling_quantiles_prauc, s=3) -_min, _max = min(map_quantiles_prauc.min(), temp_scaling_quantiles_prauc.min()), max( - map_quantiles_prauc.max(), temp_scaling_quantiles_prauc.max() -) +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_rocauc).min(), np.array(temp_scaling_rocauc).min()), max(np.array(map_rocauc).max(), np.array(temp_scaling_rocauc).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") -plt.grid() +plt.scatter(map_rocauc, temp_scaling_rocauc, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_rocauc, lose_temp_scaling_rocauc)]) -plt.subplot(2, 4, 5) -plt.title("NLL") -plt.scatter(map_quantiles_nlls, mc_conf_quantiles_nlls, s=3) -_min, _max = min(map_quantiles_nlls.min(), mc_conf_quantiles_nlls.min()), max( - map_quantiles_nlls.max(), mc_conf_quantiles_nlls.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") +plt.subplot(3, 2, 4) +plt.title("ROCAUC") plt.grid() - -plt.subplot(2, 4, 6) -plt.title("ECE") -plt.scatter(map_quantiles_ece, mc_conf_quantiles_ece, s=3) -_min, _max = min(map_quantiles_ece.min(), mc_conf_quantiles_ece.min()), max( - map_quantiles_ece.max(), mc_conf_quantiles_ece.max() -) +plt.xlabel("MAP") +plt.ylabel("TLMC") +_min, _max = min(np.array(map_rocauc).min(), np.array(mc_conf_rocauc).min()), max(np.array(map_rocauc).max(), np.array(mc_conf_rocauc).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") -plt.grid() +plt.scatter(map_rocauc, mc_conf_rocauc, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_mc_conf_rocauc, lose_mc_conf_rocauc)]) -plt.subplot(2, 4, 6) -plt.title("MSE") -plt.scatter(map_quantiles_mse, mc_conf_quantiles_mse, s=3) -_min, _max = min(map_quantiles_mse.min(), mc_conf_quantiles_mse.min()), max( - map_quantiles_mse.max(), mc_conf_quantiles_mse.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") +plt.subplot(3, 2, 5) +plt.title("PRAUC") plt.grid() - -plt.subplot(2, 4, 7) -plt.title("ROCAUC") -plt.scatter(map_quantiles_rocauc, mc_conf_quantiles_rocauc, s=3) -_min, _max = min(map_quantiles_rocauc.min(), mc_conf_quantiles_rocauc.min()), max( - map_quantiles_rocauc.max(), mc_conf_quantiles_rocauc.max() -) +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_prauc).min(), np.array(temp_scaling_prauc).min()), max(np.array(map_prauc).max(), np.array(temp_scaling_prauc).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") -plt.grid() +plt.scatter(map_prauc, temp_scaling_prauc, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_prauc, lose_temp_scaling_prauc)]) -plt.subplot(2, 4, 8) +plt.subplot(3, 2, 6) plt.title("PRAUC") -plt.scatter(map_quantiles_prauc, mc_conf_quantiles_prauc, s=3) -_min, _max = min(map_quantiles_prauc.min(), mc_conf_quantiles_prauc.min()), max( - map_quantiles_prauc.max(), mc_conf_quantiles_prauc.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") plt.grid() +plt.xlabel("MAP") +plt.ylabel("TLMC") +_min, _max = min(np.array(map_prauc).min(), np.array(mc_conf_prauc).min()), max(np.array(map_prauc).max(), np.array(mc_conf_prauc).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_prauc, mc_conf_prauc, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_mc_conf_prauc, lose_mc_conf_prauc)]) plt.tight_layout() plt.show() From 411a88b0ee833af7f9e87a91b9c8155be8cb50a7 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Wed, 27 Sep 2023 14:09:31 +0200 Subject: [PATCH 6/9] add CV+ Simple & Adaptive Prediction Conformal Classifiers --- fortuna/conformal/__init__.py | 4 +- .../classification/adaptive_prediction.py | 147 +++--------------- fortuna/conformal/classification/base.py | 93 ++++++++++- .../classification/simple_prediction.py | 131 +++------------- tests/fortuna/test_conformal_methods.py | 39 +++-- 5 files changed, 170 insertions(+), 244 deletions(-) diff --git a/fortuna/conformal/__init__.py b/fortuna/conformal/__init__.py index 28526699..e4def604 100644 --- a/fortuna/conformal/__init__.py +++ b/fortuna/conformal/__init__.py @@ -2,13 +2,13 @@ AdaptiveConformalClassifier, ) from fortuna.conformal.classification.adaptive_prediction import ( - AdaptivePredictionConformalClassifier, + AdaptivePredictionConformalClassifier, CVPlusAdaptivePredictionConformalClassifier ) from fortuna.conformal.classification.maxcovfixprec_binary_classfication import ( MaxCoverageFixedPrecisionBinaryClassificationCalibrator, ) from fortuna.conformal.classification.simple_prediction import ( - SimplePredictionConformalClassifier, + SimplePredictionConformalClassifier, CVPlusSimplePredictionConformalClassifier ) from fortuna.conformal.multivalid.iterative.classification.binary_multicalibrator import ( BinaryClassificationMulticalibrator, diff --git a/fortuna/conformal/classification/adaptive_prediction.py b/fortuna/conformal/classification/adaptive_prediction.py index 301d8f88..07d1dc5a 100644 --- a/fortuna/conformal/classification/adaptive_prediction.py +++ b/fortuna/conformal/classification/adaptive_prediction.py @@ -1,134 +1,37 @@ -from typing import ( - List, - Optional, -) - from jax import vmap import jax.numpy as jnp -import numpy as np -from fortuna.conformal.classification.base import ConformalClassifier +from fortuna.conformal.classification.base import SplitConformalClassifier, CVPlusConformalClassifier from fortuna.typing import Array -class AdaptivePredictionConformalClassifier(ConformalClassifier): - def score( - self, - val_probs: Array, - val_targets: Array, - ) -> jnp.ndarray: - """ - Compute score function. - - Parameters - ---------- - val_probs: Array - A two-dimensional array of class probabilities for each validation data point. - val_targets: Array - A one-dimensional array of validation target variables. - - Returns - ------- - jnp.ndarray - The conformal scores. - """ - if val_probs.ndim != 2: - raise ValueError( - """`val_probs` must be a two-dimensional array. The first dimension is over the validation - inputs. The second is over the classes.""" - ) - - perms = jnp.argsort(val_probs, axis=1)[:, ::-1] - inv_perms = jnp.argsort(perms, axis=1) - - @vmap - def score_fn(prob, perm, inv_perm, target): - sorted_prob = prob[perm] - return jnp.cumsum(sorted_prob)[inv_perm[target]] - - return score_fn(val_probs, perms, inv_perms, val_targets) - - def quantile( - self, - val_probs: Array, - val_targets: Array, - error: float = 0.05, - scores: Optional[Array] = None, - ) -> Array: - """ - Compute a quantile of the scores. - - Parameters - ---------- - val_probs: Array - A two-dimensional array of class probabilities for each validation data point. - val_targets: Array - A one-dimensional array of validation target variables. - error: float - Coverage error. This must be a scalar between 0 and 1, extremes included. - scores: Optional[Array] - The conformal scores. This should be the output of - :meth:`~fortuna.conformal.classification.adaptive_prediction.AdaptivePredictionConformalClassifier.score`. - - Returns - ------- - float - The conformal quantiles. - """ - if error < 0 or error > 1: - raise ValueError("""`error` must be a scalar between 0 and 1.""") +@vmap +def _score_fn(probs: Array, perm: Array, inv_perm: Array, targets: Array): + return jnp.cumsum(probs[perm])[inv_perm[targets]] - if scores is None: - scores = self.score(val_probs, val_targets) - n = scores.shape[0] - return jnp.quantile(scores, jnp.ceil((n + 1) * (1 - error)) / n) - def conformal_set( - self, - val_probs: Array, - test_probs: Array, - val_targets: Array, - error: float = 0.05, - quantile: Optional[float] = None, - ) -> List[List[int]]: - """ - Coverage set of each of the test inputs, at the desired coverage error. +def score_fn( + probs: Array, + targets: Array, +): + perms = jnp.argsort(probs, axis=1)[:, ::-1] + inv_perms = jnp.argsort(perms, axis=1) + return _score_fn(probs, perms, inv_perms, targets) - Parameters - ---------- - val_probs: Array - A two-dimensional array of class probabilities for each validation data point. - test_probs: Array - A two-dimensional array of class probabilities for each test data point. - val_targets: Array - A one-dimensional array of validation target variables. - error: float - The coverage error. This must be a scalar between 0 and 1, extremes included. - quantile: Optional[float] - Conformal quantiles. This should be the output of - :meth:`~fortuna.conformal.classification.adaptive_prediction.AdaptivePredictionConformalClassifier.quantile`. - Returns - ------- - List[List[int, ...]] - The coverage sets. - """ - if test_probs.ndim != 2: - raise ValueError( - """`test_probs` must be a two-dimensional array. The first dimension is over the validation - inputs. The second is over the classes.""" - ) +class AdaptivePredictionConformalClassifier(SplitConformalClassifier): + def score_fn( + self, + probs: Array, + targets: Array, + ): + return score_fn(probs=probs, targets=targets) - if quantile is None: - quantile = self.quantile(val_probs, val_targets, error) - test_perms = jnp.argsort(test_probs, axis=1)[:, ::-1] - test_sorted_probs = vmap(lambda prob, perm: prob[perm])(test_probs, test_perms) - sizes = ( - (test_sorted_probs.cumsum(axis=1) > quantile).astype("int32").argmax(axis=1) - ) - sets = np.zeros(len(sizes), dtype=object) - for s in jnp.unique(sizes): - idx = jnp.where(sizes == s)[0] - sets[idx] = test_perms[idx, : s + 1].tolist() - return sets.tolist() +class CVPlusAdaptivePredictionConformalClassifier(CVPlusConformalClassifier): + def score_fn( + self, + probs: Array, + targets: Array, + ): + return score_fn(probs=probs, targets=targets) diff --git a/fortuna/conformal/classification/base.py b/fortuna/conformal/classification/base.py index 0716975a..b1de6451 100644 --- a/fortuna/conformal/classification/base.py +++ b/fortuna/conformal/classification/base.py @@ -1,6 +1,9 @@ -from typing import List +import abc +from typing import List, Tuple import jax.numpy as jnp +from jax import vmap +import numpy as np from fortuna.typing import Array @@ -27,3 +30,91 @@ def is_in(self, values: Array, conformal_sets: List) -> Array: An array of ones or zero, indicating whether the values lie within their respective conformal sets. """ return jnp.array([v in s for v, s in zip(values.tolist(), conformal_sets)]) + + @abc.abstractmethod + def score_fn( + self, + probs: Array, + targets: Array, + ): + pass + + @staticmethod + def _get_conformal_sets_from_scores( + val_scores: Array, + test_scores: Array, + error: float, + ) -> List[List[int]]: + conds = jnp.sum(val_scores[:, None, None] > test_scores[None], axis=0) < (1 - error) * (len(val_scores) + 1) + sizes = conds.sum(1) + + sets = np.zeros(len(test_scores), dtype=object) + for us in jnp.unique(sizes): + idx = jnp.where(sizes == us)[0] + if us == 0: + sets[idx] = [len(idx) * []] + else: + sets[idx] = np.where(conds[idx])[1].reshape(-1, us).tolist() + + return sets.tolist() + + @abc.abstractmethod + def get_scores( + self, + *args, + **kwargs + ) -> Tuple[Array, Array]: + pass + + +class SplitConformalClassifier(ConformalClassifier, abc.ABC): + def get_scores( + self, + val_probs: Array, + val_targets: Array, + test_probs: Array + ) -> Tuple[Array, Array]: + val_scores = self.score_fn(val_probs, val_targets) + test_scores = vmap( + lambda i: self.score_fn(test_probs, i * jnp.ones(len(test_probs), dtype="int32")), out_axes=1 + )(jnp.arange(val_probs.shape[1])) + return val_scores, test_scores + + def conformal_set( + self, + val_probs: Array, + val_targets: Array, + test_probs: Array, + error: float + ) -> List[List[int]]: + val_scores, test_scores = self.get_scores(val_probs=val_probs, val_targets=val_targets, test_probs=test_probs) + return super()._get_conformal_sets_from_scores(val_scores=val_scores, test_scores=test_scores, error=error) + + +class CVPlusConformalClassifier(ConformalClassifier): + def conformal_set( + self, + cross_val_probs: List[Array], + cross_val_targets: List[Array], + cross_test_probs: List[Array], + error: float + ) -> List[List[int]]: + val_scores, test_scores = self.get_scores(cross_val_probs=cross_val_probs, cross_val_targets=cross_val_targets, cross_test_probs=cross_test_probs) + return super()._get_conformal_sets_from_scores(val_scores=val_scores, test_scores=test_scores, error=error) + + def get_scores( + self, + cross_val_probs: List[Array], + cross_val_targets: List[Array], + cross_test_probs: List[Array] + ) -> Tuple[Array, Array]: + val_scores, test_scores = [], [] + for val_probs, val_targets, test_probs in zip(cross_val_probs, cross_val_targets, cross_test_probs): + val_scores.append(self.score_fn(val_probs, val_targets)) + test_scores.append( + vmap( + lambda i: self.score_fn(test_probs, i * jnp.ones(len(test_probs), dtype="int32")), out_axes=1 + )(jnp.arange(cross_val_probs[0].shape[1])) + ) + + return jnp.concatenate(val_scores), jnp.concatenate(test_scores, axis=0) diff --git a/fortuna/conformal/classification/simple_prediction.py b/fortuna/conformal/classification/simple_prediction.py index 8be61d43..6fe3227e 100644 --- a/fortuna/conformal/classification/simple_prediction.py +++ b/fortuna/conformal/classification/simple_prediction.py @@ -1,119 +1,34 @@ -from typing import ( - List, - Optional, -) - from jax import vmap -import jax.numpy as jnp -from fortuna.conformal.classification.base import ConformalClassifier +from fortuna.conformal.classification.base import SplitConformalClassifier, CVPlusConformalClassifier from fortuna.typing import Array -class SimplePredictionConformalClassifier(ConformalClassifier): - def score( - self, - val_probs: Array, - val_targets: Array, - ) -> jnp.ndarray: - """ - Compute score function. - - Parameters - ---------- - val_probs: Array - A two-dimensional array of class probabilities for each validation data point. - val_targets: Array - A one-dimensional array of validation target variables. - - Returns - ------- - jnp.ndarray - The conformal scores. - """ - if val_probs.ndim != 2: - raise ValueError( - """`val_probs` must be a two-dimensional array. The first dimension is over the validation - inputs. The second is over the classes.""" - ) - - @vmap - def score_fn(prob, target): - return 1 - prob[target] - - return score_fn(val_probs, val_targets) - - def quantile( - self, - val_probs: Array, - val_targets: Array, - error: float = 0.05, - scores: Optional[Array] = None, - ) -> Array: - """ - Compute a quantile of the scores. - - Parameters - ---------- - val_probs: Array - A two-dimensional array of class probabilities for each validation data point. - val_targets: Array - A one-dimensional array of validation target variables. - error: float - Coverage error. This must be a scalar between 0 and 1, extremes included. - scores: Optional[Array] - The conformal scores. This should be the output of - :meth:`~fortuna.conformal.classification.simple_prediction.SimplePredictionConformalClassifier.score`. +@vmap +def _score_fn(probs: Array, target: Array): + return 1 - probs[target] - Returns - ------- - float - The conformal quantiles. - """ - if error < 0 or error > 1: - raise ValueError("""`error` must be a scalar between 0 and 1.""") - if scores is None: - scores = self.score(val_probs, val_targets) - n = scores.shape[0] - return jnp.quantile(scores, jnp.ceil((n + 1) * (1 - error)) / n) +def score_fn( + probs: Array, + targets: Array, +): + return _score_fn(probs, targets) - def conformal_set( - self, - val_probs: Array, - test_probs: Array, - val_targets: Array, - error: float = 0.05, - quantile: Optional[float] = None, - ) -> List[List[int]]: - """ - Coverage set of each of the test inputs, at the desired coverage error. - Parameters - ---------- - val_probs: Array - A two-dimensional array of class probabilities for each validation data point. - test_probs: Array - A two-dimensional array of class probabilities for each test data point. - val_targets: Array - A one-dimensional array of validation target variables. - error: float - The coverage error. This must be a scalar between 0 and 1, extremes included. - quantile: Optional[float] - Conformal quantiles. This should be the output of - :meth:`~fortuna.conformal.classification.simple_prediction.SimplePredictionConformalClassifier.quantile`. +class SimplePredictionConformalClassifier(SplitConformalClassifier): + def score_fn( + self, + probs: Array, + targets: Array, + ): + return score_fn(probs=probs, targets=targets) - Returns - ------- - List[List[int, ...]] - The coverage sets. - """ - if test_probs.ndim != 2: - raise ValueError( - """`test_probs` must be a two-dimensional array. The first dimension is over the validation - inputs. The second is over the classes.""" - ) - if quantile is None: - quantile = self.quantile(val_probs, val_targets, error) - return [jnp.where(prob > 1 - quantile)[0].tolist() for prob in test_probs] +class CVPlusSimplePredictionConformalClassifier(CVPlusConformalClassifier): + def score_fn( + self, + probs: Array, + targets: Array, + ): + return score_fn(probs=probs, targets=targets) diff --git a/tests/fortuna/test_conformal_methods.py b/tests/fortuna/test_conformal_methods.py index 5378b8ad..7d0f664e 100755 --- a/tests/fortuna/test_conformal_methods.py +++ b/tests/fortuna/test_conformal_methods.py @@ -7,6 +7,8 @@ from fortuna.conformal import ( AdaptiveConformalClassifier, AdaptiveConformalRegressor, + CVPlusSimplePredictionConformalClassifier, + CVPlusAdaptivePredictionConformalClassifier, AdaptivePredictionConformalClassifier, BatchMVPConformalClassifier, BatchMVPConformalRegressor, @@ -38,14 +40,11 @@ def test_prediction_conformal_classifier(self): test_probs = np.array([[0.5, 0.5], [0.8, 0.2]]) conformal = SimplePredictionConformalClassifier() - scores = conformal.score(val_probs, val_targets) - assert scores.shape == (3,) - quantile = conformal.quantile(val_probs, val_targets, 0.5, scores=scores) coverage_sets = conformal.conformal_set( val_probs=val_probs, test_probs=test_probs, val_targets=val_targets, - quantile=quantile, + error=0.05 ) assert ( (0 in coverage_sets[0]) @@ -53,21 +52,39 @@ def test_prediction_conformal_classifier(self): and (0 in coverage_sets[1]) ) + def test_simple_prediction_conformal_classifier(self): + val_probs = np.array([[0.1, 0.7, 0.2], [0.5, 0.4, 0.1], [0.15, 0.6, 0.35]]) + val_targets = np.array([1, 1, 2]) + test_probs = np.array([[0.2, 0.5, 0.3], [0.6, 0.01, 0.39]]) + + conformal = SimplePredictionConformalClassifier() + coverage_sets = conformal.conformal_set( + val_probs, val_targets, test_probs, error=0.05 + ) + assert len(coverage_sets) == len(test_probs) + + conformal = CVPlusSimplePredictionConformalClassifier() + coverage_sets = conformal.conformal_set( + [val_probs, val_probs], [val_targets, val_targets], [test_probs, test_probs], error=0.05 + ) + assert len(coverage_sets) == 2 * len(test_probs) + def test_adaptive_prediction_conformal_classifier(self): val_probs = np.array([[0.1, 0.7, 0.2], [0.5, 0.4, 0.1], [0.15, 0.6, 0.35]]) val_targets = np.array([1, 1, 2]) test_probs = np.array([[0.2, 0.5, 0.3], [0.6, 0.01, 0.39]]) conformal = AdaptivePredictionConformalClassifier() - scores = conformal.score(val_probs, val_targets) - assert jnp.allclose(scores, jnp.array([0.7, 0.9, 0.95])) - quantile = conformal.quantile(val_probs, val_targets, 0.5, scores=scores) - assert (quantile > 0.9) * (quantile < 0.95) coverage_sets = conformal.conformal_set( - val_probs, test_probs, val_targets, quantile=quantile + val_probs, val_targets, test_probs, error=0.05 + ) + assert len(coverage_sets) == len(test_probs) + + conformal = CVPlusAdaptivePredictionConformalClassifier() + coverage_sets = conformal.conformal_set( + [val_probs, val_probs], [val_targets, val_targets], [test_probs, test_probs], error=0.05 ) - assert np.allclose(coverage_sets[0], [1, 2, 0]) - assert np.allclose(coverage_sets[1], [0, 2]) + assert len(coverage_sets) == 2 * len(test_probs) def test_quantile_conformal_regressor(self): n_val_inputs = 100 From 270eb52a5a38013093d1822d9f71c3e975e41113 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Wed, 27 Sep 2023 14:15:37 +0200 Subject: [PATCH 7/9] pre-commit and bump up version --- fortuna/conformal/__init__.py | 6 +- .../classification/adaptive_prediction.py | 21 +++-- fortuna/conformal/classification/base.py | 90 +++++++++++-------- .../classification/simple_prediction.py | 21 +++-- pyproject.toml | 2 +- tests/fortuna/test_conformal_methods.py | 16 ++-- 6 files changed, 91 insertions(+), 65 deletions(-) diff --git a/fortuna/conformal/__init__.py b/fortuna/conformal/__init__.py index e4def604..5d4a191a 100644 --- a/fortuna/conformal/__init__.py +++ b/fortuna/conformal/__init__.py @@ -2,13 +2,15 @@ AdaptiveConformalClassifier, ) from fortuna.conformal.classification.adaptive_prediction import ( - AdaptivePredictionConformalClassifier, CVPlusAdaptivePredictionConformalClassifier + AdaptivePredictionConformalClassifier, + CVPlusAdaptivePredictionConformalClassifier, ) from fortuna.conformal.classification.maxcovfixprec_binary_classfication import ( MaxCoverageFixedPrecisionBinaryClassificationCalibrator, ) from fortuna.conformal.classification.simple_prediction import ( - SimplePredictionConformalClassifier, CVPlusSimplePredictionConformalClassifier + CVPlusSimplePredictionConformalClassifier, + SimplePredictionConformalClassifier, ) from fortuna.conformal.multivalid.iterative.classification.binary_multicalibrator import ( BinaryClassificationMulticalibrator, diff --git a/fortuna/conformal/classification/adaptive_prediction.py b/fortuna/conformal/classification/adaptive_prediction.py index 07d1dc5a..46ab5060 100644 --- a/fortuna/conformal/classification/adaptive_prediction.py +++ b/fortuna/conformal/classification/adaptive_prediction.py @@ -1,7 +1,10 @@ from jax import vmap import jax.numpy as jnp -from fortuna.conformal.classification.base import SplitConformalClassifier, CVPlusConformalClassifier +from fortuna.conformal.classification.base import ( + CVPlusConformalClassifier, + SplitConformalClassifier, +) from fortuna.typing import Array @@ -11,8 +14,8 @@ def _score_fn(probs: Array, perm: Array, inv_perm: Array, targets: Array): def score_fn( - probs: Array, - targets: Array, + probs: Array, + targets: Array, ): perms = jnp.argsort(probs, axis=1)[:, ::-1] inv_perms = jnp.argsort(perms, axis=1) @@ -21,17 +24,17 @@ def score_fn( class AdaptivePredictionConformalClassifier(SplitConformalClassifier): def score_fn( - self, - probs: Array, - targets: Array, + self, + probs: Array, + targets: Array, ): return score_fn(probs=probs, targets=targets) class CVPlusAdaptivePredictionConformalClassifier(CVPlusConformalClassifier): def score_fn( - self, - probs: Array, - targets: Array, + self, + probs: Array, + targets: Array, ): return score_fn(probs=probs, targets=targets) diff --git a/fortuna/conformal/classification/base.py b/fortuna/conformal/classification/base.py index b1de6451..73aa3b24 100644 --- a/fortuna/conformal/classification/base.py +++ b/fortuna/conformal/classification/base.py @@ -1,8 +1,11 @@ import abc -from typing import List, Tuple +from typing import ( + List, + Tuple, +) -import jax.numpy as jnp from jax import vmap +import jax.numpy as jnp import numpy as np from fortuna.typing import Array @@ -33,19 +36,21 @@ def is_in(self, values: Array, conformal_sets: List) -> Array: @abc.abstractmethod def score_fn( - self, - probs: Array, - targets: Array, + self, + probs: Array, + targets: Array, ): pass @staticmethod def _get_conformal_sets_from_scores( - val_scores: Array, - test_scores: Array, - error: float, + val_scores: Array, + test_scores: Array, + error: float, ) -> List[List[int]]: - conds = jnp.sum(val_scores[:, None, None] > test_scores[None], axis=0) < (1 - error) * (len(val_scores) + 1) + conds = jnp.sum(val_scores[:, None, None] > test_scores[None], axis=0) < ( + 1 - error + ) * (len(val_scores) + 1) sizes = conds.sum(1) sets = np.zeros(len(test_scores), dtype=object) @@ -59,61 +64,68 @@ def _get_conformal_sets_from_scores( return sets.tolist() @abc.abstractmethod - def get_scores( - self, - *args, - **kwargs - ) -> Tuple[Array, Array]: + def get_scores(self, *args, **kwargs) -> Tuple[Array, Array]: pass class SplitConformalClassifier(ConformalClassifier, abc.ABC): def get_scores( - self, - val_probs: Array, - val_targets: Array, - test_probs: Array + self, val_probs: Array, val_targets: Array, test_probs: Array ) -> Tuple[Array, Array]: val_scores = self.score_fn(val_probs, val_targets) test_scores = vmap( - lambda i: self.score_fn(test_probs, i * jnp.ones(len(test_probs), dtype="int32")), out_axes=1 + lambda i: self.score_fn( + test_probs, i * jnp.ones(len(test_probs), dtype="int32") + ), + out_axes=1, )(jnp.arange(val_probs.shape[1])) return val_scores, test_scores def conformal_set( - self, - val_probs: Array, - val_targets: Array, - test_probs: Array, - error: float + self, val_probs: Array, val_targets: Array, test_probs: Array, error: float ) -> List[List[int]]: - val_scores, test_scores = self.get_scores(val_probs=val_probs, val_targets=val_targets, test_probs=test_probs) - return super()._get_conformal_sets_from_scores(val_scores=val_scores, test_scores=test_scores, error=error) + val_scores, test_scores = self.get_scores( + val_probs=val_probs, val_targets=val_targets, test_probs=test_probs + ) + return super()._get_conformal_sets_from_scores( + val_scores=val_scores, test_scores=test_scores, error=error + ) class CVPlusConformalClassifier(ConformalClassifier): def conformal_set( - self, - cross_val_probs: List[Array], - cross_val_targets: List[Array], - cross_test_probs: List[Array], - error: float + self, + cross_val_probs: List[Array], + cross_val_targets: List[Array], + cross_test_probs: List[Array], + error: float, ) -> List[List[int]]: - val_scores, test_scores = self.get_scores(cross_val_probs=cross_val_probs, cross_val_targets=cross_val_targets, cross_test_probs=cross_test_probs) - return super()._get_conformal_sets_from_scores(val_scores=val_scores, test_scores=test_scores, error=error) + val_scores, test_scores = self.get_scores( + cross_val_probs=cross_val_probs, + cross_val_targets=cross_val_targets, + cross_test_probs=cross_test_probs, + ) + return super()._get_conformal_sets_from_scores( + val_scores=val_scores, test_scores=test_scores, error=error + ) def get_scores( - self, - cross_val_probs: List[Array], - cross_val_targets: List[Array], - cross_test_probs: List[Array] + self, + cross_val_probs: List[Array], + cross_val_targets: List[Array], + cross_test_probs: List[Array], ) -> Tuple[Array, Array]: val_scores, test_scores = [], [] - for val_probs, val_targets, test_probs in zip(cross_val_probs, cross_val_targets, cross_test_probs): + for val_probs, val_targets, test_probs in zip( + cross_val_probs, cross_val_targets, cross_test_probs + ): val_scores.append(self.score_fn(val_probs, val_targets)) test_scores.append( vmap( - lambda i: self.score_fn(test_probs, i * jnp.ones(len(test_probs), dtype="int32")), out_axes=1 + lambda i: self.score_fn( + test_probs, i * jnp.ones(len(test_probs), dtype="int32") + ), + out_axes=1, )(jnp.arange(cross_val_probs[0].shape[1])) ) diff --git a/fortuna/conformal/classification/simple_prediction.py b/fortuna/conformal/classification/simple_prediction.py index 6fe3227e..5dfd98d9 100644 --- a/fortuna/conformal/classification/simple_prediction.py +++ b/fortuna/conformal/classification/simple_prediction.py @@ -1,6 +1,9 @@ from jax import vmap -from fortuna.conformal.classification.base import SplitConformalClassifier, CVPlusConformalClassifier +from fortuna.conformal.classification.base import ( + CVPlusConformalClassifier, + SplitConformalClassifier, +) from fortuna.typing import Array @@ -10,25 +13,25 @@ def _score_fn(probs: Array, target: Array): def score_fn( - probs: Array, - targets: Array, + probs: Array, + targets: Array, ): return _score_fn(probs, targets) class SimplePredictionConformalClassifier(SplitConformalClassifier): def score_fn( - self, - probs: Array, - targets: Array, + self, + probs: Array, + targets: Array, ): return score_fn(probs=probs, targets=targets) class CVPlusSimplePredictionConformalClassifier(CVPlusConformalClassifier): def score_fn( - self, - probs: Array, - targets: Array, + self, + probs: Array, + targets: Array, ): return score_fn(probs=probs, targets=targets) diff --git a/pyproject.toml b/pyproject.toml index 3db0a1f1..d17b9d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws-fortuna" -version = "0.1.34" +version = "0.1.35" description = "A Library for Uncertainty Quantification." authors = ["Gianluca Detommaso ", "Alberto Gasparin "] license = "Apache-2.0" diff --git a/tests/fortuna/test_conformal_methods.py b/tests/fortuna/test_conformal_methods.py index 7d0f664e..2f815856 100755 --- a/tests/fortuna/test_conformal_methods.py +++ b/tests/fortuna/test_conformal_methods.py @@ -7,13 +7,13 @@ from fortuna.conformal import ( AdaptiveConformalClassifier, AdaptiveConformalRegressor, - CVPlusSimplePredictionConformalClassifier, - CVPlusAdaptivePredictionConformalClassifier, AdaptivePredictionConformalClassifier, BatchMVPConformalClassifier, BatchMVPConformalRegressor, BinaryClassificationMulticalibrator, + CVPlusAdaptivePredictionConformalClassifier, CVPlusConformalRegressor, + CVPlusSimplePredictionConformalClassifier, EnbPI, JackknifeMinmaxConformalRegressor, JackknifePlusConformalRegressor, @@ -44,7 +44,7 @@ def test_prediction_conformal_classifier(self): val_probs=val_probs, test_probs=test_probs, val_targets=val_targets, - error=0.05 + error=0.05, ) assert ( (0 in coverage_sets[0]) @@ -65,7 +65,10 @@ def test_simple_prediction_conformal_classifier(self): conformal = CVPlusSimplePredictionConformalClassifier() coverage_sets = conformal.conformal_set( - [val_probs, val_probs], [val_targets, val_targets], [test_probs, test_probs], error=0.05 + [val_probs, val_probs], + [val_targets, val_targets], + [test_probs, test_probs], + error=0.05, ) assert len(coverage_sets) == 2 * len(test_probs) @@ -82,7 +85,10 @@ def test_adaptive_prediction_conformal_classifier(self): conformal = CVPlusAdaptivePredictionConformalClassifier() coverage_sets = conformal.conformal_set( - [val_probs, val_probs], [val_targets, val_targets], [test_probs, test_probs], error=0.05 + [val_probs, val_probs], + [val_targets, val_targets], + [test_probs, test_probs], + error=0.05, ) assert len(coverage_sets) == 2 * len(test_probs) From 3a3f8a5b82ee1bf106f588413cfba3ded7cd7875 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Wed, 27 Sep 2023 14:29:52 +0200 Subject: [PATCH 8/9] make error explicit in notebook --- examples/mnist_classification.pct.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/mnist_classification.pct.py b/examples/mnist_classification.pct.py index 37200bbb..08468afd 100644 --- a/examples/mnist_classification.pct.py +++ b/examples/mnist_classification.pct.py @@ -153,6 +153,7 @@ def download(split_range, shuffle=False): val_probs=val_means, test_probs=test_means, val_targets=val_data_loader.to_array_targets(), + error=0.05 ) # %% [markdown] From acf08e7db79c776dd778dc868165ae92f9033371 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Wed, 27 Sep 2023 14:45:46 +0200 Subject: [PATCH 9/9] add error explitly in doc and notebooks --- docs/source/usage_modes/flax_models.rst | 3 ++- docs/source/usage_modes/model_outputs.rst | 3 ++- docs/source/usage_modes/uncertainty_estimates.rst | 3 ++- examples/mnist_classification_sghmc.pct.py | 1 + 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/usage_modes/flax_models.rst b/docs/source/usage_modes/flax_models.rst index 1591b47c..debace0c 100644 --- a/docs/source/usage_modes/flax_models.rst +++ b/docs/source/usage_modes/flax_models.rst @@ -195,7 +195,8 @@ but a new one could be used. conformal_sets = AdaptivePredictionConformalClassifier().conformal_set( val_probs=calib_means, test_probs=test_means, - val_targets=calib_targets + val_targets=calib_targets, + error=0.05 ) .. _flax_models_regression: diff --git a/docs/source/usage_modes/model_outputs.rst b/docs/source/usage_modes/model_outputs.rst index 388a4a4b..bc60b5f7 100644 --- a/docs/source/usage_modes/model_outputs.rst +++ b/docs/source/usage_modes/model_outputs.rst @@ -116,7 +116,8 @@ and :code:`val_targets` to be the corresponding validation target variables. conformal_sets = AdaptivePredictionConformalClassifier().conformal_set( val_probs=val_means, test_probs=test_means, - val_targets=val_targets + val_targets=val_targets, + error=0.05 ) .. _model_outputs_regression: diff --git a/docs/source/usage_modes/uncertainty_estimates.rst b/docs/source/usage_modes/uncertainty_estimates.rst index 058d6402..f8615061 100644 --- a/docs/source/usage_modes/uncertainty_estimates.rst +++ b/docs/source/usage_modes/uncertainty_estimates.rst @@ -36,7 +36,8 @@ Please check :class:`~fortuna.conformal.classification.adaptive_prediction.Adapt conformal_sets = AdaptivePredictionConformalClassifier().conformal_set( val_probs=val_probs, test_probs=test_probs, - val_targets=val_targets + val_targets=val_targets, + error=0.05 ) You should usually expect your test predictions to be included in the conformal sets, as they contain the most probable diff --git a/examples/mnist_classification_sghmc.pct.py b/examples/mnist_classification_sghmc.pct.py index f1914185..0f133d4c 100644 --- a/examples/mnist_classification_sghmc.pct.py +++ b/examples/mnist_classification_sghmc.pct.py @@ -143,6 +143,7 @@ def download(split_range, shuffle=False): val_probs=val_means, test_probs=test_means, val_targets=val_data_loader.to_array_targets(), + error=0.05 ) # %% [markdown]