Skip to content

Commit

Permalink
Merge pull request #25 from rodrigo-arenas/develop
Browse files Browse the repository at this point in the history
ParameterGrid  implementation
  • Loading branch information
rodrigo-arenas authored Apr 9, 2021
2 parents 0107124 + 6722faa commit 228af30
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 12 deletions.
8 changes: 4 additions & 4 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ pytest==6.2.2
codecov==2.1.11
pytest-cov==2.11.1
twine==3.3.0
numpy>=1.18.1
numpy~=1.18.1
ortools>=7.8.7959
pandas>=1.0.0
scikit-learn>=0.20.0
joblib>=0.11
pandas~=1.0.0
joblib~=0.11

4 changes: 2 additions & 2 deletions examples/queing/multierlang.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

param_grid = {"transactions": [100, 200], "aht": [3], "interval": [30], "asa": [20 / 60], "shrinkage": [0.3]}
required_positions_scenarios = {"service_level": [0.85, 0.9], "max_occupancy": [0.8]}
service_level_scenarios= {"positions": [10, 20, 30], "scale_positions": [True, False]}
service_level_scenarios = {"positions": [10, 20, 30], "scale_positions": [True, False]}

multi_erlang = MultiErlangC(param_grid=param_grid, n_jobs=-1)

Expand All @@ -18,4 +18,4 @@
required_positions_scenarios = {"service_level": [0.7, 0.85, 0.9], "max_occupancy": [0.8]}

positions_requirements = multi_erlang.required_positions(required_positions_scenarios)
print("positions_requirements: ", positions_requirements)
print("positions_requirements: ", positions_requirements)
6 changes: 3 additions & 3 deletions pyworkforce/queuing/erlang.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from math import exp, ceil, floor
from sklearn.model_selection import ParameterGrid
from pyworkforce.utils import ParameterGrid
from joblib import Parallel, delayed


Expand Down Expand Up @@ -132,8 +132,8 @@ def required_positions(self, service_level: float, max_occupancy: float = 1.0):

class MultiErlangC:
"""
This class uses de ErlangC class using joblib's Parallel allowing to run multiples scenarios with one class It
finds solutions iterating over all possible combinations provided by the users, inspired how Sklearn's Grid
This class uses de ErlangC class using joblib's Parallel allowing to run multiples scenarios with one class.
It finds solutions iterating over all possible combinations provided by the users, inspired how Sklearn's Grid
Search works
"""

Expand Down
3 changes: 3 additions & 0 deletions pyworkforce/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pyworkforce.utils.grid import ParameterGrid

__all__ = ["ParameterGrid"]
113 changes: 113 additions & 0 deletions pyworkforce/utils/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from collections.abc import Mapping, Iterable
from itertools import product
from functools import partial, reduce
import operator
import numpy as np


class ParameterGrid:
"""
This implementation is taken from scikit-learn: https://github.com/scikit-learn/scikit-learn
Grid of parameters with a discrete number of values for each.
Can be used to iterate over parameter value combinations with the
Python built-in function iter.
The order of the generated parameter combinations is deterministic.
Read more in the :ref:`User Guide <grid_search>`.
Parameters
----------
param_grid : dict of str to sequence, or sequence of such
The parameter grid to explore, as a dictionary mapping estimator
parameters to sequences of allowed values.
An empty dict signifies default parameters.
A sequence of dicts signifies a sequence of grids to search, and is
useful to avoid exploring parameter combinations that make no sense
or have no effect. See the examples below.
"""

def __init__(self, param_grid):
if not isinstance(param_grid, (Mapping, Iterable)):
raise TypeError('Parameter grid is not a dict or '
'a list ({!r})'.format(param_grid))

if isinstance(param_grid, Mapping):
# wrap dictionary in a singleton list to support either dict
# or list of dicts
param_grid = [param_grid]

# check if all entries are dictionaries of lists
for grid in param_grid:
if not isinstance(grid, dict):
raise TypeError('Parameter grid is not a '
'dict ({!r})'.format(grid))
for key in grid:
if not isinstance(grid[key], Iterable):
raise TypeError('Parameter grid value is not iterable '
'(key={!r}, value={!r})'
.format(key, grid[key]))

self.param_grid = param_grid

def __iter__(self):
"""Iterate over the points in the grid.
Returns
-------
params : iterator over dict of str to any
Yields dictionaries mapping each estimator parameter to one of its
allowed values.
"""
for p in self.param_grid:
# Always sort the keys of a dictionary, for reproducibility
items = sorted(p.items())
if not items:
yield {}
else:
keys, values = zip(*items)
for v in product(*values):
params = dict(zip(keys, v))
yield params

def __len__(self):
"""Number of points on the grid."""
# Product function that can handle iterables (np.product can't).
product = partial(reduce, operator.mul)
return sum(product(len(v) for v in p.values()) if p else 1
for p in self.param_grid)

def __getitem__(self, ind):
"""Get the parameters that would be ``ind``th in iteration
Parameters
----------
ind : int
The iteration index
Returns
-------
params : dict of str to any
Equal to list(self)[ind]
"""
# This is used to make discrete sampling without replacement memory
# efficient.
for sub_grid in self.param_grid:
# XXX: could memoize information used here
if not sub_grid:
if ind == 0:
return {}
else:
ind -= 1
continue

# Reverse so most frequent cycling parameter comes first
keys, values_lists = zip(*sorted(sub_grid.items())[::-1])
sizes = [len(v_list) for v_list in values_lists]
total = np.product(sizes)

if ind >= total:
# Try the next grid
ind -= total
else:
out = {}
for key, v_list, n in zip(keys, values_lists, sizes):
ind, offset = divmod(ind, n)
out[key] = v_list[offset]
return out

raise IndexError('ParameterGrid index out of range')
Empty file.
75 changes: 75 additions & 0 deletions pyworkforce/utils/tests/test_parameter_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pyworkforce.utils import ParameterGrid
from collections.abc import Iterable, Sized
from itertools import chain, product
import pytest


def assert_grid_iter_equals_getitem(grid):
assert list(grid) == [grid[i] for i in range(len(grid))]


def test_parameter_grid():
"""
Test taken from scikit-learn
"""
# Test basic properties of ParameterGrid.
params1 = {"foo": [1, 2, 3]}
grid1 = ParameterGrid(params1)
assert isinstance(grid1, Iterable)
assert isinstance(grid1, Sized)
assert len(grid1) == 3
assert_grid_iter_equals_getitem(grid1)

params2 = {"foo": [4, 2],
"bar": ["ham", "spam", "eggs"]}
grid2 = ParameterGrid(params2)
assert len(grid2) == 6

# loop to assert we can iterate over the grid multiple times
for i in range(2):
# tuple + chain transforms {"a": 1, "b": 2} to ("a", 1, "b", 2)
points = set(tuple(chain(*(sorted(p.items())))) for p in grid2)
assert (points ==
set(("bar", x, "foo", y)
for x, y in product(params2["bar"], params2["foo"])))
assert_grid_iter_equals_getitem(grid2)

# Special case: empty grid (useful to get default estimator settings)
empty = ParameterGrid({})
assert len(empty) == 1
assert list(empty) == [{}]
assert_grid_iter_equals_getitem(empty)
with pytest.raises(IndexError):
empty[1]

has_empty = ParameterGrid([{'C': [1, 10]}, {}, {'C': [.5]}])
assert len(has_empty) == 4
assert list(has_empty) == [{'C': 1}, {'C': 10}, {}, {'C': .5}]
assert_grid_iter_equals_getitem(has_empty)


def test_non_iterable_parameter_grid():
params = {"foo": 4,
"bar": ["ham", "spam", "eggs"]}
with pytest.raises(Exception) as excinfo:
grid = ParameterGrid(params)
for grid in params:
for key in grid:
assert str(excinfo.value) == 'Parameter grid value is not iterable '
'(key={!r}, value={!r})'.format(key, grid[key])


def test_non_dict_parameter_grid():
params = [4]
with pytest.raises(Exception) as excinfo:
grid = ParameterGrid(params)
for grid in params:
assert str(excinfo.value) == 'Parameter grid is not a dict ({!r})'.format(grid)


def test_wrong_parameter_grid():
params = 4
with pytest.raises(Exception) as excinfo:
grid = ParameterGrid(params)
for grid in params:
assert str(excinfo.value) == 'Parameter grid is not a dict or a list ({!r})'.format(params)
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
from setuptools import setup, find_packages

# python setup.py sdist bdist_wheel
# twine upload --skip-existing --repository-url https://test.pypi.org/legacy/ dist/*

HERE = pathlib.Path(__file__).parent
Expand All @@ -27,10 +28,9 @@
],
packages=find_packages(include=['pyworkforce', 'pyworkforce.*']),
install_requires=[
'numpy>=1.18.1',
'numpy',
'ortools>=7.8.7959',
'pandas>=1.0.0',
'scikit-learn>=0.20.0',
'pandas',
'joblib>=0.11'
],
python_requires=">=3.6",
Expand Down

0 comments on commit 228af30

Please sign in to comment.