Skip to content

Commit

Permalink
Merge pull request #39 from zafarali/strategies_real_master
Browse files Browse the repository at this point in the history
Clean up for strategies
  • Loading branch information
williamFalcon authored Nov 21, 2018
2 parents 0703b08 + 9f40cb1 commit 5b49351
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 10 deletions.
59 changes: 49 additions & 10 deletions test_tube/hyper_opt_utils/strategies.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,81 @@
"""Hyperparameter search strategies."""
import itertools
import json
import random


def generate_trials(strategy, flat_params, nb_trials=None):
"""
r"""Generates the parameter combinations to search.
Two search strategies are implemented:
1. `grid_search`: creates a search space that consists of the
product of all flat_params. If `nb_trials` is specified
the first `nb_trials` combinations are searched.
2. `random_search`: Creates random combinations of the
hyperparameters. Can be used for a more efficient search.
See (Bergstra and Bengio, 2012) for more details.
:param strategy: The hyperparameter search to strategy. Can be
one of: {`grid_search`, `random`}.
:param flat_params: The hyperparameter arguments to iterate over.
:param nb_trials: The number of hyperparameter combinations to try.
Generates the parameter combinations for each requested trial
:param strategy:
:param flat_params:
:param nb_trials: The number of trials to un.
:return:
"""
# permute for grid search
if strategy == 'grid_search':
trials = generate_grid_search_trials(flat_params, nb_trials)
return trials

# generate random search
if strategy == 'random_search':
elif strategy == 'random_search':
trials = generate_random_search_trials(flat_params, nb_trials)
return trials
else:
raise ValueError(
('Unknown strategy "{}". Must be one of '
'{{grid_search, random_search}}').format(strategy))


def generate_grid_search_trials(flat_params, nb_trials):
"""
Standard grid search. Takes the product of `flat_params`
to generate the search space.
:param params: The hyperparameters options to search.
:param nb_trials: Returns the first `nb_trials` from the
combinations space. If this is None, all combinations
are returned.
:return: A dict containing the hyperparameters.
"""
trials = list(itertools.product(*flat_params))
if nb_trials:
trials = trials[0:nb_trials]
return trials


def generate_random_search_trials(params, nb_trials):
"""
Generates random combination of hyperparameters to try.
See (Bergstra and Bengio, 2012) for more details.
:param params: The hyperparameters options to search.
:param nb_trials: The number of trials to run.
:return: A dict containing the hyperparameters.
"""
if nb_trials is None:
raise TypeError(
'`random_search` strategy requires nb_trails to be an int.')
results = []

# ensures we have unique results
seen_trials = set()

# shuffle each param list
potential_trials = 1
for p in params:
random.shuffle(p)
potential_trials *= len(p)
for param in params:
random.shuffle(param)
potential_trials *= len(param)

# we can't sample more trials than are possible
max_iters = min(potential_trials, nb_trials)
Expand All @@ -46,8 +85,8 @@ def generate_random_search_trials(params, nb_trials):
while len(results) < max_iters:
trial = []
for param in params:
p = random.sample(param, 1)[0]
trial.append(p)
sampled_param = random.sample(param, 1)[0]
trial.append(sampled_param)

# verify this is a unique trial so we
# don't duplicate work
Expand Down
45 changes: 45 additions & 0 deletions tests/strategies_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from test_tube.hyper_opt_utils import strategies

GRID_SEARCH = 'grid_search'
RANDOM_SEARCH = 'random_search'

FLAT_PARAMS = [
[
{'idx': 0, 'val': 0.0001, 'name': 'learning_rate'},
{'idx': 1, 'val': 0.001, 'name': 'learning_rate'},
{'idx': 2, 'val': 0.01, 'name': 'learning_rate'},
{'idx': 3, 'val': 0.1, 'name': 'learning_rate'}
],
[
{'idx': 4, 'val': 0.99, 'name': 'decay'},
{'idx': 5, 'val': 0.999, 'name': 'decay'},
]
]
def test_unknown_strategy():
with pytest.raises(ValueError):
strategies.generate_trials(
'unknown_strategy', FLAT_PARAMS, nb_trials=None)

def test_grid_search_no_limit():
trials = strategies.generate_trials(
GRID_SEARCH, FLAT_PARAMS, nb_trials=None)
assert len(trials) == len(FLAT_PARAMS[0]) * len(FLAT_PARAMS[1])

def test_grid_search_limit():
trials = strategies.generate_trials(
GRID_SEARCH, FLAT_PARAMS, nb_trials=5)
assert len(trials) == 5


def test_random_search():
trials = strategies.generate_trials(
RANDOM_SEARCH, FLAT_PARAMS, nb_trials=5)
assert len(trials) == 5

def test_random_search_unbounded_error():
with pytest.raises(TypeError):
trials = strategies.generate_trials(
RANDOM_SEARCH, FLAT_PARAMS, nb_trials=None)

0 comments on commit 5b49351

Please sign in to comment.