-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from rodrigo-arenas/develop
ParameterGrid implementation
- Loading branch information
Showing
8 changed files
with
203 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pyworkforce.utils.grid import ParameterGrid | ||
|
||
__all__ = ["ParameterGrid"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from collections.abc import Mapping, Iterable | ||
from itertools import product | ||
from functools import partial, reduce | ||
import operator | ||
import numpy as np | ||
|
||
|
||
class ParameterGrid: | ||
""" | ||
This implementation is taken from scikit-learn: https://github.com/scikit-learn/scikit-learn | ||
Grid of parameters with a discrete number of values for each. | ||
Can be used to iterate over parameter value combinations with the | ||
Python built-in function iter. | ||
The order of the generated parameter combinations is deterministic. | ||
Read more in the :ref:`User Guide <grid_search>`. | ||
Parameters | ||
---------- | ||
param_grid : dict of str to sequence, or sequence of such | ||
The parameter grid to explore, as a dictionary mapping estimator | ||
parameters to sequences of allowed values. | ||
An empty dict signifies default parameters. | ||
A sequence of dicts signifies a sequence of grids to search, and is | ||
useful to avoid exploring parameter combinations that make no sense | ||
or have no effect. See the examples below. | ||
""" | ||
|
||
def __init__(self, param_grid): | ||
if not isinstance(param_grid, (Mapping, Iterable)): | ||
raise TypeError('Parameter grid is not a dict or ' | ||
'a list ({!r})'.format(param_grid)) | ||
|
||
if isinstance(param_grid, Mapping): | ||
# wrap dictionary in a singleton list to support either dict | ||
# or list of dicts | ||
param_grid = [param_grid] | ||
|
||
# check if all entries are dictionaries of lists | ||
for grid in param_grid: | ||
if not isinstance(grid, dict): | ||
raise TypeError('Parameter grid is not a ' | ||
'dict ({!r})'.format(grid)) | ||
for key in grid: | ||
if not isinstance(grid[key], Iterable): | ||
raise TypeError('Parameter grid value is not iterable ' | ||
'(key={!r}, value={!r})' | ||
.format(key, grid[key])) | ||
|
||
self.param_grid = param_grid | ||
|
||
def __iter__(self): | ||
"""Iterate over the points in the grid. | ||
Returns | ||
------- | ||
params : iterator over dict of str to any | ||
Yields dictionaries mapping each estimator parameter to one of its | ||
allowed values. | ||
""" | ||
for p in self.param_grid: | ||
# Always sort the keys of a dictionary, for reproducibility | ||
items = sorted(p.items()) | ||
if not items: | ||
yield {} | ||
else: | ||
keys, values = zip(*items) | ||
for v in product(*values): | ||
params = dict(zip(keys, v)) | ||
yield params | ||
|
||
def __len__(self): | ||
"""Number of points on the grid.""" | ||
# Product function that can handle iterables (np.product can't). | ||
product = partial(reduce, operator.mul) | ||
return sum(product(len(v) for v in p.values()) if p else 1 | ||
for p in self.param_grid) | ||
|
||
def __getitem__(self, ind): | ||
"""Get the parameters that would be ``ind``th in iteration | ||
Parameters | ||
---------- | ||
ind : int | ||
The iteration index | ||
Returns | ||
------- | ||
params : dict of str to any | ||
Equal to list(self)[ind] | ||
""" | ||
# This is used to make discrete sampling without replacement memory | ||
# efficient. | ||
for sub_grid in self.param_grid: | ||
# XXX: could memoize information used here | ||
if not sub_grid: | ||
if ind == 0: | ||
return {} | ||
else: | ||
ind -= 1 | ||
continue | ||
|
||
# Reverse so most frequent cycling parameter comes first | ||
keys, values_lists = zip(*sorted(sub_grid.items())[::-1]) | ||
sizes = [len(v_list) for v_list in values_lists] | ||
total = np.product(sizes) | ||
|
||
if ind >= total: | ||
# Try the next grid | ||
ind -= total | ||
else: | ||
out = {} | ||
for key, v_list, n in zip(keys, values_lists, sizes): | ||
ind, offset = divmod(ind, n) | ||
out[key] = v_list[offset] | ||
return out | ||
|
||
raise IndexError('ParameterGrid index out of range') |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from pyworkforce.utils import ParameterGrid | ||
from collections.abc import Iterable, Sized | ||
from itertools import chain, product | ||
import pytest | ||
|
||
|
||
def assert_grid_iter_equals_getitem(grid): | ||
assert list(grid) == [grid[i] for i in range(len(grid))] | ||
|
||
|
||
def test_parameter_grid(): | ||
""" | ||
Test taken from scikit-learn | ||
""" | ||
# Test basic properties of ParameterGrid. | ||
params1 = {"foo": [1, 2, 3]} | ||
grid1 = ParameterGrid(params1) | ||
assert isinstance(grid1, Iterable) | ||
assert isinstance(grid1, Sized) | ||
assert len(grid1) == 3 | ||
assert_grid_iter_equals_getitem(grid1) | ||
|
||
params2 = {"foo": [4, 2], | ||
"bar": ["ham", "spam", "eggs"]} | ||
grid2 = ParameterGrid(params2) | ||
assert len(grid2) == 6 | ||
|
||
# loop to assert we can iterate over the grid multiple times | ||
for i in range(2): | ||
# tuple + chain transforms {"a": 1, "b": 2} to ("a", 1, "b", 2) | ||
points = set(tuple(chain(*(sorted(p.items())))) for p in grid2) | ||
assert (points == | ||
set(("bar", x, "foo", y) | ||
for x, y in product(params2["bar"], params2["foo"]))) | ||
assert_grid_iter_equals_getitem(grid2) | ||
|
||
# Special case: empty grid (useful to get default estimator settings) | ||
empty = ParameterGrid({}) | ||
assert len(empty) == 1 | ||
assert list(empty) == [{}] | ||
assert_grid_iter_equals_getitem(empty) | ||
with pytest.raises(IndexError): | ||
empty[1] | ||
|
||
has_empty = ParameterGrid([{'C': [1, 10]}, {}, {'C': [.5]}]) | ||
assert len(has_empty) == 4 | ||
assert list(has_empty) == [{'C': 1}, {'C': 10}, {}, {'C': .5}] | ||
assert_grid_iter_equals_getitem(has_empty) | ||
|
||
|
||
def test_non_iterable_parameter_grid(): | ||
params = {"foo": 4, | ||
"bar": ["ham", "spam", "eggs"]} | ||
with pytest.raises(Exception) as excinfo: | ||
grid = ParameterGrid(params) | ||
for grid in params: | ||
for key in grid: | ||
assert str(excinfo.value) == 'Parameter grid value is not iterable ' | ||
'(key={!r}, value={!r})'.format(key, grid[key]) | ||
|
||
|
||
def test_non_dict_parameter_grid(): | ||
params = [4] | ||
with pytest.raises(Exception) as excinfo: | ||
grid = ParameterGrid(params) | ||
for grid in params: | ||
assert str(excinfo.value) == 'Parameter grid is not a dict ({!r})'.format(grid) | ||
|
||
|
||
def test_wrong_parameter_grid(): | ||
params = 4 | ||
with pytest.raises(Exception) as excinfo: | ||
grid = ParameterGrid(params) | ||
for grid in params: | ||
assert str(excinfo.value) == 'Parameter grid is not a dict or a list ({!r})'.format(params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters