-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from fugue-project/0.0.3.1
Add fugue_tune
- Loading branch information
Showing
17 changed files
with
950 additions
and
4 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import copy | ||
import inspect | ||
from typing import Any, Callable, Dict, List, Optional, no_type_check | ||
|
||
from fugue import ExecutionEngine | ||
from fugue._utils.interfaceless import ( | ||
FunctionWrapper, | ||
_ExecutionEngineParam, | ||
_FuncParam, | ||
_OtherParam, | ||
is_class_method, | ||
) | ||
from triad import assert_or_throw | ||
from triad.utils.convert import get_caller_global_local_vars, to_function | ||
|
||
from fugue_tune.exceptions import FugueTuneCompileError | ||
from fugue_tune.tunable import SimpleTunable, Tunable | ||
|
||
|
||
def tunable(distributable: bool = True) -> Callable[[Any], "_FuncAsTunable"]: | ||
def deco(func: Callable) -> "_FuncAsTunable": | ||
assert_or_throw( | ||
not is_class_method(func), | ||
NotImplementedError("tunable decorator can't be used on class methods"), | ||
) | ||
return _FuncAsTunable.from_func(func) | ||
|
||
return deco | ||
|
||
|
||
def _to_tunable( | ||
obj: Any, | ||
global_vars: Optional[Dict[str, Any]] = None, | ||
local_vars: Optional[Dict[str, Any]] = None, | ||
distributable: Optional[bool] = None, | ||
) -> Tunable: | ||
global_vars, local_vars = get_caller_global_local_vars(global_vars, local_vars) | ||
|
||
def get_tunable() -> Tunable: | ||
if isinstance(obj, Tunable): | ||
return copy.copy(obj) | ||
try: | ||
f = to_function(obj, global_vars=global_vars, local_vars=local_vars) | ||
# this is for string expression of function with decorator | ||
if isinstance(f, Tunable): | ||
return copy.copy(f) | ||
# this is for functions without decorator | ||
return _FuncAsTunable.from_func(f, distributable) | ||
except Exception as e: | ||
exp = e | ||
raise FugueTuneCompileError(f"{obj} is not a valid tunable function", exp) | ||
|
||
t = get_tunable() | ||
if distributable is None: | ||
distributable = t.distributable | ||
elif distributable: | ||
assert_or_throw( | ||
t.distributable, FugueTuneCompileError(f"{t} is not distributable") | ||
) | ||
return t | ||
|
||
|
||
class _SingleParam(_FuncParam): | ||
def __init__(self, param: Optional[inspect.Parameter]): | ||
super().__init__(param, "float", "s") | ||
|
||
|
||
class _DictParam(_FuncParam): | ||
def __init__(self, param: Optional[inspect.Parameter]): | ||
super().__init__(param, "Dict[str,Any]", "d") | ||
|
||
|
||
class _TunableWrapper(FunctionWrapper): | ||
def __init__(self, func: Callable): | ||
super().__init__(func, "^e?[^e]+$", "^[sd]$") | ||
|
||
def _parse_param( | ||
self, | ||
annotation: Any, | ||
param: Optional[inspect.Parameter], | ||
none_as_other: bool = True, | ||
) -> _FuncParam: | ||
if annotation is float: | ||
return _SingleParam(param) | ||
elif annotation is Dict[str, Any]: | ||
return _DictParam(param) | ||
elif annotation is ExecutionEngine: | ||
return _ExecutionEngineParam(param) | ||
else: | ||
return _OtherParam(param) | ||
|
||
@property | ||
def single(self) -> bool: | ||
return isinstance(self._rt, _SingleParam) | ||
|
||
@property | ||
def needs_engine(self) -> bool: | ||
return isinstance(self._params.get_value_by_index(0), _ExecutionEngineParam) | ||
|
||
|
||
class _FuncAsTunable(SimpleTunable): | ||
@no_type_check | ||
def tune(self, **kwargs: Any) -> Dict[str, Any]: | ||
# pylint: disable=no-member | ||
args: List[Any] = [self.execution_engine] if self._needs_engine else [] | ||
if self._single: | ||
return dict(error=self._func(*args, **kwargs)) | ||
else: | ||
return self._func(*args, **kwargs) | ||
|
||
@no_type_check | ||
def __call__(self, *args: Any, **kwargs: Any) -> Any: | ||
return self._func(*args, **kwargs) | ||
|
||
@property | ||
def distributable(self) -> bool: | ||
return self._distributable # type: ignore | ||
|
||
@no_type_check | ||
@staticmethod | ||
def from_func( | ||
func: Callable, distributable: Optional[bool] = None | ||
) -> "_FuncAsTunable": | ||
t = _FuncAsTunable() | ||
tw = _TunableWrapper(func) | ||
t._func = tw._func | ||
t._single = tw.single | ||
t._needs_engine = tw.needs_engine | ||
if distributable is None: | ||
t._distributable = not tw.needs_engine | ||
else: | ||
if distributable: | ||
assert_or_throw( | ||
not tw.needs_engine, | ||
"function with ExecutionEngine can't be distributable", | ||
) | ||
t._distributable = distributable | ||
|
||
return t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import Any | ||
|
||
from fugue.exceptions import FugueWorkflowCompileError, FugueWorkflowRuntimeError | ||
|
||
|
||
class FugueTuneCompileError(FugueWorkflowCompileError): | ||
def __init__(self, *args: Any): | ||
super().__init__(*args) | ||
|
||
|
||
class FugueTuneRuntimeError(FugueWorkflowRuntimeError): | ||
def __init__(self, *args: Any): | ||
super().__init__(*args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import itertools | ||
from typing import Any, Dict, Iterable, List, Tuple | ||
|
||
|
||
def dict_product( | ||
d: Dict[str, Iterable[Any]], safe: bool = True | ||
) -> Iterable[Dict[str, Any]]: | ||
keys = d.keys() | ||
arrays = list(d.values()) | ||
if len(arrays) == 0: | ||
if safe: | ||
yield {} | ||
return | ||
for element in _safe_product(arrays, safe): | ||
yield {k: v for k, v in zip(keys, element) if v is not _EMPTY_ITER} | ||
|
||
|
||
def product( | ||
arrays: List[Iterable[Any]], safe: bool = False, remove_empty: bool = True | ||
) -> Iterable[List[Any]]: | ||
if len(arrays) == 0: | ||
if safe: | ||
yield [] | ||
return | ||
if remove_empty: | ||
for x in _safe_product(arrays, safe): | ||
yield [xx for xx in x if xx is not _EMPTY_ITER] | ||
else: | ||
for x in _safe_product(arrays, safe): | ||
yield [None if xx is _EMPTY_ITER else xx for xx in x] | ||
|
||
|
||
def safe_iter(it: Iterable[Any], safe: bool = True) -> Iterable[Any]: | ||
if not safe: | ||
yield from it | ||
else: | ||
n = 0 | ||
for x in it: | ||
yield x | ||
n += 1 | ||
if n == 0: | ||
yield _EMPTY_ITER | ||
|
||
|
||
def _safe_product(arrays: List[Iterable[Any]], safe: bool = True) -> Iterable[Tuple]: | ||
if not safe: | ||
yield from itertools.product(*arrays) | ||
else: | ||
arr = [safe_iter(t) for t in arrays] | ||
yield from itertools.product(*arr) | ||
|
||
|
||
class _EmptyIter(object): | ||
pass | ||
|
||
|
||
_EMPTY_ITER = _EmptyIter() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from copy import deepcopy | ||
from typing import Any, Dict, Iterable, List, Tuple, no_type_check | ||
|
||
from fugue_tune.iter import dict_product, product | ||
|
||
|
||
class Grid(object): | ||
def __init__(self, *args: Any): | ||
self._values = list(args) | ||
|
||
def __iter__(self) -> Iterable[Any]: | ||
yield from self._values | ||
|
||
|
||
class Choice(object): | ||
def __init__(self, *args: Any): | ||
self._values = list(args) | ||
|
||
def __iter__(self) -> Iterable[Any]: | ||
yield from self._values | ||
|
||
|
||
class Rand(object): | ||
def __init__(self, start: float, end: float, q: float, log: bool, normal: bool): | ||
self._start = start | ||
self._end = end | ||
self._q = q | ||
self._log = log | ||
self._normal = normal | ||
|
||
|
||
# TODO: make this inherit from iterable? | ||
class Space(object): | ||
def __init__(self, **kwargs: Any): | ||
self._value = deepcopy(kwargs) | ||
self._grid: List[List[Tuple[Any, Any, Any]]] = [] | ||
for k in self._value.keys(): | ||
self._search(self._value, k) | ||
|
||
def __iter__(self) -> Iterable[Dict[str, Any]]: | ||
for tps in product(self._grid, safe=True, remove_empty=True): # type: ignore | ||
for tp in tps: | ||
tp[0][tp[1]] = tp[2] | ||
yield deepcopy(self._value) | ||
|
||
def __mul__(self, other: Any) -> "HorizontalSpace": | ||
return HorizontalSpace(self, other) | ||
|
||
def __add__(self, other: Any) -> "VerticalSpace": | ||
return VerticalSpace(self, other) | ||
|
||
def _search(self, parent: Any, key: Any) -> None: | ||
node = parent[key] | ||
if isinstance(node, Grid): | ||
self._grid.append(self._grid_wrapper(parent, key)) | ||
elif isinstance(node, dict): | ||
for k in node.keys(): | ||
self._search(node, k) | ||
elif isinstance(node, list): | ||
for i in range(len(node)): | ||
self._search(node, i) | ||
|
||
def _grid_wrapper(self, parent: Any, key: Any) -> List[Tuple[Any, Any, Any]]: | ||
return [(parent, key, x) for x in parent[key]] | ||
|
||
|
||
class HorizontalSpace(Space): | ||
def __init__(self, *args: Any, **kwargs: Any): | ||
self._groups: List[VerticalSpace] = [] | ||
for x in args: | ||
if isinstance(x, HorizontalSpace): | ||
self._groups.append(VerticalSpace(x)) | ||
elif isinstance(x, VerticalSpace): | ||
self._groups.append(x) | ||
elif isinstance(x, Space): | ||
self._groups.append(VerticalSpace(x)) | ||
elif isinstance(x, dict): | ||
self._groups.append(VerticalSpace(HorizontalSpace(**x))) | ||
elif isinstance(x, list): | ||
self._groups.append(VerticalSpace(*x)) | ||
else: | ||
raise ValueError(f"{x} is invalid") | ||
self._dict = {k: _SpaceValue(v) for k, v in kwargs.items()} | ||
|
||
@no_type_check # TODO: remove this? | ||
def __iter__(self) -> Iterable[Dict[str, Any]]: | ||
dicts = list(dict_product(self._dict, safe=True)) | ||
for spaces in product( | ||
[g.spaces for g in self._groups], safe=True, remove_empty=True | ||
): | ||
for comb in product(list(spaces) + [dicts], safe=True, remove_empty=True): | ||
res: Dict[str, Any] = {} | ||
for d in comb: | ||
res.update(d) | ||
yield res | ||
|
||
|
||
class VerticalSpace(Space): | ||
def __init__(self, *args: Any): | ||
self._spaces: List[Space] = [] | ||
for x in args: | ||
if isinstance(x, Space): | ||
self._spaces.append(x) | ||
elif isinstance(x, dict): | ||
self._spaces.append(Space(**x)) | ||
elif isinstance(x, list): | ||
self._spaces.append(VerticalSpace(*x)) | ||
else: | ||
raise ValueError(f"{x} is invalid") | ||
|
||
@property | ||
def spaces(self) -> List[Space]: | ||
return self._spaces | ||
|
||
def __iter__(self) -> Iterable[Dict[str, Any]]: | ||
for space in self._spaces: | ||
yield from space # type: ignore | ||
|
||
|
||
class _SpaceValue(object): | ||
def __init__(self, value: Any): | ||
self.value = value | ||
|
||
@no_type_check # TODO: remove this? | ||
def __iter__(self) -> Iterable[Any]: | ||
if isinstance(self.value, (HorizontalSpace, VerticalSpace)): | ||
yield from self.value | ||
elif isinstance(self.value, dict): | ||
yield from dict_product( | ||
{k: _SpaceValue(v) for k, v in self.value.items()}, safe=True | ||
) | ||
elif isinstance(self.value, list): | ||
yield from product([_SpaceValue(v) for v in self.value], safe=True) | ||
else: | ||
yield self.value |
Oops, something went wrong.