-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(Pipeline): Optimize pipelines directly with
optimize()
(#230)
- Loading branch information
1 parent
4198de7
commit bded378
Showing
14 changed files
with
1,108 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
from amltk._richutil.renderable import RichRenderable | ||
from amltk._richutil.renderers import Function, rich_make_column_selector | ||
from amltk._richutil.util import df_to_table, richify | ||
from amltk._richutil.util import df_to_table, is_jupyter, richify | ||
|
||
__all__ = [ | ||
"df_to_table", | ||
"richify", | ||
"RichRenderable", | ||
"Function", | ||
"rich_make_column_selector", | ||
"is_jupyter", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
|
||
def threadpoolctl_heuristic(item_contained_in_node: Any | None) -> bool: | ||
"""Heuristic to determine if we should automatically set threadpoolctl. | ||
This is done by detecting if it's a scikit-learn `BaseEstimator` but this may | ||
be extended in the future. | ||
!!! tip | ||
The reason to have this heuristic is that when running scikit-learn, or any | ||
multithreaded model, in parallel, they will over subscribe to threads. This | ||
causes a significant performance hit as most of the time is spent switching | ||
thread contexts instead of work. This can be particularly bad for HPO where | ||
we are evaluating multiple models in parallel on the same system. | ||
The recommened thread count is 1 per core with no additional information to | ||
act upon. | ||
!!! todo | ||
This is potentially not an issue if running on multiple nodes of some cluster, | ||
as they do not share logical cores and hence do not clash. | ||
Args: | ||
item_contained_in_node: The item with which to base the heuristic on. | ||
Returns: | ||
Whether we should automatically set threadpoolctl. | ||
""" | ||
if item_contained_in_node is None or not isinstance(item_contained_in_node, type): | ||
return False | ||
|
||
try: | ||
# NOTE: sklearn depends on threadpoolctl so it will be installed. | ||
from sklearn.base import BaseEstimator | ||
|
||
return issubclass(item_contained_in_node, BaseEstimator) | ||
except ImportError: | ||
return False |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Evaluation protocols for how a trial and a pipeline should be evaluated. | ||
TODO: Sorry | ||
""" | ||
from __future__ import annotations | ||
|
||
from collections.abc import Callable, Iterable | ||
from typing import TYPE_CHECKING | ||
|
||
from amltk.scheduling import Plugin | ||
|
||
if TYPE_CHECKING: | ||
from amltk.optimization import Trial | ||
from amltk.pipeline import Node | ||
from amltk.scheduling import Scheduler, Task | ||
|
||
|
||
class EvaluationProtocol: | ||
"""A protocol for how a trial should be evaluated on a pipeline.""" | ||
|
||
fn: Callable[[Trial, Node], Trial.Report] | ||
|
||
def task( | ||
self, | ||
scheduler: Scheduler, | ||
plugins: Plugin | Iterable[Plugin] | None = None, | ||
) -> Task[[Trial, Node], Trial.Report]: | ||
"""Create a task for this protocol. | ||
Args: | ||
scheduler: The scheduler to use for the task. | ||
plugins: The plugins to use for the task. | ||
Returns: | ||
The created task. | ||
""" | ||
_plugins: tuple[Plugin, ...] | ||
match plugins: | ||
case None: | ||
_plugins = () | ||
case Plugin(): | ||
_plugins = (plugins,) | ||
case Iterable(): | ||
_plugins = tuple(plugins) | ||
|
||
return scheduler.task(self.fn, plugins=_plugins) | ||
|
||
|
||
class CustomProtocol(EvaluationProtocol): | ||
"""A custom evaluation protocol based on a user function.""" | ||
|
||
def __init__(self, fn: Callable[[Trial, Node], Trial.Report]) -> None: | ||
"""Initialize the protocol. | ||
Args: | ||
fn: The function to use for the evaluation. | ||
""" | ||
super().__init__() | ||
self.fn = fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.