diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 4cf53bb0..0ac16104 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -3,6 +3,16 @@ on: # Disabled for now essentially workflow_dispatch: + #push: + # branches: + # - main + #tags: + # - "*.*.*" + + #pull_request: + #branches: + # - main + permissions: contents: write @@ -18,5 +28,4 @@ jobs: with: key: ${{ github.ref }} path: .cache - - run: pip install "mkdocs-material" "mkdocs-autorefs" "mkdocstrings[python]" - - run: mkdocs gh-deploy --force + - run: python -m pip install -e ".[dev]" diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 6989a8c3..8a154ac2 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -14,10 +14,10 @@ jobs: with: submodules: recursive - - name: Setup Python 3.8 + - name: Setup Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.10 - name: Install pre-commit run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5e80dbca..dac32640 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,20 +1,24 @@ name: Tests on: - # Disabled for now essentially workflow_dispatch: + push: + branches: + - main + tags: + - "*.*.*" + + pull_request: + branches: + - main + env: - package-name: byop + package-name: amltk test-dir: tests extra-requires: "[dev]" # "" for no extra_requires - # Arguments used for pytest - pytest-args: >- - -v - --log-level=DEBUG - jobs: test: @@ -28,7 +32,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] os: ["ubuntu-latest", "macos-latest", "windows-latest"] steps: @@ -48,4 +52,4 @@ jobs: - name: Tests run: | - pytest ${{ env.pytest-args }} ${{ env.test-dir }} + pytest ${{ env.test-dir }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3092c5b3..068e62b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ files: | )/.*\.py$ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-added-large-files files: ".*" @@ -26,25 +26,19 @@ repos: - id: debug-statements files: '^src/.*\.py$' - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + rev: v1.7.0 hooks: - id: mypy exclude: "test_comm_task" # Pre-commit mypy hates this one, crashes on (l106) additional_dependencies: - - "attrs" - "types-pyyaml" - "types-psutil" args: - "--no-warn-return-any" # Disable this because it doesn't know about 3rd party imports - "--ignore-missing-imports" - "--show-traceback" - - repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - args: ["--config=pyproject.toml"] - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.23.3 + rev: 0.27.1 hooks: - id: check-github-workflows files: '^github/workflows/.*\.ya?ml$' @@ -52,11 +46,12 @@ repos: - id: check-dependabot files: '^\.github/dependabot\.ya?ml$' - repo: https://github.com/commitizen-tools/commitizen - rev: 3.5.3 + rev: 3.12.0 hooks: - id: commitizen - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.278 + rev: v0.1.5 hooks: - id: ruff - args: [--fix, --exit-non-zero-on-fix, --no-cache] \ No newline at end of file + args: [--fix, --exit-non-zero-on-fix, --no-cache] + - id: ruff-format \ No newline at end of file diff --git a/docs/api_generator.py b/docs/api_generator.py index f07bf28d..25d7492a 100644 --- a/docs/api_generator.py +++ b/docs/api_generator.py @@ -23,6 +23,9 @@ if parts[-1] in ("__main__", "__version__", "__init__"): continue + if any(part.startswith("_") for part in parts): + continue + nav[parts] = doc_path.as_posix() with mkdocs_gen_files.open(full_doc_path, "w") as fd: diff --git a/docs/example_runner.py b/docs/example_runner.py index d09209e9..dff2fb71 100644 --- a/docs/example_runner.py +++ b/docs/example_runner.py @@ -8,6 +8,7 @@ from itertools import takewhile from pathlib import Path from typing import Any +from typing_extensions import override import mkdocs_gen_files from more_itertools import first_true, peekable @@ -15,7 +16,7 @@ logger = logging.getLogger(__name__) logging.basicConfig(level=logging.WARNING) -nav = mkdocs_gen_files.Nav() # pyright: reportPrivateImportUsage=false +nav = mkdocs_gen_files.Nav() # pyright: ignore[reportPrivateImportUsage]=false ENV_VAR = "AMLTK_DOC_RENDER_EXAMPLES" @@ -119,6 +120,7 @@ def code(self, code: list[str]) -> str: body = "\n".join(s) return body + @override def __str__(self) -> str: return self.code(self.lines) @@ -127,6 +129,7 @@ def __str__(self) -> str: class CommentSegment: lines: list[str] + @override def __str__(self) -> str: return "\n".join(self.lines) @@ -289,8 +292,5 @@ def copy_section(self) -> str: mkdocs_gen_files.set_edit_path(full_doc_path, path) lines = list(nav.build_literate_nav()) -with mkdocs_gen_files.open("examples/SUMMARY.md", "w") as nav_file: # +with mkdocs_gen_files.open("examples/index.md", "w") as nav_file: # nav_file.writelines(lines) # - -with mkdocs_gen_files.open("examples/index.md", "w") as index_file: - index_file.writelines(lines) # diff --git a/docs/guides/index.md b/docs/guides/index.md index 1b4094a9..9633ca15 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -3,76 +3,108 @@ of AutoML-Toolkit. Notably, we have three core concepts at the heart of AutoML-Toolkit, with supporting types and auxiliary functionality to enable these concepts. -These take the form of a [`Task`][amltk.scheduling.Task], a [`Pipeline`][amltk.pipeline.Pipeline] -and an [`Optimizer`][amltk.optimization.Optimizer] which combines the two -to create the most flexible optimization framework we could imagine. +These take the form of a **scheduling**, a **pipeline construction** +and **optimization**. By combining these concepts, we provide an extensive +framework from which to do AutoML research, utilize AutoML for you +task or build brand new AutoML systems. --- -- **Task** +- **Scheduling** - A `Task` is a function which we want to run _somewhere_, whether it be a local - process, on some node of a cluster or out in the cloud. Equipped with an - [`asyncio`][asyncio] **event-system** and a [`Scheduler`][amltk.scheduling.Scheduler] - to drive the gears of the system, we can provide a truly flexible and performant framework - upon to which to build an AutoML system. + Dealing with multiple processes and simultaneous compute, + can be both difficult in terms of understanding and utilization. + Often a prototype script just doesn't work when you need to run + larger experiments. + + We provide an **event-driven system** with a flexible **backend**, + to help you write code that scales from just a few more cores on your machine + to utilizing an entire cluster. + + This guide introduces `Task`s and the `Scheduler` in which they run, as well + as `@events` which you can subscribe callbacks to. Define what should run, when + it should run and simply define a callback to say what should happen when it's done. + + This framework allows you to write code that simply scales, with as little + code change required as possible. Go from a single local process to an entire + cluster with the same script and 5 lines of code. + + Checkout the [Scheduling guide!](./scheduling.md) for the full guide. + We also cover some of these topics in brief detail in the reference pages. !!! tip "Notable Features" * A system that allows incremental and encapsulated feature addition. - * An event-driven system with easy to use _callbacks_. - * Place constraints on your `Task`. - * Integrations for different backends for where to run your tasks. + * An [`@event`](site:reference/scheduling/events.md) system with easy to use _callbacks_. + * Place constraints and modify your [`Task`](site:reference/scheduling/task.md) + with [`Plugins`](site:reference/scheduling/plugins.md) + * Integrations for different [backends](site:reference/scheuling/executors.md) for where + to run your tasks. * A wide set of events to plug into. - * An easy to extend system to create your own specialized events and tasks. - - Checkout the [Scheduling guide](./scheduling.md) + * An easy way to extend the functionality provided with your own set of domain or task + specific events. --- -- **Pipeline** +- **Pipelines** + + Optimizer require some _search space_ to optimize, yet provide no utility to actually + define these search space. When scaling beyond a simple single model, these search space + become harder to define, difficult to extend and are often disjoint from the actual pipeline + creation. When you want to create search spaces that can have choices between models, parametrized + pre-processing and a method to quickly change these setups, it can often feel tedious + and error-prone + + By piecing together `Node`s of a pipeline, utilizing a set of different building blocks such + as a `Component`, `Sequential`, `Choice`es and more, you can abstractly define your entire pipeline. + Once you're done, we'll stitch together the entire `search_space()`, allow you to + easily `configure()` it and finally `build()` it into a concrete object you can use, + all in the same place. - A [`Pipeline`][amltk.pipeline.Pipeline] is a definition, - defining what your **pipeline** will do and how - it can be parametrized. By piecing together [`steps`][amltk.pipeline.api.step], - [`choices`][amltk.pipeline.api.choice] and [`splits`][amltk.pipeline.api.split], you can - say how your pipeline should look and how it's parametrized. We'll take care - of creating the search space to optimize over, configuring it and finally assembling - it into something you can actually use. + Checkout the [Pipeline guide!](./pipelines.md) + We also cover some of these topics in brief detail in the reference pages. !!! tip "Notable Features" - * An easy to edit pipeline structure, allowing for rapid addition, deletion and + * An easy, declaritive pipeline structure, allowing for rapid addition, deletion and modification during experimentation. * A flexible pipeline capable of handling complex structures and subpipelines. - * Easily attachable modules for things close to your pipeline but not a direct - part of the main structure. + * Mutliple component types to help you define your pipeline. * Exporting of pipelines into concrete implementations like an [sklearn.pipeline.Pipeline][] for use in your downstream tasks. + * Extensible to add your own component types and `builder=`s to use. - Checkout the [Pipeline guide](./pipelines.md) --- -- **Optimizer** +- **Optimization** - An [`Optimizer`][amltk.optimization.Optimizer] is the capstone of the preceding two - fundamental systems. By leveraging an _"ask-and-tell"_ interface, we put you back - in control of how your system interacts with the optimizer. You run what you want, - wherever you want, telling the optimizer what you want and you storing what you want, - wherever you want. - This makes leveraging different optimizers easier than ever. By capturing the high-level - core loop of black box optimization into a simple [`Trial`][amltk.optimization.Trial] and - a [`Report`][amltk.optimization.Trial.Report], integrating your own optimizer is easy and - provides the entire system that AutoML-Toolkit offers with little cost. + An optimizer is the backbone behind many AutoML systems and the quickest way + to improve the performance of your current pipelines. However optimizer's vary + in terms of how they expect you to write code, they vary in how much control they + take of your code and can be quite difficult to interact with other than + their `run()` function. + + By setting a simple expectation on an `Optimizer`, e.g. that it should have + an `ask()` and `tell()`, you are placed get back in terms of defining the loop, + define what happens, when and you can store what you'd like to record and put it + where you'd like to put it. + + By unifying their suggestions as a `Trial` and a convenient `Report` to hand back + to them, you can switch between optimizers with minimal changes required. We have + added a load of utility to the `Trial`'s, such that you can easily profile sections, + add extra summary information, store artifacts and export DataFrames. + + Checkout the [Optimization guide](./optimization.md). We recommend reading the previous + two guides to fully understand the possibilities with optimization. + We also cover some of these topics in brief detail in the reference pages. !!! tip "Notable Features" * An assortment of different optimizers for you to swap in an out with relative ease through a unified interface. + * A suite of utilities to help you record that data you want from your HPO experiments. * Full control of how you interact with it, allowing for easy warm-starting, complex swapping mechanisms or custom stopping criterion. * A simple interface to integrate in your own optimizer. - Checkout the [Optimization guide](./optimization.md). We recommend reading the previous - two guides to fully understand the possibilities with optimization. diff --git a/docs/guides/optimization.md b/docs/guides/optimization.md index f43c1d60..419db739 100644 --- a/docs/guides/optimization.md +++ b/docs/guides/optimization.md @@ -1,10 +1,16 @@ # Optimization Guide -!!! todo "This guide needs to be revisited" +!!! todo "Under Construction" - Sorry ... + Need to document more about `Trial` objects, `Report` and `History`. + Please see the following references for the time being: + + * [`Optimizer` reference](site:reference/optmization/optimizers.md) + * [`Trial` reference](site:reference/optmization/trials.md) + * [`Profiler` reference](site:reference/optmization/history.md) + * [`History` reference](site:reference/optmization/profiling.md) One of the core tasks of any AutoML system is to optimize some objective, -whether it be some [`Pipeline`][amltk.pipeline.Pipeline], a black-box or even a toy function. +whether it be some pipeline, a black-box or even a toy function. For this we require an [`Optimizer`][amltk.optimization.Optimizer]. We integrate several optimizers and integrating your own is very straightforward, under one @@ -26,6 +32,8 @@ very central premise. ``` +You can check out the integrated optimizers in our [optimizer reference](site:reference/optimization/optimizers.md) + ??? note "Why?" 1. **Easy Parallelization**: Many optimizers handle running the function to optimize and hence roll out their own @@ -43,259 +51,285 @@ very central premise. to worry that the internal state of the optimizer is updated accordingly to these two _"Ask"_ and _"Tell"_ events and that's it. -This guide relies lightly on topics covered in the [Pipeline Guide](./pipelines.md) for -creating a `Pipeline` but also the [Scheduling guide](./scheduling.md) for creating a +This guide relies lightly on topics covered in the [Pipeline Guide](site:guides/pipelines.md) for +creating a pipeline but also the [Scheduling guide](site:guides/scheduling.md) for creating a [`Scheduler`][amltk.scheduling.Scheduler] and a [`Task`][amltk.scheduling.Task]. These aren't required but if something is not clear or you'd like to know **how** something -works, please refer to these guides +works, please refer to these guides or the reference! -## Optimizating a simple function +## Optimizing a simple function We'll start with a simple example of optimizing a simple polynomial function The first thing to do is define the function we want to optimize. -=== "Polynomial" - - ```python - def poly(x): - return (x**2 + 4*x + 3) / x - ``` - -=== "Typed" - - ```python - def poly(x: float) -> float: - return (x**2 + 4*x + 3) / x - ``` +```python +def poly(x: float) -> float: + return (x**2 + 4*x + 3) / x +``` Our next step is to define the search range over which we want to optimize, in this case, the range of values `x` can take. We cover this in more detail -in the [Pipeline guide](./pipelines.md). - -=== "Defining a Search Space" - - ```python hl_lines="6" - from amltk.pipeline import searchable - - def poly(x): - return (x**2 + 4*x + 3) / x +in the [Pipeline guide](site:guides/pipelines.md). - s = searchable("parameters", space={"x": (-10.0, 10.0)}) # (1)! - ``` - - 1. Here we say that there is a collection of `#!python "parameters"` - which has one called `#!python "x"` which is in the range `#!python [-10.0, 10.0]`. - -=== "Typed" - - ```python hl_lines="6" - from amltk.pipeline import searchable - - def poly(x: float) -> float: - return (x**2 + 4*x + 3) / x - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) # (1)! - ``` - - 1. Here we say that there is a collection of `#!python "parameters"` - which has one called `#!python "x"` which is in the range `#!python [-10.0, 10.0]`. - -## Creating an optimizer +```python exec="true" source="material-block" html="true" hl_lines="6" +from amltk.pipeline import Searchable -We'll start by using [`RandomSearch`][amltk.optimization.RandomSearch] to search -for an optimal value for `#!python "x"` but later on we'll switch to using -[SMAC](https://github.com/automl/SMAC3) which is a much smarter optimizer. +def poly(x: float) -> float: + return (x**2 + 4*x + 3) / x -=== "Creating an optmizer" +s = Searchable(space={"x": (-10.0, 10.0)}, name="my-searchable") +from amltk._doc import doc_print; doc_print(print, s) +``` - ```python hl_lines="9 10" - from amltk.optimization import RandomSearch - from amltk.pipeline import searchable +Here we say that there is a collection of `#!python "parameters"` +which has one called `#!python "x"` which is in the range `#!python [-10.0, 10.0]`. - def poly(x): - return (x**2 + 4*x + 3) / x - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) - - space = s.space() - random_search = RandomSearch(space=space, seed=42) - ``` +## Creating an Optimizer -=== "Typed" +We'll utilize by using [SMAC](https://github.com/automl/SMAC3) +here for optimization as an example but you can find other available +optimizers [here](site:reference/optimization/optimizers.md). - ```python hl_lines="9 10" - from amltk.optimization import RandomSearch - from amltk.pipeline import searchable +??? info "Requirements" - def poly(x: float) -> float: - return (x**2 + 4*x + 3) / x + This requires `smac` which can be installed with: - s = searchable("parameters", space={"x": (-10.0, 10.0)}) + ```bash + pip install amltk[smac] - space = s.space() - random_search = RandomSearch(space=space, seed=42) + # Or directly + pip install smac ``` -Some of the integrated optimizers: +Our first step is to actually get a search space from +our _pipeline_. +```python exec="true" result="python" source="material-block" hl_lines="9" +from amltk.optimization.optimizers.smac import SMACOptimizer +from amltk.pipeline import Searchable -### Random Search +def poly(x: float) -> float: + return (x**2 + 4*x + 3) / x -A custom implementation of Random Search which randomly selects -configurations from the `space` to evaluate. +s = Searchable(space={"x": (-10.0, 10.0)}, name="my-searchable") -??? note "Usage" +space = s.search_space(parser="configspace") +print(space) +``` - You can use [`RandomSearch`][amltk.optimization.RandomSearch] - by simply passing in the `space` and optionally a `seed` which - will be used for sampling. +Here we chose that the search space should be a `#!python "configspace"`, +which is the kind of search space that +[`SMACOptimizer`][amltk.optimization.optimizers.smac.SMACOptimizer] expects. - ```python exec="true" source="material-block" result="python" title="RandomSearch Construction" - from amltk.optimization import RandomSearch - from amltk.pipeline import searchable +!!! info inline end "Available Optimizers" - my_searchable = searchable("myspace", space={"x": (-10.0, 10.0)}) - space = my_searchable.space() + To see a list of available optimizers and their usage, please + see the [optimizer reference](site:reference/optimization/optimizers.md). - random_search = RandomSearch(space=space, seed=42) +```python exec="true" result="python" source="material-block" +from amltk.optimization.optimizers.smac import SMACOptimizer +from amltk.pipeline import Searchable - trial = random_search.ask() - print(trial) - ``` +def poly(x: float) -> float: + return (x**2 + 4*x + 3) / x - --- +s = Searchable(space={"x": (-10.0, 10.0)}, name="my-searchable") - By default, [`RandomSearch`][amltk.optimization.RandomSearch] - does not allow duplicates. If it can't sample a unique config - in `max_samples_attempts=` (default: `#!python 50`), then - it will deem that there are no more unique configs and raise - a [`RandomSearch.ExhaustedError`][amltk.optimization.RandomSearch.ExhaustedError]. +space = s.search_space(parser="configspace") +print(space) +``` - ```python exec="true" source="tabbed-left" result="python" returncode="1" title="RandomSearch Exhausted" tabs="Source | Error" hl_lines="5 10 11 14" - import traceback - from amltk.optimization import RandomSearch - from amltk.pipeline import searchable - my_searchable = searchable("myspace", space={"x": ["apple", "pear"]}) # Only 2 valid configs - space = my_searchable.space() - random_search = RandomSearch(space=space, seed=42) +The [`ask`][amltk.optimization.Optimizer.ask] method should return a +new [`Trial`][amltk.optimization.Trial] object, and the [`tell`][amltk.optimization.Optimizer.tell] +method should update the optimizer with the result of the trial. A [`Trial`][amltk.optimization.Trial] +should have a unique `name`, a `config` and whatever optimizer specific +information you want to store should be stored in the `trial.info` property. - random_search.ask() # Fine - random_search.ask() # Fine +## Running an Optimizer +Now that we have an optimizer that knows the `space` to search, we can begin to +actually [`ask()`][amltk.optimization.Optimizer.ask] the optimizer for a next +[`Trial`][amltk.optimization.Trial], run our function and return +a [`Trial.Report`][amltk.optimization.Trial.Report]. - try: - random_search.ask() # ...Error - except RandomSearch.ExhaustedError as e: - print(traceback.format_exc()) - ``` +First we need to modify our function we wish to optimize to actually accept +the `Trial` and return the `Report`. - If you allow for duplicates in your sampling, simply set `duplicates=True`. +```python hl_lines="4 5 6 7 8 9 10 19 20 21 22 24 25" title="Runnig the Optimizer" +from amltk.optimization import RandomSearch, Trial +from amltk.pipeline import searchable - --- +def poly(trial: Trial[RSTrialInfo]) -> Trial.Report[RSTrialInfo]: # (4)! + x = trial.config["x"] + with trial.begin(): # (1)! + y = (x**2 + 4*x + 3) / x + return trial.success(cost=y) # (2)! - If you want to use a particular [`Sampler`][amltk.pipeline.Sampler] - you can pass it in as well. + trial.fail() # (3)! - ```python exec="true" source="material-block" result="python" title="RandomSearch Specific Sampler" hl_lines="3 8" - from amltk.optimization import RandomSearch - from amltk.pipeline import searchable - from amltk.configspace import ConfigSpaceSampler +s = searchable("parameters", space={"x": (-10.0, 10.0)}) - my_searchable = searchable("myspace", space={"x": (-10.0, 10.0)}) - space = my_searchable.space() +space = s.space() +random_search = RandomSearch(space=space, seed=42) - random_search = RandomSearch(space=space, seed=42, sampler=ConfigSpaceSampler) +results: list[float] = [] +for _ in range(20): trial = random_search.ask() - print(trial) - ``` - - --- - - Or you can even pass in your own custom sampling function. + report = qaudratic(trial) + random_search.tell(trial) + + cost = report.results["cost"] + results.append(cost) +``` + +1. Using the [`with trial.begin():`][amltk.optimization.Trial.begin], +you let us know where exactly your trial begins and we can handle +all things related to exception handling and timing. +2. If you can return a success, then do so with +[`trial.success()`][amltk.optimization.Trial.success]. +3. If you can't return a success, then do so with [`trial.fail()`][amltk.optimization.Trial.fail]. +4. Here the inner type parameter `RSTrial` is the type of `trial.info` which +contains the object returned by the ask of the wrapped `optimizer`. We'll +see this in [integrating your own Optimizer](#integrating-your-own-optimizer). - ```python exec="true" source="material-block" result="python" title="RandomSearch Custom Sample Function" hl_lines="9 10 11 12 13 14 15 16 21" - import numpy as np - from amltk.optimization import RandomSearch - - search_space = { - "x": (-10.0, 10.0), - "y": ["cat", "dog", "fish"] - } - - def my_sampler(space, seed: int): - rng = np.random.RandomState(seed) +### Running the Optimizer in a parallel fashion - xlow, xhigh = space["x"] - x = rng.uniform(xlow, xhigh) - y = np.random.choice(space["y"]) +Now that we've seen the basic optimization loop, it's time to parallelize it with +a [`Scheduler`][amltk.scheduling.Scheduler] and the [`Task`][amltk.Task]. +We cover the [`Scheduler`][amltk.scheduling.Scheduler] and [`Tasks`][amltk.scheduling.Task] +in the [Scheduling guide](./scheduling.md) if you'd like to know more about how this works. - return {"x": x, "y": y} +We first create a [`Scheduler`][amltk.scheduling.Scheduler] to run with `#!python 1` +process and run it for `#!python 5` seconds. +Using the event system of AutoML-Toolkit, +we define what happens through _callbacks_, registering to certain events, such +as launch a single trial on `@scheduler.on_start`, _tell_ the optimizer whenever we get +something returned with [`@task.on_result`][amltk.Task.on_result]. - random_search = RandomSearch( - space=search_space, - seed=42, - sampler=my_sampler - ) - trial = random_search.ask() - print(trial) - ``` +```python hl_lines="19 23 24 25 26 28 29 30 32 33 34 35 37 38 39 40 42" title="Creating a Task for a Trial" +from amltk.optimization import RandomSearch, Trial, RSTrialInfo +from amltk.pipeline import searchable +from amltk.scheduling import Scheduler - !!! warning "Determinism with the `seed` argument" +def poly(trial: Trial[RSTrialInfo]) -> Trial.Report[RSTrialInfo]: + x = trial.config["x"] + with trial.begin(): + y = (x**2 + 4*x + 3) / x + return trial.success(cost=y) - Given the same `int` seed integer, your sampling function - should return the same set of configurations. + trial.fail() +s = searchable("parameters", space={"x": (-10.0, 10.0)}) +space = s.space() +random_search = RandomSearch(space=space, seed=42) +scheduler = Scheduler.with_processes(1) -### SMAC +task = scheduler.task(poly) # (5)! -[SMAC](https://github.com/automl/SMAC3) is a collection of methods from -[automl.org](https://www.automl.org) for hyperparameter optimization. -Notably the library focuses on many Bayesian Optimization methods with highly -configurable spaces. +results: list[float] = [] -!!! example "Integration" +@scheduler.on_start # (1)! +def launch_trial() -> None: + trial = random_search.ask() + task(trial) - Check out the [SMAC integration page](site:reference/smac.md) +@task.on_result # (2)! +def tell_optimizer(report: Trial.Report) -> None: + random_search.tell(report) - Install with `pip install smac` +@task.on_result +def launch_another_trial(_: Trial.Report) -> None: + trial = random_search.ask() + task(trial) + +@task.on_result # (3)! +def save_result(report: Trial.Report) -> None: + cost = report["cost"] + results.append(cost) # (4)! + +scheduler.run(timeout=5) +``` + +1. The function `launch_trial()` gets called when the `scheduler` starts, +asking the optimizer for a trial and launching the `task` with the `trial`. +`launch_trial()` gets called in the main process but `task(trial)` will get +called in a seperate process. +2. The function `tell_optimizer` gets called whenever the `task` returns a +report. We should tell the optimizer about this report. +3. This function `save_result` gets called whenever we have a successful +trial. +4. We don't store anything more than the optmimizer needs. Saving results +that you wish to access later is up to you. +5. Here we wrap the function we want to run in another process in a +[`Task`][amltk.optimization.Trial]. There are other backends than +processes, e.g. Clusters for which you should check out the +[Scheduling guide](./scheduling.md). -### Optuna +Now, to scale up, we trivially increase the number of initial trails launched with `@scheduler.on_start` +and the number of processes in our `Scheduler`. That's it. -[Optuna](https://optuna.org/) is a hyperparameter optimization library which focuses on -Tree-Parzan Estimators (TPE) for finding optimal configurations. -There are some currentl limitations, such as heirarchical spaces but -is widely used and popular. +```python hl_lines="18 19 25" +from amltk.optimization import RandomSearch, Trial, RSTrialInfo +from amltk.pipeline import searchable +from amltk.scheduling import Scheduler -!!! example "Integration" +def poly(trial: Trial[RSTrialInfo]) -> Trial.Report[RSTrialInfo]: + x = trial.config["x"] + with trial.begin(): + y = (x**2 + 4*x + 3) / x + return trial.success(cost=y) - Check out the [Optuna integration page](site:reference/optuna.md#optimizer). + trial.fail() - Install with `pip install optuna` +s = searchable("parameters", space={"x": (-10.0, 10.0)}) +space = s.space() -### NePS +random_search = RandomSearch(space=space, seed=42) -[NePS](https://automl.github.io/neps/latest/) -is an optimization framework from [automl.org](https://www.automl.org) focusing -on optimizing Neural Architectures. +n_workers = 4 +scheduler = Scheduler.with_processes(n_workers) -!!! example "NePS" +task = Trial.Task(poly) - Check out the [NePS reference page](site:reference/neps.md). +results: list[float] = [] - Install with `pip install neural-pipeline-search +@scheduler.on_start(repeat=n_workers) +def launch_trial() -> None: + trial = random_search.ask() + task(trial) +@task.on_result +def tell_optimizer(report: Trial.Report) -> None: + random_search.tell(report) -### HEBO +@task.on_result +def launch_another_trial(_: Trial.Report) -> None: + trial = random_search.ask() + task(trial) -[HEBO](https://github.com/huawei-noah/HEBO) +@task.on_result +def save_result(report: Trial.Report) -> None: + cost = report["cost"] + results.append(cost) -!!! info "Planned" +scheduler.run(timeout=5) +``` +That concludes the main portion of our `Optimization` guide. AutoML-Toolkit provides +a host of more useful options, such as: +* Setting constraints on your evaluation function, such as memory, wall time and cpu time, concurrency limits +and call limits. Please refer to the [Scheduling guide](./scheduling.md) for more information. +* Stop the scheduler with whatever stopping criterion you wish. Please refer to the [Scheduling guide](./scheduling.md) for more information. +* Optimize over complex pipelines. Please refer to the [Pipeline guide](./pipelines.md) for more information. +* Using different parallelization strategies, such as [Dask](https://dask.org/), [Ray](https://ray.io/), +[Slurm](https://slurm.schedmd.com/), and [Apache Airflow](https://airflow.apache.org/). +* Use a whole host of more callbacks to control you system, check out the [Scheduling guide](./scheduling.md) for more information. +* Run the scheduler using `asyncio` to allow interactivity, run as a server or other more advanced use cases. ### Integrating your own Optimizer Integrating in your own optimizer is fairly straightforward. @@ -338,15 +372,7 @@ To integrate you own optimizer, you'll need to implement the following interface type inference. If you are interested in learning more, check out the [Python typing documentation](https://docs.python.org/3/library/typing.html). ---- - -The [`ask`][amltk.optimization.Optimizer.ask] method should return a -new [`Trial`][amltk.optimization.Trial] object, and the [`tell`][amltk.optimization.Optimizer.tell] -method should update the optimizer with the result of the trial. A [`Trial`][amltk.optimization.Trial] -should have a unique `name`, a `config` and whatever optimizer specific -information you want to store should be stored in the `trial.info` property. - -??? example "A simplified version of SMAC integration" +??? example "A Simplified Version of SMAC integration" Here is a simplified example of wrapping [`SMAC`](https://automl.github.io/SMAC3/stable/). The real implementation is more complex, but this should give you an idea of how to @@ -433,344 +459,3 @@ information you want to store should be stored in the `trial.info` property. and [`Trial.Report`][amltk.optimization.Trial.Report]. This is how we tell type checking analyzers that the _thing_ stored in `trial.info` will be a `TrialInfo` object from SMAC. - ---- - -If there is an optimizer you would like integrated, please let us know! - -## Running an Optimizer -Now that we have an optimizer that knows the `space` to search, we can begin to -actually [`ask()`][amltk.optimization.Optimizer.ask] the optimizer for a next -[`Trial`][amltk.optimization.Trial], run our function and return -a [`Trial.Report`][amltk.optimization.Trial.Report]. - -First we need to modify our function we wish to optimize to actually accept -the `Trial` and return the `Report`. - -=== "Running the optmizer" - - ```python hl_lines="4 5 6 7 8 9 10 19 20 21 22 24 25" - from amltk.optimization import RandomSearch - from amltk.pipeline import searchable - - def poly(trial): - x = trial.config["x"] - with trial.begin(): # (1)! - y = (x**2 + 4*x + 3) / x - return trial.success(cost=y) # (2)! - - trial.fail() # (3)! - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) - - space = s.space() - random_search = RandomSearch(space=space, seed=42) - - results = [] - - for _ in range(20): - trial = random_search.ask() - report = qaudratic(trial) - random_search.tell(trial) - - cost = report.results["cost"] - results.append(cost) - ``` - - 1. Using the [`with trial.begin():`][amltk.optimization.Trial.begin], - you let us know where exactly your trial begins and we can handle - all things related to exception handling and timing. - 2. If you can return a success, then do so with - [`trial.success()`][amltk.optimization.Trial.success]. - 3. If you can't return a success, then do so with [`trial.fail()`][amltk.optimization.Trial.fail]. - -=== "Typed" - - ```python hl_lines="4 5 6 7 8 9 10 19 20 21 22 24 25" - from amltk.optimization import RandomSearch, Trial - from amltk.pipeline import searchable - - def poly(trial: Trial[RSTrialInfo]) -> Trial.Report[RSTrialInfo]: # (4)! - x = trial.config["x"] - with trial.begin(): # (1)! - y = (x**2 + 4*x + 3) / x - return trial.success(cost=y) # (2)! - - trial.fail() # (3)! - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) - - space = s.space() - random_search = RandomSearch(space=space, seed=42) - - results: list[float] = [] - - for _ in range(20): - trial = random_search.ask() - report = qaudratic(trial) - random_search.tell(trial) - - cost = report.results["cost"] - results.append(cost) - ``` - - 1. Using the [`with trial.begin():`][amltk.optimization.Trial.begin], - you let us know where exactly your trial begins and we can handle - all things related to exception handling and timing. - 2. If you can return a success, then do so with - [`trial.success()`][amltk.optimization.Trial.success]. - 3. If you can't return a success, then do so with [`trial.fail()`][amltk.optimization.Trial.fail]. - 4. Here the inner type parameter `RSTrial` is the type of `trial.info` which - contains the object returned by the ask of the wrapped `optimizer`. We'll - see this in [integrating your own Optimizer](#integrating-your-own-optimizer). - -### Running the Optimizer in a parallel fashion - -Now that we've seen the basic optimization loop, it's time to parallelize it with -a [`Scheduler`][amltk.scheduling.Scheduler] and the [`Task`][amltk.Task]. -We cover the [`Scheduler`][amltk.scheduling.Scheduler] and [`Tasks`][amltk.scheduling.Task] -in the [Scheduling guide](./scheduling.md) if you'd like to know more about how this works. - -We first create a [`Scheduler`][amltk.scheduling.Scheduler] to run with `#!python 1` -process and run it for `#!python 5` seconds. -Using the event system of AutoML-Toolkit, -we define what happens through _callbacks_, registering to certain events, such -as launch a single trial on `@scheduler.on_start`, _tell_ the optimizer whenever we get -something returned with [`@task.on_result`][amltk.Task.on_result]. - -=== "Creating a `Task` for a trial" - - ```python hl_lines="19 23 24 25 26 28 29 30 32 33 34 35 37 38 39 40 42" - from amltk.optimization import RandomSearch, Trial - from amltk.pipeline import searchable - from amltk.scheduling import Scheduler - - def poly(trial): - x = trial.config["x"] - with trial.begin(): - y = (x**2 + 4*x + 3) / x - return trial.success(cost=y) - - trial.fail() - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) - space = s.space() - - random_search = RandomSearch(space=space, seed=42) - scheduler = Scheduler.with_processes(1) - - task = scheduler.task(poly) # (5)! - - results = [] - - @scheduler.on_start # (1)! - def launch_trial(): - trial = random_search.ask() - task(trial) - - @task.on_result # (2)! - def tell_optimizer(report): - random_search.tell(report) - - @task.on_result - def launch_another_trial(_): - trial = random_search.ask() - task(trial) - - @task.on_result # (3)! - def save_result(report): - cost = report["cost"] - results.append(cost) # (4)! - - scheduler.run(timeout=5) - ``` - - 1. The function `launch_trial()` gets called when the `scheduler` starts, - asking the optimizer for a trial and launching the `task` with the `trial`. - `launch_trial()` gets called in the main process but `task(trial)` will get - called in a seperate process. - 2. The function `tell_optimizer` gets called whenever the `task` returns a - report. We should tell the optimizer about this report. - 3. This function `save_result` gets called whenever we have a successful - trial. - 4. We don't store anything more than the optmimizer needs. Saving results - that you wish to access later is up to you. - 5. Here we wrap the function we want to run in another process in a - [`Task`][amltk.Trial]. There are other backends than - processes, e.g. Clusters for which you should check out the - [Scheduling guide](./scheduling.md). - -=== "Typed" - - ```python hl_lines="19 23 24 25 26 28 29 30 32 33 34 35 37 38 39 40 42" - from amltk.optimization import RandomSearch, Trial, RSTrialInfo - from amltk.pipeline import searchable - from amltk.scheduling import Scheduler - - def poly(trial: Trial[RSTrialInfo]) -> Trial.Report[RSTrialInfo]: - x = trial.config["x"] - with trial.begin(): - y = (x**2 + 4*x + 3) / x - return trial.success(cost=y) - - trial.fail() - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) - space = s.space() - - random_search = RandomSearch(space=space, seed=42) - scheduler = Scheduler.with_processes(1) - - task = scheduler.task(poly) # (5)! - - results: list[float] = [] - - @scheduler.on_start # (1)! - def launch_trial() -> None: - trial = random_search.ask() - task(trial) - - @task.on_result # (2)! - def tell_optimizer(report: Trial.Report) -> None: - random_search.tell(report) - - @task.on_result - def launch_another_trial(_: Trial.Report) -> None: - trial = random_search.ask() - task(trial) - - @task.on_result # (3)! - def save_result(report: Trial.Report) -> None: - cost = report["cost"] - results.append(cost) # (4)! - - scheduler.run(timeout=5) - ``` - - 1. The function `launch_trial()` gets called when the `scheduler` starts, - asking the optimizer for a trial and launching the `task` with the `trial`. - `launch_trial()` gets called in the main process but `task(trial)` will get - called in a seperate process. - 2. The function `tell_optimizer` gets called whenever the `task` returns a - report. We should tell the optimizer about this report. - 3. This function `save_result` gets called whenever we have a successful - trial. - 4. We don't store anything more than the optmimizer needs. Saving results - that you wish to access later is up to you. - 5. Here we wrap the function we want to run in another process in a - [`Task`][amltk.optimization.Trial]. There are other backends than - processes, e.g. Clusters for which you should check out the - [Scheduling guide](./scheduling.md). - -Now, to scale up, we trivially increase the number of initial trails launched with `@scheduler.on_start` -and the number of processes in our `Scheduler`. That's it. - -=== "Scaling Up" - - ```python hl_lines="18 19 25" - from amltk.optimization import RandomSearch, Trial - from amltk.pipeline import searchable - from amltk.scheduling import Scheduler - - def poly(trial): - x = trial.config["x"] - with trial.begin(): - y = (x**2 + 4*x + 3) / x - return trial.success(cost=y) - - trial.fail() - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) - space = s.space() - - random_search = RandomSearch(space=space, seed=42) - - n_workers = 4 - scheduler = Scheduler.with_processes(n_workers) - - task = scheduler.task(poly) - - results = [] - - @scheduler.on_start(repeat=n_workers) - def launch_trial(): - trial = random_search.ask() - task(trial) - - @task.on_result - def tell_optimizer(report): - random_search.tell(report) - - @task.on_result - def launch_another_trial(_): - trial = random_search.ask() - task(trial) - - @task.on_result - def save_result(report): - cost = report["cost"] - results.append(cost) - - scheduler.run(timeout=5) - ``` - -=== "Typed" - - ```python hl_lines="18 19 25" - from amltk.optimization import RandomSearch, Trial, RSTrialInfo - from amltk.pipeline import searchable - from amltk.scheduling import Scheduler - - def poly(trial: Trial[RSTrialInfo]) -> Trial.Report[RSTrialInfo]: - x = trial.config["x"] - with trial.begin(): - y = (x**2 + 4*x + 3) / x - return trial.success(cost=y) - - trial.fail() - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) - space = s.space() - - random_search = RandomSearch(space=space, seed=42) - - n_workers = 4 - scheduler = Scheduler.with_processes(n_workers) - - task = Trial.Task(poly) - - results: list[float] = [] - - @scheduler.on_start(repeat=n_workers) - def launch_trial() -> None: - trial = random_search.ask() - task(trial) - - @task.on_result - def tell_optimizer(report: Trial.Report) -> None: - random_search.tell(report) - - @task.on_result - def launch_another_trial(_: Trial.Report) -> None: - trial = random_search.ask() - task(trial) - - @task.on_result - def save_result(report: Trial.Report) -> None: - cost = report["cost"] - results.append(cost) - - scheduler.run(timeout=5) - ``` - -That concludes the main portion of our `Optimization` guide. AutoML-Toolkit provides -a host of more useful options, such as: - -* Setting constraints on your evaluation function, such as memory, wall time and cpu time, concurrency limits -and call limits. Please refer to the [Scheduling guide](./scheduling.md) for more information. -* Stop the scheduler with whatever stopping criterion you wish. Please refer to the [Scheduling guide](./scheduling.md) for more information. -* Optimize over complex pipelines. Please refer to the [Pipeline guide](./pipelines.md) for more information. -* Using different parallelization strategies, such as [Dask](https://dask.org/), [Ray](https://ray.io/), -[Slurm](https://slurm.schedmd.com/), and [Apache Airflow](https://airflow.apache.org/). -* Use a whole host of more callbacks to control you system, check out the [Scheduling guide](./scheduling.md) for more information. -* Run the scheduler using `asyncio` to allow interactivity, run as a server or other more advanced use cases. diff --git a/docs/guides/pipelines.md b/docs/guides/pipelines.md index 1ae5709f..18b31643 100644 --- a/docs/guides/pipelines.md +++ b/docs/guides/pipelines.md @@ -1,56 +1,75 @@ # Pipelines Guide AutoML-toolkit was built to support future development of AutoML systems and -a central part of an AutoML system is its Pipeline. The purpose of this +a central part of an AutoML system is its pipeline. The purpose of this guide is to help you understand all the utility AutoML-toolkit can provide to help you define your pipeline. We will do this by introducing concepts from the ground up, rather than top down. -Please see [examples](site:examples/index.md) if you would rather see copy-pastable examples. +Please see [the reference](site:reference/pipelines/pipeline.md) +if you just want to quickly look something up. --- ## Introduction +The kinds of pipelines that exist in an AutoML system come in many different +forms. For example, one might be an [sklearn.pipeline.Pipeline][], other's +might be some deep-learning pipeline while some might even stand for some +real life machinery process and the settings of these machines. -At the core of a [`Pipeline`][amltk.pipeline.Pipeline] definition -is the many [`Steps`][amltk.pipeline.Step] it consists of. -By combining these together, you can define a _directed acyclic graph_ (DAG), -that represents the structure of your [`Pipeline`][amltk.pipeline.Pipeline]. -Here is one such example that we will build up towards. +To accomodate this, what AutoML-Toolkit provides is an **abstract** representation +of a pipeline, to help you define its search space and also to build concrete +objects in code if possible (see [builders](site:reference/pipelines/builders.md). + +We categorize this into 4 steps: + +1. Parametrize your pipeline using the various [components](site:reference/piplines/pipeline.md), + including the kinds of items in the pipeline, the search spaces and any additional configuration. + Each of the various types of components give a syntactic meaning when performing the next steps. + +2. [`pipeline.search_space(parser=...)`][amltk.pipeline.Node.search_space], + Get a useable search space out of the pipeline. This can then be passed to an + [`Optimizer`](site:reference/optimization/optimizers.md). + +3. [`pipeline.configure(config=...)`][amltk.pipeline.Node.configure], + Configure your pipeline, either manually or using a configuration suggested by + an optimizers. + +4. [`pipeline.build(builder=)`][amltk.pipeline.Node.build], + Build your configured pipeline definition into something useable, i.e. + an [`sklearn.pipeline.Pipeline`][sklearn.pipeline.Pipeline] or a + `torch.nn.Module` (_todo_). + +At the core of these definitions is the many [`Nodes`][amltk.pipeline.node.Node] +it consists of. By combining these together, you can define a _directed acyclic graph_ (DAG), +that represents the structure of your pipeline. +Here is one such sklearn example that we will build up towards. ```python exec="true" source="tabbed-right" html="True" title="Pipeline" -from sklearn.compose import ColumnTransformer, make_column_selector +from sklearn.compose import make_column_selector from sklearn.ensemble import RandomForestClassifier from sklearn.impute import SimpleImputer -from sklearn.metrics import accuracy_score -from sklearn.preprocessing import LabelEncoder, OneHotEncoder +from sklearn.preprocessing import OneHotEncoder import numpy as np -from amltk import step, split, group, Pipeline +from amltk.pipeline import Component, Split, Sequential -categorical_preprocessing = ( - step("categorical_imputer", SimpleImputer, config={"strategy": "constant", "fill_value": "missing"}) - | step("one_hot_encoding", OneHotEncoder, config={"drop": "first"}) +feature_preprocessing = Split( + { + "categoricals": [SimpleImputer(strategy="constant", fill_value="missing"), OneHotEncoder(drop="first")], + "numerics": Component(SimpleImputer, space={"strategy": ["mean", "median"]}), + }, + config={ + "categoricals": make_column_selector(dtype_include=object), + "numerics": make_column_selector(dtype_include=np.number), + }, + name="preprocessing", ) -numerical_preprocessing = step("numeric_imputer", SimpleImputer, space={"strategy": ["mean", "median"]}) - -pipeline = Pipeline.create( - split( - "feature_preprocessing", - group("categoricals", categorical_preprocessing), - group("numerics", numerical_preprocessing), - item=ColumnTransformer, - config={ - "categoricals": make_column_selector(dtype_include=object), - "numerics": make_column_selector(dtype_include=np.number), - }, - ), - step( - "rf", - RandomForestClassifier, - space={"n_estimators": (10, 100), "criterion": ["gini", "entropy", "log_loss"]}, - ), - name="My Classification Pipeline" + +pipeline = Sequential( + feature_preprocessing, + Component(RandomForestClassifier, space={"n_estimators": (10, 100), "criterion": ["gini", "log_loss"]}), + name="Classy Pipeline", ) -from amltk._doc import doc_print; doc_print(print, pipeline, output="html", fontsize="small", width=120) # markdown-exec: hide +from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide ``` ??? tip "`rich` printing" @@ -62,32 +81,38 @@ from amltk._doc import doc_print; doc_print(print, pipeline, output="html", font Once we have our pipeline definition, extracting a search space, configuring it and building it into something useful can be done with the methods. -* [`pipeline.space()`][amltk.pipeline.Pipeline.space], - Get a useable search space out of the pipeline to pass to an optimizer. - -* [`pipeline.sample()`][amltk.pipeline.Pipeline.sample], - Sample a valid configuration from the pipeline. +!!! tip "Guide Requirements" -* [`pipeline.configure(config=...)`][amltk.pipeline.Pipeline.configure], - Configure a pipeline with a given config + For this guide, we will be using `ConfigSpace` and `scikit-learn`, you can + install them manually or as so: -* [`pipeline.build()`][amltk.pipeline.Pipeline.build], - Build a configured pipeline into some useable object. + ```bash + pip install "amltk[sklearn, configspace]" + ``` ## Component -A `Pipeline` consists of building blocks which we can combine together +A pipeline consists of building blocks which we can combine together to create a DAG. We will start by introducing the `Component`, the common operations, and then show how to combine them together. -A [`Component`][amltk.pipeline.Component] is a single atomic step in a pipeline. While -you can construct a component directly, it is recommended to use the -[`step()`][amltk.pipeline.api.step] function to create one. +A [`Component`][amltk.pipeline.Component] is the most common kind of node a pipeline. +Like all parts of the pipeline, they subclass [`Node`][amltk.pipeline.Node] but a +`Component` signifies this is some concrete object, with a possible +[`.space`][amltk.pipeline.Node.space] and [`.config`][amltk.pipeline.Node.config]. + ### Definition + +??? tip inline end "Naming Nodes" + + By default, a `Component` (or any `Node` for that matter), will use the function/classname + for the [`.name`][amltk.pipeline.Node.name] of the `Node`. You can explicitly pass + a `name=` **as a keyword argument** when constructing these. + ```python exec="true" source="material-block" html="true" session="Pipeline-Component" from dataclasses import dataclass -from amltk import step +from amltk.pipeline import Component @dataclass class MyModel: @@ -95,15 +120,14 @@ class MyModel: i: int c: str -mystep = step( - "mystep", +my_component = Component( MyModel, space={"f": (0.0, 1.0), "i": (0, 10), "c": ["red", "green", "blue"]}, ) -from amltk._doc import doc_print; doc_print(print, mystep, output="html", fontsize="small") # markdown-exec: hide +from amltk._doc import doc_print; doc_print(print, my_component, output="html", fontsize="small") # markdown-exec: hide ``` -You can also use a **function** instead of a class if that is preffered. +You can also use a **function** instead of a class if that is preferred. ```python exec="true" source="material-block" html="true" session="Pipeline-Component" def myfunc(f: float, i: int, c: str) -> MyModel: @@ -111,210 +135,206 @@ def myfunc(f: float, i: int, c: str) -> MyModel: c = "red" return MyModel(f=f, i=i, c=c) -step_with_function = step( - "step_with_function", +component_with_function = Component( myfunc, space={"f": (0.0, 1.0), "i": (0, 10), "c": ["red", "green", "blue"]}, ) -from amltk._doc import doc_print; doc_print(print, step_with_function, output="html", fontsize="small") # markdown-exec: hide -``` - -### Sample -We now have a basic `Component` that parametrizes the class `MyModel`. What can be quite useful -is to now [`sample()`][amltk.pipeline.Step.sample] from it to get a valid configuration. - -```python exec="true" source="material-block" result="python" session="Pipeline-Component" -config = mystep.sample(seed=1) -print(config) +from amltk._doc import doc_print; doc_print(print, component_with_function, output="html", fontsize="small") # markdown-exec: hide ``` -### Space -If interacting with an `Optimizer`, you'll often require some search space object to pass to it. -To extract a search space from a `Component`, we can call [`space()`][amltk.pipeline.Step.space]. +### Search Space +If interacting with an [`Optimizer`](site:reference/optimization/optimizers.md), you'll often require some +search space object to pass to it. +To extract a search space from a `Component`, we can call [`search_space(parser=)`][amltk.pipeline.Node.search_space], +passing in the kind of search space you'd like to get out of it. ```python exec="true" source="material-block" result="python" session="Pipeline-Component" -space = mystep.space(seed=1) +space = my_component.search_space("configspace") print(space) ``` -??? tip "What type of space is this?" - - Depending on the libraries you have installed and the values inside `space`, we will attempt - to produce a valid search space for you. In this case, we have a `ConfigSpace` implementation - installed and so we get a `ConfigSpace.ConfigurationSpace` object. If you wish to use a different - space, you can always pass a specific [`step.space(parser=...)`][amltk.pipeline.Step.space]. Do - note that not all spaces support all features. +!!! tip inline end "Available Search Spaces" - === "`ConfigSpace`" + Please see the [spaces reference](site:reference/pipelines/spaces.md) - ```python exec="true" source="material-block" result="python" session="Pipeline-Component" - from amltk.configspace import ConfigSpaceAdapter +Depending on what you pass as the `parser=` to `search_space(parser=...)`, we'll attempt +to give you a valid search space. In this case, we specified `#!python "configspace"` and +so we get a `ConfigSpace` implementation. - configspace_space = mystep.space(parser=ConfigSpaceAdapter) - print(configspace_space) - ``` +You may also define your own `parser=` and use that if desired. - === "`Optuna`" - - ```python exec="true" source="material-block" result="python" session="Pipeline-Component" - from amltk.optuna import OptunaSpaceAdapter - - optuna_space = mystep.space(parser=OptunaSpaceAdapter) - print(optuna_space) - ``` - - You may also construct your own parser and use that if desired. ### Configure Pretty straight forward but what do we do with this `config`? Well we can -[`configure(config=...)`][amltk.pipeline.Step.configure] the component with it. +[`configure(config=...)`][amltk.pipeline.Node.configure] the component with it. ```python exec="true" source="material-block" html="true" session="Pipeline-Component" -configured_step = mystep.configure(config) -from amltk._doc import doc_print; doc_print(print, configured_step, output="html", fontsize="small") # markdown-exec: hide +config = space.sample_configuration() +configured_component = my_component.configure(config) +from amltk._doc import doc_print; doc_print(print, configured_component) # markdown-exec: hide ``` You'll notice that each variable in the space has been set to some value. We could also manually define a config and pass that in. You are **not** obliged to fully specify this either. ```python exec="true" source="material-block" html="true" session="Pipeline-Component" -manually_configured_step = mystep.configure({"f": 0.5, "i": 1}) -from amltk._doc import doc_print; doc_print(print, manually_configured_step, output="html") # markdown-exec: hide +manually_configured_component = my_component.configure({"f": 0.5, "i": 1}) +from amltk._doc import doc_print; doc_print(print, manually_configured_component, output="html") # markdown-exec: hide ``` !!! tip "Immutable methods!" - One thing you may have noticed is that we assigned the result of `configure(...)` to a new - variable. This is because we do not mutate the original `mystep` and instead return a copy + One thing you may have noticed is that we assigned the result of `configure(config=...)` to a new + variable. This is because we do not mutate the original `my_component` and instead return a copy with all of the `config` variables set. ### Build -The last important thing we can do with a `Component` is to [`build()`][amltk.pipeline.Component.build] -it. Thisa step is very straight-forward for a `Component` and it simply calls the `.item` with the -config we have set. +To build the individual item of a `Component` we can use [`build_item()`][amltk.pipeline.Component.build_item] +and it simply calls the `.item` with the config we have set. ```python exec="true" source="material-block" result="python" session="Pipeline-Component" -the_built_model = configured_step.build() +the_built_model = configured_component.build_item() -# Same as if we did `configured_step.item(**configured_step.config)` +# Same as if we did `configured_component.item(**configured_component.config)` print(the_built_model) ``` -You may also pass additional items to `build()` which will overwrite any config values set. +However, as we'll see later, we often have multiple steps of a pipeline joined together and so +we need some way to get a full object out of it that takes into account all of these items +joined together. We can do this with [`build(builder=...)`][amltk.pipeline.Node.build]. ```python exec="true" source="material-block" result="python" session="Pipeline-Component" -the_built_model = configured_step.build(f=0.5, i=1) +the_built_model = configured_component.build(builder="sklearn") print(the_built_model) ``` +For a look at the available arguments to pass to `builder=`, see the +[builder reference](site:reference/pipelines/builders.md) + +### Fixed + +Sometimes we just have some part of the pipeline with no search space and +no configuration required, i.e. just some prebuilt thing. We can +use the [`Fixed`][amltk.pipeline.Fixed] node type to signify this. + +```python exec="true" source="material-block" result="python" +from amltk.pipeline import Fixed +from sklearn.ensemble import RandomForestClassifier + +frozen_rf = Fixed(RandomForestClassifier(n_estimators=5)) +``` + ### Parameter Requests Sometimes you may wish to explicitly specify some value should be added to the `.config` during -`configure()` which would be difficult to include in the `config`, for example the `random_state` +`configure()` which would be difficult to include in the `config` directly, for example the `random_state` of an sklearn estimator. You can pass these extra parameters into `configure(params={...})`, which do not require any namespace prefixing. -For this reason, we have the concept of a [`request()`][amltk.pipeline.request], allowing +For this reason, we introduce the concept of a [`request()`][amltk.pipeline.request], allowing you to specify that a certain parameter should be added to the config during `configure()`. ```python exec="true" hl_lines="14 17 18" source="material-block" html="true" session="Pipeline-Parameter-Request" from dataclasses import dataclass -from amltk import step, request +from amltk import Component, request @dataclass class MyModel: f: float random_state: int -mystep = step( - "mystep", +my_component = Component( MyModel, space={"f": (0.0, 1.0)}, config={"random_state": request("seed", default=42)} ) -configured_step_with_seed = mystep.configure({"f": 0.5}, params={"seed": 1337}) -configured_step_no_seed = mystep.configure({"f": 0.5}) -from amltk._doc import doc_print; doc_print(print, configured_step_with_seed, output="html", fontsize="small") # markdown-exec: hide -doc_print(print, configured_step_no_seed, output="html", fontsize="small") # markdown-exec: hide +# Without passing the params +configured_component_no_seed = my_component.configure({"f": 0.5}) + +# With passing the params +configured_component_with_seed = my_component.configure({"f": 0.5}, params={"seed": 1337}) +from amltk._doc import doc_print; doc_print(print, configured_component_no_seed) # markdown-exec: hide +doc_print(print, configured_component_with_seed) # markdown-exec: hide ``` -If you explicitly require a parameter to be set, you can pass `required=True` to the request. +If you explicitly require a parameter to be set, just do not set a `default=`. ```python exec="true" source="material-block" result="python" session="Pipeline-Parameter-Request" -mystep = step( - "mystep", +my_component = Component( MyModel, space={"f": (0.0, 1.0)}, - config={"random_state": request("seed", required=True)} + config={"random_state": request("seed")} ) -mystep.configure({"f": 0.5}, params={"seed": 5}) # All good +my_component.configure({"f": 0.5}, params={"seed": 5}) # All good try: - mystep.configure({"f": 0.5}) # Missing required parameter + my_component.configure({"f": 0.5}) # Missing required parameter except ValueError as e: print(e) ``` ### Config Transform Some search space and optimizers may have limitations in terms of the kinds of parameters they -can support, one notable example is tuple parameters. To get around this, we can pass -a `config_transform` to `step` which will transform the config before it is passed to the +can support, one notable example is **tuple** parameters. To get around this, we can pass +a `config_transform=` to `component` which will transform the config before it is passed to the `.item` during `build()`. ```python exec="true" hl_lines="9-13 19" source="material-block" html="true" from dataclasses import dataclass -from amltk import step +from amltk import Component @dataclass class MyModel: dimensions: tuple[int, int] def config_transform(config: dict, _) -> dict: + """Convert "dim1" and "dim2" into a tuple.""" dim1 = config.pop("dim1") dim2 = config.pop("dim2") config["dimensions"] = (dim1, dim2) return config -mystep = step( - "mystep", +my_component = Component( MyModel, space={"dim1": (1, 10), "dim2": (1, 10)}, config_transform=config_transform, ) -configured_step = mystep.configure({"dim1": 5, "dim2": 5}) -from amltk._doc import doc_print; doc_print(print, configured_step, output="html", fontsize="small") # markdown-exec: hide +configured_component = my_component.configure({"dim1": 5, "dim2": 5}) +from amltk._doc import doc_print; doc_print(print, configured_component, fontsize="small") # markdown-exec: hide ``` -Lastly, there may be times where you may have some additional context which you may only -know at configuration time, you may pass this to `configure(..., transform_context=...)` -which will be forwarded as the second argument to your `.config_transform`. +!!! tip inline end "Transform Context" + + There may be times where you may have some additional context which you may only + know at configuration time, you may pass this to `configure(..., transform_context=...)` + which will be forwarded as the second argument to your `.config_transform`. -## Pipelines -A single step might be enough for some basic definitions but generally we need to combine multiple -steps. AutoML-Toolkit is designed for large and more complex structures which can be made from -simple atomic steps. +## Sequential +A single component might be enough for some basic definitions but generally we need to combine multiple +components together. AutoML-Toolkit is designed for large and more complex structures which can be +made from simple atomic [`Node`][amltk.pipeline.Node]s. -### Joining Steps +### Chaining Together Nodes We'll begin by creating two components that wrap scikit-learn estimators. -```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Steps" +```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Nodes" from sklearn.impute import SimpleImputer from sklearn.ensemble import RandomForestClassifier -from amltk import step +from amltk.pipeline import Component -imputer_step = step("imputer", SimpleImputer, space={"strategy": ["median", "mean"]}) -rf_step = step("random_forest", RandomForestClassifier, space={"n_estimators": (10, 100)}) +imputer = Component(SimpleImputer, space={"strategy": ["median", "mean"]}) +rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) -from amltk._doc import doc_print; doc_print(print, imputer_step, output="html", fontsize="small") # markdown-exec: hide -doc_print(print, rf_step, output="html", fontsize="small") # markdown-exec: hide +from amltk._doc import doc_print; doc_print(print, imputer) # markdown-exec: hide +doc_print(print, rf) # markdown-exec: hide ``` -!!! tip "Modifying Display Output" +!!! info inline end "Modifying Display Output" By default, `amltk` will show full function signatures, including a link to their documentation if available. @@ -329,146 +349,72 @@ doc_print(print, rf_step, output="html", fontsize="small") # markdown-exec: hid You can find the [available options here][amltk.options.AMLTKOptions]. -To join these two steps together, we can either use the infix notation using `|` (reminiscent of a bash pipe) -or directly call [`append(nxt)`][amltk.pipeline.Step.append] on the first step. - -```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Steps" -joined_steps = imputer_step | rf_step -from amltk._doc import doc_print; doc_print(print, joined_steps, output="html", fontsize="small") # markdown-exec: hide -``` - -We should point out two key things here: - -* You are always returned the _head_ of the steps, i.e. the first step in the list -* You can see the `rf_step` is now attached to the `imputer_step` as its `nxt` attribute. -However viewing only one step at a time is not so useful. We can get a [`Pipeline`][amltk.pipeline.Pipeline] -out of these steps quite easily which will display a lot more nicely and allow you to perform operations - -```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Steps" -pipeline = joined_steps.as_pipeline(name="My Pipeline") -from amltk._doc import doc_print; doc_print(print, pipeline, output="html", fontsize="small") # markdown-exec: hide +```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Nodes" +from amltk.pipeline import Sequential +pipeline = Sequential(imputer, rf, name="My Pipeline") +from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide ``` -??? tip "Using `Pipeline.create(...)` instead" - - You can also use [`Pipeline.create(...)`][amltk.pipeline.Pipeline.create] to create a pipeline - from a set of steps. +!!! info inline end "Infix `>>`" - ```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Steps" - from amltk import Pipeline + To join these two components together, we can either use the infix notation using `>>`, + or passing them directly to a [`Sequential`][amltk.pipeline.Sequential]. However + a random name will be given. - pipeline2 = Pipeline.create(joined_steps, name="My Pipeline") - from amltk._doc import doc_print; doc_print(print, pipeline2, output="html", fontsize="small") # markdown-exec: hide + ```python + joined_components = imputer >> rf ``` -### Pipeline Usage - -You can perform much of the same operations as we did for the individual step but now taking into account +### Operations +You can perform much of the same operations as we did for the individual node but now taking into account everything in the pipeline. -```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Steps" -space = pipeline.space() -config = pipeline.sample(seed=1337) +```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Nodes" +space = pipeline.search_space("configspace") +config = space.sample_configuration() configured_pipeline = pipeline.configure(config) -from amltk._doc import doc_print; doc_print(print, space, output="html") # markdown-exec: hide -doc_print(print, config, output="html") # markdown-exec: hide -doc_print(print, configured_pipeline, output="html", fontsize="small") # markdown-exec: hide +from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide +doc_print(print, config) # markdown-exec: hide +doc_print(print, configured_pipeline) # markdown-exec: hide ``` -Pipelines also support a number of other operations such as traversal with [`iter()`][amltk.pipeline.Pipeline.iter], -[`traverse()`][amltk.pipeline.Pipeline.traverse] and [`walk()`][amltk.pipeline.Pipeline.walk], -search with [`find()`][amltk.pipeline.Pipeline.find], modification with [`remove()`][amltk.pipeline.Pipeline.remove], -[`apply()`][amltk.pipeline.Pipeline.apply] and [`replace()`][amltk.pipeline.Pipeline.replace]. +!!! inline end "Other notions of Sequential" -### Pipeline Building -Perhaps the most significant difference when working with a `Pipeline` is what should something -like [`build()`][amltk.pipeline.Pipeline.build] do? Well, there are perhaps multiple steps and perhaps -even nested `choice` and `split` components which we will introduce later. + We'll see this later but wherever we expect a `Node`, for example as an argument to + `Sequential` or any other type of pipeline component, a list, i.e. `[node_1, node_2]`, + will automatically be joined together and interpreted as a `Sequential`. -The answer depends on what is contained within your steps. For this example, using sklearn, we can -directly return an sklearn [`Pipeline`][sklearn.pipeline.Pipeline] object. This is auto detected -based on the `.item` contained in each step. +To build a pipeline of nodes, we simply call [`build(builder=)`][amltk.pipeline.Node.build]. We +explicitly pass the builder we want to use, which informs `build()` how to go from the abstract +pipeline definition you've defined to something concrete you can use. +You can find the [available builders here](site:reference/pipelines/builders.md). -```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Steps" +```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Nodes" from sklearn.pipeline import Pipeline as SklearnPipeline -built_pipeline = configured_pipeline.build() +built_pipeline = configured_pipeline.build("sklearn") assert isinstance(built_pipeline, SklearnPipeline) print(built_pipeline._repr_html_()) # markdown-exec: hide ``` -We currently support the following builders which are auto detected based on the `.item` contained: - -=== "Scikit-Learn" - - Using [`sklearn_pipeline()`][amltk.sklearn.sklearn_pipeline] will builds - an [`SklearnPipeline`][sklearn.pipeline.Pipeline] from the steps. The - possible pipelines allowed follow the rules of an sklearn pipeline, i.e. - only the final step can be an estimator and everything else before it must - be a transformer. - - ```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Steps" - from amltk.sklearn import sklearn_pipeline - built_pipeline = configured_pipeline.build(builder=sklearn_pipeline) - print(built_pipeline._repr_html_()) - ``` - - If using something like `imblearn` components, you will need to have an - `imblearn.pipeline.Pipeline` as output type. We can pass this directly to - the builder. - - ```python - from amltk.sklearn import sklearn_pipeline - from imblearn.pipeline import Pipeline as ImblearnPipeline - - built_pipeline = configured_pipeline.build( - builder=sklearn_pipeline, - pipeline_type=ImblearnPipeline - ) - ``` - -=== "`pytorch_builder()`" - - !!! todo "TODO" - - This is currently in progress and will be available soon. Please - feel free to reach out to help - -=== "Custom Builder" - - You can also provide your own `builder=` function which has a very basic premise that - you must be able to parse the pipeline and return _something_. Here is a basic example - which will just return a dict with the step names as keys and the values as the built - components. - - ```python exec="true" source="material-block" html="true" session="Pipeline-Connecting-Steps" - def mybuilder(pipeline: Pipeline, **kwargs) -> dict: - return {step.name: step.build() for step in pipeline.traverse()} - - components = configured_pipeline.build(builder=mybuilder) - ``` - -## Building blocks +## Other Building blocks We saw the basic building block of a `Component` but AutoML-Toolkit also provides support -for some other kinds of building blocks. These building blocks can be attached and appended -just like a `Component` can and allow for much more complex pipeline structures. +for some other kinds of building blocks. These building blocks can be attached and joined +together just like a `Component` can and allow for much more complex pipeline structures. ### Choice A [`Choice`][amltk.pipeline.Choice] is a way to define a choice between multiple components. This is useful when you want to search over multiple algorithms, which may each have their own hyperparameters. -The preferred way to create a `Choice` is to use the [`choice(...)`][amltk.pipeline.choice] -function. - -We'll start again by creating two steps: +We'll start again by creating two nodes: ```python exec="true" source="material-block" html="true" session="Pipeline-Choice" from dataclasses import dataclass -from amltk import step +from amltk.pipeline import Component @dataclass class ModelA: @@ -478,8 +424,8 @@ class ModelA: class ModelB: c: str -model_a = step("model_a", ModelA, space={"i": (0, 100)}) -model_b = step("model_b", ModelB, space={"c": ["red", "blue"]}) +model_a = Component(ModelA, space={"i": (0, 100)}) +model_b = Component(ModelB, space={"c": ["red", "blue"]}) from amltk._doc import doc_print; doc_print(print, model_a, output="html", fontsize="small") # markdown-exec: hide doc_print(print, model_b, output="html", fontsize="small") # markdown-exec: hide ``` @@ -487,248 +433,104 @@ doc_print(print, model_b, output="html", fontsize="small") # markdown-exec: hid Now combining them into a choice is rather straight forward: ```python exec="true" source="material-block" html="true" session="Pipeline-Choice" -from amltk import choice +from amltk.pipeline import Choice -model_choice = choice("model", model_a, model_b) +model_choice = Choice(model_a, model_b, name="estimator") from amltk._doc import doc_print; doc_print(print, model_choice, output="html", fontsize="small") # markdown-exec: hide ``` -Just as we did with a `Component`, we can also get a `space()` from the choice. If the space -parser supports conditionals from a space, it will even add conditionals to the space to -account for the choice and that some hyperparameters are only active depending on if the -model is chosen. +Just as we did with a `Component`, we can also get a [`search_space()`][amltk.pipeline.Node.search_space] +from the choice. ```python exec="true" source="material-block" html="true" session="Pipeline-Choice" -space = model_choice.space() +space = model_choice.search_space("configspace") from amltk._doc import doc_print; doc_print(print, space, output="html") # markdown-exec: hide ``` -When we `configure()` a choice, we will collapse it down to a single component. This is -done according to what is set in the config. - -```python exec="true" source="material-block" html="true" session="Pipeline-Choice" -config = model_choice.sample(seed=1) -configured_model = model_choice.configure(config) -from amltk._doc import doc_print; doc_print(print, configured_model, output="html") # markdown-exec: hide -``` - -### Group -The purpose of a [`Group`][amltk.pipeline.Group] is to _"draw a box"_ around a certain -subsection of a pipeline. This essentially acts as a namespacing mechanism for the -config and space of the steps contained within it. This can be useful -when you need to refer to a `Choice` in part of a `Pipeline`, where when configured, -this `Choice` will disappear and be replaced by a single component. - -To illustrate this, let's revisit what happens when we `configure()` a choice. First we'll -build a small pipeline. - -```python exec="true" source="material-block" html="true" session="Pipeline-Group" -from amltk import step, choice, group - -model_a = step("model_a", object) -model_b = step("model_b", object) - -preprocessing = step("preprocessing", object) -model_choice = choice("classifier_choice", model_a, model_b) - -pipeline = (preprocessing | model_choice).as_pipeline(name="My Pipeline") - -from amltk._doc import doc_print; doc_print(print, pipeline, output="html", fontsize="small") # markdown-exec: hide -``` - -Now let's sample a config from the pipeline space and `configure()` it to see what we get out. - - -```python exec="true" source="material-block" html="true" session="Pipeline-Group" -space = pipeline.space() -config = pipeline.sample(seed=1) -configured_pipeline = pipeline.configure(config) -doc_print(print, space, output="html", fontsize="small") # markdown-exec: hide -doc_print(print, config, output="html", fontsize="small") # markdown-exec: hide -doc_print(print, configured_pipeline, output="html", fontsize="small") # markdown-exec: hide -``` - -We can't know ahead of time whether we need to refer to `#!python "model_a"` -or `#!python "model_b"` and we can no longer refer to `#!python "classifier_choice"` as this -has been configured away. - -To circumvent this, we can use a [`Group`][amltk.pipeline.Group] to wrap the choice, with -the preferred way to create one being [`group(...)`][amltk.pipeline.group]. +??? warning inline end "Conditionals and Search Spaces" -```python exec="true" source="material-block" html="true" session="Pipeline-Group" -from amltk import step, choice, group + Not all search space implementations support conditionals and so some + `parser=` may not be able to handle this. In this case, there won't be + any conditionality in the search space. -model_a = step("model_a", object) -model_b = step("model_b", object) - -preprocessing = step("preprocessing", object) -classifier_group = group( - "classifier", - choice("classifier_choice", model_a, model_b) -) - -pipeline = (preprocessing | classifier_group).as_pipeline(name="My Pipeline") -from amltk._doc import doc_print; doc_print(print, pipeline, output="html", fontsize="small") # markdown-exec: hide -``` + Check out the [parser reference](site:reference/pipelines/spaces.md) + for more information. -Now let's configure it: -```python exec="true" source="material-block" html="true" session="Pipeline-Group" -config = pipeline.sample(seed=1) -configured_pipeline = pipeline.configure(config) -doc_print(print, config, output="html", fontsize="small") # markdown-exec: hide -doc_print(print, configured_pipeline, output="html", fontsize="small") # markdown-exec: hide -``` +When we `configure()` a choice, we will collapse it down to a single component. This is +done according to what is set in the config. -If we need to access the chosen classifier, we can do so in a straightforward manner: -```python exec="true" source="material-block" html="true" session="Pipeline-Group" -chosen_classifier = configured_pipeline.find("classifier").first() -doc_print(print, chosen_classifier, output="html", fontsize="small") # markdown-exec: hide +```python exec="true" source="material-block" html="true" session="Pipeline-Choice" +config = space.sample_configuration() +configured_choice = model_choice.configure(config) +from amltk._doc import doc_print; doc_print(print, configured_choice, output="html") # markdown-exec: hide ``` - +You'll notice that it set the `.config` of the `Choice` to `#!python {"__choice__": "model_a"}` or +`#!python {"__choice__": "model_b"}`. This lets a builder know which of these two to build. ### Split -A [`Split`][amltk.pipeline.Split] is a way to signify a split in the dataflow of a pipeline, -with the preferred way to create one being [`split(...)`][amltk.pipeline.split]. This `Split` -by itself will not do anything but it informs the builder about what to do. Each builder -will have if it's own specific strategy for dealing with one. - -Before we go ahead with a full scikit-learn example and build it, we'll start with -an abstract representation of a `Split`. - -```python exec="true" source="material-block" html="True" session="Pipeline-Split1" -from amltk import step, split, group, Pipeline - -preprocesser = split( - "preprocesser", - step("cat_imputer", object) | step("cat_encoder", object), - step("num_imputer", object), - config={"cat_imputer": object, "num_imputer": object} -) -from amltk._doc import doc_print; doc_print(print, preprocesser, output="html", fontsize="small", width=120) # markdown-exec: hide -``` +A [`Split`][amltk.pipeline.Split] is a way to signify a split in the dataflow of a pipeline. +This `Split` by itself will not do anything but it informs the builder about what to do. +Each builder will have if it's own specific strategy for dealing with one. -You'll notice that if we have any hope to configure this `Split` which normally requires -mentioning each of it's paths, we can only reference the first step of the path, in this -case `#!python "cat_imputer"` and `#!python "num_imputer"`. In the case of the first step -being a `Choice`, we may not even have a name we can refer to! - -We fix this situation by giving each split path its own name. We can either do this manually -with a `Group` or we can simply pass a `dict` of paths to a sequence of steps. - -```python exec="true" source="material-block" html="True" session="Pipeline-Split2" -from amltk import step, split, group, Pipeline - -preprocesser = split( - "preprocesser", - { - "categories": step("cat_imputer", object) | step("cat_encoder", object), - "numericals": step("num_imputer", object), - }, - config={"categories": object, "numericals": object} -) -from amltk._doc import doc_print; doc_print(print, preprocesser, output="html", fontsize="small", width=120) # markdown-exec: hide -``` - -This construction will use a `Group` around each of the paths, which will allow us to refer -to the different paths, regardless of what happens to the path. - -Now this time we will use a scikit-learn example, as an example. +Let's go ahead with a scikit-learn example, where we'll split the data into categorical +and numerical features and then perform some preprocessing on each of them. ```python exec="true" source="material-block" html="True" session="Pipeline-Split3" -from sklearn.compose import ColumnTransformer, make_column_selector +from sklearn.compose import make_column_selector from sklearn.impute import SimpleImputer -from sklearn.preprocessing import LabelEncoder, OneHotEncoder +from sklearn.preprocessing import OneHotEncoder import numpy as np -from amltk import step, split, group, Pipeline - -# We'll impute categorical features and then OneHotEncode them -category_pipeline = step("categorical_imputer", SimpleImputer) | step("one_hot_encoding", OneHotEncoder) +from amltk.pipeline import Component, Split -# We just impute numerical features -numerical_pipeline = step("numeric_imputer", SimpleImputer, config={"strategy": "median"}) +select_categories = make_column_selector(dtype_include=object) +select_numerical = make_column_selector(dtype_include=np.number) -feature_preprocessing = split( - "feature_preprocessing", - group("categoricals", category_pipeline), - group("numerics", numerical_pipeline), - item=ColumnTransformer, - config={ - # Here we specify which columns should be passed to which group - "categoricals": make_column_selector(dtype_include=object), - "numerics": make_column_selector(dtype_include=np.number), +preprocessor = Split( + { + "categories": [SimpleImputer(strategy="constant", fill_value="missing"), OneHotEncoder(drop="first")], + "numerics": Component(SimpleImputer, space={"strategy": ["mean", "median"]}), }, + config={"categories": select_categories, "numerics": select_numerical}, + name="feature_preprocessing", ) -from amltk._doc import doc_print; doc_print(print, feature_preprocessing, output="html", fontsize="small", width=120) # markdown-exec: hide +from amltk._doc import doc_print; doc_print(print, preprocessor) # markdown-exec: hide ``` -Our last step is just to convert this into a `Pipeline` and `build()` it. First, -to convert it into a pipeline with a classifier at the end. +An important thing to note here is that first, we passed a `dict` to `Split`, such that +we can name the individual paths. This is important because we need some name to refer +to them when configuring the `Split`. It does this by simply wrapping +each of the paths in a [`Sequential`][amltk.pipeline.Sequential]. -```python exec="true" source="material-block" html="True" session="Pipeline-Split3" -from sklearn.ensemble import RandomForestClassifier +The second thing is that the parameters set for the `.config` matches those of the +paths. This let's the `Split` know which data should be sent where. Each `builder=` +will have it's own way of how to set up a `Split` and you should refer to +the [builders reference](site:reference/pipelines/builders.md) for more information. -classifier = step("random_forest", RandomForestClassifier) -pipeline = (feature_preprocessing | classifier).as_pipeline(name="Classification Pipeline") -from amltk._doc import doc_print; doc_print(print, pipeline, output="html", fontsize="small", width=120) # markdown-exec: hide -``` +Our last step is just to convert this into a useable object and so once again +we use [`build()`][amltk.pipeline.Node.build]. -And finally, to build it: ```python exec="true" source="material-block" html="True" session="Pipeline-Split3" -from sklearn.pipeline import Pipeline as SklearnPipeline - -sklearn_pipeline = pipeline.build() -assert isinstance(sklearn_pipeline, SklearnPipeline) -print(sklearn_pipeline._repr_html_()) # markdown-exec: hide +built_pipeline = preprocessor.build("sklearn") +from amltk._doc import doc_print; doc_print(print, built_pipeline) # markdown-exec: hide ``` -### Option +### Join !!! todo "TODO" - Please feel free to provide a contribution! - -## Modules -A pipeline is often not sufficient to represent everything surrounding the pipeline -that you'd wish to associate with it. For that reason we introduce the concept -of _module_. -These are components or pipelines that you [`attach()`][amltk.pipeline.Pipeline.attach] -to your main pipeline, but are not directly part of the dataflow. - -For example, we can create a simple [`searchable()`][amltk.pipeline.api.searchable] -which we `attach()` to our pipeline. -This will be included in the `space()` that it outputed from the `Pipeline. - -```python exec="true" source="material-block" html="True" session="Pipeline-Modules" -from amltk.pipeline import step, searchable + TODO + -# Some extra things we want to include in the search space of the pipeline -params_a = searchable("params_a", space={"a": (1, 10), "b": ["apple", "frog"]}) -params_b = searchable("params_b", space={"c": (1.5, 1.8)}) +### Searchable -# Create a basic pipeline of two steps -pipeline = (step("step1", object) | step("step2", object)).as_pipeline() -pipeline = pipeline.attach(modules=[params_a, params_b]) - -space = pipeline.space() -from amltk._doc import doc_print; doc_print(print, pipeline, output="html", fontsize="small") # markdown-exec: hide -doc_print(print, space, output="html", fontsize="small") # markdown-exec: hide -``` +!!! todo "TODO" -These will also be included in any configurations `sample()`'ed and will be configured with -`configure()`. + TODO -```python exec="true" source="material-block" html="True" session="Pipeline-Modules" -config = pipeline.sample() -pipeline = pipeline.configure(config) +### Option -doc_print(print, config, output="html", fontsize="small") # markdown-exec: hide -doc_print(print, pipeline, output="html", fontsize="small") # markdown-exec: hide -``` +!!! todo "TODO" -Lastly, we can access the config directly through the pipelines `.modules` + Please feel free to provide a contribution! -```python exec="true" source="material-block" html="True" session="Pipeline-Modules" -module_config = pipeline.modules["params_a"].config -doc_print(print, module_config, output="html", fontsize="small") # markdown-exec: hide -``` diff --git a/docs/guides/scheduling.md b/docs/guides/scheduling.md index e92ec84f..86662376 100644 --- a/docs/guides/scheduling.md +++ b/docs/guides/scheduling.md @@ -27,19 +27,19 @@ answers = [] @scheduler.on_start def start_computing() -> None: answers.append(12) - task(12) # Launch the task with the argument 12 + task.submit(12) # Launch the task with the argument 12 # Tell the scheduler what to do when the task returns @task.on_result def compute_next(_, next_n: int) -> None: answers.append(next_n) - if next_n != 1: - task(next_n) + if scheduler.running() or next_n != 1: + task.submit(next_n) # Run the scheduler scheduler.run(timeout=1) # One second timeout print(answers) -from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide +from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide ``` @@ -51,15 +51,15 @@ However, the `Scheduler` is rather useless without some fuel. For this, we present [`Tasks`][amltk.scheduling.Task], the computational task to perform with the `Scheduler` and start the system's gears turning. -Finally, we show the [`Event`][amltk.events.Event] and how you can use this with -an [`Emitter`][amltk.events.Emitter] to create your own event-driven systems. - ??? tip "`rich` printing" To get the same output locally (terminal or Notebook), you can either call `thing.__rich()__`, use `from rich import print; print(thing)` or in a Notebook, simply leave it as the last object of a cell. + You'll have to install with `amltk[jupyter]` or + `pip install rich[jupyter]` manually.k + ## Scheduler The core engine of the AutoML-Toolkit is the [`Scheduler`][amltk.scheduling.Scheduler]. It purpose it to allow you to create workflows in an event driven manner. It does @@ -86,278 +86,25 @@ computation is actually performed. This allows us to easily switch between different backends, such as threads, processes, clusters, cloud resources, or even custom backends. +!!! info inline end "Available Executors" + + You can find a list of these in our + [executor reference](site:reference/scheduling/executors.md). + The simplest one is a [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] -which will create a pool of processes to run the compute in parallel. +which will create a pool of processes to run the compute in parallel. We provide +a convenience function for this as +[`Scheduler.with_processes()`][amltk.scheduling.Scheduler.with_processes] +well as some other builder ```python exec="true" source="material-block" html="True" from concurrent.futures import ProcessPoolExecutor from amltk.scheduling import Scheduler -scheduler = Scheduler( - executor=ProcessPoolExecutor(max_workers=2), -) -from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide +scheduler = Scheduler.with_processes(2) +from amltk._doc import doc_print; doc_print(print, scheduler) # markdown-exec: hide ``` -We provide convenience functions for some common backends, such as -[`with_processes(max_workers=2)`][amltk.scheduling.Scheduler.with_processes] -which does exactly this. - -!!! tip "Builtin backends" - - If there's any executor background you wish to integrate, we would - be happy to consider it and greatly appreciate a PR! - - === ":material-language-python: `Python`" - - Python supports the `Executor` interface natively with the - [`concurrent.futures`][concurrent.futures] module for processes with the - [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] and - [`ThreadPoolExecutor`][concurrent.futures.ThreadPoolExecutor] for threads. - - !!! example - - === "Process Pool Executor" - - ```python - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_processes(2) # (1)! - ``` - - 1. Explicitly use the `with_processes` method to create a `Scheduler` with - a `ProcessPoolExecutor` with 2 workers. - ```python - from concurrent.futures import ProcessPoolExecutor - from amltk.scheduling import Scheduler - - executor = ProcessPoolExecutor(max_workers=2) - scheduler = Scheduler(executor=executor) - ``` - - === "Thread Pool Executor" - - ```python - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_threads(2) # (1)! - ``` - - 1. Explicitly use the `with_threads` method to create a `Scheduler` with - a `ThreadPoolExecutor` with 2 workers. - ```python - from concurrent.futures import ThreadPoolExecutor - from amltk.scheduling import Scheduler - - executor = ThreadPoolExecutor(max_workers=2) - scheduler = Scheduler(executor=executor) - ``` - - !!! danger "Why to not use threads" - - Python also defines a [`ThreadPoolExecutor`][concurrent.futures.ThreadPoolExecutor] - but there are some known drawbacks to offloading heavy compute to threads. Notably, - there's no way in python to terminate a thread from the outside while it's running. - - === ":simple-dask: `dask`" - - [Dask](https://distributed.dask.org/en/stable/) and the supporting extension [`dask.distributed`](https://distributed.dask.org/en/stable/) - provide a robust and flexible framework for scheduling compute across workers. - - !!! example - - ```python hl_lines="5" - from dask.distributed import Client - from amltk.scheduling import Scheduler - - client = Client(...) - executor = client.get_executor() - scheduler = Scheduler(executor=executor) - ``` - - === ":simple-dask: `dask-jobqueue`" - - [`dask-jobqueue`](https://jobqueue.dask.org/en/latest/) is a package - for scheduling jobs across common clusters setups such as - PBS, Slurm, MOAB, SGE, LSF, and HTCondor. - - - !!! example - - Please see the `dask-jobqueue` [documentation](https://jobqueue.dask.org/en/latest/) - In particular, we only control the parameter `#!python n_workers=` to - use the [`adapt()`](https://jobqueue.dask.org/en/latest/index.html?highlight=adapt#adaptivity) - method, every other keyword is forwarded to the relative - [cluster implementation](https://jobqueue.dask.org/en/latest/api.html). - - In general, you should specify the requirements of each individual worker and - and tune your load with the `#!python n_workers=` parameter. - - If you have any tips, tricks, working setups, gotchas, please feel free - to leave a PR or simply an issue! - - === "Slurm" - - ```python hl_lines="3 4 5 6 7 8 9" - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_slurm( - n_workers=10, # (1)! - queue=..., - cores=4, - memory="6 GB", - walltime="00:10:00" - ) - ``` - - 1. The `n_workers` parameter is used to set the number of workers - to start with. - The [`adapt()`](https://jobqueue.dask.org/en/latest/index.html?highlight=adapt#adaptivity) - method will be called on the cluster to dynamically scale up to `#!python n_workers=` based on - the load. - The `with_slurm` method will create a [`SLURMCluster`][dask_jobqueue.SLURMCluster] - and pass it to the `Scheduler` constructor. - ```python hl_lines="10" - from dask_jobqueue import SLURMCluster - from amltk.scheduling import Scheduler - - cluster = SLURMCluster( - queue=..., - cores=4, - memory="6 GB", - walltime="00:10:00" - ) - cluster.adapt(max_workers=10) - executor = cluster.get_client().get_executor() - scheduler = Scheduler(executor=executor) - ``` - - !!! warning "Running outside the login node" - - If you're running the scheduler itself in a job, this may not - work on some cluster setups. The scheduler itself is lightweight - and can run on the login node without issue. - However you should make sure to offload heavy computations - to a worker. - - If you get it to work, for example in an interactive job, please - let us know! - - !!! info "Modifying the launch command" - - On some cluster commands, you'll need to modify the launch command. - You can use the following to do so: - - ```python - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_slurm(n_workers=..., submit_command="sbatch --extra" - ``` - - === "Others" - - Please see the `dask-jobqueue` [documentation](https://jobqueue.dask.org/en/latest/) - and the following methods: - - * [`Scheduler.with_pbs()`][amltk.scheduling.Scheduler.with_pbs] - * [`Scheduler.with_lsf()`][amltk.scheduling.Scheduler.with_lsf] - * [`Scheduler.with_moab()`][amltk.scheduling.Scheduler.with_moab] - * [`Scheduler.with_sge()`][amltk.scheduling.Scheduler.with_sge] - * [`Scheduler.with_htcondor()`][amltk.scheduling.Scheduler.with_htcondor] - - === ":octicons-gear-24: `loky`" - - [Loky](https://loky.readthedocs.io/en/stable/API.html) is the default backend executor behind - [`joblib`](https://joblib.readthedocs.io/en/stable/), the parallelism that - powers scikit-learn. - - !!! example "Scheduler with Loky Backend" - - === "Simple" - - ```python - from amltk import Scheduler - - # Pass any arguments you would pass to `loky.get_reusable_executor` - scheduler = Scheduler.with_loky(...) - ``` - - - === "Explicit" - - ```python - import loky - from amltk import Scheduler - - scheduler = Scheduler(executor=loky.get_reusable_executor(...)) - ``` - - !!! warning "BLAS numeric backend" - - The loky executor seems to pick up on a different BLAS library (from scipy) - which is different than those used by jobs from something like a `ProcessPoolExecutor`. - - This is likely not to matter for a majority of use-cases. - - === ":simple-ray: `ray`" - - [Ray](https://docs.ray.io/en/master/) is an open-source unified compute framework that makes it easy - to scale AI and Python workloads - — from reinforcement learning to deep learning to tuning, - and model serving. - - !!! info "In progress" - - Ray is currently in the works of supporting the Python - `Executor` interface. See this [PR](https://github.com/ray-project/ray/pull/30826) - for more info. - - === ":simple-apacheairflow: `airflow`" - - [Airflow](https://airflow.apache.org/) is a platform created by the community to programmatically author, - schedule and monitor workflows. Their list of integrations to platforms is endless - but features compute platforms such as Kubernetes, AWS, Microsoft Azure and - GCP. - - !!! info "Planned" - - We plan to support `airflow` in the future. If you'd like to help - out, please reach out to us! - - === ":material-debug-step-over: Debugging" - - Sometimes you'll need to debug what's going on and remove the noise - of processes and parallelism. For this, we have implemented a very basic - [`SequentialExecutor`][amltk.scheduling.SequentialExecutor] to run everything - in a sequential manner! - - === "Easy" - - ```python - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_sequential() - ``` - - === "Explicit" - - ```python - from amltk.scheduling import Scheduler, SequetialExecutor - - scheduler = Scheduler(executor=SequentialExecutor()) - ``` - - !!! warning "Recursion" - - If you use The `SequentialExecutor`, be careful that the stack - of function calls can get quite large, quite quick. If you are - using this for debugging, keep the number of submitted tasks - from callbacks small and focus in on debugging. If using this - for sequential ordering of operations, prefer to use - `with_processes(1)` as this will still maintain order but not - have these stack issues. - - ### Running the Scheduler You may have noticed from the above example that there are many events the shceduler will emit, @@ -419,6 +166,13 @@ scheduler.run() from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide ``` +!!! tip "Determinism" + + It's worth noting that even though we are using an event based system, we + are still guaranteed deterministic execution of the callbacks for any given + event. The source of indeterminism is the order in which events are emitted, + this is determined entirely by your compute functions themselves. + ### Submitting Compute The `Scheduler` exposes a simple [`submit()`][amltk.scheduling.Scheduler.submit] method which allows you to submit compute to be performed **while the scheduler is running**. @@ -462,7 +216,9 @@ from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fon its result/exception of some computation which may not have completed yet. ### Scheduler Events -We won't cover all possible scheduler events but we provide the complete list here: +Here are some of the possible `@events` a `Scheduler` can emit, but +please visit the [scheduler reference](site:reference/scheduling/scheduler.md) +for a complete list. === "`@on_start`" @@ -509,25 +265,17 @@ We won't cover all possible scheduler events but we provide the complete list he ::: amltk.scheduling.Scheduler.on_empty We can access all the counts of all events through the -[`scheduler.event_counts`][amltk.events.Emitter.event_counts] property. +[`scheduler.event_counts`][amltk.scheduling.events.Emitter.event_counts] property. This is a `dict` which has the events as keys and the amount of times it was emitted as the values. -!!! tip "Determinism" - - It's worth noting that even though we are using an event based system, we - are still guaranteed deterministic execution of the callbacks for any given - event. The source of indeterminism is the order in which events are emitted, - this is determined entirely by your compute functions themselves. - - ### Controlling Callbacks There's a few parameters you can pass to any event subscriber such as `@on_start` or `@on_future_result`. These control the behavior of what happens when its event is fired and can be used to control the flow of your system. -You can find their docs here [`Emitter.on()`][amltk.events.Emitter.on]. +These are covered more extensively in our [events reference](site:reference/scheduling/events.md). === "`repeat=`" @@ -691,7 +439,8 @@ However there are more explicit methods. # The will endlessly loop the scheduler @scheduler.on_future_done def submit_again(future: Future) -> None: - scheduler.submit(expensive_function) + if scheduler.running(): + scheduler.submit(expensive_function) scheduler.run(timeout=1) # End after 1 second from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide @@ -730,7 +479,7 @@ def submit_calculations() -> None: @scheduler.on_future_exception def stop_the_scheduler(future: Future, exception: Exception) -> None: - print("Got exception {exception}") + print(f"Got exception {exception}") scheduler.stop() # You can optionally pass `exception=` for logging purposes. scheduler.run(on_exception="ignore") # Scheduler will not stop because of the error @@ -740,7 +489,7 @@ from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fon The second kind of exception that can happen is one that happens in the main process. For example this could happen in one of your callbacks or in the `Scheduler` itself (please raise an issue if this occurs!). By default when you call [`run()`][amltk.scheduling.Scheduler.run] it will set -`#!python run(on_exception="raise")` and raise the exception that occured, with its traceback. +`#!python run(on_exception="raise")` and raise the exception that occurred, with its traceback. This is to help you debug your program. You may also use `#!python run(on_exception="end")` which will just end the `Scheduler` and raise no exception, @@ -750,7 +499,7 @@ are next to process. ## Tasks Now that we have seen how the [`Scheduler`][amltk.scheduling.Scheduler] works, we can look at the [`Task`][amltk.scheduling.Task], a wrapper around a function -that you'll want to submit to the `Scheduler`. The preffered way to create one +that you'll want to submit to the `Scheduler`. The preferred way to create one of these `Tasks` is to use [`scheduler.task(function)`][amltk.scheduling.Scheduler.task]. ### Running a task @@ -772,7 +521,7 @@ scheduler = Scheduler.with_processes(1) collatz_task = scheduler.task(collatz) try: - collatz_task(5) + collatz_task.submit(5) except Exception as e: print(f"{type(e)}: {e}") ``` @@ -798,7 +547,7 @@ collatz_task = scheduler.task(collatz) @scheduler.on_start def launch_initial_task() -> None: - collatz_task(5) + collatz_task.submit(5) scheduler.run() from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide @@ -827,7 +576,7 @@ echo_task = scheduler.task(echo) # Launch the task and do a raw `submit()` with the Scheduler @scheduler.on_start def launch_initial_task() -> None: - echo_task("hello") + echo_task.submit("hello") scheduler.submit(echo, "hi") # Callback for anything result from the scheduler @@ -846,7 +595,7 @@ from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fon We can see in the output of the above code that the `@scheduler.on_future_result` was called twice, meaning our callback `#!python def from_scheduler()` was called twice, -one for the result of `#!python echo_task("hello")` and the other +one for the result of `#!python echo_task.submit("hello")` and the other from `#!python scheduler.submit(echo, "hi")`. On the other hand, the event `@task.on_result` was only called once, meaning our callback `#!python def from_task()` was only called once. @@ -880,19 +629,19 @@ items = iter([1, 2, 3]) @scheduler.on_start def submit_initial() -> None: next_item = next(items) - task_1(next_item) + task_1.submit(next_item) @task_1.on_result def submit_task_2_with_results_of_task_1(_, result: int) -> None: """When task_1 returns, send the result to task_2""" - task_2(result) + task_2.submit(result) @task_1.on_result def submit_task_1_with_next_item(_, result: int) -> None: """When task_1 returns, launch it again with the next items""" next_item = next(items, None) if next_item is not None: - task_1(next_item) + task_1.submit(next_item) return print("Done!") @@ -909,15 +658,15 @@ from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fon ## Task Plugins Another benefit of [`Task`][amltk.scheduling.Task] objects is that we can attach -a [`TaskPlugin`][amltk.scheduling.TaskPlugin] to them. These plugins can automate control +a [`Plugin`][amltk.scheduling.Plugin] to them. These plugins can automate control behaviour of tasks, either through preventing their execution, -modifying the function and its arugments or even attaching plugin specific events! +modifying the function and its arguments or even attaching plugin specific events! For a complete reference, please see the [plugin reference page](site:reference/plugins). ### Call Limiter Perhaps one of the more useful plugins, at least when designing an AutoML System is the -[`CallLimiter`][amltk.scheduling.task_plugin.CallLimiter] plugin. This can help you control +[`Limiter`][amltk.scheduling.plugins.Limiter] plugin. This can help you control both it's concurrency or the absolute limit of how many times a certain task can be successfully submitted. @@ -926,7 +675,7 @@ to submit a `Task` 4 times in rapid succession. However we have the constraint t only ever want 2 of these tasks running at a given time. Let's see how we could achieve that. ```python exec="true" source="material-block" html="True" hl_lines="9" -from amltk import Scheduler, CallLimiter +from amltk.scheduling import Scheduler, Limiter def my_func(x: int) -> int: return x @@ -935,7 +684,7 @@ from amltk._doc import make_picklable; make_picklable(my_func) # markdown-exec: scheduler = Scheduler.with_processes(2) # Specify a concurrency limit of 2 -task = scheduler.task(my_func, plugins=CallLimiter(max_concurrent=2)) +task = scheduler.task(my_func, plugins=Limiter(max_concurrent=2)) # A list of 10 things we want to compute items = iter(range(10)) @@ -944,7 +693,7 @@ results = [] @scheduler.on_start(repeat=4) # Repeat callback 4 times def submit() -> None: next_item = next(items) - task(next_item) + task.submit(next_item) @task.on_result def record_result(_, result: int) -> None: @@ -954,7 +703,7 @@ def record_result(_, result: int) -> None: def launch_another(_, result: int) -> None: next_item = next(items, None) if next_item is not None: - task(next_item) + task.submit(next_item) scheduler.run() print(results) @@ -963,7 +712,7 @@ from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fon You can notice that this limiting worked, given the numbers `#!python 2` and `#!python 3` were skipped and not printed. As expected, we successfully launched the task with both -`#!python 0` and `#!python 1` but as these tasks were not done processing, the `CallLimiter` +`#!python 0` and `#!python 1` but as these tasks were not done processing, the `Limiter` kicks in and prevents the other two. A natural extension to ask is then, "how do we requeue these?". Well lets take a look at the above @@ -976,7 +725,7 @@ method. Below is the same example except here we respond to `@call-limit-reached and requeue the submissions that failed. ```python exec="true" source="material-block" html="True" hl_lines="11 19-21" -from amltk import Scheduler, CallLimiter +from amltk.scheduling import Scheduler, Limiter, Task from amltk.types import Requeue def my_func(x: int) -> int: @@ -984,7 +733,7 @@ def my_func(x: int) -> int: from amltk._doc import make_picklable; make_picklable(my_func) # markdown-exec: hide scheduler = Scheduler.with_processes(2) -task = scheduler.task(my_func, plugins=CallLimiter(max_concurrent=2)) +task = scheduler.task(my_func, plugins=Limiter(max_concurrent=2)) # A list of 10 things we want to compute items = Requeue(range(10)) # A convenience type that you can requeue/append to @@ -993,7 +742,7 @@ results = [] @scheduler.on_start(repeat=4) # Repeat callback 4 times def submit() -> None: next_item = next(items) - task(next_item) + task.submit(next_item) @task.on("concurrent-limit-reached") def add_back_to_queue(task: Task, x: int) -> None: @@ -1007,40 +756,24 @@ def record_result(_, result: int) -> None: def launch_another(_, result: int) -> None: next_item = next(items, None) if next_item is not None: - task(next_item) + task.submit(next_item) scheduler.run() print(results) from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide ``` -### Creating Your Own Task Plugins -The [`CallLimiter`][amltk.scheduling.task_plugin.CallLimiter] plugin is a good example -of what you can achieve with a [`TaskPlugin`][amltk.scheduling.task_plugin.TaskPlugin] -and serves as a reference point for how you can add new events and control -task submission. - -Another good example is the [`PynisherPlugin`][amltk.pynisher.pynisher_task_plugin.PynisherPlugin] which -wraps a `Task` when it's submitted, allowing you to limit memory and wall clock time -of your compute functions, in a cross-platform manner. - -If you have any cool new plugins, we'd love to hear about them! -Please see the [plugin reference page](site:reference/plugins.md) for more. - -## Emitters and the Event System - -!!! todo "TODO" - - This section should contain a little overview of how the - [`Emitter`][amltk.events.Emitter] class works as it's the main - layer through which objects register and emit events, often - by creating a [`subscriber()`][amltk.events.Emitter.subscriber]. - - This should also briefly mention what an [`Event`][amltk.events.Event] - is. +### Under Construction -### Events -TODO + Please see the following reference pages in the meantime: -### Emitters -TODO + * [scheduler reference](site:reference/scheduling/scheduler.md) - A slighltly + more condensed version of how to use the `Scheduler`. + * [task reference](site:reference/scheduling/task.md) - A more comprehensive + explanation of `Task`s and their `@events`. + * [plugin reference](site:reference/scheduling/plugins.md) - An intro to plugins + and how to create your own. + * [executors reference](site:reference/scheduling/executors.md) - A list of + executors and how to use them. + * [events reference](site:reference/scheduling/events.md) - A more comprehensive + look at the event system in AutoML-Toolkit and how to work with them or extend them. diff --git a/docs/index.md b/docs/index.md index 27c3bc37..e3d2af21 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,12 +4,10 @@ Welcome to the AutoML-Toolkit framework docs. See the navigation links in the header or side-bars. Click the :octicons-three-bars-16: button (top left) on mobile. -Check out the bottom of this page for a [quick start](#quick-start), -or for a more thorough understanding of -all that AutoML-Toolkit has to offer, check out our [guides](guides/index.md). - -You can also check out [examples](examples/index.md) for copy-pastable -snippets to start from. +For a quick-start, check out [examples](site:examples/index.md) for copy-pastable +snippets to start from. For a more guided tour through what AutoML-Toolkit can offer, please check +out our [guides](site:guides/index.md). If you've used AutoML-Toolkit before but need some refreshers, you can look +through our [reference pages](site:reference/index.md) or the [API docs](site:api). ## What is AutoML-Toolkit? @@ -22,14 +20,14 @@ allowing you to define, search and build machine learning systems. Use the programming language that defines modern machine learning research. We use [mypy](https://mypy.readthedocs.io/en/stable/) internally and for external - API so you can identifiy and fix errors before a single line of code runs. + API so you can identify and fix errors before a single line of code runs. --- - :octicons-package-dependents-16: __Minimal Dependencies__ AutoML-Toolkit was designed to not introduce dependencies on your code. - We support some [integrations](reference/index.md) but only if they are optionally installed!. + We support some tool integrations but only if they are optionally installed!. --- @@ -37,11 +35,14 @@ allowing you to define, search and build machine learning systems. We can't support all frameworks, and thankfully we don't have to. AutoML-Toolkit was designed to be plug-and-play. Integrate in your own - [optimizers](reference/index.md#optimizers), - [search spaces](reference/index.md#search-spaces), - [backends](reference/index.md#scheduler-executors), - [builders](reference/index.md#pipeline-builders) - and more. All of our [reference](reference/index.md) are built using this same API. + [optimizers](site:reference/optimization/optimizers.md), + [search spaces](site:reference/pipelines/spaces.md), + [execution backends](site:reference/scheduling/executors.md), + [builders](site:reference/pipelines/builders.md) + and more. + + We've worked hard to make sure that how we integrate tools can be done for + your own tools we don't cover. --- @@ -50,6 +51,7 @@ allowing you to define, search and build machine learning systems. AutoML-Toolkit is event driven, meaning you write code that reacts to events as they happen. You can ignore, extend and create new events that have meaning to the systems you build. + This enables tools built from AutoML-Toolkit to support greater forms of interaction, automation and deployment. @@ -58,7 +60,7 @@ allowing you to define, search and build machine learning systems. - :material-directions-fork: __Task Agnostic__ AutoML-Toolkit is task agnostic, meaning you can use it for any machine learning task. - We provide a base [Task](guides/scheduling.md) which you can extend with + We provide a base [Task](site:guides/scheduling.md) which you can extend with events and functionality specific to the tasks you care about. --- @@ -67,170 +69,4 @@ allowing you to define, search and build machine learning systems. AutoML-Toolkit is a community driven project, and we want to hear from you. We are always looking for new contributors, so if you have an idea or want to - contribute, please [get in touch](contributing.md). - ---- - -## Quick Start -What you can use it for depends on what you want to do. - -=== "Create Machine Learning Pipelines" - - We provide a __declarative__ way to define entire machine learning pipelines and any - hyperparameters that go along with it. Rapidly experiment with different setups, - get their search [`space()`][amltk.Pipeline.space], get concrete configurations with a quick - [`configure()`][amltk.Pipeline.configure] - and finally [`build()`][amltk.Pipeline.build] out a real - [sklearn.pipeline.Pipeline][], [torch.nn.Sequential][] or - your own custom pipeline objects. - - Here's a brief example of how you can use AutoML-Toolkit to define a pipeline, - with its hyperparameters, sample from that space and build out a sklearn pipeline - with minimal amounts of code. For a more in-depth look at pipelines and its features, - check out the [Pipelines guide](./guides/pipelines.md) documentation. - - ```python - from amltk.pipeline import Pipeline, step, split, choice - - from sklearn.preprocessing import StandardScaler, MinMaxScaler - from sklearn.svm import SVC - from sklearn.ensemble import RandomForestClassifier - - pipeline = Pipeline.create( - choice( - "scaler", # (1)! - step("standard", StandardScaler), - step("minmax", MinMaxScaler)) - ), - choice( - "algorithm", - step( - "rf", - RandomForestClassifier, - space={"n_estimators": [10, 100] } # (2)! - ), - step( - "svm", - SVC, - space={"C": (0.1, 10.0), "kernel": ["linear", "rbf"]}, - config={"kernel": "rbf"} # (3)! - ), - ), - ) - - space = pipeline.space() # (4)! - config = pipeline.sample(space) # (6)! - configured_pipeline = pipeline.configure(config) # (7)! - sklearn_pipeline = pipeline.build() # (5)! - ``` - - 1. Define choices between steps in your pipeline, `amltk` will figure out how to encode this choice into - the search space. - 2. Decide what hyperparameters to search for for your steps. - 3. Want to quickly set something constant? Use the `config` argument to set a value and remove it from the space - automatically. - 4. Parse out the search space for the pipeline, let `amltk` figure it out - or choose your own [`parse(parser=...)`](reference/index.md) - 5. Let `amltk` figure out what kind of pipeline you want, but you can also - specify your own [`build(builder=...)`](reference/index.md) - 6. Sample a configuration from the search space. - 7. Configure the pipeline with a configuration. - -=== "Optimize Machine Learning Pipelines" - - AutoML-Toolkit integrates with a variety of optimization frameworks, allowing you to - quickly optimize your machine learning pipelines with your favourite optimizer. - We leave the optimization flow, the target function, when to stop and even what you want - the tell the optimizer, completely up to you. - - We do however provide all the tools necessary to express exactly what you want - to have happen. - - Below is a short showcase of the many ways you can define how you want to - optimize and control the optimization process. For a more in-depth look at the full set - of features, follow the [Optimization](./guides/optimization.md) documentation. - - ```python - from amltk.pipeline import Pipeline - from amltk.optimization import Trial - from amltk.scheduling import Scheduler - from amltk.smac import SMACOptimizer - - def evaluate(trial: Trial, pipeline: Pipeline) -> Trial.Report: - model = pipeline.configure(trial.config).build() - - with trial.begin(): # (1)! - # Train and evaluate the model - - if not trial.exception: - return trial.success(cost=...) # (2)! - - return trial.fail() - - my_pipeline = Pipeline.create(...) - - optimizer = SMACOptimizer.create(pipeline.space(), seed=42) # (4)! - - n_workers = 8 - scheduler = Scheduler.with_processes(n_workers) # (3)! - task = scheduler.task(evaluate) - - @scheduler.on_start(repeat=n_workers) # (6)! - def start_optimizing(): - trial = optimizer.ask() - task(trial=trial, pipeline=my_pipeline) # (5)! - - @task.on_done - def start_another_trial(_): - trial = optimizer.ask() - task(trial=trial, pipeline=my_pipeline) - - @task.on_result - def tell_optimizer(report: Trial.Report): - optimizer.tell(report) - - @task.on_result - def store_result(report: Trial.Report): - ... # (8)! - - @task.on_exception - @task.on_cancelled - def stop_optimizing(exception): - print(exception) - scheduler.stop() # (9)! - - scheduler.run(timeout=60) # (10)! - ``` - - 1. We take care of the busy work, just let us know when the trial starts. - 2. We automatically fill in the reports for the optimizer, just let us - know the cost and any other additional info. - 3. Create a scheduler with your own custom backend. We provide a few out of the box, - but you can also [integrate your own](site:guides/scheduling.md). - 4. Create an optimizer over your search space, - we provide a few optimizers of the box, but you can also [integrate your own](site:guides/optimization.md#integrating-your-own-optimizer). - 5. Calling the task runs it in a worker, whether it be a process, cluster node, AWS or - whatever backend you decide to use. - 6. Say _what_ happens and _when_, when the scheduler says it's started, this function - gets called `n_workers` times. - 7. Inform the optimizer of the report ... if you want. - 8. We don't know what data you want and where, that's up to you. - 9. Stop the whole scheduler whenever you like under whatever conditions make sense to you. - 10. And let the system run! - - You can wrap this in a class, create more complicated control flows and even utilize - some more of the functionality of a [`Task`][amltk.Task] to do - much more. We don't tell you how the control flow should or where data goes, this gives - you as much flexibility as you need to get your research done. - - -=== "Build Machine Learning Tools" - - AutoML-Toolkit is a set of tools that are for the purpose of building an AutoML system, - it is not an AutoML system itself. With the variety of AutoML systems out there, we - decided to build this framework as an event driven system. The cool part is, you can - define your own events, your own tasks and how the scheduler should operate. - - !!! info "TODO" - - Come up with a nice example of defining your own task and events + contribute, please get in touch! diff --git a/docs/reference/call_limiter.md b/docs/reference/call_limiter.md deleted file mode 100644 index afb93da5..00000000 --- a/docs/reference/call_limiter.md +++ /dev/null @@ -1,33 +0,0 @@ -# Call Limiter -We can limit the number of times a function is called or how many concurrent -instances of it can be running. To do so, we create the [`CallLimiter`][amltk.CallLimiter] -and pass it in as a plugin. This plugin also introduces some new events that can be listened to. - -```python -from amltk.scheduling import CallLimiter - -task = scheduler.task(..., plugins=CallLimiter(max_calls=10, max_concurrent=2)) - -@task.on("call-limit-reached") -def print_it(task: Task, *args, **kwargs) -> None: - print(f"Task {task.name} was already called {task.n_called} times") - -@task.on("concurrent-limit-reached") -def print_it(task: Task, *args, **kwargs) -> None: - print(f"Task {task.name} already running at max concurrency") -``` - -You can also prevent a task launching while another task is currently running: - -```python -task1 = scheduler.task(...) - -task2 = scheduler.task(..., plugins=CallLimiter(not_while_running=task1)) - -@task2.on("disabled-due-to-running-task") -def on_disabled_due_to_running_task(other_task: Task, task: Task, *args, **kwargs): - print( - f"Task {task.name} was not submitted because {other_task.name} is currently" - " running" - ) -``` diff --git a/docs/reference/comms.md b/docs/reference/comms.md deleted file mode 100644 index af860928..00000000 --- a/docs/reference/comms.md +++ /dev/null @@ -1,100 +0,0 @@ -## Comm Plugin -!!! todo "Update these docs" - - Sorry - -The [`Comm`][amltk.scheduling.Comm] is the access point to to enable -two way communication between a worker and the server. - -??? warning "Local Processes Only" - - We currently use [`multiprocessing.Pipe`][multiprocessing.Pipe] to communicate - between the worker and the scheduler. This means we are limited to local processes - only. - - If there is interest, we could extend this to be interfaced and provide web socket - communication as well. Please open an issue if you are interested in this or if you - would like to contribute. - -### Usage of Comm Plugin -A [`Comm`][amltk.Comm] facilitate the communication between the worker and the scheduler. -By using this `Comm`, we can [`send()`][amltk.Comm.send] and -[`request()`][amltk.Comm.request] messages from the workers point of view. -These messages are then received by the scheduler and emitted as the -[`MESSAGE`][amltk.Comm.MESSAGE] and [`REQUEST`][amltk.Comm.REQUEST] -events respectively which both pass a [`Comm.Msg`][amltk.Comm.Msg] object -to the callback. This object contains the `data` that was transmitted. - -Below we show an example of both `send()` and -`request()` in action and how to use the plugin. - -=== "`send()`" - - ```python hl_lines="7 9 12 16 17 18 19 20 21" - from amltk import Scheduler, Comm - - # The function must accept an optional `Comm` keyword argument - def echoer(xs: list[int], comm: Comm | None = None): - assert comm is not None - - with comm: # (1)! - for x in xs: - comm.send(x) # (2)! - - scheduler = Scheduler.with_processes(1) - task = scheduler.task(echoer, plugin=Comm.Plugin()) - - @scheduler.on_start - def start(): - task.submit([1, 2, 3, 4, 5]) - - @task.on(Comm.MESSAGE) - def on_message(msg: Comm.Msg): # (3)! - print(f"Recieved a message {msg=}") - print(msg.data) - - scheduler.run() - ``` - - 1. The `Comm` object should be used as a context manager. This is to ensure - that the `Comm` object is closed correctly when the function exits. - 2. Here we use the [`send()`][amltk.Comm.send] method to send a message - to the scheduler. - 3. We can also do `#!python Comm.Msg[int]` to specify the type of data - we expect to receive. - -=== "`request()`" - - ```python hl_lines="7 16 17 18 19" - from amltk.scheduling import Scheduler, Comm - - # The function must accept an optional `Comm` keyword argument - def requester(xs: list[int], comm: Comm | None = None): - with comm: - for _ in range(n): - response = comm.request(n) # (1)! - - scheduler = Scheduler(...) - task = scheduler.task(requester, plugin=Comm.Plugin()) - - @scheduler.on_start - def start(): - task.submit([1, 2, 3, 4, 5]) - - @task.on_request - def handle_request(msg: Comm.Msg): - print(f"Recieved request {msg=}") - msg.respond(msg.data * 2) # (2)! - - scheduler.run() - ``` - - 1. Here we use the [`request()`][amltk.Comm.request] method to send a request - to the scheduler with some data. - 2. We can use the [`respond()`][amltk.Comm.Msg.respond] method to - respond to the request with some data. - -!!! tip "Identifying Workers" - - The [`Comm.Msg`][amltk.Comm.Msg] object also has the `identifier` - attribute, which is a unique identifier for the worker. diff --git a/docs/reference/configspace.md b/docs/reference/configspace.md deleted file mode 100644 index 19c83736..00000000 --- a/docs/reference/configspace.md +++ /dev/null @@ -1,213 +0,0 @@ -# ConfigSpace -[ConfigSpace](https://automl.github.io/ConfigSpace/master/) is a library for -representing and sampling configurations for hyperparameter optimization. -It features a straightforward API for defining hyperparameters, their ranges -and even conditional dependencies. - -It is generally flexible enough for more complex use cases, even -handling the complex pipelines of [AutoSklearn](https://automl.github.io/auto-sklearn/master/). -and [AutoPyTorch](https://automl.github.io/Auto-PyTorch/master/), large -scale hyperparameter spaces over which to optimize entire -pipelines at a time. - -We integrate [ConfigSpace](https://automl.github.io/ConfigSpace/master/) with -AutoML-Toolkit by allowing you to parse out entire spaces -from a [Pipeline][amltk.pipeline.Pipeline] and sample from -these spaces. - -Check out the [API doc][amltk.configspace.ConfigSpaceAdapter] for more info. - -!!! note "Space Adapter Interface" - - This integration is provided by implementing the - [SpaceAdapater][amltk.pipeline.SpaceAdapter] interface. - Check out its documentation for implementing your own. - - -## Parser -In general, you should consult the -[ConfigSpace documentation](https://automl.github.io/ConfigSpace/master/). -Anything you can insert into a `ConfigurationSpace` object is valid. - -Here's an example of a simple space using pure python objects. - -```python exec="true" source="material-block" result="python" title="A simple space" -from amltk.configspace import ConfigSpaceAdapter - -search_space = { - "a": (1, 10), - "b": (0.5, 9.0), - "c": ["apple", "banana", "carrot"], -} - -adapter = ConfigSpaceAdapter() -space = adapter.parse(search_space) -print(space) -``` - -You can specify more complex spaces using the `Integer`, `Float` and -`Categorical` functions from ConfigSpace. - -```python exec="true" source="material-block" result="python" title="A more complicated space" -from ConfigSpace import Integer, Float, Categorical, Normal -from amltk.configspace import ConfigSpaceAdapter - -search_space = { - "a": Integer("a", bounds=(1, 1000), log=True), - "b": Float("b", bounds=(2.0, 3.0), distribution=Normal(2.5, 0.1)), - "c": Categorical("c", ["small", "medium", "large"], ordered=True), -} - -adapter = ConfigSpaceAdapter() -space = adapter.parse(search_space) -print(space) -``` - -Lastly, this [`parse()`][amltk.pipeline.Parser.parse] method is also -able to parse more complicated objects, such as a [`Step`][amltk.pipeline.Step] -or even entire [`Pipelines`][amltk.pipeline.Pipeline]. - -```python exec="true" source="material-block" result="python" title="Parsing Steps" -from amltk.configspace import ConfigSpaceAdapter -from amltk.pipeline import step - -my_step = step( - "mystep", - item=object(), - space={"a": (1, 10), "b": (2.0, 3.0), "c": ["cat", "dog"]} -) - -adapter = ConfigSpaceAdapter() -space = adapter.parse(my_step) - -print(space) -``` - -```python exec="true" source="material-block" result="python" title="Parsing a Pipeline" -from ConfigSpace import Float - -from amltk.configspace import ConfigSpaceAdapter -from amltk.pipeline import step, choice, Pipeline - -my_pipeline = Pipeline.create( - choice( - "algorithm", - step("A", item=object(), space={"C": (0.0, 1.0), "initial": (1, 10)}), - step("B", item=object(), space={"lr": Float("lr", (1e-5, 1), log=True)}), - ) -) - -adapter = ConfigSpaceAdapter() -space = adapter.parse(my_pipeline) - -print(space) -``` - -## Sampler -As [ConfigSpaceAdapter][amltk.configspace.ConfigSpaceAdapter] implements the -[Sampler][amltk.pipeline.Sampler] interface, you can also [`sample()`][amltk.pipeline.Sampler.sample] -from these spaces. - -```python exec="true" source="material-block" result="python" title="Sampling from a space" -from amltk.configspace import ConfigSpaceAdapter - -search_space = { - "a": (1, 10), - "b": (0.5, 9.0), - "c": ["apple", "banana", "carrot"], -} - -adapter = ConfigSpaceAdapter() -space = adapter.parse(search_space) -sample = adapter.sample(space) - -print(sample) -``` - -### For use with Step, Pipeline -The [`Pipeline`][amltk.pipeline.Pipeline] and [`Step`][amltk.pipeline.Step] objects -have a [`space()`][amltk.pipeline.Pipeline.space] and -[`sample()`][amltk.pipeline.Pipeline.sample] method. -These accept a [`Parser`][amltk.pipeline.Parser] and a [`Sampler`][amltk.pipeline.Sampler] -interface, for which [`ConfigSpaceAdapter`][amltk.configspace.ConfigSpaceAdapter] -supports poth of these interfaces. - -```python exec="true" source="material-block" result="python" title="Using ConfigSpace with a Step" -from amltk.configspace import ConfigSpaceAdapter -from amltk.pipeline import step - -my_step = step( - "mystep", - item=object(), - space={"a": (1, 10), "b": (2.0, 3.0), "c": ["cat", "dog"]} -) - -space = my_step.space(parser=ConfigSpaceAdapter) -print(space) - -sample = my_step.sample(sampler=ConfigSpaceAdapter) -print(sample) -``` - -```python exec="true" source="material-block" result="python" title="Using ConfigSpace with a Pipeline" -from ConfigSpace import Float - -from amltk.configspace import ConfigSpaceAdapter -from amltk.pipeline import step, choice, Pipeline - -my_pipeline = Pipeline.create( - choice( - "algorithm", - step("A", item=object(), space={"C": (0.0, 1.0), "initial": [1, 10]}), - step("B", item=object(), space={"lr": Float("lr", (1e-5, 1), log=True)}), - ) -) - -space = my_pipeline.space(parser=ConfigSpaceAdapter) -print(space) - -sample = my_pipeline.sample(sampler=ConfigSpaceAdapter) -print(sample) -``` - -### For use with RandomSearch -The [`RandomSearch`][amltk.optimization.RandomSearch] object accepts a -[`Sampler`][amltk.pipeline.Sampler] interface, for which -[`ConfigSpaceAdapter`][amltk.configspace.ConfigSpaceAdapter] supports. - -```python exec="true" source="material-block" result="python" title="Using ConfigSpace with RandomSearch" -from ConfigSpace import Float - -from amltk.configspace import ConfigSpaceAdapter -from amltk.pipeline import step, choice, Pipeline -from amltk.optimization import RandomSearch - -my_pipeline = Pipeline.create( - choice( - "algorithm", - step("A", item=object(), space={"C": (0.0, 1.0), "initial": [1, 10]}), - step("B", item=object(), space={"lr": Float("lr", (1e-5, 1), log=True)}), - ) -) -space = my_pipeline.space(parser=ConfigSpaceAdapter) - -random_search_optimizer = RandomSearch( - space=space, - sampler=ConfigSpaceAdapter, - seed=10 -) - -for i in range(3): - trial = random_search_optimizer.ask() - print(trial) - - with trial.begin(): - # Run experiment here - pass - - report = trial.success(cost=1) - print(report) - - random_search_optimizer.tell(report) -``` - diff --git a/docs/reference/dask-jobqueue.md b/docs/reference/dask-jobqueue.md deleted file mode 100644 index f7a84227..00000000 --- a/docs/reference/dask-jobqueue.md +++ /dev/null @@ -1,90 +0,0 @@ -# Dask JobQueue -[`dask-jobqueue`](https://jobqueue.dask.org/en/latest/) is a package for -scheduling jobs across common clusters setups such as PBS, Slurm, MOAB, -SGE, LSF, and HTCondor. - -You can access most of these directly through the _factory_ methods -of the [`Scheduler`][amltk.Scheduler], forwarding on arguments to them. - -!!! note "Factory Methods" - - * [`Scheduler.with_pbs()`][amltk.scheduling.Scheduler.with_pbs] - * [`Scheduler.with_lsf()`][amltk.scheduling.Scheduler.with_lsf] - * [`Scheduler.with_moab()`][amltk.scheduling.Scheduler.with_moab] - * [`Scheduler.with_sge()`][amltk.scheduling.Scheduler.with_sge] - * [`Scheduler.with_htcondor()`][amltk.scheduling.Scheduler.with_htcondor] - -Please see the `dask-jobqueue` [documentation](https://jobqueue.dask.org/en/latest/) -In particular, we only control the parameter `#!python n_workers=` to -use the [`adapt()`](https://jobqueue.dask.org/en/latest/index.html?highlight=adapt#adaptivity) -method, every other keyword is forwarded to the relative -[cluster implementation](https://jobqueue.dask.org/en/latest/api.html). - -In general, you should specify the requirements of each individual worker and -and tune your load with the `#!python n_workers=` parameter. - -If you have any tips, tricks, working setups, gotchas, please feel free -to leave a PR or simply an issue! - -=== "Slurm" - - ```python hl_lines="3 4 5 6 7 8 9" - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_slurm( - n_workers=10, # (1)! - queue=..., - cores=4, - memory="6 GB", - walltime="00:10:00" - ) - ``` - - 1. The `n_workers` parameter is used to set the number of workers - to start with. - The [`adapt()`](https://jobqueue.dask.org/en/latest/index.html?highlight=adapt#adaptivity) - method will be called on the cluster to dynamically scale up to `#!python n_workers=` based on - the load. - The `with_slurm` method will create a [`SLURMCluster`][dask_jobqueue.SLURMCluster] - and pass it to the `Scheduler` constructor. - ```python hl_lines="10" - from dask_jobqueue import SLURMCluster - from amltk.scheduling import Scheduler - - cluster = SLURMCluster( - queue=..., - cores=4, - memory="6 GB", - walltime="00:10:00" - ) - cluster.adapt(max_workers=10) - executor = cluster.get_client().get_executor() - scheduler = Scheduler(executor=executor) - ``` - - !!! warning "Running inside a job" - - Some cluster setups do not allow jobs to launch jobs themselves. - The scheduler itself is lightweight and can run on the - login node without issue. However you should make sure to offload - heavy computations to a worker. - - If you get it to work, for example in an interactive job, please - let us know! - - !!! info "Modifying the launch command" - - On some cluster commands, you'll need to modify the launch command. - You can use the following to do so: - - ```python - from amltk.scheduling import Scheduler - - scheduler = Scheduler.with_slurm(n_workers=..., submit_command="sbatch --extra" - ``` - -=== "Others" - - Please see the `dask-jobqueue` [documentation](https://jobqueue.dask.org/en/latest/) - and the following methods: - diff --git a/docs/reference/buckets.md b/docs/reference/data/buckets.md similarity index 98% rename from docs/reference/buckets.md rename to docs/reference/data/buckets.md index ea83bb17..aa523b61 100644 --- a/docs/reference/buckets.md +++ b/docs/reference/data/buckets.md @@ -14,7 +14,7 @@ import numpy as np import pandas as pd from sklearn.linear_model import LinearRegression -bucket = PathBucket("path/to/bucket") +bucket = PathBucket("./path/to/bucket") array = np.array([1, 2, 3]) dataframe = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) diff --git a/docs/reference/data.md b/docs/reference/data/index.md similarity index 100% rename from docs/reference/data.md rename to docs/reference/data/index.md diff --git a/docs/reference/index.md b/docs/reference/index.md index 7f10c8fa..6c75bfd3 100644 --- a/docs/reference/index.md +++ b/docs/reference/index.md @@ -1,52 +1,11 @@ -## Search Spaces -* [ConfigSpace](./configspace.md) - A serializable search space definition which - supports choices and search space constraints. A great default go to! -* [Optuna](./optuna.md) - A search space for using Optuna as your optimizer. Only - done as a static definition and currently does not support Optuna's define-by-run. +# Reference +Here you'll find a non-exhaustive but quick reference to +many of the core types and utilities available to AMLTK. +Please use the Table of Contents on the left to browse them. -## Optimizers +If you're looking for a more in depth understanding of how +automl-toolkit works, please take a look at the +[guides section](site:guides/index.md). -* [SMAC](./smac.md) - A powerful Bayesian-Optimization framework, primarly based on a custom - Random Forest, supporting complex conditionals in a bayesian manner. -* [Optuna](./optuna.md) - A highly flexible Optimization framework based on Tree-Parzan - Estimators. -* [NEPS](./neps.md) - An optimizer focused on optimizing neural architectures, allowing for - continuatios and graph based search spaces - -## Pipeline Builders - -* [sklearn](./sklearn.md) - Export your pipelines to a pure [sklearn.pipeline.Pipeline][] - and some utility to ease data splitting. - -## Prebuilt Pipelines - -* [Prebuilt pipelines][./prebuilt_pipelines.md] - A collection of pre-built pipelines - - * XGBoost - -## Scheduler Executors - -* [DaskJobQueue](./dask-jobqueue.md) - A set of [`Executors`][concurrent.futures.Executor] - usable with the [`Scheduler`][amltk.Scheduler] for different cluster setups. - -## Plugins - -* [CallLimiter](./call_limiter.md) - A simple plugin to limit how many times your task - can be called, how many concurrent instances of it can be run and prevent a task being - submitted while another task is running. -* [pynisher](./pynisher.md) - A plugin to limit the maximum time or memory a task can - use, highly suitable for creating AutoML systems. -* [wandb](./wandb.md) - A plugin that automatically logs your runs to - [weights and biases](https://wandb.ai/site)! -* [threadpoolctl](./threadpoolctl.md) - A plugin that uses -[`threadpoolctl`](https://github.com/joblib/threadpoolctl) to limit the number of threads used -by certain numerical libraries within a tasks execution. - -## Utility - -* [Buckets](./buckets.md) - A nice utility to view the file system in a dictionary like - fashion, enabling quick and easy storing of many file types at once. -* [History](./history.md) - A datastructure to house the results of an optimization run and - pull out information after. -* [data](./data.md) - Utilities for working with data containers like numpy arrays, pandas - dataframes and series. +If you're looking for signatures and specific function +documentation, check out the API docs. diff --git a/docs/reference/metalearning.md b/docs/reference/metalearning.md deleted file mode 100644 index 0ee7e418..00000000 --- a/docs/reference/metalearning.md +++ /dev/null @@ -1,341 +0,0 @@ -# Metalearning -An important part of AutoML systems is to perform well on new unseen data. -There are a variety of methods to do so but we provide some building blocks -to help implement these methods. - -## MetaFeatures -Calculating meta-features of a dataset is quite straight foward. - -```python exec="true" source="material-block" result="python" title="Metafeatures" hl_lines="10" -import openml -from amltk.metalearning import compute_metafeatures - -dataset = openml.datasets.get_dataset( - 31, # credit-g - download_data=True, - download_features_meta_data=False, - download_qualities=False, -) -X, y, _, _ = dataset.get_data( - dataset_format="dataframe", - target=dataset.default_target_attribute, -) - -mfs = compute_metafeatures(X, y) - -print(mfs) -``` - -By default [`compute_metafeatures()`][amltk.metalearning.compute_metafeatures] will -calculate all the [`MetaFeature`][amltk.metalearning.MetaFeature] implemented, -iterating through their subclasses to do so. You can pass an explicit list -as well to `compute_metafeatures(X, y, features=[...])`. - -To implement your own is also quite straight forward: - -```python exec="true" source="material-block" result="python" title="Create Metafeature" hl_lines="10 11 12 13 14 15 16 17 18 19" -from amltk.metalearning import MetaFeature, compute_metafeatures -import openml - -dataset = openml.datasets.get_dataset( - 31, # credit-g - download_data=True, - download_features_meta_data=False, - download_qualities=False, -) -X, y, _, _ = dataset.get_data( - dataset_format="dataframe", - target=dataset.default_target_attribute, -) - -class TotalValues(MetaFeature): - - @classmethod - def compute( - cls, - x: pd.DataFrame, - y: pd.Series | pd.DataFrame, - dependancy_values: dict, - ) -> int: - return int(x.shape[0] * x.shape[1]) - -mfs = compute_metafeatures(X, y, features=[TotalValues]) -print(mfs) -``` - -As many metafeatures rely on pre-computed dataset statistics, and they do not -need to be calculated more than once, you can specify the dependancies of -a meta feature. When a metafeature would return something other than a single -value, i.e. a `dict` or a `pd.DataFrame`, we instead call those a -[`DatasetStatistic`][amltk.metalearning.DatasetStatistic]. These will -**not** be included in the result of [`compute_metafeatures()`][amltk.metalearning.compute_metafeatures]. -These `DatasetStatistic`s will only be calculated once on a call to `compute_metafeatures()` so -they can be re-used across all `MetaFeature`s that require that dependancy. - -```python exec="true" source="material-block" result="python" title="Metafeature Dependancy" hl_lines="10 11 12 13 14 15 16 17 18 19 20 23 26 35" -from amltk.metalearning import MetaFeature, DatasetStatistic, compute_metafeatures -import openml - -dataset = openml.datasets.get_dataset( - 31, # credit-g - download_data=True, - download_features_meta_data=False, - download_qualities=False, -) -X, y, _, _ = dataset.get_data( - dataset_format="dataframe", - target=dataset.default_target_attribute, -) - -class NAValues(DatasetStatistic): - """A mask of all NA values in a dataset""" - - @classmethod - def compute( - cls, - x: pd.DataFrame, - y: pd.Series | pd.DataFrame, - dependancy_values: dict, - ) -> pd.DataFrame: - return x.isna() - - -class PercentageNA(MetaFeature): - """The percentage of values missing""" - - dependencies = (NAValues,) - - @classmethod - def compute( - cls, - x: pd.DataFrame, - y: pd.Series | pd.DataFrame, - dependancy_values: dict, - ) -> int: - na_values = dependancy_values[NAValues] - n_na = na_values.sum().sum() - n_values = int(x.shape[0] * x.shape[1]) - return float(n_na / n_values) - -mfs = compute_metafeatures(X, y, features=[PercentageNA]) -print(mfs) -``` - -To view the description of a particular `MetaFeature`, you can call -[`.description()`][amltk.metalearning.DatasetStatistic.description] -on it. Otherwise you can access all of them in the following way: - -```python exec="true" source="tabbed-left" result="python" title="Metafeature Descriptions" hl_lines="4" -from pprint import pprint -from amltk.metalearning import metafeature_descriptions - -descriptions = metafeature_descriptions() -for name, description in descriptions.items(): - print("---") - print(name) - print("---") - print(" * " + description) -``` - -## Dataset Distances -One common way to define how similar two datasets are is to compute some "similarity" -between them. This notion of "similarity" requires computing some features of a dataset -(**metafeatures**) first, such that we can numerically compute some distance function. - -Let's see how we can quickly compute the distance between some datasets with -[`dataset_distance()`][amltk.metalearning.dataset_distance]! - -```python exec="true" source="material-block" result="python" title="Dataset Distances P.1" session='dd' -import pandas as pd -import openml - -from amltk.metalearning import compute_metafeatures - -def get_dataset(dataset_id: int) -> tuple[pd.DataFrame, pd.Series]: - dataset = openml.datasets.get_dataset( - dataset_id, - download_data=True, - download_features_meta_data=False, - download_qualities=False, - ) - X, y, _, _ = dataset.get_data( - dataset_format="dataframe", - target=dataset.default_target_attribute, - ) - return X, y - -d31 = get_dataset(31) -d3 = get_dataset(3) -d4 = get_dataset(4) - -metafeatures_dict = { - "dataset_31": compute_metafeatures(*d31), - "dataset_3": compute_metafeatures(*d3), - "dataset_4": compute_metafeatures(*d4), -} - -metafeatures = pd.DataFrame(metafeatures_dict) -print(metafeatures) -``` - -Now we want to know which one of `#!python "dataset_3"` or `#!python "dataset_4"` is -more _similar_ to `#!python "dataset_31"`. - -```python exec="true" source="material-block" result="python" title="Dataset Distances P.2" session='dd' -from amltk.metalearning import dataset_distance - -target = metafeatures_dict.pop("dataset_31") -others = metafeatures_dict - -distances = dataset_distance(target, others, distance_metric="l2") -print(distances) -``` - -Seems like `#!python "dataset_3"` is some notion of closer to `#!python "dataset_31"` -than `#!python "dataset_4"`. However the scale of the metafeatures are not exactly all close. -For example, many lie between `#!python (0, 1)` but some like `instance_count` can completely -dominate the show. - -Lets repeat the computation but specify that we should apply a `#!python "minmax"` scaling -across the rows. - -```python exec="true" source="material-block" result="python" title="Dataset Distances P.3" session='dd' hl_lines="5" -distances = dataset_distance( - target, - others, - distance_metric="l2", - scaler="minmax" -) -print(distances) -``` - -Now `#!python "dataset_3"` is considered more similar but the difference between the two is a lot less -dramatic. In general, applying some scaling to values of different scales is required for metalearning. - -You can also use an [sklearn.preprocessing.MinMaxScaler][] or anything other scaler from scikit-learn -for that matter. - -```python exec="true" source="material-block" result="python" title="Dataset Distances P.3" session='dd' hl_lines="7" -from sklearn.preprocessing import MinMaxScaler - -distances = dataset_distance( - target, - others, - distance_metric="l2", - scaler=MinMaxScaler() -) -print(distances) -``` - -## Portfolio Selection -Another common trick in meta-learning is to define a portfolio of configurations that maximize some -notion of converage across those datasets. The intution here is that this also means that any -new dataset is also covered! - -Suppose we hade the given performances of some configurations across some datasets. -```python exec="true" source="material-block" result="python" title="Initial Portfolio" -import pandas as pd - -performances = { - "c1": [90, 60, 20, 10], - "c2": [20, 10, 90, 20], - "c3": [10, 20, 40, 90], - "c4": [90, 10, 10, 10], -} -portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) -print(portfolio) -``` - -If we could only choose `#!python k=3` of these configurations on some new given dataset, which ones would -you choose and in what priority? -Here is where we can apply [`portfolio_selection()`][amltk.metalearning.portfolio_selection]! - -The idea is that we pick a subset of these algorithms that maximise some value of utility for -the portfolio. We do this by adding a single configuration from the entire set, 1-by-1 until -we reach `k`, beggining with the empty portfolio. - -Let's see this in action! - -```python exec="true" source="material-block" result="python" title="Portfolio Selection" hl_lines="12 13 14 15 16" -import pandas as pd -from amltk.metalearning import portfolio_selection - -performances = { - "c1": [90, 60, 20, 10], - "c2": [20, 10, 90, 20], - "c3": [10, 20, 40, 90], - "c4": [90, 10, 10, 10], -} -portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) - -selected_portfolio, trajectory = portfolio_selection( - portfolio, - k=3, - scaler="minmax" -) - -print(selected_portfolio) -print() -print(trajectory) -``` - -The trajectory tells us which configuration was added at each time stamp along with the utility -of the portfolio with that configuration added. However we havn't specified how _exactly_ we defined the -utility of a given portfolio. We could define our own function to do so: - -```python exec="true" source="material-block" result="python" title="Portfolio Selection Custom" hl_lines="12 13 14 20" -import pandas as pd -from amltk.metalearning import portfolio_selection - -performances = { - "c1": [90, 60, 20, 10], - "c2": [20, 10, 90, 20], - "c3": [10, 20, 40, 90], - "c4": [90, 10, 10, 10], -} -portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) - -def my_function(p: pd.DataFrame) -> float: - """Take the maximum score for each dataset and then take the mean across them.""" - return p.max(axis=1).mean() - -selected_portfolio, trajectory = portfolio_selection( - portfolio, - k=3, - scaler="minmax", - portfolio_value=my_function, -) - -print(selected_portfolio) -print() -print(trajectory) -``` - -This notion of reducing across all configurations for a dataset and then aggregating these is common -enough that we can also directly just define these operations and we will perform the rest. - -```python exec="true" source="material-block" result="python" title="Portfolio Selection With Reduction" hl_lines="17 18" -import pandas as pd -import numpy as np -from amltk.metalearning import portfolio_selection - -performances = { - "c1": [90, 60, 20, 10], - "c2": [20, 10, 90, 20], - "c3": [10, 20, 40, 90], - "c4": [90, 10, 10, 10], -} -portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) - -selected_portfolio, trajectory = portfolio_selection( - portfolio, - k=3, - scaler="minmax", - row_reducer=np.max, # This is actually the default - aggregator=np.mean, # This is actually the default -) - -print(selected_portfolio) -print() -print(trajectory) -``` diff --git a/docs/reference/metalearning/index.md b/docs/reference/metalearning/index.md new file mode 100644 index 00000000..5592cb96 --- /dev/null +++ b/docs/reference/metalearning/index.md @@ -0,0 +1,28 @@ +# Metalearning +An important part of AutoML systems is to perform well on new unseen data. +There are a variety of methods to do so but we provide some building blocks +to help implement these methods. + +!!! warning "API" + + The meta-learning features have not been extensively used yet + and such no solid API has been developed yet. We will + deprecate any API subject to change before changing them. + +## MetaFeatures + +::: amltk.metalearning.metafeatures + options: + members: false + +## Dataset Distances + +::: amltk.metalearning.dataset_distances + options: + members: false + +## Portfolio Selection + +::: amltk.metalearning.portfolio + options: + members: false diff --git a/docs/reference/neps.md b/docs/reference/neps.md deleted file mode 100644 index 18662355..00000000 --- a/docs/reference/neps.md +++ /dev/null @@ -1,80 +0,0 @@ -# NEPS - -The below example shows how you can use neps to optimize an sklearn pipeline. - -!!! todo "Deep Learning" - - Write an example demonstrating NEPS with continuations - -!!! todo "Graph Search Spaces" - - Write an example demonstrating NEPS with its graph search spaces - -```python -from __future__ import annotations - -import logging - -from sklearn.datasets import load_iris -from sklearn.ensemble import RandomForestClassifier -from sklearn.metrics import accuracy_score -from sklearn.model_selection import train_test_split - -from amltk import History, Pipeline, Trial, step -from amltk.neps import NEPSOptimizer, NEPSTrialInfo -from amltk.scheduling.scheduler import Scheduler - -logging.basicConfig(level=logging.DEBUG) - - -def target_function(trial: Trial[NEPSTrialInfo], pipeline: Pipeline) -> Trial.Report: - X, y = load_iris(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split(X, y) - clf = pipeline.configure(trial.config).build() - - with trial.begin(): - clf.fit(X_train, y_train) - y_pred = clf.predict(X_test) - accuracy = accuracy_score(y_test, y_pred) - loss = 1 - accuracy - return trial.success(loss=loss, accuracy=accuracy) - - return trial.fail() - - -pipeline = Pipeline.create( - step("rf", RandomForestClassifier, space={"n_estimators": (10, 100)}), -) -optimizer = NEPSOptimizer.create(space=pipeline.space(), overwrite=True) - - -N_WORKERS = 4 -scheduler = Scheduler.with_processes(N_WORKERS) -task = scheduler.task(target_function) - -history = History() - - -@scheduler.on_start(repeat=N_WORKERS) -def on_start(): - trial = optimizer.ask() - task.submit(trial, pipeline) - - -@task.on_result -def tell_and_launch_trial(_, report: Trial.Report): - optimizer.tell(report) - trial = optimizer.ask() - task.submit(trial, pipeline) - - -@task.on_result -def add_to_history(_, report: Trial.Report): - history.add(report) - - -scheduler.run(timeout=5, wait=False) - -print(history.df()) -history.to_csv("history.csv") -``` diff --git a/docs/reference/history.md b/docs/reference/optimization/history.md similarity index 98% rename from docs/reference/history.md rename to docs/reference/optimization/history.md index 9e7cdd59..a8604fee 100644 --- a/docs/reference/history.md +++ b/docs/reference/optimization/history.md @@ -38,7 +38,7 @@ print(history.df()) Typically, to use this inside of an optimization run, you would add the reports inside of a callback from your [`Task`][amltk.Task]s. Please -see the [optimization guide](../guides/optimization.md) for more details. +see the [optimization guide](site:guides/optimization.md) for more details. ??? example "With an Optimizer and Scheduler" diff --git a/docs/reference/optimization/optimizers.md b/docs/reference/optimization/optimizers.md new file mode 100644 index 00000000..eddf4891 --- /dev/null +++ b/docs/reference/optimization/optimizers.md @@ -0,0 +1,27 @@ +## Optimizers +An [`Optimizer`][amltk.optimization.Optimizer]'s goal is to maximize/minimize +attempting to find the optima + +## SMAC + +::: amltk.optimization.optimizers.smac + options: + members: false + +## NePs + +::: amltk.optimization.optimizers.neps + options: + members: false + +## Optuna + +::: amltk.optimization.optimizers.optuna + options: + members: false + +## Integrating your own + +::: amltk.optimization.optimizer + options: + members: false diff --git a/docs/reference/optimization/profiling.md b/docs/reference/optimization/profiling.md new file mode 100644 index 00000000..7a2b3147 --- /dev/null +++ b/docs/reference/optimization/profiling.md @@ -0,0 +1,5 @@ +## Profiling + +:: amltk.profiling.profiler + options: + members: False diff --git a/docs/reference/optimization/trials.md b/docs/reference/optimization/trials.md new file mode 100644 index 00000000..b71a41a0 --- /dev/null +++ b/docs/reference/optimization/trials.md @@ -0,0 +1,11 @@ +## Trial + +::: amltk.optimization.trial + options: + members: False + +### History + +::: amltk.optimization.history + options: + members: False diff --git a/docs/reference/optuna.md b/docs/reference/optuna.md deleted file mode 100644 index 6435223b..00000000 --- a/docs/reference/optuna.md +++ /dev/null @@ -1,223 +0,0 @@ -# Optuna -[Optuna](https://optuna.org/) is an automatic hyperparameter optimization -software framework, particularly designed for machine learning. - -We provide the follow integrations for Optuna: - -* A [`OptunaSpaceAdapter`][amltk.optuna.OptunaSpaceAdapter] for [parsing](#parsing-spaces) -an Optuna search space and for [sampling](#sampling) -from an Optuna search space. - -??? example "SpaceAdapter Interface" - - This is an implementation of the - [`SpaceAdapter`][amltk.pipeline.space.SpaceAdapter] interface which - can be used for parsing or sampling anything in AutoML-Toolkit. - -* An [`OptunaOptimizer`][amltk.optuna.OptunaOptimizer] for optimizing -some given function. See the [Optimizer](#Optimizer) - -??? example "Optimizer Interface" - - This is an implementation of the [`Optimizer`][amltk.optimization.Optimizer] - interface which offers an _ask-and-tell_ interface to some underlying optimizer. - - -## Parser -In general, you should consult the [Optuna documentation](https://optuna.org/). - -You can encode the following things: - -```python exec="true" source="material-block" result="python" title="A simple space" -from optuna.distributions import FloatDistribution - -from amltk.optuna import OptunaSpaceAdapter - -search_space = { - "a": (1, 10), # An int - "b": (2.5, 10.0), # A float - "c": ["apple", "banana", "carrot"], # A categorical - "d": FloatDistribution(1e-5, 1, log=True), # An Optuna log float distribution -} - -adapter = OptunaSpaceAdapter() - -space = adapter.parse(search_space) -print(space) -``` - -In general, any of these simple types or anything inheriting from -`optuna.BaseDistribution` can be used. - -This [`parse()`][amltk.pipeline.Parser.parse] method is also -able to parse more complicated objects, such as a [`Step`][amltk.pipeline.Step] -or even entire [`Pipelines`][amltk.pipeline.Pipeline]. - -```python exec="true" source="material-block" result="python" title="Parsing Steps" -from amltk.optuna import OptunaSpaceAdapter -from amltk.pipeline import step - -my_step = step( - "mystep", - item=object(), - space={"a": (1, 10), "b": (2.0, 3.0), "c": ["cat", "dog"]} -) - -adapter = OptunaSpaceAdapter() -space = adapter.parse(my_step) - -print(space) -``` - -```python exec="true" source="material-block" result="python" title="Parsing a Pipeline" -from optuna.distributions import FloatDistribution - -from amltk.optuna import OptunaSpaceAdapter -from amltk.pipeline import step, Pipeline - -my_pipeline = Pipeline.create( - step("A", item=object(), space={"C": (0.0, 1.0), "initial": (1, 10)}), - step("B", item=object(), space={"lr": FloatDistribution(1e-5, 1, log=True)}), -) - -adapter = OptunaSpaceAdapter() -space = adapter.parse(my_pipeline) - -print(space) -``` - -## Sampler -As [OptunaSpaceAdapter][amltk.optuna.OptunaSpaceAdapter] implements the -[Sampler][amltk.pipeline.Sampler] interface, you can also [`sample()`][amltk.pipeline.Sampler.sample] -from these spaces. - -```python exec="true" source="material-block" result="python" title="Sampling from a space" -from amltk.optuna import OptunaSpaceAdapter - -search_space = { - "a": (1, 10), - "b": (0.5, 9.0), - "c": ["apple", "banana", "carrot"], -} - -adapter = OptunaSpaceAdapter() -space = adapter.parse(search_space) -sample = adapter.sample(space) - -print(sample) -``` - -### For use with Step, Pipeline -The [`Pipeline`][amltk.pipeline.Pipeline] and [`Step`][amltk.pipeline.Step] objects -have a [`space()`][amltk.pipeline.Pipeline.space] and -[`sample()`][amltk.pipeline.Pipeline.sample] method. -These accept a [`Parser`][amltk.pipeline.Parser] and a [`Sampler`][amltk.pipeline.Sampler] -interface, for which [`OptunaSpaceAdapter`][amltk.optuna.OptunaSpaceAdapter] -supports poth of these interfaces. - -```python exec="true" source="material-block" result="python" title="Using Optuna with a Step" -from amltk.optuna import OptunaSpaceAdapter -from amltk.pipeline import step - -my_step = step( - "mystep", - item=object(), - space={"a": (1, 10), "b": (2.0, 3.0), "c": ["cat", "dog"]} -) - -space = my_step.space(parser=OptunaSpaceAdapter) -print(space) - -sample = my_step.sample(sampler=OptunaSpaceAdapter) -print(sample) -``` - -```python exec="true" source="material-block" result="python" title="Using Optuna with a Pipeline" -from optuna.distributions import FloatDistribution - -from amltk.optuna import OptunaSpaceAdapter -from amltk.pipeline import step, Pipeline - -my_pipeline = Pipeline.create( - step("A", item=object(), space={"C": (0.0, 1.0), "initial": [1, 10]}), - step("B", item=object(), space={"lr": FloatDistribution(1e-5, 1, log=True)}), -) - -space = my_pipeline.space(parser=OptunaSpaceAdapter) -print(space) - -sample = my_pipeline.sample(sampler=OptunaSpaceAdapter) -print(sample) -``` - -### For use with RandomSearch -The [`RandomSearch`][amltk.optimization.RandomSearch] object accepts a -[`Sampler`][amltk.pipeline.Sampler] interface, for which -[`OptunaSpaceAdapter`][amltk.optuna.OptunaSpaceAdapter] supports. - -```python exec="true" source="material-block" result="python" title="Using Optuna with RandomSearch" -from optuna.distributions import FloatDistribution - -from amltk.optuna import OptunaSpaceAdapter -from amltk.pipeline import step, Pipeline -from amltk.optimization import RandomSearch - -my_pipeline = Pipeline.create( - step("A", item=object(), space={"C": (0.0, 1.0), "initial": [1, 10]}), - step("B", item=object(), space={"lr": FloatDistribution(1e-5, 1, log=True)}), -) -space = my_pipeline.space(parser=OptunaSpaceAdapter) - -random_search_optimizer = RandomSearch( - space=space, - sampler=OptunaSpaceAdapter, - seed=10 -) - -for i in range(3): - trial = random_search_optimizer.ask() - print(trial) - - with trial.begin(): - # Run experiment here - pass - - report = trial.success(cost=1) - print(report) - - random_search_optimizer.tell(report) -``` - -## Optimizer -We also integrate Optuna using the [`Optimizer`][amltk.optimization.Optimizer] interface. -This requires us to support two keys methods, [`ask()`][amltk.optimization.Optimizer.ask] -and [`tell()`][amltk.optimization.Optimizer.tell]. - -```python -from amltk.optuna import OptunaOptimizer, OptunaSpaceAdapter -from amltk.pipeline import step - -item = step( - "mystep", - item=object(), - space={"a": (1, 10), "b": (2.0, 3.0), "c": ["cat", "dog"]} -) - -space = item.space(parser=OptunaSpaceAdapter) - -# You can forward **kwargs here to `optuna.create_study()` -optimizer = OptunaOptimizer.create(space=space) - -for i in range(3): - trial = optimizer.ask() - print(trial) - - with trial.begin(): - # Run experiment here - pass - - report = trial.success(cost=1) - print(report) - - optimizer.tell(report) -``` diff --git a/docs/reference/pipelines/builders.md b/docs/reference/pipelines/builders.md new file mode 100644 index 00000000..838d5ff9 --- /dev/null +++ b/docs/reference/pipelines/builders.md @@ -0,0 +1,33 @@ +## Builders +A [pipeline](site:reference/pipelines/pipeline.md) of [`Node`][amltk.pipeline.Node]s +is just an abstract representation of some implementation of a pipeline that will actually do +things, for example an sklearn [`Pipeline`][sklearn.pipeline.Pipeline] or a +Pytorch `Sequential`. + +To facilitate custom builders and to allow you to customize building, +there is a explicit argument `builder=` required when +calling [`.build(builder=...)`][amltk.pipeline.Node] on your pipeline. + +Each builder gives the [various kinds of components](site:reference/pipelines/components.md) +an actual meaning, for example the [`Split`][amltk.pipeline.Split] with +the sklearn [`builder()`][amltk.pipeline.builders.sklearn.build], +translates to a [`ColumnTransformer`][sklearn.compose.ColumnTransformer] and +a [`Sequential`][amltk.pipeline.Sequential] translates to an sklearn +[`Pipeline`][sklearn.pipeline.Pipeline]. + + +## Scikit-learn + +::: amltk.pipeline.builders.sklearn + options: + members: False + +## PyTorch +??? todo "Planned" + + If anyone has good knowledge of building pytorch networks in a more functional + manner and would like to contribute, please feel free to reach out! + +At the moment, we do not provide any native support for `torch`. You can +however make use of `skorch` to convert your networks to a scikit-learn interface, +using the scikit-learn builder instead. diff --git a/docs/reference/pipelines/pipeline.md b/docs/reference/pipelines/pipeline.md new file mode 100644 index 00000000..9f563119 --- /dev/null +++ b/docs/reference/pipelines/pipeline.md @@ -0,0 +1,34 @@ +## Pieces of a Pipeline +A pipeline is a collection of [`Node`][amltk.pipeline.node.Node]s +that are connected together to form a directed acylic graph, where the nodes +follow a parent-child relation ship. The purpose of these is to form some _abstract_ +representation of what you want to search over/optimize and then build into a concrete object. + +These [`Node`][amltk.pipeline.node.Node]s allow you to specific the function/object that +will be used there, it's search space and any configuration you want to explicitly apply. +There are various components listed below which gives these nodes extract syntatic meaning, +e.g. a [`Choice`](#choice) which represents some choice between it's children while +a [`Sequential`](#sequential) indicates that each child follows one after the other. + +Once a pipeline is created, you can perform 3 very critical operations on it: + +* [`search_space(parser=...)`][amltk.pipeline.node.Node.search_space] - This will return the + search space of the pipeline, as defined by it's nodes. You can find the reference to + the [available parsers and search spaces here](site:reference/pipelines/search_spaces.md). +* [`configure(config=...)`][amltk.pipeline.node.Node.configure] - This will return a + new pipeline where each node is configured correctly. +* [`build(builder=...)`][amltk.pipeline.node.Node.build] - This will return some + concrete object from a configured pipeline. You can find the reference to + the [available builders here](site:reference/pipelines/builders.md). + +### Components + +::: amltk.pipeline.components + options: + members: false + +### Node + +::: amltk.pipeline.node + options: + members: false diff --git a/docs/reference/pipelines/spaces.md b/docs/reference/pipelines/spaces.md new file mode 100644 index 00000000..14508156 --- /dev/null +++ b/docs/reference/pipelines/spaces.md @@ -0,0 +1,45 @@ +## Spaces +A common requirement when performing optimization of some pipeline +is to be able to parametrize it. To do so we often think about parametrize +each component separately, with the structure of the pipeline adding additional +constraints. + +To facilitate this, we allow the construction of +[piplines](site:reference/pipelines.pipeline.md), where each part +of the pipeline can contains a [`.space`][amltk.pipeline.node.Node.space]. +When we wish to extract out the entire search space from the pipeline, we can +call [`search_space(parser=...)`][amltk.pipeline.node.Node.search_space] on the root node +of our pipeline, returning some sort of _space_ object. + +Now there are unfortunately quite a few search space implementations out there. +Some support concepts such as forbidden combinations, conditionals and +functional constraints, while others are fully constrained just numerical +parameters. Other reasons to choose a particular space representation is +dependant upon some [`Optimizer`](site:reference/optimization/optimizers.md) +you may wish to use, where typically they will only have one preferred search +space representation. + +To generalize over this, AMLTK itself will not care what is in a `.space` +of each part of the pipeline, i.e. + +```python exec="true" source="material-block" result="python" +from amltk.pipeline import Component + +c = Component(object, space="hmmm, a str space?") +from amltk._doc import doc_print; doc_print(print, c) # markdown-exec: hide +``` + +What follow's below is a list of supported parsers you could pass `parser=` +to extract a search space representation. + +## ConfigSpace + +::: amltk.pipeline.parsers.configspace + options: + members: false + +## Optuna + +::: amltk.pipeline.parsers.optuna + options: + members: false diff --git a/docs/reference/plugins.md b/docs/reference/plugins.md deleted file mode 100644 index 6bb3301a..00000000 --- a/docs/reference/plugins.md +++ /dev/null @@ -1,2 +0,0 @@ -# TODO -Most of this info is in other reference pages for now. We should condense them here. Sorry :/ diff --git a/docs/reference/prebuilt_pipelines.md b/docs/reference/prebuilt_pipelines.md deleted file mode 100644 index 58145637..00000000 --- a/docs/reference/prebuilt_pipelines.md +++ /dev/null @@ -1,5 +0,0 @@ -# Prebuilt Pipelines -We provide some prebuilt pipelines for either testing or just to get off the ground -running. - -* [XGBoost][amltk.pipeline.xgboost] diff --git a/docs/reference/pynisher.md b/docs/reference/pynisher.md deleted file mode 100644 index 362d2450..00000000 --- a/docs/reference/pynisher.md +++ /dev/null @@ -1,105 +0,0 @@ -# Pynisher -The plugin uses [pynisher](https://github.com/automl/pynisher) to place memory, cpu and walltime -constraints on processes, crashing them if these limits are reached. - -It's best use is when used with [`Scheduler.with_processes()`][amltk.Scheduler.with_processes] to have -work performed in processes. - -??? warning "Scheduler Executor" - - This will place process limits on the task as soon as it starts - running, whever it may be running. If you are using - [`Scheduler.with_sequential()`][amltk.Scheduler.with_sequential] - then this will place limits on the main process, likely not what you - want. This also does not work with a - [`ThreadPoolExecutor`][concurrent.futures.ThreadPoolExecutor]. - - If using this with something like [`dask-jobqueue`](./dask-jobqueue.md), - then this will place limits on the workers it spawns. It would be better - to place limits directly through dask job-queue then. - -??? warning "Platform Limitations" - - Pynisher has some limitations with memory on Mac and Windows: - https://github.com/automl/pynisher#features - - -## Setting limits -To limit a task, we can create a [`PynisherPlugin`][amltk.pynisher.PynisherPlugin] and -pass that to our [`Task`][amltk.Task]. Each of the limits has an associated event -that can be listened to and acted upon if needed. - -### Wall time - -The maximum amount of wall clock time this task can use. -If the wall clock time limit triggered and the function crashes as a result, -the [`@pynisher-timeout`][amltk.pynisher.PynisherPlugin.TIMEOUT] and -[`@pynisher-wall-time-limit`][amltk.pynisher.PynisherPlugin.WALL_TIME_LIMIT_REACHED] events -will be emitted. - -```python -from amltk.scheduling import Scheduler -from amltk.pynisher import PynisherPlugin - -scheduler = Scheduler.with_processes(1) -task = scheduler.task(..., plugins=PynisherPlugin(wall_time_limit=(5, "m")) # (1)! - -@task.on("pynisher-wall-time-limit") -def print_it(exception): - print(f"Failed with {exception=}") -``` - -1. Possible units are `#!python "s", "m", "h"`, defaults to `#!python "s"` - -### Memory - -The maximum amount of memory this task can use. -If the memory limit is triggered, the function crashes as a result, -emitting the [`@pynisher-memory-limit`][amltk.pynisher.PynisherPlugin.MEMORY_LIMIT_REACHED] event. - -```python -from amltk.scheduling import Scheduling -from amltk.pynisher import PynisherPlugin - -scheduler = Scheduler.with_processes(1) -task = scheduler.task(..., plugins=PynisherPlugin(memory_limit=(2, "gb")) # (1)! - -@task.on("pynisher-memory-limit") -def print_it(exception): - print(f"Failed with {exception=}") -``` - -1. Possible units are `#!python "b", "kb", "mb", "gb"`, defaults to `#!python "b"` - -!!! warning "Memory Limits with Pynisher" - - Pynisher has some limitations with memory on Mac and Windows: - https://github.com/automl/pynisher#features - -### CPU time - -The maximum amount of CPU time this task can use. -If the CPU time limit triggered and the function crashes as a result, -the [`@pynisher-timeout`][amltk.pynisher.PynisherPlugin.TIMEOUT] and -[`@pynisher-cpu-time-limit`][amltk.pynisher.PynisherPlugin.CPU_TIME_LIMIT_REACHED] -events will be emitted. - -```python -from amltk.scheduling import Scheduler -from amltk.pynisher import PynisherPlugin - - -scheduler = Scheduler.with_processes(1) -task = scheduler.task(..., plugins=PynisherPlugin(cpu_time_limit=(60, "s"))) # (1)! - -@task.on("pynisher-cpu-time-limit") -def print_it(exception): - print(f"Failed with {exception=}") -``` - -1. Possible units are `#!python "s", "m", "h"`, defaults to `#!python "s"` - -!!! warning "CPU Time Limits with Pynisher" - - Pynisher has some limitations with cpu timing on Mac and Windows: - https://github.com/automl/pynisher#features diff --git a/docs/reference/scheduling/events.md b/docs/reference/scheduling/events.md new file mode 100644 index 00000000..53f61ac8 --- /dev/null +++ b/docs/reference/scheduling/events.md @@ -0,0 +1,5 @@ +## Events + +::: amltk.scheduling.events + options: + members: False diff --git a/docs/reference/scheduling/executors.md b/docs/reference/scheduling/executors.md new file mode 100644 index 00000000..8ce6c9fc --- /dev/null +++ b/docs/reference/scheduling/executors.md @@ -0,0 +1,268 @@ +## Executors + +The [`Scheduler`][amltk.scheduling.Scheduler] uses +an [`Executor`][concurrent.futures.Executor], a builtin python native to +`#!python submit(f, *args, **kwargs)` to be computed +else where, whether it be locally or remotely. + +```python +from amltk.scheduling import Scheduler + +scheduler = Scheduler(executor=...) +``` + +Some parallelism libraries natively support this interface while we can +wrap others. You can also wrap you own custom backend by using +the `Executor` interface, which is relatively simple to implement. + +If there's any executor background you wish to integrate, we would +be happy to consider it and greatly appreciate a PR! + +### :material-language-python: `Python` + +Python supports the `Executor` interface natively with the +[`concurrent.futures`][concurrent.futures] module for processes with the +[`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] and +[`ThreadPoolExecutor`][concurrent.futures.ThreadPoolExecutor] for threads. + +??? tip "Usage" + + === "Process Pool Executor" + + ```python + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(2) # (1)! + ``` + + 1. Explicitly use the `with_processes` method to create a `Scheduler` with + a `ProcessPoolExecutor` with 2 workers. + ```python + from concurrent.futures import ProcessPoolExecutor + from amltk.scheduling import Scheduler + + executor = ProcessPoolExecutor(max_workers=2) + scheduler = Scheduler(executor=executor) + ``` + + === "Thread Pool Executor" + + ```python + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_threads(2) # (1)! + ``` + + 1. Explicitly use the `with_threads` method to create a `Scheduler` with + a `ThreadPoolExecutor` with 2 workers. + ```python + from concurrent.futures import ThreadPoolExecutor + from amltk.scheduling import Scheduler + + executor = ThreadPoolExecutor(max_workers=2) + scheduler = Scheduler(executor=executor) + ``` + + !!! danger "Why to not use threads" + + Python also defines a [`ThreadPoolExecutor`][concurrent.futures.ThreadPoolExecutor] + but there are some known drawbacks to offloading heavy compute to threads. Notably, + there's no way in python to terminate a thread from the outside while it's running. + +### :simple-dask: `dask` + +[Dask](https://distributed.dask.org/en/stable/) and the supporting extension [`dask.distributed`](https://distributed.dask.org/en/stable/) +provide a robust and flexible framework for scheduling compute across workers. + +!!! example + + ```python hl_lines="5" + from dask.distributed import Client + from amltk.scheduling import Scheduler + + client = Client(...) + executor = client.get_executor() + scheduler = Scheduler(executor=executor) + ``` + +### :simple-dask: `dask-jobqueue` + +[`dask-jobqueue`](https://jobqueue.dask.org/en/latest/) is a package +for scheduling jobs across common clusters setups such as +PBS, Slurm, MOAB, SGE, LSF, and HTCondor. + +Please see the `dask-jobqueue` [documentation](https://jobqueue.dask.org/en/latest/) +In particular, we only control the parameter `#!python n_workers=` to +use the [`adapt()`](https://jobqueue.dask.org/en/latest/index.html?highlight=adapt#adaptivity) +method, every other keyword is forwarded to the relative +[cluster implementation](https://jobqueue.dask.org/en/latest/api.html). + +In general, you should specify the requirements of each individual worker +and tune your load with the `#!python n_workers=` parameter. + +If you have any tips, tricks, working setups, gotchas, please feel free +to leave a PR or simply an issue! + +??? tip "Usage" + + + === "Slurm" + + ```python hl_lines="3 4 5 6 7 8 9" + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_slurm( + n_workers=10, # (1)! + queue=..., + cores=4, + memory="6 GB", + walltime="00:10:00" + ) + ``` + + 1. The `n_workers` parameter is used to set the number of workers + to start with. + The [`adapt()`](https://jobqueue.dask.org/en/latest/index.html?highlight=adapt#adaptivity) + method will be called on the cluster to dynamically scale up to `#!python n_workers=` based on + the load. + The `with_slurm` method will create a [`SLURMCluster`][dask_jobqueue.SLURMCluster] + and pass it to the `Scheduler` constructor. + ```python hl_lines="10" + from dask_jobqueue import SLURMCluster + from amltk.scheduling import Scheduler + + cluster = SLURMCluster( + queue=..., + cores=4, + memory="6 GB", + walltime="00:10:00" + ) + cluster.adapt(max_workers=10) + executor = cluster.get_client().get_executor() + scheduler = Scheduler(executor=executor) + ``` + + !!! warning "Running outside the login node" + + If you're running the scheduler itself in a job, this may not + work on some cluster setups. The scheduler itself is lightweight + and can run on the login node without issue. + However you should make sure to offload heavy computations + to a worker. + + If you get it to work, for example in an interactive job, please + let us know! + + !!! info "Modifying the launch command" + + On some cluster commands, you'll need to modify the launch command. + You can use the following to do so: + + ```python + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_slurm(n_workers=..., submit_command="sbatch --extra" + ``` + + === "Others" + + Please see the `dask-jobqueue` [documentation](https://jobqueue.dask.org/en/latest/) + and the following methods: + + * [`Scheduler.with_pbs()`][amltk.scheduling.Scheduler.with_pbs] + * [`Scheduler.with_lsf()`][amltk.scheduling.Scheduler.with_lsf] + * [`Scheduler.with_moab()`][amltk.scheduling.Scheduler.with_moab] + * [`Scheduler.with_sge()`][amltk.scheduling.Scheduler.with_sge] + * [`Scheduler.with_htcondor()`][amltk.scheduling.Scheduler.with_htcondor] + +### :octicons-gear-24: `loky` + +[Loky](https://loky.readthedocs.io/en/stable/API.html) is the default backend executor behind +[`joblib`](https://joblib.readthedocs.io/en/stable/), the parallelism that +powers scikit-learn. + +??? tip "Usage" + + === "Simple" + + ```python + from amltk import Scheduler + + # Pass any arguments you would pass to `loky.get_reusable_executor` + scheduler = Scheduler.with_loky(...) + ``` + + + === "Explicit" + + ```python + import loky + from amltk import Scheduler + + scheduler = Scheduler(executor=loky.get_reusable_executor(...)) + ``` + +??? warning "BLAS numeric backend" + + The loky executor seems to pick up on a different BLAS library (from scipy) + which is different than those used by jobs from something like a `ProcessPoolExecutor`. + + This is likely not to matter for a majority of use-cases. + +### :simple-ray: `ray` + +[Ray](https://docs.ray.io/en/master/) is an open-source unified compute framework that makes it easy +to scale AI and Python workloads +— from reinforcement learning to deep learning to tuning, +and model serving. + +!!! todo "In progress" + + Ray is currently in the works of supporting the Python + `Executor` interface. See this [PR](https://github.com/ray-project/ray/pull/30826) + for more info. + +### :simple-apacheairflow: `airflow` + +[Airflow](https://airflow.apache.org/) is a platform created by the community to programmatically author, +schedule and monitor workflows. Their list of integrations to platforms is endless +but features compute platforms such as Kubernetes, AWS, Microsoft Azure and +GCP. + +!!! todo "In progress" + + We plan to support `airflow` in the future. If you'd like to help + out, please reach out to us! + +### :material-debug-step-over: Debugging + +Sometimes you'll need to debug what's going on and remove the noise +of processes and parallelism. For this, we have implemented a very basic +[`SequentialExecutor`][amltk.scheduling.SequentialExecutor] to run everything +in a sequential manner! + +=== "Easy" + + ```python + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_sequential() + ``` + +=== "Explicit" + + ```python + from amltk.scheduling import Scheduler, SequetialExecutor + + scheduler = Scheduler(executor=SequentialExecutor()) + ``` + +!!! warning "Recursion" + + If you use The `SequentialExecutor`, be careful that the stack + of function calls can get quite large, quite quick. If you are + using this for debugging, keep the number of submitted tasks + from callbacks small and focus in on debugging. If using this + for sequential ordering of operations, prefer to use + `with_processes(1)` as this will still maintain order but not + have these stack issues. diff --git a/docs/reference/scheduling/plugins.md b/docs/reference/scheduling/plugins.md new file mode 100644 index 00000000..80e94020 --- /dev/null +++ b/docs/reference/scheduling/plugins.md @@ -0,0 +1,58 @@ +## Plugins + +Plugins are a way to modify a [`Task`][amltk.scheduling.task.Task], to add new functionality +or change the behaviour of what goes on in the function that is dispatched to the +[`Scheduler`][amltk.scheduling.Scheduler]. + +Some plugins will also add new `@event`s to a task, which can be used to respond accordingly to +something that may have occured with your task. + +You can add a plugin to a [`Task`](site:reference/tasks/index.md) as so: + +```python exec="true" html="true" source="material-block" +from amltk.scheduling import Task, Scheduler +from amltk.scheduling.plugins import Limiter + +def some_function(x: int) -> int: + return x * 2 + +scheduler = Scheduler.with_processes(1) + +# When creating a task with the scheduler +task = scheduler.task(some_function, plugins=[Limiter(max_calls=10)]) + + +# or directly to a Task +task = Task(some_function, scheduler=scheduler, plugins=[Limiter(max_calls=10)]) +from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide +``` + +### Limiter +::: amltk.scheduling.plugins.limiter + options: + members: False + +### Pynisher +::: amltk.scheduling.plugins.pynisher + options: + members: False + +### Comm +::: amltk.scheduling.plugins.comm + options: + members: False + +### ThreadPoolCTL +::: amltk.scheduling.plugins.threadpoolctl + options: + members: False + +### Warning Filter +::: amltk.scheduling.plugins.warning_filter + options: + members: False + +### Creating Your Own Plugin +::: amltk.scheduling.plugins.plugin + options: + members: False diff --git a/docs/reference/scheduling/scheduler.md b/docs/reference/scheduling/scheduler.md new file mode 100644 index 00000000..d5a6a2f9 --- /dev/null +++ b/docs/reference/scheduling/scheduler.md @@ -0,0 +1,5 @@ +## Scheduler + +::: amltk.scheduling.scheduler + options: + members: False diff --git a/docs/reference/scheduling/task.md b/docs/reference/scheduling/task.md new file mode 100644 index 00000000..f7fdd111 --- /dev/null +++ b/docs/reference/scheduling/task.md @@ -0,0 +1,5 @@ +## Tasks + +::: amltk.scheduling.task + options: + members: False diff --git a/docs/reference/sklearn.md b/docs/reference/sklearn.md deleted file mode 100644 index 9116ff76..00000000 --- a/docs/reference/sklearn.md +++ /dev/null @@ -1,236 +0,0 @@ -# Scikit-learn -Scikit-learn is a library for classical machine learning, -implementing many of the time-tested, non-deep, methods for -machine learning. It includes many models, hyperparameters -for these models and is its own toolkit for evaluating -these models. - -We extend these capabilities with what we found helpful -during development of AutoML tools such as -[AutoSklearn](https://automl.github.io/auto-sklearn/master/). - -!!! note "Threads and Multiprocessing" - - If running multiple trainings across multiple processes, please - also check out [`ThreadPoolCTL`](../reference/threadpoolctl.md) - -## Pipeline Builder -The `amltk.sklearn` module provides a `build_pipeline` function -that can be passed to [`Pipeline.build()`][amltk.pipeline.Pipeline.build] -to create a pure [sklearn.pipeline.Pipeline][] from your definition. - -### A simple Pipeline - -```python exec="true" source="material-block" result="python" title="A simple Pipeline" hl_lines="12" -from sklearn.impute import SimpleImputer -from sklearn.ensemble import RandomForestClassifier - -from amltk.pipeline import step, Pipeline -from amltk.sklearn import sklearn_pipeline - -pipeline = Pipeline.create( - step("imputer", SimpleImputer, config={"strategy": "median"}), - step("rf", RandomForestClassifier, config={"n_estimators": 10}), -) - -sklearn_pipeline = pipeline.build(builder=sklearn_pipeline) -print(sklearn_pipeline) -``` - -!!! note "Implicit building" - - By default, AutoML-Toolkit will try to infer how to build your - pipeline when you call [`Pipeline.build()`][amltk.pipeline.Pipeline.build]. - If all the components contained in the `Pipeline` are from - `sklearn`, then it will use the `sklearn_pipeline` automatically. - - You will rarely have to explicitly pass the `builder` argument. - -### Data Preprocessing -Below is a fairly complex pipeline which handles data-preprocessing, -feeding `#!python "categoricals"` through a -[SimpleImputer][sklearn.impute.SimpleImputer] and a -[OneHotEncoder][sklearn.preprocessing.OneHotEncoder] and -`#!python "numerics"` through a -[SimpleImputer][sklearn.impute.SimpleImputer], -[VarianceThreshold][sklearn.feature_selection.VarianceThreshold] -and possibly a [StandardScaler][sklearn.preprocessing.StandardScaler]. - -This is done using the [`split()`][amltk.pipeline.split] operator -from AutoML-toolkit, which allows you to split your data into -multiple branches and then combine them back together. - -You will notice for `#!python "feature_preprocessing"` split, we -pass the `item=` as a [ColumnTransformer][sklearn.compose.ColumnTransformer] -and for the `config=` parameter, two [make_column_selector][sklearn.compose.make_column_selector] -functions whose names match those of the two split paths, `#!python "categoricals"` -and `#!python "numerics"`. - -!!! quote "No Custom `amltk` Components" - - To keep things as compatible as possible with `sklearn`, we - do not provide any custom components. This lets use export - things easily and allows you to include your own sklearn - components in your pipeline without us getting in the way. - - -```python exec="true" source="material-block" result="python" title="A complex Pipeline" hl_lines="53 54 55 56 57" -import numpy as np -from sklearn.compose import ColumnTransformer, make_column_selector -from sklearn.ensemble import RandomForestClassifier -from sklearn.feature_selection import VarianceThreshold -from sklearn.impute import SimpleImputer -from sklearn.preprocessing import ( - FunctionTransformer, - MinMaxScaler, - OneHotEncoder, - StandardScaler, -) -from sklearn.svm import SVC - -from amltk.sklearn import sklearn_pipeline -from amltk.pipeline import step, split, choice, group, Pipeline - -pipeline = Pipeline.create( - split( - "feature_preprocessing", - group( - "categoricals", - step( - "categorical_imputer", - SimpleImputer, - space={ - "strategy": ["most_frequent", "constant"], - "fill_value": ["missing"], - }, - ) - | step( - "ohe", - OneHotEncoder, - space={ - "min_frequency": (0.01, 0.1), - "handle_unknown": ["ignore", "infrequent_if_exist"], - }, - config={"drop": "first"}, - ) - ), - group( - "numericals", - step("numerical_imputer", SimpleImputer, space={"strategy": ["mean", "median"]}) - | step( - "variance_threshold", - VarianceThreshold, - space={"threshold": (0.0, 0.2)}, - ) - | choice( - "scaler", - step("standard", StandardScaler), - step("minmax", MinMaxScaler), - step("passthrough", FunctionTransformer), - ) - ), - item=ColumnTransformer, - config={ - "categoricals": make_column_selector(dtype_include=object), - "numericals": make_column_selector(dtype_include=np.number), - }, - ), - choice( - "algorithm", - step("svm", SVC, space={"C": (0.1, 10.0)}, config={"probability": True}), - step( - "rf", - RandomForestClassifier, - space={ - "n_estimators": [10, 100], - "criterion": ["gini", "entropy", "log_loss"], - }, - ), - ) -) - -config = pipeline.sample() -configured_pipeline = pipeline.configure(config) - -# `builder=` is optional, we can detect it's an sklearn pipeline. -sklearn_pipeline = configured_pipeline.build(builder=sklearn_pipeline) -print(sklearn_pipeline) -``` - -## Data Splitting -We also provide two convenience functions often required in AutoML -systems, namely [`train_val_test_split()`][amltk.sklearn.train_val_test_split] -for creating three splits of your data and -[`split_data()`][amltk.sklearn.split_data] for creating an arbitrary number -of splits. - -### Train, Val, Test Split -This functions much similar to the sklearn -[`train_test_split()`][sklearn.model_selection.train_test_split] but produces one more -split, the validation split. - -Instead of passing in a `test_size=` parameter, you pass in a -`splits=` parameter, which declares the percentages of splits you -would like, e.g. `(0.5, 0.3, 0.2)` would indicate a train size of `50%`, -a val size of `30%` and a test size of `20%`. - -```python exec="true" source="material-block" result="python" title="Train, Val, Test Split" -from amltk.sklearn.data import train_val_test_split - -x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -y = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1] -train_x, train_y, val_x, val_y, test_x, test_y = train_val_test_split( - x, y, splits=(0.5, 0.3, 0.2), seed=42 -) - -print(train_x, train_y) -print(val_x, val_y) -print(test_x, test_y) -``` - -You may also use the `shuffle=` and `stratify=` parameters to -shuffle and stratify your data respectively. The `stratify=` argument -will respect the stratification across all 3 splits, ensuring they each -have a proportionate amount of each value in `stratify=`. - -```python exec="true" source="material-block" result="python" title="Train, Val, Test Split with Shuffle and Stratify" hl_lines="10 11" - -from amltk.sklearn.data import train_val_test_split - -x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] - -train_x, train_y, val_x, val_y, test_x, test_y = train_val_test_split( - x, y, - splits=(0.5, 0.3, 0.2), - stratify=y, - shuffle=True, - seed=42, -) - -print(train_x, train_y) -print(val_x, val_y) -print(test_x, test_y) -``` - -### Arbitrary Data Splitting -Sometimes you need to create more than 3 splits. For this we provide -[`split_data()`][amltk.sklearn.split_data], which has an identical function -signature, except the `splits=` you specify is a dictionary from the name -of the split to the percentage you wish. - -```python exec="true" source="material-block" result="python" title="Arbitrary Data Splitting" -from amltk.sklearn import split_data - -x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -y = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1] -splits = split_data(x, y, splits={"train": 0.5, "val": 0.3, "test": 0.2}, seed=42) - -train_x, train_y = splits["train"] -val_x, val_y = splits["val"] -test_x, test_y = splits["test"] - -print(train_x, train_y) -print(val_x, val_y) -print(test_x, test_y) -``` diff --git a/docs/reference/smac.md b/docs/reference/smac.md deleted file mode 100644 index e69de29b..00000000 diff --git a/docs/reference/threadpoolctl.md b/docs/reference/threadpoolctl.md deleted file mode 100644 index 6dae070e..00000000 --- a/docs/reference/threadpoolctl.md +++ /dev/null @@ -1,40 +0,0 @@ -# ThreadPoolCTL -Performing numerical operations in while multi-processing can create over-subscription -to threads by each process, especially when using numerical libraries like numpy, -scipy and sklearn. Specifically when training many sklearn models in different processes, -this can slow down training significantly with smaller datasets. - -!!! note - - Plugin is only available if `threadpoolctl` is installed. You can so - with `pip install amltk[threadpoolctl]`. - -```python exec="true" source="material-block" result="python" title="ThreadPoolCTLPlugin example" -from amltk.scheduling import Scheduler -from amltk.threadpoolctl import ThreadPoolCTLPlugin - -# Only used to verify, not needed if running -import threadpoolctl -import sklearn - -print("------ Before") -print(threadpoolctl.threadpool_info()) - -scheduler = Scheduler.with_processes(1) - -def f() -> None: - print("------ Inside") - print(threadpoolctl.threadpool_info()) -from amltk._doc import make_picklable; make_picklable(f) # markdown-exec: hide - -task = scheduler.task(f, plugins=ThreadPoolCTLPlugin(max_threads=1)) - -@scheduler.on_start -def start_task() -> None: - task() - -scheduler.run() - -print("------ After") -print(threadpoolctl.threadpool_info()) -``` diff --git a/docs/reference/wandb.md b/docs/reference/wandb.md deleted file mode 100644 index 3ceebe66..00000000 --- a/docs/reference/wandb.md +++ /dev/null @@ -1,182 +0,0 @@ -# Weights And Biases -[Weights and Biases](https://www.wandb.com/) is a tool for visualizing and tracking -your machine learning experiments. You can log gradients, metrics, model topology, and -more. - -This lets you easily track individual trials during -[optimization](../guides/optimization.md) of your model/pipeline. - -While there is no need to explicitly use the plugin, i.e. you can use -wandb however you're used to, we do provide a plugin that tracks -everything that's normally reported on a trial and construct the `run` for -you, making everything that small bit easier. - -!!! note "Note" - - You do not need to explicitly use our integration but we automate a lot - of the process. - -!!! warning "Overhead" - - While wandb does not have too much overhead, it does have some. These are - mostly on the order of 1-2 seconds which can be insignificant for longer - running trials but can be significant for short running trials. Particularly - avoid this if you're running many small evaluations. - -## Basic Usage without Plugin -To use wandb without the plugin, you can use the following code in your function -you wish to track: - -```python hl_lines="8 13 18 20" -import wandb - -from amltk import Scheduler, Trial - -def target_function(trial: Trial) -> Trial.Report: - x, y, z = trial.config["x"], trial.config["y"], trial.config["z"] - - run = wandb.init(project="my-project", config=trial.config, name=trial.name) # (1)! - - with trial.begin(): - for i in range(10): - loss = (i * x) + 3 * (i * y) - (i * z) ** 2 - run.log({"loss": loss}) # (2)! - - cost = y + z - trial.summary = {"cost": cost} - - run.summary["cost"] = cost # (3)! - - run.finish() # (4)! - - # Finally report the success - return trial.success(cost=cost) - -scheduler = Scheduler.with_processes(1) - -task = scheduler.task(target_function) - -configs = enumerate( - [ - {"x": 1.0, "y": 2.0, "z": 3.0}, - {"x": -1.0, "y": 2.0, "z": 3.0}, - {"x": -1.0, "y": 3.0, "z": 4.0}, - ], -) - - -@scheduler.on_start() -def launch() -> None: - i, config = next(configs) - trial = Trial(name=f"trial-{i}", config=config) - task(trial) - - -@task.on_done -def launch_next(_: Trial.Report) -> None: - i, config = next(configs, (None, None)) - if config is not None: - trial = Trial(name=f"trial-{i}", config=config) - task(trial) - - -scheduler.run() -``` - -1. Create a wandb run as you normally would with `init()`. -2. Log custom metrics. -3. Log any summary metrics. -4. Make sure to tell `wandb` that you've finished the run. - -This will create some basic wandb output for you - -![Image of wandb dashboard without plugin](../images/wandb_simple.jpg) - -!!! note "One Run Per Process" - - You can only create one run per Process as per - [wandb documentation.][https://docs.wandb.ai/guides/track/log/distributed-training#method-2-many-processes] - When you use any [`Scheduler`][amltk.scheduling.Scheduler] that utilizes - multiple processes, you should be fine, the one notable exception - is using a Scheduler with threads. - -## Basic Usage with Plugin -To use the wandb plugin, the only thing we need to do is create a -[`WandbPlugin`][amltk.wandb.WandbPlugin] and attach it to the actual -[`Trial`][amltk.Trial] with the following: - -```python hl_lines="3 4 5 6 7 8 13" -from amltk.wandb import WandbPlugin -from amltk.scheduling import Scheduler - -wandb_plugin = WandbPlugin( - project="amltk-test2", - group=..., - entity=..., - mode=..., -) - -scheduler = Scheduler.with_processes(1) -task = scheduler.task(target_function, plugins=wandb_plugin.trial_tracker()) -``` - -These lines above will automatically attach the wandb run to the trial under -`.plugins["wandb"]` if you need to access it explicitly. - -Finally, to use it, this function is much like the previous but with many of the -steps automated for you, adding in additonal information. - -```python hl_lines="5 10 13" -def target_function(trial: Trial) -> Trial.Report: - x, y, z = trial.config["x"], trial.config["y"], trial.config["z"] - - # The plugin will automatically attach the run to the trial - run = trial.plugins["wandb"] - - with trial.begin(): - for i in range(10): - loss = (i * x) + 3 * (i * y) - (i * z) ** 2 - run.log({"loss": loss}) # (1)! - - cost = y + z - trial.summary = {"cost": cost} # (2)! - - # Will automatically log the report to the run and finish it up - return trial.success(cost=cost) -``` -1. You can still manually log what you like to the run as you normally would do - with weights and biases. -2. Any numeric values stored in the summary will be automatically put in the - summary report of the `wandb.Run`. - - -This will additionally report a lot of the metric generated in the report for further -possible analysis. - -![Image of wandb dashboard with plugin](../images/wandb_plugin.jpg) - -## Advanced Usage -Weights and biases gives you the ability to tag, add notes and additionally mark -specific runs. -To do so, you can pass in a function to modify the -[`WandbParams`][amltk.wandb.WandbParams] that are used to create the run. - -The below example will add tag runs based on if the `x` value is positive or negative. - -!!! example "Modifying the run" - - ```python hl_lines="1 2 3 7" - def modify_run(trial: Trial, params: WandbParams) -> WandbParams: - tags = ["positive"] if trial.config["x"] > 0 else ["negative"] - return params.modify(tags=tags) - - # Later - task = scheduler.task( - target_function, - plugins=[wandb_plugin.trial_tracker(modify=modify_run)], - ) - ``` - -## More -If you'd like to see more functionality, we'd love to hear from you, please -open an issue on our Github or submit a PR! diff --git a/examples/hpo_with_ensembling.py b/examples/hpo_with_ensembling.py index 48c4fba6..f032c851 100644 --- a/examples/hpo_with_ensembling.py +++ b/examples/hpo_with_ensembling.py @@ -38,20 +38,20 @@ import shutil from asyncio import Future +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable +from typing import Any import numpy as np import openml -from sklearn.compose import ColumnTransformer, make_column_selector +from sklearn.compose import make_column_selector from sklearn.ensemble import RandomForestClassifier from sklearn.feature_selection import VarianceThreshold from sklearn.impute import SimpleImputer from sklearn.metrics import accuracy_score from sklearn.neural_network import MLPClassifier from sklearn.preprocessing import ( - FunctionTransformer, LabelEncoder, MinMaxScaler, OneHotEncoder, @@ -63,10 +63,10 @@ from amltk.data.conversions import probabilities_to_classes from amltk.ensembling.weighted_ensemble_caruana import weighted_ensemble_caruana from amltk.optimization import History, Trial -from amltk.pipeline import Pipeline, choice, group, split, step +from amltk.optimization.optimizers.smac import SMACOptimizer +from amltk.pipeline import Choice, Component, Sequential, Split from amltk.scheduling import Scheduler from amltk.sklearn.data import split_data -from amltk.smac import SMACOptimizer from amltk.store import PathBucket """ @@ -117,68 +117,40 @@ def get_dataset(seed: int) -> tuple[np.ndarray, ...]: For more on definitions of pipelines, see the [Pipeline](site:guides/pipeline.md) guide. """ -pipeline = Pipeline.create( - split( - "feature_preprocessing", - group( # (3)! - "categoricals", - step( - "category_imputer", - SimpleImputer, - space={ - "strategy": ["most_frequent", "constant"], - "fill_value": ["missing"], - }, - ) - | step( - "ohe", - OneHotEncoder, - space={ - "min_frequency": (0.01, 0.1), - "handle_unknown": ["ignore", "infrequent_if_exist"], - }, - config={"drop": "first"}, - ), - ), - group( # (2)! - "numerics", - step( - "numerical_imputer", - SimpleImputer, - space={"strategy": ["mean", "median"]}, - ) - | step( - "variance_threshold", - VarianceThreshold, - space={"threshold": (0.0, 0.2)}, - ) - | choice( - "scaler", - step("standard", StandardScaler), - step("minmax", MinMaxScaler), - step("robust", RobustScaler), - step("passthrough", FunctionTransformer), - ), - ), - item=ColumnTransformer, +pipeline = ( + Sequential(name="Pipeline") + >> Split( + { + "categories": [ + SimpleImputer(strategy="constant", fill_value="missing"), + Component( + OneHotEncoder, + space={ + "min_frequency": (0.01, 0.1), + "handle_unknown": ["ignore", "infrequent_if_exist"], + }, + config={"drop": "first"}, + ), + ], + "numbers": [ + Component(SimpleImputer, space={"strategy": ["mean", "median"]}), + Component(VarianceThreshold, space={"threshold": (0.0, 0.2)}), + Choice(StandardScaler, MinMaxScaler, RobustScaler, name="scaler"), + ], + }, + name="feature_preprocessing", config={ - "categoricals": make_column_selector(dtype_include=object), - "numerics": make_column_selector(dtype_include=np.number), + "categories": make_column_selector(dtype_include=object), + "numbers": make_column_selector(dtype_include=np.number), }, - ), - choice( # (1)! - "algorithm", - step("svm", SVC, space={"C": (0.1, 10.0)}, config={"probability": True}), - step( - "rf", + ) + >> Choice( # (1)! + Component(SVC, space={"C": (0.1, 10.0)}, config={"probability": True}), + Component( RandomForestClassifier, - space={ - "n_estimators": [10, 100], - "criterion": ["gini", "entropy", "log_loss"], - }, + space={"n_estimators": (10, 100), "criterion": ["gini", "log_loss"]}, ), - step( - "mlp", + Component( MLPClassifier, space={ "activation": ["identity", "logistic", "relu"], @@ -186,11 +158,11 @@ def get_dataset(seed: int) -> tuple[np.ndarray, ...]: "learning_rate": ["constant", "invscaling", "adaptive"], }, ), - ), + ) ) print(pipeline) -print(pipeline.space()) +print(pipeline.search_space("configspace")) # 1. Here we define a choice of algorithms to use where each entry is a possible # algorithm to use. Each algorithm is defined by a step, which is a @@ -231,7 +203,7 @@ def target_function( trial: Trial, /, bucket: PathBucket, - pipeline: Pipeline, + pipeline: Sequential, ) -> Trial.Report: X_train, X_val, X_test, y_train, y_val, y_test = ( # (1)! bucket["X_train.csv"].load(), @@ -242,7 +214,7 @@ def target_function( bucket["y_test.npy"].load(), ) pipeline = pipeline.configure(trial.config) # (2)! - sklearn_pipeline = pipeline.build() # + sklearn_pipeline = pipeline.build("sklearn") # with trial.begin(): # (3)! sklearn_pipeline.fit(X_train, y_train) @@ -312,9 +284,9 @@ def target_function( of each trial in the ensemble. We could of course add extra functionality to the Ensemble, give it references -to the [`PathBucket`][amltk.store.PathBucket] and [`Pipeline`][amltk.pipeline.Pipeline] -objects, and even add methods to train the ensemble, but for the sake of -simplicity we will leave it as is. +to the [`PathBucket`][amltk.store.PathBucket] and the pipeline objects, +and even add methods to train the ensemble, but for the sake of simplicity we will +leave it as is. """ @@ -388,7 +360,10 @@ def _score(_targets: np.ndarray, ensembled_probabilities: np.ndarray) -> float: ) scheduler = Scheduler.with_processes() # (3)! -optimizer = SMACOptimizer.create(space=pipeline.space(), seed=seed) # (4)! +optimizer = SMACOptimizer.create( + space=pipeline.search_space("configspace"), + seed=seed, +) # (4)! task = scheduler.task(target_function) # (6)! ensemble_task = scheduler.task(create_ensemble) # (7)! @@ -420,14 +395,14 @@ def add_to_history(future: Future, report: Trial.Report) -> None: def launch_ensemble_task(future: Future, report: Trial.Report) -> None: """When a task successfully completes, launch an ensemble task.""" if report.status is Trial.Status.SUCCESS: - ensemble_task(trial_history, bucket) + ensemble_task.submit(trial_history, bucket) @task.on_result def launch_another_task(*_: Any) -> None: """When we get a report, evaluate another trial.""" trial = optimizer.ask() - task(trial, bucket=bucket, pipeline=pipeline) + task.submit(trial, bucket=bucket, pipeline=pipeline) @ensemble_task.on_result @@ -436,7 +411,6 @@ def save_ensemble(future: Future, ensemble: Ensemble) -> None: ensembles.append(ensemble) -@task.on_exception @ensemble_task.on_exception def print_ensemble_exception(future: Future[Any], exception: BaseException) -> None: """When an exception occurs, log it and stop.""" @@ -444,10 +418,17 @@ def print_ensemble_exception(future: Future[Any], exception: BaseException) -> N scheduler.stop() +@task.on_exception +def print_task_exception(future: Future[Any], exception: BaseException) -> None: + """When an exception occurs, log it and stop.""" + print(exception) + scheduler.stop() + + @scheduler.on_timeout def run_last_ensemble_task() -> None: """When the scheduler is empty, run the last ensemble task.""" - ensemble_task(trial_history, bucket) + ensemble_task.submit(trial_history, bucket) scheduler.run(timeout=5, wait=True) # (9)! @@ -467,8 +448,8 @@ def run_last_ensemble_task() -> None: # 3. We use [`Scheduler.with_processes()`][amltk.scheduling.Scheduler.with_processes] # create a [`Scheduler`][amltk.scheduling.Scheduler] that runs everything # in a different process. You can of course use a different backend if you want. -# 4. We use [`SMACOptimizer.create()`][amltk.smac.SMACOptimizer.create] to create a -# [`SMACOptimizer`][amltk.smac.SMACOptimizer] given the space from the pipeline +# 4. We use [`SMACOptimizer.create()`][amltk.optimization.optimizers.smac.SMACOptimizer.create] to create a +# [`SMACOptimizer`][amltk.optimization.optimizers.smac.SMACOptimizer] given the space from the pipeline # to optimize over. # 6. We create a [`Task`][amltk.scheduling.Task] that will run our objective, passing # in the function to run and the scheduler for where to run it diff --git a/examples/multifidelity.py b/examples/multifidelity.py deleted file mode 100644 index 0f15f6d4..00000000 --- a/examples/multifidelity.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Simple HPO loop -# Flags: doc-Runnable - -# TODO -""" -from __future__ import annotations - -from amltk.pipeline import choice, step -from amltk.smac.optimizer import SMACOptimizer - -pipeline = choice( - "choice", - step( - "x", - object(), - space={"a": [1, 2, 3]}, - fidelities={"b": (1, 10)}, - ), - step( - "y", - object(), - space={"a": [1, 2, 3]}, - fidelities={"b": (1.0, 10)}, - ), - step( - "z", - object(), - space={"a": [1, 2, 3]}, - fidelities={"b": (0.0, 1.0)}, - ), -) - -print(pipeline.linearized_fidelity(1)) - -optimizer = SMACOptimizer.create( - space=pipeline.space(), - seed=0, - fidelities=pipeline.fidelities(), -) - -for _i in range(8): - trial = optimizer.ask() - assert trial.fidelities is not None - budget = trial.fidelities["budget"] - print(budget) - - selected_fidelities = pipeline.linearized_fidelity(budget) - print(selected_fidelities) - - config = {**trial.config, **selected_fidelities} - selected_pipeline = pipeline.configure(config) - print(selected_pipeline) diff --git a/examples/simple_hpo.py b/examples/simple_hpo.py index 98b18158..f9b43652 100644 --- a/examples/simple_hpo.py +++ b/examples/simple_hpo.py @@ -27,20 +27,17 @@ import numpy as np import openml -from sklearn.compose import ColumnTransformer, make_column_selector +from sklearn.compose import make_column_selector from sklearn.ensemble import RandomForestClassifier from sklearn.impute import SimpleImputer from sklearn.metrics import accuracy_score -from sklearn.preprocessing import ( - LabelEncoder, - OneHotEncoder, -) +from sklearn.preprocessing import LabelEncoder, OneHotEncoder from amltk.optimization import History, Trial -from amltk.pipeline import Pipeline, split, step +from amltk.optimization.optimizers.smac import SMACOptimizer +from amltk.pipeline import Component, Node, Sequential, Split from amltk.scheduling import Scheduler from amltk.sklearn.data import split_data -from amltk.smac import SMACOptimizer from amltk.store import PathBucket """ @@ -81,46 +78,34 @@ def get_dataset( For more on definitions of pipelines, see the [Pipeline](site:guides/pipeline.md) guide. """ -categorical_imputer = step( - "categoricals", - SimpleImputer, - config={ - "strategy": "constant", - "fill_value": "missing", - }, -) -one_hot_encoding = step("ohe", OneHotEncoder, config={"drop": "first"}) - -numerical_imputer = step( - "numerics", - SimpleImputer, - space={"strategy": ["mean", "median"]}, -) - -pipeline = Pipeline.create( - split( - "feature_preprocessing", - categorical_imputer | one_hot_encoding, - numerical_imputer, - item=ColumnTransformer, +pipeline = ( + Sequential(name="Pipeline") + >> Split( + { + "categories": [ + SimpleImputer(strategy="constant", fill_value="missing"), + OneHotEncoder(drop="first"), + ], + "numbers": Component(SimpleImputer, space={"strategy": ["mean", "median"]}), + }, config={ - "categoricals": make_column_selector(dtype_include=object), - "numerics": make_column_selector(dtype_include=np.number), + "categories": make_column_selector(dtype_include=object), + "numbers": make_column_selector(dtype_include=np.number), }, - ), - step( - "rf", + name="feature_preprocessing", + ) + >> Component( RandomForestClassifier, space={ "n_estimators": (10, 100), "max_features": (0.0, 1.0), "criterion": ["gini", "entropy", "log_loss"], }, - ), + ) ) print(pipeline) -print(pipeline.space()) +print(pipeline.search_space("configspace")) """ ## Target Function @@ -128,8 +113,7 @@ def get_dataset( We also pass in a [`PathBucket`][amltk.store.Bucket] which is a dict-like view of the file system, where we have our dataset stored. -We also pass in our [`Pipeline`][amltk.pipeline.Pipeline] representation of our -pipeline, which we will use to build our sklearn pipeline with a specific +We also pass in our pipeline, which we will use to build our sklearn pipeline with a specific `trial.config` suggested by the [`Optimizer`][amltk.optimization.Optimizer]. """ @@ -138,7 +122,7 @@ def target_function( trial: Trial, /, bucket: PathBucket, - _pipeline: Pipeline, + _pipeline: Node, ) -> Trial.Report: # Load in data X_train, X_val, X_test, y_train, y_val, y_test = ( @@ -152,7 +136,7 @@ def target_function( # Configure the pipeline with the trial config before building it. configured_pipeline = _pipeline.configure(trial.config) - sklearn_pipeline = configured_pipeline.build() + sklearn_pipeline = configured_pipeline.build("sklearn") # Fit the pipeline, indicating when you want to start the trial timing and error # catchnig. @@ -210,7 +194,7 @@ def target_function( Now we can run the whole thing. We will use the [`Scheduler`][amltk.scheduling.Scheduler] -to run the optimization, and the [`SMACOptimizer`][amltk.smac.SMACOptimizer] to +to run the optimization, and the [`SMACOptimizer`][amltk.optimization.optimizers.smac.SMACOptimizer] to to optimize the pipeline. ### Getting and storing data @@ -251,13 +235,16 @@ def target_function( Please check out the full [guides](site:guides/index.md) to learn more! -We then create an [`SMACOptimizer`][amltk.smac.SMACOptimizer] which will +We then create an [`SMACOptimizer`][amltk.optimization.optimizers.smac.SMACOptimizer] which will optimize the pipeline. We pass in the space of the pipeline, which is the space of the hyperparameters we want to optimize. """ scheduler = Scheduler.with_processes(2) -optimizer = SMACOptimizer.create(space=pipeline.space(), seed=seed) +parser = SMACOptimizer.preferred_parser() +space = pipeline.search_space(parser=parser) + +optimizer = SMACOptimizer.create(space=space, seed=seed) """ Next we create a [`Task`][amltk.Task], passing in the function we @@ -282,7 +269,7 @@ def target_function( def launch_initial_tasks() -> None: """When we start, launch `n_workers` tasks.""" trial = optimizer.ask() - task(trial, bucket=bucket, _pipeline=pipeline) + task.submit(trial, bucket=bucket, _pipeline=pipeline) """ @@ -330,7 +317,7 @@ def add_to_history(future: Future, report: Trial.Report) -> None: def launch_another_task(*_: Any) -> None: """When we get a report, evaluate another trial.""" trial = optimizer.ask() - task(trial, bucket=bucket, _pipeline=pipeline) + task.submit(trial, bucket=bucket, _pipeline=pipeline) """ @@ -359,4 +346,3 @@ def stop_scheduler_on_cancelled(_: Any) -> None: print("Trial history:") history_df = trial_history.df() print(history_df) - diff --git a/justfile b/justfile index 26f80705..301745b5 100644 --- a/justfile +++ b/justfile @@ -30,7 +30,7 @@ docs exec_doc_code="true" example="None" offline="false": python -m webbrowser -t "http://127.0.0.1:8000/" AMLTK_DOC_RENDER_EXAMPLES={{example}} \ AMLTK_DOCS_OFFLINNE={{offline}} \ - AMLTK_EXEC_DOCS={{exec_doc_code}} mkdocs serve --watch-theme + AMLTK_EXEC_DOCS={{exec_doc_code}} mkdocs serve --watch-theme --dirtyreload # https://github.com/pawamoy/markdown-exec/issues/19 # Bump the version and generate the changelog based off commit messages diff --git a/mkdocs.yml b/mkdocs.yml index e108157a..a3c9b418 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -157,24 +157,32 @@ nav: - "guides/scheduling.md" - Reference: - "reference/index.md" - - "reference/buckets.md" - - "reference/call_limiter.md" - - "reference/comms.md" - - "reference/configspace.md" - - "reference/dask-jobqueue.md" - - "reference/data.md" - - "reference/history.md" - - "reference/metalearning.md" - - "reference/optuna.md" - - "reference/pynisher.md" - - "reference/sklearn.md" - - "reference/smac.md" - - "reference/threadpoolctl.md" - - "reference/wandb.md" - - "reference/prebuilt_pipelines.md" - - "reference/neps.md" - - "reference/plugins.md" - - Examples: "examples/" + - Scheduling: + - "reference/scheduling/scheduler.md" + - "reference/scheduling/executors.md" + - "reference/scheduling/task.md" + - "reference/scheduling/plugins.md" + - "reference/scheduling/events.md" + - Pipelines: + - "reference/pipelines/pipeline.md" + - "reference/pipelines/spaces.md" + - "reference/pipelines/builders.md" + # - "reference/pipelines/prebuilts.md" + - Optimization: + - "reference/optimization/optimizers.md" + - "reference/optimization/trials.md" + - "reference/optimization/history.md" + - "reference/optimization/profiling.md" + - Data: + - "reference/data/index.md" + - "reference/data/buckets.md" + - Meta-Learning: + - "reference/metalearning/index.md" + + - Examples: + - "examples/index.md" + - "examples/simple_hpo.md" + - "examples/hpo_with_ensembling.md" # Auto generated with docs/examples_runner.py - API: "api/" # Auto generated with docs/api_generator.py diff --git a/pyproject.toml b/pyproject.toml index 3b713316..21cccbdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,11 +5,10 @@ dependencies = [ "typing_extensions", # Better typing "more_itertools", # Better iteration "psutil", # Used for process termination of executors - "attrs", "pandas", "numpy", ] -requires-python = ">=3.8" +requires-python = ">=3.10" authors = [{ name = "Eddie Bergman", email = "eddiebergmanhs@gmail.com" }] readme = "README.md" description = "AutoML Toolkit: a toolkit for building automl system" @@ -113,15 +112,13 @@ tag_format = "v$major.$minor.$patch$prerelease" update_changelog_on_bump = true version_files = ["pyproject.toml:version", "src/amltk/__version__.py"] -[tool.black] -target-version = ['py38'] - # https://github.com/charliermarsh/ruff [tool.ruff] -target-version = "py38" +target-version = "py310" line-length = 88 show-source = true src = ["src", "tests", "examples"] +extend-safe-fixes = ["ALL"] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" @@ -188,6 +185,7 @@ ignore = [ "W292", # No newline at end of file "PLC1901", # "" can be simplified to be falsey "TCH003", # Move stdlib import into TYPE_CHECKING + "B010", # Do not use `setattr` # These tend to be lighweight and confuse pyright ] @@ -249,7 +247,7 @@ convention = "google" max-args = 10 # Changed from default of 5 [tool.mypy] -python_version = "3.8" +python_version = "3.10" packages = ["src/amltk", "tests"] show_error_codes = true @@ -274,6 +272,7 @@ module = ["tests.*"] disallow_untyped_defs = false # Sometimes we just want to ignore verbose types disallow_untyped_decorators = false # Test decorators are not properly typed disallow_incomplete_defs = false # Sometimes we just want to ignore verbose types +disable_error_code = ["var-annotated"] [[tool.mypy.overrides]] module = [ @@ -293,7 +292,7 @@ ignore_missing_imports = true [tool.pyright] include = ["src", "tests"] -pythonVersion = "3.8" +pythonVersion = "3.10" typeCheckingMode = "strict" strictListInference = true @@ -315,3 +314,4 @@ reportUnknownVariableType = false reportUnknownArgumentType = false reportUnknownLambdaType = false reportPrivateUsage = false +reportUnnecessaryCast = false diff --git a/src/amltk/__init__.py b/src/amltk/__init__.py index 32fcbf51..790b72a2 100644 --- a/src/amltk/__init__.py +++ b/src/amltk/__init__.py @@ -1,21 +1,31 @@ from amltk import options -from amltk.events import Emitter, Subscriber from amltk.optimization import ( History, IncumbentTrace, Optimizer, - RandomSearch, Trace, Trial, ) -from amltk.pipeline import Pipeline, choice, group, request, searchable, split, step +from amltk.pipeline import ( + Choice, + Component, + Fixed, + Join, + Node, + Sequential, + Split, + request, +) from amltk.scheduling import ( - CallLimiter, Comm, + Emitter, + Event, + Limiter, + Plugin, Scheduler, SequentialExecutor, + Subscriber, Task, - TaskPlugin, ) from amltk.store import ( Bucket, @@ -33,39 +43,40 @@ ) __all__ = [ - "Pipeline", - "split", - "step", - "group", - "choice", - "searchable", - "Scheduler", - "Comm", - "Task", "Bucket", - "Drop", - "PathBucket", - "PathLoader", - "Loader", "ByteLoader", + "Choice", + "Comm", + "Component", + "Drop", + "Emitter", + "Event", + "Fixed", + "History", + "IncumbentTrace", + "Join", "JSONLoader", + "Limiter", + "Loader", + "Node", "NPYLoader", + "Optimizer", + "options", + "PathBucket", + "PathLoader", "PDLoader", "PickleLoader", - "TxtLoader", - "YAMLLoader", - "History", - "IncumbentTrace", - "Optimizer", - "RandomSearch", - "Trace", - "Trial", - "CallLimiter", + "Plugin", + "request", "Scheduler", - "TaskPlugin", - "Subscriber", + "Scheduler", + "Sequential", "SequentialExecutor", - "Emitter", - "options", - "request", + "Split", + "Subscriber", + "Task", + "Trace", + "Trial", + "TxtLoader", + "YAMLLoader", ] diff --git a/src/amltk/asyncm.py b/src/amltk/_asyncm.py similarity index 97% rename from src/amltk/asyncm.py rename to src/amltk/_asyncm.py index 17cfe9f1..daa9fdda 100644 --- a/src/amltk/asyncm.py +++ b/src/amltk/_asyncm.py @@ -4,6 +4,7 @@ import asyncio from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from typing_extensions import override if TYPE_CHECKING: from multiprocessing.connection import Connection @@ -65,6 +66,7 @@ def __init__(self, **kwargs: Any) -> None: self.msg: str | None = None self.exception: BaseException | None = None + @override def set( self, msg: str | None = None, @@ -80,6 +82,7 @@ def set( self.exception = exception super().set() + @override def clear(self) -> None: """Clear the event and clear the context.""" self.msg = None diff --git a/src/amltk/_doc.py b/src/amltk/_doc.py index 80e610e7..9ce294ca 100644 --- a/src/amltk/_doc.py +++ b/src/amltk/_doc.py @@ -1,13 +1,50 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Callable, Literal +from collections.abc import Callable +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: from rich.console import RenderableType -DEFAULT_MKDOCS_CODE_BLOCK_WIDTH = 82 -FONTSIZE_RICH_HTML = "0.5rem" +DEFAULT_MKDOCS_CODE_BLOCK_WIDTH = 80 + + +SKLEARN_LINK = "https://www.scikit-learn.org/stable/modules/generated/{0}.html" + + +def sklearn_link_generator(name: str) -> str: + """Generate a link for a sklearn function.""" + reduced_name = ".".join(s for s in name.split(".") if not s.startswith("_")) + return SKLEARN_LINK.format(reduced_name) + + +@lru_cache +def _try_get_link(fully_scoped_name: str) -> str | None: + """Try to get a link for a string. + + Expects fully qualified import names. + """ + from amltk.options import _amltk_options + + links = _amltk_options.get("links", {}) + + for k, v in links.items(): + if fully_scoped_name.startswith(k): + if isinstance(v, str): + return v + if callable(v): + return v(fully_scoped_name) + + return None + + +def link(obj: Any) -> str | None: + """Try to get a link for an object.""" + from amltk._functional import fullname + + return _try_get_link(fullname(obj)) def make_picklable(thing: Any, name: str | None = None) -> None: @@ -46,8 +83,8 @@ def as_rich_svg( SIZES = { "very-small": "0.5rem", - "small": "0.75rem", - "medium": "0.8rem", + "small": "0.7rem", + "medium": "0.75rem", "large": "1rem", } @@ -79,5 +116,18 @@ def doc_print( ) -> None: if output == "svg": _print(as_rich_svg(*renderable, title=title, width=width)) + elif len(renderable) == 1: + try: + from sklearn.base import BaseEstimator, TransformerMixin + from sklearn.pipeline import Pipeline + + if isinstance( + renderable[0], + Pipeline | BaseEstimator() | TransformerMixin(), + ): + _print(renderable[0]._repr_html_()) # type: ignore + return + except Exception: # noqa: BLE001 + _print(as_rich_html(*renderable, width=width, fontsize=fontsize)) else: _print(as_rich_html(*renderable, width=width, fontsize=fontsize)) diff --git a/src/amltk/functional.py b/src/amltk/_functional.py similarity index 89% rename from src/amltk/functional.py rename to src/amltk/_functional.py index ecbc41c6..1e4d3c73 100644 --- a/src/amltk/functional.py +++ b/src/amltk/_functional.py @@ -4,22 +4,17 @@ """ from __future__ import annotations +from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence from functools import partial, reduce from inspect import isclass from itertools import count from typing import ( Any, - Callable, Generic, - Hashable, - Iterable, - Iterator, - Mapping, - Sequence, + TypeAlias, TypeVar, Union, ) -from typing_extensions import TypeAlias T = TypeVar("T") V = TypeVar("V") @@ -33,7 +28,7 @@ def prefix_keys(d: Mapping[str, V], prefix: str) -> dict[str, V]: """Prefix the keys of a mapping. ```python exec="true" source="material-block" result="python" title="prefix_keys" - from amltk.functional import prefix_keys + from amltk._functional import prefix_keys d = {"a": 1, "b": 2} print(prefix_keys(d, "c:")) @@ -46,7 +41,7 @@ def dict_get_not_none(d: Mapping[K, V], key: K, default: V2) -> V | V2: """Get a value from a dictionary, or a default value if it is None. ```python exec="true" source="material-block" result="python" title="dict_get_not_none" - from amltk.functional import dict_get_not_none + from amltk._functional import dict_get_not_none d = {"a": None, "b": 2} print(dict_get_not_none(d, "a", 1)) # d.get("a", 1) would return None @@ -61,7 +56,7 @@ def mapping_select(d: Mapping[str, V], prefix: str) -> dict[str, V]: """Select a subset of a mapping. ```python exec="true" source="material-block" result="python" title="mapping_select" - from amltk.functional import mapping_select + from amltk._functional import mapping_select d = {"a:b:c": 1, "a:b:d": 2, "c:elephant": 3} print(mapping_select(d, "a:b:")) @@ -82,7 +77,7 @@ def flatten_dict(d: RecMapping[str, V], *, delim: str | None = None) -> dict[str """Flatten a recursive mapping. ```python exec="true" source="material-block" result="python" title="flatten_dict" - from amltk.functional import flatten_dict + from amltk._functional import flatten_dict d = {"a": 1, "b": {"c": 2, "d": 3}} print(flatten_dict(d)) @@ -118,7 +113,7 @@ def reverse_enumerate( sequence in reverse. ```python exec="true" source="material-block" result="python" title="reverse_enumerate" - from amltk.functional import reverse_enumerate + from amltk._functional import reverse_enumerate xs = ["a", "b", "c"] for i, x in reverse_enumerate(xs): @@ -144,7 +139,7 @@ def rgetattr(obj: Any, attr: str, *args: Any) -> Any: attributes using '.' notation. ```python exec="true" source="material-block" result="python" title="rgetattr" - from amltk.functional import rgetattr + from amltk._functional import rgetattr class A: x = 1 @@ -245,6 +240,33 @@ def classname(c: Any, default: str | None = None) -> str: return str(c) +def entity_name( + thing: Any, + default: str | None = None, +) -> str: + """Get the name of a thing. + + Args: + thing: The thing to get the name of. + default: The default value to return if the name cannot be + determined automatically. + + Returns: + The name of the thing. + """ + if isinstance(thing, str): + return thing + if isinstance(thing, type) or hasattr(thing, "__class__"): + return classname(thing) + if callable(thing): + return funcname(thing) + if hasattr(thing, "__name__"): + return str(thing.__name__) + if default is not None: + return default + return str(thing) + + def callstring(f: Callable, *args: Any, **kwargs: Any) -> str: """Get a string representation of a function call. @@ -279,7 +301,7 @@ def compare_accumulate( carried forward. ```python exec="true" source="material-block" result="python" title="compare_accumulate" - from amltk.functional import compare_accumulate + from amltk._functional import compare_accumulate xs = [5, 4, 6, 2, 1, 8] print(list(compare_accumulate(xs, lambda x, y: x > y))) @@ -322,7 +344,7 @@ def transformations( can safely ignore the type warnings if so. ```python exec="true" source="material-block" result="python" title="transforms" - from amltk.functional import transformations + from amltk._functional import transformations def f(x): return x + 1 diff --git a/src/amltk/_richutil/__init__.py b/src/amltk/_richutil/__init__.py new file mode 100644 index 00000000..50e6a507 --- /dev/null +++ b/src/amltk/_richutil/__init__.py @@ -0,0 +1,11 @@ +from amltk._richutil.renderable import RichRenderable +from amltk._richutil.renderers import Function, rich_make_column_selector +from amltk._richutil.util import df_to_table, richify + +__all__ = [ + "df_to_table", + "richify", + "RichRenderable", + "Function", + "rich_make_column_selector", +] diff --git a/src/amltk/richutil/renderable.py b/src/amltk/_richutil/renderable.py similarity index 98% rename from src/amltk/richutil/renderable.py rename to src/amltk/_richutil/renderable.py index 2662df6a..c70faf96 100644 --- a/src/amltk/richutil/renderable.py +++ b/src/amltk/_richutil/renderable.py @@ -8,8 +8,6 @@ if TYPE_CHECKING: from rich.console import RenderableType - pass - class RichRenderable(ABC): """Mixin for adding rich methods to a class.""" diff --git a/src/amltk/_richutil/renderers/__init__.py b/src/amltk/_richutil/renderers/__init__.py new file mode 100644 index 00000000..391b86f6 --- /dev/null +++ b/src/amltk/_richutil/renderers/__init__.py @@ -0,0 +1,5 @@ +from amltk._richutil.renderers._make_column_selector import rich_make_column_selector +from amltk._richutil.renderers.executors import ProcessPoolExecutorRenderer +from amltk._richutil.renderers.function import Function + +__all__ = ["Function", "rich_make_column_selector", "ProcessPoolExecutorRenderer"] diff --git a/src/amltk/richutil/renderers/_make_column_selector.py b/src/amltk/_richutil/renderers/_make_column_selector.py similarity index 100% rename from src/amltk/richutil/renderers/_make_column_selector.py rename to src/amltk/_richutil/renderers/_make_column_selector.py diff --git a/src/amltk/richutil/renderers/executors.py b/src/amltk/_richutil/renderers/executors.py similarity index 93% rename from src/amltk/richutil/renderers/executors.py rename to src/amltk/_richutil/renderers/executors.py index ec555907..3cdd92de 100644 --- a/src/amltk/richutil/renderers/executors.py +++ b/src/amltk/_richutil/renderers/executors.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from typing_extensions import override -from amltk.richutil.renderable import RichRenderable +from amltk._richutil.renderable import RichRenderable if TYPE_CHECKING: from rich.panel import Panel diff --git a/src/amltk/richutil/renderers/function.py b/src/amltk/_richutil/renderers/function.py similarity index 94% rename from src/amltk/richutil/renderers/function.py rename to src/amltk/_richutil/renderers/function.py index 795f0051..b43d124b 100644 --- a/src/amltk/richutil/renderers/function.py +++ b/src/amltk/_richutil/renderers/function.py @@ -2,14 +2,15 @@ from __future__ import annotations import inspect +from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable, Literal +from typing import TYPE_CHECKING, Literal from typing_extensions import override -from amltk.functional import funcname -from amltk.links import link +from amltk._doc import link +from amltk._functional import funcname +from amltk._richutil.renderable import RichRenderable from amltk.options import _amltk_options -from amltk.richutil.renderable import RichRenderable if TYPE_CHECKING: from rich.highlighter import Highlighter diff --git a/src/amltk/richutil/util.py b/src/amltk/_richutil/util.py similarity index 98% rename from src/amltk/richutil/util.py rename to src/amltk/_richutil/util.py index 610320cc..77d1cbee 100644 --- a/src/amltk/richutil/util.py +++ b/src/amltk/_richutil/util.py @@ -6,7 +6,7 @@ from concurrent.futures import ProcessPoolExecutor from typing import TYPE_CHECKING, Any -from amltk.richutil.renderers import ( +from amltk._richutil.renderers import ( ProcessPoolExecutorRenderer, rich_make_column_selector, ) diff --git a/src/amltk/building/__init__.py b/src/amltk/building/__init__.py deleted file mode 100644 index d0f52dae..00000000 --- a/src/amltk/building/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from amltk.building._sklearn_builder import sklearn_builder -from amltk.building.api import BuildError, build - -__all__ = ["build", "sklearn_builder", "BuildError"] diff --git a/src/amltk/building/_sklearn_builder.py b/src/amltk/building/_sklearn_builder.py deleted file mode 100644 index f940ef0a..00000000 --- a/src/amltk/building/_sklearn_builder.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Build a pipeline into an [`sklearn.pipeline.Pipeline`][sklearn.pipeline.Pipeline].""" -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from sklearn.pipeline import Pipeline as SklearnPipeline - - from amltk.pipeline import Pipeline - - -def sklearn_builder(pipeline: Pipeline) -> SklearnPipeline: - """Build a pipeline into a usable object. - - Args: - pipeline: The pipeline to build - - Returns: - The built sklearn pipeline - """ - try: - from amltk.sklearn.builder import build - - return build(pipeline) - except (ImportError, ModuleNotFoundError) as e: - raise ImportError( - "The sklearn builder requires the sklearn package to be installed. " - "Please install it using `pip install sklearn`.", - ) from e diff --git a/src/amltk/building/api.py b/src/amltk/building/api.py deleted file mode 100644 index a67557c1..00000000 --- a/src/amltk/building/api.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Module for pipeline building. - -The `Builder` is responsble for taking a configured pipeline and -assembling it into a runnable pipeline. By default, this will try -some rough heuristics to determine what to build from your -configured [`Pipeline`][amltk.pipeline.Pipeline]. -""" -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload -from typing_extensions import override - -from more_itertools import first_true, seekable - -from amltk.building._sklearn_builder import sklearn_builder -from amltk.exceptions import safe_map -from amltk.functional import funcname - -if TYPE_CHECKING: - from amltk.pipeline.pipeline import Pipeline - - B = TypeVar("B") - -DEFAULT_BUILDERS: list[Callable[[Pipeline], Any]] = [sklearn_builder] - -logger = logging.getLogger(__name__) - - -class BuildError(Exception): - """Error when a pipeline could not be built.""" - - def __init__( - self, - builders: list[str], - err_tbs: list[tuple[Exception, str]], - ) -> None: - """Create a new BuildError. - - Args: - builders: The builders that were tried - err_tbs: The errors and tracebacks for each builder - """ - self.builders = builders - self.err_tbs = err_tbs - super().__init__(builders, err_tbs) - - @override - def __str__(self) -> str: - return "\n".join( - [ - "Could not build pipeline with any of the builders:", - *[ - f" - {builder}: {err}\n{tb}" - for builder, (err, tb) in zip(self.builders, self.err_tbs) - ], - ], - ) - - -@overload -def build(pipeline: Pipeline, builder: None = None, **builder_kwargs: Any) -> Any: - ... - - -@overload -def build( - pipeline: Pipeline, - builder: Callable[[Pipeline], B], - **builder_kwargs: Any, -) -> B: - ... - - -def build( - pipeline: Pipeline, - builder: Callable[[Pipeline], B] | None = None, - **builder_kwargs: Any, -) -> B | Any: - """Build a pipeline into a usable object. - - Args: - pipeline: The pipeline to build - builder: The builder to use. Defaults to `None` which will - try to determine the best builder to use. - **builder_kwargs: Any keyword arguments to pass to the builder - - Returns: - The built pipeline - """ - builders: list[Any] - - if builder is None: - builders = DEFAULT_BUILDERS - if any(builder_kwargs): - logger.warning( - f"If using `{builder_kwargs=}`, you most likely want to" - " pass an explicit `builder` argument", - ) - - elif callable(builder): - builders = [builder] - - else: - raise NotImplementedError(f"Builder {builder} is not supported") - - def _build(_builder: Callable[[Pipeline], B]) -> B: - return _builder(pipeline) - - results = seekable(safe_map(_build, builders)) - - is_result = lambda r: not (isinstance(r, tuple) and isinstance(r[0], Exception)) - - selected_built_pipeline = first_true(results, default=None, pred=is_result) - - # If we didn't manage to build a pipeline, iterate through - # the errors and raise a ValueError - if selected_built_pipeline is None: - results.seek(0) # Reset to start of the iterator - builders = [funcname(builder) for builder in builders] - errors = [(err, tb) for err, tb in results] # type: ignore - raise BuildError(builders=builders, err_tbs=errors) - - assert not isinstance(selected_built_pipeline, Exception) - return selected_built_pipeline diff --git a/src/amltk/configspace/__init__.py b/src/amltk/configspace/__init__.py deleted file mode 100644 index 2240a461..00000000 --- a/src/amltk/configspace/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from amltk.configspace.space import ConfigSpaceAdapter - -ConfigSpaceParser = ConfigSpaceAdapter -ConfigSpaceSampler = ConfigSpaceAdapter - -__all__ = ["ConfigSpaceAdapter", "ConfigSpaceParser", "ConfigSpaceSampler"] diff --git a/src/amltk/configspace/space.py b/src/amltk/configspace/space.py deleted file mode 100644 index f9986489..00000000 --- a/src/amltk/configspace/space.py +++ /dev/null @@ -1,299 +0,0 @@ -"""A module to interact with ConfigSpace.""" -from __future__ import annotations - -from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Mapping, Sequence, Union -from typing_extensions import TypeAlias, override - -import numpy as np -from ConfigSpace import Categorical, ConfigurationSpace, Constant -from ConfigSpace.hyperparameters import Hyperparameter - -from amltk.pipeline.space import SpaceAdapter -from amltk.randomness import as_int - -if TYPE_CHECKING: - from amltk.types import Seed - - -InputSpace: TypeAlias = Union[Mapping[str, Any], ConfigurationSpace] - - -class ConfigSpaceAdapter(SpaceAdapter[InputSpace, ConfigurationSpace]): - """An adapter following the [`SpaceAdapter`][amltk.pipeline.SpaceAdapter] interface - for interacting with ConfigSpace spaces. - - This includes parsing ConfigSpace spaces following the - [`Parser`][amltk.pipeline.Parser] interface and sampling from them with - the [`Sampler`][amltk.pipeline.Sampler] interface. - """ - - @override - def parse_space( - self, - space: Any, - config: Mapping[str, Any] | None = None, - ) -> ConfigurationSpace: - """See [`Parser.parse_space`][amltk.pipeline.Parser.parse_space]. - - ```python exec="true" source="material-block" result="python" title="A simple space" - from amltk.configspace import ConfigSpaceAdapter - - search_space = { - "a": (1, 10), - "b": (0.5, 9.0), - "c": ["apple", "banana", "carrot"], - } - - adapter = ConfigSpaceAdapter() - space = adapter.parse(search_space) - print(space) - ``` - """ # noqa: E501 - if space is None: - _space = ConfigurationSpace() - elif isinstance(space, dict): - _space = ConfigurationSpace(space) - elif isinstance(space, Hyperparameter): - _space = ConfigurationSpace({space.name: space}) - elif isinstance(space, ConfigurationSpace): - _space = self.copy(space) - else: - raise TypeError(f"{space} is not parsable as a space") - - return _space - - @override - def set_seed(self, space: ConfigurationSpace, seed: Seed) -> ConfigurationSpace: - """Set the seed for the space. - - ```python exec="true" source="material-block" result="python" title="Setting the seed" - from amltk.configspace import ConfigSpaceAdapter - - adapter = ConfigSpaceAdapter() - - space = adapter.parse({ "a": (1, 10) }) - adapter.set_seed(space, seed=42) - - seeded_value_for_a = adapter.sample(space) - print(seeded_value_for_a) - ``` - - Args: - space: The space to set the seed for. - seed: The seed to set. - """ # noqa: E501 - _seed = as_int(seed) - space.seed(_seed) - return space - - @override - def insert( - self, - space: ConfigurationSpace, - subspace: Mapping[str, Any] | ConfigurationSpace, - *, - prefix_delim: tuple[str, str] | None = None, - ) -> ConfigurationSpace: - """See [`Parser.insert`][amltk.pipeline.Parser.insert]. - - ```python exec="true" source="material-block" result="python" title="Inserting one space into another" - from amltk.configspace import ConfigSpaceAdapter - - adapter = ConfigSpaceAdapter() - - space_1 = adapter.parse({ "a": (1, 10) }) - space_2 = adapter.parse({ "b": (10.5, 100.5) }) - space_3 = adapter.parse({ "c": ["apple", "banana", "carrot"] }) - - space = adapter.empty() - adapter.insert(space, space_1) - adapter.insert(space, space_2) - adapter.insert(space, space_3, prefix_delim=("fruit", ":")) - - print(space) - ``` - """ # noqa: E501 - if prefix_delim is None: - prefix_delim = ("", "") - - prefix, delim = prefix_delim - subspace = ( - ConfigurationSpace(dict(subspace)) - if not isinstance(subspace, ConfigurationSpace) - else subspace - ) - - space.add_configuration_space( - prefix=prefix, - configuration_space=subspace, - delimiter=delim, - ) - return space - - @override - def empty(self) -> ConfigurationSpace: - """See [`Parser.empty`][amltk.pipeline.Parser.empty]. - - ```python exec="true" source="material-block" result="python" title="Getting an empty space" - from amltk.configspace import ConfigSpaceAdapter - - adapter = ConfigSpaceAdapter() - empty_space = adapter.empty() - print(empty_space) - ``` - """ # noqa: E501 - return ConfigurationSpace() - - @override - def condition( - self, - choice_name: str, - delim: str, - spaces: dict[str, ConfigurationSpace], - weights: Sequence[float] | None = None, - ) -> ConfigurationSpace: - """See [`Parser.condition`][amltk.pipeline.Parser.condition]. - - ```python exec="true" source="material-block" result="python" title="Conditioning a space" - from amltk.configspace import ConfigSpaceAdapter - - adapter = ConfigSpaceAdapter() - - space_a = adapter.parse({ "a": (1, 10) }) - space_b = adapter.parse({ "b": (200, 300) }) - - space = adapter.condition( - choice_name="letter", - delim=":", - spaces={ "a": space_a, "b": space_b } - ) - print(space) - ``` - """ # noqa: E501 - space = ConfigurationSpace() - - items = list(spaces.keys()) - choice = Categorical(choice_name, items=items, weights=weights) - space.add_hyperparameter(choice) - - for key, subspace in spaces.items(): - space.add_configuration_space( - prefix=choice_name, - configuration_space=subspace, - parent_hyperparameter={"parent": choice, "value": key}, - delimiter=delim, - ) - return space - - @override - def _sample( - self, - space: ConfigurationSpace, - n: int = 1, - seed: Seed | None = None, - ) -> list[Mapping[str, Any]]: - """See [`Sampler._sample`][amltk.pipeline.Sampler._sample].""" - if seed: - seed_int = as_int(seed) - self.set_seed(space, seed_int) - - if n == 1: - return [dict(space.sample_configuration())] - - return [dict(c) for c in space.sample_configuration(n)] - - @override - def copy(self, space: ConfigurationSpace) -> ConfigurationSpace: - """See [`Sampler.copy`][amltk.pipeline.Sampler.copy]. - - ```python exec="true" source="material-block" result="python" title="Copying a space" - from amltk.configspace import ConfigSpaceAdapter - - adapter = ConfigSpaceAdapter() - - space_original = adapter.parse({ "a": (1, 10) }) - space_copy = adapter.copy(space_original) - - print(space_copy) - ``` - """ # noqa: E501 - return deepcopy(space) - - @classmethod - @override - def supports_sampling(cls, space: Any) -> bool: - """See [`Sampler.supports_sampling`][amltk.pipeline.Sampler.supports_sampling]. - - Args: - space: The space to check. - - Returns: - True if the space is a ConfigurationSpace. - """ - return isinstance(space, ConfigurationSpace) - - @classmethod - def remove_hyperparameter( - cls, - name: str, - space: ConfigurationSpace, - ) -> ConfigurationSpace: - """A new configuration space with the hyperparameter removed. - - Essentially copies hp over and fails if there is conditionals or forbiddens - """ - if name not in space._hyperparameters: - raise ValueError(f"{name} not in {space}") - - # Copying conditionals only work on objects and not named entities - # Seeing as we copy objects and don't use the originals, transfering these - # to the new objects is a bit tedious, possible but not required at this time - # ... same goes for forbiddens - assert name not in space._conditionals, "Can't handle conditionals" - assert not any( - name != f.hyperparameter.name for f in space.get_forbiddens() - ), "Can't handle forbiddens" - - hps = [copy(hp) for hp in space.get_hyperparameters() if hp.name != name] - - if isinstance(space.random, np.random.RandomState): # type: ignore - new_seed = space.random.randint(2**32 - 1) - else: - new_seed = copy(space.random) - - new_space = ConfigurationSpace( - # TODO: not sure if this will have implications, assuming not - seed=new_seed, - name=copy(space.name), - meta=copy(space.meta), - ) - new_space.add_hyperparameters(hps) - return new_space - - @classmethod - def replace_constants( - cls, - config: Mapping[str, Any], - space: ConfigurationSpace, - ) -> ConfigurationSpace: - """Search the config for any hyperparameters that are in the space and need. - to be replaced with a constant. - - Args: - config: The configuration associated with a step, which may have - overlaps with the ConfigurationSpace - space: The space to remove overlapping parameters from - - Returns: - ConfigurationSpace: A copy of the space with the hyperparameters replaced - """ - for key, value in config.items(): - if key in space._hyperparameters: - space = cls.remove_hyperparameter(key, space) - - if not isinstance(value, bool): - hp = Constant(key, value) - space.add_hyperparameter(hp) - - return space diff --git a/src/amltk/dask_jobqueue/__init__.py b/src/amltk/dask_jobqueue/__init__.py deleted file mode 100644 index a87a60a1..00000000 --- a/src/amltk/dask_jobqueue/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from amltk.dask_jobqueue.executors import DJQ_NAMES, DaskJobqueueExecutor - -__all__ = ["DaskJobqueueExecutor", "DJQ_NAMES"] diff --git a/src/amltk/data/conversions.py b/src/amltk/data/conversions.py index 094a922c..1c38ef2b 100644 --- a/src/amltk/data/conversions.py +++ b/src/amltk/data/conversions.py @@ -58,7 +58,7 @@ def to_numpy( Returns: The converted data """ - _x = x.to_numpy() if isinstance(x, (pd.DataFrame, pd.Series)) else np.asarray(x) + _x = x.to_numpy() if isinstance(x, pd.DataFrame | pd.Series) else np.asarray(x) if ( flatten_if_1d diff --git a/src/amltk/data/dtype_reduction.py b/src/amltk/data/dtype_reduction.py index 29c559f7..31e52a69 100644 --- a/src/amltk/data/dtype_reduction.py +++ b/src/amltk/data/dtype_reduction.py @@ -2,15 +2,14 @@ from __future__ import annotations import logging -from typing import TypeVar, Union -from typing_extensions import TypeAlias +from typing import TypeAlias, TypeVar import numpy as np import pandas as pd logger = logging.getLogger(__name__) -DataContainer: TypeAlias = Union[np.ndarray, pd.DataFrame, pd.Series] +DataContainer: TypeAlias = np.ndarray | (pd.DataFrame | pd.Series) D = TypeVar("D", bound=DataContainer) @@ -95,10 +94,10 @@ def reduce_dtypes(x: D, *, reduce_int: bool = True, reduce_float: bool = True) - reduce_int: Whether to reduce integer dtypes. reduce_float: Whether to reduce floating point dtypes. """ - if not isinstance(x, (pd.DataFrame, pd.Series, np.ndarray)): + if not isinstance(x, pd.DataFrame | pd.Series | np.ndarray): raise TypeError(f"Cannot reduce data of type {type(x)}.") - if isinstance(x, (pd.Series, pd.DataFrame)): + if isinstance(x, pd.Series | pd.DataFrame): x = x.convert_dtypes() if reduce_int: diff --git a/src/amltk/data/measure.py b/src/amltk/data/measure.py index 4dd10435..9f9efd9f 100644 --- a/src/amltk/data/measure.py +++ b/src/amltk/data/measure.py @@ -2,7 +2,8 @@ from __future__ import annotations import sys -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any import numpy as np import pandas as pd diff --git a/src/amltk/distances.py b/src/amltk/distances.py index 9c4621aa..9bcf420e 100644 --- a/src/amltk/distances.py +++ b/src/amltk/distances.py @@ -5,9 +5,9 @@ """ from __future__ import annotations +from collections.abc import Callable from functools import partial -from typing import Any, Callable, Literal -from typing_extensions import TypeAlias +from typing import Any, Literal, TypeAlias import numpy as np import numpy.typing as npt diff --git a/src/amltk/ensembling/weighted_ensemble_caruana.py b/src/amltk/ensembling/weighted_ensemble_caruana.py index 256b97dd..bf643c3b 100644 --- a/src/amltk/ensembling/weighted_ensemble_caruana.py +++ b/src/amltk/ensembling/weighted_ensemble_caruana.py @@ -15,7 +15,8 @@ import logging from collections import Counter -from typing import TYPE_CHECKING, Callable, Hashable, Iterable, Mapping, TypeVar +from collections.abc import Callable, Hashable, Iterable, Mapping +from typing import TYPE_CHECKING, TypeVar import numpy as np diff --git a/src/amltk/exceptions.py b/src/amltk/exceptions.py index b4b93b80..bbeb215c 100644 --- a/src/amltk/exceptions.py +++ b/src/amltk/exceptions.py @@ -4,8 +4,12 @@ from __future__ import annotations import traceback -from typing import Any, Callable, Iterable, Iterator, TypeVar -from typing_extensions import ParamSpec +from collections.abc import Callable, Iterable, Iterator +from typing import TYPE_CHECKING, Any, TypeVar +from typing_extensions import ParamSpec, override + +if TYPE_CHECKING: + from amltk.pipeline.node import Node R = TypeVar("R") E = TypeVar("E") @@ -70,3 +74,35 @@ class SchedulerNotRunningError(RuntimeError): class EventNotKnownError(ValueError): """The event is not a known one.""" + + +class NoChoiceMadeError(ValueError): + """No choice was made.""" + + +class NodeNotFoundError(ValueError): + """The node was not found.""" + + +class RequestNotMetError(ValueError): + """Raised when a request is not met.""" + + +class DuplicateNamesError(ValueError): + """Raised when duplicate names are found.""" + + def __init__(self, node: Node) -> None: + """Initialize the exception. + + Args: + node: The node that has children with duplicate names. + """ + super().__init__(node) + self.node = node + + @override + def __str__(self) -> str: + return ( + f"Duplicate names found in {self.node.name} and can't be handled." + f"\nnodes: {[n.name for n in self.node.nodes]}." + ) diff --git a/src/amltk/fluid.py b/src/amltk/fluid.py deleted file mode 100644 index 049a86c8..00000000 --- a/src/amltk/fluid.py +++ /dev/null @@ -1,177 +0,0 @@ -"""A module for some useful tools for creating more fluid interfaces.""" -from __future__ import annotations - -from dataclasses import dataclass -from typing import Callable, Generic, Protocol, TypeVar -from typing_extensions import ParamSpec - -from amltk.types import Comparable - -V = TypeVar("V", bound=Comparable) -P = ParamSpec("P") -R_co = TypeVar("R_co", covariant=True) - - -class Partial(Protocol[P, R_co]): - """A protocol for partial functions.""" - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co: - """Call the function.""" - ... - - -@dataclass -class ChainPredicate(Generic[P]): - """A predicate that can be chained with other predicates. - - Can be chained with other callables using `&` and `|` operators. - - ```python - from amltk.fluid import ChainPredicate - - def is_even(x: int) -> bool: - return x % 2 == 0 - - def is_odd(x: int) -> bool: - return x % 2 == 1 - - and_combined = ChainPredicate() & is_even & is_odd - assert and_combined_pred(1) is False - - or_combined = ChainPredicate() & is_even | is_odd - assert or_combined_pred(1) is True - ``` - - Attributes: - pred: The predicate to be evaluated. - Defaults to `None` which defaults to returning `True` - when called. - """ - - pred: Callable[P, bool] | None = None - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> bool: - """Evaluate the predicate chain.""" - if self.pred is None: - return True - - return self.pred(*args, **kwargs) - - def __and__(self, other: Callable[P, bool] | None) -> ChainPredicate[P]: - if other is None: - return self - - call = other - - def _pred(*args: P.args, **kwargs: P.kwargs) -> bool: - return self(*args, **kwargs) and call(*args, **kwargs) - - return ChainPredicate(_pred) - - def __or__(self, other: Callable[P, bool] | None) -> ChainPredicate[P]: - if other is None: - return self - - call = other - - def _pred(*args: P.args, **kwargs: P.kwargs) -> bool: - return self(*args, **kwargs) or call(*args, **kwargs) - - return ChainPredicate(_pred) - - @classmethod - def all(cls, *preds: Callable[P, bool]) -> ChainPredicate[P]: - """Create an all predicate from multiple predicates. - - Args: - preds: The predicates to combine. - - Returns: - The combined predicate. - """ - - def _pred(*args: P.args, **kwargs: P.kwargs) -> bool: - return all(pred(*args, **kwargs) for pred in preds) - - return ChainPredicate[P](_pred) - - @classmethod - def any(cls, *preds: Callable[P, bool]) -> ChainPredicate[P]: - """Create an any predicate from multiple predicates. - - Args: - preds: The predicates to combine. - - Returns: - The combined predicate. - """ - - def _pred(*args: P.args, **kwargs: P.kwargs) -> bool: - return any(pred(*args, **kwargs) for pred in preds) - - return ChainPredicate[P](_pred) - - -@dataclass -class DelayedOp(Generic[V, P]): - """A delayed binary operation that can be chained with other operations. - - Sometimes we want to be able to save a predicate for later evaluation but - use familiar operators to build it up. This class allows us to do that. - - ```python - from amltk.fluid import DelayedOp - from dataclasses import dataclass - - @dataclass - class DynamicThing: - _x: int - - def value(self) -> int: - return self._x * 2 - - dynamo = DynamicThing(2) - - delayed = DelayedOp(dynamo.value) < 5 - assert delayed() is True - - dynamo._x = 3 - assert delayed() is False - ``` - - - Attributes: - left: The left-hand side of the operation to be evaluated later. - """ - - left: Callable[P, V] - - def __lt__(self, right: V) -> ChainPredicate[P]: - def op(*args: P.args, **kwargs: P.kwargs) -> bool: - return self.left(*args, **kwargs) < right - - return ChainPredicate(op) - - def __le__(self, right: V) -> ChainPredicate[P]: - def op(*args: P.args, **kwargs: P.kwargs) -> bool: - return self.left(*args, **kwargs) <= right # type: ignore - - return ChainPredicate(op) - - def __gt__(self, right: V) -> ChainPredicate[P]: - def op(*args: P.args, **kwargs: P.kwargs) -> bool: - return self.left(*args, **kwargs) > right - - return ChainPredicate(op) - - def __ge__(self, right: V) -> ChainPredicate[P]: - def op(*args: P.args, **kwargs: P.kwargs) -> bool: - return self.left(*args, **kwargs) >= right # type: ignore - - return ChainPredicate(op) - - def __eq__(self, right: V) -> ChainPredicate[P]: # type: ignore - def op(*args: P.args, **kwargs: P.kwargs) -> bool: - return self.left(*args, **kwargs) == right - - return ChainPredicate(op) diff --git a/src/amltk/links.py b/src/amltk/links.py deleted file mode 100644 index d292e8a8..00000000 --- a/src/amltk/links.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Links to documentation pages.""" -from __future__ import annotations - -from functools import lru_cache -from typing import Any - -SKLEARN_LINK = "https://www.scikit-learn.org/stable/modules/generated/{0}.html" - - -def sklearn_link_generator(name: str) -> str: - """Generate a link for a sklearn function.""" - reduced_name = ".".join(s for s in name.split(".") if not s.startswith("_")) - return SKLEARN_LINK.format(reduced_name) - - -@lru_cache -def _try_get_link(fully_scoped_name: str) -> str | None: - """Try to get a link for a string. - - Expects fully qualified import names. - """ - from amltk.options import _amltk_options - - links = _amltk_options.get("links", {}) - - for k, v in links.items(): - if fully_scoped_name.startswith(k): - if isinstance(v, str): - return v - if callable(v): - return v(fully_scoped_name) - - return None - - -def link(obj: Any) -> str | None: - """Try to get a link for an object.""" - from amltk.functional import fullname - - return _try_get_link(fullname(obj)) diff --git a/src/amltk/metalearning/dataset_distances.py b/src/amltk/metalearning/dataset_distances.py index 1124c9ab..a7939ffc 100644 --- a/src/amltk/metalearning/dataset_distances.py +++ b/src/amltk/metalearning/dataset_distances.py @@ -1,22 +1,107 @@ -"""Calculate distances between datasets based on metafeatures. +"""One common way to define how similar two datasets are is to compute some "similarity" +between them. This notion of "similarity" requires computing some features of a dataset +(**metafeatures**) first, such that we can numerically compute some distance function. -Please see the reference section on -[Metalearning](site:reference/metalearning.md) for more! -""" +Let's see how we can quickly compute the distance between some datasets with +[`dataset_distance()`][amltk.metalearning.dataset_distance]! + +```python exec="true" source="material-block" result="python" title="Dataset Distances P.1" session='dd' +import pandas as pd +import openml + +from amltk.metalearning import compute_metafeatures + +def get_dataset(dataset_id: int) -> tuple[pd.DataFrame, pd.Series]: + dataset = openml.datasets.get_dataset( + dataset_id, + download_data=True, + download_features_meta_data=False, + download_qualities=False, + ) + X, y, _, _ = dataset.get_data( + dataset_format="dataframe", + target=dataset.default_target_attribute, + ) + return X, y + +d31 = get_dataset(31) +d3 = get_dataset(3) +d4 = get_dataset(4) + +metafeatures_dict = { + "dataset_31": compute_metafeatures(*d31), + "dataset_3": compute_metafeatures(*d3), + "dataset_4": compute_metafeatures(*d4), +} + +metafeatures = pd.DataFrame(metafeatures_dict) +print(metafeatures) +``` + +Now we want to know which one of `#!python "dataset_3"` or `#!python "dataset_4"` is +more _similar_ to `#!python "dataset_31"`. + +```python exec="true" source="material-block" result="python" title="Dataset Distances P.2" session='dd' +from amltk.metalearning import dataset_distance + +target = metafeatures_dict.pop("dataset_31") +others = metafeatures_dict + +distances = dataset_distance(target, others, distance_metric="l2") +print(distances) +``` + +Seems like `#!python "dataset_3"` is some notion of closer to `#!python "dataset_31"` +than `#!python "dataset_4"`. However the scale of the metafeatures are not exactly all close. +For example, many lie between `#!python (0, 1)` but some like `instance_count` can completely +dominate the show. + +Lets repeat the computation but specify that we should apply a `#!python "minmax"` scaling +across the rows. + +```python exec="true" source="material-block" result="python" title="Dataset Distances P.3" session='dd' hl_lines="5" +distances = dataset_distance( + target, + others, + distance_metric="l2", + scaler="minmax" +) +print(distances) +``` + +Now `#!python "dataset_3"` is considered more similar but the difference between the two is a lot less +dramatic. In general, applying some scaling to values of different scales is required for metalearning. + +You can also use an [sklearn.preprocessing.MinMaxScaler][] or anything other scaler from scikit-learn +for that matter. + +```python exec="true" source="material-block" result="python" title="Dataset Distances P.3" session='dd' hl_lines="7" +from sklearn.preprocessing import MinMaxScaler + +distances = dataset_distance( + target, + others, + distance_metric="l2", + scaler=MinMaxScaler() +) +print(distances) +``` +""" # noqa: E501 from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Callable, Literal, Mapping, TypeVar +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Literal, TypeVar import pandas as pd +from amltk._functional import funcname from amltk.distances import ( DistanceMetric, NamedDistance, NearestNeighborsDistance, distance_metrics, ) -from amltk.functional import funcname from amltk.types import safe_isinstance if TYPE_CHECKING: diff --git a/src/amltk/metalearning/metafeatures.py b/src/amltk/metalearning/metafeatures.py index 227db6ed..6bea8a6c 100644 --- a/src/amltk/metalearning/metafeatures.py +++ b/src/amltk/metalearning/metafeatures.py @@ -1,26 +1,148 @@ -"""Metafeatures for use in metalearning. +'''A [`MetaFeature`][amltk.metalearning.MetaFeature] is some +statistic about a dataset/task, that can be used to make datasets or +tasks more comparable, thus enabling meta-learning methods. -Please see the reference section on -[Metalearning](site:reference/metalearning.md) for more! -""" +Calculating meta-features of a dataset is quite straight foward. + +```python exec="true" source="material-block" result="python" title="Metafeatures" hl_lines="10" +import openml +from amltk.metalearning import compute_metafeatures + +dataset = openml.datasets.get_dataset( + 31, # credit-g + download_data=True, + download_features_meta_data=False, + download_qualities=False, +) +X, y, _, _ = dataset.get_data( + dataset_format="dataframe", + target=dataset.default_target_attribute, +) + +mfs = compute_metafeatures(X, y) + +print(mfs) +``` + +By default [`compute_metafeatures()`][amltk.metalearning.compute_metafeatures] will +calculate all the [`MetaFeature`][amltk.metalearning.MetaFeature] implemented, +iterating through their subclasses to do so. You can pass an explicit list +as well to `compute_metafeatures(X, y, features=[...])`. + +To implement your own is also quite straight forward: + +```python exec="true" source="material-block" result="python" title="Create Metafeature" hl_lines="10 11 12 13 14 15 16 17 18 19" +from amltk.metalearning import MetaFeature, compute_metafeatures +import openml + +dataset = openml.datasets.get_dataset( + 31, # credit-g + download_data=True, + download_features_meta_data=False, + download_qualities=False, +) +X, y, _, _ = dataset.get_data( + dataset_format="dataframe", + target=dataset.default_target_attribute, +) + +class TotalValues(MetaFeature): + + @classmethod + def compute( + cls, + x: pd.DataFrame, + y: pd.Series | pd.DataFrame, + dependancy_values: dict, + ) -> int: + return int(x.shape[0] * x.shape[1]) + +mfs = compute_metafeatures(X, y, features=[TotalValues]) +print(mfs) +``` + +As many metafeatures rely on pre-computed dataset statistics, and they do not +need to be calculated more than once, you can specify the dependancies of +a meta feature. When a metafeature would return something other than a single +value, i.e. a `dict` or a `pd.DataFrame`, we instead call those a +[`DatasetStatistic`][amltk.metalearning.DatasetStatistic]. These will +**not** be included in the result of [`compute_metafeatures()`][amltk.metalearning.compute_metafeatures]. +These `DatasetStatistic`s will only be calculated once on a call to `compute_metafeatures()` so +they can be re-used across all `MetaFeature`s that require that dependancy. + +```python exec="true" source="material-block" result="python" title="Metafeature Dependancy" hl_lines="10 11 12 13 14 15 16 17 18 19 20 23 26 35" +from amltk.metalearning import MetaFeature, DatasetStatistic, compute_metafeatures +import openml + +dataset = openml.datasets.get_dataset( + 31, # credit-g + download_data=True, + download_features_meta_data=False, + download_qualities=False, +) +X, y, _, _ = dataset.get_data( + dataset_format="dataframe", + target=dataset.default_target_attribute, +) + +class NAValues(DatasetStatistic): + """A mask of all NA values in a dataset""" + + @classmethod + def compute( + cls, + x: pd.DataFrame, + y: pd.Series | pd.DataFrame, + dependancy_values: dict, + ) -> pd.DataFrame: + return x.isna() + + +class PercentageNA(MetaFeature): + """The percentage of values missing""" + + dependencies = (NAValues,) + + @classmethod + def compute( + cls, + x: pd.DataFrame, + y: pd.Series | pd.DataFrame, + dependancy_values: dict, + ) -> int: + na_values = dependancy_values[NAValues] + n_na = na_values.sum().sum() + n_values = int(x.shape[0] * x.shape[1]) + return float(n_na / n_values) + +mfs = compute_metafeatures(X, y, features=[PercentageNA]) +print(mfs) +``` + +To view the description of a particular `MetaFeature`, you can call +[`.description()`][amltk.metalearning.DatasetStatistic.description] +on it. Otherwise you can access all of them in the following way: + +```python exec="true" source="tabbed-left" result="python" title="Metafeature Descriptions" hl_lines="4" +from pprint import pprint +from amltk.metalearning import metafeature_descriptions + +descriptions = metafeature_descriptions() +for name, description in descriptions.items(): + print("---") + print(name) + print("---") + print(" * " + description) +``` +''' # noqa: E501 from __future__ import annotations import logging import re from abc import ABC, abstractmethod -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - Generic, - Iterable, - Iterator, - Mapping, - Tuple, - TypeVar, -) -from typing_extensions import TypeAlias, override +from collections.abc import Iterable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar +from typing_extensions import override import numpy as np import pandas as pd @@ -30,7 +152,7 @@ CAMEL_CASE_PATTERN = re.compile(r"(? Balanced. 1 most imbalanced. No categories implies perfectly balanced. @@ -726,7 +848,7 @@ def compute( return float(np.std(list(imbalances.values()))) -class SkewnessPerNumericalColumn(DatasetStatistic[Dict[str, float]]): +class SkewnessPerNumericalColumn(DatasetStatistic[dict[str, float]]): """Skewness of each numerical feature.""" dependencies = (NumericalColumns,) @@ -827,7 +949,7 @@ def compute( return float(np.max(list(skews.values()))) -class KurtosisPerNumericalColumn(DatasetStatistic[Dict[str, float]]): +class KurtosisPerNumericalColumn(DatasetStatistic[dict[str, float]]): """Kurtosis of each numerical feature.""" dependencies = (NumericalColumns,) diff --git a/src/amltk/metalearning/portfolio.py b/src/amltk/metalearning/portfolio.py index 91abc1f1..7a282a87 100644 --- a/src/amltk/metalearning/portfolio.py +++ b/src/amltk/metalearning/portfolio.py @@ -1,7 +1,120 @@ -"""Portfolio selection for meta-learning.""" +"""A portfolio in meta-learning is to a set (ordered or not) of configurations +that maximize some notion of coverage across datasets or tasks. +The intuition here is that this also means that any new dataset is also covered! + +Suppose we have the given performances of some configurations across some datasets. +```python exec="true" source="material-block" result="python" title="Initial Portfolio" +import pandas as pd + +performances = { + "c1": [90, 60, 20, 10], + "c2": [20, 10, 90, 20], + "c3": [10, 20, 40, 90], + "c4": [90, 10, 10, 10], +} +portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) +print(portfolio) +``` + +If we could only choose `#!python k=3` of these configurations on some new given dataset, which ones would +you choose and in what priority? +Here is where we can apply [`portfolio_selection()`][amltk.metalearning.portfolio_selection]! + +The idea is that we pick a subset of these algorithms that maximise some value of utility for +the portfolio. We do this by adding a single configuration from the entire set, 1-by-1 until +we reach `k`, beginning with the empty portfolio. + +Let's see this in action! + +```python exec="true" source="material-block" result="python" title="Portfolio Selection" hl_lines="12 13 14 15 16" +import pandas as pd +from amltk.metalearning import portfolio_selection + +performances = { + "c1": [90, 60, 20, 10], + "c2": [20, 10, 90, 20], + "c3": [10, 20, 40, 90], + "c4": [90, 10, 10, 10], +} +portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) + +selected_portfolio, trajectory = portfolio_selection( + portfolio, + k=3, + scaler="minmax" +) + +print(selected_portfolio) +print() +print(trajectory) +``` + +The trajectory tells us which configuration was added at each time stamp along with the utility +of the portfolio with that configuration added. However we havn't specified how _exactly_ we defined the +utility of a given portfolio. We could define our own function to do so: + +```python exec="true" source="material-block" result="python" title="Portfolio Selection Custom" hl_lines="12 13 14 20" +import pandas as pd +from amltk.metalearning import portfolio_selection + +performances = { + "c1": [90, 60, 20, 10], + "c2": [20, 10, 90, 20], + "c3": [10, 20, 40, 90], + "c4": [90, 10, 10, 10], +} +portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) + +def my_function(p: pd.DataFrame) -> float: + # Take the maximum score for each dataset and then take the mean across them. + return p.max(axis=1).mean() + +selected_portfolio, trajectory = portfolio_selection( + portfolio, + k=3, + scaler="minmax", + portfolio_value=my_function, +) + +print(selected_portfolio) +print() +print(trajectory) +``` + +This notion of reducing across all configurations for a dataset and then aggregating these is common +enough that we can also directly just define these operations and we will perform the rest. + +```python exec="true" source="material-block" result="python" title="Portfolio Selection With Reduction" hl_lines="17 18" +import pandas as pd +import numpy as np +from amltk.metalearning import portfolio_selection + +performances = { + "c1": [90, 60, 20, 10], + "c2": [20, 10, 90, 20], + "c3": [10, 20, 40, 90], + "c4": [90, 10, 10, 10], +} +portfolio = pd.DataFrame(performances, index=["dataset_1", "dataset_2", "dataset_3", "dataset_4"]) + +selected_portfolio, trajectory = portfolio_selection( + portfolio, + k=3, + scaler="minmax", + row_reducer=np.max, # This is actually the default + aggregator=np.mean, # This is actually the default +) + +print(selected_portfolio) +print() +print(trajectory) +``` +""" # noqa: E501 + from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Hashable, Literal, TypeVar +from collections.abc import Callable, Hashable +from typing import TYPE_CHECKING, Literal, TypeVar import numpy as np import pandas as pd @@ -161,9 +274,7 @@ def portfolio_selection( # Possible get multiple best choices, we choose one at random if so best_keys = [k for k, v in values_possible.items() if v == best_possible] best_key = ( - best_keys[0] - if len(best_keys) == 1 - else rng.choice(best_keys) # type: ignore + best_keys[0] if len(best_keys) == 1 else rng.choice(best_keys) # type: ignore ) # We found something better, add it in diff --git a/src/amltk/neps/__init__.py b/src/amltk/neps/__init__.py deleted file mode 100644 index d53efe55..00000000 --- a/src/amltk/neps/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from amltk.neps.optimizer import NEPSOptimizer, NEPSTrialInfo - -__all__ = ["NEPSOptimizer", "NEPSTrialInfo"] diff --git a/src/amltk/optimization/__init__.py b/src/amltk/optimization/__init__.py index a1a34311..ae3598d0 100644 --- a/src/amltk/optimization/__init__.py +++ b/src/amltk/optimization/__init__.py @@ -1,16 +1,11 @@ from amltk.optimization.history import History, IncumbentTrace, Trace from amltk.optimization.optimizer import Optimizer -from amltk.optimization.random_search import RandomSearch, RSTrialInfo from amltk.optimization.trial import Trial -from amltk.pipeline.api import searchable __all__ = [ "Optimizer", - "RandomSearch", "Trial", - "RSTrialInfo", "History", "Trace", "IncumbentTrace", - "searchable", ] diff --git a/src/amltk/optimization/history.py b/src/amltk/optimization/history.py index e9dd5a9b..9140d2e9 100644 --- a/src/amltk/optimization/history.py +++ b/src/amltk/optimization/history.py @@ -1,19 +1,76 @@ -"""Classes for keeping track of trials.""" +"""The [`History`][amltk.optimization.History] is +used to keep a structured record of what occured with +[`Trial`][amltk.optimization.Trial]s and their associated +[`Report`][amltk.optimization.Trial.Report]s. + +??? tip "Usage" + + ```python exec="true" source="material-block" html="true" hl_lines="19 23-24" + from amltk.optimization import Trial, History + from amltk.store import PathBucket + + def target_function(trial: Trial) -> Trial.Report: + x = trial.config["x"] + y = trial.config["y"] + trial.store({"config.json": trial.config}) + + with trial.begin(): + cost = x**2 - y + + if trial.exception: + return trial.fail() + + return trial.success(cost=cost) + + # ... usually obtained from an optimizer + bucket = PathBucket("all-trial-results") + history = History() + + for x, y in zip([1, 2, 3], [4, 5, 6]): + trial = Trial(name="some-unique-name", config={"x": x, "y": y}, bucket=bucket) + report = target_function(trial) + history.add(report) + + print(history.df()) + bucket.rmdir() # markdon-exec: hide + ``` + +You'll often need to perform some operations on a +[`History`][amltk.optimization.History] so we provide some utility functions here: + +* [`filter(by=...)`][amltk.optimization.History.filter] - Filters the history by some + predicate, e.g. `#!python history.filter(lambda report: report.status == "success")` +* [`groupby(key=...)`][amltk.optimization.History.groupby] - Groups the history by some + key, e.g. `#!python history.groupby(lambda report: report.config["x"] < 5)` +* [`sortby(key=...)`][amltk.optimization.History.sortby] - Sorts the history by some + key, e.g. `#!python history.sortby(lambda report: report.time.end)` + + This will return a [`Trace`][amltk.optimization.Trace] which is the same + as a `History` in many respects, other than the fact it now has a sorted order. + +There is also some serialization capabilities built in, to allow you to store +your results and load them back in later: + +* [`df(...)`][amltk.optimization.History.df] - Output a `pd.DataFrame` of all + the information available. +* [`from_df(...)`][amltk.optimization.History.from_df] - Create a `History` from + a `pd.DataFrame`. + +You can also retrieve individual reports from the history by using their +name, e.g. `#!python history["some-unique-name"]` or iterate through +the history with `#!python for report in history: ...`. +""" from __future__ import annotations import operator from collections import defaultdict +from collections.abc import Callable, Hashable, Iterable, Iterator, Sequence from dataclasses import dataclass, field from pathlib import Path from typing import ( IO, TYPE_CHECKING, - Callable, - Hashable, - Iterable, - Iterator, Literal, - Sequence, TypeVar, overload, ) @@ -21,9 +78,9 @@ import pandas as pd -from amltk.functional import compare_accumulate +from amltk._functional import compare_accumulate +from amltk._richutil import RichRenderable from amltk.optimization.trial import Trial -from amltk.richutil import RichRenderable from amltk.types import Comparable if TYPE_CHECKING: @@ -341,10 +398,10 @@ def from_csv(cls, path: str | Path | IO[str] | pd.DataFrame) -> History: """ _df = ( pd.read_csv( - path, + path, # type: ignore float_precision="round_trip", # type: ignore ) - if isinstance(path, (IO, str, Path)) + if isinstance(path, IO | str | Path) else path ) @@ -366,7 +423,7 @@ def from_df(cls, df: pd.DataFrame) -> History: @override def __rich__(self) -> RenderableType: - from amltk.richutil import df_to_table + from amltk._richutil import df_to_table return df_to_table( self.df(configs=False, profiles=False, summary=False), @@ -578,8 +635,6 @@ def incumbents( if op not in {"min", "max"}: raise ValueError(f"Unknown op: {op}") op = operator.lt if op == "min" else operator.gt # type: ignore - else: - op = op # type: ignore if isinstance(key, str): trace = self.filter(lambda report: key in report.summary) diff --git a/src/amltk/optimization/optimizer.py b/src/amltk/optimization/optimizer.py index 16c23320..b7080157 100644 --- a/src/amltk/optimization/optimizer.py +++ b/src/amltk/optimization/optimizer.py @@ -1,13 +1,35 @@ -"""Protocols for the optimization module.""" +"""The base [`Optimizer`][amltk.optimization.optimizer.Optimizer] class, +defines the API we require optimizers to implement. + +* [`ask()`][amltk.optimization.optimizer.Optimizer.ask] - Ask the optimizer for a + new [`Trial`][amltk.optimization.trial.Trial] to evaluate. +* [`tell()`][amltk.optimization.optimizer.Optimizer.tell] - Tell the optimizer + the result of the sampled config. This comes in the form of a + [`Trial.Report`][amltk.optimization.trial.Trial.Report]. + +Additionally, to aid users from switching between optimizers, the +[`preferred_parser()`][amltk.optimization.optimizer.Optimizer.preferred_parser] +method should return either a `parser` function or a string that can be used +with [`node.search_space(parser=..._)`][amltk.pipeline.Node.search_space] to +extract the search space for the optimizer. + +Please see the [optimizer reference](site:reference/optimization/optimizers.md) +for more. +""" from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar if TYPE_CHECKING: from amltk.optimization.trial import Trial + from amltk.pipeline import Node + from amltk.store import Bucket I = TypeVar("I") # noqa: E741 +P = ParamSpec("P") +ParserOutput = TypeVar("ParserOutput") class Optimizer(Generic[I]): @@ -17,6 +39,19 @@ class Optimizer(Generic[I]): `tell` to inform the optimizer of the report from that trial. """ + bucket: Bucket | None + """The bucket to give to trials generated by this optimizer.""" + + def __init__(self, bucket: Bucket | None = None) -> None: + """Initialize the optimizer. + + Args: + bucket: The bucket to store results of individual trials from this + optimizer. + """ + super().__init__() + self.bucket = bucket + @abstractmethod def tell(self, report: Trial.Report[I]) -> None: """Tell the optimizer the report for an asked trial. @@ -33,3 +68,16 @@ def ask(self) -> Trial[I]: A config to sample. """ ... + + @classmethod + def preferred_parser( + cls, + ) -> str | Callable[Concatenate[Node, ...], Any] | Callable[[Node], Any] | None: + """The preferred parser for this optimizer. + + !!! note + + Subclasses should override this as required. + + """ + return None diff --git a/tests/configspace/__init__.py b/src/amltk/optimization/optimizers/__init__.py similarity index 100% rename from tests/configspace/__init__.py rename to src/amltk/optimization/optimizers/__init__.py diff --git a/src/amltk/neps/optimizer.py b/src/amltk/optimization/optimizers/neps.py similarity index 68% rename from src/amltk/neps/optimizer.py rename to src/amltk/optimization/optimizers/neps.py index ce4c22b8..a7e90790 100644 --- a/src/amltk/neps/optimizer.py +++ b/src/amltk/optimization/optimizers/neps.py @@ -1,15 +1,124 @@ -"""A thin wrapper around NEPS to make it easier to use with AutoMLToolkit. +"""The [`NEPSOptimizer`][amltk.optimization.optimizers.neps.NEPSOptimizer], +is a wrapper around the [`NePs`](https://github.com/automl/neps) optimizer. -TODO: More description and explanation with examples. -""" +!!! tip "Requirements" + + This requires `smac` which can be installed with: + + ```bash + pip install amltk[neps] + + # Or directly + pip install neural-pipeline-search + ``` + +This uses `ConfigSpace` as its [`search_space()`][amltk.pipeline.Node.search_space] to +optimize. Please see +[amltk.pipeline.parsers.configspace][amltk.pipeline.parsers.configspace.parser] +and the [search space reference](site:reference/pipeline/spaces.md) for more. + +Users should report results using +[`trial.success(loss=...)`][amltk.optimization.Trial.success] +where `loss=` is a scaler value to **minimize**. Optionally, +you can also return a `cost=` which is used for more budget aware algorithms. +Again, please see NeP's documentation for more. + +!!! warning "Conditionals in ConfigSpace" + + NePs does not support conditionals in its search space. This is account + for when using the + [`preferred_parser()`][amltk.optimization.optimizers.neps.NEPSOptimizer.preferred_parser]. + during search space creation. In this case, it will simply remove all conditionals + from the search space, which may not be ideal for the given problem at hand. + +Visit their documentation for what you can pass to +[`NEPSOptimizer.create()`][amltk.optimization.optimizers.neps.NEPSOptimizer.create]. + +The below example shows how you can use neps to optimize an sklearn pipeline. + +```python exec="True" source="material-block" result="python" +from __future__ import annotations + +import logging + +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import accuracy_score +from sklearn.model_selection import train_test_split + +from amltk.optimization.optimizers.neps import NEPSOptimizer +from amltk.scheduling import Scheduler +from amltk.optimization import History, Trial +from amltk.pipeline import Component + +logging.basicConfig(level=logging.INFO) + + +def target_function(trial: Trial, pipeline: Pipeline) -> Trial.Report: + X, y = load_iris(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split(X, y) + clf = pipeline.configure(trial.config).build("sklearn") + + with trial.begin(): + clf.fit(X_train, y_train) + y_pred = clf.predict(X_test) + accuracy = accuracy_score(y_test, y_pred) + loss = 1 - accuracy + return trial.success(loss=loss, accuracy=accuracy) + + return trial.fail() +from amltk._doc import make_picklable; make_picklable(target_function) # markdown-exec: hide + + +pipeline = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) +space = pipeline.search_space(parser=NEPSOptimizer.preferred_parser()) +optimizer = NEPSOptimizer.create(space=space) + +N_WORKERS = 2 +scheduler = Scheduler.with_processes(N_WORKERS) +task = scheduler.task(target_function) + +history = History() + +@scheduler.on_start(repeat=N_WORKERS) +def on_start(): + trial = optimizer.ask() + task.submit(trial, pipeline) + +@task.on_result +def tell_and_launch_trial(_, report: Trial.Report): + if scheduler.running(): + optimizer.tell(report) + trial = optimizer.ask() + task.submit(trial, pipeline) + +@task.on_result +def add_to_history(_, report: Trial.Report): + history.add(report) + +scheduler.run(timeout=3, wait=False) + +print(history.df()) +``` + +!!! todo "Deep Learning" + + Write an example demonstrating NEPS with continuations + +!!! todo "Graph Search Spaces" + + Write an example demonstrating NEPS with its graph search spaces +""" # noqa: E501 from __future__ import annotations import logging import shutil +from collections.abc import Mapping from copy import deepcopy from dataclasses import dataclass +from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Mapping +from typing import TYPE_CHECKING, Any, Literal, Protocol from typing_extensions import override import metahyper.api @@ -20,12 +129,30 @@ from neps.search_spaces.search_space import SearchSpace, pipeline_space_from_configspace from amltk.optimization import Optimizer, Trial +from amltk.pipeline.parsers.configspace import parser as configspace_parser if TYPE_CHECKING: from typing_extensions import Self from neps.api import BaseOptimizer + from amltk.pipeline import Node + from amltk.store import Bucket + + class NEPSPreferredParser(Protocol): + """The preferred parser call signature for NEPSOptimizer.""" + + def __call__( + self, + node: Node, + *, + seed: int | None = None, + flat: bool = False, + delim: str = ":", + ) -> ConfigurationSpace: + """See [`configspace_parser`][amltk.pipeline.parsers.configspace.parser].""" + ... + logger = logging.getLogger(__name__) @@ -113,6 +240,7 @@ def __init__( space: SearchSpace, optimizer: BaseOptimizer, working_dir: Path, + bucket: Bucket | None = None, ignore_errors: bool = True, loss_value_on_error: float | None = None, cost_value_on_error: float | None = None, @@ -123,11 +251,12 @@ def __init__( space: The space to use. optimizer: The optimizer to use. working_dir: The directory to use for the optimization. + bucket: The bucket to give to trials generated from this optimizer. ignore_errors: Whether the optimizers should ignore errors from trials. loss_value_on_error: The value to use for the loss if the trial fails. cost_value_on_error: The value to use for the cost if the trial fails. """ - super().__init__() + super().__init__(bucket=bucket) self.space = space self.optimizer = optimizer self.working_dir = working_dir @@ -143,7 +272,7 @@ def __init__( self.base_result_directory.mkdir(parents=True, exist_ok=True) @classmethod - def create( + def create( # noqa: PLR0913 cls, *, space: ( @@ -151,9 +280,10 @@ def create( | ConfigurationSpace | Mapping[str, ConfigurationSpace | Parameter] ), + bucket: Bucket | None = None, searcher: str | BaseOptimizer = "default", working_dir: str | Path = "neps", - overwrite: bool = False, + overwrite: bool = True, loss_value_on_error: float | None = None, cost_value_on_error: float | None = None, max_cost_total: float | None = None, @@ -164,6 +294,7 @@ def create( Args: space: The space to use. + bucket: The bucket to give to trials generated by this optimizer. searcher: The searcher to use. working_dir: The directory to use for the optimization. overwrite: Whether to overwrite the working directory if it exists. @@ -202,6 +333,7 @@ def create( return cls( space=space, + bucket=bucket, optimizer=searcher, working_dir=working_dir, loss_value_on_error=loss_value_on_error, @@ -242,7 +374,13 @@ def ask(self) -> Trial[NEPSTrialInfo]: pipeline_directory=pipeline_directory, previous_pipeline_directory=previous_pipeline_directory, ) - trial = Trial(name=info.name, config=info.config, info=info, seed=None) + trial = Trial( + name=info.name, + config=info.config, + info=info, + seed=None, + bucket=self.bucket, + ) logger.debug(f"Asked for trial {trial.name}") return trial @@ -312,3 +450,11 @@ def tell(self, report: Trial.Report[NEPSTrialInfo]) -> None: config_metadata = self.serializer.load(info.pipeline_directory / "metadata") config_metadata.update(metadata) self.serializer.dump(config_metadata, info.pipeline_directory / "metadata") + + @override + @classmethod + def preferred_parser(cls) -> NEPSPreferredParser: + """The preferred parser for this optimizer.""" + # TODO: We might want a custom one for neps.SearchSpace, for now we will + # use config space but without conditions as NePs doesn't support conditionals + return partial(configspace_parser, conditionals=False) diff --git a/src/amltk/optimization/optimizers/optuna.py b/src/amltk/optimization/optimizers/optuna.py new file mode 100644 index 00000000..ea27cf1c --- /dev/null +++ b/src/amltk/optimization/optimizers/optuna.py @@ -0,0 +1,279 @@ +"""[Optuna](https://optuna.org/) is an automatic hyperparameter optimization +software framework, particularly designed for machine learning. + +!!! tip "Requirements" + + This requires `Optuna` which can be installed with: + + ```bash + pip install amltk[optuna] + + # Or directly + pip install optuna + ``` + +We provide a thin wrapper called +[`OptunaOptimizer`][amltk.optimization.optimizers.optuna.OptunaOptimizer] from which +you can integrate `Optuna` into your workflow. + +This uses an Optuna-like [`search_space()`][amltk.pipeline.Node.search_space] for +its optimization, please see the +[search space reference](site:reference/pipeline/spaces.md) for more. + +Users should report results using +[`trial.success()`][amltk.optimization.Trial.success] +with either `cost=` or `values=` depending on any optimization directions +given to the underyling optimizer created. Please see their documentation +for more. + +Visit their documentation for what you can pass to +[`OptunaOptimizer.create()`][amltk.optimization.optimizers.optuna.OptunaOptimizer.create], +which is forward to [`optun.create_study()`][optuna.create_study]. + +```python exec="True" source="material-block" result="python" +from __future__ import annotations + +import logging + +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import accuracy_score +from sklearn.model_selection import train_test_split + +from amltk.optimization.optimizers.optuna import OptunaOptimizer +from amltk.scheduling import Scheduler +from amltk.optimization import History, Trial +from amltk.pipeline import Component + +logging.basicConfig(level=logging.INFO) + + +def target_function(trial: Trial, pipeline: Pipeline) -> Trial.Report: + X, y = load_iris(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split(X, y) + clf = pipeline.configure(trial.config).build("sklearn") + + with trial.begin(): + clf.fit(X_train, y_train) + y_pred = clf.predict(X_test) + accuracy = accuracy_score(y_test, y_pred) + + trial.summary["accuracy"] = accuracy + return trial.success(cost=1-accuracy) + + return trial.fail() +from amltk._doc import make_picklable; make_picklable(target_function) # markdown-exec: hide + + +pipeline = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) +space = pipeline.search_space(parser=OptunaOptimizer.preferred_parser()) +optimizer = OptunaOptimizer.create(space=space) + +N_WORKERS = 2 +scheduler = Scheduler.with_processes(N_WORKERS) +task = scheduler.task(target_function) + +history = History() + +@scheduler.on_start(repeat=N_WORKERS) +def on_start(): + trial = optimizer.ask() + task.submit(trial, pipeline) + +@task.on_result +def tell_and_launch_trial(_, report: Trial.Report): + if scheduler.running(): + optimizer.tell(report) + trial = optimizer.ask() + task.submit(trial, pipeline) + + +@task.on_result +def add_to_history(_, report: Trial.Report): + history.add(report) + +scheduler.run(timeout=3, wait=False) + +print(history.df()) +``` + +!!! todo "Some more documentation" + + Sorry! + +""" # noqa: E501 +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any +from typing_extensions import Self, override + +import optuna +from optuna.study import Study, StudyDirection +from optuna.trial import ( + Trial as OptunaTrial, + TrialState, +) + +from amltk.optimization import Optimizer, Trial +from amltk.pipeline.parsers.optuna import parser + +if TYPE_CHECKING: + from typing import Protocol + + from amltk.pipeline import Node + from amltk.pipeline.parsers.optuna import OptunaSearchSpace + from amltk.store import Bucket + + class OptunaParser(Protocol): + """A protocol for Optuna search space parser.""" + + def __call__( + self, + node: Node, + *, + flat: bool = False, + delim: str = ":", + ) -> OptunaSearchSpace: + """See [`optuna_parser`][amltk.pipeline.parsers.optuna.parser].""" + ... + + +class OptunaOptimizer(Optimizer[OptunaTrial]): + """An optimizer that uses Optuna to optimize a search space.""" + + @override + def __init__( + self, + *, + study: Study, + bucket: Bucket | None = None, + space: OptunaSearchSpace, + ) -> None: + """Initialize the optimizer. + + Args: + study: The Optuna Study to use. + bucket: The bucket given to trials generated by this optimizer. + space: Defines the current search space. + """ + super().__init__(bucket=bucket) + self.study = study + self.space = space + + @classmethod + def create( + cls, + *, + space: OptunaSearchSpace, + bucket: Bucket | None = None, + **kwargs: Any, + ) -> Self: + """Create a new Optuna optimizer. For more information, check Optuna + documentation + [here](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html#). + + Args: + space: Defines the current search space. + bucket: The bucket given to trials generated by this optimizer. + **kwargs: Additional arguments to pass to + [`optuna.create_study`][optuna.create_study]. + + Returns: + Self: The newly created optimizer. + """ + study = optuna.create_study(**kwargs) + return cls(study=study, space=space, bucket=bucket) + + @override + def ask(self) -> Trial[OptunaTrial]: + """Ask the optimizer for a new config. + + Returns: + The trial info for the new config. + """ + optuna_trial = self.study.ask(self.space) + config = optuna_trial.params + trial_number = optuna_trial.number + unique_name = f"{trial_number=}" + return Trial( + name=unique_name, + config=config, + info=optuna_trial, + bucket=self.bucket, + ) + + @override + def tell(self, report: Trial.Report[OptunaTrial]) -> None: + """Tell the optimizer the result of the sampled config. + + Args: + report: The report of the trial. + """ + trial = report.trial.info + assert trial is not None + + if report.status is Trial.Status.SUCCESS: + trial_state = TrialState.COMPLETE + values = self._verify_success_report_values(report) + else: + trial_state = TrialState.FAIL + values = None + + self.study.tell(trial=trial, values=values, state=trial_state) + + def _verify_success_report_values( + self, + report: Trial.Report[OptunaTrial], + ) -> float | Sequence[float]: + """Verify that the report is valid. + + Args: + report: The report to check. + + Raises: + ValueError: If both "cost" and "values" reported or + if the study direction is not "minimize" and "cost" is reported. + """ + if "cost" in report.results and "values" in report.results: + raise ValueError( + "Both 'cost' and 'values' were provided in the report. " + "Only one of them should be provided.", + ) + + if "cost" not in report.results and "values" not in report.results: + raise ValueError( + "Neither 'cost' nor 'values' were provided in the report. " + "At least one of them should be provided.", + ) + + directions = self.study.directions + + values = None + if "cost" in report.results: + if not all(direct == StudyDirection.MINIMIZE for direct in directions): + raise ValueError( + "The study direction is not 'minimize'," + " but 'cost' was provided in the report.", + ) + values = report.results["cost"] + else: + values = report.results["values"] + + if not ( + isinstance(values, float | int) + or ( + isinstance(values, Sequence) + and all(isinstance(value, float | int) for value in values) + ) + ): + raise ValueError( + f"Reported {values=} should be float or a sequence of floats", + ) + + return values + + @override + @classmethod + def preferred_parser(cls) -> OptunaParser: + return parser diff --git a/src/amltk/smac/optimizer.py b/src/amltk/optimization/optimizers/smac.py similarity index 68% rename from src/amltk/smac/optimizer.py rename to src/amltk/optimization/optimizers/smac.py index e2dd0f4e..c7aa329d 100644 --- a/src/amltk/smac/optimizer.py +++ b/src/amltk/optimization/optimizers/smac.py @@ -1,14 +1,104 @@ -"""A thin wrapper around SMAC to make it easier to use with AutoMLToolkit. +"""The [`SMACOptimizer`][amltk.optimization.optimizers.smac.SMACOptimizer], +is a wrapper around the [`smac`](https://github.com/automl/smac3) optimizer. -TODO: More description and explanation with examples. -""" +!!! tip "Requirements" + + This requires `smac` which can be installed with: + + ```bash + pip install amltk[smac] + + # Or directly + pip install smac + ``` + +This uses `ConfigSpace` as its [`search_space()`][amltk.pipeline.Node.search_space] to +optimize. Please see +the [search space reference](site:reference/pipeline/spaces.md) for more. + +Users should report results using +[`trial.success(cost=...)`][amltk.optimization.Trial.success] +where `cost=` is a scaler value or a list of scaler values for the objective +to **minimize**. + +Visit their documentation for what you can pass to +[`SMACOptimizer.create()`][amltk.optimization.optimizers.smac.SMACOptimizer.create]. + +The below example shows how you can use SMAC to optimize an sklearn pipeline. + +```python exec="True" source="material-block" result="python" +from __future__ import annotations + +import logging + +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import accuracy_score +from sklearn.model_selection import train_test_split + +from amltk.optimization.optimizers.smac import SMACOptimizer +from amltk.scheduling import Scheduler +from amltk.optimization import History, Trial +from amltk.pipeline import Component + +logging.basicConfig(level=logging.INFO) + + +def target_function(trial: Trial, pipeline: Pipeline) -> Trial.Report: + X, y = load_iris(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split(X, y) + clf = pipeline.configure(trial.config).build("sklearn") + + with trial.begin(): + clf.fit(X_train, y_train) + y_pred = clf.predict(X_test) + accuracy = accuracy_score(y_test, y_pred) + loss = 1 - accuracy + return trial.success(loss=loss, accuracy=accuracy) + + return trial.fail() +from amltk._doc import make_picklable; make_picklable(target_function) # markdown-exec: hide + + +pipeline = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) +space = pipeline.search_space(parser=SMACOptimizer.preferred_parser()) +optimizer = SMACOptimizer.create(space=space) + +N_WORKERS = 2 +scheduler = Scheduler.with_processes(N_WORKERS) +task = scheduler.task(target_function) + +history = History() + +@scheduler.on_start(repeat=N_WORKERS) +def on_start(): + trial = optimizer.ask() + task.submit(trial, pipeline) + +@task.on_result +def tell_and_launch_trial(_, report: Trial.Report): + if scheduler.running(): + optimizer.tell(report) + trial = optimizer.ask() + task.submit(trial, pipeline) + +@task.on_result +def add_to_history(_, report: Trial.Report): + history.add(report) + +scheduler.run(timeout=3, wait=False) + +print(history.df()) +""" # noqa: E501 from __future__ import annotations import logging -from typing import TYPE_CHECKING, Literal, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Literal from typing_extensions import override import numpy as np +from pynisher import MemoryLimitException, TimeoutException from smac import HyperparameterOptimizationFacade, MultiFidelityFacade, Scenario from smac.runhistory import ( StatusType, @@ -18,7 +108,6 @@ from amltk.optimization import Optimizer, Trial from amltk.randomness import as_int -from pynisher import MemoryLimitException, TimeoutException if TYPE_CHECKING: from pathlib import Path @@ -27,6 +116,7 @@ from ConfigSpace import ConfigurationSpace from smac.facade import AbstractFacade + from amltk.store import Bucket from amltk.types import FidT, Seed @@ -40,14 +130,17 @@ def __init__( self, *, facade: AbstractFacade, + bucket: Bucket | None = None, fidelities: Mapping[str, FidT] | None = None, ) -> None: """Initialize the optimizer. Args: facade: The SMAC facade to use. + bucket: The bucket given to trials generated by this optimizer. fidelities: The fidelities to use, if any. """ + super().__init__(bucket=bucket) self.facade = facade self.fidelities = fidelities @@ -56,6 +149,7 @@ def create( cls, *, space: ConfigurationSpace, + bucket: Bucket | None = None, seed: Seed | None = None, fidelities: Mapping[str, FidT] | None = None, continue_from_last_run: bool = False, @@ -66,6 +160,7 @@ def create( Args: space: The config space to optimize. + bucket: The bucket given to trials generated by this optimizer. seed: The seed to use for the optimizer. fidelities: The fidelities to use, if any. continue_from_last_run: Whether to continue from a previous run. @@ -101,7 +196,7 @@ def create( overwrite=not continue_from_last_run, logging_level=logging_level, ) - return cls(facade=facade, fidelities=fidelities) + return cls(facade=facade, fidelities=fidelities, bucket=bucket) @override def ask(self) -> Trial[SMACTrialInfo]: @@ -133,6 +228,7 @@ def ask(self) -> Trial[SMACTrialInfo]: info=smac_trial_info, seed=seed, fidelities=trial_fids, + bucket=self.bucket, ) logger.debug(f"Asked for trial {trial.name}") return trial @@ -144,6 +240,7 @@ def tell(self, report: Trial.Report[SMACTrialInfo]) -> None: # noqa: PLR0912, C Args: report: The report of the trial. """ + assert report.trial.info is not None logger.debug(f"Telling report for trial {report.trial.name}") # If we're successful, get the cost and times and report them if report.status is Trial.Status.SUCCESS: @@ -154,7 +251,7 @@ def tell(self, report: Trial.Report[SMACTrialInfo]) -> None: # noqa: PLR0912, C ) reported_costs = report.results["cost"] - if isinstance(reported_costs, (np.number, int, float)): + if isinstance(reported_costs, np.number | int | float): reported_costs = float(reported_costs) elif isinstance(reported_costs, Sequence): reported_costs = [float(c) for c in reported_costs] @@ -236,3 +333,9 @@ def tell(self, report: Trial.Report[SMACTrialInfo]) -> None: # noqa: PLR0912, C additional_info=additional_info, ) self.facade.tell(info=report.trial.info, value=trial_value, save=True) + + @override + @classmethod + def preferred_parser(cls) -> Literal["configspace"]: + """The preferred parser for this optimizer.""" + return "configspace" diff --git a/src/amltk/optimization/random_search.py b/src/amltk/optimization/random_search.py deleted file mode 100644 index 24b641f2..00000000 --- a/src/amltk/optimization/random_search.py +++ /dev/null @@ -1,205 +0,0 @@ -"""A simple random search optimizer. - -This optimizer will sample from the space provided and return the results -without doing anything with them. -""" -from __future__ import annotations - -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, TypeVar -from typing_extensions import ParamSpec, override - -from amltk.optimization.optimizer import Optimizer -from amltk.optimization.trial import Trial -from amltk.pipeline.sampler import Sampler -from amltk.randomness import as_rng -from amltk.types import Space - -if TYPE_CHECKING: - from amltk.types import Config, Seed - -P = ParamSpec("P") -Q = ParamSpec("Q") -Result = TypeVar("Result") - -MAX_INT = 2**32 - - -@dataclass -class RSTrialInfo: - """The information about a random search trial. - - Args: - name: The name of the trial. - trial_number: The number of the trial. - config: The configuration sampled from the space. - """ - - name: str - trial_number: int - config: Config - - -class RandomSearch(Optimizer[RSTrialInfo]): - """A random search optimizer.""" - - def __init__( - self, - *, - space: Space, - sampler: ( - Sampler[Space] - | type[Sampler[Space]] - | Callable[[Space, int], Config] - | None - ) = None, - seed: Seed | None = None, - duplicates: bool = False, - max_sample_attempts: int = 50, - ): - """Initialize the optimizer. - - Args: - space: The space to sample from. - sampler: The sampler to use to sample from the space. - If not provided, the sampler will be automatically found. - - * If a `Sampler` is provided, it will be used to sample from the - space. - * If a `Callable` is provided, it will be used to sample from the space. - - ```python - def my_callable_sampler(space, seed: int) -> Config: ... - ``` - - !!! warning "Deterministic behaviour" - - This should return the same set of configurations given the same - seed for fully defined behaviour. - - - seed: The seed to use for the sampler. - duplicates: Whether to allow duplicate configurations. - max_sample_attempts: The maximum number of attempts to sample a - unique configuration. If this number is exceeded, an - `ExhaustedError` will be raised. This parameter has no - effect when `duplicates=True`. - """ - super().__init__() - self.space = space - self.trial_count = 0 - self.seed = as_rng(seed) if seed is not None else None - self.max_sample_attempts = max_sample_attempts - - # We store any configs we've seen to prevent duplicates - self._configs_seen: list[Config] | None = [] if not duplicates else None - - if sampler is None: - sampler = Sampler.find(space) - if sampler is None: - extra = "You can also provide a custom function to `sample=`." - raise Sampler.NoSamplerFoundError(space, extra=extra) - self.sampler = sampler - - elif isinstance(sampler, type) and issubclass(sampler, Sampler): - self.sampler = sampler() - elif isinstance(sampler, Sampler): - self.sampler = sampler - elif callable(sampler): - self.sampler = FunctionalSampler(sampler) - else: - raise ValueError( - f"Expected `sampler` to be a `Sampler` or `Callable`, got {sampler=}.", - ) - - @override - def ask(self) -> Trial[RSTrialInfo]: - """Sample from the space. - - Raises: - ExhaustedError: If the sampler is exhausted of unique configs. - Only possible to raise if `duplicates=False` (default). - """ - name = f"random-{self.trial_count}" - - try: - config = self.sampler.sample( - self.space, - seed=self.seed, - duplicates=self._configs_seen, # type: ignore - max_attempts=self.max_sample_attempts, - ) - except Sampler.GenerateUniqueConfigError as e: - raise self.ExhaustedError(space=self.space) from e - - if self._configs_seen is not None: - self._configs_seen.append(config) - - info = RSTrialInfo(name, self.trial_count, config) - trial = Trial( - name=name, - config=config, - info=info, - seed=self.seed.integers(MAX_INT) if self.seed is not None else None, - ) - self.trial_count = self.trial_count + 1 - return trial - - @override - def tell(self, report: Trial.Report[RSTrialInfo]) -> None: - """Do nothing with the report. - - ???+ note - We do nothing with the report as it's random search - and does not use the report to do anything useful. - """ - - class ExhaustedError(RuntimeError): - """Raised when the sampler is exhausted of unique configs.""" - - def __init__(self, space: Any): - """Initialize the error.""" - super().__init__(space) - self.space = space - - @override - def __str__(self) -> str: - return ( - f"Exhausted all unique configs in the space {self.space}." - " Consider bumping up `max_sample_attempts=` or handling this" - " error case." - ) - - -@dataclass -class FunctionalSampler(Sampler[Space]): - """A wrapper for a functional sampler for use in - [`RandomSearch`][amltk.optimization.RandomSearch]. - - Attributes: - f: The functional sampler to use. - """ - - f: Callable[[Space, int], Config] - - @override - @classmethod - def supports_sampling(cls, space: Any) -> bool: - """Defaults to True for all spaces.""" - return True - - @override - def copy(self, space: Space) -> Space: - """Attempts it's best with a deepcopy.""" - return deepcopy(space) - - @override - def _sample( - self, - space: Space, - n: int = 1, - seed: Seed | None = None, - ) -> list[Config]: - rng = as_rng(seed) - return [self.f(space, rng.integers(MAX_INT)) for _ in range(n)] diff --git a/src/amltk/optimization/trial.py b/src/amltk/optimization/trial.py index 89d75b5f..92589736 100644 --- a/src/amltk/optimization/trial.py +++ b/src/amltk/optimization/trial.py @@ -1,12 +1,35 @@ -"""A trial for an optimization task. +"""A [`Trial`][amltk.optimization.Trial] is +typically the output of +[`Optimizer.ask()`][amltk.optimization.Optimizer.ask], indicating +what the optimizer would like to evaluate next. We provide a host +of convenience methods attached to the `Trial` to make it easy to +save results, store artifacts, and more. + +Paired with the `Trial` is the [`Trial.Report`][amltk.optimization.Trial.Report], +class, providing an easy way to report back to the optimizer's +[`tell()`][amltk.optimization.Optimizer.tell] with +a simple [`trial.success(cost=...)`][amltk.optimization.Trial.success] or +[`trial.fail(cost=...)`][amltk.optimization.Trial.fail] call.. + +### Trial + +::: amltk.optimization.trial.Trial + options: + members: False + +### Report + +::: amltk.optimization.trial.Trial + options: + members: False -TODO: Populate more here. """ from __future__ import annotations import copy import logging import traceback +from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum @@ -14,12 +37,8 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Generic, - Iterable, - Iterator, Literal, - Mapping, TypeVar, overload, ) @@ -27,15 +46,15 @@ import numpy as np import pandas as pd +from rich.text import Text -from amltk.functional import dict_get_not_none, mapping_select, prefix_keys +from amltk._functional import dict_get_not_none, mapping_select, prefix_keys from amltk.profiling import Memory, Profile, Profiler, Timer from amltk.store import Bucket, PathBucket if TYPE_CHECKING: from rich.console import RenderableType from rich.panel import Panel - from rich.text import Text # Inner trial info object I = TypeVar("I") # noqa: E741 @@ -56,20 +75,111 @@ @dataclass class Trial(Generic[I]): - """A trial as suggested by an optimizer. + """A [`Trial`][amltk.optimization.Trial] encapsulates some configuration + that needs to be evaluated. Typically this is what is generated by an + [`Optimizer.ask()`][amltk.optimization.Optimizer.ask] call. + + ??? tip "Usage" + + To begin a trial, you can use the + [`trial.begin()`][amltk.optimization.Trial.begin], which will catch + exceptions/traceback and profile the block of code. + + If all went smooth, your trial was successful and you can use + [`trial.success()`][amltk.optimization.Trial.success] to generate + a success [`Report`][amltk.optimization.Trial.Report], typically + passing what your chosen optimizer expects, e.g. `"loss"` or `"cost"`. + + If your trial failed, you can instead use the + [`trial.fail()`][amltk.optimization.Trial.fail] to generate a + failure [`Report`][amltk.optimization.Trial.Report], where + any caught exception will be attached to it. Each + [`Optimizer`][amltk.optimization.Optimizer] will take care of what to do + from here. + + ```python exec="true" source="material-block" html="true" + from amltk.optimization import Trial + from amltk.store import PathBucket + + def target_function(trial: Trial) -> Trial.Report: + x = trial.config["x"] + y = trial.config["y"] + + with trial.begin(): + cost = x**2 - y + + if trial.exception: + return trial.fail() - You can modify the Trial as you see fit, specifically - `.summary` which are for recording any information you may like. + return trial.success(cost=cost) + + # ... usually obtained from an optimizer + bucket = PathBucket("all-trial-results") + trial = Trial(name="some-unique-name", config={"x": 1, "y": 2}, bucket=bucket) + + report = target_function(trial) + print(report.df()) + bucket.rmdir() # markdon-exec: hide + ``` + + Some important properties is that they have a unique + [`.name`][amltk.optimization.Trial.name] given the optimization run, + a candidate [`.config`][amltk.optimization.Trial.config]' to evaluate, + a possible [`.seed`][amltk.optimization.Trial.seed] to use, + and an [`.info`][amltk.optimization.Trial.info] object which is the optimizer + specific information, if required by you. + + If using [`Plugins`][amltk.scheduling.plugins.Plugin], they may insert + some extra objects in the [`.extra`][amltk.optimization.Trial.extras] dict. + + To profile your trial, you can wrap the logic you'd like to check with + [`trial.begin()`][amltk.optimization.Trial.begin], which will automatically + catch any errors, record the traceback, and profile the block of code, in + terms of time and memory. + + You can access the profiled time and memory using the + [`.time`][amltk.optimization.Trial.time] and + [`.memory`][amltk.optimization.Trial.memory] attributes. + If you've [`profile()`][amltk.optimization.Trial.profile]'ed any other intervals, + you can access them by name through + [`trial.profiles`][amltk.optimization.Trial.profiles]. + Please see the [profiling reference](site:reference/optimization/profiling.md) + for more. + + ??? example "Profiling with a trial." + + ```python exec="true" source="material-block" result="python" title="profile" + from amltk.optimization import Trial + + trial = Trial(name="some-unique-name", config={}) + + # ... somewhere where you've begun your trial. + with trial.profile("some_interval"): + for work in range(100): + pass + + print(trial.profiler.df()) + ``` - The other attributes will be automatically set, such - as `.profile` and `.exception`, which are capture - using [`trial.begin()`][amltk.optimization.trial.Trial.begin]. + You can also record anything you'd like into the + [`.summary`][amltk.optimization.Trial.summary], a plain `#!python dict` + or use [`trial.store()`][amltk.optimization.Trial.store] to store artifacts + related to the trial. + + ??? tip "What to put in `.summary`?" + + For large items, e.g. predictions or models, these are highly advised to + [`.store()`][amltk.optimization.Trial.store] to disk, especially if using + a `Task` for multiprocessing. + + Further, if serializing the report using the + [`report.df()`][amltk.optimization.Trial.Report.df], + returning a single row, + or a [`History`][amltk.optimization.History] + with [`history.df()`][amltk.optimization.History.df] for a dataframe consisting + of many of the reports, then you'd likely only want to store things + that are scalar and can be serialised to disk by a pandas DataFrame. - Args: - name: The unique name of the trial. - config: The config for the trial. - info: The info of the trial. - seed: The seed to use if suggested by the optimizer. """ name: str @@ -78,12 +188,15 @@ class Trial(Generic[I]): config: Mapping[str, Any] """The config of the trial provided by the optimizer.""" - info: I = field(repr=False) + info: I | None = field(default=None, repr=False) """The info of the trial provided by the optimizer.""" seed: int | None = None """The seed to use if suggested by the optimizer.""" + bucket: Bucket | None = None + """The bucket to store trial related output to.""" + fidelities: dict[str, Any] | None = None """The fidelities at which to evaluate the trial, if any.""" @@ -119,8 +232,8 @@ class Trial(Generic[I]): used to retrieve them later, such as a Path. """ - plugins: dict[str, Any] = field(default_factory=dict) - """Any plugins attached to the trial.""" + extras: dict[str, Any] = field(default_factory=dict) + """Any extras attached to the trial.""" @property def profiles(self) -> Mapping[str, Profile.Interval]: @@ -314,28 +427,29 @@ def store( self, items: Mapping[str, T], *, - where: str | Path | Bucket | Callable[[str, Mapping[str, T]], None], + where: ( + str | Path | Bucket | Callable[[str, Mapping[str, T]], None] | None + ) = None, ) -> None: """Store items related to the trial. ```python exec="true" source="material-block" result="python" title="store" hl_lines="5" from amltk.optimization import Trial + from amltk.store import PathBucket - trial = Trial(name="trial", config={"x": 1}, info={}) - - trial.store({"config.json": trial.config}, where="./results") + trial = Trial(name="trial", config={"x": 1}, info={}, bucket=PathBucket("results")) + trial.store({"config.json": trial.config}) print(trial.storage) ``` - You could also create a Bucket and use that instead. + You could also specify `where=` exactly to store the thing - ```python exec="true" source="material-block" result="python" title="store-bucket" hl_lines="8" + ```python exec="true" source="material-block" result="python" title="store-bucket" hl_lines="7" from amltk.optimization import Trial from amltk.store import PathBucket bucket = PathBucket("results") - trial = Trial(name="trial", config={"x": 1}, info={}) trial.store({"config.json": trial.config}, where=bucket) @@ -345,7 +459,7 @@ def store( Args: items: The items to store, a dict from the key to store it under - to the item itself. If using a `str`, `Path` or `PathBucket`, + to the item itself.If using a `str`, `Path` or `PathBucket`, the keys of the items should be a valid filename, including the correct extension. e.g. `#!python {"config.json": trial.config}` @@ -355,12 +469,20 @@ def store( a bucket will be created at the path, and the items will be stored in a sub-bucket with the name of the trial. - * If a `Bucket`, will store the items in a sub-bucket with the + * If a `Bucket`, will store the items **in a sub-bucket** with the name of the trial. * If a `Callable`, will call the callable with the name of the trial and the key-valued pair of items to store. """ # noqa: E501 + if where is None: + if self.bucket is None: + raise ValueError( + "Cannot store items without a bucket. Please specify a bucket" + " or set the bucket on the trial.", + ) + where = self.bucket + # If not a Callable, we convert to a path bucket method: Callable[[str, dict[str, Any]], None] | Bucket if isinstance(where, str): @@ -551,14 +673,14 @@ def retrieve( # Store in a sub-bucket return method.sub(self.name)[key].load(check=check) - def attach_plugin_item(self, name: str, plugin_item: Any) -> None: + def attach_extra(self, name: str, plugin_item: Any) -> None: """Attach a plugin item to the trial. Args: name: The name of the plugin item. plugin_item: The plugin item. """ - self.plugins[name] = plugin_item + self.extras[name] = plugin_item def rich_renderables(self) -> Iterable[RenderableType]: """The renderables for rich for this report.""" @@ -585,23 +707,23 @@ def rich_renderables(self) -> Iterable[RenderableType]: if any(self.storage): yield Panel(Pretty(self.storage), title="Storage", title_align="left") - if any(self.plugins): - yield Panel(Pretty(self.plugins), title="Plugins", title_align="left") + if any(self.extras): + yield Panel(Pretty(self.extras), title="Plugins", title_align="left") def __rich__(self) -> RenderableType: from rich.console import Group as RichGroup from rich.panel import Panel - from amltk.rich_util import key_with_paren_text + title = Text.assemble( + ("Trial", "bold"), + ("(", "default"), + (self.name, "italic"), + (")", "default"), + ) return Panel( RichGroup(*self.rich_renderables()), - title=key_with_paren_text( - "Trial", - self.name, - key_style="bold", - val_style="italic", - ), + title=title, title_align="left", ) @@ -638,7 +760,42 @@ def __rich__(self) -> Text: @dataclass class Report(Generic[I2]): - """A report for a trial.""" + """The [`Trial.Report`][amltk.optimization.Trial.Report] encapsulates + a [`Trial`][amltk.optimization.Trial], its status and any results/exceptions + that may have occured. + + Typically you will not create these yourself, but instead use + [`trial.success()`][amltk.optimization.Trial.success] or + [`trial.fail()`][amltk.optimization.Trial.fail] to generate them. + + ```python exec="true" source="material-block" result="python" + from amltk.optimization import Trial + + trial = Trial(name="trial", config={"x": 1}, info={}) + + with trial.begin(): + # Do some work + # ... + report: Trial.Report = trial.success(cost=1) + + print(report.df()) + ``` + + These reports are used to report back results to an + [`Optimizer`][amltk.optimization.Optimizer] + with [`Optimizer.tell()`][amltk.optimization.Optimizer.tell] but can also be + stored for your own uses. + + You can access the original trial with the + [`.trial`][amltk.optimization.Trial.Report.trial] attribute, and the + [`Status`][amltk.optimization.Trial.Status] of the trial with the + [`.status`][amltk.optimization.Trial.Report.status] attribute. + + You may also want to check out the [`History`][amltk.optimization.History] class + for storing a collection of `Report`s, allowing for an easier time to convert + them to a dataframe or perform some common Hyperparameter optimization parsing + of results. + """ trial: Trial[I2] """The trial that was run.""" @@ -853,7 +1010,9 @@ def store( self, items: Mapping[str, T], *, - where: str | Path | Bucket | Callable[[str, Mapping[str, T]], None], + where: ( + str | Path | Bucket | Callable[[str, Mapping[str, T]], None] | None + ) = None, ) -> None: """Store items related to the trial. @@ -958,23 +1117,23 @@ def from_dict(cls, d: Mapping[str, Any]) -> Trial.Report: def rich_renderables(self) -> Iterable[RenderableType]: """The renderables for rich for this report.""" - from amltk.rich_util import key_val_text - - yield key_val_text("Status", self.status.__rich__()) + yield Text.assemble( + ("Status", "bold"), + ("(", "default"), + self.status.__rich__(), + (")", "default"), + ) yield from self.trial.rich_renderables() def __rich__(self) -> Panel: from rich.console import Group as RichGroup from rich.panel import Panel - from amltk.rich_util import key_with_paren_text - - return Panel( - RichGroup(*self.rich_renderables()), - title=key_with_paren_text( - "Trial", - self.name, - key_style="bold", - val_style="italic", - ), + title = Text.assemble( + ("Trial", "bold"), + ("(", "default"), + (self.name, "italic"), + (")", "default"), ) + + return Panel(RichGroup(*self.rich_renderables()), title=title) diff --git a/src/amltk/options.py b/src/amltk/options.py index ee123d2a..572a1e25 100644 --- a/src/amltk/options.py +++ b/src/amltk/options.py @@ -5,9 +5,10 @@ """ from __future__ import annotations -from typing import Any, Callable, Literal, TypedDict, TypeVar, overload +from collections.abc import Callable +from typing import Any, Literal, TypedDict, TypeVar, overload -from amltk.links import sklearn_link_generator +from amltk._doc import sklearn_link_generator class AMLTKOptions(TypedDict): diff --git a/src/amltk/optuna/__init__.py b/src/amltk/optuna/__init__.py deleted file mode 100644 index 290dc792..00000000 --- a/src/amltk/optuna/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from amltk.optuna.optimizer import OptunaOptimizer -from amltk.optuna.space import OptunaSpaceAdapter - -OptunaParser = OptunaSpaceAdapter -OptunaSampler = OptunaSpaceAdapter - -__all__ = ["OptunaSpaceAdapter", "OptunaOptimizer", "OptunaParser", "OptunaSampler"] diff --git a/src/amltk/optuna/optimizer.py b/src/amltk/optuna/optimizer.py deleted file mode 100644 index 0fd0ced7..00000000 --- a/src/amltk/optuna/optimizer.py +++ /dev/null @@ -1,137 +0,0 @@ -"""A thin wrapper around Optuna to make it easier to use with AutoMLToolkit. - -Check our [integration documentation](site:reference/optuna.md#optimizer) -for more information. -""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Sequence - -import optuna -from amltk.optimization import Optimizer, Trial -from optuna.study import Study, StudyDirection -from optuna.trial import ( - Trial as OptunaTrial, - TrialState, -) - -if TYPE_CHECKING: - from typing_extensions import Self - - from amltk.optuna.space import OptunaSearchSpace - - -class OptunaOptimizer(Optimizer[OptunaTrial]): - """An optimizer that uses Optuna to optimize a search space.""" - - def __init__(self, *, study: Study, space: OptunaSearchSpace) -> None: - """Initialize the optimizer. - - Args: - study: The Optuna Study to use. - space: Defines the current search space. - """ - self.study = study - self.space = space - - @classmethod - def create( - cls, - *, - space: OptunaSearchSpace, - **kwargs: Any, - ) -> Self: - """Create a new Optuna optimizer. For more information, check Optuna - documentation - [here](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html#). - - Args: - space: Defines the current search space. - **kwargs: Additional arguments to pass to - [`optuna.create_study`][optuna.create_study]. - - Returns: - Self: The newly created optimizer. - """ - study = optuna.create_study(**kwargs) - return cls(study=study, space=space) - - def ask(self) -> Trial[OptunaTrial]: - """Ask the optimizer for a new config. - - Returns: - The trial info for the new config. - """ - optuna_trial = self.study.ask(self.space) - config = optuna_trial.params - trial_number = optuna_trial.number - unique_name = f"{trial_number=}" - return Trial(name=unique_name, config=config, info=optuna_trial) - - def tell(self, report: Trial.Report[OptunaTrial]) -> None: - """Tell the optimizer the result of the sampled config. - - Args: - report: The report of the trial. - """ - trial = report.trial.info - - if report.status is Trial.Status.SUCCESS: - trial_state = TrialState.COMPLETE - values = self._verify_success_report_values(report) - else: - trial_state = TrialState.FAIL - values = None - - self.study.tell(trial=trial, values=values, state=trial_state) - - def _verify_success_report_values( - self, - report: Trial.Report[OptunaTrial], - ) -> float | Sequence[float]: - """Verify that the report is valid. - - Args: - report: The report to check. - - Raises: - ValueError: If both "cost" and "values" reported or - if the study direction is not "minimize" and "cost" is reported. - """ - if "cost" in report.results and "values" in report.results: - raise ValueError( - "Both 'cost' and 'values' were provided in the report. " - "Only one of them should be provided.", - ) - - if "cost" not in report.results and "values" not in report.results: - raise ValueError( - "Neither 'cost' nor 'values' were provided in the report. " - "At least one of them should be provided.", - ) - - directions = self.study.directions - - values = None - if "cost" in report.results: - if not all(direct == StudyDirection.MINIMIZE for direct in directions): - raise ValueError( - "The study direction is not 'minimize'," - " but 'cost' was provided in the report.", - ) - values = report.results["cost"] - else: - values = report.results["values"] - - if not ( - isinstance(values, (float, int)) - or ( - isinstance(values, Sequence) - and all(isinstance(value, (float, int)) for value in values) - ) - ): - raise ValueError( - f"Reported {values=} should be float or a sequence of floats", - ) - - return values diff --git a/src/amltk/optuna/space.py b/src/amltk/optuna/space.py deleted file mode 100644 index 8512dc6d..00000000 --- a/src/amltk/optuna/space.py +++ /dev/null @@ -1,227 +0,0 @@ -"""A module for utilities for spaces defined by Optuna. - -The notable class is [`OptunaSpaceAdapter`][amltk.optuna.OptunaSpaceAdapter]. -It implements the [`SpaceAdapter`][amltk.pipeline.SpaceAdapter] interface -to provide utility to parse and sample from an Optuna space. - - -```python hl_lines="8 9 10 13 15 16" -from amltk.pipeline import step -from amltk.optuna import OptunaSpaceAdapter - -item = step( - "name", - ..., - space={ - "myint": (1, 10), # (1)! - "myfloat": (1.0, 10.0) # (2)! - "mycategorical": ["a", "b", "c"], # (3)! - } -) -adapter = OptunaSpaceAdapter() # (6)! - -optuna_space = item.space(parser=adapter) # (4)! -config = item.sample(sampler=adapter) # (5)! - -configured_item = item.configure(config) -configured_item.build() -``` - -1. `myint` will be an integer between 1 and 10. -2. `myfloat` will be a float between 1.0 and 10.0. -3. `mycategorical` will be a categorical variable with values `a`, `b`, and `c`. -4. Pass the `adapter` to [`space()`][amltk.pipeline.Step.space] to get the optuna space. - It will be a dictionary mapping the name of the hyperparameter to the optuna - distribution. -5. Pass the `adapter` to [`sample()`][amltk.pipeline.Step.sample] to get a sample - config. It will be a dictionary mapping the name of the hyperparameter to the - sampled value. -6. Create an instance of [`OptunaSpaceAdapter`][amltk.optuna.OptunaSpaceAdapter] which - will be used to parse the space and sample from it. - - !!! note "Note" - - The `OptunaSpaceAdapter` is a [`SpaceAdapter`][amltk.pipeline.SpaceAdapter] - which means that it can be used to parse any space and sample from it. It - implements the [`Parser`][amltk.pipeline.Parser] and - [`Sampler`][amltk.pipeline.Sampler] interfaces. -""" -from __future__ import annotations - -from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Mapping, Sequence, Union -from typing_extensions import override - -from amltk.configspace.space import as_int -from amltk.pipeline.space import SpaceAdapter -from optuna.distributions import ( - BaseDistribution, - CategoricalDistribution, - FloatDistribution, - IntDistribution, -) -from optuna.samplers import RandomSampler - -if TYPE_CHECKING: - from typing_extensions import TypeAlias - - from amltk.types import Config, Seed - -OptunaSearchSpace: TypeAlias = Dict[str, BaseDistribution] -InputSpace: TypeAlias = Union[Mapping[str, Any], OptunaSearchSpace] - - -class OptunaSpaceAdapter(SpaceAdapter[InputSpace, OptunaSearchSpace]): - """An Optuna adapter to allow for parsing Optuna spaces and sampling from them.""" - - @override - def parse_space( - self, - space: Any, - config: Mapping[str, Any] | None = None, - ) -> OptunaSearchSpace: - """See [`Parser.parse_space`][amltk.pipeline.Parser.parse_space].""" - if not isinstance(space, Mapping): - raise ValueError("Can only parse mappings with Optuna but got {space=}") - - parsed_space = { - name: self._convert_hp_to_optuna_distribution(name=name, hp=hp) - for name, hp in space.items() - } - for name, value in (config or {}).items(): - parsed_space[name] = CategoricalDistribution([value]) - - return parsed_space - - @override - def insert( - self, - space: OptunaSearchSpace, - subspace: InputSpace, - *, - prefix_delim: tuple[str, str] | None = None, - ) -> OptunaSearchSpace: - """See [`Parser.insert`][amltk.pipeline.Parser.insert].""" - if prefix_delim is None: - prefix_delim = ("", "") - - prefix, delim = prefix_delim - - space.update({f"{prefix}{delim}{name}": hp for name, hp in subspace.items()}) - - return space - - @override - def condition( - self, - choice_name: str, - delim: str, - spaces: dict[str, OptunaSearchSpace], - weights: Sequence[float] | None = None, - ) -> OptunaSearchSpace: - """See [`Parser.condition`][amltk.pipeline.Parser.condition].""" - # TODO(eddiebergman): Might be possible to implement this but it requires some - # toying around with options to various Samplers in the Optimizer used. - raise NotImplementedError( - f"Conditions (from {choice_name}) not supported with Optuna", - ) - - @override - def empty(self) -> OptunaSearchSpace: - """Return an empty space.""" - return {} - - @override - def copy(self, space: OptunaSearchSpace) -> OptunaSearchSpace: - """Copy the space.""" - return deepcopy(space) - - @override - def _sample( - self, - space: OptunaSearchSpace, - n: int = 1, - seed: Seed | None = None, - ) -> list[Config]: - """Sample n configs from the space. - - Args: - space: The space to sample from. - n: The number of configs to sample. - seed: The seed to use for sampling. - - Returns: - A list of configs sampled from the space. - """ - seed_int = as_int(seed) - sampler = RandomSampler(seed=seed_int) - - # Can be used because `sample_independant` doesn't use the study or trial - study: Any = None - trial: Any = None - - # Sample n configs - configs: list[Config] = [ - { - name: sampler.sample_independent(study, trial, name, dist) - for name, dist in space.items() - } - for _ in range(n) - ] - return configs - - @classmethod - def _convert_hp_to_optuna_distribution( - cls, - name: str, - hp: tuple | Sequence | int | str | float | BaseDistribution, - ) -> BaseDistribution: - if isinstance(hp, BaseDistribution): - return hp - - # If it's an allowed type, it's a constant - # TODO: Not sure if this makes sense to be honest - if isinstance(hp, (int, str, float)): - return CategoricalDistribution([hp]) - - if isinstance(hp, tuple) and len(hp) == 2: # noqa: PLR2004 - lower, upper = hp - if type(lower) != type(upper): - raise ValueError( - f"Expected {name} to have same type for lower and upper bound," - f"got lower: {type(lower)}, upper: {type(upper)}.", - ) - - if isinstance(lower, float): - return FloatDistribution(lower, upper) - - return IntDistribution(lower, upper) - - # Sequences - if isinstance(hp, Sequence): # type: ignore - if len(hp) == 0: - raise ValueError(f"Can't have empty list for categorical {name}") - - return CategoricalDistribution(hp) - - raise ValueError( - f"Expected hyperparameter value for {name} to be one of " - "tuple | list | int | str | float | Optuna.BaseDistribution," - f" got {type(hp)}", - ) - - @classmethod - @override - def supports_sampling(cls, space: Any) -> bool: - """Supports sampling from a mapping where every value is a - [`BaseDistribution`][optuna.distributions]. - - Args: - space: The space to check. - - Returns: - Whether the space is supported. - """ - return isinstance(space, Mapping) and all( - isinstance(hp, BaseDistribution) for hp in space.values() - ) diff --git a/src/amltk/pipeline/__init__.py b/src/amltk/pipeline/__init__.py index 799f4221..9f6f25e4 100644 --- a/src/amltk/pipeline/__init__.py +++ b/src/amltk/pipeline/__init__.py @@ -1,28 +1,26 @@ from __future__ import annotations -from amltk.pipeline.api import choice, group, request, searchable, split, step -from amltk.pipeline.components import Choice, Component, Group, Searchable, Split -from amltk.pipeline.parser import Parser -from amltk.pipeline.pipeline import Pipeline -from amltk.pipeline.sampler import Sampler -from amltk.pipeline.space import SpaceAdapter -from amltk.pipeline.step import Step +from amltk.pipeline.components import ( + Choice, + Component, + Fixed, + Join, + Searchable, + Sequential, + Split, + as_node, +) +from amltk.pipeline.node import Node, request __all__ = [ - "Pipeline", - "split", - "step", - "choice", - "searchable", - "Step", + "Node", "Component", "Split", "Choice", - "Parser", - "Sampler", - "SpaceAdapter", "Searchable", - "group", - "Group", + "Sequential", + "Fixed", + "Join", "request", + "as_node", ] diff --git a/src/amltk/pipeline/api.py b/src/amltk/pipeline/api.py deleted file mode 100644 index 2c0b2669..00000000 --- a/src/amltk/pipeline/api.py +++ /dev/null @@ -1,401 +0,0 @@ -"""The public api for pipeline, steps and components. - -Anything changing here is considering a major change -""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload - -from amltk.pipeline.components import Choice, Component, Group, Searchable, Split -from amltk.pipeline.step import ParamRequest, Step - -if TYPE_CHECKING: - from amltk.types import FidT - -Space = TypeVar("Space") -T = TypeVar("T") - - -@overload -def searchable( - name: str, - *, - space: None = None, - config: Mapping[str, Any] | None = None, - fidelities: Mapping[str, FidT] | None = ..., - meta: Mapping[str, Any] | None = ..., -) -> Searchable[None]: - ... - - -@overload -def searchable( - name: str, - *, - space: Space, - config: Mapping[str, Any] | None = None, - fidelities: Mapping[str, FidT] | None = ..., - meta: Mapping[str, Any] | None = ..., -) -> Searchable[Space]: - ... - - -def searchable( - name: str, - *, - space: Space | None = None, - config: Mapping[str, Any] | None = None, - fidelities: Mapping[str, FidT] | None = None, - meta: Mapping[str, Any] | None = None, -) -> Searchable[Space] | Searchable[None]: - """A set of searachble items. - - ```python - from amltk.pipeline import searchable - - s = searchable("parameters", space={"x": (-10.0, 10.0)}) - ``` - - Args: - name: The unique identifier for this set of searachables. - space: A space asscoiated with this searchable. - config: - A config of set values to pass. If any parameter here is also present in - the space, this will be removed from the space. - fidelities: - A fidelity associated with this searchable. This can be a single range - indicated as a tuple, an ordered list or a mapping from a name to - any of the above. - meta: Any metadata to associate with this - - Returns: - Step - """ - return Searchable( - name=name, - config=config, - search_space=space, - fidelity_space=fidelities, - meta=meta, - ) - - -@overload -def step( - name: str, - item: Callable[..., T], - *, - config: Mapping[str, Any] | None = ..., - fidelities: Mapping[str, FidT] | None = ..., - meta: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = None, -) -> Component[T, None]: - ... - - -@overload -def step( - name: str, - item: Any, - *, - config: Mapping[str, Any] | None = ..., - fidelities: Mapping[str, FidT] | None = ..., - meta: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = None, -) -> Component[Any, None]: - ... - - -@overload -def step( - name: str, - item: Callable[..., T], - *, - space: Space, - config: Mapping[str, Any] | None = ..., - fidelities: Mapping[str, FidT] | None = ..., - meta: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = None, -) -> Component[T, Space]: - ... - - -@overload -def step( - name: str, - item: Any, - *, - space: Space, - config: Mapping[str, Any] | None = ..., - fidelities: Mapping[str, FidT] | None = ..., - meta: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = None, -) -> Component[Any, Space]: - ... - - -def step( - name: str, - item: Callable[..., T] | Any, - *, - space: Space | None = None, - config: Mapping[str, Any] | None = None, - fidelities: Mapping[str, FidT] | None = None, - meta: Mapping[str, Any] | None = None, - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = None, -) -> ( - Component[T, Space] - | Component[T, None] - | Component[Any, Space] - | Component[Any, None] -): - """A step in a pipeline. - - Can be joined together with the `|` operator, creating a chain and returning - a new set of steps, with the first step still at the head. - - ```python - head = step("1", 1) | step("2", 2) | step("3", 3) - ``` - Note: - These are immutable steps, where operations on them will create new instances - with references to any content they may store. Equality between steps is based - solely on the name, not their contents or the steps they are linked to. - - For this reason, Pipelines expose a `validate` method that will check that - the steps in a pipeline are all uniquenly named. - - Args: - name: The unique identifier for this step. - item: The item for this step. - space: A space with which this step can be searched over. - config: - A config of set values to pass. If any parameter here is also present in - the space, this will be removed from the space. - fidelities: - A fidelity associated with this searchable. This can be a single range - indicated as a tuple, an ordered list or a mapping from a name to - any of the above. - meta: Any metadata to associate with this - config_transform: - A function that will be applied to the config before it is passed to the - item. This can be used to transform the config into a format that the item - expects. - - Returns: - The component describing this step - """ - return Component( - name=name, - item=item, - config=config, - search_space=space, - fidelity_space=fidelities, - meta=meta, - config_transform=config_transform, - ) - - -def choice( - name: str, - *choices: Step, - weights: Iterable[float] | None = None, - meta: Mapping[str, Any] | None = None, -) -> Choice[None]: - """Define a choice in a pipeline. - - Args: - name: The unique name of this step - *choices: The choices that can be taken - weights: Weights to assign to each choice - meta: Any metadata to associate with this - - Returns: - Choice: Choice component with your choices as possibilities - """ - weights = list(weights) if weights is not None else None - if weights and len(weights) != len(choices): - raise ValueError("Weights must be the same length as choices") - - return Choice(name=name, paths=list(choices), weights=weights, meta=meta) - - -@overload -def split( - name: str, - *paths: Step | dict[str, Step], - config: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = ..., - meta: Mapping[str, Any] | None = ..., -) -> Split[None, None]: - ... - - -@overload -def split( - name: str, - *paths: Step | dict[str, Step], - item: Callable[..., T], - config: Mapping[str, Any] | None = ..., - meta: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = ..., -) -> Split[T, None]: - ... - - -@overload -def split( - name: str, - *paths: Step | dict[str, Step], - item: Callable[..., T], - space: Space, - config: Mapping[str, Any] | None = ..., - meta: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = ..., -) -> Split[T, Space]: - ... - - -@overload -def split( - name: str, - *paths: Step | dict[str, Step], - item: Any, - space: Space, - config: Mapping[str, Any] | None = ..., - meta: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = ..., -) -> Split[Any, Space]: - ... - - -@overload -def split( - name: str, - *paths: Step | dict[str, Step], - item: Any, - config: Mapping[str, Any] | None = ..., - meta: Mapping[str, Any] | None = ..., - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = ..., -) -> Split[Any, None]: - ... - - -def split( - name: str, - *paths: Step | dict[str, Step], - item: Callable[..., T] | Any | None = None, - space: Space | None = None, - config: Mapping[str, Any] | None = None, - meta: Mapping[str, Any] | None = None, - config_transform: ( - Callable[[Mapping[str, Any], Any], Mapping[str, Any]] | None - ) = None, -) -> ( - Split[T, Space] - | Split[T, None] - | Split[None, None] - | Split[Any, Space] - | Split[Any, None] -): - """Create a Split component, allowing data to flow multiple paths. - - Args: - name: The unique name of this step - *paths: The different paths - item: The item for this step. - config: - A config of set values to pass. If any parameter here is also present in - the space, this will be removed from the space. - space: A space with which this step can be searched over. - meta: Any metadata to associate with this - config_transform: - A function that will be applied to the config before it is passed to the - item. This can be used to transform the config into a format that the item - expects. - - Returns: - Split: Split component with your choices as possibilities - """ - if len(paths) > 1 and isinstance(paths[0], dict): - raise ValueError( - "When passing a dict as the first argument, you can't pass any other" - " positional arguments.", - ) - - if len(paths) == 1 and isinstance(paths[0], dict): - _paths = tuple(group(k, v) for k, v in paths[0].items()) - else: - _paths = paths # type: ignore - assert all(isinstance(p, Step) for p in _paths) - - return Split( - name=name, - paths=list(_paths), # type: ignore - item=item, - search_space=space, - config=config, - meta=meta, - config_transform=config_transform, - ) - - -_NotSet = object() - - -def request( - key: str, - *, - default: T = _NotSet, # type: ignore - required: bool = False, -) -> ParamRequest[T]: - """Create a new parameter request. - - Args: - key: The key to request under. - default: The default value to use if the key is not found. - If left as `__NotSet` (default) then the key will be removed from the - config once [`configure`][amltk.pipeline.Step.configure] is called and - nothing has been provided. - - required: Whether the key is required to be present. - """ - return ParamRequest(key=key, default=default, required=required) - - -def group( - name: str, - *paths: Step[Space], - meta: Mapping[str, Any] | None = None, -) -> Group[Space]: - """Create a Group component, allowing to namespace one or multiple steps. - - Args: - name: The unique name of this step - *paths: The different paths - meta: Any metadata to associate with this - - Returns: - Group component with your choices as possibilities - """ - return Group(name=name, paths=list(paths), meta=meta) diff --git a/tests/configuring/__init__.py b/src/amltk/pipeline/builders/__init__.py similarity index 100% rename from tests/configuring/__init__.py rename to src/amltk/pipeline/builders/__init__.py diff --git a/src/amltk/pipeline/builders/sklearn.py b/src/amltk/pipeline/builders/sklearn.py new file mode 100644 index 00000000..56b14f0e --- /dev/null +++ b/src/amltk/pipeline/builders/sklearn.py @@ -0,0 +1,346 @@ +"""The sklearn [`builder()`][amltk.pipeline.builders.sklearn.build], converts +a pipeline made of [`Node`][amltk.pipeline.Node]s into a sklearn +[`Pipeline`][sklearn.pipeline.Pipeline]. + +!!! tip "Requirements" + + This requires `sklearn` which can be installed with: + + ```bash + pip install "amltk[scikit-learn]" + + # Or directly + pip install scikit-learn + ``` + + +??? tip "Basic Usage" + + ```python + # TODO + ``` + + +Each _kind_ of node corresponds to a different part of the end pipeline: + +=== "`Fixed`" + + [`Fixed`][amltk.pipeline.Fixed] - The estimator will simply be cloned, allowing you + to directly configure some object in a pipeline. + + ```python exec="true" source="material-block" html="true" + from sklearn.ensemble import RandomForestClassifier + from amltk.pipeline import Fixed + + est = Fixed(RandomForestClassifier(n_estimators=25)) + built_pipeline = est.build("sklearn") + from amltk._doc import doc_print; doc_print(print, built_pipeline) # markdown-exec: hide + ``` + +=== "`Component`" + + [`Component`][amltk.pipeline.Component] - The estimator will be built from the + component's config. This is mostly useful to allow a space to be defined for + the component. + + ```python exec="true" source="material-block" html="true" + from sklearn.ensemble import RandomForestClassifier + from amltk.pipeline import Component + + est = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) + + # ... Likely get the configuration through an optimizer or sampling + configured_est = est.configure({"n_estimators": 25}) + + built_pipeline = configured_est.build("sklearn") + from amltk._doc import doc_print; doc_print(print, built_pipeline) # markdown-exec: hide + ``` + +=== "`Sequential`" + + [`Sequential`][amltk.pipeline.Sequential] - The sequential will be converted into a + [`Pipeline`][sklearn.pipeline.Pipeline], building whatever nodes are contained + within in. + + ```python exec="true" source="material-block" html="true" + from sklearn.ensemble import RandomForestClassifier + from sklearn.decomposition import PCA + from amltk.pipeline import Component, Sequential + + pipeline = Sequential( + PCA(n_components=3), + Component(RandomForestClassifier, config={"n_estimators": 25}) + ) + built_pipeline = pipeline.build("sklearn") + from amltk._doc import doc_print; doc_print(print, built_pipeline) # markdown-exec: hide + ``` + +=== "`Split`" + + [`Split`][amltk.pipeline.Split] - The split will be converted into a + [`ColumnTransformer`][sklearn.compose.ColumnTransformer], where each path + and the data that should go through it is specified by the split's config. + You can provide a `ColumnTransformer` directly as the item to the `Split`, + or otherwise if left blank, it will default to the standard sklearn one. + + You can use a `Fixed` with the special keyword `"passthrough"` as you might normally + do with a `ColumnTransformer`. + + ```python exec="true" source="material-block" html="true" hl_lines="18-19 23-24" + from amltk.pipeline import Split + + categorical_pipeline = [ + SimpleImputer(strategy="constant", fill_value="missing"), + Component( + OneHotEncoder, + space={ + "min_frequency": (0.01, 0.1), + "handle_unknown": ["ignore", "infrequent_if_exist"], + }, + config={"drop": "first"}, + ), + ] + numerical_pipeline = [SimpleImputer(strategy="median"), StandardScaler()] + + split = Split( + { + "categories": categorical_pipeline, + "numbers": numerical_pipeline, + }, + config={ + "categories": make_column_selector(dtype_include=object), + "numbers": make_column_selector(dtype_include=np.number), + }, + ) + ``` + + +=== "`Join`" + + [`Join`][amltk.pipeline.Join] - The join will be converted into a + [`FeatureUnion`][sklearn.pipeline.FeatureUnion]. + + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Join, Component + from sklearn.decomposition import PCA + from sklearn.feature_selection import SelectKBest + + join = Join(PCA(n_components=2), SelectKBest(k=3), name="my_feature_union") + + pipeline = join.build("sklearn") + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + ``` + +=== "`Choice`" + + [`Choice`][amltk.pipeline.Choice] - The estimator will be built from the chosen + component's config. This is very similar to [`Component`][amltk.pipeline.Component]. + + ```python exec="true" source="material-block" html="true" + from sklearn.ensemble import RandomForestClassifier + from sklearn.neural_network import MLPClassifier + from amltk.pipeline import Choice + + # The choice here is usually provided during the `.configure()` step. + estimator_choice = Choice( + RandomForestClassifier(), + MLPClassifier(), + config={"__choice__": "RandomForestClassifier"} + ) + + built_pipeline = estimator_choice.build("sklearn") + from amltk._doc import doc_print; doc_print(print, built_pipeline) # markdown-exec: hide + ``` + + +""" # noqa: E501 +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any, TypeVar + +from sklearn.base import BaseEstimator, clone +from sklearn.compose import ColumnTransformer +from sklearn.pipeline import ( + FeatureUnion, + Pipeline as SklearnPipeline, +) + +from amltk.pipeline import ( + Choice, + Component, + Fixed, + Join, + Node, + Searchable, + Sequential, + Split, +) + +if TYPE_CHECKING: + from typing import TypeAlias + +COLUMN_TRANSFORMER_ARGS = [ + "remainder", + "sparse_threshold", + "n_jobs", + "transformer_weights", + "verbose", + "verbose_feature_names_out", +] +FEATURE_UNION_ARGS = ["n_jobs", "transformer_weights", "verbose"] + +# TODO: We can make this more explicit with typing out sklearn types. +# However sklearn operates in a bit more of a general level so it would +# require creating protocols to type this properly and work with sklearn's +# duck-typing. +SklearnItem: TypeAlias = Any | ColumnTransformer +SklearnPipelineT = TypeVar("SklearnPipelineT", bound=SklearnPipeline) + + +def _process_split( + node: Split, + pipeline_type: type[SklearnPipelineT] = SklearnPipeline, + **pipeline_kwargs: Any, +) -> tuple[str, ColumnTransformer]: + if node.config is None: + raise ValueError( + f"Can't handle split as it has no config attached: {node}.\n" + " Sklearn builder requires all splits to have a config to tell" + " the ColumnTransformer how to operate.", + ) + + if any(child.name in COLUMN_TRANSFORMER_ARGS for child in node.nodes): + raise ValueError( + f"Can't handle step as it has a path with a name that matches" + f" a known ColumnTransformer argument: {node}", + ) + + if any(child.name not in node.config for child in node.nodes): + raise ValueError( + f"Can't handle split {node.name=} as some path has no config associated" + " with it." + "\nPlease ensure that all paths have a config associated with them.\n" + f"config={node.config}\n" + f"children={[child.name for child in node.nodes]}\n", + ) + + match node.item: + case None: + col_transform_cls = ColumnTransformer + case type() if issubclass(node.item, ColumnTransformer): + col_transform_cls = node.item + case _: + raise ValueError( + f"Can't handle: {node}.\n" + " Requires all splits to have a subclass" + " ColumnTransformer as the item, or None.", + ) + + # https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html + # list of (name, estimator, columns) + transformers: list[tuple[str, Any, Any]] = [] + config = node.config + + for child in node.nodes: + child_steps = list(_iter_steps(child)) + match child_steps: + case []: + raise ValueError(f"Can't handle child of split.\n{child=}\n{node}") + case [(name, sklearn_thing)]: + transformers.append((name, sklearn_thing, config[child.name])) + case list(): + sklearn_thing = pipeline_type(child_steps, **pipeline_kwargs) + transformers.append((child.name, sklearn_thing, config[child.name])) + + return (node.name, col_transform_cls(transformers)) + + +def _process_join( + node: Join, + pipeline_type: type[SklearnPipelineT] = SklearnPipeline, + **pipeline_kwargs: Any, +) -> tuple[str, FeatureUnion]: + if any(child.name in FEATURE_UNION_ARGS for child in node.nodes): + raise ValueError( + f"Can't handle step as it has a path with a name that matches" + f" a known FeatureUnion argument: {node}", + ) + + match node.item: + case None: + feature_union_cls = FeatureUnion + case type() if issubclass(node.item, FeatureUnion): + feature_union_cls = node.item + case _: + raise ValueError( + f"Can't handle: {node}.\n" + " Requires all splits to have a subclass" + " ColumnTransformer as the item, or None.", + ) + + # https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.FeatureUnion.html + # list of (name, estimator) + transformers: list[tuple[str, Any]] = [] + + for child in node.nodes: + child_steps = list(_iter_steps(child)) + match child_steps: + case []: + raise ValueError(f"Can't handle child of Join.\n{child=}\n{node}") + case [(name, sklearn_thing)]: + transformers.append((name, sklearn_thing)) + case list(): + sklearn_thing = pipeline_type(child_steps, **pipeline_kwargs) + transformers.append((child.name, sklearn_thing)) + + return (node.name, feature_union_cls(transformers)) + + +def _iter_steps( + node: Node, + pipeline_type: type[SklearnPipelineT] = SklearnPipeline, + **pipeline_kwargs: Any, +) -> Iterator[tuple[str, SklearnItem]]: + match node: + case Fixed(item=BaseEstimator()): + yield (node.name, clone(node.item)) + case Fixed(item=anything): + yield (node.name, anything) + case Component(): + yield (node.name, node.build_item()) + case Choice(): + yield from _iter_steps(node.chosen()) + case Sequential(): + for child in node.nodes: + yield from _iter_steps(child) + # Bit more involved, we defer to another functino + case Join(): + yield _process_join(node, pipeline_type=pipeline_type, **pipeline_kwargs) + case Split(): + yield _process_split(node, pipeline_type=pipeline_type, **pipeline_kwargs) + case Searchable(): + raise ValueError(f"Can't handle Searchable: {node}") + case _: + raise ValueError(f"Can't handle node: {node}") + + +def build( + node: Node[Any, Any], + *, + pipeline_type: type[SklearnPipelineT] = SklearnPipeline, + **pipeline_kwargs: Any, +) -> SklearnPipelineT: + """Build a pipeline into a usable object. + + Args: + node: The node from which to build a pipeline. + pipeline_type: The type of pipeline to build. Defaults to the standard + sklearn pipeline but can be any derivative of that, i.e. ImbLearn's + pipeline. + **pipeline_kwargs: The kwargs to pass to the pipeline_type. + + Returns: + The built pipeline + """ + return pipeline_type(list(_iter_steps(node)), **pipeline_kwargs) # type: ignore diff --git a/src/amltk/pipeline/components.py b/src/amltk/pipeline/components.py index ce0a9692..513c245c 100644 --- a/src/amltk/pipeline/components.py +++ b/src/amltk/pipeline/components.py @@ -1,643 +1,871 @@ -"""The various components that can be part of a pipeline. +"""You can use the various different node types to build a pipeline. -These can all be created through the functions `step`, `split`, `choice` -exposed through the `amltk.pipeline` module and this is the preffered way to do so. -""" +The syntactic meaning of these components are dependant upon +the [search space parser](site:reference/pipelines/spaces.md) +and [builder](site:reference/pipelines/builders.md) you use. + +You can connect these nodes together using either the constructors explicitly, +as shown in the examples. We also provide some index operators: + +* `>>` - Connect nodes together to form a [`Sequential`][amltk.pipeline.components.Sequential] +* `&` - Connect nodes together to form a [`Join`][amltk.pipeline.components.Join] +* `|` - Connect nodes together to form a [`Choice`][amltk.pipeline.components.Choice] + +There is also another short-hand that you may find useful to know: + +* `{comp1, comp2, comp3}` - This will automatically be converted into a + [`Choice`][amltk.pipeline.Choice] between the given components. +* `(comp1, comp2, comp3)` - This will automatically be converted into a + [`Join`][amltk.pipeline.Join] between the given components. +* `[comp1, comp2, comp3]` - This will automatically be converted into a + [`Sequential`][amltk.pipeline.Sequential] between the given components. + +For each of these components we will show examples using +the `#! "sklearn"` builder but please check out the +[builder reference](site:reference/pipelines/builders.md) +for more. + +The components are: + +### Component + +::: amltk.pipeline.components.Component + options: + members: false + +### Sequential + +::: amltk.pipeline.components.Sequential + options: + members: false + +### Choice + +::: amltk.pipeline.components.Choice + options: + members: false + +### Split + +::: amltk.pipeline.components.Split + options: + members: false + +### Join + +::: amltk.pipeline.components.Join + options: + members: false + +### Fixed + +::: amltk.pipeline.components.Fixed + options: + members: false + +### Searchable + +::: amltk.pipeline.components.Searchable + options: + members: false +""" # noqa: E501 from __future__ import annotations -from contextlib import suppress -from itertools import chain, repeat -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Generic, - Iterator, - Literal, - Mapping, - Sequence, -) +import inspect +from collections.abc import Callable, Iterator, Mapping, Sequence +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload from typing_extensions import override -from attrs import field, frozen -from more_itertools import first_true +from more_itertools import all_unique, first_true -from amltk.pipeline.step import ParamRequest, Step, mapping_select, prefix_keys -from amltk.types import Config, FidT, Item, Space +from amltk._functional import entity_name +from amltk.exceptions import DuplicateNamesError, NoChoiceMadeError, NodeNotFoundError +from amltk.pipeline.node import Node, RichOptions +from amltk.randomness import randuid +from amltk.types import Config, Item, Space if TYPE_CHECKING: - from rich.console import RenderableType + from amltk.pipeline.node import NodeLike -@frozen(kw_only=True) -class Searchable(Step[Space], Generic[Space]): - """A step to be searched over. +T = TypeVar("T") +NodeT = TypeVar("NodeT", bound=Node) - See Also: - [`Step`][amltk.pipeline.step.Step] - """ - name: str - """Name of the step""" +@overload +def as_node(thing: Node, name: str | None = ...) -> Node: # type: ignore + ... - config: Mapping[str, Any] | None = field(default=None, hash=False) - """The configuration for this step""" - search_space: Space | None = field(default=None, hash=False, repr=False) - """The search space for this step""" +@overload +def as_node(thing: tuple[Node | NodeLike, ...], name: str | None = ...) -> Join: # type: ignore + ... - fidelity_space: Mapping[str, FidT] | None = field( - default=None, - hash=False, - repr=False, - ) - """The fidelities for this step""" - config_transform: ( - Callable[ - [Mapping[str, Any], Any], - Mapping[str, Any], - ] - | None - ) = field(default=None, hash=False, repr=False) - """A function that transforms the configuration of this step""" +@overload +def as_node(thing: set[Node | NodeLike], name: str | None = ...) -> Choice: # type: ignore + ... - meta: Mapping[str, Any] | None = None - """Any meta information about this step""" - RICH_PANEL_BORDER_COLOR: ClassVar[str] = "light_steel_blue" +@overload +def as_node(thing: list[Node | NodeLike], name: str | None = ...) -> Sequential: # type: ignore + ... -@frozen(kw_only=True) -class Component(Step[Space], Generic[Item, Space]): - """A Fixed component with an item attached. +@overload +def as_node( # type: ignore + thing: Callable[..., Item], + name: str | None = ..., +) -> Component[Item, None]: + ... - See Also: - [`Step`][amltk.pipeline.step.Step] - """ - item: Callable[..., Item] | Any = field(hash=False) - """The item attached to this step""" +@overload +def as_node(thing: Item, name: str | None = ...) -> Fixed[Item]: + ... - name: str - """Name of the step""" - config: Mapping[str, Any] | None = field(default=None, hash=False) - """The configuration for this step""" +def as_node( # noqa: PLR0911 + thing: Node | NodeLike[Item], + name: str | None = None, +) -> Node | Choice | Join | Sequential | Fixed[Item]: + """Convert a node, pipeline, set or tuple into a component, copying anything + in the process and removing all linking to other nodes. - search_space: Space | None = field(default=None, hash=False, repr=False) - """The search space for this step""" + Args: + thing: The thing to convert + name: The name of the node. If it already a node, it will be renamed to that + one. - fidelity_space: Mapping[str, FidT] | None = field( - default=None, - hash=False, - repr=False, + Returns: + The component + """ + match thing: + case set(): + return Choice(*thing, name=name) + case tuple(): + return Join(*thing, name=name) + case list(): + return Sequential(*thing, name=name) + case Node(): + name = thing.name if name is None else name + return thing.mutate(name=name) + case type(): + return Component(thing, name=name) + case thing if (inspect.isfunction(thing) or inspect.ismethod(thing)): + return Component(thing, name=name) + case _: + return Fixed(thing, name=name) + + +@dataclass(init=False, frozen=True, eq=True) +class Join(Node[Item, Space]): + """[`Join`][amltk.pipeline.Join] together different parts of the pipeline. + + This indicates the different children in + [`.nodes`][amltk.pipeline.Node.nodes] should act in tandem with one + another, for example, concatenating the outputs of the various members of the + `Join`. + + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Join, Component + from sklearn.decomposition import PCA + from sklearn.feature_selection import SelectKBest + + pca = Component(PCA, space={"n_components": (1, 3)}) + kbest = Component(SelectKBest, space={"k": (1, 3)}) + + join = Join(pca, kbest, name="my_feature_union") + from amltk._doc import doc_print; doc_print(print, join) # markdown-exec: hide + + space = join.search_space("configspace") + from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide + + pipeline = join.build("sklearn") + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + ``` + + You may also just join together nodes using an infix operator `&` if you prefer: + + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Join, Component + from sklearn.decomposition import PCA + from sklearn.feature_selection import SelectKBest + + pca = Component(PCA, space={"n_components": (1, 3)}) + kbest = Component(SelectKBest, space={"k": (1, 3)}) + + # Can not parametrize or name the join + join = pca & kbest + from amltk._doc import doc_print; doc_print(print, join) # markdown-exec: hide + + # With a parametrized join + join = ( + Join(name="my_feature_union") & pca & kbest + ) + item = join.build("sklearn") + from amltk._doc import doc_print; doc_print(print, item) # markdown-exec: hide + ``` + + Whenever some other node sees a tuple, i.e. `(comp1, comp2, comp3)`, this + will automatically be converted into a `Join`. + + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Sequential, Component + from sklearn.decomposition import PCA + from sklearn.feature_selection import SelectKBest + from sklearn.ensemble import RandomForestClassifier + + pca = Component(PCA, space={"n_components": (1, 3)}) + kbest = Component(SelectKBest, space={"k": (1, 3)}) + + # Can not parametrize or name the join + join = Sequential( + (pca, kbest), + RandomForestClassifier(n_estimators=5), + name="my_feature_union", ) - """The fidelities for this step""" + from amltk._doc import doc_print; doc_print(print, join) # markdown-exec: hide + ``` + + Like all [`Node`][amltk.pipeline.node.Node]s, a `Join` accepts an explicit + [`name=`][amltk.pipeline.node.Node.name], + [`item=`][amltk.pipeline.node.Node.item], + [`config=`][amltk.pipeline.node.Node.config], + [`space=`][amltk.pipeline.node.Node.space], + [`fidelities=`][amltk.pipeline.node.Node.fidelities], + [`config_transform=`][amltk.pipeline.node.Node.config_transform] and + [`meta=`][amltk.pipeline.node.Node.meta]. - config_transform: ( - Callable[ - [Mapping[str, Any], Any], - Mapping[str, Any], - ] - | None - ) = field(default=None, hash=False, repr=False) - """A function that transforms the configuration of this step""" + See Also: + * [`Node`][amltk.pipeline.node.Node] + """ - meta: Mapping[str, Any] | None = None - """Any meta information about this step""" + nodes: tuple[Node, ...] + """The nodes that this node leads to.""" - RICH_PANEL_BORDER_COLOR: ClassVar[str] = "default" + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#7E6B8F") - @override - def build(self, **kwargs: Any) -> Item: - """Build the item attached to this component. + _NODES_INIT: ClassVar = "args" - Args: - **kwargs: Any additional arguments to pass to the item + def __init__( + self, + *nodes: Node | NodeLike, + name: str | None = None, + item: Item | Callable[[Item], Item] | None = None, + config: Config | None = None, + space: Space | None = None, + fidelities: Mapping[str, Any] | None = None, + config_transform: Callable[[Config, Any], Config] | None = None, + meta: Mapping[str, Any] | None = None, + ): + """See [`Node`][amltk.pipeline.node.Node] for details.""" + _nodes = tuple(as_node(n) for n in nodes) + if not all_unique(_nodes, key=lambda node: node.name): + raise ValueError( + f"Can't handle nodes they do not all contain unique names, {nodes=}." + "\nAll nodes must have a unique name. Please provide a `name=` to them", + ) - Returns: - Item - The built item - """ - if callable(self.item): - config = self.config or {} - return self.item(**{**config, **kwargs}) + if name is None: + name = f"Join-{randuid(8)}" + + super().__init__( + *_nodes, + name=name, + item=item, + config=config, + space=space, + fidelities=fidelities, + config_transform=config_transform, + meta=meta, + ) - if self.config is not None: - raise ValueError(f"Can't pass config to a non-callable item in step {self}") + @override + def __and__(self, other: Node | NodeLike) -> Join: + other_node = as_node(other) + if any(other_node.name == node.name for node in self.nodes): + raise ValueError( + f"Can't handle node with name '{other_node.name} as" + f" there is already a node called '{other_node.name}' in {self.name}", + ) - return self.item + nodes = (*tuple(as_node(n) for n in self.nodes), other_node) + return self.mutate(name=self.name, nodes=nodes) - @override - def _rich_table_items(self) -> Iterator[tuple[RenderableType, ...]]: - from rich.pretty import Pretty - from amltk.richutil import Function +@dataclass(init=False, frozen=True, eq=True) +class Choice(Node[Item, Space]): + """A [`Choice`][amltk.pipeline.Choice] between different subcomponents. - if self.item is not None: - if callable(self.item): - yield "item", Function(self.item) - else: - yield "item", Pretty(self.item) + This indicates that a choice should be made between the different children in + [`.nodes`][amltk.pipeline.Node.nodes], usually done when you + [`configure()`][amltk.pipeline.node.Node.configure] with some `config` from + a [`search_space()`][amltk.pipeline.node.Node.search_space]. - yield from super()._rich_table_items() + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Choice, Component + from sklearn.ensemble import RandomForestClassifier + from sklearn.neural_network import MLPClassifier + rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) + mlp = Component(MLPClassifier, space={"activation": ["logistic", "relu", "tanh"]}) -@frozen(kw_only=True) -class Group(Mapping[str, Step], Step[Space]): - """A Fixed component with an item attached. + estimator_choice = Choice(rf, mlp, name="estimator") + from amltk._doc import doc_print; doc_print(print, estimator_choice) # markdown-exec: hide - See Also: - [`Step`][amltk.pipeline.step.Step] - """ + space = estimator_choice.search_space("configspace") + from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide - paths: Sequence[Step] - """The paths that can be taken from this split""" + config = space.sample_configuration() + from amltk._doc import doc_print; doc_print(print, config) # markdown-exec: hide - name: str - """Name of the step""" + configured_choice = estimator_choice.configure(config) + from amltk._doc import doc_print; doc_print(print, configured_choice) # markdown-exec: hide - config: Mapping[str, Any] | None = field(default=None, hash=False) - """The configuration for this step""" + chosen_estimator = configured_choice.chosen() + from amltk._doc import doc_print; doc_print(print, chosen_estimator) # markdown-exec: hide - search_space: Space | None = field(default=None, hash=False, repr=False) - """The search space for this step""" + estimator = chosen_estimator.build_item() + from amltk._doc import doc_print; doc_print(print, estimator) # markdown-exec: hide + ``` - fidelity_space: Mapping[str, FidT] | None = field( - default=None, - hash=False, - repr=False, - ) - """The fidelities for this step""" + You may also just add nodes to a `Choice` using an infix operator `|` if you prefer: - config_transform: ( - Callable[ - [Mapping[str, Any], Any], - Mapping[str, Any], - ] - | None - ) = field(default=None, hash=False, repr=False) - """A function that transforms the configuration of this step""" + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Choice, Component + from sklearn.ensemble import RandomForestClassifier + from sklearn.neural_network import MLPClassifier - meta: Mapping[str, Any] | None = None - """Any meta information about this step""" + rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) + mlp = Component(MLPClassifier, space={"activation": ["logistic", "relu", "tanh"]}) - RICH_PANEL_BORDER_COLOR: ClassVar[str] = "deep_sky_blue2" + estimator_choice = ( + Choice(name="estimator") | mlp | rf + ) + from amltk._doc import doc_print; doc_print(print, estimator_choice) # markdown-exec: hide + ``` - def __attrs_post_init__(self) -> None: - """Ensure that the paths are all unique.""" - if len(self) != len(set(self)): - raise ValueError("Paths must be unique") + Whenever some other node sees a set, i.e. `{comp1, comp2, comp3}`, this + will automatically be converted into a `Choice`. - for path_head in self.paths: - object.__setattr__(path_head, "parent", self) - object.__setattr__(path_head, "prv", None) + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Choice, Component, Sequential + from sklearn.ensemble import RandomForestClassifier + from sklearn.neural_network import MLPClassifier + from sklearn.impute import SimpleImputer - @override - def path_to( - self, - key: str | Step | Callable[[Step], bool], - *, - direction: Literal["forward", "backward"] | None = None, - ) -> list[Step] | None: - """See [`Step.path_to`][amltk.pipeline.step.Step.path_to].""" - if callable(key): - pred = key - elif isinstance(key, Step): - pred = lambda step: step == key - else: - pred = lambda step: step.name == key + rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) + mlp = Component(MLPClassifier, space={"activation": ["logistic", "relu", "tanh"]}) - # We found our target, just return now - if pred(self): - return [self] + pipeline = Sequential( + SimpleImputer(fill_value=0), + {mlp, rf}, + name="my_pipeline", + ) + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + ``` - if direction in (None, "forward"): - for member in self.paths: - if path := member.path_to(pred, direction="forward"): - return [self, *path] + Like all [`Node`][amltk.pipeline.node.Node]s, a `Choice` accepts an explicit + [`name=`][amltk.pipeline.node.Node.name], + [`item=`][amltk.pipeline.node.Node.item], + [`config=`][amltk.pipeline.node.Node.config], + [`space=`][amltk.pipeline.node.Node.space], + [`fidelities=`][amltk.pipeline.node.Node.fidelities], + [`config_transform=`][amltk.pipeline.node.Node.config_transform] and + [`meta=`][amltk.pipeline.node.Node.meta]. - if self.nxt is not None and ( - path := self.nxt.path_to(pred, direction="forward") - ): - return [self, *path] + !!! warning "Order of nodes" - if direction in (None, "backward"): - back = self.prv or self.parent - if back and (path := back.path_to(pred, direction="backward")): - return [self, *path] + The given nodes of a choice are always ordered according + to their name, so indexing `choice.nodes` may not be reliable + if modifying the choice dynamically. - return None + Please use `choice["name"]` to access the nodes instead. - return None + See Also: + * [`Node`][amltk.pipeline.node.Node] + """ # noqa: E501 - @override - def traverse( - self, - *, - include_self: bool = True, - backwards: bool = False, - ) -> Iterator[Step]: - """See `Step.traverse`.""" - if include_self: - yield self - - # Backward mode - if backwards: - if self.prv is not None: - yield from self.prv.traverse(backwards=True) - elif self.parent is not None: - yield from self.parent.traverse(backwards=True) - - if include_self: - yield self - - return - - # Forward mode - yield from chain.from_iterable(path.traverse() for path in self.paths) - if self.nxt is not None: - yield from self.nxt.traverse() + nodes: tuple[Node, ...] + """The nodes that this node leads to.""" - @override - def walk( + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#FF4500") + _NODES_INIT: ClassVar = "args" + + def __init__( self, - groups: Sequence[Group] | None = None, - parents: Sequence[Step] | None = None, - ) -> Iterator[tuple[list[Group], list[Step], Step]]: - """See `Step.walk`.""" - groups = list(groups) if groups is not None else [] - parents = list(parents) if parents is not None else [] - yield groups, parents, self - - for path in self.paths: - yield from path.walk(groups=[*groups, self], parents=[]) - - if self.nxt: - yield from self.nxt.walk( - groups=groups, - parents=[*parents, self], + *nodes: Node | NodeLike, + name: str | None = None, + item: Item | Callable[[Item], Item] | None = None, + config: Config | None = None, + space: Space | None = None, + fidelities: Mapping[str, Any] | None = None, + config_transform: Callable[[Config, Any], Config] | None = None, + meta: Mapping[str, Any] | None = None, + ): + """See [`Node`][amltk.pipeline.node.Node] for details.""" + _nodes: tuple[Node, ...] = tuple( + sorted((as_node(n) for n in nodes), key=lambda n: n.name), + ) + if not all_unique(_nodes, key=lambda node: node.name): + raise ValueError( + f"Can't handle nodes as we can not generate a __choice__ for {nodes=}." + "\nAll nodes must have a unique name. Please provide a `name=` to them", ) - @override - def replace(self, replacements: Mapping[str, Step]) -> Iterator[Step]: - """See `Step.replace`.""" - if self.name in replacements: - yield replacements[self.name] - else: - # Otherwise, we need to call replace over any paths and create a new - # split with those replacements - paths = [ - Step.join(path.replace(replacements=replacements)) - for path in self.paths - ] - yield self.mutate(paths=paths) - - if self.nxt is not None: - yield from self.nxt.replace(replacements=replacements) + if name is None: + name = f"Choice-{randuid(8)}" + + super().__init__( + *_nodes, + name=name, + item=item, + config=config, + space=space, + fidelities=fidelities, + config_transform=config_transform, + meta=meta, + ) @override - def remove(self, keys: Sequence[str]) -> Iterator[Step]: - """See `Step.remove`.""" - if self.name not in keys: - # We need to call remove on all the paths. If this removes a - # path that only has one entry, leading to an empty path, then - # we ignore any errors from joining and remove the path - paths = [] - for path in self.paths: - with suppress(ValueError): - new_path = Step.join(path.remove(keys)) - paths.append(new_path) - - yield self.mutate(paths=paths) - - if self.nxt is not None: - yield from self.nxt.remove(keys) + def __or__(self, other: Node | NodeLike) -> Choice: + other_node = as_node(other) + if any(other_node.name == node.name for node in self.nodes): + raise ValueError( + f"Can't handle node with name '{other_node.name} as" + f" there is already a node called '{other_node.name}' in {self.name}", + ) - @override - def apply(self, func: Callable[[Step], Step]) -> Step: - """Apply a function to all the steps in this group. + nodes = tuple( + sorted( + [as_node(n) for n in self.nodes] + [other_node], + key=lambda n: n.name, + ), + ) + return self.mutate(name=self.name, nodes=nodes) - Args: - func: The function to apply + def chosen(self) -> Node: + """The chosen branch. Returns: - Step: The new group + The chosen branch """ - new_paths = [path.apply(func) for path in self.paths] - new_nxt = self.nxt.apply(func) if self.nxt else None + match self.config: + case {"__choice__": choice}: + chosen = first_true( + self.nodes, + pred=lambda node: node.name == choice, + default=None, + ) + if chosen is None: + raise NodeNotFoundError(choice, self.name) - # NOTE: We can't be sure that the function will return a new instance of - # `self` so we have to make a copy of `self` and then apply the function - # to that copy. - new_self = func(self.copy()) + return chosen + case _: + raise NoChoiceMadeError(self.name) - if new_nxt is not None: - # HACK: Frozen objects do not allow setting attributes after - # instantiation. Join the two steps together. - object.__setattr__(new_self, "nxt", new_nxt) - object.__setattr__(new_self, "paths", new_paths) - object.__setattr__(new_nxt, "prv", new_self) - return new_self +@dataclass(init=False, frozen=True, eq=True) +class Sequential(Node[Item, Space]): + """A [`Sequential`][amltk.pipeline.Sequential] set of operations in a pipeline. - # OPTIMIZE: Unlikely to be an issue but I figure `.items()` on - # a split of size `n` will cause `n` iterations of `paths` - # Fixable by implementing more of the `Mapping` functions + This indicates the different children in + [`.nodes`][amltk.pipeline.Node.nodes] should act one after + another, feeding the output of one into the next. - @override - def __getitem__(self, key: str) -> Step: - if val := first_true(self.paths, pred=lambda p: p.name == key): - return val - raise KeyError(key) + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Component, Sequential + from sklearn.decomposition import PCA + from sklearn.ensemble import RandomForestClassifier - @override - def __len__(self) -> int: - return len(self.paths) + pipeline = Sequential( + PCA(n_components=3), + Component(RandomForestClassifier, space={"n_estimators": (10, 100)}), + name="my_pipeline" + ) + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide - @override - def __iter__(self) -> Iterator[str]: - return iter(p.name for p in self.paths) + space = pipeline.search_space("configspace") + from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide - @override - def configure( # noqa: PLR0912, C901 - self, - config: Config, - *, - prefixed_name: bool | None = None, - transform_context: Any | None = None, - params: Mapping[str, Any] | None = None, - clear_space: bool | Literal["auto"] = "auto", - ) -> Step: - """Configure this step and anything following it with the given config. + configuration = space.sample_configuration() + from amltk._doc import doc_print; doc_print(print, configuration) # markdown-exec: hide - Args: - config: The configuration to apply - prefixed_name: Whether items in the config are prefixed by the names - of the steps. - * If `None`, the default, then `prefixed_name` will be assumed to - be `True` if this step has a next step or if the config has - keys that begin with this steps name. - * If `True`, then the config will be searched for items prefixed - by the name of the step (and subsequent chained steps). - * If `False`, then the config will be searched for items without - the prefix, i.e. the config keys are exactly those matching - this steps search space. - transform_context: Any context to give to `config_transform=` of individual - steps. - params: The params to match any requests when configuring this step. - These will match against any ParamRequests in the config and will - be used to fill in any missing values. - clear_space: Whether to clear the search space after configuring. - If `"auto"` (default), then the search space will be cleared of any - keys that are in the config, if the search space is a `dict`. Otherwise, - `True` indicates that it will be removed in the returned step and - `False` indicates that it will remain as is. + configured_pipeline = pipeline.configure(configuration) + from amltk._doc import doc_print; doc_print(print, configured_pipeline) # markdown-exec: hide - Returns: - Step: The configured step - """ - if prefixed_name is None: - if any(key.startswith(f"{self.name}:") for key in config): - prefixed_name = True - else: - prefixed_name = self.nxt is not None - - nxt = ( - self.nxt.configure( - config, - prefixed_name=prefixed_name, - transform_context=transform_context, - params=params, - clear_space=clear_space, - ) - if self.nxt - else None - ) + sklearn_pipeline = pipeline.build("sklearn") + from amltk._doc import doc_print; doc_print(print, sklearn_pipeline) # markdown-exec: hide + ``` - # Configure all the paths, we assume all of these must - # have the prefixed name and hence use `mapping_select` - subconfig = mapping_select(config, f"{self.name}:") if prefixed_name else config - - paths = [ - path.configure( - subconfig, - prefixed_name=True, - transform_context=transform_context, - params=params, - clear_space=clear_space, - ) - for path in self.paths - ] - - this_config = subconfig if prefixed_name else config - - # The config for this step is anything that doesn't have - # another delimiter in it and is not a part of a subpath - this_config = { - k: v - for k, v in this_config.items() - if ":" not in k and not any(k.startswith(f"{p.name}") for p in self.paths) - } - - if self.config is not None: - this_config = {**self.config, **this_config} - - _params = params or {} - reqs = [(k, v) for k, v in this_config.items() if isinstance(v, ParamRequest)] - for k, request in reqs: - if request.key in _params: - this_config[k] = _params[request.key] - elif request.has_default: - this_config[k] = request.default - elif request.required: - raise ParamRequest.RequestNotMetError( - f"Missing required parameter {request.key} for step {self.name}" - " and no default was provided." - f"\nThe request given was: {request}", - f"Please use the `params=` argument to provide a value for this" - f" request. What was given was `{params=}`", - ) + You may also just chain together nodes using an infix operator `>>` if you prefer: - # If we have a `dict` for a space, then we can remove any configured keys that - # overlap it. - _space: Any - if clear_space == "auto": - _space = self.search_space - if isinstance(self.search_space, dict) and any(self.search_space): - _overlap = set(this_config).intersection(self.search_space) - _space = { - k: v for k, v in self.search_space.items() if k not in _overlap - } - if len(_space) == 0: - _space = None - - elif clear_space is True: - _space = None - else: - _space = self.search_space + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Join, Component, Sequential + from sklearn.decomposition import PCA + from sklearn.ensemble import RandomForestClassifier - if self.config_transform is not None: - this_config = self.config_transform(this_config, transform_context) + pipeline = ( + Sequential(name="my_pipeline") + >> PCA(n_components=3) + >> Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) + ) + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + ``` + + Whenever some other node sees a list, i.e. `[comp1, comp2, comp3]`, this + will automatically be converted into a `Sequential`. + + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Choice + from sklearn.impute import SimpleImputer + from sklearn.preprocessing import StandardScaler + from sklearn.ensemble import RandomForestClassifier + from sklearn.neural_network import MLPClassifier + + pipeline_choice = Choice( + [SimpleImputer(), RandomForestClassifier()], + [StandardScaler(), MLPClassifier()], + name="pipeline_choice" + ) + from amltk._doc import doc_print; doc_print(print, pipeline_choice) # markdown-exec: hide + ``` + + Like all [`Node`][amltk.pipeline.node.Node]s, a `Sequential` accepts an explicit + [`name=`][amltk.pipeline.node.Node.name], + [`item=`][amltk.pipeline.node.Node.item], + [`config=`][amltk.pipeline.node.Node.config], + [`space=`][amltk.pipeline.node.Node.space], + [`fidelities=`][amltk.pipeline.node.Node.fidelities], + [`config_transform=`][amltk.pipeline.node.Node.config_transform] and + [`meta=`][amltk.pipeline.node.Node.meta]. - new_self = self.mutate( - paths=paths, - config=this_config if this_config else None, - nxt=nxt, - search_space=_space, - ) + See Also: + * [`Node`][amltk.pipeline.node.Node] + """ # noqa: E501 - if nxt is not None: - # HACK: This is a hack to to modify the fact `nxt` is a frozen - # object. Frozen objects do not allow setting attributes after - # instantiation. - object.__setattr__(nxt, "prv", new_self) + nodes: tuple[Node, ...] + """The nodes in series.""" - return new_self + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions( + panel_color="#7E6B8F", + node_orientation="vertical", + ) + _NODES_INIT: ClassVar = "args" - def first(self) -> Step: - """Get the first step in this group.""" - return self.paths[0] + def __init__( + self, + *nodes: Node | NodeLike, + name: str | None = None, + item: Item | Callable[[Item], Item] | None = None, + config: Config | None = None, + space: Space | None = None, + fidelities: Mapping[str, Any] | None = None, + config_transform: Callable[[Config, Any], Config] | None = None, + meta: Mapping[str, Any] | None = None, + ): + """See [`Node`][amltk.pipeline.node.Node] for details.""" + _nodes = tuple(as_node(n) for n in nodes) + + # Perhaps we need to do a deeper check on this... + if not all_unique(_nodes, key=lambda node: node.name): + raise DuplicateNamesError(self) + + if name is None: + name = f"Seq-{randuid(8)}" + + super().__init__( + *_nodes, + name=name, + item=item, + config=config, + space=space, + fidelities=fidelities, + config_transform=config_transform, + meta=meta, + ) - @override - def select(self, choices: Mapping[str, str]) -> Iterator[Step]: - """See `Step.select`.""" - if self.name in choices: - choice = choices[self.name] - chosen = first_true(self.paths, pred=lambda path: path.name == choice) - if chosen is None: - raise ValueError( - f"{self.__class__.__qualname__} {self.name} has no path '{choice}'" - f"\n{self}", - ) - yield chosen - else: - # Otherwise, we need to call select over the paths - paths = [Step.join(path.select(choices)) for path in self.paths] - yield self.mutate(paths=paths) + @property + def tail(self) -> Node: + """The last step in the pipeline.""" + return self.nodes[-1] - if self.nxt is not None: - yield from self.nxt.select(choices) + def __len__(self) -> int: + """Get the number of nodes in the pipeline.""" + return len(self.nodes) @override - def fidelities(self) -> dict[str, FidT]: - """See `Step.fidelities`.""" - fids = {} - for path in self.paths: - fids.update(prefix_keys(path.fidelities(), f"{self.name}:")) - - if self.nxt is not None: - fids.update(self.nxt.fidelities()) + def __rshift__(self, other: Node | NodeLike) -> Sequential: + other_node = as_node(other) + if any(other_node.name == node.name for node in self.nodes): + raise ValueError( + f"Can't handle node with name '{other_node.name} as" + f" there is already a node called '{other_node.name}' in {self.name}", + ) - return fids + nodes = (*tuple(as_node(n) for n in self.nodes), other_node) + return self.mutate(name=self.name, nodes=nodes) @override - def linearized_fidelity(self, value: float) -> dict[str, int | float | Any]: - """Get the linearized fidelity for this step. + def walk( + self, + path: Sequence[Node] | None = None, + ) -> Iterator[tuple[list[Node], Node]]: + """Walk the nodes in this chain. Args: - value: The value to linearize. Must be between [0, 1] + path: The current path to this node - Return: - dictionary from key to it's linearized fidelity. + Yields: + The parents of the node and the node itself """ - assert 1.0 <= value <= 100.0, f"{value=} not in [1.0, 100.0]" # noqa: PLR2004 - d = {} - if self.fidelity_space is not None: - for f_name, f_range in self.fidelity_space.items(): - low, high = f_range - fid = low + (high - low) * value - fid = low + (high - low) * (value - 1) / 100 - fid = fid if isinstance(low, float) else round(fid) - d[f_name] = fid + path = list(path) if path is not None else [] + yield path, self + + path = [*path, self] + for node in self.nodes: + yield from node.walk(path=path) + + # Append the previous node so that the next node in the sequence is + # lead to from the previous node + path = [*path, node] + + +@dataclass(init=False, frozen=True, eq=True) +class Split(Node[Item, Space]): + """A [`Split`][amltk.pipeline.Split] of data in a pipeline. + + This indicates the different children in + [`.nodes`][amltk.pipeline.Node.nodes] should + act in parallel but on different subsets of data. + + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Component, Split + from sklearn.impute import SimpleImputer + from sklearn.preprocessing import OneHotEncoder + from sklearn.compose import make_column_selector + + categorical_pipeline = [ + SimpleImputer(strategy="constant", fill_value="missing"), + OneHotEncoder(drop="first"), + ] + numerical_pipeline = Component(SimpleImputer, space={"strategy": ["mean", "median"]}) + + preprocessor = Split( + { + "categories": categorical_pipeline, + "numerical": numerical_pipeline, + }, + config={ + # This is how you would configure the split for the sklearn builder in particular + "categories": make_column_selector(dtype_include="category"), + "numerical": make_column_selector(dtype_exclude="category"), + }, + name="my_split" + ) + from amltk._doc import doc_print; doc_print(print, preprocessor) # markdown-exec: hide - d = prefix_keys(d, f"{self.name}:") + space = preprocessor.search_space("configspace") + from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide - for path in self.paths: - path_fids = prefix_keys(path.linearized_fidelity(value), f"{self.name}:") - d.update(path_fids) + configuration = space.sample_configuration() + from amltk._doc import doc_print; doc_print(print, configuration) # markdown-exec: hide - if self.nxt is None: - return d + configured_preprocessor = preprocessor.configure(configuration) + from amltk._doc import doc_print; doc_print(print, configured_preprocessor) # markdown-exec: hide - nxt_fids = self.nxt.linearized_fidelity(value) - return {**d, **nxt_fids} + built_preprocessor = configured_preprocessor.build("sklearn") + from amltk._doc import doc_print; doc_print(print, built_preprocessor) # markdown-exec: hide + ``` - @override - def _rich_panel_contents(self) -> Iterator[RenderableType]: - from rich.console import Group as RichGroup - from rich.table import Table - from rich.text import Text + The split is a slight oddity when compared to the other kinds of components in that + it allows a `dict` as it's first argument, where the keys are the names of the + different paths through which data will go and the values are the actual nodes that + will receive the data. - if panel_contents := list(self._rich_table_items()): - table = Table.grid(padding=(0, 1), expand=False) - for tup in panel_contents: - table.add_row(*tup, style="default") - table.add_section() + If nodes are passed in as they are for all other components, usually the name of the + first node will be important for any [builder](site:reference/pipelines/builders.md), + trying to make sense of how to use the `Split` - yield table - if any(self.paths): - # HACK : Unless we exposed this through another function, we - # just assume this is desired behaviour. - connecter = Text("↓", style="bold", justify="center") + Like all [`Node`][amltk.pipeline.node.Node]s, a `Split` accepts an explicit + [`name=`][amltk.pipeline.node.Node.name], + [`item=`][amltk.pipeline.node.Node.item], + [`config=`][amltk.pipeline.node.Node.config], + [`space=`][amltk.pipeline.node.Node.space], + [`fidelities=`][amltk.pipeline.node.Node.fidelities], + [`config_transform=`][amltk.pipeline.node.Node.config_transform] and + [`meta=`][amltk.pipeline.node.Node.meta]. - pipeline_table = Table.grid(padding=(0, 1), expand=False) - pipelines = [RichGroup(*p._rich_iter(connecter)) for p in self.paths] - pipeline_table.add_row(*pipelines) + See Also: + * [`Node`][amltk.pipeline.node.Node] + """ # noqa: E501 - yield pipeline_table + nodes: tuple[Node, ...] + """The nodes that this node leads to.""" + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions( + panel_color="#777DA7", + node_orientation="horizontal", + ) -@frozen(kw_only=True) -class Split(Group[Space], Generic[Item, Space]): - """A split in the pipeline. + _NODES_INIT: ClassVar = "args" + + def __init__( + self, + *nodes: Node | NodeLike | dict[str, Node | NodeLike], + name: str | None = None, + item: Item | Callable[[Item], Item] | None = None, + config: Config | None = None, + space: Space | None = None, + fidelities: Mapping[str, Any] | None = None, + config_transform: Callable[[Config, Any], Config] | None = None, + meta: Mapping[str, Any] | None = None, + ): + """See [`Node`][amltk.pipeline.node.Node] for details.""" + if any(isinstance(n, dict) for n in nodes): + if len(nodes) > 1: + raise ValueError( + "Can't handle multiple nodes with a dictionary as a node.\n" + f"{nodes=}", + ) + _node = nodes[0] + assert isinstance(_node, dict) + + def _construct(key: str, value: Node | NodeLike) -> Node: + match value: + case list(): + return Sequential(*value, name=key) + case set() | tuple(): + return as_node(value, name=key) + case _: + return Sequential(value, name=key) + + _nodes = tuple(_construct(key, value) for key, value in _node.items()) + else: + _nodes = tuple(as_node(n) for n in nodes) + + if not all_unique(_nodes, key=lambda node: node.name): + raise ValueError( + f"Can't handle nodes they do not all contain unique names, {nodes=}." + "\nAll nodes must have a unique name. Please provide a `name=` to them", + ) + + if name is None: + name = f"Split-{randuid(8)}" + + super().__init__( + *_nodes, + name=name, + item=item, + config=config, + space=space, + fidelities=fidelities, + config_transform=config_transform, + meta=meta, + ) - See Also: - * [`Step`][amltk.pipeline.step.Step] - * [`Group`][amltk.pipeline.components.Group] - """ - item: Callable[..., Item] | Any | None = field(default=None, hash=False) - """The item attached to this step""" +@dataclass(init=False, frozen=True, eq=True) +class Component(Node[Item, Space]): + """A [`Component`][amltk.pipeline.Component] of the pipeline with + a possible item and **no children**. - paths: Sequence[Step] - """The paths that can be taken from this split""" + This is the basic building block of most pipelines, it accepts + as it's [`item=`][amltk.pipeline.node.Node.item] some function that will be + called with [`build_item()`][amltk.pipeline.components.Component.build_item] to + build that one part of the pipeline. - name: str - """Name of the step""" + When [`build_item()`][amltk.pipeline.Component.build_item] is called, + The [`.config`][amltk.pipeline.node.Node.config] on this node will be passed + to the function to build the item. - config: Mapping[str, Any] | None = field(default=None, hash=False) - """The configuration for this step""" + A common pattern is to use a [`Component`][amltk.pipeline.Component] to + wrap a constructor, specifying the [`space=`][amltk.pipeline.node.Node.space] + and [`config=`][amltk.pipeline.node.Node.config] to be used when building the + item. - search_space: Space | None = field(default=None, hash=False, repr=False) - """The search space for this step""" + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Component + from sklearn.ensemble import RandomForestClassifier - fidelity_space: Mapping[str, FidT] | None = field( - default=None, - hash=False, - repr=False, + rf = Component( + RandomForestClassifier, + config={"max_depth": 3}, + space={"n_estimators": (10, 100)} ) - """The fidelities for this step""" + from amltk._doc import doc_print; doc_print(print, rf) # markdown-exec: hide - config_transform: ( - Callable[ - [Mapping[str, Any], Any], - Mapping[str, Any], - ] - | None - ) = field(default=None, hash=False, repr=False) - """A function that transforms the configuration of this step""" + config = {"n_estimators": 50} # Sample from some space or something + configured_rf = rf.configure(config) - meta: Mapping[str, Any] | None = None - """Any meta information about this step""" + estimator = configured_rf.build_item() + from amltk._doc import doc_print; doc_print(print, estimator) # markdown-exec: hide + ``` - RICH_PANEL_BORDER_COLOR: ClassVar[str] = "chartreuse4" + Whenever some other node sees a function/constructor, i.e. `RandomForestClassifier`, + this will automatically be converted into a `Component`. - @override - def build(self, **kwargs: Any) -> Item: + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Sequential + from sklearn.ensemble import RandomForestClassifier + + pipeline = Sequential(RandomForestClassifier, name="my_pipeline") + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + ``` + + The default `.name` of a component is the name of the class/function that it will + use. You can explicitly set the `name=` if you want to when constructing the + component. + + Like all [`Node`][amltk.pipeline.node.Node]s, a `Component` accepts an explicit + [`name=`][amltk.pipeline.node.Node.name], + [`item=`][amltk.pipeline.node.Node.item], + [`config=`][amltk.pipeline.node.Node.config], + [`space=`][amltk.pipeline.node.Node.space], + [`fidelities=`][amltk.pipeline.node.Node.fidelities], + [`config_transform=`][amltk.pipeline.node.Node.config_transform] and + [`meta=`][amltk.pipeline.node.Node.meta]. + + See Also: + * [`Node`][amltk.pipeline.node.Node] + """ + + item: Callable[..., Item] + """A node which constructs an item in the pipeline.""" + + nodes: tuple[()] + """A component has no children.""" + + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#E6AF2E") + + _NODES_INIT: ClassVar = None + + def __init__( + self, + item: Callable[..., Item], + *, + name: str | None = None, + config: Config | None = None, + space: Space | None = None, + fidelities: Mapping[str, Any] | None = None, + config_transform: Callable[[Config, Any], Config] | None = None, + meta: Mapping[str, Any] | None = None, + ): + """See [`Node`][amltk.pipeline.node.Node] for details.""" + super().__init__( + name=name if name is not None else entity_name(item), + item=item, + config=config, + space=space, + fidelities=fidelities, + config_transform=config_transform, + meta=meta, + ) + + def build_item(self, **kwargs: Any) -> Item: """Build the item attached to this component. Args: @@ -647,252 +875,168 @@ def build(self, **kwargs: Any) -> Item: Item The built item """ - if self.item is None: - raise ValueError(f"Can't build a split without an item in step {self}") + config = self.config or {} + return self.item(**{**config, **kwargs}) - if callable(self.item): - config = self.config or {} - return self.item(**{**config, **kwargs}) - if self.config is not None: - raise ValueError(f"Can't pass config to a non-callable item in step {self}") +@dataclass(init=False, frozen=True, eq=True) +class Searchable(Node[None, Space]): # type: ignore + """A [`Searchable`][amltk.pipeline.Searchable] + node of the pipeline which just represents a search space, no item attached. - return self.item + While not usually applicable to pipelines you want to build, this component + is useful for creating a search space, especially if the the real pipeline you + want to optimize can not be built directly. For example, if you are optimize + a script, you may wish to use a `Searchable` to represent the search space + of that script. - @override - def _rich_table_items(self) -> Iterator[tuple[RenderableType, ...]]: - from rich.pretty import Pretty + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Searchable - from amltk.richutil import Function + script_space = Searchable({"mode": ["orange", "blue", "red"], "n": (10, 100)}) + from amltk._doc import doc_print; doc_print(print, script_space) # markdown-exec: hide + ``` - if self.item is not None: - if callable(self.item): - yield "item", Function(self.item) - else: - yield "item", Pretty(self.item) + A `Searchable` explicitly does not allow for `item=` to be set, nor can it have + any children. A `Searchable` accepts an explicit + [`name=`][amltk.pipeline.node.Node.name], + [`config=`][amltk.pipeline.node.Node.config], + [`space=`][amltk.pipeline.node.Node.space], + [`fidelities=`][amltk.pipeline.node.Node.fidelities], + [`config_transform=`][amltk.pipeline.node.Node.config_transform] and + [`meta=`][amltk.pipeline.node.Node.meta]. - yield from super()._rich_table_items() + See Also: + * [`Node`][amltk.pipeline.node.Node] + """ # noqa: E501 + item: None = None + """A searchable has no item.""" -@frozen(kw_only=True) -class Choice(Group[Space]): - """A Choice between different subcomponents. + nodes: tuple[()] = () + """A component has no children.""" - See Also: - * [`Step`][amltk.pipeline.step.Step] - * [`Group`][amltk.pipeline.components.Group] - """ + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="light_steel_blue") - paths: Sequence[Step] - """The paths that can be taken from this choice""" + _NODES_INIT: ClassVar = None - weights: Sequence[float] | None = field(hash=False) - """The weights to assign to each path""" + def __init__( + self, + space: Space | None = None, + *, + name: str | None = None, + config: Config | None = None, + fidelities: Mapping[str, Any] | None = None, + config_transform: Callable[[Config, Any], Config] | None = None, + meta: Mapping[str, Any] | None = None, + ): + """See [`Node`][amltk.pipeline.node.Node] for details.""" + if name is None: + name = f"Searchable-{randuid(8)}" + + super().__init__( + name=name, + config=config, + space=space, + fidelities=fidelities, + config_transform=config_transform, + meta=meta, + ) - name: str - """Name of the step""" - config: Mapping[str, Any] | None = field(default=None, hash=False) - """The configuration for this step""" +@dataclass(init=False, frozen=True, eq=True) +class Fixed(Node[Item, None]): # type: ignore + """A [`Fixed`][amltk.pipeline.Fixed] part of the pipeline that + represents something that can not be configured and used directly as is. - search_space: Space | None = field(default=None, hash=False, repr=False) - """The search space for this step""" + It consists of an [`.item`][amltk.pipeline.node.Node.item] that is fixed, + non-configurable and non-searchable. It also has no children. - fidelity_space: Mapping[str, FidT] | None = field( - default=None, - hash=False, - repr=False, - ) - """The fidelities for this step""" + This is useful for representing parts of the pipeline that are fixed, for example + if you have a pipeline that is a `Sequential` of nodes, but you want to + fix the first component to be a `PCA` with `n_components=3`, you can use a `Fixed` + to represent that. - config_transform: ( - Callable[ - [Mapping[str, Any], Any], - Mapping[str, Any], - ] - | None - ) = field(default=None, hash=False, repr=False) - """A function that transforms the configuration of this step""" + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Component, Fixed, Sequential + from sklearn.ensemble import RandomForestClassifier + from sklearn.decomposition import PCA - meta: Mapping[str, Any] | None = None - """Any meta information about this step""" + rf = Component(RandomForestClassifier, space={"n_estimators": (10, 100)}) + pca = Fixed(PCA(n_components=3)) - RICH_PANEL_BORDER_COLOR: ClassVar[str] = "orange4" + pipeline = Sequential(pca, rf, name="my_pipeline") + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + ``` - def iter_weights(self) -> Iterator[tuple[Step, float]]: - """Iter over the paths with their weights.""" - return zip(self.paths, (repeat(1) if self.weights is None else self.weights)) + Whenever some other node sees an instance of something, i.e. something that can't be + called, this will automatically be converted into a `Fixed`. - @override - def configure( # noqa: PLR0912, C901 - self, - config: Config, - *, - prefixed_name: bool | None = None, - transform_context: Any | None = None, - params: Mapping[str, Any] | None = None, - clear_space: bool | Literal["auto"] = "auto", - ) -> Step: - """Configure this step and anything following it with the given config. + ```python exec="true" source="material-block" html="true" + from amltk.pipeline import Sequential + from sklearn.ensemble import RandomForestClassifier + from sklearn.decomposition import PCA - Args: - config: The configuration to apply - prefixed_name: Whether items in the config are prefixed by the names - of the steps. - * If `None`, the default, then `prefixed_name` will be assumed to - be `True` if this step has a next step or if the config has - keys that begin with this steps name. - * If `True`, then the config will be searched for items prefixed - by the name of the step (and subsequent chained steps). - * If `False`, then the config will be searched for items without - the prefix, i.e. the config keys are exactly those matching - this steps search space. - transform_context: The context to pass to the config transform function. - params: The params to match any requests when configuring this step. - These will match against any ParamRequests in the config and will - be used to fill in any missing values. - clear_space: Whether to clear the search space after configuring. - If `"auto"` (default), then the search space will be cleared of any - keys that are in the config, if the search space is a `dict`. Otherwise, - `True` indicates that it will be removed in the returned step and - `False` indicates that it will remain as is. + pipeline = Sequential( + PCA(n_components=3), + RandomForestClassifier(n_estimators=50), + name="my_pipeline", + ) + from amltk._doc import doc_print; doc_print(print, pipeline) # markdown-exec: hide + ``` - Returns: - Step: The configured step - """ - if prefixed_name is None: - if any(key.startswith(self.name) for key in config): - prefixed_name = True - else: - prefixed_name = self.nxt is not None - - nxt = ( - self.nxt.configure( - config, - prefixed_name=prefixed_name, - transform_context=transform_context, - params=params, - clear_space=clear_space, - ) - if self.nxt - else None - ) + The default `.name` of a component is the class name of the item that it will + use. You can explicitly set the `name=` if you want to when constructing the + component. - # For a choice to be made, the config must have the a key - # for the name of this choice and the choice made. - chosen_path_name = config.get(self.name) + A `Fixed` accepts only an explicit [`name=`][amltk.pipeline.node.Node.name], + [`item=`][amltk.pipeline.node.Node.item], + [`meta=`][amltk.pipeline.node.Node.meta]. - if chosen_path_name is not None: - chosen_path = first_true( - self.paths, - pred=lambda path: path.name == chosen_path_name, - ) - if chosen_path is None: - raise Step.ConfigurationError( - step=self, - config=config, - reason=f"Choice {self.name} has no path '{chosen_path_name}'", - ) + See Also: + * [`Node`][amltk.pipeline.node.Node] + """ - # Configure the chosen path - subconfig = mapping_select(config, f"{self.name}:") - chosen_path = chosen_path.configure( - subconfig, - prefixed_name=prefixed_name, - transform_context=transform_context, - params=params, - clear_space=clear_space, - ) + item: Item = field() + """The fixed item that this node represents.""" - object.__setattr__(chosen_path, "old_parent", self.name) - - if nxt is not None: - # HACK: This is a hack to to modify the fact `nxt` is a frozen - # object. Frozen objects do not allow setting attributes after - # instantiation. - object.__setattr__(nxt, "prv", chosen_path) - object.__setattr__(chosen_path, "nxt", nxt) - - return chosen_path - - # Otherwise there is no chosen path and we simply configure what we can - # of the choices and return that - subconfig = mapping_select(config, f"{self.name}:") - paths = [ - path.configure( - subconfig, - prefixed_name=True, - transform_context=transform_context, - params=params, - clear_space=clear_space, - ) - for path in self.paths - ] - - # The config for this step is anything that doesn't have - # another delimiter in it - config_for_this_choice = {k: v for k, v in subconfig.items() if ":" not in k} - - if self.config is not None: - config_for_this_choice = {**self.config, **config_for_this_choice} - - _params = params or {} - reqs = [ - (k, v) - for k, v in config_for_this_choice.items() - if isinstance(v, ParamRequest) - ] - for k, request in reqs: - if request.key in _params: - config_for_this_choice[k] = _params[request.key] - elif request.has_default: - config_for_this_choice[k] = request.default - elif request.required: - raise ParamRequest.RequestNotMetError( - f"Missing required parameter {request.key} for step {self.name}" - " and no default was provided." - f"\nThe request given was: {request}", - f"Please use the `params=` argument to provide a value for this" - f" request. What was given was `{params=}`", - ) + space: None = None + """A frozen node has no search space.""" - # If we have a `dict` for a space, then we can remove any configured keys that - # overlap it. - _space: Any - if clear_space == "auto": - _space = self.search_space - if isinstance(self.search_space, dict) and any(self.search_space): - _overlap = set(config_for_this_choice).intersection(self.search_space) - _space = { - k: v for k, v in self.search_space.items() if k not in _overlap - } - if len(_space) == 0: - _space = None - - elif clear_space is True: - _space = None - else: - _space = self.search_space + fidelities: None = None + """A frozen node has no search space.""" - if self.config_transform is not None: - _config_for_this_choice = self.config_transform( - config_for_this_choice, - transform_context, - ) - else: - _config_for_this_choice = config_for_this_choice + config: None = None + """A frozen node has no config.""" - new_self = self.mutate( - paths=paths, - config=_config_for_this_choice if _config_for_this_choice else None, - nxt=nxt, - search_space=_space, - ) + config_transform: None = None + """A frozen node has no config so no transform.""" - if nxt is not None: - # HACK: This is a hack to to modify the fact `nxt` is a frozen - # object. Frozen objects do not allow setting attributes after - # instantiation. - object.__setattr__(nxt, "prv", new_self) + nodes: tuple[()] = () + """A component has no children.""" - return new_self + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions(panel_color="#56351E") + + _NODES_INIT: ClassVar = None + + def __init__( + self, + item: Item, + *, + name: str | None = None, + config: None = None, + space: None = None, + fidelities: None = None, + config_transform: None = None, + meta: Mapping[str, Any] | None = None, + ): + """See [`Node`][amltk.pipeline.node.Node] for details.""" + super().__init__( + name=name if name is not None else entity_name(item), + item=item, + config=config, + space=space, + fidelities=fidelities, + config_transform=config_transform, + meta=meta, + ) diff --git a/src/amltk/pipeline/node.py b/src/amltk/pipeline/node.py new file mode 100644 index 00000000..4b8be5a5 --- /dev/null +++ b/src/amltk/pipeline/node.py @@ -0,0 +1,732 @@ +"""A pipeline consists of [`Node`][amltk.pipeline.node.Node]s, which hold +the various attributes required to build a pipeline, such as the +[`.item`][amltk.pipeline.node.Node.item], its [`.space`][amltk.pipeline.node.Node.space], +its [`.config`][amltk.pipeline.node.Node.config] and so on. + +The [`Node`][amltk.pipeline.node.Node]s are connected to each in a parent-child +relation ship where the children are simply the [`.nodes`][amltk.pipeline.node.Node.nodes] +that the parent leads to. + +To give these attributes and relations meaning, there are various subclasses +of [`Node`][amltk.pipeline.node.Node] which give different syntactic meanings +when you want to construct something like a +[`search_space()`][amltk.pipeline.node.Node.search_space] or +[`build()`][amltk.pipeline.node.Node.build] some concrete object out of the +pipeline. + +For example, a [`Sequential`][amltk.pipeline.Sequential] node +gives the meaning that each of its children in +[`.nodes`][amltk.pipeline.node.Node.nodes] should follow one another while +something like a [`Choice`][amltk.pipeline.Choice] +gives the meaning that only one of its children should be chosen. + +You will likely never have to create a [`Node`][amltk.pipeline.node.Node] +directly, but instead use the various components to create the pipeline. + +??? note "Hashing" + + When hashing a node, i.e. to put it in a `set` or as a key in a `dict`, + only the name of the node and the hash of its children is used. + This means that two nodes with the same connectivity will be equalling hashed, + +??? note "Equality" + + When considering equality, this will be done by comparing all the fields + of the node. This include even the `parent` and `branches` fields. This + means two nodes are considered equal if they look the same and they are + connected in to nodes that also look the same. + + +You can find the [reference to `Component`s here](site:reference/pipeline/components.md), +while the [full guide to pipelines can be found here](site:guides/pipelines.md). +""" # noqa: E501 +from __future__ import annotations + +import inspect +from collections.abc import Callable, Iterator, Mapping, Sequence +from dataclasses import dataclass, field +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Concatenate, + Generic, + Literal, + NamedTuple, + ParamSpec, + TypeAlias, + TypeVar, + overload, +) +from typing_extensions import override + +from more_itertools import first_true +from rich.text import Text +from sklearn.pipeline import Pipeline as SklearnPipeline + +from amltk._functional import classname, mapping_select, prefix_keys +from amltk._richutil import RichRenderable +from amltk.exceptions import RequestNotMetError +from amltk.types import Config, Item, Space + +if TYPE_CHECKING: + from typing_extensions import Self + + from ConfigSpace import ConfigurationSpace + from rich.console import RenderableType + from rich.panel import Panel + + from amltk.pipeline.components import Choice, Join, Sequential + from amltk.pipeline.parsers.optuna import OptunaSearchSpace + + NodeLike: TypeAlias = ( + set["Node" | "NodeLike"] + | tuple["Node" | "NodeLike", ...] + | list["Node" | "NodeLike"] + | Callable[..., Item] + | Item + ) + + SklearnPipelineT = TypeVar("SklearnPipelineT", bound=SklearnPipeline) + +T = TypeVar("T") +ParserOutput = TypeVar("ParserOutput") +BuilderOutput = TypeVar("BuilderOutput") +P = ParamSpec("P") + + +_NotSet = object() + + +class RichOptions(NamedTuple): + """Options for rich printing.""" + + panel_color: str = "default" + node_orientation: Literal["horizontal", "vertical"] = "horizontal" + + +@dataclass(frozen=True) +class ParamRequest(Generic[T]): + """A parameter request for a node. This is most useful for things like seeds.""" + + key: str + """The key to request under.""" + + default: T | object = _NotSet + """The default value to use if the key is not found. + + If left as `_NotSet` (default) then an error will be raised if the + parameter is not found during configuration with + [`configure()`][amltk.pipeline.node.Node.configure]. + """ + + @property + def has_default(self) -> bool: + """Whether this request has a default value.""" + return self.default is not _NotSet + + +def request(key: str, default: T | object = _NotSet) -> ParamRequest[T]: + """Create a new parameter request. + + Args: + key: The key to request under. + default: The default value to use if the key is not found. + If left as `_NotSet` (default) then the key will be removed from the + config once [`configure`][amltk.pipeline.Node.configure] is called and + nothing has been provided. + """ + return ParamRequest(key=key, default=default) + + +@dataclass(init=False, frozen=True, eq=True) +class Node(RichRenderable, Generic[Item, Space]): + """The core node class for the pipeline. + + These are simple objects that are named and linked together to form + a chain. They are then wrapped in a `Pipeline` object to provide + a convenient interface for interacting with the chain. + + See Also: + For creating the concrete implementations of this class, you should use + the [components available](site:reference/pipeline/components.md). + """ + + name: str = field(hash=True) + """Name of the node""" + + item: Callable[..., Item] | Item | None = field(hash=False) + """The item attached to this node""" + + nodes: tuple[Node, ...] = field(hash=True) + """The nodes that this node leads to.""" + + config: Config | None = field(hash=False) + """The configuration for this node""" + + space: Space | None = field(hash=False) + """The search space for this node""" + + fidelities: Mapping[str, Any] | None = field(hash=False) + """The fidelities for this node""" + + config_transform: Callable[[Config, Any], Config] | None = field(hash=False) + """A function that transforms the configuration of this node""" + + meta: Mapping[str, Any] | None = field(hash=False) + """Any meta information about this node""" + + _NODES_INIT: ClassVar[Literal["args", "kwargs"] | None] = "args" + """Whether __init__ takes nodes as positional args, kwargs or does not accept it.""" + + RICH_OPTIONS: ClassVar[RichOptions] = RichOptions( + panel_color="default", + node_orientation="horizontal", + ) + """Options for rich printing""" + + def __init__( + self, + *nodes: Node, + name: str, + item: Item | Callable[[Item], Item] | None = None, + config: Config | None = None, + space: Space | None = None, + fidelities: Mapping[str, Any] | None = None, + config_transform: Callable[[Config, Any], Config] | None = None, + meta: Mapping[str, Any] | None = None, + ): + """Initialize a choice.""" + super().__init__() + object.__setattr__(self, "name", name) + object.__setattr__(self, "item", item) + object.__setattr__(self, "config", config) + object.__setattr__(self, "space", space) + object.__setattr__(self, "fidelities", fidelities) + object.__setattr__(self, "config_transform", config_transform) + object.__setattr__(self, "meta", meta) + object.__setattr__(self, "nodes", nodes) + + def __getitem__(self, key: str) -> Node: + """Get the node with the given name.""" + found = first_true( + self.nodes, + None, + lambda node: node.name == key, + ) + if found is None: + raise KeyError( + f"Could not find node with name {key} in '{self.name}'." + f" Available nodes are: {', '.join(node.name for node in self.nodes)}", + ) + + return found + + @override + def __eq__(self, other: Any) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + + return ( + self.name == other.name + and self.item == other.item + and self.config == other.config + and self.space == other.space + and self.fidelities == other.fidelities + and self.config_transform == other.config_transform + and self.meta == other.meta + and self.nodes == other.nodes + ) + + def __or__(self, other: Node | NodeLike) -> Choice: + from amltk.pipeline.components import as_node + + return as_node({self, as_node(other)}) + + def __and__(self, other: Node | NodeLike) -> Join: + from amltk.pipeline.components import as_node + + return as_node((self, other)) + + def __rshift__(self, other: Node | NodeLike) -> Sequential: + from amltk.pipeline.components import as_node + + return as_node([self, other]) + + def configure( + self, + config: Config, + *, + prefixed_name: bool | None = None, + transform_context: Any | None = None, + params: Mapping[str, Any] | None = None, + ) -> Self: + """Configure this node and anything following it with the given config. + + Args: + config: The configuration to apply + prefixed_name: Whether items in the config are prefixed by the names + of the nodes. + * If `None`, the default, then `prefixed_name` will be assumed to + be `True` if this node has a next node or if the config has + keys that begin with this nodes name. + * If `True`, then the config will be searched for items prefixed + by the name of the node (and subsequent chained nodes). + * If `False`, then the config will be searched for items without + the prefix, i.e. the config keys are exactly those matching + this nodes search space. + transform_context: Any context to give to `config_transform=` of individual + nodes. + params: The params to match any requests when configuring this node. + These will match against any ParamRequests in the config and will + be used to fill in any missing values. + + Returns: + The configured node + """ + # Get the config for this node + match prefixed_name: + case True: + config = mapping_select(config, f"{self.name}:") + case False: + pass + case None if any(k.startswith(f"{self.name}:") for k in config): + config = mapping_select(config, f"{self.name}:") + case None: + pass + + _kwargs: dict[str, Any] = {} + + # Configure all the branches if exists + if len(self.nodes) > 0: + nodes = tuple( + node.configure( + config, + prefixed_name=True, + transform_context=transform_context, + params=params, + ) + for node in self.nodes + ) + _kwargs["nodes"] = nodes + + this_config = { + hp: v + for hp, v in config.items() + if ( + ":" not in hp + and not any(hp.startswith(f"{node.name}") for node in self.nodes) + ) + } + if self.config is not None: + this_config = {**self.config, **this_config} + + this_config = dict(self._fufill_param_requests(this_config, params=params)) + + if self.config_transform is not None: + this_config = dict(self.config_transform(this_config, transform_context)) + + if len(this_config) > 0: + _kwargs["config"] = dict(this_config) + + return self.mutate(**_kwargs) + + def fidelity_space(self) -> dict[str, Any]: + """Get the fidelities for this node and any connected nodes.""" + fids = {} + for node in self.nodes: + fids.update(prefix_keys(node.fidelity_space(), f"{self.name}:")) + + return fids + + def linearized_fidelity(self, value: float) -> dict[str, int | float | Any]: + """Get the liniearized fidelities for this node and any connected nodes. + + Args: + value: The value to linearize. Must be between [0, 1] + + Return: + dictionary from key to it's linearized fidelity. + """ + assert 1.0 <= value <= 100.0, f"{value=} not in [1.0, 100.0]" # noqa: PLR2004 + d = {} + for node in self.nodes: + node_fids = prefix_keys( + node.linearized_fidelity(value), + f"{self.name}:", + ) + d.update(node_fids) + + if self.fidelities is None: + return d + + for f_name, f_range in self.fidelities.items(): + match f_range: + case (int() | float(), int() | float()): + low, high = f_range + fid = low + (high - low) * value + fid = low + (high - low) * (value - 1) / 100 + fid = fid if isinstance(low, float) else round(fid) + d[f_name] = fid + case _: + raise ValueError( + f"Invalid fidelities to linearize {f_range} for {f_name}" + f" in {self}. Only supports ranges of the form (low, high)", + ) + + return prefix_keys(d, f"{self.name}:") + + def iter(self) -> Iterator[Node]: + """Iterate the the nodes, including this node. + + Yields: + The nodes connected to this node + """ + yield self + for node in self.nodes: + yield from node.iter() + + def mutate(self, **kwargs: Any) -> Self: + """Mutate the node with the given keyword arguments. + + Args: + **kwargs: The keyword arguments to mutate + + Returns: + Self + The mutated node + """ + _args = () + _kwargs = {**self.__dict__, **kwargs} + + # If there's nodes in kwargs, we have to check if it's + # a positional or keyword argument and handle accordingly. + if (nodes := _kwargs.pop("nodes", None)) is not None: + match self._NODES_INIT: + case "args": + _args = nodes + case "kwargs": + _kwargs["nodes"] = nodes + case None if len(nodes) == 0: + pass # Just ignore it, it's popped out + case None: + raise ValueError( + "Cannot mutate nodes when __init__ does not accept nodes", + ) + + # If there's a config in kwargs, we have to check if it's actually got values + config = _kwargs.pop("config", None) + if config is not None and len(config) > 0: + _kwargs["config"] = config + + # Lastly, we remove anything that can't be passed to kwargs of the + # subclasses __init__ + _available_kwargs = inspect.signature(self.__init__).parameters.keys() # type: ignore + for k in list(_kwargs.keys()): + if k not in _available_kwargs: + _kwargs.pop(k) + + return self.__class__(*_args, **_kwargs) + + def copy(self) -> Self: + """Copy this node, removing all links in the process.""" + return self.mutate() + + def path_to(self, key: str | Node | Callable[[Node], bool]) -> list[Node] | None: + """Find a path to the given node. + + Args: + key: The key to search for or a function that returns True if the node + is the desired node + + Returns: + The path to the node if found, else None + """ + # We found our target, just return now + + match key: + case Node(): + pred = lambda node: node == key + case str(): + pred = lambda node: node.name == key + case _: + pred = key + + for path, node in self.walk(): + if pred(node): + return path + + return None + + def walk( + self, + path: Sequence[Node] | None = None, + ) -> Iterator[tuple[list[Node], Node]]: + """Walk the nodes in this chain. + + Args: + path: The current path to this node + + Yields: + The parents of the node and the node itself + """ + path = list(path) if path is not None else [] + yield path, self + + for node in self.nodes: + yield from node.walk(path=[*path, self]) + + @overload + def find(self, key: str | Node | Callable[[Node], bool], default: T) -> Node | T: + ... + + @overload + def find(self, key: str | Node | Callable[[Node], bool]) -> Node | None: + ... + + def find( + self, + key: str | Node | Callable[[Node], bool], + default: T | None = None, + ) -> Node | T | None: + """Find a node in that's nested deeper from this node. + + Args: + key: The key to search for or a function that returns True if the node + is the desired node + default: The value to return if the node is not found. Defaults to None + + Returns: + The node if found, otherwise the default value. Defaults to None + """ + itr = self.iter() + match key: + case Node(): + return first_true(itr, default, lambda node: node == key) + case str(): + return first_true(itr, default, lambda node: node.name == key) + case _: + return first_true(itr, default, key) # type: ignore + + @overload + def search_space( + self, + parser: Literal["configspace"], + *, + flat: bool = False, + conditionals: bool = True, + delim: str = ":", + ) -> ConfigurationSpace: + ... + + @overload + def search_space( + self, + parser: Literal["optuna"], + *, + seed: int | None = None, + flat: bool = False, + conditionals: bool = True, + delim: str = ":", + ) -> OptunaSearchSpace: + ... + + @overload + def search_space( + self, + parser: Callable[Concatenate[Node, P], ParserOutput], + *parser_args: P.args, + **parser_kwargs: P.kwargs, + ) -> ParserOutput: + ... + + def search_space( + self, + parser: ( + Callable[Concatenate[Node, P], ParserOutput] + | Literal["configspace", "optuna"] + ), + *parser_args: P.args, + **parser_kwargs: P.kwargs, + ) -> ParserOutput | ConfigurationSpace | OptunaSearchSpace: + """Get the search space for this node.""" + match parser: + case "configspace": + from amltk.pipeline.parsers.configspace import parser as cs_parser + + return cs_parser(self, *parser_args, **parser_kwargs) # type: ignore + case "optuna": + from amltk.pipeline.parsers.optuna import parser as optuna_parser + + return optuna_parser(self, *parser_args, **parser_kwargs) # type: ignore + case str(): # type: ignore + raise ValueError( + f"Invalid str for parser {parser}. " + "Please use 'configspace' or 'optuna' or pass in your own" + " parser function", + ) + case _: + return parser(self, *parser_args, **parser_kwargs) + + @overload + def build( + self, + builder: Literal["sklearn"], + *builder_args: Any, + pipeline_type: type[SklearnPipelineT] = SklearnPipeline, + **builder_kwargs: Any, + ) -> SklearnPipelineT: + ... + + @overload + def build( + self, + builder: Literal["sklearn"], + *builder_args: Any, + **builder_kwargs: Any, + ) -> SklearnPipeline: + ... + + @overload + def build( + self, + builder: Callable[Concatenate[Node, P], BuilderOutput], + *builder_args: P.args, + **builder_kwargs: P.kwargs, + ) -> BuilderOutput: + ... + + def build( + self, + builder: Callable[Concatenate[Node, P], BuilderOutput] | Literal["sklearn"], + *builder_args: P.args, + **builder_kwargs: P.kwargs, + ) -> BuilderOutput | SklearnPipeline: + """Get the search space for this node.""" + match builder: + case "sklearn": + from amltk.pipeline.builders.sklearn import build as _build + + return _build(self, *builder_args, **builder_kwargs) # type: ignore + case _: + return builder(self, *builder_args, **builder_kwargs) + + def _rich_iter(self) -> Iterator[RenderableType]: + """Iterate the panels for rich printing.""" + yield self.__rich__() + for node in self.nodes: + yield from node._rich_iter() + + def _rich_table_items(self) -> Iterator[tuple[RenderableType, ...]]: + """Get the items to add to the rich table.""" + from rich.pretty import Pretty + + from amltk._richutil import Function + + if self.item is not None: + if isinstance(self.item, type) or callable(self.item): + yield "item", Function(self.item, signature="...") + else: + yield "item", Pretty(self.item) + + if self.config is not None: + yield "config", Pretty(self.config) + + if self.space is not None: + yield "space", Pretty(self.space) + + if self.fidelities is not None: + yield "fidelity", Pretty(self.fidelities) + + if self.config_transform is not None: + yield "transform", Function(self.config_transform, signature="...") + + if self.meta is not None: + yield "meta", Pretty(self.meta) + + def _rich_panel_contents(self) -> Iterator[RenderableType]: + from rich.table import Table + from rich.text import Text + + options = self.RICH_OPTIONS + + if panel_contents := list(self._rich_table_items()): + table = Table.grid(padding=(0, 1), expand=False) + for tup in panel_contents: + table.add_row(*tup, style="default") + table.add_section() + yield table + + if len(self.nodes) > 0: + match options: + case RichOptions(node_orientation="horizontal"): + table = Table.grid(padding=(0, 1), expand=False) + nodes = [node.__rich__() for node in self.nodes] + table.add_row(*nodes) + yield table + case RichOptions(node_orientation="vertical"): + first, *rest = self.nodes + yield first.__rich__() + for node in rest: + yield Text("↓", style="bold", justify="center") + yield node.__rich__() + case _: + raise ValueError(f"Invalid orientation {options.node_orientation}") + + def display(self, *, full: bool = False) -> RenderableType: + """Display this node. + + Args: + full: Whether to display the full node or just a summary + """ + if not full: + return self.__rich__() + + from rich.console import Group as RichGroup + + return RichGroup(*self._rich_iter()) + + @override + def __rich__(self) -> Panel: + from rich.console import Group as RichGroup + from rich.panel import Panel + + clr = self.RICH_OPTIONS.panel_color + title = Text.assemble( + (classname(self), f"{clr} bold"), + "(", + (self.name, f"{clr} italic"), + ")", + style="default", + end="", + ) + contents = list(self._rich_panel_contents()) + _content = contents[0] if len(contents) == 1 else RichGroup(*contents) + return Panel( + _content, + title=title, + title_align="left", + border_style=clr, + expand=False, + ) + + def _fufill_param_requests( + self, + config: Config, + params: Mapping[str, Any] | None = None, + ) -> Config: + _params = params or {} + new_config = dict(config) + + for k, request in config.items(): + match request: + case ParamRequest(key=k) if k in _params: + new_config[k] = _params[request.key] + case ParamRequest(default=default) if request.has_default: + new_config[k] = default + case ParamRequest(): + raise RequestNotMetError(f"{params=} missing {request=} for {self}") + case _: + continue + + return new_config diff --git a/src/amltk/pipeline/parser.py b/src/amltk/pipeline/parser.py deleted file mode 100644 index 50030f05..00000000 --- a/src/amltk/pipeline/parser.py +++ /dev/null @@ -1,472 +0,0 @@ -"""Parser.""" -from __future__ import annotations - -import logging -from abc import ABC, abstractmethod -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Generic, - Mapping, - Sequence, - TypeVar, - overload, -) -from typing_extensions import override - -from more_itertools import first_true, seekable - -from amltk.exceptions import safe_map -from amltk.pipeline.components import Choice, Group, Step -from amltk.pipeline.pipeline import Pipeline - -if TYPE_CHECKING: - from amltk.types import Seed - -logger = logging.getLogger(__name__) - -InputT = TypeVar("InputT") -OutputT = TypeVar("OutputT") - - -class ParserError(Exception): - """Error for when a Parser fails to parse a Pipeline.""" - - @overload - def __init__(self, parser: Parser, error: Exception): - ... - - @overload - def __init__(self, parser: list[Parser], error: list[Exception]): - ... - - def __init__( - self, - parser: Parser | list[Parser], - error: Exception | list[Exception], - ): - """Create a new parser error. - - Args: - parser: The parser(s) that failed. - error: The error(s) that was raised. - - Raises: - ValueError: If parser is a list, exception must be a list of the - same length. - """ - super().__init__(parser, error) - if isinstance(parser, list) and ( - not (isinstance(error, list) and len(parser) == len(error)) - ): - raise ValueError( - "If parser is a list, `error` must be a list of the same length." - f"Got {parser=} and {error=} .", - ) - - self.parser = parser - self.error = error - - @override - def __str__(self) -> str: - if isinstance(self.parser, list): - msg = "\n\n".join( - f"Failed to parse with {p}:" + "\n " + f"{e.__class__.__name__}: {e}" - for p, e in zip(self.parser, self.error) # type: ignore - ) - else: - msg = ( - f"Failed to parse with {self.parser}:" - + "\n" - + f"{self.error.__class__.__name__}: {self.error}" - ) - - return msg - - -class Parser(ABC, Generic[InputT, OutputT]): - """A parser to parse a Pipeline/Step's `search_space` into a Space. - - This class is a parser for a given Space type, providing functionality - for the parsing algothim to run on a given search space. To implement - a parser for a new search space, you must implement the abstract methods - in this class. - - !!! example "Abstract Methods" - - * [`parse_space`][amltk.pipeline.parser.Parser.parse_space]: - Parse a search space into a space. - * [`empty`][amltk.pipeline.parser.Parser.empty]: - Get an empty space. - * [`insert`][amltk.pipeline.parser.Parser.insert]: - Insert a space into another space, with a possible prefix + delimiter. - * [`set_seed`][amltk.pipeline.parser.Parser.set_seed]: - _(Optional)_ Set the seed of a space. - - !!! note - - If your space supports conditions, you must also implement: - - * [`condition`][amltk.pipeline.parser.Parser.condition]: - Condition a set of subspaces on their names, based on a hyperparameter - with which takes on values with these names. Must be encoded as a Space. - - Please see the respective docstrings for more. - - See Also: - * [`SpaceAdapter`][amltk.pipeline.space.SpaceAdapter] - Together with implementing the [`Sampler`][amltk.pipeline.sampler.Sampler] - interface, this class provides a complete adapter for a given search space. - """ - - ParserError: ClassVar[type[ParserError]] = ParserError - """The error to raise when parsing fails.""" - - @classmethod - def default_parsers(cls) -> list[Parser]: - """Get the default parsers.""" - parsers: list[Parser] = [] - - try: - from amltk.configspace import ConfigSpaceAdapter - - parsers.append(ConfigSpaceAdapter()) - except ImportError as e: - logger.debug( - "ConfigSpace not installed for parsing, skipping" - f"\n{e.__class__.__name__}: {e}", - ) - - try: - from amltk.optuna import OptunaSpaceAdapter - - parsers.append(OptunaSpaceAdapter()) - except ImportError as e: - logger.debug( - "Optuna not installed for parsing, skipping" - f"\n{e.__class__.__name__}: {e}", - ) - - return parsers - - @classmethod - def try_parse( - cls, - pipeline_or_step: Pipeline | Step, - parser: type[Parser[InputT, OutputT]] | Parser[InputT, OutputT] | None = None, - *, - seed: Seed | None = None, - ) -> OutputT: - """Attempt to parse a pipeline with the default parsers. - - Args: - pipeline_or_step: The pipeline or step to parse. - parser: The parser to use. If `None`, will try all default parsers that - are installed. - seed: The seed to use for the parser. - - Returns: - The parsed space. - """ - if parser is None: - parsers = cls.default_parsers() - elif isinstance(parser, Parser): - parsers = [parser] - elif isinstance(parser, type): - parsers = [parser()] - else: - parsers = [] - - if not any(parsers): - raise RuntimeError( - "Found no possible parser to use. Have you tried installing any of:" - "\n* ConfigSpace" - "\n* Optuna" - "\nPlease see the integration documentation for more info, especially" - "\nif using an optimizer which often requires a specific search space." - "\nUsually just installing the optimizer will work.", - ) - - def _parse(_parser: Parser[InputT, OutputT]) -> OutputT: - _parsed_space = _parser.parse(pipeline_or_step) - if seed is not None: - _parser.set_seed(_parsed_space, seed) - return _parsed_space - - # Wrap in seekable so we don't evaluate all of them, only as - # far as we need to get a succesful parse. - results_itr = seekable(safe_map(_parse, parsers)) - - is_result = lambda r: not (isinstance(r, tuple) and isinstance(r[0], Exception)) - # Progress the iterator until we get a successful parse - parsed_space = first_true(results_itr, default=False, pred=is_result) - - # If we didn't get a succesful parse, raise the appropriate error - if parsed_space is False: - results_itr.seek(0) # Reset to start of iterator - errors = list(results_itr) - raise Parser.ParserError(parser=parsers, error=errors) # type: ignore - - assert not isinstance(parsed_space, (tuple, bool)) - return parsed_space - - def parse(self, step: Pipeline | Step | Group | Choice | Any) -> OutputT: - """Parse a pipeline, step or something resembling a Space. - - Args: - step: The pipeline or step to parse. If it is not - a Pipeline object, it will be treated as a - search space and attempt to be parsed as such. - - Returns: - The space representing the pipeline or step. - """ - if isinstance(step, Pipeline): - return self.parse_pipeline(step) - - if isinstance(step, Choice): - return self.parse_choice(step) - - if isinstance(step, Group): - return self.parse_group(step) - - if isinstance(step, Step): - return self.parse_step(step) - - return self.parse_space(step) - - def parse_pipeline(self, pipeline: Pipeline) -> OutputT: - """Parse a pipeline into a space. - - Args: - pipeline: The pipeline to parse. - - Returns: - The space representing the pipeline. The pipeline will have no prefix - while any modules attached to the pipeline will have the modules - name as the prefix in the space - """ - space = self.parse(pipeline.head) - - for module in pipeline.modules.values(): - prefix_delim = (module.name, ":") if isinstance(module, Pipeline) else None - space = self.insert(space, self.parse(module), prefix_delim=prefix_delim) - - return space - - def parse_step(self, step: Step) -> OutputT: - """Parse the space from a given step. - - Args: - step: The step to parse. - - Returns: - The space for this step. - """ - space = self.empty() - - if step.search_space: - _space = self.parse_space(step.search_space, step.config) - space = self.insert(space, _space, prefix_delim=(step.name, ":")) - - if step.nxt is not None: - _space = self.parse(step.nxt) - space = self.insert(space, _space) - - return space - - def parse_group(self, step: Group) -> OutputT: - """Parse the space from a given group. - - Args: - step: The group to parse. - - Returns: - The space for this group. - """ - space = self.empty() - - if step.search_space: - _space = self.parse_space(step.search_space, step.config) - space = self.insert(space, _space, prefix_delim=(step.name, ":")) - - for path in step.paths: - _space = self.parse(path) - space = self.insert(space, _space, prefix_delim=(step.name, ":")) - - if step.nxt is not None: - _space = self.parse(step.nxt) - space = self.insert(space, _space) - - return space - - def parse_choice(self, step: Choice) -> OutputT: - """Parse the space from a given choice. - - Note: - This relies on the implementation of the `condition` method to - condition the subspaces under the choice parameter. Please see - the class docstring [here][amltk.pipeline.parser.Parser] for more - information. - - Args: - step: The choice to parse. - - Returns: - The space for this choice. - """ - space = self.empty() - - if step.search_space: - _space = self.parse_space(step.search_space, step.config) - space = self.insert(space, _space, prefix_delim=(step.name, ":")) - - # Get all the subspaces for each choice - subspaces = {path.name: self.parse(path) for path in step.paths} - - # Condition each subspace under some parameter "choice_name" - conditioned_space = self.condition( - choice_name=step.name, - delim=":", - spaces=subspaces, - weights=step.weights, - ) - - space = self.insert(space, conditioned_space) - - if step.nxt is not None: - _space = self.parse(step.nxt) - space = self.insert(space, _space) - - return space - - def set_seed(self, space: OutputT, seed: Seed) -> OutputT: # noqa: ARG002 - """Set the seed for the space. - - Overwrite if the can do something meaninfgul for the space. - - Args: - space: The space to set the seed for. - seed: The seed to set. - - Returns: - The space with the seed set if applicable. - """ - return space - - @abstractmethod - def empty(self) -> OutputT: - """Create an empty space. - - Returns: - An empty space. - """ - ... - - @abstractmethod - def insert( - self, - space: OutputT, - subspace: InputT | OutputT, - *, - prefix_delim: tuple[str, str] | None = None, - ) -> OutputT: - """Insert a subspace into a space. - - Args: - space: The space to insert into. - subspace: The subspace to insert. - prefix_delim: The prefix, delimiter to use for the subspace. - - Returns: - The original space with the subspace inserted. - """ - ... - - def merge(self, *spaces: InputT) -> OutputT: - """Merge a list of spaces into a single space. - - - ```python exec="true" source="material-block" result="python" title="Merging spaces" - # Note, relies on ConfigSpace being installed `pip install ConfigSpace` - from amltk.configspace import ConfigSpaceAdapter - - adapter = ConfigSpaceAdapter() - - space_1 = adapter.parse({ "a": (1, 10) }) - space_2 = adapter.parse({ "b": (10.5, 100.5) }) - space_3 = adapter.parse({ "c": ["apple", "banana", "carrot"] }) - - space = adapter.merge(space_1, space_2, space_3) - - print(space) - ``` - - Args: - spaces: The spaces to merge. - - Returns: - The merged space. - """ # noqa: E501 - space = self.empty() - - for _space in spaces: - space = self.insert(space, _space) - - return space - - @abstractmethod - def parse_space( - self, - space: Any, - config: Mapping[str, Any] | None = None, - ) -> OutputT: - """Parse a space from some object. - - Args: - space: The space to parse. - config: A possible set of concrete values to use for the space. - If provided, the space should either set these values as constant - or be excluded from the generated space. - - Returns: - The parsed space. - """ - ... - - @abstractmethod - def condition( - self, - choice_name: str, - delim: str, - spaces: dict[str, OutputT], - weights: Sequence[float] | None = None, - ) -> OutputT: - """Condition a set of spaces such that only one can be active. - - When sampling from the generated space. The choice name must be present - along with the value it takes, which is any of the names of the choice paths. - - When a given choice is sampled, the corresponding subspace is sampled - and none of the others. - - This must be encoded into the Space. - - If your space does not support conditionals, you can raise a - an Error. If your space does support conditionals but not in this - format, please raise an Issue! - - Args: - choice_name: The name of the choice parameter. - delim: The delimiter to use for the choice parameter. - spaces: The spaces to condition. This is a mapping from the name - of the choice to the space. - weights: The weights to use for the choice parameter. - If set and not possible, raise an Error. - """ - ... - - @override - def __repr__(self) -> str: - return f"{self.__class__.__name__}" diff --git a/tests/optuna/__init__.py b/src/amltk/pipeline/parsers/__init__.py similarity index 100% rename from tests/optuna/__init__.py rename to src/amltk/pipeline/parsers/__init__.py diff --git a/src/amltk/pipeline/parsers/configspace.py b/src/amltk/pipeline/parsers/configspace.py new file mode 100644 index 00000000..308b8075 --- /dev/null +++ b/src/amltk/pipeline/parsers/configspace.py @@ -0,0 +1,290 @@ +"""[ConfigSpace](https://automl.github.io/ConfigSpace/master/) is a library for +representing and sampling configurations for hyperparameter optimization. +It features a straightforward API for defining hyperparameters, their ranges +and even conditional dependencies. + +It is generally flexible enough for more complex use cases, even +handling the complex pipelines of [AutoSklearn](https://automl.github.io/auto-sklearn/master/) +and [AutoPyTorch](https://automl.github.io/Auto-PyTorch/master/), large +scale hyperparameter spaces over which to optimize entire +pipelines at a time. + +!!! tip "Requirements" + + This requires `ConfigSpace` which can be installed with: + + ```bash + pip install "amltk[configspace]" + + # Or directly + pip install ConfigSpace + ``` + +In general, you should have the +[ConfigSpace documentation](https://automl.github.io/ConfigSpace/master/) +ready to consult for a full understanding of how to construct +hyperparameter spaces with AMLTK. + +#### Basic Usage + +You can directly us the [`parser()`][amltk.pipeline.parsers.configspace.parser] +function and pass that into the [`search_space()`][amltk.pipeline.Node.search_space] +method of a [`Node`][amltk.pipeline.Node], however you can also simply provide +`#!python search_space(parser="configspace", ...)` for simplicity. + +```python exec="true" result="python" source="material-block" hl_lines="27" session="configspace-parser" +from amltk.pipeline import Component, Choice, Sequential +from sklearn.decomposition import PCA +from sklearn.ensemble import RandomForestClassifier +from sklearn.neural_network import MLPClassifier +from sklearn.svm import SVC + +my_pipeline = ( + Sequential(name="Pipeline") + >> Component(PCA, space={"n_components": (1, 3)}) + >> Choice( + Component( + SVC, + space={"C": (0.1, 10.0)} + ), + Component( + RandomForestClassifier, + space={"n_estimators": (10, 100), "criterion": ["gini", "log_loss"]}, + ), + Component( + MLPClassifier, + space={ + "activation": ["identity", "logistic", "relu"], + "alpha": (0.0001, 0.1), + "learning_rate": ["constant", "invscaling", "adaptive"], + }, + ), + name="estimator" + ) +) + +space = my_pipeline.search_space("configspace") +print(space) +``` + +Here we have an example of a few different kinds of hyperparmeters, + +* `PCA:n_components` is a integer with a range of 1 to 3, uniform distribution, as specified + by it's integer bounds in a tuple. +* `SVC:C` is a float with a range of 0.1 to 10.0, uniform distribution, as specified + by it's float bounds in a tuple. +* `RandomForestClassifier:criterion` is a categorical hyperparameter, with two choices, + `"gini"` and `"log_loss"`. + +There is also a [`Choice`][amltk.pipeline.Choice] node, which is a special node that indicates that +we could choose from one of these estimators. This leads to the conditionals that you +can see in the printed out space. + +You may wish to remove all conditionals if an `Optimizer` does not support them, or +you may wish to remove them for other reasons. You can do this by passing +`conditionals=False` to the [`parser()`][amltk.pipeline.parsers.configspace.parser] function. + +```python exec="true" result="python" source="material-block" hl_lines="27" session="configspace-parser" +print(my_pipeline.search_space("configspace", conditionals=False)) +``` + +Likewise, you can also remove all heirarchy from the space which may make downstream tasks easier, +by passing `flat=True` to the [`parser()`][amltk.pipeline.parsers.configspace.parser] function. + +```python exec="true" result="python" source="material-block" hl_lines="27" session="configspace-parser" +print(my_pipeline.search_space("configspace", flat=True)) +``` + +#### More Specific Hyperparameters +You'll often want to be a bit more specific with your hyperparameters, here we just +show a few examples of how you'd couple your pipelines a bit more towards `ConfigSpace`. + +```python exec="true" result="python" source="material-block" +from ConfigSpace import Float, Categorical, Normal +from amltk.pipeline import Searchable + +s = Searchable( + space={ + "lr": Float("lr", bounds=(1e-5, 1.), log=True, default=0.3), + "balance": Float("balance", bounds=(-1.0, 1.0), distribution=Normal(0.0, 0.5)), + "color": Categorical("color", ["red", "green", "blue"], weights=[2, 1, 1], default="blue"), + }, + name="Something-To-Search", +) +print(s.search_space("configspace")) +``` + +#### Conditional ands Advanced Usage +We will refer you to the +[ConfigSpace documentation](https://automl.github.io/ConfigSpace/master/) for the construction +of these. However once you've constructed a `ConfigurationSpace` and added any forbiddens and +conditionals, you may simply set that as the `.space` attribute. + +```python exec="true" result="python" source="material-block" hl_lines="27" +from amltk.pipeline import Component, Choice, Sequential +from ConfigSpace import ConfigurationSpace, EqualsCondition, InCondition + +myspace = ConfigurationSpace({"A": ["red", "green", "blue"], "B": (1, 10), "C": (-100.0, 0.0)}) +myspace.add_conditions([ + EqualsCondition(myspace["B"], myspace["A"], "red"), # B is active when A is red + InCondition(myspace["C"], myspace["A"], ["green", "blue"]), # C is active when A is green or blue +]) + +component = Component(object, space=myspace, name="MyThing") + +parsed_space = component.search_space("configspace") +print(parsed_space) +``` + +""" # noqa: E501 +from __future__ import annotations + +from collections.abc import Mapping +from copy import deepcopy +from typing import Any + +from ConfigSpace import Categorical, ConfigurationSpace, Constant + +from amltk.pipeline import Choice, Node + + +def _remove_hyperparameter( + name: str, + space: ConfigurationSpace, + seed: int | None = None, +) -> ConfigurationSpace: + if name not in space._hyperparameters: + raise ValueError(f"{name} not in {space}") + + # Copying conditionals only work on objects and not named entities + # Seeing as we copy objects and don't use the originals, transfering these + # to the new objects is a bit tedious, possible but not required at this time + # ... same goes for forbiddens + if name in space._conditionals: + raise ValueError("Can't remove conditionals") + if any(name == f.hyperparameter.name for f in space.get_forbiddens()): + raise ValueError("Can't remove forbiddens") + + hps = [deepcopy(hp) for hp in space.get_hyperparameters() if hp.name != name] + + new_space = ConfigurationSpace(seed=seed, name=space.name, meta=space.meta) + new_space.add_hyperparameters(hps) + return new_space + + +def _remove_conditionals( + space: ConfigurationSpace, + seed: int | None = None, +) -> ConfigurationSpace: + new_space = ConfigurationSpace(seed=seed, name=space.name, meta=space.meta) + new_space.add_hyperparameters(space.values()) + return new_space + + +def _replace_constants( + config: Mapping[str, Any], + space: ConfigurationSpace, + seed: int | None = None, +) -> ConfigurationSpace: + for key, value in config.items(): + if key in space._hyperparameters: + space = _remove_hyperparameter(key, space, seed) + + # These are just restrictions on hyperparameters from ConfigSpace + match value: + case bool(): + space.add_hyperparameter(Constant(key, str(value))) + case int() | float() | str(): + space.add_hyperparameter(Constant(key, value)) + case _: + raise ValueError(f"Can't handle {value} from {config} as Constant") + + return space + + +def _parse_space( + node: Node, + *, + conditionals: bool = True, + seed: int | None = None, +) -> ConfigurationSpace: + space = node.space + match space: + case ConfigurationSpace(): + _space = deepcopy(space) + case Mapping(): + _space = ConfigurationSpace(dict(space)) + case None: + _space = ConfigurationSpace() + case _: + raise ValueError(f"Can't handle {space} from {node}") + + if not conditionals: + _space = _remove_conditionals(_space, seed) + + if node.config is not None: + _space = _replace_constants(node.config, _space, seed) + + if seed is not None: + _space.seed(seed) + + return _space + + +def parser( + node: Node, + *, + seed: int | None = None, + flat: bool = False, + conditionals: bool = True, + delim: str = ":", +) -> ConfigurationSpace: + """Parse a Node and its children into a ConfigurationSpace. + + Args: + node: The Node to parse + seed: The seed to use for the ConfigurationSpace + flat: Whether to have a heirarchical naming scheme for nodes and their children. + conditionals: Whether to include conditionals in the space from a + [`Choice`][amltk.pipeline.Choice]. If this is `False`, this will + also remove all forbidden clauses and other conditional clauses. + The primary use of this functionality is that some optimizers do not + support these features. + delim: The delimiter to use for the names of the hyperparameters + """ + space = ConfigurationSpace(seed=seed) + space.add_configuration_space( + prefix=node.name, + delimiter=delim, + configuration_space=_parse_space(node, seed=seed, conditionals=conditionals), + ) + + children = node.nodes + + choice = None + if isinstance(node, Choice) and any(children): + choice = Categorical( + name=f"{node.name}{delim}__choice__", + items=[child.name for child in children], + ) + space.add_hyperparameter(choice) + + for child in children: + space.add_configuration_space( + prefix=node.name if not flat else "", + delimiter=delim if not flat else "", + configuration_space=parser( + child, + seed=seed, + flat=flat, + conditionals=conditionals, + delim=delim, + ), + parent_hyperparameter=( + {"parent": choice, "value": child.name} + if choice and conditionals + else None + ), + ) + + return space diff --git a/src/amltk/pipeline/parsers/optuna.py b/src/amltk/pipeline/parsers/optuna.py new file mode 100644 index 00000000..f6387b55 --- /dev/null +++ b/src/amltk/pipeline/parsers/optuna.py @@ -0,0 +1,217 @@ +"""[Optuna](https://optuna.org/) parser for parsing out a +[`search_space()`][amltk.pipeline.node.Node.search_space]. +from a pipeline. + +!!! tip "Requirements" + + This requires `Optuna` which can be installed with: + + ```bash + pip install amltk[optuna] + + # Or directly + pip install optuna + ``` + +??? warning "Limitations" + + Optuna feature a very dynamic search space (_define-by-run_), + where people typically sample from some trial object and use traditional + python control flow to define conditionality. + + This means we can not trivially represent this conditionality in a static + search space. While _band-aids_ are possible, + it naturally does not sit well with the static output of a parser. + + As such, our parser **does not support conditionals or choices!**. + Users may still use the _define-by-run_ within their optimization function + itself. + + If you have experience with Optuna and have any suggestions, + please feel free to open an issue or PR on GitHub! + +### Usage +The typical way to represent a search space for Optuna is just to use a dictionary, +where the keys are the names of the hyperparameters and the values are either +integer/float tuples indicating boundaries or some discrete set of values. +It is possible to have the value directly be a +`BaseDistribution`, an optuna type, when you need to customize the distribution more. + + +```python exec="true" source="material-block" html="true" session="optuna-parser" +from amltk.pipeline import Component +from optuna.distributions import FloatDistribution + +c = Component( + object, + space={ + "myint": (1, 10), + "myfloat": (1.0, 10.0), + "mycategorical": ["a", "b", "c"], + "log-scale-custom": FloatDistribution(1e-10, 1e-2, log=True), + }, + name="name", +) + +space = c.search_space(parser="optuna") +from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide +``` + +You may also just pass the `parser=` function directly if preferred + +```python exec="true" source="material-block" html="true" session="optuna-parser" +from amltk.pipeline.parsers.optuna import parser as optuna_parser + +space = c.search_space(parser=optuna_parser) +from amltk._doc import doc_print; doc_print(print, space) # markdown-exec: hide +``` + +When using [`search_space()`][amltk.pipeline.node.Node.search_space] on a some nested +structures, you may want to flatten the names of the hyperparameters. For this you +can use `flat=` + +```python exec="true" source="material-block" html="true" +from amltk.pipeline import Searchable, Sequential + +seq = Sequential( + Searchable({"myint": (1, 10)}, name="nested_1"), + Searchable({"myfloat": (1.0, 10.0)}, name="nested_2"), + name="seq" +) + +hierarchical_space = seq.search_space(parser="optuna", flat=False) # Default +from amltk._doc import doc_print; doc_print(print, hierarchical_space) # markdown-exec: hide + +flat_space = seq.search_space(parser="optuna", flat=False) # Default +from amltk._doc import doc_print; doc_print(print, flat_space) # markdown-exec: hide +``` + +""" # noqa: E501 + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING + +import numpy as np +from optuna.distributions import ( + BaseDistribution, + CategoricalChoiceType, + CategoricalDistribution, + FloatDistribution, + IntDistribution, +) + +from amltk._functional import prefix_keys + +if TYPE_CHECKING: + from typing import TypeAlias + + from amltk.pipeline import Node + + OptunaSearchSpace: TypeAlias = dict[str, BaseDistribution] + +PAIR = 2 + + +def _convert_hp_to_optuna_distribution( + name: str, + hp: tuple | Sequence | CategoricalChoiceType | BaseDistribution, +) -> BaseDistribution: + match hp: + case BaseDistribution(): + return hp + case None | bool() | int() | str() | float(): + return CategoricalDistribution([hp]) + case tuple() as tup if len(tup) == PAIR: + match tup: + case (int() | np.integer(), int() | np.integer()): + x, y = tup + return IntDistribution(int(x), int(y)) + case (float() | np.floating(), float() | np.floating()): + x, y = tup + return FloatDistribution(float(x), float(y)) + case (x, y): + raise ValueError( + f"Expected {name} to have same type for lower and upper bound," + f"got lower: {type(x)}, upper: {type(y)}.", + ) + case Sequence(): + if len(hp) == 0: + raise ValueError(f"Can't have empty list for categorical {name}") + + return CategoricalDistribution(hp) + case _: + raise ValueError( + f"Could not parse {name} as a valid Optuna distribution." f"\n{hp=}", + ) + + raise ValueError(f"Could not parse {name} as a valid Optuna distribution.\n{hp=}") + + +def _parse_space(node: Node) -> OptunaSearchSpace: + match node.space: + case None: + space = {} + case Mapping(): + space = { + name: _convert_hp_to_optuna_distribution(name=name, hp=hp) + for name, hp in node.space.items() + } + case _: + raise ValueError( + f"Can only parse mappings with Optuna but got {node.space=}", + ) + + if node.config is not None: + for name, value in node.config.items(): + if name in space: + space[name] = CategoricalDistribution([value]) + + return space + + +def parser( + node: Node, + *, + flat: bool = False, + conditionals: bool = False, + delim: str = ":", +) -> OptunaSearchSpace: + """Parse a Node and its children into a ConfigurationSpace. + + Args: + node: The Node to parse + flat: Whether to have a hierarchical naming scheme for nodes and their children. + conditionals: Whether to include conditionals in the space from a + [`Choice`][amltk.pipeline.Choice]. If this is `False`, this will + also remove all forbidden clauses and other conditional clauses. + The primary use of this functionality is that some optimizers do not + support these features. + + !!! TODO "Not yet supported" + + This functionality is not yet supported as we can't encode this into + a static Optuna search space. + + delim: The delimiter to use for the names of the hyperparameters. + """ + if conditionals: + raise NotImplementedError("Conditionals are not yet supported with Optuna.") + + space = prefix_keys(_parse_space(node), prefix=f"{node.name}{delim}") + + for child in node.nodes: + subspace = parser(child, flat=flat, conditionals=conditionals, delim=delim) + if not flat: + subspace = prefix_keys(subspace, prefix=f"{node.name}{delim}") + + for name, hp in subspace.items(): + if name in space: + raise ValueError( + f"Duplicate name {name} already in space from space of {node.name}", + f"\nCurrently parsed space: {space}", + ) + space[name] = hp + + return space diff --git a/src/amltk/pipeline/pipeline.py b/src/amltk/pipeline/pipeline.py deleted file mode 100644 index f438491a..00000000 --- a/src/amltk/pipeline/pipeline.py +++ /dev/null @@ -1,727 +0,0 @@ -"""The pipeline class used to represent a pipeline of steps. - -This module exposes a Pipeline class that wraps a chain of -[`Component`][amltk.pipeline.Component], [`Split`][amltk.pipeline.Split], -[`Group`][amltk.pipeline.Group] and [`Choice`][amltk.pipeline.Choice] -components, created through the [`step()`][amltk.pipeline.api.step], -[`choice()`][amltk.pipeline.choice], [`split()`][amltk.pipeline.split] -and [`group()`][amltk.pipeline.group] api functions from `amltk.pipeline`. -""" -from __future__ import annotations - -import logging -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Iterable, - Iterator, - Literal, - Mapping, - TypeVar, - overload, -) -from typing_extensions import override -from uuid import uuid4 - -from attrs import field, frozen - -from amltk.functional import classname, mapping_select -from amltk.pipeline.components import Group, Step, prefix_keys -from amltk.richutil import RichRenderable - -if TYPE_CHECKING: - from rich.console import RenderableType - from rich.text import TextType - - from amltk.pipeline.parser import Parser - from amltk.pipeline.sampler import Sampler - from amltk.types import Config, FidT, Seed, Space - -T = TypeVar("T") # Dummy typevar -B = TypeVar("B") # Built pipeline - -logger = logging.getLogger(__name__) - - -@frozen(kw_only=True) -class Pipeline(RichRenderable): - """A sequence of steps and operations on them.""" - - name: str - """The name of the pipeline""" - - steps: list[Step] - """The steps in the pipeline. - - This does not include any steps that are part of a `Split` or `Choice`. - """ - - modules: Mapping[str, Step | Pipeline] = field(factory=dict) - """Additional modules to associate with the pipeline""" - - meta: Mapping[str, Any] | None = None - """Additional meta information to associate with the pipeline""" - - RICH_PANEL_BORDER_COLOR: ClassVar[str] = "magenta" - - @property - def head(self) -> Step: - """The first step in the pipeline.""" - return self.steps[0] - - @property - def tail(self) -> Step: - """The last step in the pipeline.""" - return self.steps[-1] - - def __contains__(self, key: str | Step) -> bool: - """Check if a step is in the pipeline. - - Args: - key: The name of the step or the step itself - - Returns: - bool: True if the step is in the pipeline, False otherwise - """ - key = key.name if isinstance(key, Step) else key - return self.find(key, deep=True) is not None - - def __len__(self) -> int: - return len(self.steps) - - def __iter__(self) -> Iterator[Step]: - return self.steps.__iter__() - - def __or__(self, other: Step | Pipeline) -> Pipeline: - """Append a step or pipeline to this one and return a new one.""" - return self.append(other) - - def iter(self) -> Iterator[Step]: - """Iterate over the top layer of the pipeline. - - Yields: - Step[Key] - """ - yield from iter(self.steps) - - def traverse(self) -> Iterator[Step]: - """Traverse the pipeline in a depth-first manner. - - Yields: - Step[Key] - """ - yield from self.head.traverse() - - def walk(self) -> Iterator[tuple[list[Group], list[Step], Step]]: - """Walk the pipeline in a depth-first manner. - - This is similar to traverse, but yields the groups that lead to the step along - with any parents in a chain with that step (which does not include the groups) - - Yields: - (groups, parents, step) - """ - yield from self.head.walk(groups=[], parents=[]) - - @overload - def find( - self, - key: str | Callable[[Step], bool], - default: T, - *, - deep: bool = ..., - ) -> Step | T: - ... - - @overload - def find( - self, - key: str | Callable[[Step], bool], - *, - deep: bool = ..., - ) -> Step | None: - ... - - def find( - self, - key: str | Callable[[Step], bool], - default: T | None = None, - *, - deep: bool = True, - ) -> Step | T | None: - """Find a step in the pipeline. - - Args: - key: The key to search for or a function that returns True if the step - is the desired step - default: - The value to return if the step is not found. Defaults to None - deep: - Whether to search the entire pipeline or just the top layer. - - Returns: - The step if found, otherwise the default value. Defaults to None - """ - return self.head.find(key, default, deep=deep) - - def select( - self, - choices: Mapping[str, str], - *, - name: str | None = None, - ) -> Pipeline: - """Select particular choices from the pipeline. - - Args: - choices: A mapping of the choice name to the choice to select - name: A name to give to the new pipeline returned. Defaults to the current - - Returns: - A new pipeline with the selected choices - """ - return self.create( - self.head.select(choices), - name=self.name if name is None else name, - ) - - def apply(self, f: Callable[[Step], Step], *, name: str | None = None) -> Pipeline: - """Apply a function to each step in the pipeline, returning a new pipeline. - - !!! warning "Modifications to pipeline structure" - - Any modifications to pipeline structure will be ignored. This is done by - providing a `copy()` of the step to the function, rejoining each modified - step in the pipeline and then returning a new pipeline. - - Args: - f: The function to apply - name: A name to give to the new pipeline returned. Defaults to the current - - Returns: - A new pipeline with the function applied - """ - return self.create( - self.head.apply(f), - name=self.name if name is None else name, - ) - - def remove(self, step: str | list[str], *, name: str | None = None) -> Pipeline: - """Remove a step from the pipeline. - - Args: - step: The name of the step(s) to remove - name: A name to give to the new pipeline returned. Defaults to - the current pipelines name - - Returns: - A new pipeline with the step removed - """ - # NOTE: We explicitly use a list instead of a Sequence for multiple steps. - # This is because technically you could have a single Key = tuple(X, Y, Z), - # which is a Sequence. - # This problem also arises more simply in the case where Key = str. - # Hence, by explicitly checking for the concrete type, `list`, we can - # avoid this problem. - return self.create( - self.head.remove(step if isinstance(step, list) else [step]), - name=name if name is not None else self.name, - ) - - def append(self, nxt: Pipeline | Step, *, name: str | None = None) -> Pipeline: - """Append a step or pipeline to this one and return a new one. - - Args: - nxt: The step or pipeline to append - name: A name to give to the new pipeline returned. Defaults to - the current pipelines name - - Returns: - A new pipeline with the step appended - """ - if isinstance(nxt, Pipeline): - nxt = nxt.head - - return self.create( - self.steps, - nxt.iter(), - name=name if name is not None else self.name, - ) - - def replace( - self, - key: str | dict[str, Step], - step: Step | None = None, - *, - name: str | None = None, - ) -> Pipeline: - """Replace a step in the pipeline. - - Args: - key: The key of the step to replace or a dictionary of steps to replace - step: The step to replace the old step with. Only used if key is - a single key - name: A name to give to the new pipeline returned. Defaults to the current - - Returns: - A new pipeline with the step replaced - """ - if isinstance(key, dict) and step is not None: - raise ValueError("Cannot specify both dictionary of keys and step") - - if not isinstance(key, dict): - if step is None: - raise ValueError("Must specify step to replace with") - replacements = {key: step} - else: - replacements = key - - return self.create( - self.head.replace(replacements), - name=self.name if name is None else name, - ) - - def attach( - self, - *, - modules: Pipeline | Step | Iterable[Pipeline | Step] | None = None, - ) -> Pipeline: - """Attach modules to the pipeline. - - Args: - modules: The modules to attach - """ - if modules is None: - modules = [] - - if isinstance(modules, (Step, Pipeline)): - modules = [modules] - - return self.create( - self.head, - modules=[*self.modules.values(), *modules], - name=self.name, - ) - - def configured(self) -> bool: - """Whether the pipeline has been configured. - - Returns: - True if the pipeline has been configured, False otherwise - """ - return all(step.configured() for step in self.traverse()) and all( - module.configured() for module in self.modules.values() - ) - - def qualified_name(self, step: str | Step, *, delimiter: str = ":") -> str: - """Get the qualified name of a substep in the pipeline. - - Args: - step: The step to get the qualified name of - delimiter: The delimiter to use between the groups and the step - - Returns: - The qualified name of the step - """ - # We use the walk function to get the step along with any groups - # to get there - if isinstance(step, Step): - step = step.name - - found = self.find(step) - if found is None: - raise ValueError(f"Step {step} not found in pipeline") - - return found.qualified_name(delimiter=delimiter) - - @overload - def space( - self, - parser: type[Parser[Any, Space]] | Parser[Any, Space], - *, - seed: Seed | None = None, - ) -> Space: - ... - - @overload - def space( - self, - parser: None = None, - *, - seed: Seed | None = None, - ) -> Any: - ... - - def space( - self, - parser: type[Parser[Any, Space]] | Parser[Any, Space] | None = None, - *, - seed: Seed | None = None, - ) -> Space | Any: - """Get the space for the pipeline. - - If there are any modules attached to this pipeline, - these will also be included in the space. - - Args: - parser: The parser to use for assembling the space. Default is `None`. - * If `None` is provided, the assembler will attempt to - automatically figure out the kind of Space to extract from the pipeline. - * Otherwise we will attempt to use the given Parser. - If there are other intuitive ways to indicate the type, please open - an issue on GitHub and we will consider it! - seed: The seed to use for the space if applicable. - - Raises: - Parser.Error: If the parser fails to parse the space. - - Returns: - The space for the pipeline - """ - from amltk.pipeline.parser import Parser # Prevent circular imports - - return Parser.try_parse(pipeline_or_step=self, parser=parser, seed=seed) - - def fidelities(self) -> dict[str, FidT]: - """Get the fidelities for the pipeline. - - Returns: - The fidelities for the pipeline - """ - return self.head.fidelities() - - @overload - def sample(self) -> Config: - ... - - @overload - def sample( - self, - *, - n: None = None, - space: Space | None = ..., - sampler: type[Sampler[Space]] | Sampler[Space] | None = ..., - seed: Seed | None = ..., - duplicates: bool | Iterable[Config] = ..., - max_attempts: int | None = ..., - ) -> Config: - ... - - @overload - def sample( - self, - *, - n: int, - space: Space | None = ..., - sampler: type[Sampler[Space]] | Sampler[Space] | None = ..., - seed: Seed | None = ..., - duplicates: bool | Iterable[Config] = ..., - max_attempts: int | None = ..., - ) -> list[Config]: - ... - - def sample( - self, - *, - n: int | None = None, - space: Space | None = None, - sampler: type[Sampler[Space]] | Sampler[Space] | None = None, - seed: Seed | None = None, - duplicates: bool | Iterable[Config] = False, - max_attempts: int | None = 10, - ) -> Config | list[Config]: - """Sample a configuration from the space of the pipeline. - - Args: - space: The space to sample from. Will be automatically inferred - if `None` is provided. - n: The number of configurations to sample. If `None`, a single - configuration will be sampled. If `n` is greater than 1, a list of - configurations will be returned. - sampler: The sampler to use. If `None`, a sampler will be automatically - chosen based on the type of space that is provided. - seed: The seed to seed the space with if applicable. - duplicates: If True, allow duplicate samples. If False, make - sure all samples are unique. If an Iterable, make sure all - samples are unique and not in the Iterable. - max_attempts: The number of times to attempt sampling unique - configurations before giving up. If `None` will keep - sampling forever until satisfied. - - Returns: - A configuration sampled from the space of the pipeline - """ - from amltk.pipeline.parser import Parser - from amltk.pipeline.sampler import Sampler - - return Sampler.try_sample( - space - if space is not None - else self.space(parser=sampler if isinstance(sampler, Parser) else None), - sampler=sampler, - n=n, - seed=seed, - duplicates=duplicates, - max_attempts=max_attempts, - ) - - def config(self) -> Config: - """Get the configuration for the pipeline. - - Returns: - The configuration for the pipeline - """ - config: dict[str, Any] = {} - for parents, _, step in self.walk(): - config.update( - **prefix_keys( - step.config, - prefix=":".join([p.name for p in parents] + [step.name]) + ":", - ) - if step.config is not None - else {}, - ) - return config - - def configure( - self, - config: Config, - *, - rename: bool | str = False, - prefixed_name: bool = False, - transform_context: Any | None = None, - params: Mapping[str, Any] | None = None, - clear_space: bool | Literal["auto"] = "auto", - ) -> Pipeline: - """Configure the pipeline with the given configuration. - - This takes a pipeline with spaces and choices and trims it down based on the - configuration. For example, choosing selected steps and setting the `config` - of steps with those present in the `config` object given to this function. - - If there are any modules attached to this pipeline, - these will also be configured for you. - - Args: - config: The configuration to use - rename: Whether to rename the pipeline. Defaults to `False`. - * If `True`, the pipeline will be renamed using a random uuid - * If a name is provided, the pipeline will be renamed to that name - * If `False`, the pipeline will not be renamed - prefixed_name: Whether the configuration is prefixed with the name of the - pipeline. Defaults to `False`. - transform_context: Any context to give to `config_transform=` of individual - steps. - params: The params to match any requests when configuring this step. - These will match against any ParamRequests in the config and will - be used to fill in any missing values. - clear_space: Whether to clear the search space after configuring. - If `"auto"` (default), then the search space will be cleared of any - keys that are in the config, if the search space is a `dict`. Otherwise, - `True` indicates that it will be removed in the returned step and - `False` indicates that it will remain as is. - - Returns: - A new pipeline with the configuration applied - """ - this_config: Config - if prefixed_name: - this_config = mapping_select(config, f"{self.name}:") - else: - this_config = config - - config = dict(config) - - new_head = self.head.configure( - this_config, - transform_context=transform_context, - params=params, - prefixed_name=True, - clear_space=clear_space, - ) - - new_modules = [ - module.configure( - this_config, - transform_context=transform_context, - params=params, - prefixed_name=True, - clear_space=clear_space, - ) - for module in self.modules.values() - ] - - if rename is True: - _rename = None - elif rename is False: - _rename = self.name - else: - _rename = rename - - return Pipeline.create(new_head, modules=new_modules, name=_rename) - - @overload - def build(self, builder: None = None, **builder_kwargs: Any) -> Any: - ... - - @overload - def build(self, builder: Callable[[Pipeline], B], **builder_kwargs: Any) -> B: - ... - - def build( - self, - builder: Callable[[Pipeline], B] | None = None, - **builder_kwargs: Any, - ) -> B | Any: - """Build the pipeline. - - Args: - builder: The builder to use. Default is `None`. - * If `None` is provided, the assembler will attempt to automatically - figure out build the pipeline as it can. - * If `builder` is a callable, we will attempt to use that. - **builder_kwargs: Any additional keyword arguments to pass to the builder. - - Returns: - The built pipeline - """ - from amltk.building import build # Prevent circular imports - - return build(self, builder=builder, **builder_kwargs) - - def copy(self, *, name: str | None = None) -> Pipeline: - """Copy the pipeline. - - Returns: - A copy of the pipeline - """ - return self.create(self, name=self.name if name is None else name) - - @classmethod - def create( - cls, - *steps: Step | Pipeline | Iterable[Step], - modules: Pipeline | Step | Iterable[Pipeline | Step] | None = None, - name: str | None = None, - meta: Mapping[str, Any] | None = None, - ) -> Pipeline: - """Create a pipeline from a sequence of steps. - - ??? warning "Using another pipeline `create()`" - - When using another pipeline as part of a substep, we handle - the parameters of subpiplines in the following ways: - - * `modules`: Any modules attached to a subpipeline will be copied - and attached to the new pipeline. If there is a naming conflict, - an error will be raised. - - * `meta`: Any metadata associated with subpiplines will be erased. - Please retrieve them an handle accordingly. - - Args: - *steps: The steps to create the pipeline from - name: The name of the pipeline. Defaults to a uuid - modules: The modules to use for the pipeline - meta: The meta information to attach to the pipeline - - Returns: - Pipeline - """ - # Expand out any pipelines in the init - expanded = [s.steps if isinstance(s, Pipeline) else s for s in steps] - step_sequence = list(Step.chain(*expanded)) - - if name is None: - name = str(uuid4()) - - if isinstance(modules, (Pipeline, Step)): - modules = [modules] - - # Collect all the modules, turning them into pipelines - # as required by the internal api - final_modules: dict[str, Pipeline | Step] = {} - if modules is not None: - final_modules = {module.name: module.copy() for module in modules} - - # If any of the steps are pipelines and contain modules, attach - # them to the final modules of this newly created pipeline - for step in steps: - if isinstance(step, Pipeline): - step_modules = { - module_name: module.copy() - for module_name, module in step.modules.items() - } - - # If one of the subpipelines has a duplicate module name - # then we need to raise an error - duplicates = step_modules.keys() & final_modules.keys() - if any(duplicates): - msg = ( - "Duplicate module(s) found during pipeline" - f" creation {duplicates=}." - ) - raise ValueError(msg) - - final_modules.update(step_modules) - - return cls( - name=name, - steps=step_sequence, - modules=final_modules, - meta=meta, - ) - - def _rich_iter( - self, - connect: TextType | None = None, # noqa: ARG002 - ) -> Iterator[RenderableType]: - """Used to make it more inline with steps.""" - yield self.__rich__() - - @override - def __rich__(self) -> RenderableType: - """Get the rich renderable for the pipeline.""" - from rich.console import Group as RichGroup - from rich.panel import Panel - from rich.pretty import Pretty - from rich.rule import Rule - from rich.table import Table - from rich.text import Text - - def _contents() -> Iterator[RenderableType]: - # Things for this pipeline - if self.meta is not None: - table = Table.grid(padding=(0, 1), expand=False) - table.add_row("meta", Pretty(self.meta)) - table.add_section() - yield table - - connecter = Text("↓", style="bold", justify="center") - # The main pipeline - yield from self.head._rich_iter(connect=connecter) - - if any(self.modules): - yield Rule(title="Modules", style=self.RICH_PANEL_BORDER_COLOR) - - # Any modules attached to this pipeline - for module in self.modules.values(): - yield from module._rich_iter(connect=connecter) - - clr = self.RICH_PANEL_BORDER_COLOR - title = Text.assemble( - (classname(self), f"{clr} bold"), - "(", - (self.name, f"{clr} italic"), - ")", - style="default", - end="", - ) - return Panel( - RichGroup(*_contents()), - title=title, - title_align="left", - expand=False, - border_style=clr, - ) diff --git a/src/amltk/pipeline/sampler.py b/src/amltk/pipeline/sampler.py deleted file mode 100644 index 76b4d119..00000000 --- a/src/amltk/pipeline/sampler.py +++ /dev/null @@ -1,420 +0,0 @@ -"""The base definition of a Sampler. - -It's primary role is to allow sampling from a particular Space. -""" -from __future__ import annotations - -import logging -from abc import ABC, abstractmethod -from typing import Any, Generic, Iterable, overload -from typing_extensions import override - -from more_itertools import first, first_true, seekable - -from amltk.exceptions import safe_map -from amltk.randomness import as_int, as_rng -from amltk.types import Config, Seed, Space - -logger = logging.getLogger(__name__) - - -class Sampler(ABC, Generic[Space]): - """A sampler to sample configs from a search space. - - This class is a sampler for a given Space type, providing functionality - to sample from the space. To implement a new sampler, subclass this class - and implement the following methods: - - !!! example "Abstract Methods" - - * [`supports_sampling`][amltk.pipeline.sampler.Sampler.supports_sampling]: - Check if the sampler supports sampling from a given Space. - * [`_sample`][amltk.pipeline.Sampler._sample]: Sample from the - given Space, given a specific seed and number of samples. Should - ideally be deterministic given a pair `(seed, n)`. - This is used in the [`sample`][amltk.pipeline.sampler.Sampler.sample] - method. - * [`copy`][amltk.pipeline.sampler.Sampler.copy]: Copy a Space to get - and identical space. - - Please see the documentation for these methods for more information. - - - See Also: - * [`SpaceAdapter`][amltk.pipeline.space.SpaceAdapter] - Together with implementing the [`Parser`][amltk.pipeline.Parser] - interface, this class provides a complete adapter for a given search space. - """ - - @classmethod - def default_samplers(cls, space: Any) -> list[Sampler]: - """Get the default samplers compatible with a given space. - - Args: - space: The space to sample from. - - Returns: - A list of samplers that can sample from the given space. - """ - samplers: list[Sampler] = [] - adapter: Sampler - try: - from amltk.configspace import ConfigSpaceAdapter - - adapter = ConfigSpaceAdapter() - if adapter.supports_sampling(space): - samplers.append(adapter) - - samplers.append(adapter) - except ImportError: - logger.debug("ConfigSpace not installed for sampling, skipping") - - try: - from amltk.optuna import OptunaSpaceAdapter - - adapter = OptunaSpaceAdapter() - if adapter.supports_sampling(space): - samplers.append(adapter) - - samplers.append(adapter) - except ImportError: - logger.debug("Optuna not installed for sampling, skipping") - - return samplers - - @classmethod - def find(cls, space: Any) -> Sampler | None: - """Find a sampler that supports the given space. - - Args: - space: The space to sample from. - - Returns: - The first sampler that supports the given space, or None if no - sampler supports the given space. - """ - return first( - ( - sampler - for sampler in Sampler.default_samplers(space) - if sampler.supports_sampling(space) - ), - default=None, - ) - - @overload - @classmethod - def try_sample( - cls, - space: Space, - sampler: type[Sampler[Space]] | Sampler[Space] | None = ..., - *, - n: None = None, - seed: Seed | None = ..., - duplicates: bool | Iterable[Config] = ..., - max_attempts: int | None = ..., - ) -> Config: - ... - - @overload - @classmethod - def try_sample( - cls, - space: Space, - sampler: type[Sampler[Space]] | Sampler[Space] | None = ..., - *, - n: int, - seed: Seed | None = ..., - duplicates: bool | Iterable[Config] = ..., - max_attempts: int | None = ..., - ) -> list[Config]: - ... - - @classmethod - def try_sample( - cls, - space: Space, - sampler: type[Sampler[Space]] | Sampler[Space] | None = None, - *, - n: int | None = None, - seed: Seed | None = None, - duplicates: bool | Iterable[Config] = False, - max_attempts: int | None = 10, - ) -> Config | list[Config]: - """Attempt to sample a pipeline with the default samplers. - - Args: - space: The space to sample from. - sampler: The sampler to use. If None, the default samplers will be - used. - n: The number of samples to return. If None, a single sample will - be returned. - seed: The seed to use for sampling. - duplicates: If True, allow duplicate samples. If False, make - sure all samples are unique. If a Iterable, make sure all - samples are unique and not in the Iterable. - max_attempts: The number of times to attempt sampling unique - configurations before giving up. If `None` will keep - sampling forever until satisfied. - - Returns: - A single sample if `n` is None, otherwise a list of samples. - """ - if sampler is None: - samplers = cls.default_samplers(space) - elif isinstance(sampler, Sampler): - samplers = [sampler] - else: - samplers = [sampler()] - - if not any(samplers): - raise RuntimeError( - "Found no possible sampler to use. Have you tried installing any of:" - "\n* ConfigSpace" - "\n* Optuna" - "\nPlease see the integration documentation for more info, especially" - "\nif using an optimizer which often requires a specific search space." - "\nUsually just installing the optimizer will work.", - ) - - def _sample(_sampler: Sampler[Space]) -> Config | list[Config]: - _space = _sampler.copy(space) - return _sampler.sample( - space=_space, - n=n, - seed=seed, - duplicates=duplicates, - max_attempts=max_attempts, - ) - - # Wrap in seekable so we don't evaluate all of them, only as - # far as we need to get a succesful parse. - results_itr = seekable(safe_map(_sample, samplers)) - - is_result = lambda r: not (isinstance(r, tuple) and isinstance(r[0], Exception)) - - # Progress the iterator until we get a successful sample - samples = first_true(results_itr, default=False, pred=is_result) - - # If we didn't get a succesful parse, raise the appropriate error - if samples is False: - results_itr.seek(0) # Reset to start of iterator - sampler_strs = [str(s) for s in samplers] - errors = [(err, tb) for err, tb in results_itr] - raise Sampler.FailedSamplingError( - samplers=sampler_strs, - err_tbs=errors, # type: ignore - ) - - assert not isinstance(samples, (tuple, bool)) - return samples - - @overload - def sample( - self, - space: Space, - *, - seed: Seed | None = None, - n: None = None, - duplicates: bool | Iterable[Config] = ..., - max_attempts: int | None = ..., - ) -> Config: - ... - - @overload - def sample( - self, - space: Space, - *, - seed: Seed | None = None, - n: int, - duplicates: bool | Iterable[Config] = ..., - max_attempts: int | None = ..., - ) -> list[Config]: - ... - - def sample( - self, - space: Space, - *, - seed: Seed | None = None, - n: int | None = None, - duplicates: bool | Iterable[Config] = False, - max_attempts: int | None = 10, - ) -> Config | list[Config]: - """Sample a configuration from the given space. - - Args: - space: The space to sample from. - seed: The seed to use for sampling. - n: The number of configurations to sample. - duplicates: If True, allow duplicate samples. If False, make - sure all samples are unique. If a Iterable, make sure all - samples are unique and not in the Iterable. - max_attempts: The number of times to attempt sampling unique - configurations before giving up. If `None` will keep - sampling forever until satisfied. - - Returns: - A single sample if `n` is None, otherwise a list of samples. - If `duplicates` is not True and we fail to sample. - """ - _n = 1 if n is None else n - rng = as_rng(seed) - - if duplicates is True: - samples = self._sample(space=space, n=_n, seed=rng) - return samples[0] if n is None else samples - - # NOTE: We use a list here as Config's could be a dict - # which are not hashable. We rely on equality checks - seen = list(duplicates) if isinstance(duplicates, Iterable) else [] - - samples = [] - _max_attempts: int = max_attempts if max_attempts is not None else 2**32 - rng = as_rng(seed) - for _ in range(_max_attempts): - next_seed = as_int(rng) - _samples = self._sample(space=space, n=_n, seed=next_seed) - - for s in _samples: - if s in seen: - continue - samples.append(s) - seen.append(s) - - if len(samples) >= _n: - break - - if len(samples) != _n: - raise Sampler.GenerateUniqueConfigError( - n=_n, - max_attempts=_max_attempts, - seen=seen, - ) - - return samples[0] if n is None else samples[:n] - - @classmethod - @abstractmethod - def supports_sampling(cls, space: Any) -> bool: - """Check if the space is supported for sampling. - - Args: - space: The space to check. - - Returns: - True if the space is supported, False otherwise. - """ - ... - - @abstractmethod - def copy(self, space: Space) -> Space: - """Copy the space. - - Args: - space: The space to copy. - - Returns: - A copy of the space. - """ - ... - - @abstractmethod - def _sample( - self, - space: Space, - n: int = 1, - seed: Seed | None = None, - ) -> list[Config]: - """Sample a configuration from the given space. - - Args: - space: The space to sample from. - n: The number of configurations to sample. - seed: The seed to use for sampling. - - Returns: - A list of samples. - """ - ... - - class NoSamplerFoundError(ValueError): - """Error when no sampler is found for a given space.""" - - def __init__(self, space: Any, extra: str | None = None): - """Create a new no sampler found error. - - Args: - space: The space that no sampler was found for - extra: Any extra information to add to the error message. - """ - self.space = space - self.extra = extra - super().__init__(space, extra) - - @override - def __str__(self) -> str: - msg = ( - f"No sampler found for space of type={type(self.space)}." - " Do you have the correct integrations installed? If none" - " exist for your space, you can create your own sampler." - ) - if self.extra: - msg += f" {self.extra}" - - return msg - - class GenerateUniqueConfigError(RuntimeError): - """Error when a Sampler fails to sample a unique configuration.""" - - def __init__(self, n: int, max_attempts: int, seen: list[Config]): - """Create a new sample unique config error. - - Args: - n: The number of unique configs that were to sample. - max_attempts: The maximum number of attempts made to sample - `n` unique configurations. - seen: The configs seen during sampling. - """ - self.n = n - self.max_attempts = max_attempts - self.seen = seen - super().__init__(n, max_attempts, seen) - - @override - def __str__(self) -> str: - n = self.n - max_attempts = self.max_attempts - seen = self.seen - return ( - f"Could not find {n=} unique configs after {max_attempts=} attempts." - "\nYou could try increasing the `max_attempts=` parameter." - f"\n {len(seen)} seen: {seen=}" - ) - - class FailedSamplingError(RuntimeError): - """Error for when a Sampler fails to sample from a Space.""" - - def __init__(self, samplers: list[str], err_tbs: list[tuple[Exception, str]]): - """Create a new sampler error. - - Args: - samplers: The sampler(s) that failed. - err_tbs: The errors and tracebacks for each sampler - """ - self.samplers = samplers - self.err_tbs = err_tbs - super().__init__(samplers, err_tbs) - - @override - def __str__(self) -> str: - return "\n".join( - [ - "Could not sample with any of the samplers:", - *[ - f" - {sampler}: {err}\n{tb}" - for sampler, (err, tb) in zip(self.samplers, self.err_tbs) - ], - ], - ) diff --git a/src/amltk/pipeline/space.py b/src/amltk/pipeline/space.py deleted file mode 100644 index 89726207..00000000 --- a/src/amltk/pipeline/space.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Space adapter.""" -from __future__ import annotations - -import logging -from typing import Generic, TypeVar - -from amltk.pipeline.parser import Parser -from amltk.pipeline.sampler import Sampler - -logger = logging.getLogger(__name__) - -InputT = TypeVar("InputT") -OutputT = TypeVar("OutputT") - - -class SpaceAdapter(Parser[InputT, OutputT], Sampler[OutputT], Generic[InputT, OutputT]): - """Space adapter. - - This interfaces combines the utility to parse and sample from a given - type of Space. - It is a combination of the [`Parser`][amltk.pipeline.parser.Parser] and - [`Sampler`][amltk.pipeline.sampler.Sampler] interfaces, such that - we can perform operations on a Space without knowing its type. - - To implement a new SpaceAdapter, you must implement the methods - described in the [`Parser`][amltk.pipeline.parser.Parser] and - [`Sampler`][amltk.pipeline.sampler.Sampler] interfaces. - - !!! example "Example Adapaters" - - We have integrated adapters for the following libraries which - you can use as full reference guide. - - * [`OptunaSpaceAdapter`][amltk.optuna.space.OptunaSpaceAdapter] for - [Optuna](https://optuna.org/) - * [`ConfigSpaceAdapter`][amltk.configspace.space.ConfigSpaceAdapter] - for [ConfigSpace](https://automl.github.io/ConfigSpace/master/) - - """ - - @classmethod - def default_adapters(cls) -> list[SpaceAdapter]: - """Get the default adapters. - - Returns: - A list of default adapters. - """ - adapters: list[SpaceAdapter] = [] - try: - from amltk.optuna.space import OptunaSpaceAdapter - - adapters.append(OptunaSpaceAdapter()) - except ImportError: - logger.debug("Optuna not installed, skipping adapter") - - try: - from amltk.configspace.space import ConfigSpaceAdapter - - adapters.append(ConfigSpaceAdapter()) - except ImportError: - logger.debug("ConfigSpace not installed, skipping adapter") - - return adapters diff --git a/src/amltk/pipeline/step.py b/src/amltk/pipeline/step.py deleted file mode 100644 index 983bbd6b..00000000 --- a/src/amltk/pipeline/step.py +++ /dev/null @@ -1,1049 +0,0 @@ -"""The core step class for the pipeline. - -These objects act as a doubly linked list to connect steps into a chain which -are then convenientyl wrapped in a `Pipeline` object. Their concrete implementations -can be found in the `amltk.pipeline.components` module. -""" -from __future__ import annotations - -from copy import deepcopy -from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Generic, - Iterable, - Iterator, - Literal, - Mapping, - Sequence, - TypeVar, - overload, -) -from typing_extensions import override - -from attrs import evolve, field, frozen -from more_itertools import consume, first_true, last, peekable, triplewise -from rich.text import Text - -from amltk.functional import classname, mapping_select, prefix_keys -from amltk.richutil import RichRenderable -from amltk.types import Config, FidT, Seed, Space - -if TYPE_CHECKING: - from typing_extensions import Self - - from rich.console import RenderableType - from rich.panel import Panel - from rich.text import TextType - - from amltk.pipeline.components import Group - from amltk.pipeline.parser import Parser - from amltk.pipeline.pipeline import Pipeline - from amltk.pipeline.sampler import Sampler - -T = TypeVar("T") -ParserOutput = TypeVar("ParserOutput") - -_NotSet = object() - - -class ParamRequest(Generic[T]): - """A parameter request for a step. This is most useful for things like seeds.""" - - def __init__( - self, - key: str, - *, - default: T = _NotSet, # type: ignore - required: bool = False, - ) -> None: - """Create a new parameter request. - - Args: - key: The key to request under. - default: The default value to use if the key is not found. - If left as `_NotSet` (default) then the key will be removed from the - config once [`configure`][amltk.pipeline.Step.configure] is called and - nothing has been provided. - - required: Whether the key is required to be present. - """ - super().__init__() - self.key = key - self.default = default - self.required = required - self.has_default = default is not _NotSet - - @override - def __repr__(self) -> str: - default = self.default if self.default is not _NotSet else "_NotSet" - required = self.required - return f"ParamRequest({self.key}, {default=}, {required=})" - - class RequestNotMetError(ValueError): - """Raised when a request is not met.""" - - -@frozen(kw_only=True) -class Step(RichRenderable, Generic[Space]): - """The core step class for the pipeline. - - These are simple objects that are named and linked together to form - a chain. They are then wrapped in a `Pipeline` object to provide - a convenient interface for interacting with the chain. - - See Also: - For creating the concrete implementations of this class, you can use these - convenience methods. - - * [`step()`][amltk.pipeline.api.step] - * [`choice()`][amltk.pipeline.api.choice] - * [`group()`][amltk.pipeline.api.group] - * [`split()`][amltk.pipeline.api.split] - * [`searchable()`][amltk.pipeline.api.searchable] - """ - - name: str - """Name of the step""" - - prv: Step | None = field(default=None, eq=False, repr=False) - """The previous step in the chain""" - - nxt: Step | None = field(default=None, eq=False, repr=False) - """The next step in the chain""" - - parent: Step | None = field(default=None, eq=False, repr=False) - """Any [`Group`][amltk.pipeline.components.Group] or - [`Choice`][amltk.pipeline.components.Choice] that this step is a part of - and is the head of the chain. - """ - - config: Mapping[str, Any] | None = field(default=None, hash=False) - """The configuration for this step""" - - search_space: Space | None = field(default=None, hash=False, repr=False) - """The search space for this step""" - - fidelity_space: Mapping[str, FidT] | None = field( - default=None, - hash=False, - repr=False, - ) - """The fidelities for this step""" - - config_transform: ( - Callable[ - [Mapping[str, Any], Any], - Mapping[str, Any], - ] - | None - ) = field(default=None, hash=False, repr=False) - """A function that transforms the configuration of this step""" - - meta: Mapping[str, Any] | None = None - """Any meta information about this step""" - - old_parent: str | None = field(default=None, hash=False, repr=False, eq=False) - """The original parent of this step, used in case of being chosen from a `Choice`""" - - RICH_PANEL_BORDER_COLOR: ClassVar[str] = "default" - - def __or__(self, nxt: Any) -> Step: - """Append a step on this one, return the head of a new chain of steps. - - Args: - nxt: The next step in the chain - - Returns: - Step: The head of the new chain of steps - """ - if not isinstance(nxt, Step): - return NotImplemented - - return self.append(nxt) - - def append(self, nxt: Step) -> Step: - """Append a step on this one, return the head of a new chain of steps. - - Args: - nxt: The next step in the chain - - Returns: - Step: The head of the new chain of steps - """ - return Step.join(self, nxt) - - def extend(self, nxt: Iterable[Step]) -> Step: - """Extend many steps on to this one, return the head of a new chain of steps. - - Args: - nxt: The next steps in the chain - - Returns: - Step: The head of the new chain of steps - """ - return Step.join(self, nxt) - - def iter( - self, - *, - backwards: bool = False, - include_self: bool = True, - to: str | Step | None = None, - ) -> Iterator[Step]: - """Iterate the linked-list of steps. - - Args: - backwards: Traversal order. Defaults to False - include_self: Whether to include self in iterator. Default True - to: Stop iteration at this step. Defaults to None - - Yields: - Step[Key]: The steps in the chain - """ - # Break out if current step is `to - if to is not None: - if isinstance(to, Step): - to = to.name - if self.name == to: - return - - if include_self: - yield self - - if backwards: - if self.prv is not None: - yield from self.prv.iter(backwards=True, to=to) - elif self.nxt is not None: - yield from self.nxt.iter(backwards=False, to=to) - - def qualified_name(self, *, delimiter: str = ":") -> str: - """Get the qualified name of this step. - - This is the name of the step prefixed by the names of all the previous - groups taken to reach this step in the chain. - - Args: - delimiter: The delimiter to use between names. Defaults to ":" - - Returns: - The qualified name - """ - from amltk.pipeline.components import Group - - groups = [ - s.parent - for s in self.climb(include_self=True) - if s.parent is not None and isinstance(s.parent, Group) - ] - names = [*reversed([group.name for group in groups]), self.name] - return delimiter.join(names) - - def configured(self) -> bool: - """Check if this searchable is configured.""" - return self.search_space is None and self.config is not None - - def configure( # noqa: C901, PLR0912 - self, - config: Config, - *, - prefixed_name: bool | None = None, - transform_context: Any = None, - params: Mapping[str, Any] | None = None, - clear_space: bool | Literal["auto"] = "auto", - ) -> Step: - """Configure this step and anything following it with the given config. - - Args: - config: The configuration to apply - prefixed_name: Whether items in the config are prefixed by the names - of the steps. - * If `None`, the default, then `prefixed_name` will be assumed to - be `True` if this step has a next step or if the config has - keys that begin with this steps name. - * If `True`, then the config will be searched for items prefixed - by the name of the step (and subsequent chained steps). - * If `False`, then the config will be searched for items without - the prefix, i.e. the config keys are exactly those matching - this steps search space. - transform_context: Any context to give to `config_transform=` of individual - steps. This will apply once the config has been fully built. - params: The params to match any requests when configuring this step. - These will match against any ParamRequests in the config and will - be used to fill in any missing values. - clear_space: Whether to clear the search space after configuring. - If `"auto"` (default), then the search space will be cleared of any - keys that are in the config, if the search space is a `dict`. Otherwise, - `True` indicates that it will be removed in the returned step and - `False` indicates that it will remain as is. - - Returns: - Step: The configured step - """ - if prefixed_name is None: - if any(key.startswith(self.name) for key in config): - prefixed_name = True - else: - prefixed_name = self.nxt is not None - - nxt = ( - self.nxt.configure( - config, - prefixed_name=prefixed_name, - transform_context=transform_context, - params=params, - clear_space=clear_space, - ) - if self.nxt - else None - ) - - this_config: dict[str, Any] - if prefixed_name: - this_config = mapping_select(config, f"{self.name}:") - else: - this_config = dict(deepcopy(config)) - - if self.config is not None: - this_config = {**self.config, **this_config} - - _params = params or {} - reqs = [(k, v) for k, v in this_config.items() if isinstance(v, ParamRequest)] - for k, request in reqs: - if request.key in _params: - this_config[k] = _params[request.key] - elif request.has_default: - this_config[k] = request.default - elif request.required: - raise ParamRequest.RequestNotMetError( - f"Missing required parameter {request.key} for step {self.name}" - " and no default was provided." - f"\nThe request given was: {request}", - f"Please use the `params=` argument to provide a value for this" - f" request. What was given was `{params=}`", - ) - - # If we have a `dict` for a space, then we can remove any configured keys that - # overlap it. - _space: Any - if clear_space == "auto": - if isinstance(self.search_space, dict) and any(self.search_space): - _lap = set(this_config).intersection(self.search_space) - _space = {k: v for k, v in self.search_space.items() if k not in _lap} - if len(_space) == 0: - _space = None - else: - _space = self.search_space - - elif clear_space is True: - _space = None - else: - _space = self.search_space - - if self.config_transform is not None: - this_config = dict(self.config_transform(this_config, transform_context)) - - new_self = self.mutate( - config=this_config if this_config else None, - search_space=_space, - nxt=nxt, - ) - - if nxt is not None: - # HACK: This is a hack to to modify the fact `nxt` is a frozen - # object. Frozen objects do not allow setting attributes after - # instantiation. - object.__setattr__(nxt, "prv", new_self) - - return new_self - - def apply(self, func: Callable[[Step], Step]) -> Step: - """Apply a function to this step and all following steps. - - Args: - func: The function to apply - - Returns: - Step: The step with the function applied - """ - new_nxt = self.nxt.apply(func) if self.nxt is not None else None - - # NOTE: We can't be sure that the function will return a new instance of - # `self` so we have to make a copy of `self` and then apply the function - # to that copy. - new_self = func(self.copy()) - - if new_nxt is not None: - # HACK: Frozen objects do not allow setting attributes after - # instantiation. Join the two steps together. - object.__setattr__(new_self, "nxt", new_nxt) - object.__setattr__(new_nxt, "prv", new_self) - - return new_self - - def head(self) -> Step: - """Get the first step of this chain.""" - return last(self.iter(backwards=True)) - - def tail(self) -> Step: - """Get the last step of this chain.""" - return last(self.iter()) - - def root(self) -> Step: - """Climb to the first step of this chain.""" - return last(self.climb()) - - def climb(self, *, include_self: bool = True) -> Iterator[Step]: - """Iterate the steps required to reach the root.""" - if include_self: - yield self - - if self.prv is not None: - yield from self.prv.climb() - elif self.parent is not None: - yield from self.parent.climb() - - def proceeding(self) -> Iterator[Step]: - """Iterate the steps that follow this one.""" - return self.iter(include_self=False) - - def preceeding(self) -> Iterator[Step]: - """Iterate the steps that preceed this one.""" - head = self.head() - if self != head: - yield from head.iter(to=self) - - def mutate(self, **kwargs: Any) -> Self: - """Mutate this step with the given kwargs, creating a copy. - - !!! warning "Warning" - - Will remove any existing `nxt` or `prv` to prevent `nxt` and `prv` - pointing to the old step while the new version of this step points to - those old `nxt` and `prv` steps. - - ``` - Before: - - ---[prv]--[self, x=4]--[nxt]--- - - After Mutation: - - ----------[self, x=5]---------- - ``` - - To overwrite this behaviour, please explicitly pass `prv=` and `nxt=`. - - Args: - **kwargs: The attributes to mutate - - Returns: - Self: The mutated step - """ - # NOTE: To prevent the confusion that this instance of `step` would link to - # `prv` and `nxt` while the steps `prv` and `nxt` would not link to this - # *new* mutated step, we explicitly remove the "prv" and "nxt" attributes - # This is unlikely to be very useful for the base Step class other than - # to rename it. - # However this can overwritten by passing "nxt" or "prv" explicitly. - return evolve(self, **{"prv": None, "nxt": None, **kwargs}) # type: ignore - - def copy(self: Self) -> Self: - """Copy this step. - - Returns: - Self: The copied step - """ - return deepcopy(self) # type: ignore - - def remove(self, keys: Sequence[str]) -> Iterator[Step]: - """Remove the given steps from this chain. - - Args: - keys: The name of the steps to remove - - Yields: - Step[Key]: The steps in the chain unless it was one to remove - """ - if self.name not in keys: - yield self - - if self.nxt is not None: - yield from self.nxt.remove(keys) - - def walk( - self, - groups: Sequence[Group] | None = None, - parents: Sequence[Step] | None = None, - ) -> Iterator[tuple[list[Group], list[Step], Step]]: - """See `Step.walk`.""" - groups = list(groups) if groups is not None else [] - parents = list(parents) if parents is not None else [] - yield groups, parents, self - - if self.nxt is not None: - yield from self.nxt.walk(groups=groups, parents=[*parents, self]) - - def traverse( - self, - *, - include_self: bool = True, - backwards: bool = False, - ) -> Iterator[Step]: - """Traverse any sub-steps associated with this step. - - Subclasses should overwrite as required - - Args: - include_self: Whether to include this step. Defaults to True - backwards: Whether to traverse backwards. This will - climb linearly until it reaches some head. - - Returns: - Iterator[Step[Key]]: The iterator over steps - """ - if include_self: - yield self - - if backwards: - if self.prv is not None: - yield from self.prv.traverse(backwards=True) - elif self.parent is not None: - yield from self.parent.traverse(backwards=True) - - if self.nxt is not None: - yield from self.nxt.traverse() - - @overload - def find( - self, - key: str | Callable[[Step], bool], - default: T, - *, - deep: bool = ..., - ) -> Step | T: - ... - - @overload - def find( - self, - key: str | Callable[[Step], bool], - *, - deep: bool = ..., - ) -> Step | None: - ... - - def find( - self, - key: str | Callable[[Step], bool], - default: T | None = None, - *, - deep: bool = True, - ) -> Step | T | None: - """Find a step in that's nested deeper from this step. - - Args: - key: The key to search for or a function that returns True if the step - is the desired step - default: - The value to return if the step is not found. Defaults to None - deep: - Whether to search the entire pipeline or just the top layer. - - Returns: - The step if found, otherwise the default value. Defaults to None - """ - pred: Callable[[Step], bool] - pred = key if callable(key) else (lambda step: step.name == key) - if deep: - all_steps = chain(self.traverse(), self.traverse(backwards=True)) - return first_true(all_steps, default, pred) # type: ignore - - all_steps = chain(self.iter(), self.iter(backwards=True)) - return first_true(all_steps, default, pred) # type: ignore - - def path_to( - self, - key: str | Step | Callable[[Step], bool], - *, - direction: Literal["forward", "backward"] | None = None, - ) -> list[Step] | None: - """Get the path to the given step. - - This includes the path to the step itself. - - ```python exec="true" source="material-block" result="python" title="path_to" - from amltk.pipeline import step, split - - head = ( - step("head", 42) - | step("middle", 0) - | split( - "split", - step("left", 0), - step("right", 0), - ) - | step("tail", 0) - ) - - path = head.path_to("left") - print([s.name for s in path]) - - left = head.find("left") - path = left.path_to("head") - print([s.name for s in path]) - ``` - - Args: - key: The step to find - direction: Specify a particular direction to search in. Defaults to None - which means search both directions, starting with forwards. - - Returns: - Iterator[Step[Key]]: The path to the step - """ - if isinstance(key, Step): - pred = lambda step: step == key - elif isinstance(key, str): - pred = lambda step: step.name == key - else: - pred = key - - # We found our target, just return now - if pred(self): - return [self] - - if direction in (None, "forward") and self.nxt is not None: # noqa: SIM102 - if path := self.nxt.path_to(pred, direction="forward"): - return [self, *path] - - if direction in (None, "backward"): - back = self.prv or self.parent - if back and (path := back.path_to(pred, direction="backward")): - return [self, *path] - - return None - - return None - - def replace(self, replacements: Mapping[str, Step]) -> Iterator[Step]: - """Replace the given step with a new one. - - Args: - replacements: The steps to replace - - Yields: - step: The steps in the chain, replaced if in replacements - """ - yield replacements.get(self.name, self) - - if self.nxt is not None: - yield from self.nxt.replace(replacements=replacements) - - def select(self, choices: Mapping[str, str]) -> Iterator[Step]: - """Replace the current step with the chosen step if it's a choice. - - Args: - choices: Mapping of choice names to the path to pick - - Yields: - Step[Key]: The unmodified step if not a choice, else the chosen choice - if applicable - """ - yield self - - if self.nxt is not None: - yield from self.nxt.select(choices) - - @overload - def space( - self, - *, - parser: type[Parser[Space, ParserOutput]] | Parser[Space, ParserOutput], - seed: Seed | None = ..., - ) -> ParserOutput: - ... - - @overload - def space(self, *, seed: Seed | None = ...) -> Any: - ... - - def space( - self, - parser: ( - type[Parser[Space, ParserOutput]] | Parser[Space, ParserOutput] | None - ) = None, - *, - seed: Seed | None = None, - ) -> ParserOutput | Any: - """Get the search space for this step.""" - from amltk.pipeline.parser import Parser - - return Parser.try_parse(self, parser=parser, seed=seed) - - def fidelities(self) -> dict[str, FidT]: - """Get the fidelities for this step.""" - fids = prefix_keys(self.fidelity_space or {}, f"{self.name}:") - - if self.nxt is None: - return fids - - nxt_fids = self.nxt.fidelities() - return {**fids, **nxt_fids} - - def linearized_fidelity(self, value: float) -> dict[str, int | float | Any]: - """Get the linearized fidelity for this step. - - Args: - value: The value to linearize. Must be between [0, 1] - - Return: - dictionary from key to it's linearized fidelity. - """ - assert 1.0 <= value <= 100.0, f"{value=} not in [1.0, 100.0]" # noqa: PLR2004 - d = {} - if self.fidelity_space is not None: - for f_name, f_range in self.fidelity_space.items(): - low, high = f_range - fid = low + (high - low) * (value - 1) / 100 - fid = fid if isinstance(low, float) else round(fid) - d[f_name] = fid - - d = prefix_keys(d, f"{self.name}:") - - if self.nxt is None: - return d - - nxt_fids = self.nxt.linearized_fidelity(value) - return {**d, **nxt_fids} - - def build(self) -> Any: - """Build the step. - - Subclasses should overwrite as required, by default for a Step, - this will raise an Error - - Raises: - NotImplementedError: If not overwritten - """ - raise NotImplementedError(f"`build()` is not implemented for {type(self)}") - - @overload - def sample( - self, - *, - n: None = None, - space: Space | None = ..., - sampler: type[Sampler[Space]] | Sampler[Space] | None = ..., - seed: Seed | None = ..., - duplicates: bool | Iterable[Config] = ..., - max_attempts: int | None = ..., - ) -> Config: - ... - - @overload - def sample( - self, - *, - n: int, - space: Space | None = ..., - sampler: type[Sampler[Space]] | Sampler[Space] | None = ..., - seed: Seed | None = ..., - duplicates: bool | Iterable[Config] = ..., - max_attempts: int | None = ..., - ) -> list[Config]: - ... - - def sample( - self, - *, - n: int | None = None, - space: Space | None = None, - sampler: type[Sampler[Space]] | Sampler[Space] | None = None, - seed: Seed | None = None, - duplicates: bool | Iterable[Config] = False, - max_attempts: int | None = 10, - ) -> Config | list[Config]: - """Sample a configuration from the space of the pipeline. - - Args: - space: The space to sample from. Will default to it's own space if - not provided. - n: The number of configurations to sample. If `None`, a single - configuration will be sampled. If `n` is greater than 1, a list of - configurations will be returned. - sampler: The sampler to use. If `None`, a sampler will be automatically - chosen based on the type of space that is provided. - seed: The seed to seed the space with if applicable. - duplicates: If True, allow duplicate samples. If False, make - sure all samples are unique. If a Iterable, make sure all - samples are unique and not in the Iterable. - max_attempts: The number of times to attempt sampling unique - configurations before giving up. If `None` will keep - sampling forever until satisfied. - - Returns: - A configuration sampled from the space of the pipeline - """ - from amltk.pipeline.parser import Parser - from amltk.pipeline.sampler import Sampler - - if space is None: - # Make sure if the sampler is also a space parser and we recieved - # no space, that we use this for parsing the space - if ( - isinstance(sampler, type) and issubclass(sampler, Parser) - ) or isinstance(sampler, Parser): - space = self.space(parser=sampler) - else: - space = self.space() - else: - space = space - - return Sampler.try_sample( - space, - sampler=sampler, - n=n, # type: ignore - seed=seed, - duplicates=duplicates, - max_attempts=max_attempts, - ) - - @classmethod - def join(cls, *steps: Step | Iterable[Step]) -> Step: - """Join together a collection of steps, returning the head. - - This is essentially a shortform of Step.chain(*steps) that returns - the head of the chain. See `Step.chain` for more description. - - Args: - *steps : Any amount of steps or iterables of steps - - Returns: - Step[Key] - The head of the chain of steps - """ - itr = cls.chain(*steps) - head = next(itr, None) - if head is None: - raise ValueError(f"Recieved no values for {steps=}") - - consume(itr) - return head - - def as_pipeline( - self, - modules: Pipeline | Step | Iterable[Pipeline | Step] | None = None, - name: str | None = None, - meta: Mapping[str, Any] | None = None, - ) -> Pipeline: - """Wrap this step in a pipeline. - - See Also: - * [`Pipeline.create()`][amltk.pipeline.pipeline.Pipeline.create] - - Args: - name: The name of the pipeline. Defaults to a uuid - modules: The modules to use for the pipeline - meta: The meta information to attach to the pipeline - - Returns: - Pipeline: The pipeline - """ - from amltk.pipeline.pipeline import Pipeline - - return Pipeline.create(self, name=name, modules=modules, meta=meta) - - @classmethod - def chain( - cls, - *steps: Step | Iterable[Step], - expand: bool = True, - ) -> Iterator[Step]: - """Chain together a collection of steps into an iterable. - - Args: - *steps : Any amount of steps or iterable of steps. - expand: Individual steps will be expanded with `step.iter()` while - Iterables will remain as is, defaults to True - - Returns: - An iterator over the steps joined together - """ - expanded = chain.from_iterable( - (s.iter() if expand else [s]) if isinstance(s, Step) else s for s in steps - ) - - # We use a `peekable` to check if there's actually anything to chain - # In the off case we got nothing in `*steps` but empty iterables - new_steps = peekable(s.copy() for s in expanded) - if not new_steps: - return - - # Used to check if we have a duplicate name, - # if so get that step and raise an error - seen_steps: dict[str, Step] = {} - - # As these Steps are frozen, we break the frozen api to build a doubly linked - # list of steps. - # ? Is it possible to build a doubly linked list where each node is immutable? - itr = chain([None], new_steps, [None]) - for prv, cur, nxt in triplewise(itr): - assert cur is not None - - if cur.name in seen_steps: - duplicates = (cur, seen_steps[cur.name]) - raise Step.DuplicateNameError(duplicates) - - seen_steps[cur.name] = cur - - object.__setattr__(cur, "prv", prv) - object.__setattr__(cur, "nxt", nxt) - yield cur - - def _rich_iter(self, connect: TextType | None = None) -> Iterator[RenderableType]: - """Iterate the panels for rich printing.""" - yield self.__rich__() - if self.nxt is not None: - if connect is not None: - yield connect - yield from self.nxt._rich_iter(connect=connect) - - def _rich_table_items(self) -> Iterator[tuple[RenderableType, ...]]: - """Get the items to add to the rich table.""" - from rich.pretty import Pretty - from rich.text import Text - - from amltk.richutil import Function, richify - - if self.config is not None: - _config = {k: richify(v) for k, v in self.config.items()} - yield Text("config"), Pretty(_config) - - if self.search_space is not None: - yield "space", richify(self.search_space, otherwise=Pretty) - - if self.fidelity_space is not None: - yield "fidelities", richify(self.fidelity_space, otherwise=Pretty) - - if self.config_transform is not None: - yield "transform", Function(self.config_transform) - - if self.meta is not None: - yield "meta", Pretty(self.meta) - - def _rich_panel_contents(self) -> Iterator[RenderableType]: - from rich.table import Table - - table = Table.grid(padding=(0, 1), expand=False) - for tup in self._rich_table_items(): - table.add_row(*tup) - table.add_section() - - yield table - - def display( - self, - *, - full: bool = False, - connect: TextType | None = None, - ) -> RenderableType: - """Display this step. - - Args: - full: Whether to display the full step or just a summary - connect: The text to connect the steps together. Defaults to None - """ - if not full: - return self.__rich__() - - from rich.console import Group as RichGroup - - return RichGroup(*self._rich_iter(connect=connect)) - - @override - def __rich__(self) -> Panel: - from rich.console import Group as RichGroup - from rich.panel import Panel - - clr = self.RICH_PANEL_BORDER_COLOR - title = Text.assemble( - (classname(self), f"{clr} bold"), - "(", - (self.name, f"{clr} italic"), - ")", - style="default", - end="", - ) - contents = list(self._rich_panel_contents()) - _content = contents[0] if len(contents) == 1 else RichGroup(*contents) - return Panel( - _content, - title=title, - title_align="left", - border_style=clr, - expand=False, - ) - - class DelimiterInNameError(ValueError): - """Raise when a delimiter is found in a name.""" - - def __init__(self, step: Step, delimiter: str = ":"): - """Initialize the exception. - - Args: - step: The step that contains the delimiter - delimiter: The delimiter that was found - """ - super().__init__(step, delimiter) - self.step = step - self.delimiter = delimiter - - @override - def __str__(self) -> str: - delimiter = self.delimiter - return f"Delimiter ({delimiter=}) in name: {self.step.name} for {self.step}" - - class DuplicateNameError(ValueError): - """Raise when a duplicate name is found.""" - - def __init__(self, steps: tuple[Step, Step]): - """Initialize the exception. - - Args: - steps: The steps that have the same name - """ - super().__init__(steps) - self.steps = steps - - @override - def __str__(self) -> str: - s1, s2 = self.steps - return f"Duplicate names ({s1.name}) for\n\n{s1}\n\n{s2}" - - class ConfigurationError(ValueError): - """Raise when a configuration is invalid.""" - - def __init__(self, step: Step, config: Config, reason: str): - """Initialize the exception. - - Args: - step: The step that has the invalid configuration - config: The invalid configuration - reason: The reason the configuration is invalid - """ - super().__init__() - self.step = step - self.config = config - self.reason = reason - - @override - def __str__(self) -> str: - return ( - f"Invalid configuration: {self.reason}" - f" - Given by: {self.step}" - f" - With config: {self.config}" - ) diff --git a/src/amltk/pipeline/xgboost.py b/src/amltk/pipeline/xgboost.py index af795518..cf126ac4 100644 --- a/src/amltk/pipeline/xgboost.py +++ b/src/amltk/pipeline/xgboost.py @@ -13,14 +13,11 @@ import warnings from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal from xgboost import XGBClassifier, XGBRegressor -from amltk.pipeline.api import step - -if TYPE_CHECKING: - from amltk.pipeline.components import Component +from amltk.pipeline import Component def xgboost_component( @@ -94,12 +91,7 @@ def xgboost_component( f"Space and kwargs overlap: {overlap}, please remove one of them", ) - return step( - name=name, - item=estimator_type, - config=config, - space=space, - ) + return Component(name=name, item=estimator_type, config=config, space=space) def xgboost_large_space( diff --git a/src/amltk/profiling/memory.py b/src/amltk/profiling/memory.py index 149f7b75..6c89244e 100644 --- a/src/amltk/profiling/memory.py +++ b/src/amltk/profiling/memory.py @@ -1,17 +1,18 @@ """Module to measure memory.""" from __future__ import annotations +from collections.abc import Iterator, Mapping from contextlib import contextmanager from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Iterator, Literal, Mapping +from typing import TYPE_CHECKING, Any, Literal from typing_extensions import override import numpy as np import pandas as pd import psutil -from amltk.functional import dict_get_not_none +from amltk._functional import dict_get_not_none if TYPE_CHECKING: from pandas._libs.missing import NAType diff --git a/src/amltk/profiling/profiler.py b/src/amltk/profiling/profiler.py index acc35626..2b98ede7 100644 --- a/src/amltk/profiling/profiler.py +++ b/src/amltk/profiling/profiler.py @@ -1,24 +1,93 @@ -"""Module to measure memory.""" +"""Whether for debugging, building an AutoML system or for optimization +purposes, we provide a powerful [`Profiler`][amltk.profiling.Profiler], +which can generate a [`Profile`][amltk.profiling.Profile] of different sections +of code. This is particularly useful with [`Trial`][amltk.optimization.Trial]s, +so much so that we attach one to every `Trial` made as +[`trial.profiler`][amltk.optimization.Trial.profiler]. + +When done profiling, you can export all generated profiles as a dataframe using +[`profiler.df()`][amltk.profiling.Profiler.df]. + +```python exec="true" result="python" source="material-block" +from amltk.profiling import Profiler +import numpy as np + +profiler = Profiler() + +with profiler("loading-data"): + X = np.random.rand(1000, 1000) + +with profiler("training-model"): + model = np.linalg.inv(X) + +with profiler("predicting"): + y = model @ X + +print(profiler.df()) +``` + +You'll find these profiles as keys in the [`Profiler`][amltk.profiling.Profiler], +e.g. `#! python profiler["loading-data"]`. + +This will measure both the time it took within the block but also +the memory consumed before and after the block finishes, allowing +you to get an estimate of the memory consumed. + + +??? tip "Memory, vms vs rms" + + While not entirely accurate, this should be enough for info + for most use cases. + + Given the main process uses 2GB of memory and the process + then spawns a new process in which you are profiling, as you + might do from a [`Task`][amltk.scheduling.Task]. In this new + process you use another 2GB on top of that, then: + + * The virtual memory size (**vms**) will show 4GB as the + new process will share the 2GB with the main process and + have it's own 2GB. + + * The resident set size (**rss**) will show 2GB as the + new process will only have 2GB of it's own memory. + + +If you need to profile some iterator, like a for loop, you can use +[`Profiler.each()`][amltk.profiling.Profiler.each] which will measure +the entire loop but also each individual iteration. This can be useful +for iterating batches of a deep-learning model, splits of a cross-validator +or really any loop with work you want to profile. + +```python exec="true" result="python" source="material-block" +from amltk.profiling import Profiler +import numpy as np + +profiler = Profiler() + +for i in profiler.each(range(3), name="for-loop"): + X = np.random.rand(1000, 1000) + +print(profiler.df()) +``` + +Lastly, to disable profiling without editing much code, +you can always use [`Profiler.disable()`][amltk.profiling.Profiler.disable] +and [`Profiler.enable()`][amltk.profiling.Profiler.enable] to toggle +profiling on and off. +""" + from __future__ import annotations from collections import deque +from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import contextmanager from dataclasses import dataclass, field -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - Iterator, - Literal, - Mapping, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Literal, TypeVar from typing_extensions import override import pandas as pd -from amltk.optimization.trial import mapping_select +from amltk._functional import mapping_select from amltk.profiling.memory import Memory from amltk.profiling.timing import Timer @@ -262,7 +331,7 @@ def df(self) -> pd.DataFrame: def __rich__(self) -> RenderableType: """Render the profiler.""" - from amltk.richutil import df_to_table + from amltk._richutil import df_to_table _df = self.df() return df_to_table(_df, title="Profiler", index_style="bold") diff --git a/src/amltk/profiling/timing.py b/src/amltk/profiling/timing.py index a1f53644..7b957c46 100644 --- a/src/amltk/profiling/timing.py +++ b/src/amltk/profiling/timing.py @@ -2,16 +2,17 @@ from __future__ import annotations import time +from collections.abc import Iterator, Mapping from contextlib import contextmanager from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Iterator, Literal, Mapping +from typing import TYPE_CHECKING, Any, Literal from typing_extensions import override import numpy as np import pandas as pd -from amltk.functional import dict_get_not_none +from amltk._functional import dict_get_not_none if TYPE_CHECKING: from pandas._libs.missing import NAType diff --git a/src/amltk/pynisher/__init__.py b/src/amltk/pynisher/__init__.py deleted file mode 100644 index 1d01e7d8..00000000 --- a/src/amltk/pynisher/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from amltk.pynisher.pynisher_task_plugin import PynisherPlugin - -__all__ = ["PynisherPlugin"] diff --git a/src/amltk/pynisher/pynisher_task_plugin.py b/src/amltk/pynisher/pynisher_task_plugin.py deleted file mode 100644 index 6edeef49..00000000 --- a/src/amltk/pynisher/pynisher_task_plugin.py +++ /dev/null @@ -1,242 +0,0 @@ -"""A plugin that wraps a task in a pynisher to enforce limits on it. - -Please note that this plugin requires the `pynisher` package to be installed. - -Documentation for the `pynisher` package can be found [here](https://github.com/automl/pynisher). - -Notably, there are some limitations to the `pynisher` package with Mac and Windows -which are [listed here](https://github.com/automl/pynisher#features). -""" -from __future__ import annotations - -from multiprocessing.context import BaseContext -from typing import TYPE_CHECKING, Callable, TypeVar -from typing_extensions import ParamSpec, Self, override - -from amltk.events import Event -from amltk.scheduling.task_plugin import TaskPlugin -from pynisher import Pynisher - -if TYPE_CHECKING: - import asyncio - - from amltk.scheduling.task import Task - -P = ParamSpec("P") -R = TypeVar("R") - - -class PynisherPlugin(TaskPlugin): - """A plugin that wraps a task in a pynisher to enforce limits on it. - - This plugin wraps a task function in a `Pynisher` instance to enforce - limits on the task. The limits are set by any of `memory_limit=`, - `cpu_time_limit=` and `wall_time_limit=`. - - Adds four new events to the task - - * [`TIMEOUT`][amltk.pynisher.PynisherPlugin.TIMEOUT] - - subscribe with `@task.on("pynisher-timeout")` - * [`MEMORY_LIMIT_REACHED`][amltk.pynisher.PynisherPlugin.MEMORY_LIMIT_REACHED] - - subscribe with `@task.on("pynisher-memory-limit")` - * [`CPU_TIME_LIMIT_REACHED`][amltk.pynisher.PynisherPlugin.CPU_TIME_LIMIT_REACHED] - - subscribe with `@task.on("pynisher-cputime-limit")` - * [`WALL_TIME_LIMIT_REACHED`][amltk.pynisher.PynisherPlugin.WALL_TIME_LIMIT_REACHED] - - subscribe with `@task.on("pynisher-walltime-limit")` - - - ```python exec="true" source="material-block" result="python" title="PynisherPlugin" - from amltk.scheduling import Task, Scheduler - from amltk.pynisher import PynisherPlugin - import time - - def f(x: int) -> int: - time.sleep(x) - return "yay" - - scheduler = Scheduler.with_sequential() - task = scheduler.task(f, plugins=PynisherPlugin(wall_time_limit=(1, "s"))) - - @scheduler.on_start - def on_start(): - task(3) - - @task.on("pynisher-wall-time-limit") - def on_wall_time_limit(exception): - print(f"Wall time limit reached!") - - end_status = scheduler.run(on_exception="end") - print(end_status) - ``` - - Attributes: - memory_limit: The memory limit of the task. - cpu_time_limit: The cpu time limit of the task. - wall_time_limit: The wall time limit of the task. - """ - - name = "pynisher-plugin" - """The name of the plugin.""" - - TIMEOUT: Event[Pynisher.TimeoutException] = Event("pynisher-timeout") - """A Task timed out. - - Will call any subscribers with the exception as the argument. - - ```python - @task.on("pynisher-timeout") - def on_timeout(exception: PynisherPlugin.TimeoutException): - ... - ``` - """ - - MEMORY_LIMIT_REACHED: Event[Pynisher.MemoryLimitException] = Event( - "pynisher-memory-limit", - ) - """A Task was submitted but reached it's memory limit. - - Will call any subscribers with the exception as the argument. - - ```python - @task.on("pynisher-memory-limit") - def on_memout(exception: PynisherPlugin.MemoryLimitException): - ... - ``` - """ - - CPU_TIME_LIMIT_REACHED: Event[Pynisher.CpuTimeoutException] = Event( - "pynisher-cpu-time-limit", - ) - """A Task was submitted but reached it's cpu time limit. - - Will call any subscribers with the exception as the argument. - - ```python - @task.on("pynisher-cpu-time-limit") - def on_cpu_time_limit(exception: PynisherPlugin.TimeoutException): - ... - ``` - """ - - WALL_TIME_LIMIT_REACHED: Event[Pynisher.WallTimeoutException] = Event( - "pynisher-wall-time-limit", - ) - """A Task was submitted but reached it's wall time limit. - - Will call any subscribers with the exception as the argument. - - ```python - @task.on("pynisher-wall-time-limit") - def on_wall_time_limit(exception: PynisherPlugin.TimeoutException): - ... - ``` - """ - - TimeoutException = Pynisher.TimeoutException - """The exception that is raised when a task times out.""" - - MemoryLimitException = Pynisher.MemoryLimitException - """The exception that is raised when a task reaches it's memory limit.""" - - CpuTimeoutException = Pynisher.CpuTimeoutException - """The exception that is raised when a task reaches it's cpu time limit.""" - - WallTimeoutException = Pynisher.WallTimeoutException - """The exception that is raised when a task reaches it's wall time limit.""" - - def __init__( - self, - *, - memory_limit: int | tuple[int, str] | None = None, - cpu_time_limit: int | tuple[float, str] | None = None, - wall_time_limit: int | tuple[float, str] | None = None, - context: BaseContext | None = None, - ): - """Initialize a `PynisherPlugin` instance. - - Args: - memory_limit: The memory limit to wrap the task in. Base unit is in bytes - but you can specify `(value, unit)` where `unit` is one of - `("B", "KB", "MB", "GB")`. Defaults to `None` - cpu_time_limit: The cpu time limit to wrap the task in. Base unit is in - seconds but you can specify `(value, unit)` where `unit` is one of - `("s", "m", "h")`. Defaults to `None` - wall_time_limit: The wall time limit for the task. Base unit is in seconds - but you can specify `(value, unit)` where `unit` is one of - `("s", "m", "h")`. Defaults to `None`. - context: The context to use for multiprocessing. Defaults to `None`. - See [`multiprocessing.get_context()`][multiprocessing.get_context] - """ - super().__init__() - self.memory_limit = memory_limit - self.cpu_time_limit = cpu_time_limit - self.wall_time_limit = wall_time_limit - self.context = context - - self.task: Task - - @override - def pre_submit( - self, - fn: Callable[P, R], - *args: P.args, - **kwargs: P.kwargs, - ) -> tuple[Callable[P, R], tuple, dict]: - """Wrap a task function in a `Pynisher` instance.""" - # If any of our limits is set, we need to wrap it in Pynisher - # to enfore these limits. - if any( - limit is not None - for limit in (self.memory_limit, self.cpu_time_limit, self.wall_time_limit) - ): - fn = Pynisher( - fn, - memory=self.memory_limit, - cpu_time=self.cpu_time_limit, - wall_time=self.wall_time_limit, - terminate_child_processes=True, - context=self.context, - ) - - return fn, args, kwargs - - @override - def attach_task(self, task: Task) -> None: - """Attach the plugin to a task.""" - self.task = task - task.emitter.add_event( - self.TIMEOUT, - self.MEMORY_LIMIT_REACHED, - self.CPU_TIME_LIMIT_REACHED, - self.WALL_TIME_LIMIT_REACHED, - ) - - # Check the exception and emit pynisher specific ones too - task.on_exception(self._check_to_emit_pynisher_exception, hidden=True) - - @override - def copy(self) -> Self: - """Return a copy of the plugin. - - Please see [`TaskPlugin.copy()`][amltk.TaskPlugin.copy]. - """ - return self.__class__( - memory_limit=self.memory_limit, - cpu_time_limit=self.cpu_time_limit, - wall_time_limit=self.wall_time_limit, - ) - - def _check_to_emit_pynisher_exception( - self, - _: asyncio.Future, - exception: BaseException, - ) -> None: - """Check if the exception is a pynisher exception and emit it.""" - if isinstance(exception, Pynisher.CpuTimeoutException): - self.task.emitter.emit(self.TIMEOUT, exception) - self.task.emitter.emit(self.CPU_TIME_LIMIT_REACHED, exception) - elif isinstance(exception, self.WallTimeoutException): - self.task.emitter.emit(self.TIMEOUT) - self.task.emitter.emit(self.WALL_TIME_LIMIT_REACHED, exception) - elif isinstance(exception, self.MemoryLimitException): - self.task.emitter.emit(self.MEMORY_LIMIT_REACHED, exception) diff --git a/src/amltk/randomness.py b/src/amltk/randomness.py index ced3e6eb..c10eb355 100644 --- a/src/amltk/randomness.py +++ b/src/amltk/randomness.py @@ -1,6 +1,7 @@ """Utilities for dealing with randomness.""" from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING import numpy as np @@ -9,6 +10,7 @@ from amltk.types import Seed MAX_INT = np.iinfo(np.int32).max +ALPHABET = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") def as_rng(seed: Seed | None = None) -> np.random.Generator: @@ -20,19 +22,16 @@ def as_rng(seed: Seed | None = None) -> np.random.Generator: Returns: A valid np.random.Generator object to use """ - if isinstance(seed, np.random.Generator): - return seed + match seed: + case None | int() | np.integer(): + return np.random.default_rng(seed) + case np.random.Generator(): + return seed + case np.random.RandomState(): + _seed = seed.randint(0, MAX_INT) + return np.random.default_rng(_seed) - if isinstance(seed, np.random.RandomState): - seed = seed.randint(0, MAX_INT) - - if isinstance(seed, np.integer): - seed = int(seed) - - if seed is None or isinstance(seed, int): - return np.random.default_rng(seed) - - raise ValueError(f"Can't use {seed=} to create a numpy.random.Generator instance") + raise ValueError(f"Can't {seed=} ({type(seed)}) to create numpy.random.Generator") def as_randomstate(seed: Seed | None = None) -> np.random.RandomState: @@ -44,19 +43,16 @@ def as_randomstate(seed: Seed | None = None) -> np.random.RandomState: Returns: A valid np.random.RandomSTate object to use """ - if isinstance(seed, np.random.RandomState): - return seed - - if isinstance(seed, np.random.Generator): - seed = seed.integers(0, MAX_INT) - - if isinstance(seed, np.integer): - seed = int(seed) - - if seed is None or isinstance(seed, int): - return np.random.RandomState(seed) + match seed: + case None | int() | np.integer(): + return np.random.RandomState(seed) + case np.random.RandomState(): + return seed + case np.random.Generator(): + _seed = seed.integers(0, MAX_INT) + return np.random.RandomState(_seed) - raise ValueError(f"Can't use {seed=} to create a numpy.random.Generator instance") + raise ValueError(f"Can't {seed=} ({type(seed)}) to create numpy.random.RandomState") def as_int(seed: Seed | None = None) -> int: @@ -68,16 +64,36 @@ def as_int(seed: Seed | None = None) -> int: Returns: A valid integer to use as a seed """ - if isinstance(seed, (int, np.integer)): - return int(seed) + match seed: + case None: + return np.random.default_rng().integers(0, MAX_INT) + case np.integer() | int(): + return int(seed) + case np.random.Generator(): + return seed.integers(0, MAX_INT) + case np.random.RandomState(): + return seed.randint(0, MAX_INT) - if seed is None: - return np.random.default_rng().integers(0, MAX_INT) + raise ValueError(f"Can't {seed=} ({type(seed)}) to create int") - if isinstance(seed, np.random.Generator): - return seed.integers(0, MAX_INT) - if isinstance(seed, np.random.RandomState): - return seed.randint(0, MAX_INT) +def randuid( + k: int = 8, + *, + charset: Sequence[str] = ALPHABET, + seed: Seed | None = None, +) -> str: + """Generate a random alpha-numeric uuid of a specified length. - raise ValueError(f"Can't use {seed=} to create an integer seed") + See: https://stackoverflow.com/a/56398787/5332072 + + Args: + k: The length of the uuid to generate + charset: The charset to use + seed: The seed to use + + Returns: + A random uid + """ + rng = as_rng(seed) + return "".join(rng.choice(np.asarray(charset), size=k)) diff --git a/src/amltk/richutil/__init__.py b/src/amltk/richutil/__init__.py deleted file mode 100644 index fa6e8614..00000000 --- a/src/amltk/richutil/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from amltk.richutil.renderable import RichRenderable -from amltk.richutil.renderers import Function, rich_make_column_selector -from amltk.richutil.util import df_to_table, richify - -__all__ = [ - "df_to_table", - "richify", - "RichRenderable", - "Function", - "rich_make_column_selector", -] diff --git a/src/amltk/richutil/renderers/__init__.py b/src/amltk/richutil/renderers/__init__.py deleted file mode 100644 index a78e64cf..00000000 --- a/src/amltk/richutil/renderers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from amltk.richutil.renderers._make_column_selector import rich_make_column_selector -from amltk.richutil.renderers.executors import ProcessPoolExecutorRenderer -from amltk.richutil.renderers.function import Function - -__all__ = ["Function", "rich_make_column_selector", "ProcessPoolExecutorRenderer"] diff --git a/src/amltk/scheduling/__init__.py b/src/amltk/scheduling/__init__.py index 5ee4a8d0..99d7662e 100644 --- a/src/amltk/scheduling/__init__.py +++ b/src/amltk/scheduling/__init__.py @@ -1,15 +1,19 @@ -from amltk.scheduling.comms import Comm +from amltk.scheduling.events import Emitter, Event, Subscriber +from amltk.scheduling.executors import SequentialExecutor +from amltk.scheduling.plugins import Comm, Limiter, Plugin from amltk.scheduling.scheduler import ExitState, Scheduler -from amltk.scheduling.sequential_executor import SequentialExecutor from amltk.scheduling.task import Task -from amltk.scheduling.task_plugin import CallLimiter, TaskPlugin __all__ = [ "Scheduler", "Comm", "Task", "SequentialExecutor", - "TaskPlugin", - "CallLimiter", + "Plugin", + "Limiter", "ExitState", + "Comm", + "Emitter", + "Subscriber", + "Event", ] diff --git a/src/amltk/scheduling/comms.py b/src/amltk/scheduling/comms.py deleted file mode 100644 index f624874f..00000000 --- a/src/amltk/scheduling/comms.py +++ /dev/null @@ -1,522 +0,0 @@ -"""A module containing the Comm class. - -???+ note - - Please see the documentation for the [`Task`][amltk.scheduling.task.Task] - for basics of a task. -""" -from __future__ import annotations - -import asyncio -import logging -from dataclasses import dataclass, field -from enum import Enum, auto -from multiprocessing import Pipe -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Generic, - Literal, - TypeVar, - overload, -) -from typing_extensions import ParamSpec, TypeAlias, override - -from more_itertools import first_true - -from amltk.asyncm import AsyncConnection -from amltk.events import Event -from amltk.scheduling.task_plugin import TaskPlugin - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from multiprocessing.connection import Connection - from typing_extensions import Self - - from amltk.scheduling.task import Task - - CommID: TypeAlias = int - - -T = TypeVar("T") -P = ParamSpec("P") -R = TypeVar("R") - - -class Comm: - """A communication channel between a worker and scheduler. - - For duplex connections, such as returned by python's builtin - [`Pipe`][multiprocessing.Pipe], use the - [`create(duplex=...)`][amltk.Comm.create] class method. - - Adds three new events to the task: - - * [`MESSAGE`][amltk.scheduling.comms.Comm.MESSAGE] - - subscribe with `@task.on("comm-message")` - * [`REQUEST`][amltk.scheduling.comms.Comm.REQUEST] - - subscribe with `@task.on("comm-request")` - * [`CLOSE`][amltk.scheduling.comms.Comm.CLOSE] - - subscribe with `@task.on("comm-close")` - - Attributes: - connection: The underlying Connection - id: The id of the comm. - """ - - MESSAGE: Event[Comm.Msg] = Event("comm-message") - """A Task has sent a message. - - ```python - @task.on("comm-message") - def on_message(msg: Comm.Msg[int]): - task: Task = msg.task - data: int = msg.data - identifier: int = msg.identifier - - msg.respond("hello") - ``` - """ - - REQUEST: Event[Comm.Msg] = Event("comm-request") - """A Task is waiting for a response to this message. - - ```python - @task.on("comm-message") - def on_request(msg: Comm.Msg[int]): - task: Task = msg.task - data: int = msg.data - identifier: int = msg.identifier - - msg.respond(data * 2) - ``` - """ - - CLOSE: Event[[]] = Event("comm-close") - """The task has signalled it's close. - - ```python - @task.on("comm-close") - def on_close(): - ... - ``` - """ - - def __init__(self, connection: Connection) -> None: - """Initialize the Comm. - - Args: - connection: The underlying Connection - """ - super().__init__() - self.connection = connection - self.id: CommID = id(self) - - def send(self, obj: Any) -> None: - """Send a message. - - Args: - obj: The object to send. - """ - try: - self.connection.send(obj) - except BrokenPipeError: - # It's possble that the connection was closed by the other end - # before we could send the message. - logger.warning(f"Broken pipe error while sending message {obj}") - - def close(self, *, wait_for_ack: bool = False) -> None: - """Close the connection. - - Args: - wait_for_ack: If `True`, wait for an acknowledgement from the - other end before closing the connection. - """ - if not self.connection.closed: - try: - self.connection.send(Comm.Msg.Kind.CLOSE) - except BrokenPipeError: - # It's possble that the connection was closed by the other end - # before we could close it. - pass - except Exception as e: # noqa: BLE001 - logger.error(f"Error sending close signal: {type(e)}{e}") - - if wait_for_ack: - try: - logger.debug("Waiting for ACK") - self.connection.recv() - logger.debug("Recieved ACK") - except Exception as e: # noqa: BLE001 - logger.error(f"Error waiting for ACK: {type(e)}{e}") - - try: - self.connection.close() - except OSError: - # It's possble that the connection was closed by the other end - # before we could close it. - pass - except Exception as e: # noqa: BLE001 - logger.error(f"Error closing connection: {type(e)}{e}") - - @classmethod - def create(cls, *, duplex: bool = True) -> tuple[Self, Self]: - """Create a pair of communication channels. - - Wraps the output of - [`multiprocessing.Pipe(duplex=duplex)`][multiprocessing.Pipe]. - - Args: - duplex: Whether to allow for two-way communication - - Returns: - A pair of communication channels. - """ - reader, writer = Pipe(duplex=duplex) - return cls(reader), cls(writer) - - @property - def as_async(self) -> AsyncComm: - """Return an async version of this comm.""" - return AsyncComm(self) - - # No block with a default - @overload - def request( - self, - msg: Any | None = ..., - *, - block: Literal[False] | float, - default: T, - ) -> Comm.Msg | T: - ... - - # No block with no default - @overload - def request( - self, - msg: Any | None = ..., - *, - block: Literal[False] | float, - default: None = None, - ) -> Comm.Msg | None: - ... - - # Block - @overload - def request( - self, - msg: Any | None = ..., - *, - block: Literal[True] = True, - ) -> Comm.Msg: - ... - - def request( - self, - msg: Any | None = None, - *, - block: bool | float = True, - default: T | None = None, - ) -> Comm.Msg | T | None: - """Receive a message. - - Args: - msg: The message to send to the other end of the connection. - If left empty, will be `None`. - block: Whether to block until a message is received. If False, return - default. - default: The default value to return if block is False and no message - is received. Defaults to None. - - Returns: - The received message or the default. - """ - if block is False: - response = self.connection.poll() # Non blocking poll - return default if not response else self.connection.recv() - - # None indicates blocking poll - poll_timeout = None if block is True else block - self.send((Comm.Msg.Kind.REQUEST, msg)) - response = self.connection.poll(timeout=poll_timeout) - return default if not response else self.connection.recv() - - def __enter__(self) -> Self: - return self - - def __exit__(self, *_: Any) -> None: - self.close(wait_for_ack=False) - - @dataclass - class Msg(Generic[T]): - """A message sent over a communication channel. - - Attributes: - task: The task that sent the message. - comm: The communication channel. - future: The future of the task. - data: The data sent by the task. - """ - - task: Task = field(repr=False) - comm: Comm = field(repr=False) - future: asyncio.Future = field(repr=False) - data: T - identifier: CommID - - def respond(self, response: Any) -> None: - """Respond to the message. - - Args: - response: The response to send back to the task. - """ - self.comm.send(response) - - class Kind(Enum): - """The kind of message.""" - - CLOSE = auto() - MESSAGE = auto() - REQUEST = auto() - - class Plugin(TaskPlugin): - """A plugin that handles communication with a worker.""" - - name: ClassVar[str] = "comm-plugin" - - def __init__( - self, - create_comms: Callable[[], tuple[Comm, Comm]] | None = None, - ) -> None: - """Initialize the plugin. - - Args: - create_comms: A function that creates a pair of communication - channels. Defaults to `Comm.create`. - """ - super().__init__() - if create_comms is None: - create_comms = Comm.create - - self.create_comms = create_comms - self.comms: dict[CommID, tuple[Comm, Comm]] = {} - self.communication_tasks: dict[asyncio.Future, asyncio.Task] = {} - self.task: Task - - @override - def attach_task(self, task: Task) -> None: - """Attach the plugin to a task. - - This method is called when the plugin is attached to a task. This - is the place to subscribe to events on the task, create new subscribers - for people to use or even store a reference to the task for later use. - - Args: - task: The task the plugin is being attached to. - """ - self.task = task - task.on_submitted(self._establish_connection) - - @override - def pre_submit( - self, - fn: Callable[P, R], - *args: P.args, - **kwargs: P.kwargs, - ) -> tuple[Callable[P, R], tuple, dict] | None: - """Pre-submit hook. - - This method is called before the task is submitted. - - Args: - fn: The task function. - *args: The arguments to the task function. - **kwargs: The keyword arguments to the task function. - - Returns: - A tuple of the task function, arguments and keyword arguments - if the task should be submitted, or `None` if the task should - not be submitted. - """ - from amltk.optimization.trial import Trial - - host_comm, worker_comm = self.create_comms() - # NOTE: This works but not sure why pyright is complaining - - trial = first_true( - (a for a in args if isinstance(a, Trial)), - default=None, - ) - if trial is None: - if "comm" in kwargs: - raise ValueError( - "Can't attach a comm as there is already a kwarg named `comm`.", - ) - kwargs.update({"comm": worker_comm}) - - # We don't necessarily know if the future will be submitted. If so, - # we will use this index later to retrieve the host_comm - self.comms[worker_comm.id] = (host_comm, worker_comm) - return fn, args, kwargs - - @override - def copy(self) -> Self: - """Return a copy of the plugin. - - Please see [`TaskPlugin.copy()`][amltk.TaskPlugin.copy]. - """ - return self.__class__(create_comms=self.create_comms) - - def _establish_connection( - self, - f: asyncio.Future, - *args: Any, - **kwargs: Any, - ) -> Any: - from amltk.optimization.trial import Trial - - trial = first_true( - (a for a in args if isinstance(a, Trial)), - default=None, - ) - if trial is None: - if "comm" not in kwargs: - raise ValueError( - "Cannot find a comm as there is no kwarg named `comm`.", - "and cannot find comm from a trial as there is no trial in" - " the arguments." - f"\nArgs: {args} kwargs: {kwargs}", - ) - worker_comm = kwargs["comm"] - else: - worker_comm = trial.plugins["comm"] - - host_comm, worker_comm = self.comms[worker_comm.id] - self.communication_tasks[f] = asyncio.create_task( - self._communicate(f, host_comm, worker_comm), - ) - - async def _communicate( - self, - future: asyncio.Future, - host_comm: Comm, - worker_comm: Comm, - ) -> None: - """Communicate with the task. - - This is a coroutine that will run until the scheduler is stopped or - the comms have finished. - """ - worker_id = worker_comm.id - task_name = self.task.unique_ref - name = f"{task_name}({worker_id})" - - while True: - try: - data = await host_comm.as_async.request() - logger.debug(f"{self.name}: receieved {data=}") - - # When we recieve CLOSE, the task has signalled it's - # close and we emit a CLOSE event. This should break out - # of the loop as we expect no more signals after this point - if data is Comm.Msg.Kind.CLOSE: - self.task.emitter.emit(Comm.CLOSE) - break - - # When we recieve (REQUEST, data), this was sent with - # `request` and we emit a REQUEST event - if ( - isinstance(data, tuple) - and len(data) == 2 # noqa: PLR2004 - and data[0] == Comm.Msg.Kind.REQUEST - ): - _, real_data = data - msg = Comm.Msg( - self.task, - host_comm, - future, - real_data, - identifier=worker_id, - ) - self.task.emitter.emit(Comm.REQUEST, msg) - - # Otherwise it's just a simple `send` with some data we - # emit as a MESSAGE event - else: - msg = Comm.Msg( - self.task, - host_comm, - future, - data, - identifier=worker_id, - ) - self.task.emitter.emit(Comm.MESSAGE, msg) - - except EOFError: - logger.debug(f"{name}: closed connection") - break - - logger.debug(f"{name}: finished communication, closing comms") - - # When the loop is finished, we can't communicate, close the comm - # We explicitly don't wait for any acknowledgment from the worker - host_comm.close(wait_for_ack=False) - worker_comm.close() - - # Remove the reference to the work comm so it gets garbarged - del self.comms[worker_id] - - -@dataclass -class AsyncComm: - """A async wrapper of a Comm.""" - - comm: Comm - - @overload - async def request( - self, - *, - timeout: float, - default: None = None, - ) -> Comm.Msg | None: - ... - - @overload - async def request(self, *, timeout: float, default: T) -> Comm.Msg | T: - ... - - @overload - async def request(self, *, timeout: None = None) -> Comm.Msg: - ... - - async def request( - self, - *, - timeout: float | None = None, - default: T | None = None, - ) -> Comm.Msg | T | None: - """Recieve a message. - - Args: - timeout: The timeout in seconds to wait for a message. - default: The default value to return if the timeout is reached. - - Returns: - The message from the worker or the default value. - """ - connection = AsyncConnection(self.comm.connection) - result = await asyncio.wait_for(connection.recv(), timeout=timeout) - return default if result is None else result - - async def send(self, obj: Comm.Msg) -> None: - """Send a message. - - Args: - obj: The message to send. - """ - return await AsyncConnection(self.comm.connection).send(obj) diff --git a/src/amltk/events.py b/src/amltk/scheduling/events.py similarity index 54% rename from src/amltk/events.py rename to src/amltk/scheduling/events.py index 885bb599..fc1aefb3 100644 --- a/src/amltk/events.py +++ b/src/amltk/scheduling/events.py @@ -1,29 +1,257 @@ -"""All code for allowing an event system.""" +"""One of the primary ways to respond to `@events` emitted +with by a [`Task`](site:reference/scheduling/task.md) or +the [`Scheduler`](site:reference/scheduling/scheduler.md) +is through use of a **callback**. + +The reason for this is to enable an easier time for API's to utilize +multiprocessing and remote compute from the `Scheduler`, without having +to burden users with knowing the details of how to use multiprocessing. + +A callback subscribes to some event using a decorator but can also be done in +a functional style if preferred. The below example is based on the +event [`@scheduler.on_start`][amltk.scheduling.Scheduler.on_start] but +the same applies to all events. + +=== "Decorators" + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + @scheduler.on_start + def print_hello() -> None: + print("hello") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide + ``` + +=== "Functional" + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + def print_hello() -> None: + print("hello") + + scheduler.on_start(print_hello) + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide + ``` + +There are a number of ways to customize the behaviour of these callbacks, notably +to control how often they get called and when they get called. + +??? tip "Callback customization" + + + === "`on('event', repeat=...)`" + + This will cause the callback to be called `repeat` times successively. + This is most useful in combination with + [`@scheduler.on_start`][amltk.scheduling.Scheduler.on_start] to launch + a number of tasks at the start of the scheduler. + + ```python exec="true" source="material-block" html="true" hl_lines="11" + from amltk import Scheduler + + N_WORKERS = 2 + + def f(x: int) -> int: + return x * 2 + from amltk._doc import make_picklable; make_picklable(f) # markdown-exec: hide + + scheduler = Scheduler.with_processes(N_WORKERS) + task = scheduler.task(f) + + @scheduler.on_start(repeat=N_WORKERS) + def on_start(): + task.submit(1) + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, fontsize="small") # markdown-exec: hide + ``` + + === "`on('event', limit=...)`" + + Limit the number of times a callback can be called, after which, the callback + will be ignored. + + ```python exec="true" source="material-block" html="True" hl_lines="13" + from asyncio import Future + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(2) + + def expensive_function(x: int) -> int: + return x ** 2 + from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide + + @scheduler.on_start + def submit_calculations() -> None: + scheduler.submit(expensive_function, 2) + + @scheduler.on_future_result(limit=3) + def print_result(future, result) -> None: + scheduler.submit(expensive_function, 2) + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + + === "`on('event', when=...)`" + + A callable which takes no arguments and returns a `bool`. The callback + will only be called when the `when` callable returns `True`. + + Below is a rather contrived example, but it shows how we can use the + `when` parameter to control when the callback is called. + + ```python exec="true" source="material-block" html="True" hl_lines="8 12" + import random + from amltk.scheduling import Scheduler + + LOCALE = random.choice(["English", "German"]) + + scheduler = Scheduler.with_processes(1) + + @scheduler.on_start(when=lambda: LOCALE == "English") + def print_hello() -> None: + print("hello") + + @scheduler.on_start(when=lambda: LOCALE == "German") + def print_guten_tag() -> None: + print("guten tag") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + + === "`on('event', every=...)`" + + Only call the callback every `every` times the event is emitted. This + includes the first time it's called. + + ```python exec="true" source="material-block" html="True" hl_lines="6" + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + # Print "hello" only every 2 times the scheduler starts. + @scheduler.on_start(every=2) + def print_hello() -> None: + print("hello") + + # Run the scheduler 5 times + scheduler.run() + scheduler.run() + scheduler.run() + scheduler.run() + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + +### Emitter, Subscribers and Events +This part of the documentation is not necessary to understand or use for AMLTK. People +wishing to build tools upon AMLTK may still find this a useful component to add to their +arsenal. + +The core of making this functionality work is the [`Emitter`][amltk.scheduling.events.Emitter]. +Its purpose is to have `@events` that can be emitted and subscribed to. Classes like the +[`Scheduler`][amltk.scheduling.Scheduler] and [`Task`][amltk.scheduling.Task] carry +around with them an `Emitter` to enable all of this functionality. + +Creating an `Emitter` is rather straight-forward, but we must also create +[`Events`][amltk.scheduling.events.Event] that people can subscribe to. + +```python +from amltk.scheduling import Emitter, Event +emitter = Emitter("my-emitter") + +event: Event[int] = Event("my-event") # (1)! + +@emitter.on(event) +def my_callback(x: int) -> None: + print(f"Got {x}!") + +emitter.emit(event, 42) # (2)! +``` + +1. The typing `#!python Event[int]` is used to indicate that the event will be emitting + an integer. This is not necessary, but it is useful for type-checking and + documentation. +2. The `#!python emitter.emit(event, 42)` is used to emit the event. This will call + all the callbacks registered for the event, i.e. `#!python my_callback()`. + +!!! warning "Independent Events" + + Given a single `Emitter` and a single instance of an `Event`, there is no way to + have different `@events` for callbacks. There are two options, both used extensively + in AMLTK. + + The first is to have different `Events` quite naturally, i.e. you distinguish + between different things that can happen. However, you often want to have different + objects emit the same `Event` but have different callbacks for each object. + + This makes most sense in the context of a `Task` the `Event` instances are shared as + class variables in the `Task` class, however a user likely want's to subscribe to + the `Event` for a specific instance of the `Task`. + + This is where the second option comes in, in which each object carries around its + own `Emitter` instance. This is how a user can subscribe to the same kind of `Event` + but individually for each `Task`. + + +However, to shield users from this and to create named access points for users to +subscribe to, we can use the [`Subscriber`][amltk.scheduling.events.Subscriber] class, +conveniently created by the [`Emitter.subscriber()`][amltk.scheduling.events.Emitter.subscriber] +method. + +```python +from amltk.scheduling import Emitter, Event +emitter = Emitter("my-emitter") + +class GPT: + + event: Event[str] = Event("my-event") + + def __init__(self) -> None: + self.on_answer: Subscriber[str] = emitter.subscriber(self.event) + + def ask(self, question: str) -> None: + emitter.emit(self.event, "hello world!") + +gpt = GPT() + +@gpt.on_answer +def print_answer(answer: str) -> None: + print(answer) + +gpt.ask("What is the conical way for an AI to greet someone?") +``` + +Typically these event based systems make little sense in a synchronous context, however +with the [`Scheduler`][amltk.scheduling.Scheduler] and [`Task`][amltk.scheduling.Task] +classes, they are used to enable a simple way to use multiprocessing and remote compute. +""" # noqa: E501 from __future__ import annotations import logging import math import time from collections import Counter, defaultdict +from collections.abc import Callable, Iterable, Iterator, Mapping from dataclasses import dataclass, field from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Iterable, - Iterator, - List, - Mapping, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload from typing_extensions import ParamSpec, override from uuid import uuid4 -from amltk.functional import callstring, funcname -from amltk.richutil.renderers.function import Function +from amltk._functional import callstring, funcname +from amltk._richutil.renderers.function import Function if TYPE_CHECKING: from rich.text import Text @@ -97,7 +325,7 @@ class Subscriber(Generic[P]): """An object that can be used to easily subscribe to a certain event. ```python - from amltk.events import Event, EventManager, Subscriber + from amltk.scheduling.events import Event, Subscriber test_event: Event[[int, str]] = Event("test") @@ -161,29 +389,16 @@ def __call__( ) -> Callable[P, Any]: ... - @overload - def __call__( - self, - callback: Iterable[Callable[P, Any]], - *, - when: Callable[[], bool] | None = ..., - limit: int | None = ..., - repeat: int = ..., - every: int = ..., - hidden: bool = ..., - ) -> None: - ... - def __call__( self, - callback: Callable[P, Any] | Iterable[Callable[P, Any]] | None = None, + callback: Callable[P, Any] | None = None, *, when: Callable[[], bool] | None = None, limit: int | None = None, repeat: int = 1, every: int = 1, hidden: bool = False, - ) -> Callable[P, Any] | partial[Callable[P, Any]] | None: + ) -> Callable[P, Any] | partial[Callable[P, Any]]: """Subscribe to the event associated with this object. Args: @@ -193,7 +408,7 @@ def __call__( repeat: The callback will be called `repeat` times successively. limit: The maximum number of times the callback can be called. hidden: Whether to hide the callback in visual output. - This is mainly used to facilitate TaskPlugins who + This is mainly used to facilitate Plugins who act upon events but don't want to be seen, primarily as they are just book-keeping callbacks. @@ -219,8 +434,6 @@ def __call__( every=every, hidden=hidden, ) - if isinstance(callback, Iterable): - return None return callback def emit(self, *args: P.args, **kwargs: P.kwargs) -> None: @@ -270,7 +483,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: def __rich__(self) -> Text: from rich.text import Text - f_rich = Function(self.callback, signature=False).__rich__() + f_rich = Function(self.callback).__rich__() if self.n_calls_to_callback == 0: return f_rich @@ -282,15 +495,14 @@ def __rich__(self) -> Text: ) -class Emitter(Mapping[Event, List[Handler]]): +class Emitter(Mapping[Event, list[Handler]]): """An event emitter. - This is a convenience class that wraps an event manager and provides - a way to emit events. The events emitter and subscribed to will be - identified by a UUID, such that two objects emitting the same event - will have a different set of listeners who will be called. For - downstream users, this means they must subscribe to events directly - from the object they are using. + This class is used to emit events and register callbacks for those events. + It also provides a convenience function + [`subscriber()`][amltk.scheduling.events.Emitter.subscriber] such + that objects using an `Emitter` can easily create access points for users + to directly subscribe to their [`Events`][amltk.scheduling.events.Event]. """ name: str | None @@ -416,7 +628,7 @@ def subscriber( def on( self, event: Event[P], - callback: Callable | Iterable[Callable], + callback: Callable, *, when: Callable[[], bool] | None = None, every: int = 1, @@ -434,7 +646,7 @@ def on( repeat: The callback will be called `repeat` times successively. limit: The maximum number of times the callback can be called. hidden: Whether to hide the callback in visual output. - This is mainly used to facilitate TaskPlugins who + This is mainly used to facilitate Plugins who act upon events but don't want to be seen, primarily as they are just book-keeping callbacks. """ @@ -450,30 +662,28 @@ def on( # This hackery is just to get down to a flat list of events that need # to be set up - callbacks = [callback] if callable(callback) else list(callback) - for c in callbacks: - self.handlers[event].append( - Handler( - c, - when=when, - every=every, - repeat=repeat, - limit=limit, - hidden=hidden, - ), - ) - - _name = funcname(c) - msg = f"{self.name}: Registered callback '{_name}' for event {event}" - if every: - msg += f" every {every} times" - if when: - msg += f" with predicate ({funcname(when)})" - if repeat > 1: - msg += f" called {repeat} times successively" - if hidden: - msg += " (hidden from visual output)" - logger.debug(msg) + self.handlers[event].append( + Handler( + callback, + when=when, + every=every, + repeat=repeat, + limit=limit, + hidden=hidden, + ), + ) + + _name = funcname(callback) + msg = f"{self.name}: Registered callback '{_name}' for event {event}" + if every > 1: + msg += f" every {every} times" + if when: + msg += f" with predicate ({funcname(when)})" + if repeat > 1: + msg += f" called {repeat} times successively" + if hidden: + msg += " (hidden from visual output)" + logger.debug(msg) def add_event(self, *event: Event) -> None: """Add an event to the event manager so that it shows up in visuals. @@ -491,7 +701,11 @@ def __rich__(self) -> Tree: tree = Tree(self.name or "", hide_root=self.name is None) # This just groups events with callbacks together - handler_items = sorted(self.handlers.items(), key=lambda item: not any(item[1])) + handler_items = sorted( + self.handlers.items(), + key=lambda item: not any(item[1]) + or not all(handler.hidden for handler in item[1]), + ) for event, _handlers in handler_items: event_text = event.__rich__() diff --git a/src/amltk/scheduling/executors/__init__.py b/src/amltk/scheduling/executors/__init__.py new file mode 100644 index 00000000..9ee75ff1 --- /dev/null +++ b/src/amltk/scheduling/executors/__init__.py @@ -0,0 +1,3 @@ +from amltk.scheduling.executors.sequential_executor import SequentialExecutor + +__all__ = ["SequentialExecutor"] diff --git a/src/amltk/dask_jobqueue/executors.py b/src/amltk/scheduling/executors/dask_jobqueue.py similarity index 84% rename from src/amltk/dask_jobqueue/executors.py rename to src/amltk/scheduling/executors/dask_jobqueue.py index f95e8555..57bbcd1d 100644 --- a/src/amltk/dask_jobqueue/executors.py +++ b/src/amltk/scheduling/executors/dask_jobqueue.py @@ -18,18 +18,10 @@ import logging import pprint +from collections.abc import Callable, Iterable, Iterator from concurrent.futures import Executor, Future -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Iterable, - Iterator, - Literal, - TypeVar, -) -from typing_extensions import ParamSpec, Self, TypeAlias +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar +from typing_extensions import ParamSpec, Self, override from dask_jobqueue import ( HTCondorCluster, @@ -75,13 +67,13 @@ def __init__( Prefer to use the class methods to create an instance of this class. - * [`DaskJobqueueExecutor.SLURM()`][amltk.dask_jobqueue.DaskJobqueueExecutor.SLURM] - * [`DaskJobqueueExecutor.HTCondor()`][amltk.dask_jobqueue.DaskJobqueueExecutor.HTCondor] - * [`DaskJobqueueExecutor.LSF()`][amltk.dask_jobqueue.DaskJobqueueExecutor.LSF] - * [`DaskJobqueueExecutor.OAR()`][amltk.dask_jobqueue.DaskJobqueueExecutor.OAR] - * [`DaskJobqueueExecutor.PBS()`][amltk.dask_jobqueue.DaskJobqueueExecutor.PBS] - * [`DaskJobqueueExecutor.SGE()`][amltk.dask_jobqueue.DaskJobqueueExecutor.SGE] - * [`DaskJobqueueExecutor.Moab()`][amltk.dask_jobqueue.DaskJobqueueExecutor.Moab] + * [`DaskJobqueueExecutor.SLURM()`][amltk.scheduling.executors.dask_jobqueue.DaskJobqueueExecutor.SLURM] + * [`DaskJobqueueExecutor.HTCondor()`][amltk.scheduling.executors.dask_jobqueue.DaskJobqueueExecutor.HTCondor] + * [`DaskJobqueueExecutor.LSF()`][amltk.scheduling.executors.dask_jobqueue.DaskJobqueueExecutor.LSF] + * [`DaskJobqueueExecutor.OAR()`][amltk.scheduling.executors.dask_jobqueue.DaskJobqueueExecutor.OAR] + * [`DaskJobqueueExecutor.PBS()`][amltk.scheduling.executors.dask_jobqueue.DaskJobqueueExecutor.PBS] + * [`DaskJobqueueExecutor.SGE()`][amltk.scheduling.executors.dask_jobqueue.DaskJobqueueExecutor.SGE] + * [`DaskJobqueueExecutor.Moab()`][amltk.scheduling.executors.dask_jobqueue.DaskJobqueueExecutor.Moab] Args: cluster: The implementation of a @@ -90,30 +82,35 @@ def __init__( submit_command: To overwrite the submission command if necessary. cancel_command: To overwrite the cancel command if necessary. """ + super().__init__() self.cluster = cluster if submit_command: - self.cluster.job_cls.submit_command = submit_command + self.cluster.job_cls.submit_command = submit_command # type: ignore if cancel_command: - self.cluster.job_cls.cancel_command = cancel_command + self.cluster.job_cls.cancel_command = cancel_command # type: ignore self.cluster.adapt(minimum=0, maximum=n_workers) self.n_workers = n_workers self.executor: ClientExecutor = self.cluster.get_client().get_executor() + @override def __enter__(self) -> Self: - configuration = { - "header": self.cluster.job_header, - "script": self.cluster.job_script(), - "job_name": self.cluster.job_name, - } - config_str = pprint.pformat(configuration) - logger.debug(f"Launching script with configuration:\n {config_str}") - return self - + with super().__enter__(): + configuration = { + "header": self.cluster.job_header, + "script": self.cluster.job_script(), + "job_name": self.cluster.job_name, + } + config_str = pprint.pformat(configuration) + logger.debug(f"Launching script with configuration:\n {config_str}") + return self + + @override def submit( self, fn: Callable[P, R], + /, *args: P.args, **kwargs: P.kwargs, ) -> Future[R]: @@ -122,6 +119,7 @@ def submit( assert isinstance(future, Future) return future + @override def map( self, fn: Callable[..., R], @@ -137,6 +135,7 @@ def map( chunksize=chunksize, ) + @override def shutdown( self, wait: bool = True, # noqa: FBT001, FBT002 @@ -159,7 +158,7 @@ def SLURM( See the [dask_jobqueue.SLURMCluster documentation][dask_jobqueue.SLURMCluster] for more information on the available keyword arguments. """ - return cls( + return cls( # type: ignore SLURMCluster(**kwargs), submit_command=submit_command, cancel_command=cancel_command, @@ -180,7 +179,7 @@ def HTCondor( See the [dask_jobqueue.HTCondorCluster documentation][dask_jobqueue.HTCondorCluster] for more information on the available keyword arguments. """ - return cls( + return cls( # type: ignore HTCondorCluster(**kwargs), submit_command=submit_command, cancel_command=cancel_command, @@ -201,7 +200,7 @@ def LSF( See the [dask_jobqueue.LSFCluster documentation][dask_jobqueue.LSFCluster] for more information on the available keyword arguments. """ - return cls( + return cls( # type: ignore LSFCluster(**kwargs), submit_command=submit_command, cancel_command=cancel_command, @@ -222,7 +221,7 @@ def OAR( See the [dask_jobqueue.OARCluster documentation][dask_jobqueue.OARCluster] for more information on the available keyword arguments. """ - return cls( + return cls( # type: ignore OARCluster(**kwargs), submit_command=submit_command, cancel_command=cancel_command, @@ -243,7 +242,7 @@ def PBS( See the [dask_jobqueue.PBSCluster documentation][dask_jobqueue.PBSCluster] for more information on the available keyword arguments. """ - return cls( + return cls( # type: ignore PBSCluster(**kwargs), submit_command=submit_command, cancel_command=cancel_command, @@ -264,7 +263,7 @@ def SGE( See the [dask_jobqueue.SGECluster documentation][dask_jobqueue.SGECluster] for more information on the available keyword arguments. """ - return cls( + return cls( # type: ignore SGECluster(**kwargs), submit_command=submit_command, cancel_command=cancel_command, @@ -285,7 +284,7 @@ def Moab( See the [dask_jobqueue.MoabCluster documentation][dask_jobqueue.MoabCluster] for more information on the available keyword arguments. """ - return cls( + return cls( # type: ignore MoabCluster(**kwargs), submit_command=submit_command, cancel_command=cancel_command, diff --git a/src/amltk/scheduling/sequential_executor.py b/src/amltk/scheduling/executors/sequential_executor.py similarity index 78% rename from src/amltk/scheduling/sequential_executor.py rename to src/amltk/scheduling/executors/sequential_executor.py index 5b401a09..a14635b3 100644 --- a/src/amltk/scheduling/sequential_executor.py +++ b/src/amltk/scheduling/executors/sequential_executor.py @@ -1,9 +1,10 @@ """A concurrent.futures.Executor interface that forces sequential execution.""" from __future__ import annotations +from collections.abc import Callable from concurrent.futures import Executor, Future -from typing import Callable, TypeVar -from typing_extensions import ParamSpec +from typing import TypeVar +from typing_extensions import ParamSpec, override R = TypeVar("R") P = ParamSpec("P") @@ -12,9 +13,11 @@ class SequentialExecutor(Executor): """A [Executor][concurrent.futures.Executor] interface for sequential execution.""" + @override def submit( self, fn: Callable[P, R], + /, *args: P.args, **kwargs: P.kwargs, ) -> Future[R]: @@ -28,6 +31,8 @@ def submit( Returns: A future that is already resolved with the result/exception of the function. """ + # TODO: It would be good if we can somehow wrap this in some sort + # of async context such that it allows other callbacks to operate. future: Future[R] = Future() future.set_running_or_notify_cancel() diff --git a/src/amltk/scheduling/plugins/__init__.py b/src/amltk/scheduling/plugins/__init__.py new file mode 100644 index 00000000..d4ee12fc --- /dev/null +++ b/src/amltk/scheduling/plugins/__init__.py @@ -0,0 +1,11 @@ +from amltk.scheduling.plugins.comm import Comm +from amltk.scheduling.plugins.limiter import Limiter +from amltk.scheduling.plugins.plugin import Plugin +from amltk.scheduling.plugins.warning_filter import WarningFilter + +__all__ = [ + "Limiter", + "WarningFilter", + "Plugin", + "Comm", +] diff --git a/src/amltk/scheduling/plugins/comm.py b/src/amltk/scheduling/plugins/comm.py new file mode 100644 index 00000000..76898d02 --- /dev/null +++ b/src/amltk/scheduling/plugins/comm.py @@ -0,0 +1,740 @@ +"""The [`Comm.Plugin`][amltk.scheduling.plugins.comm.Comm.Plugin] enables +two way-communication with running [`Task`][amltk.scheduling.task.Task]. + +The [`Comm`][amltk.scheduling.plugins.comm.Comm] provides an easy interface to +communicate while the [`Comm.Msg`][amltk.scheduling.plugins.comm.Comm.Msg] encapsulates +messages between the main process and the `Task`. + +??? tip "Usage" + + To setup a `Task` to work with a `Comm`, the `Task` **must accept a `comm` as + it's first argument**. + + ```python exec="true" source="material-block" result="python" hl_lines="4-7 10 17-19 21-23" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Comm + + def powers_of_two(comm: Comm, start: int, n: int) -> None: + with comm.open(): + for i in range(n): + comm.send(start ** (i+1)) + from amltk._doc import make_picklable; make_picklable(powers_of_two) # markdown-exec: hide + + scheduler = Scheduler.with_processes(1) + task = scheduler.task(powers_of_two, plugins=Comm.Plugin()) + results = [] + + @scheduler.on_start + def on_start(): + task.submit(2, 5) + + @task.on("comm-open") + def on_open(msg: Comm.Msg): + print(f"Task has opened | {msg}") + + @task.on("comm-message") + def on_message(msg: Comm.Msg): + results.append(msg.data) + + scheduler.run() + print(results) + ``` + + You can also block a worker, waiting for a response from the main process, allowing for the + worker to [`request()`][amltk.scheduling.plugins.comm.Comm.request] data from the main process. + + ```python exec="true" source="material-block" result="python" hl_lines="7 20-23" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Comm + + def my_worker(comm: Comm, n_tasks: int) -> None: + with comm.open(): + for task_number in range(n_tasks): + task = comm.request(task_number) + comm.send(f"Task recieved {task} for {task_number}") + from amltk._doc import make_picklable; make_picklable(my_worker) # markdown-exec: hide + + scheduler = Scheduler.with_processes(1) + task = scheduler.task(my_worker, plugins=Comm.Plugin()) + + items = ["A", "B", "C"] + results = [] + + @scheduler.on_start + def on_start(): + task.submit(n_tasks=3) + + @task.on("comm-request") + def on_request(msg: Comm.Msg): + task_number = msg.data + msg.respond(items[task_number]) + + @task.on("comm-message") + def on_message(msg: Comm.Msg): + results.append(msg.data) + + scheduler.run() + print(results) + ``` + +??? example "`@events`" + + Check out the [`@events`](site:reference/scheduling/events.md) + reference for more on how to customize these callbacks. + + === "`@comm-message`" + + ::: amltk.scheduling.plugins.comm.Comm.MESSAGE + + === "`@comm-request`" + + ::: amltk.scheduling.plugins.comm.Comm.REQUEST + + === "`@comm-open`" + + ::: amltk.scheduling.plugins.comm.Comm.OPEN + + === "`@comm-close`" + + ::: amltk.scheduling.plugins.comm.Comm.CLOSE + +??? warning "Supported Backends" + + The current implementation relies on [`Pipe`][multiprocessing.Pipe] which only + works between processes on the same system/cluster. There is also limited support + with `dask` backends. + + This could be extended to allow for web sockets or other forms of connections + but requires time. Please let us know in the Github issues if this is something + you are interested in! +""" # noqa: E501 +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from multiprocessing import ( + Pipe, + TimeoutError as MPTimeoutError, +) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + TypeAlias, + TypeVar, +) +from typing_extensions import ParamSpec, override + +from amltk._asyncm import AsyncConnection +from amltk.scheduling.events import Event +from amltk.scheduling.plugins.plugin import Plugin as TaskPlugin + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from multiprocessing.connection import Connection + from typing_extensions import Self + + from rich.panel import Panel + + from amltk.scheduling.task import Task + + CommID: TypeAlias = int + + +T = TypeVar("T") +M = TypeVar("M") +P = ParamSpec("P") +R = TypeVar("R") + + +@dataclass +class AsyncComm: + """A async wrapper of a Comm.""" + + comm: Comm + + async def request( + self, + *, + timeout: float | None = None, + ) -> Any: + """Recieve a message. + + Args: + timeout: The timeout in seconds to wait for a message, raises + a [`Comm.TimeoutError`][amltk.scheduling.plugins.comm.Comm.TimeoutError] + if the timeout is reached. + If `None`, will wait forever. + + Returns: + The message from the worker or the default value. + """ + connection = AsyncConnection(self.comm.connection) + try: + return await asyncio.wait_for(connection.recv(), timeout=timeout) + except asyncio.TimeoutError as e: + raise Comm.TimeoutError( + f"Timed out waiting for response from {self.comm}", + ) from e + + async def send(self, obj: Any) -> None: + """Send a message. + + Args: + obj: The message to send. + """ + return await AsyncConnection(self.comm.connection).send(obj) + + +class Comm: + """A communication channel between a worker and scheduler. + + For duplex connections, such as returned by python's builtin + [`Pipe`][multiprocessing.Pipe], use the + [`create(duplex=...)`][amltk.Comm.create] class method. + + Adds three new events to the task: + + * [`@comm-message`][amltk.scheduling.plugins.comm.Comm.MESSAGE] + * [`@comm-request`][amltk.scheduling.plugins.comm.Comm.REQUEST] + * [`@comm-close`][amltk.scheduling.plugins.comm.Comm.CLOSE] + * [`@comm-open`][amltk.scheduling.plugins.comm.Comm.OPEN] + + Attributes: + connection: The underlying Connection + id: The id of the comm. + """ + + MESSAGE: Event[Comm.Msg] = Event("comm-message") + """A Task has sent a message to the main process. + + ```python exec="true" source="material-block" html="true" hl_lines="6 11-13" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Comm + + def fn(comm: Comm, x: int) -> int: + with comm.open(): + comm.send(x + 1) + + scheduler = Scheduler.with_processes(1) + task = scheduler.task(fn, plugins=Comm.Plugin()) + + @task.on("comm-message") + def callback(msg: Comm.Msg): + print(msg.data) + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + REQUEST: Event[Comm.Msg] = Event("comm-request") + """A Task has sent a request. + + ```python exec="true" source="material-block" html="true" hl_lines="6 16-18" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Comm + + def greeter(comm: Comm, greeting: str) -> None: + with comm.open(): + name = comm.request() + comm.send(f"{greeting} {name}!") + from amltk._doc import make_picklable; make_picklable(greeter) # markdown-exec: hide + + scheduler = Scheduler.with_processes(1) + task = scheduler.task(greeter, plugins=Comm.Plugin()) + + @scheduler.on_start + def on_start(): + task.submit("Hello") + + @task.on("comm-request") + def on_request(msg: Comm.Msg): + msg.respond("Alice") + + @task.on("comm-message") + def on_msg(msg: Comm.Msg): + print(msg.data) + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ # noqa: E501 + + OPEN: Event[Comm.Msg] = Event("comm-open") + """The task has signalled it's open. + + ```python exec="true" source="material-block" html="true" hl_lines="5 15-17" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Comm + + def fn(comm: Comm) -> None: + with comm.open(): + pass + from amltk._doc import make_picklable; make_picklable(fn) # markdown-exec: hide + + scheduler = Scheduler.with_processes(1) + task = scheduler.task(fn, plugins=Comm.Plugin()) + + @scheduler.on_start + def on_start(): + task.submit() + + @task.on("comm-open") + def callback(msg: Comm.Msg): + print("Comm has just used comm.open()") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + CLOSE: Event[Comm.Msg] = Event("comm-close") + """The task has signalled it's close. + + ```python exec="true" source="material-block" html="true" hl_lines="7 17-19" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Comm + + def fn(comm: Comm) -> None: + with comm.open(): + pass + # Will send a close signal to the main process as it exists this block + + print("Done") + from amltk._doc import make_picklable; make_picklable(fn) # markdown-exec: hide + scheduler = Scheduler.with_processes(1) + task = scheduler.task(fn, plugins=Comm.Plugin()) + + @scheduler.on_start + def on_start(): + task.submit() + + @task.on("comm-close") + def on_close(msg: Comm.msg): + print(f"Worker close with {msg}") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + def __init__(self, connection: Connection) -> None: + """Initialize the Comm. + + Args: + connection: The underlying Connection + """ + super().__init__() + self.connection = connection + self.id: CommID = id(self) + + def _send_pipe(self, obj: Any) -> None: + self.connection.send(obj) + + def send(self, obj: Any) -> None: + """Send a message. + + Args: + obj: The object to send. + """ + self._send_pipe((Comm.Msg.Kind.MESSAGE, obj)) + + def close( # noqa: PLR0912, C901 + self, + msg: Any | None = None, + *, + wait_for_ack: bool = False, + okay_if_broken_pipe: bool = False, + side: str = "", + ) -> None: + """Close the connection. + + Args: + msg: The message to send to the other end of the connection. + wait_for_ack: If `True`, wait for an acknowledgement from the + other end before closing the connection. + okay_if_broken_pipe: If `True`, will not log an error if the + connection is already closed. + side: The side of the connection for naming purposes. + """ + if not self.connection.closed: + kind = Comm.Msg.Kind.CLOSE_WITH_ACK if wait_for_ack else Comm.Msg.Kind.CLOSE + try: + self._send_pipe((kind, msg)) + except BrokenPipeError as e: + if not okay_if_broken_pipe: + logger.error(f"{side} - Error sending close signal: {type(e)}{e}") + except Exception as e: # noqa: BLE001 + logger.error(f"{side} - Error sending close signal: {type(e)}{e}") + else: + if wait_for_ack: + logger.debug(f"{side} - Waiting for ACK") + try: + recieved_msg = self.connection.recv() + except Exception as e: # noqa: BLE001 + logger.error( + f"{side} - Error waiting for ACK, closing: {type(e)}{e}", + ) + else: + match recieved_msg: + case Comm.Msg.Kind.WORKER_CLOSE_REQUEST: + logger.error( + f"{side} - Worker recieved request to close!", + ) + case Comm.Msg.Kind.ACK: + logger.debug(f"{side} - Recieved ACK, closing") + case _: + logger.warning( + f"{side} - Expected ACK but {recieved_msg=}", + ) + finally: + try: + self.connection.close() + except OSError: + # It's possble that the connection was closed by the other end + # before we could close it. + pass + except Exception as e: # noqa: BLE001 + logger.error(f"{side} - Error closing connection: {type(e)}{e}") + + @classmethod + def create(cls, *, duplex: bool = True) -> tuple[Self, Self]: + """Create a pair of communication channels. + + Wraps the output of + [`multiprocessing.Pipe(duplex=duplex)`][multiprocessing.Pipe]. + + Args: + duplex: Whether to allow for two-way communication + + Returns: + A pair of communication channels. + """ + reader, writer = Pipe(duplex=duplex) + return cls(reader), cls(writer) + + @property + def as_async(self) -> AsyncComm: + """Return an async version of this comm.""" + return AsyncComm(self) + + def request( + self, + msg: Any | None = None, + *, + timeout: None | float = None, + ) -> Any: + """Receive a message. + + Args: + msg: The message to send to the other end of the connection. + If left empty, will be `None`. + timeout: If float, will wait for that many seconds, raising an exception + if exceeded. Otherwise, None will wait forever. + + Raises: + Comm.TimeoutError: If the timeout is reached. + Comm.CloseRequestError: If the other end needs to abruptly end and + can not fufill the request. If thise error is thrown, the worker + should finish as soon as possible. + + Returns: + The received message or the default. + """ + self._send_pipe((Comm.Msg.Kind.REQUEST, msg)) + if not self.connection.poll(timeout): + raise Comm.TimeoutError(f"Timed out waiting for response for {msg}") + + response = self.connection.recv() + if response == Comm.Msg.Kind.WORKER_CLOSE_REQUEST: + logger.error("Worker recieved request to close!") + raise Comm.CloseRequestError() + + return response + + @contextmanager + def open( + self, + opening_msg: Any | None = None, + *, + wait_for_ack: bool = False, + side: str = "worker", + ) -> Iterator[Self]: + """Open the connection. + + Args: + opening_msg: The message to send to the main process + when the connection is opened. + wait_for_ack: If `True`, wait for an acknowledgement from the + other end before closing the connection and exiting the + context manager. + side: The side of the connection for naming purposes. + Usually this is only done on the `"worker"` side. + + Yields: + The comm. + """ + self._send_pipe((Comm.Msg.Kind.OPEN, opening_msg)) + yield self + self.close(wait_for_ack=wait_for_ack, side=side) + + class Plugin(TaskPlugin): + """A plugin that handles communication with a worker.""" + + name: ClassVar[str] = "comm-plugin" + + def __init__( + self, + create_comms: Callable[[], tuple[Comm, Comm]] | None = None, + ) -> None: + """Initialize the plugin. + + Args: + create_comms: A function that creates a pair of communication + channels. Defaults to `Comm.create`. + """ + super().__init__() + if create_comms is None: + create_comms = Comm.create + + self.create_comms = create_comms + self.comms: dict[CommID, tuple[Comm, Comm]] = {} + self.communication_tasks: list[asyncio.Task] = [] + self.task: Task + self.open_comms: set[CommID] = set() + + @override + def attach_task(self, task: Task) -> None: + """Attach the plugin to a task. + + This method is called when the plugin is attached to a task. This + is the place to subscribe to events on the task, create new subscribers + for people to use or even store a reference to the task for later use. + + Args: + task: The task the plugin is being attached to. + """ + self.task = task + task.emitter.add_event(Comm.MESSAGE, Comm.REQUEST, Comm.OPEN, Comm.CLOSE) + task.on_submitted(self._begin_listening, hidden=True) + + @override + def pre_submit( + self, + fn: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> tuple[Callable[P, R], tuple, dict] | None: + """Pre-submit hook. + + This method is called before the task is submitted. + + Args: + fn: The task function. + *args: The arguments to the task function. + **kwargs: The keyword arguments to the task function. + + Returns: + A tuple of the task function, arguments and keyword arguments + if the task should be submitted, or `None` if the task should + not be submitted. + """ + host_comm, worker_comm = self.create_comms() + + # We don't necessarily know if the future will be submitted. If so, + # we will use this index later to retrieve the host_comm + self.comms[worker_comm.id] = (host_comm, worker_comm) + + # Make sure to include the Comm + return fn, (worker_comm, *args), kwargs + + @override + def copy(self) -> Self: + """Return a copy of the plugin. + + Please see [`Plugin.copy()`][amltk.scheduling.Plugin.copy]. + """ + return self.__class__(create_comms=self.create_comms) + + def _begin_listening(self, f: asyncio.Future, *args: Any, **_: Any) -> Any: + match args: + case (worker_comm, *_) if isinstance(worker_comm, Comm): + worker_comm = args[0] + case _: + raise ValueError(f"Expected first arg to be a Comm, got {args[0]}") + + host_comm, worker_comm = self.comms[worker_comm.id] + + coroutine = asyncio.create_task( + self._communicate(f, host_comm, worker_comm), + ) + coroutine.add_done_callback(self._deregister_comm_coroutine) + + # NOTE: Asyncio coroutines must have a reference stored somewhere so + # we need to hold on to it until it's done. + self.communication_tasks.append(coroutine) + + def _deregister_comm_coroutine(self, coroutine: asyncio.Task) -> None: + if coroutine in self.communication_tasks: + self.communication_tasks.remove(coroutine) + else: + logger.warning(f"Communication coroutine {coroutine} not found!") + + if (exception := coroutine.exception()) is not None: + raise exception + + async def _communicate( + self, + future: asyncio.Future, + host_comm: Comm, + worker_comm: Comm, + ) -> None: + """Communicate with the task. + + This is a coroutine that will run until the scheduler is stopped or + the comms have finished. + """ + worker_id = worker_comm.id + task_name = self.task.unique_ref + name = f"Task [{task_name}] (worker_id: {worker_id})" + closed = False + + try: + while not closed and (_msg := await host_comm.as_async.request()): + assert isinstance(_msg, tuple), "Expected (msg_kind, data)!" + msg_kind, data = _msg + logger.debug(f"{self.name}: receieved {msg_kind} with {data=}") + + match msg_kind: + # Other side has closed the connection, break out of coroutine + case Comm.Msg.Kind.CLOSE: + closed = True + case Comm.Msg.Kind.CLOSE_WITH_ACK: + host_comm._send_pipe(Comm.Msg.Kind.ACK) + closed = True + case Comm.Msg.Kind.OPEN: + self.open_comms.add(worker_id) + case _: + pass + + event = EVENT_LOOKUP[msg_kind] + msg = Comm.Msg( + comm=host_comm, + data=data, + kind=msg_kind, + future=future, + task=self.task, + ) + self.task.emitter.emit(event, msg) + + except EOFError: + # This means the connection dropped to the worker, however this is not + # an error in the main process and so we can safely ignore that. + logger.debug(f"{name}: closed connection") + except Exception as e: + # Something unexpected happened in the main process, either from us or + # from a users callback. In this case we want to raise the exception + logger.error( + f"{name}: Exception occured in scheduler or callbacks!", + exc_info=e, + ) + + # NOTE: It's important that we let the worker know that something went + # wrong, especially if it's requesting things. The worker will only + # see this msg when it does a `request()` + host_comm._send_pipe(Comm.Msg.Kind.WORKER_CLOSE_REQUEST) + raise e + finally: + # Make sure we do all the clean up! + logger.debug(f"{name}: finished communication, closing comms") + + # We don't necessarily know how we got here but + host_comm.close( + wait_for_ack=False, + okay_if_broken_pipe=True, + side="host", + ) + worker_comm.close( + wait_for_ack=False, + okay_if_broken_pipe=True, + side="host-on-worker-comm", + ) + + if worker_id in self.open_comms: + self.open_comms.remove(worker_id) + + # Remove the reference to the work comm so it gets garbarged + del self.comms[worker_id] + logger.debug(f"{name}: finished and cleaned") + + @override + def __rich__(self) -> Panel: + from rich.panel import Panel + from rich.text import Text + + return Panel( + Text("Open Connections: ").append(str(len(self.open_comms)), "yellow"), + title=f"Plugin {self.name}", + ) + + @dataclass + class Msg(Generic[T]): + """A message sent over a communication channel. + + Attributes: + task: The task that sent the message. + comm: The communication channel. + future: The future of the task. + data: The data sent by the task. + """ + + kind: Kind + data: T + comm: Comm = field(repr=False) + future: asyncio.Future = field(repr=False) + task: Task = field(repr=False) + + def respond(self, response: Any) -> None: + """Respond to the message. + + Args: + response: The response to send back to the task. + """ + self.comm._send_pipe(response) + + class Kind(str, Enum): + """The kind of message.""" + + CLOSE = "close" + CLOSE_WITH_ACK = "close-with-ack" + WORKER_CLOSE_REQUEST = "worker-close-request" + OPEN = "open" + MESSAGE = "message" + REQUEST = "request" + ACK = "ack" + + @override + def __str__(self) -> str: + return self.value + + class TimeoutError(MPTimeoutError): # noqa: A001 + """A timeout error for communications.""" + + class CloseRequestError(RuntimeError): + """An exception happened in the main process and it send + a response to the worker to raise this exception. + """ + + +EVENT_LOOKUP = { + Comm.Msg.Kind.CLOSE: Comm.CLOSE, + Comm.Msg.Kind.CLOSE_WITH_ACK: Comm.CLOSE, + Comm.Msg.Kind.OPEN: Comm.OPEN, + Comm.Msg.Kind.MESSAGE: Comm.MESSAGE, + Comm.Msg.Kind.REQUEST: Comm.REQUEST, +} diff --git a/src/amltk/scheduling/plugins/limiter.py b/src/amltk/scheduling/plugins/limiter.py new file mode 100644 index 00000000..d0083cd1 --- /dev/null +++ b/src/amltk/scheduling/plugins/limiter.py @@ -0,0 +1,300 @@ +"""The [`Limiter`][amltk.scheduling.plugins.Limiter] can limit the number of +times a function is called, how many concurrent instances of it can be running, +or whether it can run while another task is running. + +The functionality of the `Limiter` could also be implemented without a plugin but +it gives some nice utility. + +??? tip "Usage" + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Limiter + + def fn(x: int) -> int: + return x + 1 + + scheduler = Scheduler.with_processes(1) + + task = scheduler.task(fn, plugins=[Limiter(max_calls=2)]) + + @task.on("call-limit-reached") + def callback(task: Task, *args, **kwargs): + pass + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + +??? example "`@events`" + + Check out the [`@events`](site:reference/scheduling/events.md) + reference for more on how to customize these callbacks. + + === "`@call-limit-reached`" + + ::: amltk.scheduling.plugins.Limiter.CALL_LIMIT_REACHED + + === "`@concurrent-limit-reached`" + + ::: amltk.scheduling.plugins.Limiter.CONCURRENT_LIMIT_REACHED + + === "`@disabled-due-to-running-task`" + + ::: amltk.scheduling.plugins.Limiter.DISABLED_DUE_TO_RUNNING_TASK +""" +from __future__ import annotations + +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar +from typing_extensions import ParamSpec, Self, override + +from amltk.scheduling.events import Event +from amltk.scheduling.plugins.plugin import Plugin + +if TYPE_CHECKING: + from rich.panel import Panel + + from amltk.scheduling.task import Task + +P = ParamSpec("P") +R = TypeVar("R") +TrialInfo = TypeVar("TrialInfo") + + +class Limiter(Plugin): + """A plugin that limits the submission of a task. + + Adds three new events to the task: + + * [`@call-limit-reached`][amltk.scheduling.plugins.Limiter.CALL_LIMIT_REACHED] + * [`@concurrent-limit-reached`][amltk.scheduling.plugins.Limiter.CONCURRENT_LIMIT_REACHED] + * [`@disabled-due-to-running-task`][amltk.scheduling.plugins.Limiter.DISABLED_DUE_TO_RUNNING_TASK] + """ # noqa: E501 + + name: ClassVar = "limiter" + """The name of the plugin.""" + + CALL_LIMIT_REACHED: Event[...] = Event("call-limit-reached") + """The event emitted when the task has reached its call limit. + + Will call any subscribers with the task as the first argument, + followed by the arguments and keyword arguments that were passed to the task. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Limiter + + def fn(x: int) -> int: + return x + 1 + + scheduler = Scheduler.with_processes(1) + + task = scheduler.task(fn, plugins=[Limiter(max_calls=2)]) + + @task.on("call-limit-reached") + def callback(task: Task, *args, **kwargs): + pass + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + CONCURRENT_LIMIT_REACHED: Event[...] = Event("concurrent-limit-reached") + """The event emitted when the task has reached its concurrent call limit. + + Will call any subscribers with the task as the first argument, followed by the + arguments and keyword arguments that were passed to the task. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Limiter + + def fn(x: int) -> int: + return x + 1 + + scheduler = Scheduler.with_processes(2) + + task = scheduler.task(fn, plugins=[Limiter(max_concurrent=2)]) + + @task.on("concurrent-limit-reached") + def callback(task: Task, *args, **kwargs): + pass + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + DISABLED_DUE_TO_RUNNING_TASK: Event[...] = Event("disabled-due-to-running-task") + """The event emitter when the task was not submitted due to some other + running task. + + Will call any subscribers with the task as first argument, followed by + the arguments and keyword arguments that were passed to the task. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Limiter + + def fn(x: int) -> int: + return x + 1 + + scheduler = Scheduler.with_processes(2) + + other_task = scheduler.task(fn) + task = scheduler.task(fn, plugins=[Limiter(not_while_running=other_task)]) + + @task.on("disabled-due-to-running-task") + def callback(task: Task, *args, **kwargs): + pass + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + def __init__( + self, + *, + max_calls: int | None = None, + max_concurrent: int | None = None, + not_while_running: Task | Iterable[Task] | None = None, + ): + """Initialize the plugin. + + Args: + max_calls: The maximum number of calls to the task. + max_concurrent: The maximum number of calls of this task that can + be in the queue. + not_while_running: A task or iterable of tasks that if active, will prevent + this task from being submitted. + """ + super().__init__() + + if not_while_running is None: + not_while_running = [] + elif isinstance(not_while_running, Iterable): + not_while_running = list(not_while_running) + else: + not_while_running = [not_while_running] + + self.max_calls = max_calls + self.max_concurrent = max_concurrent + self.not_while_running = not_while_running + self.task: Task | None = None + + if isinstance(max_calls, int) and not max_calls > 0: + raise ValueError("max_calls must be greater than 0") + + if isinstance(max_concurrent, int) and not max_concurrent > 0: + raise ValueError("max_concurrent must be greater than 0") + + self._calls = 0 + + @override + def attach_task(self, task: Task) -> None: + """Attach the plugin to a task.""" + self.task = task + + if self.task in self.not_while_running: + raise ValueError( + f"Task {self.task} was found in the {self.not_while_running=}" + " list. This is disabled but please raise an issue if you think this" + " has sufficient use case.", + ) + + task.emitter.add_event( + self.CALL_LIMIT_REACHED, + self.CONCURRENT_LIMIT_REACHED, + self.DISABLED_DUE_TO_RUNNING_TASK, + ) + + # Make sure to increment the count when a task was submitted + task.on_submitted(self._increment_call_count, hidden=True) + + @property + def n_running(self) -> int: + """Return the number of running tasks.""" + assert self.task is not None + return len(self.task.queue) + + @override + def pre_submit( + self, + fn: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> tuple[Callable[P, R], tuple, dict] | None: + """Pre-submit hook. + + Prevents submission of the task if it exceeds any of the set limits. + """ + assert self.task is not None + + if self.max_calls is not None and self._calls >= self.max_calls: + self.task.emitter.emit(self.CALL_LIMIT_REACHED, self.task, *args, **kwargs) + return None + + if self.max_concurrent is not None and self.n_running >= self.max_concurrent: + self.task.emitter.emit( + self.CONCURRENT_LIMIT_REACHED, + self.task, + *args, + **kwargs, + ) + return None + + for other_task in self.not_while_running: + if other_task.running(): + self.task.emitter.emit( + self.DISABLED_DUE_TO_RUNNING_TASK, + other_task, + self.task, + *args, + **kwargs, + ) + return None + + return fn, args, kwargs + + @override + def copy(self) -> Self: + """Return a copy of the plugin.""" + return self.__class__( + max_calls=self.max_calls, + max_concurrent=self.max_concurrent, + ) + + def _increment_call_count(self, *_: Any, **__: Any) -> None: + self._calls += 1 + + @override + def __rich__(self) -> Panel: + from rich.panel import Panel + from rich.table import Table + from rich.text import Text + + from amltk._richutil import Function + + table = Table.grid(padding=(0, 1)) + if self.max_calls is not None: + table.add_row("Calls", f"{self._calls}/{self.max_calls}") + + if self.max_concurrent is not None: + table.add_row("Concurrent", f"{self.n_running}/{self.max_concurrent}") + + for task in self.not_while_running: + f = Function( + task.function, + signature="...", + link=False, + ) + if task.running(): + table.add_row( + "Not While", + f, + Text(task.unique_ref, "italic, yellow"), + Text("Running", style="bold green"), + ) + else: + table.add_row( + "Not While", + f, + Text("Ref: ").append(task.unique_ref, "italic yellow"), + ) + + return Panel(table, title=f"Plugin {self.name}") diff --git a/src/amltk/scheduling/plugins/plugin.py b/src/amltk/scheduling/plugins/plugin.py new file mode 100644 index 00000000..18d121b7 --- /dev/null +++ b/src/amltk/scheduling/plugins/plugin.py @@ -0,0 +1,198 @@ +r"""A plugin that can be attached to a Task. + +By inheriting from a `Plugin`, you can hook into a +[`Task`][amltk.scheduling.Task]. A plugin can affect, modify and extend its +behaviours. Please see the documentation of the methods for more information. +Creating a plugin is only necesary if you need to modify actual behaviour of +the task. For siply hooking into the lifecycle of a task, you can use the `@events` +that a `Task` emits. + +??? example "Creating a Plugin" + + For a full example of a simple plugin, see the + [`Limiter`][amltk.scheduling.plugins.Limiter] plugin which prevents + the task being submitted if for example, it has already been submitted + too many times. + + The below example shows how to create a plugin that prints the task name + before submitting it. It also emits an event when the task is submitted. + + ```python exec="true" source="material-block" html="true" + from __future__ import annotations + from typing import Callable + + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import Plugin + from amltk.scheduling.events import Event + + # A simple plugin that prints the task name before submitting + class Printer(Plugin): + name = "my-plugin" + + # Define an event the plugin will emit + # Event[Task] indicates the callback for the event will be called with the task + PRINTED: Event[str] = Event("printer-msg") + + def __init__(self, greeting: str): + self.greeting = greeting + self.n_greetings = 0 + + def attach_task(self, task) -> None: + self.task = task + # Register an event with the task, this lets the task know valid events + # people can subscribe to and helps it show up in visuals + task.emitter.add_event(self.PRINTED) + task.on_submitted(self._print_submitted, hidden=True) # You can hide this callback from visuals + + def pre_submit(self, fn, *args, **kwargs) -> tuple[Callable, tuple, dict]: + print(f"{self.greeting} for {self.task} {args} {kwargs}") + self.n_greetings += 1 + return fn, args, kwargs + + def _print_submitted(self, future, *args, **kwargs) -> None: + msg = f"Task was submitted {self.task} {args} {kwargs}" + self.task.emitter.emit(self.PRINTED, msg) # Emit the event with a msg + + def copy(self) -> Printer: + # Plugins need to be able to copy themselves as if fresh + return self.__class__(self.greeting) + + def __rich__(self): + # Custome how the plugin is displayed in rich (Optional) + # rich is an optional dependancy of amltk so we move the imports into here + from rich.panel import Panel + + return Panel( + f"Greeting: {self.greeting} ({self.n_greetings})", + title=f"Plugin {self.name}" + ) + + def fn(x: int) -> int: + return x + 1 + from amltk._doc import make_picklable; make_picklable(fn) # markdown-exec: hide + + scheduler = Scheduler.with_processes(1) + task = scheduler.task(fn, plugins=[Printer("Hello")]) + + @scheduler.on_start + def on_start(): + task.submit(15) + + @task.on("printer-msg") + def callback(msg: str): + print("\nmsg") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + + All methods are optional, and you can choose to implement only the ones + you need. Most plugins will likely need to implement the + [`attach_task()`][amltk.scheduling.Plugin.attach_task] method, which is called + when the plugin is attached to a task. In this method, you can for + example subscribe to events on the task, create new subscribers for people + to use or even store a reference to the task for later use. + + Plugins are also encouraged to utilize the events of a + [`Task`][amltk.scheduling.Task] to further hook into the lifecycle of the task. + For exampe, by saving a reference to the task in the `attach_task()` method, you + can use the [`emit()`][amltk.scheduling.Task] method of the task to emit + your own specialized events. +""" # noqa: E501 +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from itertools import chain +from typing import TYPE_CHECKING, ClassVar, TypeVar +from typing_extensions import ParamSpec, Self, override + +from amltk._richutil.renderable import RichRenderable +from amltk.scheduling.events import Event + +if TYPE_CHECKING: + from rich.panel import Panel + + from amltk.scheduling import Task + +logger = logging.getLogger(__name__) + + +P = ParamSpec("P") +P2 = ParamSpec("P2") + +R = TypeVar("R") +R2 = TypeVar("R2") +CallableT = TypeVar("CallableT", bound=Callable) + + +class Plugin(RichRenderable, ABC): + """A plugin that can be attached to a Task.""" + + name: ClassVar[str] + """The name of the plugin. + + This is used to identify the plugin during logging. + """ + + def attach_task(self, task: Task) -> None: + """Attach the plugin to a task. + + This method is called when the plugin is attached to a task. This + is the place to subscribe to events on the task, create new subscribers + for people to use or even store a reference to the task for later use. + + Args: + task: The task the plugin is being attached to. + """ + + def pre_submit( + self, + fn: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> tuple[Callable[P, R], tuple, dict] | None: + """Pre-submit hook. + + This method is called before the task is submitted. + + Args: + fn: The task function. + *args: The arguments to the task function. + **kwargs: The keyword arguments to the task function. + + Returns: + A tuple of the task function, arguments and keyword arguments + if the task should be submitted, or `None` if the task should + not be submitted. + """ + return fn, args, kwargs + + def events(self) -> list[Event]: + """Return a list of events that this plugin emits. + + Likely no need to override this method, as it will automatically + return all events defined on the plugin. + """ + inherited_attrs = chain.from_iterable( + vars(cls).values() for cls in self.__class__.__mro__ + ) + return [attr for attr in inherited_attrs if isinstance(attr, Event)] + + @abstractmethod + def copy(self) -> Self: + """Return a copy of the plugin. + + This method is used to create a copy of the plugin when a task is + copied. This is useful if the plugin stores a reference to the task + it is attached to, as the copy will need to store a reference to the + copy of the task. + """ + ... + + @override + def __rich__(self) -> Panel: + from rich.panel import Panel + + return Panel("", title=f"Plugin {self.name}") diff --git a/src/amltk/scheduling/plugins/pynisher.py b/src/amltk/scheduling/plugins/pynisher.py new file mode 100644 index 00000000..8e34bf04 --- /dev/null +++ b/src/amltk/scheduling/plugins/pynisher.py @@ -0,0 +1,365 @@ +"""The [`PynisherPlugin`][amltk.scheduling.plugins.pynisher.PynisherPlugin] +uses [pynisher](https://github.com/automl/pynisher) to place **memory**, **walltime** +and **cputime** constraints on processes, crashing them if these limits are reached. +These default units are `bytes ("B")` and `seconds ("s")` but you can also use other +units, please see the relevant API doc. + +It's best use is when used with +[`Scheduler.with_processes()`][amltk.scheduling.Scheduler.with_processes] to have work +performed in processes. + +!!! tip "Requirements" + + This required `pynisher` which can be installed with: + + ```bash + pip install amltk[pynisher] + + # Or directly + pip install pynisher + ``` + +??? tip "Usage" + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Task, Scheduler + from amltk.scheduling.plugins.pynisher import PynisherPlugin + import time + + def f(x: int) -> int: + time.sleep(x) + return 42 + + scheduler = Scheduler.with_processes() + task = scheduler.task(f, plugins=PynisherPlugin(walltime_limit=(1, "s"))) + + @task.on("pynisher-timeout") + def callback(exception): + pass + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + +??? example "`@events`" + + Check out the [`@events`](site:reference/scheduling/events.md) + reference for more on how to customize these callbacks. + + === "`@pynisher-timeout`" + + ::: amltk.scheduling.plugins.pynisher.PynisherPlugin.TIMEOUT + + === "`@pynisher-memory-limit`" + + ::: amltk.scheduling.plugins.pynisher.PynisherPlugin.MEMORY_LIMIT_REACHED + + === "`@pynisher-cputime-limit`" + + ::: amltk.scheduling.plugins.pynisher.PynisherPlugin.CPU_TIME_LIMIT_REACHED + + === "`@pynisher-walltime-limit`" + + ::: amltk.scheduling.plugins.pynisher.PynisherPlugin.WALL_TIME_LIMIT_REACHED + +??? warning "Scheduler Executor" + + This will place process limits on the task as soon as it starts + running, whever it may be running. If you are using + [`Scheduler.with_sequential()`][amltk.Scheduler.with_sequential] + then this will place limits on the main process, likely not what you + want. This also does not work with a + [`ThreadPoolExecutor`][concurrent.futures.ThreadPoolExecutor]. + + If using this with something like [`dask-jobqueue`], + then this will place limits on the workers it spawns. It would be better + to place limits directly through dask job-queue then. + +??? warning "Platform Limitations (Mac, Windows)" + + Pynisher has some limitations with memory on Mac and Windows: + https://github.com/automl/pynisher#features +""" +from __future__ import annotations + +from collections.abc import Callable +from multiprocessing.context import BaseContext +from typing import TYPE_CHECKING, ClassVar, TypeAlias, TypeVar +from typing_extensions import ParamSpec, Self, override + +import pynisher + +from amltk.scheduling.events import Event +from amltk.scheduling.plugins.plugin import Plugin + +if TYPE_CHECKING: + import asyncio + + from rich.panel import Panel + + from amltk.scheduling.task import Task + +P = ParamSpec("P") +R = TypeVar("R") + + +class PynisherPlugin(Plugin): + """A plugin that wraps a task in a pynisher to enforce limits on it. + + This plugin wraps a task function in a `Pynisher` instance to enforce + limits on the task. The limits are set by any of `memory_limit=`, + `cpu_time_limit=` and `wall_time_limit=`. + + Adds four new events to the task + + * [`@pynisher-timeout`][amltk.scheduling.plugins.pynisher.PynisherPlugin.TIMEOUT] + * [`@pynisher-memory-limit`][amltk.scheduling.plugins.pynisher.PynisherPlugin.MEMORY_LIMIT_REACHED] + * [`@pynisher-cputime-limit`][amltk.scheduling.plugins.pynisher.PynisherPlugin.CPU_TIME_LIMIT_REACHED] + * [`@pynisher-walltime-limit`][amltk.scheduling.plugins.pynisher.PynisherPlugin.WALL_TIME_LIMIT_REACHED] + + Attributes: + memory_limit: The memory limit of the task. + cpu_time_limit: The cpu time limit of the task. + wall_time_limit: The wall time limit of the task. + """ # noqa: E501 + + name: ClassVar = "pynisher-plugin" + """The name of the plugin.""" + + TIMEOUT: Event[PynisherPlugin.TimeoutException] = Event("pynisher-timeout") + """A Task timed out, either due to the wall time or cpu time limit. + + Will call any subscribers with the exception as the argument. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Task, Scheduler + from amltk.scheduling.plugins.pynisher import PynisherPlugin + import time + + def f(x: int) -> int: + time.sleep(x) + return 42 + + scheduler = Scheduler.with_processes() + task = scheduler.task(f, plugins=PynisherPlugin(walltime_limit=(1, "s"))) + + @task.on("pynisher-timeout") + def callback(exception): + pass + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + MEMORY_LIMIT_REACHED: Event[pynisher.MemoryLimitException] = Event( + "pynisher-memory-limit", + ) + """A Task was submitted but reached it's memory limit. + + Will call any subscribers with the exception as the argument. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Task, Scheduler + from amltk.scheduling.plugins.pynisher import PynisherPlugin + import numpy as np + + def f(x: int) -> int: + x = np.arange(100000000) + time.sleep(x) + return 42 + + scheduler = Scheduler.with_processes() + task = scheduler.task(f, plugins=PynisherPlugin(memory_limit=(1, "KB"))) + + @task.on("pynisher-memory-limit") + def callback(exception): + pass + + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + CPU_TIME_LIMIT_REACHED: Event[PynisherPlugin.CpuTimeoutException] = Event( + "pynisher-cputime-limit", + ) + """A Task was submitted but reached it's cpu time limit. + + Will call any subscribers with the exception as the argument. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Task, Scheduler + from amltk.scheduling.plugins.pynisher import PynisherPlugin + import time + + def f(x: int) -> int: + i = 0 + while True: + # Keep busying computing the answer to everything + i += 1 + + return 42 + + scheduler = Scheduler.with_processes() + task = scheduler.task(f, plugins=PynisherPlugin(cputime_limit=(1, "s"))) + + @task.on("pynisher-cputime-limit") + def callback(exception): + pass + + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + WALL_TIME_LIMIT_REACHED: Event[PynisherPlugin.WallTimeoutException] = Event( + "pynisher-walltime-limit", + ) + """A Task was submitted but reached it's wall time limit. + + Will call any subscribers with the exception as the argument. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Task, Scheduler + from amltk.scheduling.plugins.pynisher import PynisherPlugin + import time + + def f(x: int) -> int: + time.sleep(x) + return 42 + + scheduler = Scheduler.with_processes() + task = scheduler.task(f, plugins=PynisherPlugin(walltime_limit=(1, "s"))) + + @task.on("pynisher-walltime-limit") + def callback(exception): + pass + + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` + """ + + TimeoutException: TypeAlias = pynisher.TimeoutException + """The exception that is raised when a task times out.""" + + MemoryLimitException: TypeAlias = pynisher.MemoryLimitException + """The exception that is raised when a task reaches it's memory limit.""" + + CpuTimeoutException: TypeAlias = pynisher.CpuTimeoutException + """The exception that is raised when a task reaches it's cpu time limit.""" + + WallTimeoutException: TypeAlias = pynisher.WallTimeoutException + """The exception that is raised when a task reaches it's wall time limit.""" + + def __init__( + self, + *, + memory_limit: int | tuple[int, str] | None = None, + cputime_limit: int | tuple[float, str] | None = None, + walltime_limit: int | tuple[float, str] | None = None, + context: BaseContext | None = None, + ): + """Initialize a `PynisherPlugin` instance. + + Args: + memory_limit: The memory limit to wrap the task in. Base unit is in bytes + but you can specify `(value, unit)` where `unit` is one of + `("B", "KB", "MB", "GB")`. Defaults to `None` + cputime_limit: The cpu time limit to wrap the task in. Base unit is in + seconds but you can specify `(value, unit)` where `unit` is one of + `("s", "m", "h")`. Defaults to `None` + walltime_limit: The wall time limit for the task. Base unit is in seconds + but you can specify `(value, unit)` where `unit` is one of + `("s", "m", "h")`. Defaults to `None`. + context: The context to use for multiprocessing. Defaults to `None`. + See [`multiprocessing.get_context()`][multiprocessing.get_context] + """ + super().__init__() + self.memory_limit = memory_limit + self.cputime_limit = cputime_limit + self.walltime_limit = walltime_limit + self.context = context + + self.task: Task + + @override + def pre_submit( + self, + fn: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> tuple[Callable[P, R], tuple, dict]: + """Wrap a task function in a `Pynisher` instance.""" + # If any of our limits is set, we need to wrap it in Pynisher + # to enfore these limits. + if any( + limit is not None + for limit in (self.memory_limit, self.cputime_limit, self.walltime_limit) + ): + fn = pynisher.Pynisher( + fn, + memory=self.memory_limit, + cpu_time=self.cputime_limit, + wall_time=self.walltime_limit, + terminate_child_processes=True, + context=self.context, + ) + + return fn, args, kwargs + + @override + def attach_task(self, task: Task) -> None: + """Attach the plugin to a task.""" + self.task = task + task.emitter.add_event( + self.TIMEOUT, + self.MEMORY_LIMIT_REACHED, + self.CPU_TIME_LIMIT_REACHED, + self.WALL_TIME_LIMIT_REACHED, + ) + + # Check the exception and emit pynisher specific ones too + task.on_exception(self._check_to_emit_pynisher_exception, hidden=True) + + @override + def copy(self) -> Self: + """Return a copy of the plugin. + + Please see [`Plugin.copy()`][amltk.Plugin.copy]. + """ + return self.__class__( + memory_limit=self.memory_limit, + cputime_limit=self.cputime_limit, + walltime_limit=self.walltime_limit, + ) + + def _check_to_emit_pynisher_exception( + self, + _: asyncio.Future, + exception: BaseException, + ) -> None: + """Check if the exception is a pynisher exception and emit it.""" + if isinstance(exception, pynisher.CpuTimeoutException): + self.task.emitter.emit(self.TIMEOUT, exception) + self.task.emitter.emit(self.CPU_TIME_LIMIT_REACHED, exception) + elif isinstance(exception, pynisher.WallTimeoutException): + self.task.emitter.emit(self.TIMEOUT) + self.task.emitter.emit(self.WALL_TIME_LIMIT_REACHED, exception) + elif isinstance(exception, pynisher.MemoryLimitException): + self.task.emitter.emit(self.MEMORY_LIMIT_REACHED, exception) + + @override + def __rich__(self) -> Panel: + from rich.panel import Panel + from rich.pretty import Pretty + from rich.table import Table + + table = Table( + "Memory", + "Wall Time", + "CPU Time", + padding=(0, 1), + show_edge=False, + box=None, + ) + table.add_row( + Pretty(self.memory_limit), + Pretty(self.walltime_limit), + Pretty(self.cputime_limit), + ) + return Panel(table, title=f"Plugin {self.name}") diff --git a/src/amltk/threadpoolctl/threadpoolctl_plugin.py b/src/amltk/scheduling/plugins/threadpoolctl.py similarity index 64% rename from src/amltk/threadpoolctl/threadpoolctl_plugin.py rename to src/amltk/scheduling/plugins/threadpoolctl.py index b7ba5226..655293b4 100644 --- a/src/amltk/threadpoolctl/threadpoolctl_plugin.py +++ b/src/amltk/scheduling/plugins/threadpoolctl.py @@ -1,22 +1,53 @@ -"""Plugin for threadpoolctl. - -This plugin is used to make utilize threadpoolctl with tasks, -useful for parallel training of models. Without limiting with +"""The +[`ThreadPoolCTLPlugin`][amltk.scheduling.plugins.threadpoolctl.ThreadPoolCTLPlugin] +if useful for parallel training of models. Without limiting with threadpoolctl, the number of threads used by a given model may oversubscribe to resources and cause significant slowdowns. +This is the mechanism employed by scikit-learn to limit the number of +threads used by a given model. + See [threadpoolctl documentation](https://github.com/joblib/threadpoolctl). -""" +!!! tip "Requirements" + + This requires `threadpoolctl` which can be installed with: + + ```bash + pip install amltk[threadpoolctl] + + # Or directly + pip install threadpoolctl + ``` + +??? tip "Usage" + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins.threadpoolctl import ThreadPoolCTLPlugin + + scheduler = Scheduler.with_processes(1) + + def f() -> None: + # ... some task that respects the limits set by threadpoolctl + pass + + task = scheduler.task(f, plugins=ThreadPoolCTLPlugin(max_threads=1)) + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` +""" from __future__ import annotations import logging -from typing import TYPE_CHECKING, Callable, ClassVar, Generic, TypeVar +from collections.abc import Callable +from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar from typing_extensions import ParamSpec, Self, override -from amltk.scheduling.task_plugin import TaskPlugin +from amltk.scheduling.plugins.plugin import Plugin if TYPE_CHECKING: + from rich.panel import Panel + from amltk.scheduling.task import Task P = ParamSpec("P") @@ -46,7 +77,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: return self.fn(*args, **kwargs) -class ThreadPoolCTLPlugin(TaskPlugin): +class ThreadPoolCTLPlugin(Plugin): """A plugin that limits the usage of threads in a task. This plugin is used to make utilize threadpoolctl with tasks, @@ -71,8 +102,6 @@ def __init__( ): """Initialize the plugin. - See [threadpoolctl documentation](https://github.com/joblib/threadpoolctl). - Args: max_threads: The maximum number of threads to use. user_api: The user API to limit. @@ -112,6 +141,22 @@ def pre_submit( def copy(self) -> Self: """Return a copy of the plugin. - Please see [`TaskPlugin.copy()`][amltk.TaskPlugin.copy]. + Please see [`Plugin.copy()`][amltk.Plugin.copy]. """ return self.__class__(max_threads=self.max_threads, user_api=self.user_api) + + @override + def __rich__(self) -> Panel: + from rich.panel import Panel + from rich.pretty import Pretty + from rich.table import Table + + table = Table( + "Max Threads", + "User-API", + padding=(0, 1), + show_edge=False, + box=None, + ) + table.add_row(Pretty(self.max_threads), Pretty(self.user_api)) + return Panel(table, title=f"Plugin {self.name}") diff --git a/src/amltk/wandb/wandb.py b/src/amltk/scheduling/plugins/wandb.py similarity index 90% rename from src/amltk/wandb/wandb.py rename to src/amltk/scheduling/plugins/wandb.py index 0a903571..205faf59 100644 --- a/src/amltk/wandb/wandb.py +++ b/src/amltk/scheduling/plugins/wandb.py @@ -1,32 +1,40 @@ -"""Wandb plugin.""" +"""Wandb plugin. + +!!! todo + + This plugin is experimental and out of date. + +""" from __future__ import annotations import logging +from collections.abc import Callable, Mapping from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, replace from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, + Concatenate, Generic, Literal, - Mapping, + TypeAlias, TypeVar, - Union, ) -from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, override +from typing_extensions import ParamSpec, Self, override import numpy as np import wandb from wandb.sdk import wandb_run from wandb.sdk.lib import RunDisabled -from amltk.scheduling import Scheduler, SequentialExecutor, Task, TaskPlugin +from amltk.scheduling.executors import SequentialExecutor +from amltk.scheduling.plugins.plugin import Plugin if TYPE_CHECKING: from amltk.optimization import Trial + from amltk.scheduling import Scheduler, Task P = ParamSpec("P") R = TypeVar("R") @@ -34,7 +42,7 @@ logger = logging.getLogger(__name__) -WRun: TypeAlias = Union[wandb_run.Run, RunDisabled] +WRun: TypeAlias = wandb_run.Run | RunDisabled @dataclass @@ -43,8 +51,9 @@ class WandbParams: This class is a dataclass that contains all the parameters that are used to initialize a wandb run. It is used by the - [`WandbPlugin`][amltk.wandb.WandbPlugin] to initialize a run. It can be - modified using the [`modify()`][amltk.wandb.WandbParams.modify] method. + [`WandbPlugin`][amltk.scheduling.plugins.wandb.WandbPlugin] to initialize a run. + It can be modified using the + [`modify()`][amltk.scheduling.plugins.wandb.WandbParams.modify] method. Please refer to the documentation of the [`wandb.init()`](https://docs.wandb.ai/ref/python/init) method for more information @@ -122,7 +131,7 @@ class WandbLiveRunWrap(Generic[P]): This class is used to wrap a function that returns a report to log the results to a wandb run. It is used by the - [`WandbTrialTracker`][amltk.wandb.WandbTrialTracker] to wrap + [`WandbTrialTracker`][amltk.scheduling.plugins.wandb.WandbTrialTracker] to wrap the target function. """ @@ -151,7 +160,7 @@ def __call__(self, trial: Trial, *args: P.args, **kwargs: P.kwargs) -> Trial.Rep params = self.params if self.modify is None else self.modify(trial, self.params) with params.run(name=trial.name, config=trial.config) as run: # Make sure the run is available from the trial - trial.attach_plugin_item("wandb", run) + trial.extras["wandb"] = run report = self.fn(trial, *args, **kwargs) @@ -160,7 +169,7 @@ def __call__(self, trial: Trial, *args: P.args, **kwargs: P.kwargs) -> Trial.Rep wandb_summary = { k: v for k, v in report.summary.items() - if isinstance(v, (int, float, np.number)) + if isinstance(v, int | float | np.number) } run.summary.update(wandb_summary) @@ -168,7 +177,7 @@ def __call__(self, trial: Trial, *args: P.args, **kwargs: P.kwargs) -> Trial.Rep return report -class WandbTrialTracker(TaskPlugin): +class WandbTrialTracker(Plugin): """Track trials using wandb. This class is a task plugin that tracks trials using wandb. @@ -237,7 +246,7 @@ def _check_explicit_reinit_arg_with_executor( Args: scheduler: The scheduler to check. """ - if isinstance(scheduler.executor, (SequentialExecutor, ThreadPoolExecutor)): + if isinstance(scheduler.executor, SequentialExecutor | ThreadPoolExecutor): if self.params.reinit is False: raise ValueError( "WandbPlugin reinit argument is not compatible with" @@ -258,10 +267,11 @@ class WandbPlugin: """Log trials using wandb. This class is the entry point to log trials using wandb. It - can be used to create a [`trial_tracker()`][amltk.wandb.WandbPlugin.trial_tracker] + can be used to create a + [`trial_tracker()`][amltk.scheduling.plugins.wandb.WandbPlugin.trial_tracker] to pass into a [`Task(plugins=...)`][amltk.Task] or to create `wandb.Run`'s for custom purposes with - [`run()`][amltk.wandb.WandbPlugin.run]. + [`run()`][amltk.scheduling.plugins.wandb.WandbPlugin.run]. """ def __init__( diff --git a/src/amltk/scheduling/plugins/warning_filter.py b/src/amltk/scheduling/plugins/warning_filter.py new file mode 100644 index 00000000..c998c735 --- /dev/null +++ b/src/amltk/scheduling/plugins/warning_filter.py @@ -0,0 +1,126 @@ +"""The +[`WarningFilter`][amltk.scheduling.plugins.warning_filter.WarningFilter] +if used to automatically filter out warnings from a [`Task`][amltk.scheduling.task.Task] +as it runs. + +This wraps your function in context manager +[`warnings.catch_warnings()`][warnings.catch_warnings] +and applies your arguments to [`warnings.filterwarnings()`][warnings.filterwarnings], +as you would normally filter warnings in Python. + +??? tip "Usage" + + ```python exec="true" source="material-block" html="true" + import warnings + from amltk.scheduling import Scheduler + from amltk.scheduling.plugins import WarningFilter + + def f() -> None: + warnings.warn("This is a warning") + + scheduler = Scheduler.with_processes(1) + task = scheduler.task(f, plugins=WarningFilter("ignore")) + from amltk._doc import doc_print; doc_print(print, task) # markdown-exec: hide + ``` +""" +from __future__ import annotations + +import warnings +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar +from typing_extensions import ParamSpec, Self, override + +from amltk.scheduling.plugins.plugin import Plugin + +if TYPE_CHECKING: + from rich.panel import Panel + + from amltk.scheduling.task import Task + +P = ParamSpec("P") +R = TypeVar("R") +TrialInfo = TypeVar("TrialInfo") + + +class _IgnoreWarningWrapper(Generic[P, R]): + """A wrapper to ignore warnings.""" + + def __init__( + self, + fn: Callable[P, R], + *warning_args: Any, + **warning_kwargs: Any, + ): + """Initialize the wrapper. + + Args: + fn: The function to wrap. + *warning_args: arguments to pass to + [`warnings.filterwarnings()`][warnings.filterwarnings]. + **warning_kwargs: keyword arguments to pass to + [`warnings.filterwarnings()`][warnings.filterwarnings]. + """ + super().__init__() + self.fn = fn + self.warning_args = warning_args + self.warning_kwargs = warning_kwargs + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + with warnings.catch_warnings(): + warnings.filterwarnings(*self.warning_args, **self.warning_kwargs) + return self.fn(*args, **kwargs) + + +class WarningFilter(Plugin): + """A plugin that disables warnings emitted from tasks.""" + + name: ClassVar = "warning-filter" + """The name of the plugin.""" + + def __init__(self, *args: Any, **kwargs: Any): + """Initialize the plugin. + + Args: + *args: arguments to pass to + [`warnings.filterwarnings`][warnings.filterwarnings]. + **kwargs: keyword arguments to pass to + [`warnings.filterwarnings`][warnings.filterwarnings]. + """ + super().__init__() + self.task: Task | None = None + self.warning_args = args + self.warning_kwargs = kwargs + + @override + def attach_task(self, task: Task) -> None: + """Attach the plugin to a task.""" + self.task = task + + @override + def pre_submit( + self, + fn: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> tuple[Callable[P, R], tuple, dict]: + """Pre-submit hook. + + Wraps the function to ignore warnings. + """ + wrapped_f = _IgnoreWarningWrapper(fn, *self.warning_args, **self.warning_kwargs) + return wrapped_f, args, kwargs + + @override + def copy(self) -> Self: + """Return a copy of the plugin.""" + return self.__class__(*self.warning_args, **self.warning_kwargs) + + @override + def __rich__(self) -> Panel: + from rich.panel import Panel + from rich.pretty import Pretty + from rich.table import Table + + table = Table("Args", "Kwargs", padding=(0, 1), show_edge=False, box=None) + table.add_row(Pretty(self.warning_args), Pretty(self.warning_kwargs)) + return Panel(table, title=f"Plugin {self.name}") diff --git a/src/amltk/scheduling/scheduler.py b/src/amltk/scheduling/scheduler.py index ddb9d556..7728f073 100644 --- a/src/amltk/scheduling/scheduler.py +++ b/src/amltk/scheduling/scheduler.py @@ -1,84 +1,273 @@ -"""A scheduler which uses asyncio and an executor to run tasks concurrently. +"""The [`Scheduler`][amltk.scheduling.Scheduler] uses +an [`Executor`][concurrent.futures.Executor], a builtin python native with +a `#!python submit(f, *args, **kwargs)` function to submit compute to +be compute else where, whether it be locally or remotely. -It's primary use is to dispatch tasks to an executor and manage callbacks -for when they complete. -""" +The `Scheduler` is primarily used to dispatch compute to an `Executor` and +emit [`@events`](site:reference/scheduling/events.md), which can +trigger user callbacks. + +Typically you should not use the `Scheduler` directly for dispatching and +responding to computed functions, but rather use a [`Task`](site:reference/scheduling/task.md). + +For a complete walk-through please check out the +[Scheduling Guide](site:guides/scheduling.md), this page is more for a quick reference. + +!!! note "Jupyter Notebook" + + If you are using a Jupyter Notebook, you _might_ need to replace usages + of [`#!python scheduler.run()`][amltk.scheduling.Scheduler.run]. + + You should instead use [`#!python scheduler.run_in_notebook()`][amltk.scheduling.Scheduler.run_in_notebook]. + This is most likely is only necessary for using the [`run(display=...)`][amltk.scheduling.Scheduler.run] feature. + +??? tip "Basic Usage" + + In this example, we create a scheduler that uses local processes as + workers. We then create a task that will run a function `fn` and submit it + to the scheduler. Lastly, a callback is registered to `@on_future_result` to print the + result when the compute is done. + + ```python exec="true" source="material-block" html="true" + from amltk.scheduling import Scheduler + + def fn(x: int) -> int: + return x + 1 + from amltk._doc import make_picklable; make_picklable(fn) # markdown-exec: hide + + scheduler = Scheduler.with_processes(1) + + @scheduler.on_start + def launch_the_compute(): + scheduler.submit(fn, 1) + + @scheduler.on_future_result + def callback(future, result): + print(f"Result: {result}") + + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler) # markdown-exec: hide + ``` + + The last line in the previous example called + [`scheduler.run()`][amltk.scheduling.Scheduler.run] is what starts the scheduler + running, in which it will first emit the `@on_start` event. This triggered the + callback `launch_the_compute()` which submitted the function `fn` with the + arguments `#!python 1`. + + The scheduler then ran the compute and waited for it to complete, emitting the + `@on_future_result` event when it was done successfully. This triggered the callback + `callback()` which printed the result. + + At this point, there is no more compute happening and no more events to respond to + so the scheduler will halt. + +There are many `@events` emitted by the `Scheduler`, with the most important being +[`@on_start`][amltk.scheduling.Scheduler.on_start]. There are however many more events +you can respond to. + +??? example "`@events`" + + Check out the [`@events`](site:reference/scheduling/events.md) + reference for more on how to customize these callbacks. + + === "Scheduler Status Events" + + When the scheduler enters some important state, it will emit an event + to let you know. + + === "`@on_start`" + + ::: amltk.scheduling.Scheduler.on_start + + === "`@on_finishing`" + + ::: amltk.scheduling.Scheduler.on_finishing + + === "`@on_finished`" + + ::: amltk.scheduling.Scheduler.on_finished + + === "`@on_stop`" + + ::: amltk.scheduling.Scheduler.on_stop + + === "`@on_timeout`" + + ::: amltk.scheduling.Scheduler.on_timeout + + === "`@on_empty`" + + ::: amltk.scheduling.Scheduler.on_empty + + === "Submitted Compute Events" + + When any compute goes through the `Scheduler`, it will emit an event + to let you know. You should however prefer to use a + [`Task`](site:reference/scheduling/task.md) as it will emit specific events + for the task at hand, and not all compute. + + === "`@on_future_submitted`" + + ::: amltk.scheduling.Scheduler.on_future_submitted + + === "`@on_future_result`" + + ::: amltk.scheduling.Scheduler.on_future_result + + === "`@on_future_exception`" + + ::: amltk.scheduling.Scheduler.on_future_exception + + === "`@on_future_done`" + + ::: amltk.scheduling.Scheduler.on_future_done + + === "`@on_future_cancelled`" + + ::: amltk.scheduling.Scheduler.on_future_cancelled + +There are various ways to [`run()`][amltk.scheduling.Scheduler.run] the +scheduler, notably how long it should run with `timeout=` and also how +it should react to any exception that may have occurred within the `Scheduler` +itself or your callbacks. + +??? tip "Usage of `run()`" + + Please see the [`run()`][amltk.scheduling.Scheduler.run] API doc for more + details and features, however we show two common use cases of using the `timeout=` + parameter. + + === "`run(timeout=...)`" + + You can tell the `Scheduler` to stop after a certain amount of time + with the `timeout=` argument to [`run()`][amltk.scheduling.Scheduler.run]. + + This will also trigger the `@on_timeout` event as seen in the `Scheduler` output. + + ```python exec="true" source="material-block" html="True" hl_lines="19" + import time + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + def expensive_function() -> int: + time.sleep(0.1) + return 42 + from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide + + @scheduler.on_start + def submit_calculations() -> None: + scheduler.submit(expensive_function) + + # The will endlessly loop the scheduler + @scheduler.on_future_done + def submit_again(future: Future) -> None: + if scheduler.running(): + scheduler.submit(expensive_function) + + scheduler.run(timeout=1) # End after 1 second + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + + === "`run(timeout=..., wait=False)`" + + By specifying that the `Scheduler` should not wait for ongoing tasks + to finish, the `Scheduler` will attempt to cancel and possibly terminate + any running tasks. + + ```python exec="true" source="material-block" html="True" + import time + from amltk.scheduling import Scheduler + + scheduler = Scheduler.with_processes(1) + + def expensive_function() -> None: + time.sleep(10) + + from amltk._doc import make_picklable; make_picklable(expensive_function) # markdown-exec: hide + + @scheduler.on_start + def submit_calculations() -> None: + scheduler.submit(expensive_function) + + scheduler.run(timeout=1, wait=False) # End after 1 second + from amltk._doc import doc_print; doc_print(print, scheduler, output="html", fontsize="small") # markdown-exec: hide + ``` + + ??? info "Forcibly Terminating Workers" + + As an `Executor` does not provide an interface to forcibly + terminate workers, we provide `Scheduler(terminate=...)` as a custom + strategy for cleaning up a provided executor. It is not possible + to terminate running thread based workers, for example using + `ThreadPoolExecutor` and any Executor using threads to spawn + tasks will have to wait until all running tasks are finish + before python can close. + + It's likely `terminate` will trigger the `EXCEPTION` event for + any tasks that are running during the shutdown, **not*** + a cancelled event. This is because we use a + [`Future`][concurrent.futures.Future] + under the hood and these can not be cancelled once running. + However there is no guarantee of this and is up to how the + `Executor` handles this. + +Lastly, the `Scheduler` can render a live display using +[`run(display=...)`][amltk.scheduling.Scheduler.run]. This +require [`rich`](https://github.com/Textualize/rich) to be installed. You +can install this with `#!bash pip install rich` or `#!bash pip install amltk[rich]`. +""" # noqa: E501 from __future__ import annotations import asyncio import logging import warnings from asyncio import Future +from collections.abc import Callable, Iterable from concurrent.futures import Executor, ProcessPoolExecutor from dataclasses import dataclass from enum import Enum, auto from threading import Timer -from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + Literal, + ParamSpec, + TypeVar, + overload, +) +from typing_extensions import Self from uuid import uuid4 -from amltk.asyncm import ContextEvent -from amltk.events import Emitter, Event, Subscriber +from amltk._asyncm import ContextEvent +from amltk._functional import Flag from amltk.exceptions import SchedulerNotRunningError -from amltk.functional import Flag -from amltk.scheduling.sequential_executor import SequentialExecutor +from amltk.scheduling.events import Emitter, Event, Subscriber +from amltk.scheduling.executors import SequentialExecutor from amltk.scheduling.task import Task from amltk.scheduling.termination_strategies import termination_strategy if TYPE_CHECKING: from multiprocessing.context import BaseContext - from typing_extensions import ParamSpec, Self from rich.console import RenderableType from rich.live import Live - from amltk.dask_jobqueue import DJQ_NAMES - from amltk.scheduling.task_plugin import TaskPlugin + from amltk.scheduling.executors.dask_jobqueue import DJQ_NAMES + from amltk.scheduling.plugins import Plugin + from amltk.scheduling.plugins.comm import Comm P = ParamSpec("P") R = TypeVar("R") - CallableT = TypeVar("CallableT", bound=Callable) logger = logging.getLogger(__name__) -@dataclass -class ExitState: - """The exit state of a scheduler. - - Attributes: - reason: The reason for the exit. - exception: The exception that caused the exit, if any. - """ - - code: Scheduler.ExitCode - exception: BaseException | None = None - - class Scheduler: - """A scheduler for submitting tasks to an Executor. - - ```python - from amltk.scheduling import Scheduler - - # For your own custom Executor - scheduler = Scheduler(executor=...) - - # Create a scheduler which uses local processes as workers - scheduler = Scheduler.with_processes(2) - - # Run a function when the scheduler starts, twice - @scheduler.on_start(repeat=2) - def say_hello_world(): - print("hello world") - - @scheduler.on_finish - def say_goodbye_world(): - print("goodbye world") - - scheduler.run(timeout=10) - ``` - """ + """A scheduler for submitting tasks to an Executor.""" executor: Executor """The executor to use to run tasks.""" @@ -90,8 +279,9 @@ def say_goodbye_world(): """The queue of tasks running.""" on_start: Subscriber[[]] - """A [`Subscriber`][amltk.events.Subscriber] which is called when the - scheduler starts. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the + scheduler starts. This is the first event emitted by the scheduler and + one of the only ways to submit the initial compute to the scheduler. ```python @scheduler.on_start @@ -100,18 +290,18 @@ def my_callback(): ``` """ on_future_submitted: Subscriber[Future] - """A [`Subscriber`][amltk.events.Subscriber] which is called when - a future is submitted. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when + some compute is submitted. ```python - @scheduler.on_submission + @scheduler.on_future_submitted def my_callback(future: Future): ... ``` """ on_future_done: Subscriber[Future] - """A [`Subscriber`][amltk.events.Subscriber] which is called when - a future is done. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when + some compute is done, regardless of whether it was successful or not. ```python @scheduler.on_future_done @@ -120,8 +310,8 @@ def my_callback(future: Future): ``` """ on_future_result: Subscriber[Future, Any] - """A [`Subscriber`][amltk.events.Subscriber] which is called when - a future returned with a result. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when + a future returned with a result, no exception raise. ```python @scheduler.on_future_result @@ -130,8 +320,8 @@ def my_callback(future: Future, result: Any): ``` """ on_future_exception: Subscriber[Future, BaseException] - """A [`Subscriber`][amltk.events.Subscriber] which is called when - a future has an exception. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when + some compute raised an uncaught exception. ```python @scheduler.on_future_exception @@ -140,8 +330,9 @@ def my_callback(future: Future, exception: BaseException): ``` """ on_future_cancelled: Subscriber[Future] - """A [`Subscriber`][amltk.events.Subscriber] which is called - when a future is cancelled. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called + when a future is cancelled. This usually occurs due to the underlying Scheduler, + and is not something we do directly, other than when shutting down the scheduler. ```python @scheduler.on_future_cancelled @@ -150,8 +341,9 @@ def my_callback(future: Future): ``` """ on_finishing: Subscriber[[]] - """A [`Subscriber`][amltk.events.Subscriber] which is called when the - scheduler is finishing up. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the + scheduler is finishing up. This occurs right before the scheduler shuts down + the executor. ```python @scheduler.on_finishing @@ -160,8 +352,9 @@ def my_callback(): ``` """ on_finished: Subscriber[[]] - """A [`Subscriber`][amltk.events.Subscriber] which is called when - the scheduler finishes. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when + the scheduler is finished, has shutdown the executor and possibly + terminated any remaining compute. ```python @scheduler.on_finished @@ -169,18 +362,19 @@ def my_callback(): ... ``` """ - on_stop: Subscriber[[]] - """A [`Subscriber`][amltk.events.Subscriber] which is called when the - scheduler is stopped. + on_stop: Subscriber[str, BaseException | None] + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the + scheduler is has been stopped due to the [`stop()`][amltk.scheduling.Scheduler.stop] + method being called. ```python @scheduler.on_stop - def my_callback(): + def my_callback(stop_msg: str, exception: BaseException | None): ... ``` """ on_timeout: Subscriber[[]] - """A [`Subscriber`][amltk.events.Subscriber] which is called when + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the scheduler reaches the timeout. ```python @@ -190,8 +384,9 @@ def my_callback(): ``` """ on_empty: Subscriber[[]] - """A [`Subscriber`][amltk.events.Subscriber] which is called when the - queue is empty. + """A [`Subscriber`][amltk.scheduling.events.Subscriber] which is called when the + queue is empty. This can be useful to re-fill the queue and prevent the + scheduler from exiting. ```python @scheduler.on_empty @@ -203,7 +398,7 @@ def my_callback(): STARTED: Event[[]] = Event("on_start") FINISHING: Event[[]] = Event("on_finishing") FINISHED: Event[[]] = Event("on_finished") - STOP: Event[[]] = Event("on_stop") + STOP: Event[str, BaseException | None] = Event("on_stop") TIMEOUT: Event[[]] = Event("on_timeout") EMPTY: Event[[]] = Event("on_empty") FUTURE_SUBMITTED: Event[Future] = Event("on_future_submitted") @@ -220,24 +415,6 @@ def __init__( ) -> None: """Initialize a scheduler. - !!! note "Forcibully Terminating Workers" - - As an `Executor` does not provide an interface to forcibly - terminate workers, we provide `terminate` as a custom - strategy for cleaning up a provided executor. It is not possible - to terminate running thread based workers, for example using - `ThreadPoolExecutor` and any Executor using threads to spawn - tasks will have to wait until all running tasks are finish - before python can close. - - It's likely `terminate` will trigger the `EXCEPTION` event for - any tasks that are running during the shutdown, **not*** - a cancelled event. This is because we use a - [`Future`][concurrent.futures.Future] - under the hood and these can not be cancelled once running. - However there is no gaurantee of this and is up to how the - `Executor` handles this. - Args: executor: The dispatcher to use for submitting tasks. terminate: Whether to call shutdown on the executor when @@ -640,7 +817,7 @@ class in `dask_jobqueue` to use. For example, to use A new scheduler with a `dask_jobqueue` executor. """ try: - from amltk.dask_jobqueue import DaskJobqueueExecutor + from amltk.scheduling.executors.dask_jobqueue import DaskJobqueueExecutor except ImportError as e: raise ImportError( @@ -675,47 +852,70 @@ def running(self) -> bool: def submit( self, - function: Callable[P, R], + fn: Callable[P, R], + /, *args: P.args, **kwargs: P.kwargs, ) -> Future[R]: """Submits a callable to be executed with the given arguments. Args: - function: The callable to be executed as + fn: The callable to be executed as fn(*args, **kwargs) that returns a Future instance representing the execution of the callable. args: positional arguments to pass to the function kwargs: keyword arguments to pass to the function Raises: - Scheduler.NotRunningError: If the scheduler is not running. + SchedulerNotRunningError: If the scheduler is not running. + You can protect against this using, + [`scheduler.running()`][amltk.scheduling.scheduler.Scheduler.running]. Returns: A Future representing the given call. """ if not self.running(): msg = ( - f"Scheduler is not running, cannot submit task {function}" + f"Scheduler is not running, cannot submit task {fn}" f" with {args=}, {kwargs=}" ) raise SchedulerNotRunningError(msg) try: - sync_future = self.executor.submit(function, *args, **kwargs) + sync_future = self.executor.submit(fn, *args, **kwargs) future = asyncio.wrap_future(sync_future) except Exception as e: - logger.exception(f"Could not submit task {function}", exc_info=e) + logger.exception(f"Could not submit task {fn}", exc_info=e) raise e - self._register_future(future, function, *args, **kwargs) + self._register_future(future, fn, *args, **kwargs) return future + @overload + def task( + self, + function: Callable[Concatenate[Comm, P], R], + *, + plugins: Comm.Plugin | Iterable[Comm.Plugin | Plugin] = ..., + init_plugins: bool = ..., + ) -> Task[P, R]: + ... + + @overload def task( self, function: Callable[P, R], *, - plugins: TaskPlugin | Iterable[TaskPlugin] = (), + plugins: Plugin | Iterable[Plugin] = (), + init_plugins: bool = True, + ) -> Task[P, R]: + ... + + def task( + self, + function: Callable[P, R] | Callable[Concatenate[Comm, P], R], + *, + plugins: Plugin | Iterable[Plugin] = (), init_plugins: bool = True, ) -> Task[P, R]: """Create a new task. @@ -728,9 +928,13 @@ def task( Returns: A new task. """ - task = Task(function, self, plugins=plugins, init_plugins=init_plugins) + # HACK: Not that the type: ignore is due to the fact that we can't use type + # checking to enforce that + # A. `function` is a callable with the first arg being a Comm + # B. `plugins` + task = Task(function, self, plugins=plugins, init_plugins=init_plugins) # type: ignore self.add_renderable(task) - return task + return task # type: ignore def _register_future( self, @@ -814,7 +1018,7 @@ async def _run_scheduler( # noqa: C901, PLR0912, PLR0915 timeout: float | None = None, end_on_empty: bool = True, wait: bool = True, - ) -> ExitCode | BaseException: + ) -> ExitState.Code | BaseException: self.executor.__enter__() self._stop_event = ContextEvent() @@ -853,7 +1057,7 @@ async def _run_scheduler( # noqa: C901, PLR0912, PLR0915 if end_on_empty: self.on_empty(lambda: monitor_empty.cancel(), hidden=True) - # The timeout criterion is satisifed by the `timeout` arg + # The timeout criterion is satisfied by the `timeout` arg await asyncio.wait( [stop_triggered, monitor_empty], timeout=timeout, @@ -861,31 +1065,31 @@ async def _run_scheduler( # noqa: C901, PLR0912, PLR0915 ) # Determine the reason for stopping - stop_reason: BaseException | Scheduler.ExitCode + stop_reason: BaseException | ExitState.Code if stop_triggered.done() and self._stop_event.is_set(): - stop_reason = Scheduler.ExitCode.STOPPED + stop_reason = ExitState.Code.STOPPED msg, exception = self._stop_event.context _log = logger.exception if exception else logger.debug _log(f"Stop Message: {msg}", exc_info=exception) - self.on_stop.emit() + self.on_stop.emit(str(msg), exception) if self._end_on_exception_flag and exception: stop_reason = exception else: - stop_reason = Scheduler.ExitCode.STOPPED + stop_reason = ExitState.Code.STOPPED elif monitor_empty.done(): logger.debug("Scheduler stopped due to being empty.") - stop_reason = Scheduler.ExitCode.EXHAUSTED + stop_reason = ExitState.Code.EXHAUSTED elif timeout is not None: logger.debug(f"Scheduler stopping as {timeout=} reached.") - stop_reason = Scheduler.ExitCode.TIMEOUT + stop_reason = ExitState.Code.TIMEOUT self.on_timeout.emit() else: logger.warning("Scheduler stopping for unknown reason!") - stop_reason = Scheduler.ExitCode.UNKNOWN + stop_reason = ExitState.Code.UNKNOWN - # Stop all runnings async tasks, i.e. monitoring the queue to trigger an event + # Stop all running async tasks, i.e. monitoring the queue to trigger an event tasks = [monitor_empty, stop_triggered] for task in tasks: task.cancel() @@ -943,20 +1147,34 @@ def run( Args: timeout: The maximum time to run the scheduler for in seconds. Defaults to `None` which means no timeout and it - will end once the queue becomes empty. - end_on_empty: Whether to end the scheduler when the - queue becomes empty. Defaults to `True`. - wait: Whether to wait for the executor to shutdown. - on_exception: What to do when an exception occurs. - If "raise", the exception will be raised. - If "ignore", the scheduler will continue running. - If "end", the scheduler will end but not raise. + will end once the queue is empty if `end_on_empty=True`. + end_on_empty: Whether to end the scheduler when the queue becomes empty. + wait: Whether to wait for currently running compute to finish once + the `Scheduler` is shutting down. + + * If `#!python True`, will wait for all currently running compute. + * If `#!python False`, will attempt to cancel/terminate all currently + running compute and shutdown the executor. This may be useful + if you want to end the scheduler as quickly as possible or + respect the `timeout=` more precisely. + on_exception: What to do when an exception occurs in the scheduler + or callbacks (**Does not apply to submitted compute!**) + + * If `#!python "raise"`, the scheduler will stop and raise the + exception at the point where you called `run()`. + * If `#!python "ignore"`, the scheduler will continue running, + ignoring the exception. This may be useful when requiring more + robust execution. + * If `#!python "end"`, similar to `#!python "raise"`, the scheduler + will stop but no exception will occur and the control flow + will return gracefully to the point where you called `run()`. asyncio_debug_mode: Whether to run the async loop in debug mode. Defaults to `False`. Please see [asyncio.run][] for more. - display: Whether to display things in the console. - If `True`, will display the scheduler and all its - renderables. If a list of renderables, will display - the scheduler itself plus those renderables. + display: Whether to display the scheduler live in the console. + + * If `#!python True`, will display the scheduler and all its tasks. + * If a `#!python list[RenderableType]` , will display the scheduler + itself plus those renderables. Returns: The reason for the scheduler ending. @@ -986,23 +1204,10 @@ async def async_run( ) -> ExitState: """Async version of `run`. - Args: - timeout: The maximum time to run the scheduler for. - Defaults to `None` which means no timeout. - end_on_empty: Whether to end the scheduler when the - queue becomes empty. Defaults to `True`. - wait: Whether to wait for the executor to shutdown. - on_exception: Whether to end if an exception occurs. - if "raise", the exception will be raised. - If "ignore", the scheduler will continue running. - If "end", the scheduler will end but not raise. - display: Whether to display things in the console. - If `True`, will display the scheduler and all its - renderables. If a list of renderables, will display - the scheduler itself plus those renderables. + This can be useful if you are already running in an async context, + such as in a web server or Jupyter notebook. - Returns: - The reason for the scheduler ending. + Please see [`run()`][amltk.Scheduler.run] for more details. """ if self.running(): raise RuntimeError("Scheduler already seems to be running") @@ -1038,7 +1243,6 @@ def custom_exception_handler( ) -> None: exception = context.get("exception") message = context.get("message") - self.stop(stop_msg=message, exception=exception) # handle with previous handler if previous_exception_handler: @@ -1046,6 +1250,8 @@ def custom_exception_handler( else: loop.default_exception_handler(context) + self.stop(stop_msg=message, exception=exception) + loop.set_exception_handler(custom_exception_handler) # Run the actual scheduling loop @@ -1069,14 +1275,24 @@ def custom_exception_handler( if on_exception == "raise": raise result - return ExitState(code=Scheduler.ExitCode.EXCEPTION, exception=result) + return ExitState(code=ExitState.Code.EXCEPTION, exception=result) return ExitState(code=result) run_in_notebook = async_run - """Alias for [`async_run()`][amltk.Scheduler.async_run]""" + """Alias for [`async_run()`][amltk.Scheduler.async_run]. - def stop(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + This allows the `Scheduler` to be run in a Jupyter notebook, which + happens to be inside an async context. + """ + + def stop( + self, + *args: Any, + stop_msg: str | None = None, + exception: BaseException | None = None, + **kwargs: Any, + ) -> None: """Stop the scheduler. The scheduler will stop, finishing currently running tasks depending @@ -1088,21 +1304,20 @@ def stop(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 Args: *args: Logged in a debug message **kwargs: Logged in a debug message - - * **stop_msg**: The message to pass to the stop event which - gets logged as the stop reason. - - * **exception**: The exception to pass to the stop event which - gets logged as the stop reason. + stop_msg: The message to log when stopping the scheduler. + exception: The exception which incited `stop()` to be called. + Will be used by the `Scheduler` to possibly raise the exception + to the user. """ if not self.running(): return assert self._stop_event is not None - msg = kwargs.get("stop_msg", "stop() called") + msg = stop_msg if stop_msg is not None else "scheduler.stop() was called." + logger.debug(f"Stopping scheduler: {msg} {args=} {kwargs=}") - self._stop_event.set(msg=f"{msg}", exception=kwargs.get("exception")) + self._stop_event.set(msg=msg, exception=exception) self._running_event.clear() @staticmethod @@ -1141,7 +1356,7 @@ def _end_pending( executor.shutdown(wait=wait) def add_renderable(self, renderable: RenderableType) -> None: - """Add a renderable to the scheduler. + """Add a renderable object to the scheduler. This will be displayed whenever the scheduler is displayed. """ @@ -1154,8 +1369,8 @@ def __rich__(self) -> RenderableType: from rich.text import Text from rich.tree import Tree - from amltk.richutil import richify - from amltk.richutil.renderers.function import Function + from amltk._richutil import richify + from amltk._richutil.renderers.function import Function MAX_FUTURE_ITEMS = 5 OFFSETS = 1 + 1 + 2 # Header + ellipses space + panel borders @@ -1196,14 +1411,14 @@ def __rich__(self) -> RenderableType: ) layout_table.add_row(richify(self.executor), future_table) - title = Panel( + panel = Panel( layout_table, title=title, title_align="left", border_style="magenta", height=MAX_FUTURE_ITEMS + OFFSETS, ) - tree = Tree(title, guide_style="magenta bold") + tree = Tree(panel, guide_style="magenta bold") for renderable in self._renderables: tree.add(renderable) @@ -1213,7 +1428,20 @@ def __rich__(self) -> RenderableType: return Group(tree, *self._extra_renderables) - class ExitCode(Enum): + +@dataclass +class ExitState: + """The exit state of a scheduler. + + Attributes: + reason: The reason for the exit. + exception: The exception that caused the exit, if any. + """ + + code: ExitState.Code + exception: BaseException | None = None + + class Code(Enum): """The reason the scheduler ended.""" STOPPED = auto() diff --git a/src/amltk/scheduling/task.py b/src/amltk/scheduling/task.py index 9c0b0c52..0d911c92 100644 --- a/src/amltk/scheduling/task.py +++ b/src/amltk/scheduling/task.py @@ -1,117 +1,117 @@ -"""This module holds the definition of a Task. +"""A [`Task`][amltk.scheduling.task.Task] is a unit of work that can be scheduled by the +[`Scheduler`][amltk.scheduling.Scheduler]. -A Task is a unit of work that can be scheduled by the scheduler. It is -defined by its name, its function, and it's `Future` representing the -final outcome of the task. -""" -from __future__ import annotations +It is defined by its `function=` to call. Whenever a `Task` +has its [`submit()`][amltk.scheduling.task.Task.submit] method called, +the function will be dispatched to run by a `Scheduler`. -import logging -from asyncio import Future -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Iterable, - TypeVar, - overload, -) -from typing_extensions import Concatenate, ParamSpec, Self, override -from uuid import uuid4 as uuid +When a task has returned, either successfully, or with an exception, +it will emit `@events` to indicate so. You can subscribe to these events +with callbacks and act accordingly. -from more_itertools import first_true -from amltk.events import Emitter, Event, Subscriber -from amltk.exceptions import EventNotKnownError, SchedulerNotRunningError -from amltk.functional import callstring -from amltk.scheduling.task_plugin import TaskPlugin +??? example "`@events`" -if TYPE_CHECKING: - from rich.panel import Panel + Check out the [`@events`](site:reference/scheduling/events.md) reference + for more on how to customize these callbacks. You can also take a look + at the API of [`on()`][amltk.scheduling.task.Task.on] for more information. - from amltk.scheduling.scheduler import Scheduler + === "`@on_result`" -logger = logging.getLogger(__name__) + ::: amltk.scheduling.task.Task.on_result + === "`@on_exception`" -P = ParamSpec("P") -P2 = ParamSpec("P2") + ::: amltk.scheduling.task.Task.on_exception -R = TypeVar("R") -R2 = TypeVar("R2") -CallableT = TypeVar("CallableT", bound=Callable) + === "`@on_done`" + + ::: amltk.scheduling.task.Task.on_done + + === "`@on_submitted`" + ::: amltk.scheduling.task.Task.on_submitted -class Task(Generic[P, R]): - """A task is a unit of work that can be scheduled by the scheduler. + === "`@on_cancelled`" - It is defined by its `function` to call. Whenever a task - has its `__call__` method called, the function will be dispatched to run - by a [`Scheduler`][amltk.scheduling.scheduler.Scheduler]. + ::: amltk.scheduling.task.Task.on_cancelled - To interact with the results of these tasks, you must subscribe to to these - events and provide callbacks. +??? tip "Usage" The usual way to create a task is with - [`Scheduler.task()`][amltk.scheduling.scheduler.Scheduler.task]. + [`Scheduler.task()`][amltk.scheduling.scheduler.Scheduler.task], + where you provide the `function=` to call. - ```python - from amltk import Task, Scheduler + ```python exec="true" source="material-block" html="true" + from amltk import Scheduler - # Define some function to run def f(x: int) -> int: return x * 2 + from amltk._doc import make_picklable; make_picklable(f) # markdown-exec: hide - # And a scheduler to run it on scheduler = Scheduler.with_processes(2) + task = scheduler.task(f) - # Create the task object - my_task = scheduler.task(f) + @scheduler.on_start + def on_start(): + task.submit(1) - # Subscribe to events - @my_task.on_result - def print_result(future: Future[int], result: int): - print(f"Future {future} returned {result}") + @task.on_result + def on_result(future: Future[int], result: int): + print(f"Task {future} returned {result}") - @my_task.on_exception - def print_exception(future: Future[int], result: int): - print(error) + scheduler.run() + from amltk._doc import doc_print; doc_print(print, scheduler) # markdown-exec: hide ``` - If providing `plugins=` to the task, these may add new events that will be emitted - from the task. Do listen to these events, you must use the `on` method. Please - see their respective documentation. + If you'd like to simply just call the original function, without submitting it to + the scheduler, you can always just call the task directly, i.e. `#!python task(1)`. - ```python - from amltk import Scheduler, Task - from amltk.scheduling import CallLimiter +You can also provide [`Plugins`](site:reference/scheduling/plugins.md) to the task, +to modify tasks, add functionality and add new events. + +Please check out the [Scheduling Guide](site:guides/scheduling.md) +for a more detailed walkthrough. +""" +from __future__ import annotations + +import logging +from asyncio import Future +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any, Concatenate, Generic, TypeVar, overload +from typing_extensions import ParamSpec, Self, override +from more_itertools import first_true - def f() -> None: - print("task ran") +from amltk._functional import callstring +from amltk._richutil.renderable import RichRenderable +from amltk.exceptions import EventNotKnownError, SchedulerNotRunningError +from amltk.randomness import randuid +from amltk.scheduling.events import Emitter, Event, Subscriber +from amltk.scheduling.plugins.plugin import Plugin +if TYPE_CHECKING: + from rich.panel import Panel - scheduler = Scheduler.with_processes(1) - task = scheduler.task(f, plugins=[CallLimiter(max_calls=1)]) + from amltk.scheduling.scheduler import Scheduler - @scheduler.on_start(repeat=3) - def start(): - task() +logger = logging.getLogger(__name__) - @task.on(CallLimiter.CALL_LIMIT_REACHED) - def on_limit_reached(task: Task, *args, **kwargs): - print(f"Task {task} reached its call limit with {args=} and {kwargs=}") - scheduler.run() - ``` - """ +P = ParamSpec("P") +P2 = ParamSpec("P2") + +R = TypeVar("R") +R2 = TypeVar("R2") +CallableT = TypeVar("CallableT", bound=Callable) + + +class Task(RichRenderable, Generic[P, R]): + """The task class.""" - uuid: str - """A unique identifier for this task.""" unique_ref: str """A unique reference to this task.""" - plugins: list[TaskPlugin] + plugins: list[Plugin] """The plugins to use for this task.""" function: Callable[P, R] """The function of this task""" @@ -183,7 +183,7 @@ def __init__( function: Callable[P, R], scheduler: Scheduler, *, - plugins: TaskPlugin | Iterable[TaskPlugin] = (), + plugins: Plugin | Iterable[Plugin] = (), init_plugins: bool = True, ) -> None: """Initialize a task. @@ -195,12 +195,12 @@ def __init__( init_plugins: Whether to initialize the plugins or not. """ super().__init__() - self.unique_ref = str(uuid()) + self.unique_ref = randuid(8) self.emitter = Emitter() self.event_counts = self.emitter.event_counts - self.plugins: list[TaskPlugin] = ( - [plugins] if isinstance(plugins, TaskPlugin) else list(plugins) + self.plugins: list[Plugin] = ( + [plugins] if isinstance(plugins, Plugin) else list(plugins) ) self.function: Callable[P, R] = function self.scheduler: Scheduler = scheduler @@ -252,24 +252,11 @@ def on( ) -> Subscriber[...]: ... - @overload - def on( - self, - event: Event[P2], - callback: Callable[P2, Any] | Iterable[Callable[P2, Any]], - *, - when: Callable[[], bool] | None = ..., - limit: int | None = ..., - repeat: int = ..., - every: int = ..., - ) -> None: - ... - @overload def on( self, event: str, - callback: Callable | Iterable[Callable], + callback: Callable, *, when: Callable[[], bool] | None = ..., limit: int | None = ..., @@ -281,7 +268,7 @@ def on( def on( self, event: Event[P2] | str, - callback: Callable[P2, Any] | Iterable[Callable[P2, Any]] | None = None, + callback: Callable[P2, Any] | None = None, *, when: Callable[[], bool] | None = None, limit: int | None = None, @@ -300,7 +287,7 @@ def on( repeat: The number of times to repeat the subscription. every: The number of times to wait between repeats. hidden: Whether to hide the callback in visual output. - This is mainly used to facilitate TaskPlugins who + This is mainly used to facilitate Plugins who act upon events but don't want to be seen, primarily as they are just book-keeping callbacks. @@ -348,32 +335,12 @@ def submit(self, *args: P.args, **kwargs: P.kwargs) -> Future[R] | None: Returns: The future of the task, or `None` if the limit was reached. - """ - return self.__call__(*args, **kwargs) - - def copy(self, *, init_plugins: bool = True) -> Self: - """Create a copy of this task. - - Will use the same scheduler and function, but will have a different - event manager such that any events listend to on the old task will - **not** trigger with the copied task. - Args: - init_plugins: Whether to initialize the copied plugins on the copied - task. Usually you will want to leave this as `True`. - - Returns: - A copy of this task. + Raises: + SchedulerNotRunningError: If the scheduler is not running. + You can protect against this using, + [`scheduler.running()`][amltk.scheduling.scheduler.Scheduler.running]. """ - return self.__class__( - self.function, - self.scheduler, - plugins=tuple(p.copy() for p in self.plugins), - init_plugins=init_plugins, - ) - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Future[R] | None: - """Please see [`Task.submit()`][amltk.Task.submit].""" # Inform all plugins that the task is about to be called # They have chance to cancel submission based on their return # value. @@ -413,6 +380,27 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Future[R] | None: future.add_done_callback(self._process_future) return future + def copy(self, *, init_plugins: bool = True) -> Self: + """Create a copy of this task. + + Will use the same scheduler and function, but will have a different + event manager such that any events listend to on the old task will + **not** trigger with the copied task. + + Args: + init_plugins: Whether to initialize the copied plugins on the copied + task. Usually you will want to leave this as `True`. + + Returns: + A copy of this task. + """ + return self.__class__( + self.function, + self.scheduler, + plugins=tuple(p.copy() for p in self.plugins), + init_plugins=init_plugins, + ) + def _process_future(self, future: Future[R]) -> None: try: self.queue.remove(future) @@ -432,7 +420,7 @@ def _process_future(self, future: Future[R]) -> None: result = future.result() self.on_result.emit(future, result) - def attach_plugin(self, plugin: TaskPlugin) -> None: + def attach_plugin(self, plugin: Plugin) -> None: """Attach a plugin to this task. Args: @@ -462,17 +450,29 @@ def __repr__(self) -> str: kwargs_str = ", ".join(f"{k}={v}" for k, v in kwargs.items()) return f"{self.__class__.__name__}({kwargs_str})" + @override def __rich__(self) -> Panel: + from rich.console import Group as RichGroup from rich.panel import Panel + from rich.text import Text from rich.tree import Tree - from amltk.richutil import Function + from amltk._richutil import Function + + items: list[RichRenderable | Tree] = [] + + if any(self.plugins): + for plugin in self.plugins: + items.append(plugin) tree = Tree(label="", hide_root=True) tree.add(self.emitter) + items.append(tree) + return Panel( - tree, + RichGroup(*items), title=Function(self.function, prefix="Task").__rich__(), title_align="left", border_style="deep_sky_blue2", + subtitle=Text("Ref: ").append(self.unique_ref, "yellow italic"), ) diff --git a/src/amltk/scheduling/task_plugin.py b/src/amltk/scheduling/task_plugin.py deleted file mode 100644 index 2fa2cef5..00000000 --- a/src/amltk/scheduling/task_plugin.py +++ /dev/null @@ -1,357 +0,0 @@ -"""This module contains the TaskPlugin class.""" -from __future__ import annotations - -import warnings -from abc import ABC, abstractmethod -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar -from typing_extensions import ParamSpec, Self, override - -from amltk.events import Event - -if TYPE_CHECKING: - from amltk.scheduling.task import Task - -P = ParamSpec("P") -R = TypeVar("R") -TrialInfo = TypeVar("TrialInfo") - - -class TaskPlugin(ABC): - """A plugin that can be attached to a Task. - - By inheriting from a `TaskPlugin`, you can hook into a - [`Task`][amltk.scheduling.Task]. A plugin can affect, modify and extend its - behaviours. Please see the documentation of the methods for more information. - Creating a plugin is only necesary if you need to modify actual behaviour of - the task. For siply hooking into the lifecycle of a task, you can use the events - that a [`Task`][amltk.scheduling.Task] emits. - - For an example of a simple plugin, see the - [`CallLimiter`][amltk.scheduling.CallLimiter] plugin which prevents - the task being submitted if for example, it has already been submitted - too many times. - - All methods are optional, and you can choose to implement only the ones - you need. Most plugins will likely need to implement the - [`attach_task()`][amltk.scheduling.TaskPlugin.attach_task] method, which is called - when the plugin is attached to a task. In this method, you can for - example subscribe to events on the task, create new subscribers for people - to use or even store a reference to the task for later use. - - Plugins are also encouraged to utilize the events of a - [`Task`][amltk.scheduling.Task] to further hook into the lifecycle of the task. - For exampe, by saving a reference to the task in the `attach_task()` method, you - can use the [`emit()`][amltk.scheduling.Task] method of the task to emit - your own specialized events. - - !!! note "Methods" - - * [`attach_task()`][amltk.scheduling.TaskPlugin.attach_task] - * [`pre_submit()`][amltk.scheduling.TaskPlugin.pre_submit] - """ - - name: ClassVar[str] - """The name of the plugin. - - This is used to identify the plugin during logging. - """ - - def attach_task(self, task: Task) -> None: # noqa: B027 - """Attach the plugin to a task. - - This method is called when the plugin is attached to a task. This - is the place to subscribe to events on the task, create new subscribers - for people to use or even store a reference to the task for later use. - - Args: - task: The task the plugin is being attached to. - """ - - def pre_submit( - self, - fn: Callable[P, R], - *args: P.args, - **kwargs: P.kwargs, - ) -> tuple[Callable[P, R], tuple, dict] | None: - """Pre-submit hook. - - This method is called before the task is submitted. - - Args: - fn: The task function. - *args: The arguments to the task function. - **kwargs: The keyword arguments to the task function. - - Returns: - A tuple of the task function, arguments and keyword arguments - if the task should be submitted, or `None` if the task should - not be submitted. - """ - return fn, args, kwargs - - def events(self) -> list[Event]: - """Return a list of events that this plugin emits. - - Likely no need to override this method, as it will automatically - return all events defined on the plugin. - """ - inherited_attrs = chain.from_iterable( - vars(cls).values() for cls in self.__class__.__mro__ - ) - return [attr for attr in inherited_attrs if isinstance(attr, Event)] - - @abstractmethod - def copy(self) -> Self: - """Return a copy of the plugin. - - This method is used to create a copy of the plugin when a task is - copied. This is useful if the plugin stores a reference to the task - it is attached to, as the copy will need to store a reference to the - copy of the task. - """ - ... - - -class CallLimiter(TaskPlugin): - """A plugin that limits the submission of a task. - - Adds three new events to the task: - - * [`CALL_LIMIT_REACHED`][amltk.scheduling.CallLimiter.CALL_LIMIT_REACHED] - - subscribe with `@task.on("call-limit-reached")` - * [`CONCURRENT_LIMIT_REACHED`][amltk.scheduling.CallLimiter.CONCURRENT_LIMIT_REACHED] - - subscribe with `@task.on("concurrent-limit-reached")` - * [`DISABLED_DUE_TO_RUNNING_TASK`][amltk.scheduling.CallLimiter.DISABLED_DUE_TO_RUNNING_TASK] - - subscribe with `@task.on("disabled-due-to-running-task")` - """ # noqa: E501 - - name: ClassVar = "call-limiter" - """The name of the plugin.""" - - CALL_LIMIT_REACHED: Event[...] = Event("call-limit-reached") - """The event emitted when the task has reached its call limit. - - Will call any subscribers with the task as the first argument, - followed by the arguments and keyword arguments that were passed to the task. - - ```python - @task.on("call-limit-reached") - def on_call_limit_reached(task: Task, *args, **kwargs): - ... - ``` - """ - - CONCURRENT_LIMIT_REACHED: Event[...] = Event("concurrent-limit-reached") - """The event emitted when the task has reached its concurrent call limit. - - Will call any subscribers with the task as the first argument, followed by the - arguments and keyword arguments that were passed to the task. - - ```python - @task.on("concurrent-limit-reached") - def on_concurrent_limit_reached(task: Task, *args, **kwargs): - ... - ``` - """ - - DISABLED_DUE_TO_RUNNING_TASK: Event[...] = Event("disabled-due-to-running-task") - """The event emitter when the task was not submitted due to some other - running task. - - Will call any subscribers with the task as first argument, followed by - the arguments and keyword arguments that were passed to the task. - - ```python - @task.on("disabled-due-to-running-task") - def on_disabled_due_to_running_task(task: Task, *args, **kwargs): - ... - ``` - """ - - def __init__( - self, - *, - max_calls: int | None = None, - max_concurrent: int | None = None, - not_while_running: Task | Iterable[Task] | None = None, - ): - """Initialize the plugin. - - Args: - max_calls: The maximum number of calls to the task. - max_concurrent: The maximum number of calls of this task that can - be in the queue. - not_while_running: A task or iterable of tasks that if active, will prevent - this task from being submitted. - """ - super().__init__() - - if not_while_running is None: - not_while_running = [] - elif isinstance(not_while_running, Iterable): - not_while_running = list(not_while_running) - else: - not_while_running = [not_while_running] - - self.max_calls = max_calls - self.max_concurrent = max_concurrent - self.not_while_running = not_while_running - self.task: Task | None = None - - if isinstance(max_calls, int) and not max_calls > 0: - raise ValueError("max_calls must be greater than 0") - - if isinstance(max_concurrent, int) and not max_concurrent > 0: - raise ValueError("max_concurrent must be greater than 0") - - self._calls = 0 - self._concurrent = 0 - - @override - def attach_task(self, task: Task) -> None: - """Attach the plugin to a task.""" - self.task = task - - if self.task in self.not_while_running: - raise ValueError( - f"Task {self.task} was found in the {self.not_while_running=}" - " list. This is disabled but please raise an issue if you think this" - " has sufficient use case.", - ) - - task.emitter.add_event( - self.CALL_LIMIT_REACHED, - self.CONCURRENT_LIMIT_REACHED, - self.DISABLED_DUE_TO_RUNNING_TASK, - ) - - # Make sure to increment the count when a task was submitted - task.on_submitted(self._increment_call_count, hidden=True) - - @override - def pre_submit( - self, - fn: Callable[P, R], - *args: P.args, - **kwargs: P.kwargs, - ) -> tuple[Callable[P, R], tuple, dict] | None: - """Pre-submit hook. - - Prevents submission of the task if it exceeds any of the set limits. - """ - assert self.task is not None - - if self.max_calls is not None and self._calls >= self.max_calls: - self.task.emitter.emit(self.CALL_LIMIT_REACHED, self.task, *args, **kwargs) - return None - - if ( - self.max_concurrent is not None - and len(self.task.queue) >= self.max_concurrent - ): - self.task.emitter.emit( - self.CONCURRENT_LIMIT_REACHED, - self.task, - *args, - **kwargs, - ) - return None - - for other_task in self.not_while_running: - if other_task.running(): - self.task.emitter.emit( - self.DISABLED_DUE_TO_RUNNING_TASK, - other_task, - self.task, - *args, - **kwargs, - ) - return None - - return fn, args, kwargs - - @override - def copy(self) -> Self: - """Return a copy of the plugin.""" - return self.__class__( - max_calls=self.max_calls, - max_concurrent=self.max_concurrent, - ) - - def _increment_call_count(self, *_: Any, **__: Any) -> None: - self._calls += 1 - - -class _IgnoreWarningWrapper(Generic[P, R]): - """A wrapper to ignore warnings.""" - - def __init__( - self, - fn: Callable[P, R], - *warning_args: Any, - **warning_kwargs: Any, - ): - """Initialize the wrapper. - - Args: - fn: The function to wrap. - *warning_args: arguments to pass to - [`warnings.filterwarnings`][warnings.filterwarnings]. - **warning_kwargs: keyword arguments to pass to - [`warnings.filterwarnings`][warnings.filterwarnings]. - """ - super().__init__() - self.fn = fn - self.warning_args = warning_args - self.warning_kwargs = warning_kwargs - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - with warnings.catch_warnings(): - warnings.filterwarnings(*self.warning_args, **self.warning_kwargs) - return self.fn(*args, **kwargs) - - -class WarningFilterPlugin(TaskPlugin): - """A plugin that disables warnings emitted from tasks.""" - - name: ClassVar = "warning-filter" - """The name of the plugin.""" - - def __init__(self, *args: Any, **kwargs: Any): - """Initialize the plugin. - - Args: - *args: arguments to pass to - [`warnings.filterwarnings`][warnings.filterwarnings]. - **kwargs: keyword arguments to pass to - [`warnings.filterwarnings`][warnings.filterwarnings]. - """ - super().__init__() - self.task: Task | None = None - self.warning_args = args - self.warning_kwargs = kwargs - - @override - def attach_task(self, task: Task) -> None: - """Attach the plugin to a task.""" - self.task = task - - @override - def pre_submit( - self, - fn: Callable[P, R], - *args: P.args, - **kwargs: P.kwargs, - ) -> tuple[Callable[P, R], tuple, dict]: - """Pre-submit hook. - - Wraps the function to ignore warnings. - """ - wrapped_f = _IgnoreWarningWrapper(fn, *self.warning_args, **self.warning_kwargs) - return wrapped_f, args, kwargs - - @override - def copy(self) -> Self: - """Return a copy of the plugin.""" - return self.__class__(*self.warning_args, **self.warning_kwargs) diff --git a/src/amltk/scheduling/termination_strategies.py b/src/amltk/scheduling/termination_strategies.py index 08e978d5..d93dcf92 100644 --- a/src/amltk/scheduling/termination_strategies.py +++ b/src/amltk/scheduling/termination_strategies.py @@ -13,9 +13,10 @@ """ from __future__ import annotations +from collections.abc import Callable from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor from contextlib import suppress -from typing import Callable, TypeVar +from typing import TypeVar import psutil diff --git a/src/amltk/sklearn/__init__.py b/src/amltk/sklearn/__init__.py index ac6ed6e3..24949eaf 100644 --- a/src/amltk/sklearn/__init__.py +++ b/src/amltk/sklearn/__init__.py @@ -1,8 +1,14 @@ -from amltk.sklearn.builder import build as sklearn_pipeline from amltk.sklearn.data import split_data, train_val_test_split +from amltk.sklearn.estimators import ( + StoredPredictionClassifier, + StoredPredictionRegressor, +) +from amltk.sklearn.voting import voting_with_preffited_estimators __all__ = [ "train_val_test_split", "split_data", - "sklearn_pipeline", + "StoredPredictionRegressor", + "StoredPredictionClassifier", + "voting_with_preffited_estimators", ] diff --git a/src/amltk/sklearn/builder.py b/src/amltk/sklearn/builder.py deleted file mode 100644 index bf4ee9e6..00000000 --- a/src/amltk/sklearn/builder.py +++ /dev/null @@ -1,229 +0,0 @@ -"""Builds an sklearn.pipeline.Pipeline from a amltk.pipeline.Pipeline.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Iterable, TypeVar, Union - -from sklearn.compose import ColumnTransformer -from sklearn.pipeline import Pipeline as SklearnPipeline - -from amltk.pipeline.components import Component, Group, Split - -if TYPE_CHECKING: - from typing_extensions import TypeAlias - - from amltk.pipeline.pipeline import Pipeline - from amltk.pipeline.step import Step - -COLUMN_TRANSFORMER_ARGS = [ - "remainder", - "sparse_threshold", - "n_jobs", - "transformer_weights", - "verbose", - "verbose_feature_names_out", -] - -# TODO: We can make this more explicit with typing out sklearn types. -# However sklearn operates in a bit more of a general level so it would -# require creating protocols to type this properly and work with sklearn's -# duck-typing. -SklearnItem: TypeAlias = Union[Any, ColumnTransformer] -SklearnPipelineT = TypeVar("SklearnPipelineT", bound=SklearnPipeline) - - -def process_component( - step: Component[SklearnItem, Any], -) -> Iterable[tuple[str, SklearnItem]]: - """Process a single step into a tuple of (name, component) for sklearn. - - Args: - step: The step to process - - Returns: - tuple[str, SklearnComponent]: The name and component for sklearn - """ - yield (str(step.name), step.build()) - - if step.nxt is not None: - yield from process_from(step.nxt) - - -def process_group(step: Group[Any]) -> Iterable[tuple[str, SklearnItem]]: - """Process a single group into a tuple of (name, component) for sklearn. - - !!! warning - - Only works for groups with a single item. - - Args: - step: The step to process - - Returns: - tuple[str, SklearnComponent]: The name and component for sklearn - """ - if len(step) > 1: - raise ValueError( - f"Can't handle groups with more than 1 item: {step}." - "\nCurrently they are simply removed and replaced with their one item." - " If you inteded some other functionality with inclduing more than" - " one item in a group, please raise a ticket or implement your own" - " builder.", - ) - - single_path = step.paths[0] - yield from process_from(single_path) - - if step.nxt is not None: - yield from process_from(step.nxt) - - -def process_split( # noqa: C901 - split: Split[ColumnTransformer, Any], -) -> Iterable[tuple[str, SklearnItem]]: - """Process a single split into a tuple of (name, component) for sklearn. - - Args: - split: The step to process - - Returns: - tuple[str, SklearnComponent]: The name and component for sklearn - """ - if split.item is None: - raise NotImplementedError( - f"Can't handle split as it has no item attached: {split}.", - " Sklearn builder requires all splits to have a ColumnTransformer", - " as the item.", - ) - - if isinstance(split.item, type) and not issubclass(split.item, ColumnTransformer): - raise NotImplementedError( - f"Can't handle split as it has a ColumnTransformer as the item: {split}.", - " Sklearn builder requires all splits to have a subclass ColumnTransformer", - " as the item.", - ) - - if split.config is None: - raise NotImplementedError( - f"Can't handle split as it has no config attached: {split}.", - " Sklearn builder requires all splits to have a config to tell", - " the ColumnTransformer how to operate.", - ) - - if any(path.name in COLUMN_TRANSFORMER_ARGS for path in split.paths): - raise ValueError( - f"Can't handle step as it has a path with a name that matches" - f" a known ColumnTransformer argument: {split}", - ) - - path_names = {path.name for path in split.paths} - - # NOTE: If the path was previously under a `Choice`, then the choices name - # will be in the config while the selected choice's name will not. We add - # them here so they're added to the config and later we will remove the unsued - # one. - path_names.update( - {path.old_parent for path in split.paths if path.old_parent is not None}, - ) - - # Get the config values for the column transformer, and the paths - ct_config = {k: v for k, v in split.config.items() if k in COLUMN_TRANSFORMER_ARGS} - ct_paths = {k: v for k, v in split.config.items() if k in path_names} - - # ... if theirs any other values in the config that isn't these, raise an error - if any(split.config.keys() - ct_config.keys() - ct_paths.keys()): - raise ValueError( - "Can't handle split as it has a config with keys that aren't" - " ColumnTransformer arguments or paths" - "\nPlease ensure that all keys in the config are either ColumnTransformer" - " arguments or paths.\n" - f"\nSplit '{split.name}': {split.config}" - f"\nColumnTransformer arguments: {COLUMN_TRANSFORMER_ARGS}" - f"\nPaths: {path_names}" - f"\nPath config: {ct_paths}" - f"\n{split.paths}" - "\n", - ) - - for path in split.paths: - if path.name not in ct_paths and path.old_parent in ct_paths: - ct_paths[path.name] = ct_paths.pop(path.old_parent) - - transformers: list = [] - for path in split.paths: - if path.name not in ct_paths: - raise ValueError( - f"Can't handle split {split.name=} as it has a path {path.name=}" - " with no config associated with it." - "\nPlease ensure that all paths have a config associated with them." - f"\nSplit '{split.name}': {ct_paths}", - ) - - assert isinstance(path, (Component, Split, Group)) - steps = list(process_from(path)) - - sklearn_step: SklearnItem - - sklearn_step = steps[0][1] if len(steps) == 1 else SklearnPipeline(steps) - - split_config = ct_paths[path.name] - - split_item = (path.name, sklearn_step, split_config) - transformers.append(split_item) - - column_transformer_cls = split.item - column_transformer = column_transformer_cls(transformers, **ct_config) - yield (split.name, column_transformer) - - if split.nxt is not None: - yield from process_from(split.nxt) - - -def process_from(step: Step) -> Iterable[tuple[str, SklearnItem]]: - """Process a chain of steps into tuples of (name, component) for sklearn. - - Args: - step: The head of the chain of steps to process - - Yields: - tuple[str, SklearnComponent]: The name and component for sklearn - """ - if isinstance(step, Split): - yield from process_split(step) - elif isinstance(step, Group): - yield from process_group(step) - elif isinstance(step, Component): - yield from process_component(step) - else: - raise NotImplementedError(f"Can't handle step: {step}") - - -def build( - pipeline: Pipeline, - pipeline_type: type[SklearnPipelineT] = SklearnPipeline, - **pipeline_kwargs: Any, -) -> SklearnPipelineT: - """Build a pipeline into a usable object. - - Args: - pipeline: The pipeline to build - pipeline_type: The type of pipeline to build. Defaults to the standard - sklearn pipeline but can be any deritiative of that, i.e. imblearn's - pipeline. - **pipeline_kwargs: The kwargs to pass to the pipeline_type. - - Returns: - The built pipeline - """ - pipeline_kwargs = pipeline_kwargs or {} - - for step in pipeline.traverse(): - if not isinstance(step, (Component, Group, Split)): - msg = ( - f"Can't build pipeline with step {step}." - " Only Components and Splits are supported." - ) - raise ValueError(msg) - - assert isinstance(pipeline.head, (Component, Split, Group)) - steps = list(process_from(pipeline.head)) - return pipeline_type(steps, **pipeline_kwargs) # type: ignore diff --git a/src/amltk/sklearn/data.py b/src/amltk/sklearn/data.py index f53f3bdd..7520c907 100644 --- a/src/amltk/sklearn/data.py +++ b/src/amltk/sklearn/data.py @@ -1,8 +1,9 @@ """Data utilities for scikit-learn.""" from __future__ import annotations +from collections.abc import Sequence from itertools import chain -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING from more_itertools import last from sklearn.model_selection import train_test_split diff --git a/src/amltk/sklearn/voting.py b/src/amltk/sklearn/voting.py index 53281997..c8df3e0e 100644 --- a/src/amltk/sklearn/voting.py +++ b/src/amltk/sklearn/voting.py @@ -68,7 +68,7 @@ def voting_with_preffited_estimators( if is_classification: est0_classes_ = est0.classes_ # type: ignore - _voter.classes_ = est0_classes_ + _voter.classes_ = est0_classes_ # type: ignore if np.ndim(est0_classes_) > 1: est0_classes_ = est0_classes_[0] _voter.le_ = MultiLabelBinarizer().fit(est0_classes_) # type: ignore diff --git a/src/amltk/smac/__init__.py b/src/amltk/smac/__init__.py deleted file mode 100644 index f6bccaf0..00000000 --- a/src/amltk/smac/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from smac.runhistory import TrialInfo as SMACTrialInfo - -from amltk.smac.optimizer import SMACOptimizer - -__all__ = ["SMACOptimizer", "SMACTrialInfo"] diff --git a/src/amltk/store/bucket.py b/src/amltk/store/bucket.py index 23595636..8b274e95 100644 --- a/src/amltk/store/bucket.py +++ b/src/amltk/store/bucket.py @@ -12,20 +12,15 @@ import re from abc import ABC, abstractmethod -from typing import ( - TYPE_CHECKING, - Any, +from collections.abc import ( Callable, - Generic, Hashable, Iterable, Iterator, - Literal, Mapping, MutableMapping, - TypeVar, - overload, ) +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload from typing_extensions import override from more_itertools import ilen @@ -191,7 +186,7 @@ def find( matches = {match.groups(): self[key] for key, match in keys} # If it's a tuple of length 1, we expand it - one_group = len(list(matches.keys())[0]) == 1 + one_group = len(next(iter(matches.keys()))) == 1 if one_group: if multi_key: raise ValueError( diff --git a/src/amltk/store/drop.py b/src/amltk/store/drop.py index 8e443caf..de10d07e 100644 --- a/src/amltk/store/drop.py +++ b/src/amltk/store/drop.py @@ -4,12 +4,13 @@ from __future__ import annotations import logging +from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload from more_itertools.more import first -from amltk.functional import funcname +from amltk._functional import funcname from amltk.types import StoredValue if TYPE_CHECKING: diff --git a/src/amltk/store/paths/path_bucket.py b/src/amltk/store/paths/path_bucket.py index b22c788a..577ae765 100644 --- a/src/amltk/store/paths/path_bucket.py +++ b/src/amltk/store/paths/path_bucket.py @@ -4,9 +4,10 @@ from __future__ import annotations import shutil +from collections.abc import Iterator, Sequence from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterator, Sequence +from typing import TYPE_CHECKING, Any from typing_extensions import override from amltk.store.bucket import Bucket diff --git a/src/amltk/store/paths/path_loaders.py b/src/amltk/store/paths/path_loaders.py index bccb1de1..046dfd06 100644 --- a/src/amltk/store/paths/path_loaders.py +++ b/src/amltk/store/paths/path_loaders.py @@ -16,14 +16,7 @@ import pickle from abc import abstractmethod from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Literal, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar from typing_extensions import override import numpy as np @@ -126,7 +119,7 @@ class NPYLoader(PathLoader[np.ndarray]): * [`np.ndarray`][numpy.ndarray] """ - name: ClassVar[Literal["np"]] = "np" + name: ClassVar = "np" """::: amltk.store.paths.path_loaders.PathLoader.name""" @override @@ -161,7 +154,7 @@ def save(cls, obj: np.ndarray, key: Path, /) -> None: np.save(key, obj, allow_pickle=False) -class PDLoader(PathLoader[Union[pd.DataFrame, pd.Series]]): +class PDLoader(PathLoader[pd.DataFrame | pd.Series]): """A [`Loader`][amltk.store.paths.path_loaders.PathLoader] for loading and saving [`pd.DataFrame`][pandas.DataFrame]s. @@ -192,7 +185,7 @@ class PDLoader(PathLoader[Union[pd.DataFrame, pd.Series]]): Please consider using `".pdpickle"` instead. """ - name: ClassVar[Literal["pd"]] = "pd" + name: ClassVar = "pd" """::: amltk.store.paths.path_loaders.PathLoader.name""" @override @@ -212,7 +205,7 @@ def can_load(cls, key: Path, /, *, check: type | None = None) -> bool: def can_save(cls, obj: Any, key: Path, /) -> bool: """::: amltk.store.paths.path_loaders.PathLoader.can_save""" # noqa: D415 if key.suffix == ".pdpickle": - return isinstance(obj, (pd.Series, pd.DataFrame)) + return isinstance(obj, pd.Series | pd.DataFrame) if key.suffix == ".parquet": return isinstance(obj, pd.DataFrame) @@ -236,7 +229,7 @@ def load(cls, key: Path, /) -> pd.DataFrame | pd.Series: if key.suffix == ".pdpickle": obj = pd.read_pickle(key) # noqa: S301 - if not isinstance(obj, (pd.Series, pd.DataFrame)): + if not isinstance(obj, pd.Series | pd.DataFrame): msg = ( f"Expected `pd.Series | pd.DataFrame` from {key=}" f" but got `{type(obj).__name__}`." @@ -272,7 +265,7 @@ def save(cls, obj: pd.Series | pd.DataFrame, key: Path, /) -> None: raise ValueError(f"Unknown extension {key.suffix=}") -class JSONLoader(PathLoader[Union[dict, list]]): +class JSONLoader(PathLoader[dict | list]): """A [`Loader`][amltk.store.paths.path_loaders.PathLoader] for loading and saving [`dict`][dict]s and [`list`][list]s to JSON. @@ -286,7 +279,7 @@ class JSONLoader(PathLoader[Union[dict, list]]): * [`list`][list] """ - name: ClassVar[Literal["json"]] = "json" + name: ClassVar = "json" """::: amltk.store.paths.path_loaders.PathLoader.name""" @override @@ -299,7 +292,7 @@ def can_load(cls, key: Path, /, *, check: type | None = None) -> bool: @classmethod def can_save(cls, obj: Any, key: Path, /) -> bool: """::: amltk.store.paths.path_loaders.PathLoader.can_save""" # noqa: D415 - return isinstance(obj, (dict, list)) and key.suffix == ".json" + return isinstance(obj, dict | list) and key.suffix == ".json" @override @classmethod @@ -309,7 +302,7 @@ def load(cls, key: Path, /) -> dict | list: with key.open("r") as f: item = json.load(f) - if not isinstance(item, (dict, list)): + if not isinstance(item, dict | list): msg = f"Expected `dict | list` from {key=} but got `{type(item).__name__}`" raise TypeError(msg) @@ -324,7 +317,7 @@ def save(cls, obj: dict | list, key: Path, /) -> None: json.dump(obj, f) -class YAMLLoader(PathLoader[Union[dict, list]]): +class YAMLLoader(PathLoader[dict | list]): """A [`Loader`][amltk.store.paths.path_loaders.PathLoader] for loading and saving [`dict`][dict]s and [`list`][list]s to YAML. @@ -339,7 +332,7 @@ class YAMLLoader(PathLoader[Union[dict, list]]): * [`list`][list] """ - name: ClassVar[Literal["yaml"]] = "yaml" + name: ClassVar = "yaml" """::: amltk.store.paths.path_loaders.PathLoader.name""" @override @@ -352,7 +345,7 @@ def can_load(cls, key: Path, /, *, check: type | None = None) -> bool: @classmethod def can_save(cls, obj: Any, key: Path, /) -> bool: """::: amltk.store.paths.path_loaders.PathLoader.can_save""" # noqa: D415 - return isinstance(obj, (dict, list)) and key.suffix in (".yaml", ".yml") + return isinstance(obj, dict | list) and key.suffix in (".yaml", ".yml") @override @classmethod @@ -365,7 +358,7 @@ def load(cls, key: Path, /) -> dict | list: with key.open("r") as f: item = yaml.safe_load(f) - if not isinstance(item, (dict, list)): + if not isinstance(item, dict | list): msg = f"Expected `dict | list` from {key=} but got `{type(item).__name__}`" raise TypeError(msg) @@ -403,7 +396,7 @@ class PickleLoader(PathLoader[Any]): attempting to save or load the object. """ - name: ClassVar[Literal["pickle"]] = "pickle" + name: ClassVar = "pickle" """::: amltk.store.paths.path_loaders.PathLoader.name""" @override @@ -455,7 +448,7 @@ class TxtLoader(PathLoader[str]): * [`str`][str] """ - name: ClassVar[Literal["text"]] = "text" + name: ClassVar = "text" """::: amltk.store.paths.path_loaders.PathLoader.name""" @override @@ -501,7 +494,7 @@ class ByteLoader(PathLoader[bytes]): * [`bytes`][bytes] """ - name: ClassVar[Literal["bytes"]] = "bytes" + name: ClassVar = "bytes" @override @classmethod @@ -513,7 +506,7 @@ def can_load(cls, key: Path, /, *, check: type | None = None) -> bool: @classmethod def can_save(cls, obj: Any, key: Path, /) -> bool: """::: amltk.store.paths.path_loaders.PathLoader.can_save""" # noqa: D415 - return isinstance(obj, (dict, list)) and key.suffix in (".bin", ".bytes") + return isinstance(obj, dict | list) and key.suffix in (".bin", ".bytes") @override @classmethod diff --git a/src/amltk/threadpoolctl/__init__.py b/src/amltk/threadpoolctl/__init__.py deleted file mode 100644 index 81c61ebc..00000000 --- a/src/amltk/threadpoolctl/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from amltk.threadpoolctl.threadpoolctl_plugin import ThreadPoolCTLPlugin - -__all__ = ["ThreadPoolCTLPlugin"] diff --git a/src/amltk/types.py b/src/amltk/types.py index 0269593d..aee88fb9 100644 --- a/src/amltk/types.py +++ b/src/amltk/types.py @@ -2,24 +2,11 @@ from __future__ import annotations from abc import abstractmethod +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass from itertools import chain, repeat -from typing import ( - Any, - Callable, - Generic, - Iterable, - Iterator, - List, - Mapping, - NoReturn, - Protocol, - Sequence, - Tuple, - TypeVar, - Union, -) -from typing_extensions import TypeAlias, override +from typing import Any, Generic, NoReturn, Protocol, TypeAlias, TypeVar +from typing_extensions import override import numpy as np @@ -45,10 +32,11 @@ Space = TypeVar("Space") """Generic for objects that are aware of a space but not the specific kind""" -Seed: TypeAlias = Union[int, np.random.RandomState, np.random.Generator] -"""Type alias for kinds of Seeded objects""" +Seed: TypeAlias = int | np.integer | (np.random.RandomState | np.random.Generator) +"""Type alias for kinds of Seeded objects.""" -FidT = Union[Tuple[int, int], Tuple[float, float], List[Any]] +FidT: TypeAlias = tuple[int, int] | tuple[float, float] | list[Any] +"""Type alias for a fidelity bound.""" class Comparable(Protocol): diff --git a/src/amltk/wandb/__init__.py b/src/amltk/wandb/__init__.py deleted file mode 100644 index 73f73b27..00000000 --- a/src/amltk/wandb/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from amltk.wandb.wandb import WandbParams, WandbPlugin, WandbTrialTracker - -__all__ = ["WandbPlugin", "WandbTrialTracker", "WandbParams"] diff --git a/tests/configspace/test_parsing.py b/tests/configspace/test_parsing.py deleted file mode 100644 index ecf279ef..00000000 --- a/tests/configspace/test_parsing.py +++ /dev/null @@ -1,421 +0,0 @@ -from __future__ import annotations - -import pytest - -from amltk.configspace import ConfigSpaceParser - -try: - from ConfigSpace import ConfigurationSpace, EqualsCondition, ForbiddenEqualsClause -except ImportError: - pytest.skip("ConfigSpace not installed", allow_module_level=True) - -from pytest_cases import case, parametrize_with_cases - -from amltk.pipeline import Choice, Pipeline, Split, Step, choice, split, step - - -@case -def case_single_step() -> tuple[Step, ConfigurationSpace]: - item = step("a", object, space={"hp": [1, 2, 3]}) - expected = ConfigurationSpace({"a:hp": [1, 2, 3]}) - return item, expected - - -@case -def case_steps_with_embedded_forbiddens() -> tuple[Step, ConfigurationSpace]: - space = ConfigurationSpace({"hp": [1, 2, 3], "hp_other": ["a", "b", "c"]}) - space.add_forbidden_clause(ForbiddenEqualsClause(space["hp"], 2)) - - item = step("a", object, space=space) - expected = ConfigurationSpace({"a:hp": [1, 2, 3], "a:hp_other": ["a", "b", "c"]}) - expected.add_forbidden_clause(ForbiddenEqualsClause(expected["a:hp"], 2)) - - return item, expected - - -@case -def case_single_step_two_hp() -> tuple[Step, ConfigurationSpace]: - item = step("a", object, space={"hp": [1, 2, 3], "hp2": [1, 2, 3]}) - expected = ConfigurationSpace({"a:hp": [1, 2, 3], "a:hp2": [1, 2, 3]}) - return item, expected - - -@case -def case_single_step_two_hp_different_types() -> tuple[Step, ConfigurationSpace]: - item = step("a", object, space={"hp": [1, 2, 3], "hp2": (1, 10)}) - expected = ConfigurationSpace({"a:hp": [1, 2, 3], "a:hp2": (1, 10)}) - return item, expected - - -@case -def case_choice() -> tuple[Choice, ConfigurationSpace]: - item = choice( - "choice1", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ) - expected = ConfigurationSpace( - {"choice1:a:hp": [1, 2, 3], "choice1:b:hp2": (1, 10), "choice1": ["a", "b"]}, - ) - expected.add_conditions( - [ - EqualsCondition(expected["choice1:a:hp"], expected["choice1"], "a"), - EqualsCondition(expected["choice1:b:hp2"], expected["choice1"], "b"), - ], - ) - return item, expected - - -@case -def case_nested_choices() -> tuple[Choice, ConfigurationSpace]: - item = choice( - "choice1", - choice( - "choice2", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ) - expected = ConfigurationSpace( - { - "choice1:choice2:a:hp": [1, 2, 3], - "choice1:choice2:b:hp2": (1, 10), - "choice1:c:hp3": (1, 10), - "choice1": ["choice2", "c"], - "choice1:choice2": ["a", "b"], - }, - ) - expected.add_conditions( - [ - EqualsCondition( - expected["choice1:choice2"], - expected["choice1"], - "choice2", - ), - EqualsCondition(expected["choice1:c:hp3"], expected["choice1"], "c"), - EqualsCondition( - expected["choice1:choice2:a:hp"], - expected["choice1:choice2"], - "a", - ), - EqualsCondition( - expected["choice1:choice2:b:hp2"], - expected["choice1:choice2"], - "b", - ), - EqualsCondition(expected["choice1:c:hp3"], expected["choice1"], "c"), - ], - ) - return item, expected - - -@case -def case_nested_choices_with_split() -> tuple[Choice, ConfigurationSpace]: - item = choice( - "choice1", - split( - "split2", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ) - expected = ConfigurationSpace( - { - "choice1:split2:a:hp": [1, 2, 3], - "choice1:split2:b:hp2": (1, 10), - "choice1:c:hp3": (1, 10), - "choice1": ["split2", "c"], - }, - ) - expected.add_conditions( - [ - EqualsCondition(expected["choice1:c:hp3"], expected["choice1"], "c"), - EqualsCondition( - expected["choice1:split2:a:hp"], - expected["choice1"], - "split2", - ), - EqualsCondition( - expected["choice1:split2:b:hp2"], - expected["choice1"], - "split2", - ), - EqualsCondition(expected["choice1:c:hp3"], expected["choice1"], "c"), - ], - ) - return item, expected - - -@case -def case_nested_choices_with_split_and_choice() -> tuple[Choice, ConfigurationSpace]: - item = choice( - "choice1", - split( - "split2", - choice( - "choice3", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ), - step("d", object, space={"hp4": (1, 10)}), - ) - expected = ConfigurationSpace( - { - "choice1:split2:choice3:a:hp": [1, 2, 3], - "choice1:split2:choice3:b:hp2": (1, 10), - "choice1:split2:c:hp3": (1, 10), - "choice1:d:hp4": (1, 10), - "choice1": ["split2", "d"], - "choice1:split2:choice3": ["a", "b"], - }, - ) - - expected.add_conditions( - [ - EqualsCondition(expected["choice1:d:hp4"], expected["choice1"], "d"), - EqualsCondition( - expected["choice1:split2:choice3"], - expected["choice1"], - "split2", - ), - EqualsCondition( - expected["choice1:split2:c:hp3"], - expected["choice1"], - "split2", - ), - EqualsCondition( - expected["choice1:split2:choice3:a:hp"], - expected["choice1:split2:choice3"], - "a", - ), - EqualsCondition( - expected["choice1:split2:choice3:b:hp2"], - expected["choice1:split2:choice3"], - "b", - ), - EqualsCondition(expected["choice1:d:hp4"], expected["choice1"], "d"), - ], - ) - - return item, expected - - -@case -def case_choice_pipeline() -> tuple[Pipeline, ConfigurationSpace]: - pipeline = Pipeline.create( - choice( - "choice", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": [1, 2, 3]}), - ), - ) - expected = ConfigurationSpace( - {"choice:a:hp": [1, 2, 3], "choice:b:hp": [1, 2, 3], "choice": ["a", "b"]}, - ) - expected.add_conditions( - [ - EqualsCondition(expected["choice:a:hp"], expected["choice"], "a"), - EqualsCondition(expected["choice:b:hp"], expected["choice"], "b"), - ], - ) - return pipeline, expected - - -@case -def case_pipeline_with_choice_modules() -> tuple[Pipeline, ConfigurationSpace]: - pipeline = Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - choice( - "choice", - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - ), - ], - ) - expected = ConfigurationSpace( - { - "a:hp": [1, 2, 3], - "b:hp": (1, 10), - "c:hp": (1.0, 10.0), - "choice:d:hp": (1.0, 10.0), - "choice:e:hp": (1.0, 10.0), - "choice": ["d", "e"], - }, - ) - - expected.add_conditions( - [ - EqualsCondition(expected["choice:d:hp"], expected["choice"], "d"), - EqualsCondition(expected["choice:e:hp"], expected["choice"], "e"), - ], - ) - return pipeline, expected - - -@case -def case_joint_steps() -> tuple[Step, ConfigurationSpace]: - item = step("a", object, space={"hp": [1, 2, 3]}) | step( - "b", - object, - space={"hp2": (1, 10)}, - ) - expected = ConfigurationSpace({"a:hp": [1, 2, 3], "b:hp2": (1, 10)}) - return item, expected - - -@case -def case_split_steps() -> tuple[Step, ConfigurationSpace]: - item = split( - "split", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ) - expected = ConfigurationSpace({"split:a:hp": [1, 2, 3], "split:b:hp2": (1, 10)}) - return item, expected - - -@case -def case_nested_splits() -> tuple[Split, ConfigurationSpace]: - item = split( - "split1", - split( - "split2", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ) - expected = ConfigurationSpace( - { - "split1:split2:a:hp": [1, 2, 3], - "split1:split2:b:hp2": (1, 10), - "split1:c:hp3": (1, 10), - }, - ) - return item, expected - - -@case -def case_simple_linear_pipeline() -> tuple[Pipeline, ConfigurationSpace]: - pipeline = Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - ) - expected = ConfigurationSpace( - { - "a:hp": [1, 2, 3], - "b:hp": (1, 10), - "c:hp": (1.0, 10.0), - }, - ) - return pipeline, expected - - -@case -def case_split_pipeline() -> tuple[Pipeline, ConfigurationSpace]: - pipeline = Pipeline.create( - split( - "split", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": [1, 2, 3]}), - ), - ) - expected = ConfigurationSpace( - { - "split:a:hp": [1, 2, 3], - "split:b:hp": [1, 2, 3], - }, - ) - return pipeline, expected - - -@case -def case_pipeline_with_step_modules() -> tuple[Pipeline, ConfigurationSpace]: - pipeline = Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - ], - ) - expected = ConfigurationSpace( - { - "a:hp": [1, 2, 3], - "b:hp": (1, 10), - "c:hp": (1.0, 10.0), - "d:hp": (1.0, 10.0), - "e:hp": (1.0, 10.0), - }, - ) - return pipeline, expected - - -@case -def case_pipeline_with_pipeline_modules() -> tuple[Pipeline, ConfigurationSpace]: - pipeline = Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - Pipeline.create( - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - name="subpipeline", - ), - ], - ) - expected = ConfigurationSpace( - { - "a:hp": [1, 2, 3], - "b:hp": (1, 10), - "c:hp": (1.0, 10.0), - "subpipeline:d:hp": (1.0, 10.0), - "subpipeline:e:hp": (1.0, 10.0), - }, - ) - return pipeline, expected - - -@parametrize_with_cases("pipeline, expected", cases=".") -def test_parsing_pipeline(pipeline: Pipeline, expected: ConfigurationSpace) -> None: - parsed_space = pipeline.space(parser=ConfigSpaceParser()) - assert parsed_space == expected - - -@parametrize_with_cases("pipeline, expected", cases=".") -def test_parsing_pipeline_does_not_mutate_space( - pipeline: Pipeline, - expected: ConfigurationSpace, # noqa: ARG001 -) -> None: - spaces_before = { - step.qualified_name(): step.search_space for step in pipeline.traverse() - } - pipeline.space(parser=ConfigSpaceParser()) - - spaces_after = { - step.qualified_name(): step.search_space for step in pipeline.traverse() - } - assert spaces_before == spaces_after - - -@parametrize_with_cases("pipeline, expected", cases=".") -def test_parsing_twice_produces_same_space( - pipeline: Pipeline, - expected: ConfigurationSpace, -) -> None: - parsed_space = pipeline.space(parser=ConfigSpaceParser()) - parsed_space2 = pipeline.space(parser=ConfigSpaceParser()) - - assert parsed_space == expected - assert parsed_space2 == expected - assert parsed_space == parsed_space2 diff --git a/tests/configspace/test_sampling.py b/tests/configspace/test_sampling.py deleted file mode 100644 index 935e4dd6..00000000 --- a/tests/configspace/test_sampling.py +++ /dev/null @@ -1,265 +0,0 @@ -from __future__ import annotations - -from more_itertools import all_unique -from pytest_cases import case, parametrize, parametrize_with_cases - -from amltk.configspace import ConfigSpaceAdapter -from amltk.pipeline import Choice, Pipeline, Split, Step, choice, split, step - - -@case -def case_single_step() -> Step: - return step("a", object, space={"hp": [1, 2, 3]}) - - -@case -def case_single_step_two_hp() -> Step: - return step("a", object, space={"hp": [1, 2, 3], "hp2": [1, 2, 3]}) - - -@case -def case_single_step_two_hp_different_types() -> Step: - return step("a", object, space={"hp": [1, 2, 3], "hp2": (1, 10)}) - - -@case -def case_joint_steps() -> Step: - return step("a", object, space={"hp": [1, 2, 3]}) | step( - "b", - object, - space={"hp2": (1, 10)}, - ) - - -@case -def case_split_steps() -> Step: - return split( - "split", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ) - - -@case -def case_nested_splits() -> Split: - return split( - "split1", - split( - "split2", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ) - - -@case -def case_choice() -> Choice: - return choice( - "choice1", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ) - - -@case -def case_nested_choices() -> Choice: - return choice( - "choice1", - choice( - "choice2", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ) - - -@case -def case_nested_choices_with_split() -> Choice: - return choice( - "choice1", - split( - "split2", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ) - - -@case -def case_nested_choices_with_split_and_choice() -> Choice: - return choice( - "choice1", - split( - "split2", - choice( - "choice3", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ), - step("d", object, space={"hp4": (1, 10)}), - ) - - -@case -def case_simple_linear_pipeline() -> Pipeline: - return Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - ) - - -@case -def case_split_pipeline() -> Pipeline: - return Pipeline.create( - split( - "split", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": [1, 2, 3]}), - ), - ) - - -@case -def case_choice_pipeline() -> Pipeline: - return Pipeline.create( - choice( - "choice", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": [1, 2, 3]}), - ), - ) - - -@case -def case_pipeline_with_step_modules() -> Pipeline: - return Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - ], - ) - - -@case -def case_pipeline_with_choice_modules() -> Pipeline: - return Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - choice( - "choice", - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - ), - ], - ) - - -@case -def case_pipeline_with_pipeline_modules() -> Pipeline: - return Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - Pipeline.create( - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - ), - ], - ) - - -@case -def case_pipeline_with_pipeline_choice_modules() -> Pipeline: - return Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - Pipeline.create( - choice( - "choice", - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - ), - ), - ], - ) - - -@parametrize("n", [None, 5, 10]) -@parametrize_with_cases("item", cases=".", prefix="case_") -def test_sample_with_seed_returns_same_results( - item: Pipeline | Step, - n: int | None, -) -> None: - space = item.space(parser=ConfigSpaceAdapter) - - configs_1 = item.sample( - space=space, - sampler=ConfigSpaceAdapter, - seed=1, - n=n, - duplicates=True, - ) - configs_2 = item.sample( - space=space, - sampler=ConfigSpaceAdapter, - seed=1, - n=n, - duplicates=True, - ) - - assert configs_1 == configs_2 - - -def test_sampling_no_duplicates() -> None: - values = list(range(10)) - n = len(values) - - item: Step = step("x", object, space={"a": values}) - - adapter = ConfigSpaceAdapter - item.space(parser=adapter) - - configs = item.sample( - sampler=adapter, - n=n, - duplicates=False, - seed=42, - ) - - assert all_unique(configs) - - -def test_sampling_no_duplicates_with_seen_values() -> None: - values = list(range(10)) - n = len(values) - - item: Step = step("x", object, space={"a": values}) - - adapter = ConfigSpaceAdapter() - item.space(parser=adapter) - - seen_config = item.sample(sampler=adapter, seed=42) - - configs = item.sample( - sampler=adapter, - n=n - 1, - duplicates=[seen_config], - seed=42, - ) - - assert all_unique(configs) - assert seen_config not in configs diff --git a/tests/configuring/test_configuring.py b/tests/configuring/test_configuring.py deleted file mode 100644 index 3f17e47d..00000000 --- a/tests/configuring/test_configuring.py +++ /dev/null @@ -1,204 +0,0 @@ -from __future__ import annotations - -from typing import Any, Mapping - -from amltk import Pipeline, choice, searchable, split, step - - -def test_heirarchical_str() -> None: - pipeline = Pipeline.create( - step("one", 1, space={"v": [1, 2, 3]}), - split( - "split", - step("x", 1, space={"v": [4, 5, 6]}), - step("y", 1, space={"v": [4, 5, 6]}), - ), - choice( - "choice", - step("a", 1, space={"v": [4, 5, 6]}), - step("b", 1, space={"v": [4, 5, 6]}), - ), - ) - config = { - "one:v": 1, - "split:x:v": 4, - "split:y:v": 5, - "choice": "a", - "choice:a:v": 6, - } - result = pipeline.configure(config) - - expected = Pipeline.create( - step("one", 1, config={"v": 1}), - split( - "split", - step("x", 1, config={"v": 4}), - step("y", 1, config={"v": 5}), - ), - step("a", 1, config={"v": 6}), - name=pipeline.name, - ) - - assert result == expected - - -def test_heirarchical_str_with_predefined_configs() -> None: - pipeline = Pipeline.create( - step("one", 1, config={"v": 1}), - split( - "split", - step("x", 1), - step("y", 1, space={"v": [4, 5, 6]}), - ), - choice( - "choice", - step("a", 1), - step("b", 1), - ), - ) - - config = { - "one:v": 2, - "one:w": 3, - "split:x:v": 4, - "split:x:w": 42, - "choice": "a", - "choice:a:v": 3, - } - - expected = Pipeline.create( - step("one", 1, config={"v": 2, "w": 3}), - split( - "split", - step("x", 1, config={"v": 4, "w": 42}), - step("y", 1, config=None, space={"v": [4, 5, 6]}), - ), - step("a", 1, config={"v": 3}), - name=pipeline.name, - ) - - result = pipeline.configure(config) - assert result == expected - - -def test_configuration_with_nested_submodules() -> None: - pipeline = Pipeline.create( - step("1", 1, space={"a": [1, 2, 3]}), - step("2", 1, space={"b": [4, 5, 6]}), - ) - - module1 = Pipeline.create( - step("3", 1, space={"c": [7, 8, 9]}), - step("4", 1, space={"d": [10, 11, 12]}), - name="module1", - ) - - module2 = Pipeline.create( - choice( - "choice", - step("6", 1, space={"e": [13, 14, 15]}), - step("7", 1, space={"f": [16, 17, 18]}), - ), - name="module2", - ) - - module3 = Pipeline.create( - step("8", 1, space={"g": [19, 20, 21]}), - name="module3", - ) - - module2 = module2.attach(modules=(module3)) - - pipeline = pipeline.attach(modules=(module1, module2)) - - config = { - "1:a": 1, - "2:b": 4, - "module1:3:c": 7, - "module1:4:d": 10, - "module2:choice": "6", - "module2:choice:6:e": 13, - "module2:module3:8:g": 19, - } - - expected_module1 = Pipeline.create( - step("3", 1, config={"c": 7}), - step("4", 1, config={"d": 10}), - name="module1", - ) - - expected_module2 = Pipeline.create( - step("6", 1, config={"e": 13}), - name="module2", - ) - - expected_module3 = Pipeline.create( - step("8", 1, config={"g": 19}), - name="module3", - ) - - expected_pipeline = Pipeline.create( - step("1", 1, config={"a": 1}), - step("2", 1, config={"b": 4}), - name=pipeline.name, - ) - - expected_module2 = expected_module2.attach(modules=(expected_module3)) - - expected_pipeline = expected_pipeline.attach( - modules=(expected_module1, expected_module2), - ) - - assert expected_pipeline == pipeline.configure(config) - - -def test_heirachical_str_with_searchables() -> None: - pipeline = Pipeline.create( - step("1", 1, space={"a": [1, 2, 3]}), - step("2", 1, space={"b": [4, 5, 6]}), - ) - - extra = searchable("searchables", space={"a": [1, 2, 3], "b": [4, 5, 6]}) - pipeline = pipeline.attach(modules=extra) - - config = { - "1:a": 1, - "2:b": 4, - "searchables:a": 1, - "searchables:b": 4, - } - - expected = Pipeline.create( - step("1", 1, config={"a": 1}), - step("2", 1, config={"b": 4}), - name=pipeline.name, - ) - expected_extra = searchable("searchables", config={"a": 1, "b": 4}) - - expected = expected.attach(modules=expected_extra) - - assert expected == pipeline.configure(config) - - -def test_config_transform() -> None: - def _transformer_1(_: Mapping, __: Any) -> Mapping: - return {"hello": "world"} - - def _transformer_2(_: Mapping, __: Any) -> Mapping: - return {"hi": "mars"} - - pipeline = Pipeline.create( - step("1", 1, space={"a": [1, 2, 3]}, config_transform=_transformer_1), - step("2", 1, space={"b": [1, 2, 3]}, config_transform=_transformer_2), - ) - config = { - "1:a": 1, - "2:b": 1, - } - - expected = Pipeline.create( - step("1", 1, config={"hello": "world"}, config_transform=_transformer_1), - step("2", 1, config={"hi": "mars"}, config_transform=_transformer_2), - name=pipeline.name, - ) - assert expected == pipeline.configure(config) diff --git a/tests/conftest.py b/tests/conftest.py index 6a937fda..b4e6baef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,9 @@ import asyncio import re +from collections.abc import Iterator from pathlib import Path -from typing import Any, Iterator +from typing import Any import pytest diff --git a/tests/metalearning/test_dataset_distances.py b/tests/metalearning/test_dataset_distances.py index d433b16a..b5b4d285 100644 --- a/tests/metalearning/test_dataset_distances.py +++ b/tests/metalearning/test_dataset_distances.py @@ -5,13 +5,13 @@ import pandas as pd from pytest_cases import case, parametrize, parametrize_with_cases +from amltk._functional import funcname from amltk.distances import ( DistanceMetric, NamedDistance, NearestNeighborsDistance, distance_metrics, ) -from amltk.functional import funcname from amltk.metalearning.dataset_distances import dataset_distance @@ -65,25 +65,25 @@ def test_distance_to_itself_is_zero( other: npt.ArrayLike, metric: DistanceMetric | NearestNeighborsDistance, ) -> None: - target = np.asarray(target) - target = pd.Series( - target, + _target = np.asarray(target) + starget = pd.Series( + _target, name="target", - index=[f"mf{i}" for i in range(len(target))], + index=[f"mf{i}" for i in range(len(_target))], ) - other = np.asarray(other) - other = pd.Series( - other, + _other = np.asarray(other) + sother = pd.Series( + _other, name="other", - index=[f"mf{i}" for i in range(len(other))], + index=[f"mf{i}" for i in range(len(_other))], ) - other2 = other.copy() + sother2 = sother.copy() # We use 2 here to make sure the ordering remains correct expected = pd.Series([0, 0], index=["other", "other2"], dtype=float) distances = dataset_distance( - target=target, - dataset_metafeatures={"other": other, "other2": other2}, + target=starget, + dataset_metafeatures={"other": sother, "other2": sother2}, distance_metric=metric, ) diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index a819fb66..70ebb9ec 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -6,14 +6,14 @@ import pytest from pytest_cases import case, parametrize, parametrize_with_cases -from amltk.optimization import Optimizer, RandomSearch, Trial -from amltk.pipeline import Pipeline, step +from amltk.optimization import Optimizer, Trial +from amltk.pipeline import Component from amltk.profiling import Memory, Timer if TYPE_CHECKING: - from amltk.neps import NEPSOptimizer - from amltk.optuna import OptunaOptimizer - from amltk.smac import SMACOptimizer + from amltk.optimization.optimizers.neps import NEPSOptimizer + from amltk.optimization.optimizers.optuna import OptunaOptimizer + from amltk.optimization.optimizers.smac import SMACOptimizer logger = logging.getLogger(__name__) @@ -48,45 +48,41 @@ def valid_time_interval(interval: Timer.Interval) -> bool: return interval.start <= interval.end -@case -def opt_random_search() -> tuple[RandomSearch, str]: - s = step("hi", _A, space={"a": (1, 10)}) - pipeline = Pipeline.create(s) - return RandomSearch(space=pipeline.space()), "cost" - - @case def opt_smac_hpo() -> tuple[SMACOptimizer, str]: try: - from amltk.smac import SMACOptimizer + from amltk.optimization.optimizers.smac import SMACOptimizer except ImportError: pytest.skip("SMAC is not installed") - pipeline = Pipeline.create(step("hi", _A, space={"a": (1, 10)})) - return SMACOptimizer.create(space=pipeline.space(), seed=2**32 - 1), "cost" + pipeline = Component(_A, name="hi", space={"a": (1, 10)}) + return SMACOptimizer.create( + space=pipeline.search_space(SMACOptimizer.preferred_parser()), + seed=2**32 - 1, + ), "cost" @case def opt_optuna() -> tuple[OptunaOptimizer, str]: try: - from amltk.optuna import OptunaOptimizer, OptunaParser + from amltk.optimization.optimizers.optuna import OptunaOptimizer except ImportError: pytest.skip("Optuna is not installed") - pipeline = Pipeline.create(step("hi", _A, space={"a": (1, 10)})) - space = pipeline.space(parser=OptunaParser()) + pipeline = Component(_A, name="hi", space={"a": (1, 10)}) + space = pipeline.search_space(parser=OptunaOptimizer.preferred_parser()) return OptunaOptimizer.create(space=space), "cost" @case def opt_neps() -> tuple[NEPSOptimizer, str]: try: - from amltk.neps import NEPSOptimizer + from amltk.optimization.optimizers.neps import NEPSOptimizer except ImportError: pytest.skip("NEPS is not installed") - pipeline = Pipeline.create(step("hi", _A, space={"a": (1, 10)})) - space = pipeline.space() + pipeline = Component(_A, name="hi", space={"a": (1, 10)}) + space = pipeline.search_space(parser=NEPSOptimizer.preferred_parser()) return NEPSOptimizer.create(space=space, overwrite=True), "loss" diff --git a/tests/optimizers/test_random_search.py b/tests/optimizers/test_random_search.py deleted file mode 100644 index 7f1a8b66..00000000 --- a/tests/optimizers/test_random_search.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Callable - -import pytest -from ConfigSpace import Configuration, ConfigurationSpace -from more_itertools import all_unique -from pytest_cases import case, parametrize, parametrize_with_cases - -from amltk.configspace import ConfigSpaceAdapter -from amltk.optimization import RandomSearch -from amltk.optuna import OptunaSpaceAdapter -from amltk.pipeline import Parser, Sampler, Step, searchable - -logger = logging.getLogger(__name__) - - -@case -def case_int_searchable() -> Step: - return searchable("my_space", space={"a": (1, 1_000)}) - - -@case -def case_mixed_searchable() -> Step: - return searchable("my_space", space={"a": (1, 10), "b": (1.0, 10.0)}) - - -def custom_sampler(space: ConfigurationSpace, seed: int) -> Configuration: - space.seed(seed) - return space.sample_configuration() - - -@parametrize( - "parser, sampler", - [ - (None, None), - (OptunaSpaceAdapter, OptunaSpaceAdapter), - (ConfigSpaceAdapter, ConfigSpaceAdapter), - (ConfigSpaceAdapter, custom_sampler), - ], -) -@parametrize_with_cases("step", cases=".", prefix="case_") -def test_random_search_space( - step: Step, - parser: type[Parser], - sampler: type[Sampler] | Callable, -) -> None: - """Test that the random search space is correct.""" - space = step.space(parser=parser) - optimizer1 = RandomSearch(space=space, sampler=sampler, seed=42) - trials1 = [optimizer1.ask() for _ in range(10)] - - assert all_unique(trials1) - - optimizer2 = RandomSearch(space=space, sampler=sampler, seed=42) - trials2 = [optimizer2.ask() for _ in range(10)] - - assert all_unique(trials2) - - assert trials1 == trials2 - assert (t1.config == t2.config for t1, t2 in zip(trials1, trials2)) - - -def test_random_search_exhausted_with_limited_space() -> None: - limited_space = searchable( - "my_space", - space={"a": ["cat", "dog", "elephant"], "b": ["apple", "honey", "spice"]}, - ).space() - - optimizer = RandomSearch(space=limited_space, seed=42) - - for _ in range(3 * 3): - optimizer.ask() - - with pytest.raises(RandomSearch.ExhaustedError): - optimizer.ask() diff --git a/tests/optuna/test_parsing.py b/tests/optuna/test_parsing.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/optuna/test_sampling.py b/tests/optuna/test_sampling.py deleted file mode 100644 index b08c8d12..00000000 --- a/tests/optuna/test_sampling.py +++ /dev/null @@ -1,159 +0,0 @@ -from __future__ import annotations - -from more_itertools import all_unique -from pytest_cases import case, parametrize, parametrize_with_cases - -from amltk.optuna.space import OptunaSpaceAdapter -from amltk.pipeline import Pipeline, Split, Step, split, step - - -@case -def case_single_step() -> Step: - return step("a", object, space={"hp": [1, 2, 3]}) - - -@case -def case_single_step_two_hp() -> Step: - return step("a", object, space={"hp": [1, 2, 3], "hp2": [1, 2, 3]}) - - -@case -def case_single_step_two_hp_different_types() -> Step: - return step("a", object, space={"hp": [1, 2, 3], "hp2": (1, 10)}) - - -@case -def case_joint_steps() -> Step: - return step("a", object, space={"hp": [1, 2, 3]}) | step( - "b", - object, - space={"hp2": (1, 10)}, - ) - - -@case -def case_split_steps() -> Step: - return split( - "split", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ) - - -@case -def case_nested_splits() -> Split: - return split( - "split1", - split( - "split2", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp2": (1, 10)}), - ), - step("c", object, space={"hp3": (1, 10)}), - ) - - -@case -def case_simple_linear_pipeline() -> Pipeline: - return Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - ) - - -@case -def case_split_pipeline() -> Pipeline: - return Pipeline.create( - split( - "split", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": [1, 2, 3]}), - ), - ) - - -@case -def case_pipeline_with_step_modules() -> Pipeline: - return Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - ], - ) - - -@case -def case_pipeline_with_pipeline_modules() -> Pipeline: - return Pipeline.create( - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": (1, 10)}), - step("c", object, space={"hp": (1.0, 10.0)}), - modules=[ - Pipeline.create( - step("d", object, space={"hp": (1.0, 10.0)}), - step("e", object, space={"hp": (1.0, 10.0)}), - ), - ], - ) - - -@parametrize("n", [None, 5, 10]) -@parametrize_with_cases("item", cases=".", prefix="case_") -def test_sample_with_seed_returns_same_results( - item: Pipeline | Step, - n: int | None, -) -> None: - configs_1 = item.sample( - sampler=OptunaSpaceAdapter(), - seed=1, - n=n, - duplicates=True, - ) - configs_2 = item.sample( - sampler=OptunaSpaceAdapter(), - seed=1, - n=n, - duplicates=True, - ) - - assert configs_1 == configs_2 - - -def test_sampling_no_duplicates() -> None: - values = list(range(10)) - n = len(values) - - item: Step = step("x", object, space={"a": values}) - - configs = item.sample( - sampler=OptunaSpaceAdapter, - n=n, - duplicates=False, - seed=42, - ) - - assert all_unique(configs) - - -def test_sampling_no_duplicates_with_seen_values() -> None: - values = list(range(10)) - n = len(values) - - item: Step = step("x", object, space={"a": values}) - - adapter = OptunaSpaceAdapter() - seen_config = item.sample(sampler=adapter, seed=42) - - configs = item.sample( - sampler=adapter, - n=n - 1, - duplicates=[seen_config], - seed=42, - ) - - assert all_unique(configs) - assert seen_config not in configs diff --git a/tests/pynisher/__init__.py b/tests/pipeline/parsing/__init__.py similarity index 100% rename from tests/pynisher/__init__.py rename to tests/pipeline/parsing/__init__.py diff --git a/tests/pipeline/parsing/test_configspace_parsing.py b/tests/pipeline/parsing/test_configspace_parsing.py new file mode 100644 index 00000000..df140224 --- /dev/null +++ b/tests/pipeline/parsing/test_configspace_parsing.py @@ -0,0 +1,780 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import pytest +from pytest_cases import case, parametrize_with_cases + +from amltk.pipeline import Choice, Component, Fixed, Node, Sequential, Split + +try: + from ConfigSpace import ConfigurationSpace, EqualsCondition, ForbiddenEqualsClause +except ImportError: + pytest.skip("ConfigSpace not installed", allow_module_level=True) + + +FLAT = True +NOT_FLAT = False +CONDITIONED = True +NOT_CONDITIONED = False + + +@dataclass +class Params: + """A test case for parsing a Node into a ConfigurationSpace.""" + + root: Node + expected: dict[tuple[bool, bool], ConfigurationSpace] + + +@case +def case_single_frozen() -> Params: + item = Fixed(object(), name="a") + space = ConfigurationSpace() + expected = { + (NOT_FLAT, CONDITIONED): space, + (NOT_FLAT, NOT_CONDITIONED): space, + (FLAT, CONDITIONED): space, + (FLAT, NOT_CONDITIONED): space, + } + return Params(item, expected) # type: ignore + + +@case +def case_single_component() -> Params: + item = Component(object, name="a", space={"hp": [1, 2, 3]}) + space = ConfigurationSpace({"a:hp": [1, 2, 3]}) + expected = { + (NOT_FLAT, CONDITIONED): space, + (NOT_FLAT, NOT_CONDITIONED): space, + (FLAT, CONDITIONED): space, + (FLAT, NOT_CONDITIONED): space, + } + return Params(item, expected) # type: ignore + + +@case +def case_steps_with_embedded_forbiddens() -> Params: + space = ConfigurationSpace({"hp": [1, 2, 3], "hp_other": ["a", "b", "c"]}) + space.add_forbidden_clause(ForbiddenEqualsClause(space["hp"], 2)) + + item = Component(object, name="a", space=space) + + with_conditions = ConfigurationSpace( + {"a:hp": [1, 2, 3], "a:hp_other": ["a", "b", "c"]}, + ) + with_conditions.add_forbidden_clause( + ForbiddenEqualsClause(with_conditions["a:hp"], 2), + ) + + without_conditions = ConfigurationSpace( + {"a:hp": [1, 2, 3], "a:hp_other": ["a", "b", "c"]}, + ) + + expected = { + (NOT_FLAT, CONDITIONED): with_conditions, + (NOT_FLAT, NOT_CONDITIONED): without_conditions, + (FLAT, CONDITIONED): with_conditions, + (FLAT, NOT_CONDITIONED): without_conditions, + } + return Params(item, expected) # type: ignore + + +@case +def case_single_step_two_hp() -> Params: + item = Component(object, name="a", space={"hp": [1, 2, 3], "hp2": [1, 2, 3]}) + space = ConfigurationSpace({"a:hp": [1, 2, 3], "a:hp2": [1, 2, 3]}) + + expected = { + (NOT_FLAT, CONDITIONED): space, + (NOT_FLAT, NOT_CONDITIONED): space, + (FLAT, CONDITIONED): space, + (FLAT, NOT_CONDITIONED): space, + } + return Params(item, expected) # type: ignore + + +@case +def case_single_step_two_hp_different_types() -> Params: + item = Component(object, name="a", space={"hp": [1, 2, 3], "hp2": (1, 10)}) + space = ConfigurationSpace({"a:hp": [1, 2, 3], "a:hp2": (1, 10)}) + expected = { + (NOT_FLAT, CONDITIONED): space, + (NOT_FLAT, NOT_CONDITIONED): space, + (FLAT, CONDITIONED): space, + (FLAT, NOT_CONDITIONED): space, + } + return Params(item, expected) # type: ignore + + +@case +def case_choice() -> Params: + item = Choice( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp2": (1, 10)}), + name="choice1", + space={"hp3": (1, 10)}, + ) + + expected = {} + + # Not flat and with conditions + space = ConfigurationSpace( + { + "choice1:a:hp": [1, 2, 3], + "choice1:b:hp2": (1, 10), + "choice1:hp3": (1, 10), + "choice1:__choice__": ["a", "b"], + }, + ) + space.add_conditions( + [ + EqualsCondition(space["choice1:a:hp"], space["choice1:__choice__"], "a"), + EqualsCondition(space["choice1:b:hp2"], space["choice1:__choice__"], "b"), + ], + ) + expected[(NOT_FLAT, CONDITIONED)] = space + + # Flat and with conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "choice1:hp3": (1, 10), + "choice1:__choice__": ["a", "b"], + }, + ) + space.add_conditions( + [ + EqualsCondition(space["a:hp"], space["choice1:__choice__"], "a"), + EqualsCondition(space["b:hp2"], space["choice1:__choice__"], "b"), + ], + ) + expected[(FLAT, CONDITIONED)] = space + + # Not Flat and without conditions + space = ConfigurationSpace( + { + "choice1:a:hp": [1, 2, 3], + "choice1:b:hp2": (1, 10), + "choice1:hp3": (1, 10), + "choice1:__choice__": ["a", "b"], + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # Flat and without conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "choice1:hp3": (1, 10), + "choice1:__choice__": ["a", "b"], + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + return Params(item, expected) # type: ignore + + +@case +def case_nested_choices() -> Params: + item = Choice( + Choice( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp2": (1, 10)}), + name="choice2", + ), + Component(object, name="c", space={"hp3": (1, 10)}), + name="choice1", + ) + + expected = {} + + # Not flat and with conditions + space = ConfigurationSpace( + { + "choice1:choice2:a:hp": [1, 2, 3], + "choice1:choice2:b:hp2": (1, 10), + "choice1:c:hp3": (1, 10), + "choice1:__choice__": ["c", "choice2"], + "choice1:choice2:__choice__": ["a", "b"], + }, + ) + space.add_conditions( + [ + EqualsCondition( + space["choice1:choice2:__choice__"], + space["choice1:__choice__"], + "choice2", + ), + EqualsCondition(space["choice1:c:hp3"], space["choice1:__choice__"], "c"), + EqualsCondition( + space["choice1:choice2:a:hp"], + space["choice1:choice2:__choice__"], + "a", + ), + EqualsCondition( + space["choice1:choice2:b:hp2"], + space["choice1:choice2:__choice__"], + "b", + ), + EqualsCondition(space["choice1:c:hp3"], space["choice1:__choice__"], "c"), + ], + ) + expected[(NOT_FLAT, CONDITIONED)] = space + + # flat and with conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "c:hp3": (1, 10), + "choice1:__choice__": ["c", "choice2"], + "choice2:__choice__": ["a", "b"], + }, + ) + space.add_conditions( + [ + EqualsCondition( + space["choice2:__choice__"], + space["choice1:__choice__"], + "choice2", + ), + EqualsCondition(space["c:hp3"], space["choice1:__choice__"], "c"), + EqualsCondition( + space["a:hp"], + space["choice2:__choice__"], + "a", + ), + EqualsCondition( + space["b:hp2"], + space["choice2:__choice__"], + "b", + ), + EqualsCondition(space["c:hp3"], space["choice1:__choice__"], "c"), + ], + ) + expected[(FLAT, CONDITIONED)] = space + + # Not flat and without conditions + space = ConfigurationSpace( + { + "choice1:choice2:a:hp": [1, 2, 3], + "choice1:choice2:b:hp2": (1, 10), + "choice1:c:hp3": (1, 10), + "choice1:__choice__": ["c", "choice2"], + "choice1:choice2:__choice__": ["a", "b"], + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # flat and without conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "c:hp3": (1, 10), + "choice1:__choice__": ["c", "choice2"], + "choice2:__choice__": ["a", "b"], + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + + return Params(item, expected) # type: ignore + + +@case +def case_nested_choices_with_split() -> Params: + item = Choice( + Split( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp2": (1, 10)}), + name="split2", + ), + Component(object, name="c", space={"hp3": (1, 10)}), + name="choice1", + ) + expected = {} + + # Not flat and with conditions + space = ConfigurationSpace( + { + "choice1:split2:a:hp": [1, 2, 3], + "choice1:split2:b:hp2": (1, 10), + "choice1:c:hp3": (1, 10), + "choice1:__choice__": ["c", "split2"], + }, + ) + space.add_conditions( + [ + EqualsCondition(space["choice1:c:hp3"], space["choice1:__choice__"], "c"), + EqualsCondition( + space["choice1:split2:a:hp"], + space["choice1:__choice__"], + "split2", + ), + EqualsCondition( + space["choice1:split2:b:hp2"], + space["choice1:__choice__"], + "split2", + ), + EqualsCondition(space["choice1:c:hp3"], space["choice1:__choice__"], "c"), + ], + ) + expected[(NOT_FLAT, CONDITIONED)] = space + + # flat and with conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "c:hp3": (1, 10), + "choice1:__choice__": ["c", "split2"], + }, + ) + space.add_conditions( + [ + EqualsCondition(space["c:hp3"], space["choice1:__choice__"], "c"), + EqualsCondition(space["a:hp"], space["choice1:__choice__"], "split2"), + EqualsCondition(space["b:hp2"], space["choice1:__choice__"], "split2"), + EqualsCondition(space["c:hp3"], space["choice1:__choice__"], "c"), + ], + ) + expected[(FLAT, CONDITIONED)] = space + + # not flat and without conditions + space = ConfigurationSpace( + { + "choice1:split2:a:hp": [1, 2, 3], + "choice1:split2:b:hp2": (1, 10), + "choice1:c:hp3": (1, 10), + "choice1:__choice__": ["c", "split2"], + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # flat and without conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "c:hp3": (1, 10), + "choice1:__choice__": ["c", "split2"], + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + + return Params(item, expected) + + +@case +def case_nested_choices_with_split_and_choice() -> Params: + item = Choice( + Split( + Choice( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp2": (1, 10)}), + name="choice3", + ), + Component(object, name="c", space={"hp3": (1, 10)}), + name="split2", + ), + Component(object, name="d", space={"hp4": (1, 10)}), + name="choice1", + ) + expected = {} + + # Not flat and with conditions + space = ConfigurationSpace( + { + "choice1:split2:choice3:a:hp": [1, 2, 3], + "choice1:split2:choice3:b:hp2": (1, 10), + "choice1:split2:c:hp3": (1, 10), + "choice1:d:hp4": (1, 10), + "choice1:__choice__": ["d", "split2"], + "choice1:split2:choice3:__choice__": ["a", "b"], + }, + ) + space.add_conditions( + [ + EqualsCondition(space["choice1:d:hp4"], space["choice1:__choice__"], "d"), + EqualsCondition( + space["choice1:split2:choice3:__choice__"], + space["choice1:__choice__"], + "split2", + ), + EqualsCondition( + space["choice1:split2:c:hp3"], + space["choice1:__choice__"], + "split2", + ), + EqualsCondition( + space["choice1:split2:choice3:a:hp"], + space["choice1:split2:choice3:__choice__"], + "a", + ), + EqualsCondition( + space["choice1:split2:choice3:b:hp2"], + space["choice1:split2:choice3:__choice__"], + "b", + ), + ], + ) + expected[(NOT_FLAT, CONDITIONED)] = space + + # Flat and with conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "c:hp3": (1, 10), + "d:hp4": (1, 10), + "choice1:__choice__": ["d", "split2"], + "choice3:__choice__": ["a", "b"], + }, + ) + space.add_conditions( + [ + EqualsCondition(space["d:hp4"], space["choice1:__choice__"], "d"), + EqualsCondition( + space["choice3:__choice__"], + space["choice1:__choice__"], + "split2", + ), + EqualsCondition(space["c:hp3"], space["choice1:__choice__"], "split2"), + EqualsCondition(space["a:hp"], space["choice3:__choice__"], "a"), + EqualsCondition(space["b:hp2"], space["choice3:__choice__"], "b"), + ], + ) + expected[(FLAT, CONDITIONED)] = space + + # Not flat and without conditions + space = ConfigurationSpace( + { + "choice1:split2:choice3:a:hp": [1, 2, 3], + "choice1:split2:choice3:b:hp2": (1, 10), + "choice1:split2:c:hp3": (1, 10), + "choice1:d:hp4": (1, 10), + "choice1:__choice__": ["d", "split2"], + "choice1:split2:choice3:__choice__": ["a", "b"], + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # Flat and without conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "c:hp3": (1, 10), + "d:hp4": (1, 10), + "choice1:__choice__": ["d", "split2"], + "choice3:__choice__": ["a", "b"], + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + + return Params(item, expected) + + +@case +def case_sequential_with_choice() -> Params: + item = Sequential( + Choice( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp": [1, 2, 3]}), + name="choice", + ), + name="pipeline", + ) + expected = {} + + # Not flat and with conditions + space = ConfigurationSpace( + { + "pipeline:choice:a:hp": [1, 2, 3], + "pipeline:choice:b:hp": [1, 2, 3], + "pipeline:choice:__choice__": ["a", "b"], + }, + ) + space.add_conditions( + [ + EqualsCondition( + space["pipeline:choice:a:hp"], + space["pipeline:choice:__choice__"], + "a", + ), + EqualsCondition( + space["pipeline:choice:b:hp"], + space["pipeline:choice:__choice__"], + "b", + ), + ], + ) + expected[(NOT_FLAT, CONDITIONED)] = space + + # Flat and with conditions + # Note: For flat configuration, the namespace does not include "pipeline:choice", + # but just the "choice" to reflect the selected component. + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp": [1, 2, 3], + "choice:__choice__": ["a", "b"], + }, + ) + space.add_conditions( + [ + EqualsCondition(space["a:hp"], space["choice:__choice__"], "a"), + EqualsCondition(space["b:hp"], space["choice:__choice__"], "b"), + ], + ) + expected[(FLAT, CONDITIONED)] = space + + # Not flat and without conditions + space = ConfigurationSpace( + { + "pipeline:choice:a:hp": [1, 2, 3], + "pipeline:choice:b:hp": [1, 2, 3], + "pipeline:choice:__choice__": ["a", "b"], + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # Flat and without conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp": [1, 2, 3], + "choice:__choice__": ["a", "b"], + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + + return Params(item, expected) + + +@case +def case_sequential_with_own_search_space() -> Params: + item = Sequential( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp": (1, 10)}), + Component(object, name="c", space={"hp": (1.0, 10.0)}), + name="pipeline", + space={"something": ["a", "b", "c"]}, + ) + expected = {} + + # Not flat and without conditions + space = ConfigurationSpace( + { + "pipeline:a:hp": [1, 2, 3], + "pipeline:b:hp": (1, 10), + "pipeline:c:hp": (1.0, 10.0), + "pipeline:something": ["a", "b", "c"], + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # Not flat and with conditions - although it doesn't logically apply here, + # we still add a dummy condition for the sake of consistency + space = ConfigurationSpace( + { + "pipeline:a:hp": [1, 2, 3], + "pipeline:b:hp": (1, 10), + "pipeline:c:hp": (1.0, 10.0), + "pipeline:something": ["a", "b", "c"], + }, + ) + space.add_conditions([]) + expected[(NOT_FLAT, CONDITIONED)] = space + + # Flat and without conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp": (1, 10), + "c:hp": (1.0, 10.0), + "pipeline:something": ["a", "b", "c"], + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + + # Flat and with conditions - similarly, dummy conditions are added + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp": (1, 10), + "c:hp": (1.0, 10.0), + "pipeline:something": ["a", "b", "c"], + }, + ) + space.add_conditions([]) + expected[(FLAT, CONDITIONED)] = space + + return Params(item, expected) + + +@case +def case_nested_splits() -> Params: + item = Split( + Split( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp2": (1, 10)}), + name="split2", + ), + Component(object, name="c", space={"hp3": (1, 10)}), + name="split1", + ) + expected = {} + + # Not flat and without conditions + space = ConfigurationSpace( + { + "split1:split2:a:hp": [1, 2, 3], + "split1:split2:b:hp2": (1, 10), + "split1:c:hp3": (1, 10), + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # Not flat and with conditions - Splits do not have conditions + space = ConfigurationSpace( + { + "split1:split2:a:hp": [1, 2, 3], + "split1:split2:b:hp2": (1, 10), + "split1:c:hp3": (1, 10), + }, + ) + space.add_conditions([]) # No conditions for splits + expected[(NOT_FLAT, CONDITIONED)] = space + + # Flat and without conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "c:hp3": (1, 10), + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + + # Flat and with conditions - Conditions would be empty as no __choice__ + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp2": (1, 10), + "c:hp3": (1, 10), + }, + ) + space.add_conditions([]) # No conditions for splits + expected[(FLAT, CONDITIONED)] = space + + return Params(item, expected) + + +@case +def case_sequential_with_split() -> Params: + pipeline = Sequential( + Split( + Component(object, name="a", space={"hp": [1, 2, 3]}), + Component(object, name="b", space={"hp": [1, 2, 3]}), + name="split", + ), + name="pipeline", + space={"something": ["a", "b", "c"]}, + ) + + expected = {} + + # Not flat and without conditions + space = ConfigurationSpace( + { + "pipeline:split:a:hp": [1, 2, 3], + "pipeline:split:b:hp": [1, 2, 3], + "pipeline:something": ["a", "b", "c"], + }, + ) + expected[(NOT_FLAT, NOT_CONDITIONED)] = space + + # Flat and without conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp": [1, 2, 3], + "pipeline:something": ["a", "b", "c"], + }, + ) + expected[(FLAT, NOT_CONDITIONED)] = space + + # Not flat and with conditions + space = ConfigurationSpace( + { + "pipeline:split:a:hp": [1, 2, 3], + "pipeline:split:b:hp": [1, 2, 3], + "pipeline:something": ["a", "b", "c"], + }, + ) + space.add_conditions([]) # No conditions to add, but placeholder for consistency + expected[(NOT_FLAT, CONDITIONED)] = space + + # Flat and with conditions + space = ConfigurationSpace( + { + "a:hp": [1, 2, 3], + "b:hp": [1, 2, 3], + "pipeline:something": ["a", "b", "c"], + }, + ) + space.add_conditions([]) # No conditions to add, but placeholder for consistency + expected[(FLAT, CONDITIONED)] = space + + return Params(pipeline, expected) + + +@parametrize_with_cases("test_case", cases=".") +def test_parsing_pipeline(test_case: Params) -> None: + pipeline = test_case.root + + for (flat, conditioned), expected in test_case.expected.items(): + parsed_space = pipeline.search_space( + "configspace", + flat=flat, + conditionals=conditioned, + ) + assert ( + parsed_space == expected + ), f"Failed for {flat=}, {conditioned=}.\n{parsed_space}\n{expected}" + + +@parametrize_with_cases("test_case", cases=".") +def test_parsing_does_not_mutate_space_of_nodes(test_case: Params) -> None: + pipeline = test_case.root + spaces_before = {tuple(path): step.space for path, step in pipeline.walk()} + + for (flat, conditioned), _ in test_case.expected.items(): + pipeline.search_space( + "configspace", + flat=flat, + conditionals=conditioned, + ) + spaces_after = {tuple(path): step.space for path, step in pipeline.walk()} + assert spaces_before == spaces_after + + +@parametrize_with_cases("test_case", cases=".") +def test_parsing_twice_produces_same_space(test_case: Params) -> None: + pipeline = test_case.root + + for (flat, conditioned), _ in test_case.expected.items(): + parsed_space = pipeline.search_space( + "configspace", + flat=flat, + conditionals=conditioned, + ) + parsed_space2 = pipeline.search_space( + "configspace", + flat=flat, + conditionals=conditioned, + ) + assert parsed_space == parsed_space2 diff --git a/tests/pipeline/parsing/test_optuna_parser.py b/tests/pipeline/parsing/test_optuna_parser.py new file mode 100644 index 00000000..ba098ff7 --- /dev/null +++ b/tests/pipeline/parsing/test_optuna_parser.py @@ -0,0 +1 @@ +# TODO: Fill this in diff --git a/tests/pipeline/test_as_node.py b/tests/pipeline/test_as_node.py new file mode 100644 index 00000000..7ac5a379 --- /dev/null +++ b/tests/pipeline/test_as_node.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from amltk.pipeline import Choice, Component, Fixed, Join, Sequential, as_node + + +def test_as_node_returns_copy_of_node() -> None: + c = Component(int) + + out = as_node(c) + assert c == out + + # Should be a copied object + assert id(out) != id(c) + + +def test_as_node_with_tuple_returns_join() -> None: + c1 = Component(int) + c2 = Component(str) + c3 = Component(bool) + + out = as_node((c1, c2, c3)) + expected = Join(c1, c2, c3) + assert out.nodes == expected.nodes + + +def test_as_node_with_set_returns_choice() -> None: + c1 = Component(int) + c2 = Component(str) + c3 = Component(bool) + + # A choice will sort the nodes such that their + # order is always consistent, even though we provided + # a set. + out = as_node({c1, c2, c3}) + expected = Choice(c1, c2, c3) + + assert out.nodes == expected.nodes + + +def test_as_node_with_list_returns_sequential() -> None: + c1 = Component(int) + c2 = Component(str) + c3 = Component(bool) + + out = as_node([c1, c2, c3]) + expected = Sequential(c1, c2, c3) + assert out.nodes == expected.nodes + + +@dataclass +class MyThing: + """Docstring.""" + + x: int = 1 + + def __call__(self) -> None: + """Dummy to try trick frozen into thinking it's a function, it should not.""" + + +def test_as_node_with_constructed_object_returns_frozen() -> None: + thing = MyThing(1) + out = as_node(thing) + assert out == Fixed(thing) + + +def create_a_thing() -> MyThing: + return MyThing(1) + + +def test_as_node_with_callable_function_returns_component() -> None: + out = as_node(create_a_thing) + assert out == Component(create_a_thing) + + +def test_as_node_with_class_returns_component() -> None: + out = as_node(MyThing) + assert out == Component(MyThing) diff --git a/tests/pipeline/test_choice.py b/tests/pipeline/test_choice.py index 44e79662..3552f39f 100644 --- a/tests/pipeline/test_choice.py +++ b/tests/pipeline/test_choice.py @@ -1,44 +1,75 @@ from __future__ import annotations -import pytest +from dataclasses import dataclass -from amltk import choice, step +from amltk.pipeline import Choice, Component -def test_error_when_uneven_weights() -> None: - with pytest.raises(ValueError, match="Weights must be the same length as choices"): - choice("choice", step("a", object), step("b", object), weights=[1]) +@dataclass +class Thing1: + """A thing.""" + x: int = 1 -def test_choice_shallow() -> None: - c = choice( - "choice", - step("a", object), - step("b", object) | step("c", object), + +@dataclass +class Thing2: + """A thing.""" + + x: int = 2 + + +def test_choice_creation_empty() -> None: + choice = Choice() + assert choice.nodes == () + + +def test_choice_construction() -> None: + choice = Choice(Thing1, Thing2) + assert choice.nodes == (Component(Thing1), Component(Thing2)) + + +def test_choice_copy() -> None: + choice = Choice(Component(Thing2, config={"x": 1})) + assert choice == choice.copy() + + +def test_choice_or() -> None: + """__or__ changes behavior when compared to other nodes.""" + choice = ( + Choice(name="choice") + | Component(Thing1, name="comp1", config={"x": 1}) + | Component(Thing1, name="comp2", config={"x": 1}) + | Component(Thing1, name="comp3", config={"x": 1}) ) - assert next(c.select({"choice": "a"})) == step("a", object) - assert next(c.select({"choice": "b"})) == step("b", object) | step("c", object) + assert choice == Choice( + Component(Thing1, name="comp1", config={"x": 1}), + Component(Thing1, name="comp2", config={"x": 1}), + Component(Thing1, name="comp3", config={"x": 1}), + name="choice", + ) -def test_choice_deep() -> None: - c = ( - step("head", object) - | choice( - "choice", - step("a", object), - step("b", object), - choice("choice2", step("c", object), step("d", object)) | step("e", object), - ) - | step("tail", object) +def test_choice_configured_gives_chosen_node() -> None: + choice = ( + Choice(name="choice_thing") + | Component(Thing1, name="comp1", config={"x": 1}) + | Component(Thing1, name="comp2", config={"x": 1}) + | Component(Thing1, name="comp3", config={"x": 1}) ) + configured_choice = choice.configure({"__choice__": "comp2"}) - expected = ( - step("head", object) - | step("d", object) - | step("e", object) - | step("tail", object) + assert configured_choice == Choice( + Component(Thing1, name="comp1", config={"x": 1}), + Component(Thing1, name="comp2", config={"x": 1}), + Component(Thing1, name="comp3", config={"x": 1}), + name="choice_thing", + config={"__choice__": "comp2"}, ) - chosen = next(c.select({"choice": "choice2", "choice2": "d"})) - assert chosen == expected + assert configured_choice.chosen() == Component( + Thing1, + name="comp2", + config={"x": 1}, + ) diff --git a/tests/pipeline/test_component.py b/tests/pipeline/test_component.py index 7338859b..7203b831 100644 --- a/tests/pipeline/test_component.py +++ b/tests/pipeline/test_component.py @@ -1,88 +1,42 @@ from __future__ import annotations -from itertools import chain, combinations +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any -from pytest_cases import case, parametrize, parametrize_with_cases +from pytest_cases import parametrize -from amltk import step -from amltk.pipeline import Component, Step +from amltk.pipeline.components import Component -@case -@parametrize("size", [1, 3, 10]) -def case_component_chain(size: int) -> Component: - component = Step.join(step(str(i), object) for i in range(size)) - assert isinstance(component, Component) - return component +@dataclass +class Thing: + """A thing.""" + x: int = 1 -@parametrize_with_cases("head", cases=".") -def test_traverse(head: Component) -> None: - # Component chains with no splits should traverse as they iter - assert list(head.traverse()) == list(head.iter()) +def thing_maker(x: int = 1) -> Thing: + return Thing(x) -@parametrize_with_cases("head", cases=".") -def test_walk(head: Component) -> None: - # Components chains with no splits should walk as they iter - walk = head.walk([], []) +@parametrize("maker", [Thing, thing_maker]) +def test_component_construction(maker: Any) -> None: + component = Component(maker, name="comp", config={"x": 2}) + assert component.name == "comp" + assert component.item == maker + assert component.config == {"x": 2} - # Ensure the head has no splits or parents - splits, parents, the_head = next(walk) - assert not any(splits) - assert not any(parents) - assert the_head == head - for splits, parents, current_step in walk: - assert not any(splits) - assert any(parents) - # Ensure that the parents are all the steps from the head up to the current step - assert parents == list(head.head().iter(to=current_step)) +@parametrize("maker", [Thing, thing_maker]) +def test_component_builds(maker: Callable[[], Thing]) -> None: + f = Component(maker, name="comp", config={"x": 5}) + obj = f.build_item() + assert obj == Thing(x=5) -@parametrize_with_cases("head", cases=".") -def test_replace_one(head: Component) -> None: - new_step = step("replacement", object) - for to_replace in head.iter(): - new_chain = list(head.replace({to_replace.name: new_step})) - expected = [new_step if s.name == to_replace.name else s for s in head.iter()] - assert new_chain == expected - - -@parametrize_with_cases("head", cases=".") -def test_replace_many(head: Component) -> None: - steps = list(head.iter()) - lens = range(1, len(steps) + 1) - replacements = [ - {s.name: step(f"{s.name}_r", object) for s in to_replace} - for to_replace in chain.from_iterable( - combinations(steps, length) for length in lens - ) - ] - - for to_replace in replacements: - new_chain = list(head.replace(to_replace)) - expected = [to_replace.get(s.name, s) for s in head.iter()] - assert new_chain == expected - - -@parametrize_with_cases("head", cases=".") -def test_remove_one(head: Component) -> None: - for to_remove in head.iter(): - removed_chain = list(head.remove([to_remove.name])) - expected = [s for s in head.iter() if s.name != to_remove.name] - assert expected == removed_chain - - -@parametrize_with_cases("head", cases=".") -def test_remove_many(head: Component) -> None: - steps = list(head.iter()) - lens = range(1, len(steps) + 1) - removals = chain.from_iterable(combinations(steps, length) for length in lens) - - for to_remove in removals: - names = [r.name for r in to_remove] - remaining = list(head.remove(names)) - expected = [s for s in head.iter() if s.name not in names] - assert expected == remaining +@parametrize("maker", [Thing, thing_maker]) +def test_copy(maker: Any) -> None: + f = Component(maker, name="comp", config={"x": 5}, space={"x": [1, 2, 3]}) + f2 = f.copy() + assert f == f2 diff --git a/tests/pipeline/test_configuring.py b/tests/pipeline/test_configuring.py new file mode 100644 index 00000000..8c5e4597 --- /dev/null +++ b/tests/pipeline/test_configuring.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from amltk.pipeline import Choice, Component, Sequential, Split + + +def test_heirarchical_str() -> None: + pipeline = ( + Sequential(name="pipeline") + >> Component(object, name="one", space={"v": [1, 2, 3]}) + >> Split( + Component(object, name="x", space={"v": [4, 5, 6]}), + Component(object, name="y", space={"v": [4, 5, 6]}), + name="split", + ) + >> Choice( + Component(object, name="a", space={"v": [4, 5, 6]}), + Component(object, name="b", space={"v": [4, 5, 6]}), + name="choice", + ) + ) + config = { + "pipeline:one:v": 1, + "pipeline:split:x:v": 4, + "pipeline:split:y:v": 5, + "pipeline:choice:__choice__": "a", + "pipeline:choice:a:v": 6, + } + result = pipeline.configure(config) + + expected = ( + Sequential(name="pipeline") + >> Component(object, name="one", config={"v": 1}, space={"v": [1, 2, 3]}) + >> Split( + Component(object, name="x", config={"v": 4}, space={"v": [4, 5, 6]}), + Component(object, name="y", config={"v": 5}, space={"v": [4, 5, 6]}), + name="split", + ) + >> Choice( + Component(object, name="a", config={"v": 6}, space={"v": [4, 5, 6]}), + Component(object, name="b", space={"v": [4, 5, 6]}), + name="choice", + config={"__choice__": "a"}, + ) + ) + + assert result == expected + + +def test_heirarchical_str_with_predefined_configs() -> None: + pipeline = ( + Sequential(name="pipeline") + >> Component(object, name="one", config={"v": 1}) + >> Split( + Component(object, name="x"), + Component(object, name="y", space={"v": [4, 5, 6]}), + name="split", + ) + >> Choice( + Component(object, name="a"), + Component(object, name="b"), + name="choice", + ) + ) + + config = { + "pipeline:one:v": 2, + "pipeline:one:w": 3, + "pipeline:split:x:v": 4, + "pipeline:split:x:w": 42, + "pipeline:choice:__choice__": "a", + "pipeline:choice:a:v": 3, + } + + expected = ( + Sequential(name="pipeline") + >> Component(object, name="one", config={"v": 2, "w": 3}) + >> Split( + Component(object, name="x", config={"v": 4, "w": 42}), + Component(object, name="y", space={"v": [4, 5, 6]}), + name="split", + ) + >> Choice( + Component(object, name="a", config={"v": 3}), + Component(object, name="b"), + name="choice", + config={"__choice__": "a"}, + ) + ) + + result = pipeline.configure(config) + assert result == expected + + +def test_config_transform() -> None: + def _transformer_1(_: Mapping, __: Any) -> Mapping: + return {"hello": "world"} + + def _transformer_2(_: Mapping, __: Any) -> Mapping: + return {"hi": "mars"} + + pipeline = ( + Sequential(name="pipeline") + >> Component( + object, + name="1", + space={"a": [1, 2, 3]}, + config_transform=_transformer_1, + ) + >> Component( + object, + name="2", + space={"b": [1, 2, 3]}, + config_transform=_transformer_2, + ) + ) + config = { + "pipeline:1:a": 1, + "pipeline:2:b": 1, + } + + expected = ( + Sequential(name="pipeline") + >> Component( + object, + name="1", + space={"a": [1, 2, 3]}, + config={"hello": "world"}, + config_transform=_transformer_1, + ) + >> Component( + object, + name="2", + space={"b": [1, 2, 3]}, + config={"hi": "mars"}, + config_transform=_transformer_2, + ) + ) + assert expected == pipeline.configure(config) diff --git a/tests/pipeline/test_frozen.py b/tests/pipeline/test_frozen.py new file mode 100644 index 00000000..e5fdd003 --- /dev/null +++ b/tests/pipeline/test_frozen.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from amltk.pipeline import Fixed + + +@dataclass +class Thing: + """A thing.""" + + x: int = 1 + + +def test_frozen_construction_direct() -> None: + f = Fixed(Thing(x=1)) + assert f.name == "Thing" + assert f.item == Thing(x=1) + + +def test_copy() -> None: + f = Fixed(Thing(x=1)) + f2 = f.copy() + assert f == f2 diff --git a/tests/pipeline/test_group.py b/tests/pipeline/test_group.py deleted file mode 100644 index 7edb71ac..00000000 --- a/tests/pipeline/test_group.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -from amltk.pipeline import group, step - - -def test_group_path_to_simple() -> None: - g = group("group", step("a", object) | step("b", object) | step("c", object)) - - a = g.find("a") - b = g.find("b") - c = g.find("c") - - assert g.path_to("group") == [g] - assert g.path_to("a") == [g, a] - assert g.path_to("b") == [g, a, b] - assert g.path_to("c") == [g, a, b, c] - - -def test_group_path_to_deep() -> None: - g = ( - step("head", object) - | group( - "group", - step("a", object), - step("b", object), - group("group2", step("c", object), step("d", object)) | step("e", object), - ) - | step("tail", object) - ) - - head = g.find("head") - _group = g.find("group") - a = g.find("a") - b = g.find("b") - group2 = g.find("group2") - c = g.find("c") - d = g.find("d") - e = g.find("e") - tail = g.find("tail") - - assert g.path_to("head") == [head] - assert g.path_to("group") == [head, _group] - assert g.path_to("a") == [head, _group, a] - assert g.path_to("b") == [head, _group, b] - assert g.path_to("group2") == [head, _group, group2] - assert g.path_to("c") == [head, _group, group2, c] - assert g.path_to("d") == [head, _group, group2, d] - assert g.path_to("e") == [head, _group, group2, e] - assert g.path_to("tail") == [head, _group, tail] - - -def test_group_simple_select() -> None: - g = group("group", step("a", object), step("b", object), step("c", object)) - - assert next(g.select({"group": "a"})) == step("a", object) - assert next(g.select({"group": "b"})) == step("b", object) - assert next(g.select({"group": "c"})) == step("c", object) - - -def test_group_deep_select() -> None: - g = ( - step("head", object) - | group( - "group", - step("a", object), - step("b", object), - group("group2", step("c", object), step("d", object)) | step("e", object), - ) - | step("tail", object) - ) - - expected = ( - step("head", object) - | step("d", object) - | step("e", object) - | step("tail", object) - ) - - chosen = next(g.select({"group": "group2", "group2": "d"})) - assert chosen == expected - - -def test_group_simple_traverse() -> None: - g = group("group", step("a", object), step("b", object), step("c", object)) - - assert list(g.traverse()) == [ - g, - step("a", object), - step("b", object), - step("c", object), - ] - - -def test_group_deep_traverse() -> None: - g = ( - step("head", object) - | group( - "group", - step("a", object), - step("b", object), - group("group2", step("c", object), step("d", object)) | step("e", object), - ) - | step("tail", object) - ) - - _group = g.find("group") - group2 = g.find("group2") - - expected = [ - step("head", object), - _group, - step("a", object), - step("b", object), - group2, - step("c", object), - step("d", object), - step("e", object), - step("tail", object), - ] - - assert list(g.traverse()) == expected - - -def test_group_simple_walk() -> None: - g = group("group", step("a", object), step("b", object), step("c", object)) - - assert list(g.walk()) == [ - ([], [], g), - ([g], [], step("a", object)), - ([g], [], step("b", object)), - ([g], [], step("c", object)), - ] - - -def test_group_list_walk() -> None: - g = group("group", step("a", object) | step("b", object) | step("c", object)) - - assert list(g.walk()) == [ - ([], [], g), - ([g], [], step("a", object)), - ([g], [step("a", object)], step("b", object)), - ([g], [step("a", object), step("b", object)], step("c", object)), - ] - - -def test_group_deep_walk() -> None: - g = ( - step("head", object) - | group( - "group", - step("a", object), - step("b", object), - group( - "group2", - step("c", object), - step("d", object) | step("extra", object), - ) - | step("e", object), - ) - | step("tail", object) - ) - - _group = g.find("group") - group2 = g.find("group2") - head = g.find("head") - - expected = [ - ([], [], head), - ([], [head], _group), - ([_group], [], step("a", object)), - ([_group], [], step("b", object)), - ([_group], [], group2), - ([_group, group2], [], step("c", object)), - ([_group, group2], [], step("d", object)), - ([_group, group2], [step("d", object)], step("extra", object)), - ([_group], [group2], step("e", object)), - ([], [head, _group], step("tail", object)), - ] - - assert list(g.walk()) == expected - - -def test_group_configure_simple() -> None: - g = group( - "group", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": [4, 5, 6]}), - step("c", object, space={"hp": [7, 8, 9]}), - ) - expected = group( - "group", - step("a", object, config={"hp": 1}), - step("b", object, config={"hp": 4}), - step("c", object, config={"hp": 7}), - ) - - configuration = { - "group:a:hp": 1, - "group:b:hp": 4, - "group:c:hp": 7, - } - - assert g.configure(configuration) == expected - - -def test_group_configure_deep() -> None: - g = ( - step("head", object) - | group( - "group", - step("a", object, space={"hp": [1, 2, 3]}), - step("b", object, space={"hp": [4, 5, 6]}), - group( - "group2", - step("c", object, space={"hp": [7, 8, 9]}), - step("d", object, space={"hp": [10, 11, 12]}) - | step("extra", object, space={"hp": [21, 22, 23]}), - ) - | step("e", object, space={"hp": [13, 14, 15]}), - ) - | step("tail", object) - ) - expected = ( - step("head", object) - | group( - "group", - step("a", object, config={"hp": 1}), - step("b", object, config={"hp": 4}), - group( - "group2", - step("c", object, config={"hp": 7}), - step("d", object, config={"hp": 10}) - | step("extra", object, config={"hp": 21}), - ) - | step("e", object, config={"hp": 13}), - ) - | step("tail", object) - ) - - config = { - "group:a:hp": 1, - "group:b:hp": 4, - "group:group2:c:hp": 7, - "group:group2:d:hp": 10, - "group:group2:d:extra:hp": 21, - "group:e:hp": 13, - } - - assert g.configure(config) == expected diff --git a/tests/pipeline/test_join.py b/tests/pipeline/test_join.py new file mode 100644 index 00000000..8fd33b34 --- /dev/null +++ b/tests/pipeline/test_join.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from amltk.pipeline import Component, Join + + +@dataclass +class Thing: + """A thing.""" + + x: int = 1 + + +def test_join_creation_empty() -> None: + j = Join() + assert j.nodes == () + + +def test_join_construction() -> None: + j = Join( + Component(Thing, name="comp1"), + Component(Thing, name="comp2"), + name="join", + ) + assert j.name == "join" + assert j.nodes == (Component(Thing, name="comp1"), Component(Thing, name="comp2")) + + +def test_join_copy() -> None: + join1 = Join(Component(Thing, name="comp1", config={"x": 1}), name="join1") + join2 = Join(Component(Thing, name="comp2", config={"x": 1}), name="join2") + + assert join1 == join1.copy() + assert join2 == join2.copy() + + join3 = join1 & join2 + assert join3 == join3.copy() + + +def test_join_and() -> None: + """__and__ changes behavior when compared to other nodes.""" + join = ( + Join(name="join") + & Component(Thing, name="comp1", config={"x": 1}) + & Component(Thing, name="comp2", config={"x": 1}) + & Component(Thing, name="comp3", config={"x": 1}) + ) + + assert join == Join( + Component(Thing, name="comp1", config={"x": 1}), + Component(Thing, name="comp2", config={"x": 1}), + Component(Thing, name="comp3", config={"x": 1}), + name="join", + ) diff --git a/tests/pipeline/test_modules.py b/tests/pipeline/test_modules.py deleted file mode 100644 index c8c8a9a5..00000000 --- a/tests/pipeline/test_modules.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -from ConfigSpace import ConfigurationSpace, EqualsCondition - -from amltk import Pipeline, choice, step -from amltk.configspace import ConfigSpaceParser - - -def test_pipeline_with_2_pipeline_modules() -> None: - pipeline = Pipeline.create( - step("1", object, space={"a": [1, 2, 3]}), - step("2", object, space={"b": [4, 5, 6]}), - ) - - module1 = Pipeline.create( - step("3", object, space={"c": [7, 8, 9]}), - step("4", object, space={"d": [10, 11, 12]}), - name="module1", - ) - - module2 = Pipeline.create( - choice( - "choice", - step("6", object, space={"e": [13, 14, 15]}), - step("7", object, space={"f": [16, 17, 18]}), - ), - name="module2", - ) - - pipeline = pipeline.attach(modules=(module1, module2)) - assert len(pipeline) == 2 - assert len(pipeline.modules) == 2 - - space = pipeline.space(parser=ConfigSpaceParser()) - assert isinstance(space, ConfigurationSpace) - - expected_space = { - "1:a": [1, 2, 3], - "2:b": [4, 5, 6], - "module1:3:c": [7, 8, 9], - "module1:4:d": [10, 11, 12], - "module2:choice": ["6", "7"], - "module2:choice:6:e": [13, 14, 15], - "module2:choice:7:f": [16, 17, 18], - } - expected = ConfigurationSpace(expected_space) - expected.add_conditions( - [ - EqualsCondition( - expected["module2:choice:6:e"], - expected["module2:choice"], - "6", - ), - EqualsCondition( - expected["module2:choice:7:f"], - expected["module2:choice"], - "7", - ), - ], - ) - assert space == expected diff --git a/tests/pipeline/test_node.py b/tests/pipeline/test_node.py new file mode 100644 index 00000000..f90aadf7 --- /dev/null +++ b/tests/pipeline/test_node.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import pytest + +from amltk.exceptions import RequestNotMetError +from amltk.pipeline import Choice, Join, Node, Sequential, request + + +def test_node_rshift() -> None: + node1 = Node(name="n1") + node2 = Node(name="n2") + node3 = Node(name="n3") + + out = node1 >> node2 >> node3 + expected_nodes = (node1, node2, node3) + + assert isinstance(out, Sequential) + assert out.nodes == expected_nodes + for a, b in zip(out.nodes, expected_nodes, strict=True): + assert id(a) != id(b) + + +def test_node_and() -> None: + node1 = Node(name="node1") + node2 = Node(name="node2") + node3 = Node(name="node3") + + out = node1 & node2 & node3 + expected_nodes = (node1, node2, node3) + + assert isinstance(out, Join) + assert out.nodes == expected_nodes + for a, b in zip(out.nodes, expected_nodes, strict=True): + assert id(a) != id(b) + + +def test_node_or() -> None: + node1 = Node(name="node1") + node2 = Node(name="node2") + node3 = Node(name="node3") + + out = node1 | node2 | node3 + expected_nodes = (node1, node2, node3) + + assert isinstance(out, Choice) + assert set(out.nodes) == set(expected_nodes) + for a, b in zip(out.nodes, expected_nodes, strict=True): + assert id(a) != id(b) + + +def test_single_node_configure() -> None: + node = Node(name="node") + node = node.configure({"a": 1, "b": 2}) + assert node == Node(name="node", config={"a": 1, "b": 2}) + + node = Node(name="node") + node = node.configure({"node:a": 1, "node:b": 2}) + assert node == Node(name="node", config={"a": 1, "b": 2}) + + +def test_with_children_configure() -> None: + node = Node(Node(name="child1"), Node(name="child2"), name="node") + node = node.configure( + {"node:a": 1, "node:b": 2, "node:child1:c": 3, "node:child2:d": 4}, + ) + + assert node == Node( + Node(config={"c": 3}, name="child1"), + Node(config={"d": 4}, name="child2"), + name="node", + config={"a": 1, "b": 2}, + ) + + +def test_deeply_nested_children_configuration() -> None: + node = Node( + Node(Node(name="child2"), name="child1"), + name="node", + ) + node = node.configure({"node:a": 1, "node:child1:b": 2, "node:child1:child2:c": 3}) + + assert node == Node( + Node( + Node(name="child2", config={"c": 3}), + name="child1", + config={"b": 2}, + ), + name="node", + config={"a": 1}, + ) + + +def test_configure_with_transform() -> None: + def _transform(config: Mapping[str, Any], _) -> dict: + c = (config["a"], config["b"]) + return {"c": c} + + node = Node(config_transform=_transform, name="1") + node = node.configure({"1:a": 1, "1:b": 2}) + assert node == Node(config={"c": (1, 2)}, name="1", config_transform=_transform) + + +def test_configure_with_param_request() -> None: + node = Node( + config={ + "x": request("x"), + "y": request("y"), + "z": request("z", default=3), + }, + name="1", + ) + + # Should configure as expected, with default and specified values + conf_node = node.configure({"a": -1}, params={"x": 1, "y": 2}) + assert conf_node == Node(name="1", config={"a": -1, "x": 1, "y": 2, "z": 3}) + + # When trying to configure with "x" missing, should raise + with pytest.raises(RequestNotMetError): + node.configure({"a": -1}, params={"y": 2}) + + +def test_find() -> None: + n1 = Node(name="1") + n2 = Node(name="2") + n3 = Node(name="3") + seq = n1 >> n2 >> n3 + + s1 = seq.find("1") + assert s1 is seq.nodes[0] + + s2 = seq.find("2") + assert s2 is seq.nodes[1] + + s3 = seq.find("3") + assert s3 is seq.nodes[2] + + s4 = seq.find("4") + assert s4 is None + + default = Node(name="default") + s5 = seq.find("5", default=default) + assert s5 is default + + +def test_walk() -> None: + sub3 = Node(name="sub3") + sub2 = Node(sub3, name="sub2") + n1 = Node(name="1") + n2 = Node(sub2, name="2") + n3 = Node(name="3") + + seq = n1 >> n2 >> n3 + + expected_path = [ + ([], seq), + ([seq], n1), + ([seq, n1], n2), + ([seq, n1, n2], sub2), + ([seq, n1, n2, sub2], sub3), + ([seq, n1, n2], n3), + ] + + for (path, node), (_exp_path, _exp_node) in zip( + seq.walk(), + expected_path, + strict=True, + ): + assert node == _exp_node + assert path == _exp_path diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py deleted file mode 100644 index a509a0c5..00000000 --- a/tests/pipeline/test_pipeline.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import annotations - -import random - -import pytest -from more_itertools import first, last -from pytest_cases import case, parametrize, parametrize_with_cases - -from amltk.pipeline import Pipeline, Step, choice, request, split, step - - -class _A: - pass - - -@case(tags="shallow") -@parametrize("size", [1, 3, 10]) -def case_shallow_pipeline(size: int) -> Pipeline: - return Pipeline.create(step(str(i), _A) for i in range(size)) - - -@case(tags="deep") -def case_deep_pipeline() -> Pipeline: - # We want sequential split points - split1 = split("split1", step("1", _A), step("2", _A)) - split2 = split("split2", step("3", _A), step("4", _A)) - sequential = split1 | split2 - - # We want some choices - choice1 = choice( - "choice1", - step("5", _A), - step("6", _A) | step("7", _A), - ) - choice2 = choice( - "choice2", - step("8", _A), - step("9", _A) | step("10", _A), - ) - - # Use these to create at least one layer of depth - shallow_spread = split("split3", choice1, choice2) - - # Create a very deep part - deep_part = split("deep1", split("deep2", split("deep3", step("leaf", _A)))) - - # Throw on a long part - long_part = step("l1", _A) | step("l2", _A) | step("l3", _A) | step("l4", _A) - head = step("head", _A) - tail = step("tail", _A) - return Pipeline.create( - head, - sequential, - shallow_spread, - deep_part, - long_part, - tail, - ) - - -def test_pipeline_mixture_of_steps() -> None: - s1 = step("1", _A) - s2 = step("2", _A) - s3 = step("3", _A) - s4 = step("4", _A) - - pipeline = Pipeline.create(s1, s2 | s3, s4) - - assert list(pipeline) == [s1, s2, s3, s4] - - -@parametrize_with_cases("pipeline", cases=".") -def test_pipeline_has_steps(pipeline: Pipeline) -> None: - assert list(pipeline) == list(pipeline.steps) - - -@parametrize_with_cases("pipeline", cases=".") -def test_head(pipeline: Pipeline) -> None: - assert pipeline.head == pipeline.steps[0] == first(pipeline) - - -@parametrize_with_cases("pipeline", cases=".") -def test_tail(pipeline: Pipeline) -> None: - assert pipeline.tail == pipeline.steps[-1] == last(pipeline) - - -@parametrize_with_cases("pipeline", cases=".") -def test_len(pipeline: Pipeline) -> None: - assert len(pipeline) == len(pipeline.steps) - - -@parametrize_with_cases("pipeline", cases=".") -def test_iter_shallow(pipeline: Pipeline) -> None: - assert all(a == b for a, b in zip(pipeline, pipeline.steps)) - - -@parametrize_with_cases("pipeline", cases=".", has_tag="deep") -def test_traverse_contains_deeper_items_than_iter(pipeline: Pipeline) -> None: - # TODO: This should probably be tested better and with a specific example - shallow_items = {s.name for s in pipeline.iter()} - deep_items = {s.name for s in pipeline.traverse()} - - assert deep_items.issuperset(shallow_items) - assert len(deep_items) > len(shallow_items) - - -@parametrize_with_cases("pipeline", cases=".", has_tag="deep") -def test_traverse_contains_no_duplicates(pipeline: Pipeline) -> None: - seen: set[str] = set() - for item in pipeline.traverse(): - assert item.name not in seen - seen.add(item.name) - - -@parametrize_with_cases("pipeline", cases=".") -def test_find_shallow_success(pipeline: Pipeline) -> None: - for selected_step in pipeline.steps: - assert pipeline.find(selected_step.name) == selected_step - - -@parametrize_with_cases("pipeline", cases=".") -def test_find_default(pipeline: Pipeline) -> None: - o = _A() - assert pipeline.find("dummy", default=o) is o - - -@parametrize_with_cases("pipeline", cases=".") -def test_find_not_present(pipeline: Pipeline) -> None: - assert pipeline.find("dummy") is None - - -@parametrize_with_cases("pipeline", cases=".", has_tag="deep") -def test_find_deep(pipeline: Pipeline) -> None: - selected_step = random.choice(list(pipeline.traverse())) # noqa: S311 - assert pipeline.find(selected_step.name) == selected_step - - -def test_or_operator() -> None: - p1 = Pipeline.create(step("1", _A) | step("2", _A)) - p2 = Pipeline.create(step("3", _A) | step("4", _A)) - s = step("hello", _A) - pnew = p1 | p2 | s - assert pnew == p1 | p2 | s - - -def test_append() -> None: - p1 = Pipeline.create(step("1", _A) | step("2", _A)) - p2 = Pipeline.create(step("3", _A) | step("4", _A)) - s = step("hello", _A) - pnew = p1.append(p2).append(s) - # Need to make sure they have the same name for pipeline equality - assert pnew == Pipeline.create(p1, p2, s, name=pnew.name) - - -@parametrize_with_cases("pipeline", cases=".") -def test_replace(pipeline: Pipeline) -> None: - new_step = step("replacement", _A) - for selected_step in pipeline.traverse(): - assert selected_step in pipeline - assert new_step not in pipeline - - new_pipeline = pipeline.replace(selected_step.name, new_step) - assert selected_step not in new_pipeline - assert new_step in new_pipeline - - replacement_step = new_pipeline.find(new_step.name) - assert replacement_step is not None - assert replacement_step == new_step - assert replacement_step.nxt == selected_step.nxt - assert replacement_step.prv == selected_step.prv - - -@parametrize_with_cases("pipeline", cases=".", has_tag="shallow") -def test_remove_shallow(pipeline: Pipeline) -> None: - for s in pipeline.steps: - new_pipeline = pipeline.remove(s.name) - expected_steps = [*s.preceeding(), *s.proceeding()] - assert new_pipeline.steps == expected_steps - - -@parametrize_with_cases("pipeline", cases=".", has_tag="deep") -def test_remove_deep(pipeline: Pipeline) -> None: - for selected_step in pipeline.traverse(): - selected_prv = selected_step.prv - selected_nxt = selected_step.nxt - - assert selected_step in pipeline - new_pipeline = pipeline.remove(selected_step.name) - - assert selected_step not in new_pipeline - - # Ensure that the previous and next steps are still connected - if selected_prv is not None: - new_prv = new_pipeline.find(selected_prv.name) - assert new_prv is not None - assert new_prv.nxt == selected_nxt - - if selected_nxt is not None: - new_nxt = new_pipeline.find(selected_nxt.name) - assert new_nxt is not None - assert new_nxt.prv == selected_prv - - -@parametrize_with_cases("pipeline", cases=".") -def test_duplicate_name_error(pipeline: Pipeline) -> None: - first_step = pipeline.head - name = first_step.name - with pytest.raises(Step.DuplicateNameError): - pipeline | step(name, _A) # pyright: ignore[reportUnusedExpression] - - -def test_qualified_name() -> None: - pipeline = Pipeline.create( - step("1", _A) - | step("2", _A) - | split( - "split", - step("split1", _A) | step("split2", _A), - ) - | choice( - "3", - step("4", _A), - step("5", _A), - ), - ) - assert pipeline.qualified_name("1") == "1" - assert pipeline.qualified_name("2") == "2" - assert pipeline.qualified_name("split") == "split" - assert pipeline.qualified_name("split1") == "split:split1" - assert pipeline.qualified_name("split2") == "split:split2" - assert pipeline.qualified_name("3") == "3" - assert pipeline.qualified_name("4") == "3:4" - assert pipeline.qualified_name("5") == "3:5" - - -@parametrize_with_cases("pipeline", cases=".") -def test_renaming_function(pipeline: Pipeline) -> None: - new_name = "replaced_name" - x = step("nothing", _A) - - assert pipeline.replace(x.name, x, name=new_name).name == new_name - assert pipeline.remove(x.name, name=new_name).name == new_name - assert pipeline.append(x, name=new_name).name == new_name - assert pipeline.select({x.name: x.name}, name=new_name).name == new_name - - -def test_param_requests() -> None: - pipeline = Pipeline.create( - step("1", _A, config={"seed": request("seed1")}) - | step("2", _A, config={"seed": request("seed2")}) - | split( - "split", - ( - step( - "split1", - _A, - config={ - "seed": request("seed1", required=True), - }, - ) - | step( - "split2", - _A, - config={"seed": request("seed4", default=4)}, - ) - ), - ) - | choice( - "3", - step("4", _A, config={"seed": None}), - step("5", _A), - ), - ) - configured_pipeline = pipeline.configure(config={}, params={"seed1": 1, "seed2": 2}) - - assert configured_pipeline.config() == { - "1:seed": 1, - "2:seed": 2, - "split:split1:seed": 1, - "split:split2:seed": 4, - "3:4:seed": None, - } - - assert configured_pipeline == Pipeline.create( - step("1", _A, config={"seed": 1}) - | step("2", _A, config={"seed": 2}) - | split( - "split", - ( - step("split1", _A, config={"seed": 1}) - | step("split2", _A, config={"seed": 4}) - ), - ) - | choice( - "3", - step("4", _A, config={"seed": None}), - step("5", _A), - ), - name=pipeline.name, - ) diff --git a/tests/pipeline/test_searchable.py b/tests/pipeline/test_searchable.py new file mode 100644 index 00000000..052ba8c3 --- /dev/null +++ b/tests/pipeline/test_searchable.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from amltk.pipeline import Searchable + + +def test_searchable_construction() -> None: + component = Searchable({"x": ["red", "green", "blue"]}, name="searchable") + assert component.name == "searchable" + assert component.space == {"x": ["red", "green", "blue"]} + + +def test_searchable_copyable() -> None: + component = Searchable({"x": ["red", "green", "blue"]}, name="searchable") + assert component.copy() == component diff --git a/tests/pipeline/test_sequential.py b/tests/pipeline/test_sequential.py new file mode 100644 index 00000000..48da6d37 --- /dev/null +++ b/tests/pipeline/test_sequential.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from amltk.pipeline import Component, Sequential + + +@dataclass +class Thing: + """A thing.""" + + x: int = 1 + + +def test_sequential_construction_empty() -> None: + seq = Sequential(name="seq") + assert seq.name == "seq" + assert seq.nodes == () + + +def test_sequential_construction() -> None: + seq = Sequential( + Component(Thing, name="comp1"), + Component(Thing, name="comp2"), + name="seq", + ) + assert seq.name == "seq" + assert seq.nodes == (Component(Thing, name="comp1"), Component(Thing, name="comp2")) + + +def test_sequential_copy() -> None: + seq1 = Sequential(Component(Thing, name="comp1", config={"x": 1}), name="seq1") + seq2 = Sequential(Component(Thing, name="comp2", config={"x": 1}), name="seq2") + + assert seq1 == seq1.copy() + assert seq2 == seq2.copy() + + seq3 = seq1 >> seq2 + assert seq3 == seq3.copy() + + +def test_sequential_rshift() -> None: + """__rshift__ changes behavior when compared to other nodes.""" + seq = ( + Sequential(name="seq") + >> Component(Thing, name="comp1", config={"x": 1}) + >> Component(Thing, name="comp2", config={"x": 1}) + >> Component(Thing, name="comp3", config={"x": 1}) + ) + + assert seq == Sequential( + Component(Thing, name="comp1", config={"x": 1}), + Component(Thing, name="comp2", config={"x": 1}), + Component(Thing, name="comp3", config={"x": 1}), + name="seq", + ) diff --git a/tests/pipeline/test_split.py b/tests/pipeline/test_split.py index 342e72f9..f1f9d5d6 100644 --- a/tests/pipeline/test_split.py +++ b/tests/pipeline/test_split.py @@ -1,318 +1,36 @@ from __future__ import annotations -from amltk import split, step -from amltk.pipeline import Step, choice +from dataclasses import dataclass +from amltk.pipeline import Component, Split -def test_split() -> None: - split_step = split( - "split", - step("1", object) | step("2", object), - step("3", object) | step("4", object), - ) - expected_npaths = 2 - assert len(split_step.paths) == expected_npaths - - -def test_traverse_one_layer() -> None: - s1, s2, s3, s4 = ( - step("1", object), - step("2", object), - step("3", object), - step("4", object), - ) - split_step = split("split", s1 | s2, s3 | s4) - - assert list(split_step.traverse()) == [split_step, s1, s2, s3, s4] - - -def test_traverse_one_deep() -> None: - s1, s2, s3, s4 = ( - step("1", object), - step("2", object), - step("3", object), - step("4", object), - ) - subsplit = split("subsplit", s3 | s4) - split_step = split("split", s1, s2 | subsplit) - - assert list(split_step.traverse()) == [split_step, s1, s2, subsplit, s3, s4] - - -def test_traverse_sequential_splits() -> None: - s1, s2, s3, s4, s5, s6, s7, s8 = (step(str(i), object) for i in range(1, 9)) - split1 = split("split1", s1, s2) - split2 = split("split2", s3, s4) - split3 = split("split3", s5, s6) - split4 = split("split4", s7, s8) - steps = Step.join(split1, split2, split3, split4) - - expected = [split1, s1, s2, split2, s3, s4, split3, s5, s6, split4, s7, s8] - assert list(steps.traverse()) == expected - - -def test_traverse_deep() -> None: - s1, s2, s3, s4, s5, s6, s7, s8 = (step(str(i), object) for i in range(1, 9)) - subsub_split1 = split("subsplit1", s3 | s4) - sub_split1 = split("subsubsplit1", s1, s2 | subsub_split1) - - subsub_split2 = split("subsplit2", s7 | s8) - sub_split2 = split("subssubplit2", s5, s6 | subsub_split2) - - split_step = split("split1", sub_split1, sub_split2) - expected = [ - split_step, - sub_split1, - s1, - s2, - subsub_split1, - s3, - s4, - sub_split2, - s5, - s6, - subsub_split2, - s7, - s8, - ] - assert list(split_step.traverse()) == expected +@dataclass +class Thing: + """A thing.""" + x: int = 1 -def test_remove_split() -> None: - s1, s2, s3, s4, s5 = ( - step("1", object), - step("2", object), - step("3", object), - step("4", object), - step("5", object), - ) - split_step = split( - "split", - s1, - s2 | split("subsplit", s3 | s4) | s5, - ) - - new = Step.join(split_step.remove(["subsplit"])) - assert new == split( - "split", - s1, - s2 | s5, - ) - new = Step.join(split_step.remove(["3"])) - assert new == split( - "split", - s1, - s2 | split("subsplit", s4) | s5, - ) - - -def test_replace_split() -> None: - s1, s2, s3, s4, s5 = ( - step("1", object), - step("2", object), - step("3", object), - step("4", object), - step("5", object), - ) - split_step = split( - "split", - s1, - s2 | split("subsplit", s3 | s4) | s5, - ) +def test_split_creation_empty() -> None: + split = Split(name="split") + assert split.name == "split" + assert split.nodes == () - replacement = step("replacement", object) - new = Step.join(split_step.replace({"subsplit": replacement})) - assert new == split( - "split", - s1, - s2 | replacement | s5, - ) - new = Step.join(split_step.replace({"3": replacement})) - assert new == split( - "split", - s1, - s2 | split("subsplit", replacement | s4) | s5, +def test_split_construction() -> None: + split = Split( + Component(Thing, name="comp1"), + Component(Thing, name="comp2"), + name="split", ) - - -def test_split_on_path_with_one_entry_removes_properly() -> None: - s = split("split", step("1", object), step("2", object)) - result = next(s.remove(["1"])) - assert result == split("split", step("2", object)) - - -def test_split_on_head_of_path_does_not_remove_rest_of_path() -> None: - s = split("split", step("1", object) | step("2", object)) - result = next(s.remove(["1"])) - assert result == split("split", step("2", object)) - - -def test_configure_single() -> None: - s1 = split( - "split", - step("1", object, space={"a": [1, 2, 3]}) - | step("2", object, space={"b": [1, 2, 3]}), - step("3", object, space={"c": [1, 2, 3]}), - item=object, - space={"split_space": [1, 2, 3]}, + assert split.name == "split" + assert split.nodes == ( + Component(Thing, name="comp1"), + Component(Thing, name="comp2"), ) - configured_s1 = s1.configure({"split_space": 1, "1:a": 1, "2:b": 2, "3:c": 3}) - - expected_configs_by_name = { - "split": {"split_space": 1}, - "1": {"a": 1}, - "2": {"b": 2}, - "3": {"c": 3}, - } - for s in configured_s1.traverse(): - assert s.config == expected_configs_by_name[s.name] - assert s.search_space is None - - -def test_split_with_step_and_nested_choice() -> None: - s1 = split( - "split", - step("1", object, space={"a": [1, 2, 3]}) - | step("2", object, space={"b": [1, 2, 3]}), - choice( - "choice", - step("3", object, space={"c": [1, 2, 3]}) - | step("4", object, space={"d": [1, 2, 3]}), - step("5", object, space={"e": [1, 2, 3]}), - ), - config={"hello": "world"}, - ) - config = { - "split:1:a": 1, - "split:2:b": 1, - "split:choice": "3", - "split:choice:3:c": 1, - "split:choice:4:d": 1, - } - - expected = split( - "split", - step("1", object, config={"a": 1}) | step("2", object, config={"b": 1}), - step("3", object, config={"c": 1}) | step("4", object, config={"d": 1}), - config={"hello": "world"}, - ) - - assert s1.configure(config) == expected - - -def test_configure_chained() -> None: - head = ( - split( - "split", - step("1", object, space={"a": [1, 2, 3]}), - ) - | step("2", object, space={"b": [1, 2, 3]}) - | step("3", object, space={"c": [1, 2, 3]}) - ) - configured_head = head.configure({"split:1:a": 1, "2:b": 2, "3:c": 3}) - - expected_configs = { - "split": None, - "1": {"a": 1}, - "2": {"b": 2}, - "3": {"c": 3}, - } - for s in configured_head.traverse(): - assert s.config == expected_configs[s.name] - assert s.search_space is None - - -def test_qualified_name() -> None: - head = split( - "split", - step("0", object), - split("subsplit1", step("1", object) | step("2", object)), - split("subsplit2", step("3", object) | step("4", object)), - ) - assert head.qualified_name() == "split" - - s0 = head.find("0") - assert s0 is not None - assert s0.qualified_name() == "split:0" - - subsplit1 = head.find("subsplit1") - assert subsplit1 is not None - assert subsplit1.qualified_name() == "split:subsplit1" - - s1 = head.find("1") - assert s1 is not None - assert s1.qualified_name() == "split:subsplit1:1" - - s2 = head.find("2") - assert s2 is not None - assert s2.qualified_name() == "split:subsplit1:2" - - subsplit2 = head.find("subsplit2") - assert subsplit2 is not None - assert subsplit2.qualified_name() == "split:subsplit2" - - s3 = head.find("3") - assert s3 is not None - assert s3.qualified_name() == "split:subsplit2:3" - - s4 = head.find("4") - assert s4 is not None - assert s4.qualified_name() == "split:subsplit2:4" - - -def test_path_to() -> None: - head = split( - "split", - step("1", object), - split( - "subsplit", - step("2", object) | step("3", object), - ), - ) - _split = head.find("split") - assert _split is not None - - s1 = head.find("1") - assert s1 is not None - - subsplit = head.find("subsplit") - assert subsplit is not None - - s2 = head.find("2") - assert s2 is not None - - s3 = head.find("3") - assert s3 is not None - - assert _split.path_to(_split) == [_split] - assert _split.path_to(s1) == [_split, s1] - assert _split.path_to(subsplit) == [_split, subsplit] - assert _split.path_to(s2) == [_split, subsplit, s2] - assert _split.path_to(s3) == [_split, subsplit, s2, s3] - - assert s1.path_to(_split) == [s1, _split] - assert s1.path_to(s1) == [s1] - assert s1.path_to(subsplit) is None - assert s1.path_to(s2) is None - assert s1.path_to(s3) is None - - assert subsplit.path_to(_split) == [subsplit, _split] - assert subsplit.path_to(s1) is None - assert subsplit.path_to(subsplit) == [subsplit] - assert subsplit.path_to(s2) == [subsplit, s2] - assert subsplit.path_to(s3) == [subsplit, s2, s3] - assert s2.path_to(_split) == [s2, subsplit, _split] - assert s2.path_to(s1) is None - assert s2.path_to(subsplit) == [s2, subsplit] - assert s2.path_to(s2) == [s2] - assert s2.path_to(s3) == [s2, s3] - assert s3.path_to(_split) == [s3, s2, subsplit, _split] - assert s3.path_to(s1) is None - assert s3.path_to(subsplit) == [s3, s2, subsplit] - assert s3.path_to(s2) == [s3, s2] - assert s3.path_to(s3) == [s3] +def test_split_copy() -> None: + split = Split(Component(Thing, name="comp1", config={"x": 1}), name="split1") + assert split == split.copy() diff --git a/tests/pipeline/test_steps.py b/tests/pipeline/test_steps.py deleted file mode 100644 index 5d234e79..00000000 --- a/tests/pipeline/test_steps.py +++ /dev/null @@ -1,194 +0,0 @@ -from __future__ import annotations - -from amltk.pipeline import Component, Step, request, step - - -def test_step_component() -> None: - o = object - s = step("name", object) - assert s.name == "name" - assert s.item is o - assert s.config is None - assert isinstance(s, Component) - - -def test_step_searchable() -> None: - s = step("name", object, space={"a": [1, 2]}, config={"b": 2}) - assert s.name == "name" - assert s.search_space == {"a": [1, 2]} - assert s.config == {"b": 2} - assert isinstance(s, Component) - - -def test_step_joinable() -> None: - s1 = step("1", object) - s2 = step("2", object) - steps = s1 | s2 - - assert list(steps.iter()) == [s1, s2] - - -def test_step_head() -> None: - s1 = step("1", object) - s2 = step("2", object) - - x = s1 | s2 - - assert x.head() == s1 - - -def test_step_tail() -> None: - s1 = step("1", object) - s2 = step("2", object) - - x = s1 | s2 - - assert x.tail() == s2 - - -def test_step_iter() -> None: - s1 = step("1", object) - s2 = step("2", object) - s3 = step("3", object) - start = s1 | s2 | s3 - - middle = start.nxt - assert middle is not None - end = middle.nxt - assert end is not None - - # start - middle - end - # s1 - s2 - s3 - - assert list(start.iter()) == [s1, s2, s3] - assert list(middle.iter()) == [s2, s3] - assert list(end.iter()) == [s3] - - assert list(start.iter(backwards=True)) == [s1] - assert list(middle.iter(backwards=True)) == [s2, s1] - assert list(end.iter(backwards=True)) == [s3, s2, s1] - - assert list(start.proceeding()) == [s2, s3] - assert list(middle.proceeding()) == [s3] - assert list(end.proceeding()) == [] - - assert list(start.preceeding()) == [] - assert list(middle.preceeding()) == [s1] - assert list(end.preceeding()) == [s1, s2] - - -def test_join() -> None: - s1 = step("1", object) - s2 = step("2", object) - s3 = step("3", object) - - assert list(Step.join([s1, s2, s3]).iter()) == [s1, s2, s3] - assert list(Step.join(s1, [s2, s3]).iter()) == [s1, s2, s3] - assert list(Step.join([s1], [s2, s3]).iter()) == [s1, s2, s3] - assert list(Step.join([s1, s2], s3).iter()) == [s1, s2, s3] - - -def test_append_single() -> None: - s1 = step("1", object) - s2 = step("2", object) - x = s1.append(s2) - - assert list(x.iter()) == [s1, s2] - - -def test_append_chain() -> None: - s1 = step("1", object) - s2 = step("2", object) - s3 = step("3", object) - x = s1.append(s2 | s3) - - assert list(x.iter()) == [s1, s2, s3] - - -def test_configure_single() -> None: - s1 = step("1", object, space={"a": [1, 2, 3]}) - configured_s1 = s1.configure({"a": 1}) - - assert configured_s1.config == {"a": 1} - assert configured_s1.search_space is None - - -def test_configure_chain() -> None: - head = ( - step("1", object, space={"a": [1, 2, 3]}) - | step("2", object, space={"b": [1, 2, 3]}) - | step("3", object, space={"c": [1, 2, 3]}) - ) - configured_head = head.configure({"1:a": 1, "2:b": 2, "3:c": 3}) - - expected_configs = [ - {"a": 1}, - {"b": 2}, - {"c": 3}, - ] - for s, expected_config in zip(configured_head.iter(), expected_configs): - assert s.config == expected_config - assert s.search_space is None - - -def test_qualified_name() -> None: - head = step("1", object) | step("2", object) | step("3", object) - last = head.tail() - - # Should not have any prefixes from the other steps - assert last.qualified_name() == "3" - - -def test_path_to() -> None: - head = step("1", object) | step("2", object) | step("3", object) - - s1 = head.find("1") - assert s1 is not None - - s2 = head.find("2") - assert s2 is not None - - s3 = head.find("3") - assert s3 is not None - - assert s1.path_to(s1) == [s1] - assert s1.path_to(s3) == [s1, s2, s3] - assert s1.path_to(s2) == [s1, s2] - - assert s2.path_to(s2) == [s2] - assert s2.path_to(s3) == [s2, s3] - assert s2.path_to(s1) == [s2, s1] - - assert s3.path_to(s3) == [s3] - assert s3.path_to(s1) == [s3, s2, s1] - assert s3.path_to(s2) == [s3, s2] - - assert s3.path_to(s1, direction="forward") is None - assert s2.path_to(s1, direction="forward") is None - assert s1.path_to(s1, direction="forward") == [s1] - - assert s3.path_to(s3, direction="backward") == [s3] - assert s2.path_to(s3, direction="backward") is None - assert s1.path_to(s3, direction="backward") is None - - -def test_param_request() -> None: - component = step( - "rf", - object, - space={"n_estimators": (10, 100), "criterion": ["gini", "entropy"]}, - config={"random_state": request("seed", default=None)}, - ) - - config = {"n_estimators": 10, "criterion": "gini"} - configured_component = component.configure(config, params={"seed": 42}) - - assert configured_component == step( - "rf", - object, - config={ - "n_estimators": 10, - "criterion": "gini", - "random_state": 42, - }, - ) diff --git a/tests/pipeline/test_walk.py b/tests/pipeline/test_walk.py deleted file mode 100644 index 942b2412..00000000 --- a/tests/pipeline/test_walk.py +++ /dev/null @@ -1,24 +0,0 @@ -"""These tests are specifically traversal of steps and pipelines.""" -from __future__ import annotations - -from pytest_cases import parametrize - -from amltk import Pipeline, step - - -@parametrize("size", [1, 3, 10]) -def test_walk_shallow_pipeline(size: int) -> None: - pipeline = Pipeline.create(step(str(i), object) for i in range(size)) - - walk = pipeline.walk() - - # Ensure the head has no splits or parents - splits, parents, head = next(walk) - assert not any(splits) - assert not any(parents) - assert head == pipeline.head - - for splits, parents, current_step in walk: - assert not any(splits) - # Ensure that the parents are all the steps from the head up to the current step - assert parents == list(pipeline.head.iter(to=current_step)) diff --git a/tests/pipeline/test_xgboost.py b/tests/pipeline/test_xgboost.py index a48c1823..5eff66f8 100644 --- a/tests/pipeline/test_xgboost.py +++ b/tests/pipeline/test_xgboost.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Literal, Mapping +from collections.abc import Mapping +from typing import Literal import pytest from pytest_cases import parametrize @@ -26,11 +27,11 @@ def test_xgboost_pipeline_default_creation( name = f"name_{kind}_{space}" xgb = xgboost_component(name=name, kind=kind, space=space) assert xgb.item is expected_type - assert isinstance(xgb.search_space, Mapping) - assert len(xgb.search_space) > 1 + assert isinstance(xgb.space, Mapping) + assert len(xgb.space) > 1 assert xgb.name == name - model = xgb.build() + model = xgb.build_item() assert isinstance(model, expected_type) @@ -38,8 +39,8 @@ def test_xgboost_custom_config() -> None: eta = 0.341 xgb = xgboost_component("classifier", config={"eta": eta}) - assert isinstance(xgb.search_space, Mapping) - assert "eta" not in xgb.search_space + assert isinstance(xgb.space, Mapping) + assert "eta" not in xgb.space assert xgb.config is not None assert xgb.config["eta"] == eta diff --git a/tests/threadpoolctl/__init__.py b/tests/scheduling/plugins/__init__.py similarity index 100% rename from tests/threadpoolctl/__init__.py rename to tests/scheduling/plugins/__init__.py diff --git a/tests/scheduling/test_call_limiter_plugin.py b/tests/scheduling/plugins/test_call_limiter_plugin.py similarity index 87% rename from tests/scheduling/test_call_limiter_plugin.py rename to tests/scheduling/plugins/test_call_limiter_plugin.py index cd0007ba..e1c74f81 100644 --- a/tests/scheduling/test_call_limiter_plugin.py +++ b/tests/scheduling/plugins/test_call_limiter_plugin.py @@ -4,15 +4,15 @@ import time import warnings from collections import Counter +from collections.abc import Iterator from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor -from typing import Any, Iterator +from typing import Any from dask.distributed import Client, LocalCluster, Worker from distributed.cfexecutor import ClientExecutor from pytest_cases import case, fixture, parametrize_with_cases -from amltk.scheduling import ExitState, Scheduler -from amltk.scheduling.task_plugin import CallLimiter +from amltk.scheduling import ExitState, Limiter, Scheduler @case(tags=["executor"]) @@ -63,15 +63,15 @@ def time_wasting_function(duration: int) -> int: def test_concurrency_limit_of_tasks(scheduler: Scheduler) -> None: - limiter = CallLimiter(max_concurrent=2) + limiter = Limiter(max_concurrent=2) task = scheduler.task(time_wasting_function, plugins=limiter) @scheduler.on_start(repeat=10) def launch_many() -> None: - task(duration=2) + task.submit(duration=2) end_status = scheduler.run(end_on_empty=True) - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) assert task.event_counts == Counter( { @@ -96,15 +96,15 @@ def launch_many() -> None: def test_call_limit_of_tasks(scheduler: Scheduler) -> None: - limiter = CallLimiter(max_calls=2) + limiter = Limiter(max_calls=2) task = scheduler.task(time_wasting_function, plugins=limiter) @scheduler.on_start(repeat=10) def launch() -> None: - task(duration=2) + task.submit(duration=2) end_status = scheduler.run(end_on_empty=True) - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) assert task.event_counts == Counter( { @@ -131,19 +131,19 @@ def launch() -> None: def test_call_limit_with_not_while_running(scheduler: Scheduler) -> None: task1 = scheduler.task(time_wasting_function) - limiter = CallLimiter(not_while_running=task1) + limiter = Limiter(not_while_running=task1) task2 = scheduler.task(time_wasting_function, plugins=limiter) @scheduler.on_start def launch() -> None: - task1(duration=2) + task1.submit(duration=2) @task1.on_submitted def launch2(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - task2(duration=2) + task2.submit(duration=2) end_status = scheduler.run(end_on_empty=True) - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) assert task1.event_counts == Counter( {task1.SUBMITTED: 1, task1.DONE: 1, task1.RESULT: 1}, diff --git a/tests/scheduling/test_comm_plugin.py b/tests/scheduling/plugins/test_comm_plugin.py similarity index 85% rename from tests/scheduling/test_comm_plugin.py rename to tests/scheduling/plugins/test_comm_plugin.py index bab67ab9..4b16991c 100644 --- a/tests/scheduling/test_comm_plugin.py +++ b/tests/scheduling/plugins/test_comm_plugin.py @@ -3,8 +3,9 @@ import logging import warnings from collections import Counter +from collections.abc import Hashable, Iterator from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor -from typing import Any, Hashable, Iterator +from typing import Any import pytest from dask.distributed import Client, LocalCluster, Worker @@ -12,32 +13,24 @@ from pytest_cases import case, fixture, parametrize_with_cases from amltk.scheduling import ExitState, Scheduler -from amltk.scheduling.comms import Comm +from amltk.scheduling.plugins import Comm logger = logging.getLogger(__name__) -pytest.skip( - "We havn't revisited the Comm's works in a while and it has issues with " - " dask. We will revisit this in the future.", - allow_module_level=True, -) - -def sending_worker(replies: list[Any], comm: Comm | None = None) -> None: +def sending_worker(comm: Comm, replies: list[Any]) -> None: """A worker that responds to messages. Args: comm: The communication channel to use. replies: A list of replies to send to the client. """ - assert comm is not None - - with comm: + with comm.open(): for reply in replies: comm.send(reply) -def requesting_worker(requests: list[Any], comm: Comm | None = None) -> None: +def requesting_worker(comm: Comm, requests: list[Any]) -> None: """A worker that waits for messages. This will send a request, waiting for a response, finally @@ -48,9 +41,7 @@ def requesting_worker(requests: list[Any], comm: Comm | None = None) -> None: comm: The communication channel to use. requests: A list of requests to receive from the client. """ - assert comm is not None - - with comm: + with comm.open(): for request in requests: response = comm.request(request) comm.send(response) @@ -78,7 +69,7 @@ def case_dask_executor() -> ClientExecutor: # Dask will raise a warning when re-using the ports, hence # we silence the warnings here. pytest.skip( - "Dask executor stopped support for passing Connection" " objects in 2023.4", + "Dask executor stopped support for passing Connection objects in 2023.4", ) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -114,7 +105,7 @@ def handle_msg(msg: Any) -> None: @scheduler.on_start def start() -> None: - task(replies) + task.submit(replies) end_status = scheduler.run() @@ -123,13 +114,14 @@ def start() -> None: task.SUBMITTED: 1, task.DONE: 1, task.RESULT: 1, + Comm.OPEN: 1, Comm.MESSAGE: len(replies), Comm.CLOSE: 1, }, ) assert task.event_counts == task_counts - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) scheduler_counts = Counter( { scheduler.STARTED: 1, @@ -138,6 +130,7 @@ def start() -> None: scheduler.EMPTY: 1, scheduler.FUTURE_SUBMITTED: 1, scheduler.FUTURE_DONE: 1, + scheduler.FUTURE_RESULT: 1, }, ) assert scheduler.event_counts == scheduler_counts @@ -161,10 +154,10 @@ def handle_msg(msg: Comm.Msg) -> None: @scheduler.on_start def start() -> None: - task(requests) + task.submit(requests) end_status = scheduler.run() - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) assert results == [2, 4, 6] @@ -173,6 +166,7 @@ def start() -> None: task.SUBMITTED: 1, task.DONE: 1, task.RESULT: 1, + Comm.OPEN: 1, Comm.MESSAGE: len(results), Comm.REQUEST: len(requests), Comm.CLOSE: 1, @@ -187,5 +181,6 @@ def start() -> None: scheduler.EMPTY: 1, scheduler.FUTURE_SUBMITTED: 1, scheduler.FUTURE_DONE: 1, + scheduler.FUTURE_RESULT: 1, }, ) diff --git a/tests/pynisher/test_pynisher_plugin.py b/tests/scheduling/plugins/test_pynisher_plugin.py similarity index 93% rename from tests/pynisher/test_pynisher_plugin.py rename to tests/scheduling/plugins/test_pynisher_plugin.py index ef87769f..cd1f1820 100644 --- a/tests/pynisher/test_pynisher_plugin.py +++ b/tests/scheduling/plugins/test_pynisher_plugin.py @@ -4,16 +4,16 @@ import time import warnings from collections import Counter +from collections.abc import Iterator from concurrent.futures import Executor, ProcessPoolExecutor -from typing import Iterator import pytest from dask.distributed import Client, LocalCluster, Worker from distributed.cfexecutor import ClientExecutor from pytest_cases import case, fixture, parametrize_with_cases -from amltk.pynisher import PynisherPlugin from amltk.scheduling import Scheduler +from amltk.scheduling.plugins.pynisher import PynisherPlugin @case(tags=["executor"]) @@ -78,7 +78,7 @@ def test_memory_limited_task(scheduler: Scheduler) -> None: @scheduler.on_start def start_task() -> None: - task(mem_in_bytes=two_gb) + task.submit(mem_in_bytes=two_gb) with pytest.raises(PynisherPlugin.MemoryLimitException): scheduler.run(on_exception="raise") @@ -109,12 +109,12 @@ def start_task() -> None: def test_time_limited_task(scheduler: Scheduler) -> None: task = scheduler.task( time_wasting_function, - plugins=PynisherPlugin(wall_time_limit=1), + plugins=PynisherPlugin(walltime_limit=1), ) @scheduler.on_start def start_task() -> None: - task(duration=3) + task.submit(duration=3) with pytest.raises(PynisherPlugin.WallTimeoutException): scheduler.run(on_exception="raise") @@ -147,12 +147,12 @@ def start_task() -> None: def test_cpu_time_limited_task(scheduler: Scheduler) -> None: task = scheduler.task( cpu_time_wasting_function, - plugins=PynisherPlugin(cpu_time_limit=1), + plugins=PynisherPlugin(cputime_limit=1), ) @scheduler.on_start def start_task() -> None: - task(iterations=int(1e16)) + task.submit(iterations=int(1e16)) with pytest.raises(PynisherPlugin.CpuTimeoutException): scheduler.run(on_exception="raise") diff --git a/tests/threadpoolctl/test_threadpoolctl_plugin.py b/tests/scheduling/plugins/test_threadpoolctl_plugin.py similarity index 85% rename from tests/threadpoolctl/test_threadpoolctl_plugin.py rename to tests/scheduling/plugins/test_threadpoolctl_plugin.py index 316d038b..ad338509 100644 --- a/tests/threadpoolctl/test_threadpoolctl_plugin.py +++ b/tests/scheduling/plugins/test_threadpoolctl_plugin.py @@ -8,16 +8,17 @@ # We need these imported to ensure that the threadpoolctl plugin # actually does something. -import numpy # noqa: F401 -import sklearn # noqa: F401 +import numpy # noqa: F401 # type: ignore +import pytest +import sklearn # noqa: F401 # type: ignore +import threadpoolctl from dask.distributed import Client, LocalCluster, Worker from distributed.cfexecutor import ClientExecutor from pytest_cases import case, fixture, parametrize_with_cases -import threadpoolctl from amltk.scheduling import Scheduler, SequentialExecutor +from amltk.scheduling.plugins.threadpoolctl import ThreadPoolCTLPlugin from amltk.scheduling.scheduler import ExitState -from amltk.threadpoolctl import ThreadPoolCTLPlugin from amltk.types import safe_isinstance logger = logging.getLogger(__name__) @@ -68,6 +69,14 @@ def f() -> list[Any]: def test_empty_kwargs_does_not_change_anything(scheduler: Scheduler) -> None: + if isinstance(scheduler.executor, ClientExecutor): + pytest.skip( + "Unfortunatly, dask is rather flaky in this tests." + " My current hypothesis is that this is due to the order in which" + " imports are done when dask uses it's own unpikcling strategy." + " It's rather undeterministic.", + ) + task = scheduler.task(f, plugins=ThreadPoolCTLPlugin()) retrieved_info = [] @@ -75,14 +84,14 @@ def test_empty_kwargs_does_not_change_anything(scheduler: Scheduler) -> None: @scheduler.on_start def start_task() -> None: - task() + task.submit() @task.on_result def check_threadpool_info(_, inner_info: list) -> None: retrieved_info.append(inner_info) end_status = scheduler.run() - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) inside_info = retrieved_info[0] after = threadpoolctl.threadpool_info() @@ -119,14 +128,14 @@ def test_limiting_thread_count_limits_only_inside_task(scheduler: Scheduler) -> @scheduler.on_start def start_task() -> None: - task() + task.submit() @task.on_result def check_threadpool_info(_, inner_info: list) -> None: retrieved_info.append(inner_info) end_status = scheduler.run() - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) inside_info = retrieved_info[0] after = threadpoolctl.threadpool_info() diff --git a/tests/scheduling/test_scheduler.py b/tests/scheduling/test_scheduler.py index 909c9b95..b763dbc9 100644 --- a/tests/scheduling/test_scheduler.py +++ b/tests/scheduling/test_scheduler.py @@ -100,7 +100,7 @@ def append_result(_: Future, res: float) -> None: @scheduler.on_start def launch_task() -> None: - task(sleep_time=sleep_time) + task.submit(sleep_time=sleep_time) end_status = scheduler.run(timeout=0.1, wait=True) assert results == [sleep_time] @@ -120,7 +120,7 @@ def launch_task() -> None: scheduler.FUTURE_RESULT: 1, }, ) - assert end_status == ExitState(code=Scheduler.ExitCode.TIMEOUT) + assert end_status == ExitState(code=ExitState.Code.TIMEOUT) assert scheduler.empty() assert not scheduler.running() @@ -139,7 +139,7 @@ def test_scheduler_with_timeout_and_not_wait_for_tasks(scheduler: Scheduler) -> results: list[float] = [] task = scheduler.task(sleep_and_return) - scheduler.on_start(lambda: task(sleep_time=10)) + scheduler.on_start(lambda: task.submit(sleep_time=10)) end_status = scheduler.run(timeout=0.1, wait=False) @@ -166,13 +166,13 @@ def test_scheduler_with_timeout_and_not_wait_for_tasks(scheduler: Scheduler) -> # something that can be done with Python's default executors. if isinstance( scheduler.executor, - (ClientExecutor, ProcessPoolExecutor), + ClientExecutor | ProcessPoolExecutor, ) or safe_isinstance(scheduler.executor, "_ReusablePoolExecutor"): expected_scheduler_counts[scheduler.FUTURE_CANCELLED] = 1 del expected_scheduler_counts[scheduler.FUTURE_DONE] assert scheduler.event_counts == expected_scheduler_counts - assert end_status == ExitState(code=Scheduler.ExitCode.TIMEOUT) + assert end_status == ExitState(code=ExitState.Code.TIMEOUT) assert scheduler.empty() assert not scheduler.running() @@ -183,11 +183,11 @@ def test_chained_tasks(scheduler: Scheduler) -> None: task_2 = scheduler.task(sleep_and_return) # Feed the output of task_1 into task_2 - task_1.on_result(lambda _, res: task_2(sleep_time=res)) + task_1.on_result(lambda _, res: task_2.submit(sleep_time=res)) task_1.on_result(lambda _, res: results.append(res)) task_2.on_result(lambda _, res: results.append(res)) - scheduler.on_start(lambda: task_1(sleep_time=0.1)) + scheduler.on_start(lambda: task_1.submit(sleep_time=0.1)) end_status = scheduler.run(wait=True) @@ -207,7 +207,7 @@ def test_chained_tasks(scheduler: Scheduler) -> None: }, ) assert results == [0.1, 0.1] - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) assert scheduler.empty() assert not scheduler.running() @@ -218,7 +218,7 @@ def test_queue_empty_status(scheduler: Scheduler) -> None: # Reload on the first empty @scheduler.on_empty(when=lambda: scheduler.event_counts[scheduler.EMPTY] == 1) def launch_first() -> None: - task(sleep_time=0.1) + task.submit(sleep_time=0.1) # Stop on the second empty @scheduler.on_empty(when=lambda: scheduler.event_counts[scheduler.EMPTY] == 2) @@ -243,7 +243,7 @@ def stop_scheduler() -> None: scheduler.FUTURE_RESULT: 1, }, ) - assert end_status == ExitState(code=Scheduler.ExitCode.STOPPED) + assert end_status == ExitState(code=ExitState.Code.STOPPED) assert scheduler.empty() @@ -264,7 +264,7 @@ def append_1() -> None: scheduler.EMPTY: 1, }, ) - assert end_status == ExitState(code=Scheduler.ExitCode.EXHAUSTED) + assert end_status == ExitState(code=ExitState.Code.EXHAUSTED) assert scheduler.empty() assert results == [1] * 10 @@ -274,7 +274,7 @@ def test_raise_on_exception_in_task(scheduler: Scheduler) -> None: @scheduler.on_start def run_task() -> None: - task() + task.submit() with pytest.raises(CustomError): scheduler.run(on_exception="raise") @@ -285,10 +285,10 @@ def test_end_on_exception_in_task(scheduler: Scheduler) -> None: @scheduler.on_start def run_task() -> None: - task() + task.submit() end_status = scheduler.run(on_exception="end") - assert end_status.code == Scheduler.ExitCode.EXCEPTION + assert end_status.code == ExitState.Code.EXCEPTION assert isinstance(end_status.exception, CustomError) @@ -297,10 +297,10 @@ def test_dont_end_on_exception_in_task(scheduler: Scheduler) -> None: @scheduler.on_start def run_task() -> None: - task() + task.submit() end_status = scheduler.run(on_exception="ignore") - assert end_status.code == Scheduler.ExitCode.EXHAUSTED + assert end_status.code == ExitState.Code.EXHAUSTED def test_cant_subscribe_to_nonexistent_event(scheduler: Scheduler) -> None: diff --git a/tests/sklearn/test_builder.py b/tests/sklearn/test_builder.py deleted file mode 100644 index 6d2f39a7..00000000 --- a/tests/sklearn/test_builder.py +++ /dev/null @@ -1,175 +0,0 @@ -from __future__ import annotations - -import numpy as np -import pandas as pd -from pytest_cases import parametrize -from sklearn.compose import ColumnTransformer, make_column_selector -from sklearn.ensemble import RandomForestClassifier -from sklearn.pipeline import Pipeline as SklearnPipeline -from sklearn.preprocessing import OrdinalEncoder, StandardScaler -from sklearn.svm import SVC - -from amltk.configspace import ConfigSpaceParser -from amltk.pipeline import Pipeline, SpaceAdapter, choice, group, split, step - -# Some toy data -X = pd.DataFrame({"a": ["1", "0", "1", "dog"], "b": [4, 5, 6, 7], "c": [7, 8, 9, 10]}) -y = pd.Series([1, 0, 1, 1]) - - -def test_simple_pipeline() -> None: - # Defining a pipeline - pipeline = Pipeline.create( - step("ordinal", OrdinalEncoder), - step("std", StandardScaler), - step("rf", RandomForestClassifier), - ) - - # Building the pipeline - sklearn_pipeline: SklearnPipeline = pipeline.build() - - # Fitting the pipeline - sklearn_pipeline.fit(X, y) - - # Predicting with the pipeline - sklearn_pipeline.predict(X) - - -def test_passthrough() -> None: - # Defining a pipeline - step("passthrough", "passthrough") - pipeline = Pipeline.create( - step("passthrough", "passthrough"), - split( - "split", - step("a", OrdinalEncoder), - step("b", "passthrough"), - item=ColumnTransformer, - config={ - "a": make_column_selector(dtype_include=object), - "b": make_column_selector(dtype_include=np.number), - }, - ), - ) - - # Building the pipeline - sklearn_pipeline: SklearnPipeline = pipeline.build() - - # Fitting the pipeline - Xt = sklearn_pipeline.fit_transform(X, y) - - # Should ordinal encoder the strings - assert np.array_equal(Xt[:, 0], np.array([1, 0, 1, 2])) - - # Should leave the remaining columns untouched - assert np.array_equal(Xt[:, 1], np.array([4, 5, 6, 7])) - assert np.array_equal(Xt[:, 2], np.array([7, 8, 9, 10])) - - -def test_simple_pipeline_with_group() -> None: - # Defining a pipeline - pipeline = Pipeline.create( - group( - "feature_preprocessing", - step("ordinal", OrdinalEncoder) | step("std", StandardScaler), - ), - step("rf", RandomForestClassifier), - ) - - # Building the pipeline - sklearn_pipeline: SklearnPipeline = pipeline.build() - - # Fitting the pipeline - sklearn_pipeline.fit(X, y) - - # Predicting with the pipeline - sklearn_pipeline.predict(X) - - -@parametrize("adapter", [ConfigSpaceParser()]) -@parametrize("seed", range(10)) -def test_split_with_choice(adapter: SpaceAdapter, seed: int) -> None: - # Defining a pipeline - pipeline = Pipeline.create( - split( - "feature_preprocessing", - group( - "categoricals", - step("ordinal", OrdinalEncoder) | step("std", StandardScaler), - ), - group( - "numericals", - step("scaler", StandardScaler, space={"with_mean": [True, False]}), - ), - item=ColumnTransformer, - config={ - "categoricals": make_column_selector(dtype_include=object), - "numericals": make_column_selector(dtype_include=np.number), - }, - ), - step( - "another_standard_scaler", - StandardScaler, - config={"with_mean": False}, - ), - choice( - "algorithm", - step( - "rf", - item=RandomForestClassifier, - space={ - "n_estimators": [10, 100], - "criterion": ["gini", "entropy", "log_loss"], - }, - ), - step("svm", SVC, space={"C": [0.1, 1, 10]}), - ), - name="test_pipeline_sklearn", - ) - - space = pipeline.space(parser=adapter) - config = pipeline.sample(space=space, sampler=adapter, seed=seed) - configured_pipeline = pipeline.configure(config) - - sklearn_pipeline = configured_pipeline.build() - assert isinstance(sklearn_pipeline, SklearnPipeline) - - sklearn_pipeline = sklearn_pipeline.fit(X, y) - sklearn_pipeline.predict(X) - - -@parametrize("adapter", [ConfigSpaceParser()]) -@parametrize("seed", range(10)) -def test_build_module(adapter: SpaceAdapter, seed: int) -> None: - # Defining a pipeline - pipeline = Pipeline.create( - choice( - "algorithm", - step( - "rf", - item=RandomForestClassifier, - space={ - "n_estimators": [10, 100], - "criterion": ["gini", "entropy", "log_loss"], - }, - ), - step("svm", SVC, space={"C": [0.1, 1, 10]}), - ), - name="test_pipeline_sklearn", - ) - submodule_pipeline = pipeline.copy(name="sub") - - pipeline = pipeline.attach(modules=submodule_pipeline) - - space = pipeline.space(parser=adapter) - - config = pipeline.sample(space=space, sampler=adapter, seed=seed) - - configured_pipeline = pipeline.configure(config) - - # Build the pipeline and module - built_pipeline = configured_pipeline.build() - built_sub_pipeline = configured_pipeline.modules["sub"].build() - - assert isinstance(built_pipeline, SklearnPipeline) - assert isinstance(built_sub_pipeline, SklearnPipeline) diff --git a/tests/sklearn/test_data.py b/tests/sklearn/test_data.py index 2b5387a6..b88d4553 100644 --- a/tests/sklearn/test_data.py +++ b/tests/sklearn/test_data.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence import numpy as np import pandas as pd diff --git a/tests/store/test_bucket.py b/tests/store/test_bucket.py index 53e4679e..e259b64c 100644 --- a/tests/store/test_bucket.py +++ b/tests/store/test_bucket.py @@ -2,8 +2,9 @@ import operator import shutil +from collections.abc import Callable, Iterator from pathlib import Path -from typing import Callable, Iterator, Literal, TypeVar +from typing import Any, Literal, TypeVar import numpy as np import pandas as pd @@ -31,7 +32,7 @@ def bucket_path_bucket(tmp_path: Path) -> Iterator[PathBucket]: shutil.rmtree(path) -def unjson_serialisable(x): +def unjson_serialisable(x: Any) -> Any: return x