Skip to content

Commit

Permalink
feat(Optuna): Allow for parsing of Choice Nodes (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
berombau authored Oct 31, 2024
1 parent b680838 commit 5221d80
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/amltk/optimization/optimizers/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def add_to_history(_, report: Trial.Report):
Sorry!
""" # noqa: E501

from __future__ import annotations

from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -291,8 +292,7 @@ def ask(
"""
if n is not None:
return (self.ask(n=None) for _ in range(n))

optuna_trial: optuna.Trial = self.study.ask(self.space)
optuna_trial = self.space.get_trial(self.study)
config = optuna_trial.params
trial_number = optuna_trial.number
unique_name = f"{trial_number=}"
Expand Down
112 changes: 100 additions & 12 deletions src/amltk/pipeline/parsers/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@
from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import numpy as np
import optuna
from optuna.distributions import (
BaseDistribution,
CategoricalChoiceType,
Expand All @@ -103,17 +105,95 @@
)

from amltk._functional import prefix_keys
from amltk.pipeline.components import Choice

if TYPE_CHECKING:
from typing import TypeAlias

from amltk.pipeline import Node

OptunaSearchSpace: TypeAlias = dict[str, BaseDistribution]

PAIR = 2


@dataclass
class OptunaSearchSpace:
"""A class to represent an Optuna search space.
Wraps a dictionary of hyperparameters and their Optuna distributions.
"""

distributions: dict[str, BaseDistribution] = field(default_factory=dict)

def __repr__(self) -> str:
return f"OptunaSearchSpace({self.distributions})"

def __str__(self) -> str:
return str(self.distributions)

@classmethod
def parse(cls, *args: Any, **kwargs: Any) -> OptunaSearchSpace:
"""Parse a Node into an Optuna search space."""
return parser(*args, **kwargs)

def sample_configuration(self) -> dict[str, Any]:
"""Sample a configuration from the search space using a default Optuna Study."""
study = optuna.create_study()
trial = self.get_trial(study)
return trial.params

def get_trial(self, study: optuna.Study) -> optuna.Trial:
"""Get a trial from a given Optuna Study using this search space."""
optuna_trial: optuna.Trial
if any("__choice__" in k for k in self.distributions):
optuna_trial = study.ask()
# do all __choice__ suggestions with suggest_categorical
workspace = self.distributions.copy()
filter_patterns = []
for name, distribution in workspace.items():
if "__choice__" in name and isinstance(
distribution,
CategoricalDistribution,
):
possible_choices = distribution.choices
choice_made = optuna_trial.suggest_categorical(
name,
choices=possible_choices,
)
for c in possible_choices:
if c != choice_made:
# deletable options have the name of the unwanted choices
filter_patterns.append(f":{c}:")
# filter all parameters for the unwanted choices
filtered_workspace = {
k: v
for k, v in workspace.items()
if (
("__choice__" not in k)
and (
not any(
filter_pattern in k for filter_pattern in filter_patterns
)
)
)
}
# do all remaining suggestions with the correct suggest function
for name, distribution in filtered_workspace.items():
match distribution:
case CategoricalDistribution(choices=choices):
optuna_trial.suggest_categorical(name, choices=choices)
case IntDistribution(
low=low,
high=high,
log=log,
):
optuna_trial.suggest_int(name, low=low, high=high, log=log)
case FloatDistribution(low=low, high=high):
optuna_trial.suggest_float(name, low=low, high=high)
case _:
raise ValueError(f"Unknown distribution: {distribution}")
else:
optuna_trial = study.ask(self.distributions)
return optuna_trial


def _convert_hp_to_optuna_distribution(
name: str,
hp: tuple | Sequence | CategoricalChoiceType | BaseDistribution,
Expand Down Expand Up @@ -149,7 +229,7 @@ def _convert_hp_to_optuna_distribution(
raise ValueError(f"Could not parse {name} as a valid Optuna distribution.\n{hp=}")


def _parse_space(node: Node) -> OptunaSearchSpace:
def _parse_space(node: Node) -> dict[str, BaseDistribution]:
match node.space:
case None:
space = {}
Expand Down Expand Up @@ -196,13 +276,21 @@ def parser(
delim: The delimiter to use for the names of the hyperparameters.
"""
if conditionals:
raise NotImplementedError("Conditionals are not yet supported with Optuna.")

space = prefix_keys(_parse_space(node), prefix=f"{node.name}{delim}")

for child in node.nodes:
subspace = parser(child, flat=flat, conditionals=conditionals, delim=delim)
children = node.nodes

if isinstance(node, Choice) and any(children):
name = f"{node.name}{delim}__choice__"
space[name] = CategoricalDistribution([child.name for child in children])

for child in children:
subspace = parser(
child,
flat=flat,
conditionals=conditionals,
delim=delim,
).distributions
if not flat:
subspace = prefix_keys(subspace, prefix=f"{node.name}{delim}")

Expand All @@ -214,4 +302,4 @@ def parser(
)
space[name] = hp

return space
return OptunaSearchSpace(distributions=space)
47 changes: 47 additions & 0 deletions tests/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from amltk.optimization import Metric, Optimizer, Trial
from amltk.pipeline import Component
from amltk.pipeline.components import Choice
from amltk.profiling import Timer

if TYPE_CHECKING:
Expand All @@ -24,6 +25,10 @@ class _A:
pass


class _B:
pass


metrics = [
Metric("score_bounded", minimize=False, bounds=(0, 1)),
Metric("score_unbounded", minimize=False),
Expand Down Expand Up @@ -87,6 +92,25 @@ def opt_optuna(metric: Metric, tmp_path: Path) -> OptunaOptimizer:
)


@case
@parametrize("metric", [*metrics, metrics]) # Single obj and multi
def opt_optuna_choice_hierarchical(metric: Metric, tmp_path: Path) -> OptunaOptimizer:
try:
from amltk.optimization.optimizers.optuna import OptunaOptimizer
except ImportError:
pytest.skip("Optuna is not installed")

c1 = Component(_A, name="hi1", space={"a": [1, 2, 3]})
c2 = Component(_B, name="hi2", space={"b": [4, 5, 6]})
pipeline = Choice(c1, c2, name="hi")
return OptunaOptimizer.create(
space=pipeline,
metrics=metric,
seed=42,
bucket=tmp_path,
)


@case
@parametrize("metric", [*metrics]) # Single obj
def opt_neps(metric: Metric, tmp_path: Path) -> NEPSOptimizer:
Expand Down Expand Up @@ -142,3 +166,26 @@ def test_batched_ask_generates_unique_configs(optimizer: Optimizer):
batch = list(optimizer.ask(10))
assert len(batch) == 10
assert all_unique(batch)


@parametrize_with_cases("optimizer", cases=".", prefix="opt_optuna_choice")
def test_optuna_choice_output(optimizer: Optimizer):
trial = optimizer.ask()
keys = list(trial.config.keys())
assert any("__choice__" in k for k in keys), trial.config


@parametrize_with_cases("optimizer", cases=".", prefix="opt_optuna_choice")
def test_optuna_choice_no_params_left(optimizer: Optimizer):
trial = optimizer.ask()
keys_without_choices = [
k for k in list(trial.config.keys()) if "__choice__" not in k
]
for k, v in trial.config.items():
if "__choice__" in k:
name_without_choice = k.removesuffix("__choice__")
params_for_choice = [
k for k in keys_without_choices if k.startswith(name_without_choice)
]
# Check that only params for the chosen choice are left
assert all(v in k for k in params_for_choice), params_for_choice
Loading

0 comments on commit 5221d80

Please sign in to comment.