From 3a09d1a97a27b309fc6651d45d3073357ba3d8ec Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Tue, 17 Jul 2018 12:53:42 +0200 Subject: [PATCH 01/41] upgrading scannables --- artemis/general/scannable_functions.py | 94 +++++++++++++++------ artemis/general/test_scannable_functions.py | 10 +-- 2 files changed, 75 insertions(+), 29 deletions(-) diff --git a/artemis/general/scannable_functions.py b/artemis/general/scannable_functions.py index 3d0faa6d..c9a1b8e5 100644 --- a/artemis/general/scannable_functions.py +++ b/artemis/general/scannable_functions.py @@ -1,4 +1,9 @@ -def scannable(state, output=None, returns=None): +from collections import namedtuple + +from artemis.general.should_be_builtins import izip_equal + + +def scannable(state, returns=None, output=None): """ A decorator for turning functions into stateful objects. The decorator attaches a "scan" method to the given function, which can be called to create an object which stores the state that gets fed back into the next function call. This @@ -18,18 +23,44 @@ def simple_moving_average(x, avg, n): :param dict state: A dictionary whose keys are the names of the arguments to feed back, and whose values are the default initial state. These initial values can be overridden when calling function.scan(arg_name, initial_value) - :param Optional[Union[str, Sequence[str]]] output: If there is more than one state variable or more than one output, + :param Optional[Union[str, Sequence[str]]] returns: If there is more than one state variable or more than one output, include the list of output names, so that the scan knows which outputs to use to update the state. - :param Optional[Union[str, Sequence[str]]] returns: If there is more than one output and you only wish to return a + :param Optional[Union[str, Sequence[str]]] output: If there is more than one output and you only wish to return a subset of the outputs, indicate here which variables you want to return. :return: func, but with a "scan" function attched. """ def wrapper(func): def create_scannable(**kwargs): - return Scannable(func=func, state=state, output=output, returns=returns, kwargs=kwargs) - + return Scannable(func=func, state=state, returns=returns, output=output, kwargs=kwargs) func.scan = create_scannable + + + + + class StatelessUpdater(namedtuple('StatelessUpdater for {}'.format(func.__name__))): + + def __call__(self, *args, **kwargs): + + + + + return StatelessUpdater(*(return_value for return_name, return_value in izip_equal(returns, return_values) if return_name in state)) + + + state_object = None if isinstance(state, str) else namedtuple('State of {}'.format(func.__name__), state) + output_object = None if isinstance(output, str) else namedtuple('Output of {}'.format(func.__name__, output)) + + def standard_form(input_values, state_values): + + kwargs = input_value + + + return output_values, state_values + + + func.standard_form = standard_form + return func return wrapper @@ -40,7 +71,7 @@ class Scannable(object): SINGLE_OUTPUT_FORMAT = object() TUPLE_OUTPUT_FORMAT = object() - def __init__(self, func, state, output, returns, kwargs = None): + def __init__(self, func, state, returns, output, kwargs = None): """ See scannable docstring """ @@ -52,34 +83,34 @@ def __init__(self, func, state, output, returns, kwargs = None): if kwargs is not None: state.update(kwargs) - if output is None: + if returns is None: assert len(state_names)==1, "If there is more than one state variable, you must specify the output!" - output = next(iter(state_names)) - if isinstance(output, str): - assert output in state_names, 'Output name "{}" was not provided not included in the state dict: "{}"'.format(output, state_names) + returns = next(iter(state_names)) + if isinstance(returns, str): + assert returns in state_names, 'Output name "{}" was not provided not included in the state dict: "{}"'.format(returns, state_names) self._output_format = Scannable.SINGLE_OUTPUT_FORMAT - self._state_names = output - output_names = [output] + self._state_names = returns + output_names = [returns] self._output_format = Scannable.SINGLE_OUTPUT_FORMAT else: - assert isinstance(output, (list, tuple)), "output must be a string, a list/tuple, or None" - assert all(sn in output for sn in state_names), 'Variabels name(s) {} were listed as state variables but not included in the list of outputs: {}'.format([sn for sn in state_names if sn not in output], output) - output_names = output + assert isinstance(returns, (list, tuple)), "output must be a string, a list/tuple, or None" + assert all(sn in returns for sn in state_names), 'Variabels name(s) {} were listed as state variables but not included in the list of outputs: {}'.format([sn for sn in state_names if sn not in returns], returns) + output_names = returns self._output_format = Scannable.TUPLE_OUTPUT_FORMAT self._state_names = tuple(state_names) self._state_indices_in_output = [output_names.index(state_name) for state_name in state_names] - if isinstance(returns, str): - assert output is not None, 'If you specify returns, you must specify output' - if isinstance(output, str): - assert returns==output_names + if isinstance(output, str): + assert returns is not None, 'If you specify returns, you must specify output' + if isinstance(returns, str): + assert output == output_names return_index = None else: - assert isinstance(output, (list, tuple)) - return_index = output.index(returns) - elif isinstance(returns, (list, tuple)): - return_index = tuple(output_names.index(r) for r in returns) + assert isinstance(returns, (list, tuple)) + return_index = returns.index(output) + elif isinstance(output, (list, tuple)): + return_index = tuple(output_names.index(r) for r in output) else: - assert returns is None + assert output is None return_index = None self.func = func @@ -109,3 +140,18 @@ def __call__(self, *args, **kwargs): @property def state(self): return self._state.copy() + +# +# class ScannableStateLess(object): +# +# def __init__(self, func, inputs, returns, outputs): +# pass +# +# def __call__(self, *args, **kwargs): +# """ +# :param args: +# :param kwargs: +# :return: +# """ + + diff --git a/artemis/general/test_scannable_functions.py b/artemis/general/test_scannable_functions.py index d8e3bed3..f62f7115 100644 --- a/artemis/general/test_scannable_functions.py +++ b/artemis/general/test_scannable_functions.py @@ -8,7 +8,7 @@ def test_simple_moving_average(): seq = np.random.randn(100) + np.sin(np.linspace(0, 10, 100)) - @scannable(state=['avg', 'n'], output=['avg', 'n'], returns='avg') + @scannable(state=['avg', 'n'], returns=['avg', 'n'], output='avg') def simple_moving_average(x, avg=0, n=0): return (n/(1.+n))*avg + (1./(1.+n))*x, n+1 @@ -51,7 +51,7 @@ def test_rnn_type_comp(): w_hh = rng.randn(n_hid, n_hid) w_hy = rng.randn(n_hid, n_out) - @scannable(state='hid', output=['out', 'hid'], returns='out') + @scannable(state='hid', returns=['out', 'hid'], output='out') def rnn_like_func(x, hid= np.zeros(n_hid)): new_hid = np.tanh(x.dot(w_xh) + hid.dot(w_hh)) out = new_hid.dot(w_hy) @@ -87,19 +87,19 @@ def moving_average_with_typo(x, decay, avg=0): simply_smoothed_signal = [f(x=x, decay=1./(t+1)) for t, x in enumerate(seq)] with pytest.raises(AssertionError): # Should really be done before instance-creation, but whatever. - @scannable(state='avg', output='avgf') + @scannable(state='avg', returns='avgf') def moving_average_with_typo(x, decay, avg=0): return (1-decay)*avg + decay*x f = moving_average_with_typo.scan() with pytest.raises(ValueError): # Invalid return name - @scannable(state=['avg'], output=['avg'], returns='avgf') + @scannable(state=['avg'], returns=['avg'], output='avgf') def moving_average_with_typo(x, decay, avg=0): return (1-decay)*avg + decay*x f = moving_average_with_typo.scan() with pytest.raises(TypeError): # Wrong output format - @scannable(state=['avg'], output=['avg', 'something'], returns='avg') + @scannable(state=['avg'], returns=['avg', 'something'], output='avg') def moving_average_with_typo(x, decay, avg=0): return (1-decay)*avg + decay*x f = moving_average_with_typo.scan() From 8fb2bd92542ac821c9552734ec4df8068f1b7776 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Wed, 18 Jul 2018 16:42:21 +0200 Subject: [PATCH 02/41] ok, the scan thing is internally messy but seems to work' --- artemis/general/scannable_functions.py | 314 ++++++++++++-------- artemis/general/test_scannable_functions.py | 99 +++++- 2 files changed, 277 insertions(+), 136 deletions(-) diff --git a/artemis/general/scannable_functions.py b/artemis/general/scannable_functions.py index c9a1b8e5..3398f02d 100644 --- a/artemis/general/scannable_functions.py +++ b/artemis/general/scannable_functions.py @@ -1,6 +1,193 @@ from collections import namedtuple +from artemis.general.functional import advanced_getargspec -from artemis.general.should_be_builtins import izip_equal + +def _normalize_formats(func, returns, state, output, outer_kwargs): + """ + + :param func: + :param returns: + :param state: + :param output: + :return: + """ + + if returns is None: + assert isinstance(state, str), "If there is more than one state variable, you must specify the return variables!" + returns = state + + single_return_format = isinstance(returns, str) + + if isinstance(state, str): + state = (state, ) + assert isinstance(state, (list, tuple)), 'State should be a list of state names. Got a {}'.format(state.__class__) + + arg_names, _, _, initial_state_dict = advanced_getargspec(func) + + for s in state: + assert s in arg_names, "The state '{}' is not a parameter to the function '{}'".format(s, func.__name__) + assert s==returns if single_return_format else s in returns, "The state variable '{}' is not updated by the returns '{}'".format(s, returns) + + if output is None: + output = returns + + if isinstance(state, dict): + initial_state_dict = state + else: + for name in state: + assert name in initial_state_dict, "Argument '{}' is part of your state, but it does not have an initial value. Provide one either by passing in state as a dict or adding a default in the function signature".format(name) + + parameter_kwargs = {} + for k, v in outer_kwargs.items(): + if k in state: + initial_state_dict[k] = v + else: + parameter_kwargs[k] = v + + if isinstance(state, (list, tuple, dict)): + return_state_indices = None if single_return_format else [returns.index(s) for s in state] + else: + raise Exception('State must be a list, tuple, dict, string. Not {}'.format(output)) + + if single_return_format: + return_state_names = returns + else: + return_state_names = [returns[ix] for ix in return_state_indices] + + if isinstance(output, (list, tuple)): + output_type = namedtuple('ScanOutputs_of_{}'.format(func.__name__), output) + return_output_indices = [returns.index(o) for o in output] + elif isinstance(output, str): + output_type = None + return_output_indices = returns.index(output) + else: + raise Exception('Output must be a list/tuple or string. Not {}'.format(output)) + + if single_return_format: + return_output_indices = None + + return single_return_format, return_output_indices, return_state_indices, return_state_names, initial_state_dict, output_type, parameter_kwargs + + +def immutable_scan(func, returns, state, output, outer_kwargs = {}): + """ + Create a StatelessUpdater object from a function. + + A StatelessUpdater is an Immutable callable object, which stores state. When called, it returns a new Statelessupdater, + containing the new state, and the specified return value. + + e.g. + + def moving_average(avg, x, t=0): + return avg*t/(t+1)+x/(t+1), t+1 + + sup = immutable_scan(moving_average, state=['avg', 't'], returns = ['avg', 't'], output='avg') + + sup2, avg = sup(3) + assert avg==3 + sup3, avg = sup2(4) + assert avg == 3.5 + + Note, you may choose to use the @scannable decorator instead: + + @scannable(state=['avg', 't'], returns = ['avg', 't'], output='avg') + def moving_average(avg, x, t=0): + return avg*t/(t+1)+x/(t+1), t+1 + + sup = moving_average.immutable_scan() + + :param Callable func: A function which defines a step in an iterative process + :param Union[Sequence[str], str] state: A list of names of state variables. + :param Union[Sequence[str], str] returns: A list of names of variables returned from the function. + :param Union[Sequence[str], str] output: A list of names of "output" variables + :return Callable[[...], Tuple[Callable, Any]]: An immutable callable of the form: + new_object_state, outputs = old_object_state(**inputs) + """ + + single_return_format, return_output_indices, return_state_indices, return_state_names, initial_state_dict, output_type, parameter_kwargs = _normalize_formats(func=func, returns=returns, state=state, output=output, outer_kwargs=outer_kwargs) + single_output_format = not isinstance(return_output_indices, (list, tuple)) + + class ImmutableScan(namedtuple('ImmutableScan_of_{}'.format(func.__name__), state)): + + def __call__(self, *args, **kwargs): + """ + :param args: + :param kwargs: + :return StatelessUpdater, Any: + Where the second output is an arbitrary value if output is specified as a string, or a namedtuple if outputs is specified as a list/tuple + """ + arguments = self._asdict() + arguments.update(**parameter_kwargs) + arguments.update(**kwargs) + return_values = func(*args, **arguments) + + if single_return_format: + if single_output_format: + output_values = return_values + else: + output_values = output_type(return_values) + new_state = ImmutableScan(return_values) + else: + try: + assert len(return_values) == len(returns), 'The number of return values: {}, does not match the length of the specified return variables: {} ({})'.format(len(return_values), len(returns), returns) + except TypeError: + raise TypeError('{} should have returned an iterable of length {} containing variables {}, but got a non-iterable: {}'.format(func.__name__, len(returns), returns, return_values)) + if single_output_format: + output_values = return_values[return_output_indices] + else: + output_values = output_type(*(return_values[i] for i in return_output_indices)) + new_state = ImmutableScan(*(return_values[ix] for ix in return_state_indices)) + + return new_state, output_values + + return ImmutableScan(**initial_state_dict) + + +def mutable_scan(func, state, returns, output, outer_kwargs = {}): + + single_return_format, return_output_indices, return_state_indices, return_state_names, initial_state_dict, output_type, parameter_kwargs = _normalize_formats(func=func, returns=returns, state=state, output=output, outer_kwargs=outer_kwargs) + single_output_format = not isinstance(return_output_indices, (list, tuple)) + + try: + from recordclass import recordclass + except: + raise ImportError('Stateful Updaters require recordclass to be installed. Run "pip install recordclass".') + + class MutableScan(recordclass('MutableScan_of_{}'.format(func.__name__), state)): + + def __call__(self, *args, **kwargs): + """ + :param args: + :param kwargs: + :return StatelessUpdater, Any: + Where the second output is an arbitrary value if output is specified as a string, or a namedtuple if outputs is specified as a list/tuple + """ + arguments = self._asdict() + arguments.update(**parameter_kwargs) + arguments.update(**kwargs) + return_values = func(*args, **arguments) + + if single_return_format: + if single_output_format: + output_values = return_values + else: + output_values = output_type(return_values) + setattr(self, return_state_names, return_values) + else: + try: + assert len(return_values) == len(returns), 'The number of return values: {}, does not match the length of the specified return variables: {} ({})'.format(len(return_values), len(returns), returns) + except TypeError: + raise TypeError('{} should have returned an iterable of length {} containing variables {}, but got a non-iterable: {}'.format(func.__name__, len(returns), returns, return_values)) + if single_output_format: + output_values = return_values[return_output_indices] + else: + output_values = output_type(*(return_values[i] for i in return_output_indices)) + for ix, name in zip(return_state_indices, return_state_names): + setattr(self, name, return_values[ix]) + + return output_values + + return MutableScan(**initial_state_dict) def scannable(state, returns=None, output=None): @@ -31,127 +218,14 @@ def simple_moving_average(x, avg, n): """ def wrapper(func): - def create_scannable(**kwargs): - return Scannable(func=func, state=state, returns=returns, output=output, kwargs=kwargs) - func.scan = create_scannable - - - - - class StatelessUpdater(namedtuple('StatelessUpdater for {}'.format(func.__name__))): - - def __call__(self, *args, **kwargs): - - - - - return StatelessUpdater(*(return_value for return_name, return_value in izip_equal(returns, return_values) if return_name in state)) + def make_mutable_scan(**kwargs): + return mutable_scan(func=func, state=state, returns=returns, output=output, outer_kwargs=kwargs) + func.mutable_scan = make_mutable_scan + def make_immutable_scan(**kwargs): + return immutable_scan(func=func, state=state, returns=returns, output=output, outer_kwargs=kwargs) - state_object = None if isinstance(state, str) else namedtuple('State of {}'.format(func.__name__), state) - output_object = None if isinstance(output, str) else namedtuple('Output of {}'.format(func.__name__, output)) - - def standard_form(input_values, state_values): - - kwargs = input_value - - - return output_values, state_values - - - func.standard_form = standard_form - + func.immutable_scan = make_immutable_scan return func return wrapper - - -class Scannable(object): - - SINGLE_OUTPUT_FORMAT = object() - TUPLE_OUTPUT_FORMAT = object() - - def __init__(self, func, state, returns, output, kwargs = None): - """ - See scannable docstring - """ - if isinstance(state, str): - state = (state, ) - assert isinstance(state, (list, tuple)), 'State should be a list of state names. Got a {}'.format(state.__class__) - state_names = state - state = {} - if kwargs is not None: - state.update(kwargs) - - if returns is None: - assert len(state_names)==1, "If there is more than one state variable, you must specify the output!" - returns = next(iter(state_names)) - if isinstance(returns, str): - assert returns in state_names, 'Output name "{}" was not provided not included in the state dict: "{}"'.format(returns, state_names) - self._output_format = Scannable.SINGLE_OUTPUT_FORMAT - self._state_names = returns - output_names = [returns] - self._output_format = Scannable.SINGLE_OUTPUT_FORMAT - else: - assert isinstance(returns, (list, tuple)), "output must be a string, a list/tuple, or None" - assert all(sn in returns for sn in state_names), 'Variabels name(s) {} were listed as state variables but not included in the list of outputs: {}'.format([sn for sn in state_names if sn not in returns], returns) - output_names = returns - self._output_format = Scannable.TUPLE_OUTPUT_FORMAT - self._state_names = tuple(state_names) - self._state_indices_in_output = [output_names.index(state_name) for state_name in state_names] - if isinstance(output, str): - assert returns is not None, 'If you specify returns, you must specify output' - if isinstance(returns, str): - assert output == output_names - return_index = None - else: - assert isinstance(returns, (list, tuple)) - return_index = returns.index(output) - elif isinstance(output, (list, tuple)): - return_index = tuple(output_names.index(r) for r in output) - else: - assert output is None - return_index = None - - self.func = func - self._state = state - self._return_index = return_index - self._output_names = output_names - - def __str__(self): - output = self._output_names[0] if self._output_format is Scannable.SINGLE_OUTPUT_FORMAT else self._output_names - returns = None if self._return_index is None else repr(self._output_names[self._return_index]) if isinstance(self._return_index, int) else tuple(self._output_names[i] for i in self._return_index) - self._strrep = '{}(func={}, state={}, output={}, returns={})'.format(self.__class__.__name__, self.func.__name__, self._state, output, returns) - return self._strrep - - def __call__(self, *args, **kwargs): - kwargs.update(self._state) - values_returned = self.func(*args, **kwargs) - if self._output_format is Scannable.SINGLE_OUTPUT_FORMAT: - self._state[self._state_names] = values_returned - else: - try: - assert len(values_returned) == len(self._output_names), 'The number of outputs: {}, does not match the length of the specified outputs: {} ({})'.format(len(values_returned), len(self._output_names), self._output_names) - except TypeError: - raise TypeError('{} should have returned an iterable of length {} containing variables {}, but got a non-iterable: {}'.format(self.func.__name__, len(self._output_names), self._output_names, values_returned)) - self._state.update((state_name, values_returned[ix]) for state_name, ix in zip(self._state_names, self._state_indices_in_output)) - return values_returned if self._return_index is None else values_returned[self._return_index] if isinstance(self._return_index, int) else tuple(values_returned[i] for i in self._return_index) - - @property - def state(self): - return self._state.copy() - -# -# class ScannableStateLess(object): -# -# def __init__(self, func, inputs, returns, outputs): -# pass -# -# def __call__(self, *args, **kwargs): -# """ -# :param args: -# :param kwargs: -# :return: -# """ - - diff --git a/artemis/general/test_scannable_functions.py b/artemis/general/test_scannable_functions.py index f62f7115..1e30187a 100644 --- a/artemis/general/test_scannable_functions.py +++ b/artemis/general/test_scannable_functions.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from artemis.general.scannable_functions import scannable +from artemis.general.scannable_functions import scannable, immutable_scan, mutable_scan def test_simple_moving_average(): @@ -12,11 +12,11 @@ def test_simple_moving_average(): def simple_moving_average(x, avg=0, n=0): return (n/(1.+n))*avg + (1./(1.+n))*x, n+1 - f = simple_moving_average.scan() + f = simple_moving_average.mutable_scan() averaged_signal = [f(x=x) for t, x in enumerate(seq)] truth = np.cumsum(seq)/np.arange(1, len(seq)+1) assert np.allclose(averaged_signal, truth) - assert np.allclose(f.state['avg'], np.mean(seq)) + assert np.allclose(f.avg, np.mean(seq)) def test_moving_average(): @@ -27,14 +27,14 @@ def moving_average(x, decay, avg=0): seq = np.random.randn(100) + np.sin(np.linspace(0, 10, 100)) - f = moving_average.scan() + f = moving_average.mutable_scan() simply_smoothed_signal = [f(x=x, decay=1./(t+1)) for t, x in enumerate(seq)] truth = np.cumsum(seq)/np.arange(1, len(seq)+1) assert np.allclose(simply_smoothed_signal, truth) - assert list(f.state.keys())==['avg'] - assert np.allclose(f.state['avg'], np.mean(seq)) + assert list(f._fields)==['avg'] + assert np.allclose(f.avg, np.mean(seq)) - f = moving_average.scan() + f = moving_average.mutable_scan() exponentially_smoothed_signal = [f(x=x, decay=0.1) for x in seq] truth = [avg for avg in [0] for x in seq for avg in [0.9*avg + 0.1*x]] assert np.allclose(exponentially_smoothed_signal, truth) @@ -69,45 +69,112 @@ def rnn_like_func(x, hid= np.zeros(n_hid)): outputs.append(y) # The NEW way of doing things. - rnn_step = rnn_like_func.scan(hid=initial_state) + rnn_step = rnn_like_func.mutable_scan(hid=initial_state) outputs2 = [rnn_step(x) for x in seq] assert np.allclose(outputs, outputs2) - assert np.allclose(rnn_step.state['hid'], h) + assert np.allclose(rnn_step.hid, h) + + # Now try the immutable version: + rnn_step = rnn_like_func.immutable_scan(hid=initial_state) + outputs3 = [output for rnn_step in [rnn_step] for x in seq for rnn_step, output in [rnn_step(x)]] + assert np.allclose(outputs, outputs3) def test_bad_beheviour_caught(): seq = np.random.randn(100) + np.sin(np.linspace(0, 10, 100)) - with pytest.raises(TypeError): # Typo in state name + with pytest.raises(AssertionError): # Typo in state name @scannable(state='avgfff') def moving_average_with_typo(x, decay, avg=0): return (1-decay)*avg + decay*x - - f = moving_average_with_typo.scan() - simply_smoothed_signal = [f(x=x, decay=1./(t+1)) for t, x in enumerate(seq)] + f = moving_average_with_typo.mutable_scan() with pytest.raises(AssertionError): # Should really be done before instance-creation, but whatever. @scannable(state='avg', returns='avgf') def moving_average_with_typo(x, decay, avg=0): return (1-decay)*avg + decay*x - f = moving_average_with_typo.scan() + f = moving_average_with_typo.mutable_scan() with pytest.raises(ValueError): # Invalid return name @scannable(state=['avg'], returns=['avg'], output='avgf') def moving_average_with_typo(x, decay, avg=0): return (1-decay)*avg + decay*x - f = moving_average_with_typo.scan() + f = moving_average_with_typo.mutable_scan() with pytest.raises(TypeError): # Wrong output format @scannable(state=['avg'], returns=['avg', 'something'], output='avg') def moving_average_with_typo(x, decay, avg=0): return (1-decay)*avg + decay*x - f = moving_average_with_typo.scan() + f = moving_average_with_typo.mutable_scan() simply_smoothed_signal = [f(x=x, decay=1./(t+1)) for t, x in enumerate(seq)] +def test_stateless_updater(): + + # Direct API + def moving_average(x, avg=0, t=0): + t_next = t+1. + return avg*t/t_next+x/t_next, t_next + + sup = immutable_scan(moving_average, state=['avg', 't'], returns = ['avg', 't'], output='avg') + sup2, avg = sup(3) + assert avg==3 + sup3, avg = sup2(4) + assert avg == 3.5 + sup2a, avg = sup2(1) + assert avg == 2 + + +def test_stateless_updater_with_decorator(): + # Using Decordator + @scannable(state=['avg', 't'], output='avg', returns=['avg', 't']) + def moving_average(x, avg=0, t=0): + t_next = t+1. + return avg*t/t_next+x/t_next, t_next + + sup = moving_average.immutable_scan() + sup2, avg = sup(3) + assert avg==3 + sup3, avg = sup2(4) + assert avg == 3.5 + sup2a, avg = sup2(1) + assert avg == 2 + + +def test_stateful_updater(): + + # Direct API + def moving_average(x, avg=0, t=0): + t_next = t+1. + return avg*t/t_next+x/t_next, t_next + + sup = mutable_scan(moving_average, state=['avg', 't'], returns = ['avg', 't'], output='avg') + avg = sup(3) + assert avg==3 + avg = sup(4) + assert avg == 3.5 + + +def test_stateful_updater_with_decorator(): + # Using Decordator + @scannable(state=['avg', 't'], output='avg', returns=['avg', 't']) + def moving_average(x, avg=0, t=0): + t_next = t+1. + return avg*t/t_next+x/t_next, t_next + + sup = mutable_scan(moving_average, state=['avg', 't'], returns = ['avg', 't'], output='avg') + avg = sup(3) + assert avg==3 + avg = sup(4) + assert avg == 3.5 + + if __name__ == '__main__': test_simple_moving_average() test_moving_average() test_rnn_type_comp() test_bad_beheviour_caught() + test_stateless_updater() + test_stateless_updater_with_decorator() + test_stateful_updater() + test_stateful_updater_with_decorator() \ No newline at end of file From 8688d2242752832b393e2f55ce81c69c4e6faa4c Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Thu, 19 Jul 2018 14:32:09 +0200 Subject: [PATCH 03/41] fixed rsync-copying-all-experiments problem --- artemis/experiments/experiment_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/artemis/experiments/experiment_management.py b/artemis/experiments/experiment_management.py index d66f9466..42a93527 100644 --- a/artemis/experiments/experiment_management.py +++ b/artemis/experiments/experiment_management.py @@ -64,7 +64,7 @@ def pull_experiment_records(user, ip, experiment_names, include_variants=True, n # +["--include='**/*-{exp_name}{variants}/*'".format(exp_name=exp_name, variants = '*' if include_variants else '') for exp_name in experiment_names] # This was the old line, but it could be too long for many experiments. if not need_pass: - output = subprocess.check_output(command) + output = subprocess.check_output(' '.join(command), shell=True) else: # This one works if you need a password password = getpass.getpass("Enter password for {}@{}:".format(user, ip)) From a0a4a692c3dfb1f3d329749d1431b3ae4e15596f Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Tue, 7 Aug 2018 14:50:38 +0200 Subject: [PATCH 04/41] a bunch of things with nested structures --- artemis/experiments/experiment_management.py | 4 + artemis/experiments/experiment_record_view.py | 2 +- artemis/general/nested_structures.py | 99 ++++++++++++++----- artemis/general/pareto_efficiency.py | 31 ++++-- artemis/general/test_nested_structures.py | 34 ++++++- artemis/general/test_pareto_efficiency.py | 33 ++++--- artemis/ml/predictors/predictor_tests.py | 5 +- 7 files changed, 158 insertions(+), 50 deletions(-) diff --git a/artemis/experiments/experiment_management.py b/artemis/experiments/experiment_management.py index 42a93527..46786b32 100644 --- a/artemis/experiments/experiment_management.py +++ b/artemis/experiments/experiment_management.py @@ -323,6 +323,10 @@ def _filter_records(user_range, exp_record_dict): current_time = datetime.now() for exp_id, _ in base.items(): base[exp_id] = [filter_func(current_time - load_experiment_record(rec_id).get_datetime(), time_delta) for rec_id in exp_record_dict[exp_id]] + elif user_range.startswith('has:'): + phrase = user_range[len('has:'):] + for exp_id, records in base.items(): + base[exp_id] = [True]*len(records) if phrase in exp_id else [False]*len(records) else: raise RecordSelectionError("Don't know how to interpret subset '{}'. Possible subsets: {}".format(user_range, list(_named_record_filters.keys()))) return base diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 750f7667..5b093fc1 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -112,7 +112,7 @@ def get_record_full_string(record, show_info = True, show_logs = True, truncate_ return '\n'.join(parts) -def get_record_invalid_arg_string(record, recursive=True, ignore_valid_keys=(), note_version = 'full'): +def get_record_invalid_arg_string(record, recursive=False, ignore_valid_keys=(), note_version = 'full'): """ Return a string identifying ig the arguments for this experiment are still valid. :return: diff --git a/artemis/general/nested_structures.py b/artemis/general/nested_structures.py index 79ee05f3..f9c9cfe6 100644 --- a/artemis/general/nested_structures.py +++ b/artemis/general/nested_structures.py @@ -3,7 +3,7 @@ import numpy as np from six import string_types, next -from artemis.general.should_be_builtins import all_equal +from artemis.general.should_be_builtins import all_equal, izip_equal __author__ = 'peter' @@ -65,10 +65,10 @@ def flatten_struct(struct, primatives = PRIMATIVE_TYPES, custom_handlers = {}, def _is_primitive_container(obj): - return isinstance(obj, _primitive_containers) + return isinstance(obj, _primitive_containers) or hasattr(obj, '_fields') -def get_meta_object(data_object, is_container_func = _is_primitive_container): +def get_meta_object(data_object, is_container = _is_primitive_container): """ Given an arbitrary data structure, return a "meta object" which is the same structure, except all non-container objects are replaced by their types. @@ -77,18 +77,55 @@ def get_meta_object(data_object, is_container_func = _is_primitive_container): get_meta_obj([1, 2, {'a':(3, 4), 'b':['hey', 'yeah']}, 'woo']) == [int, int, {'a':(int, int), 'b':[str, str]}, str] :param data_object: A data object with arbitrary nested structure - :param is_container_func: A callback which returns True if an object is to be considered a container and False otherwise + :param is_container: A callback which returns True if an object is to be considered a container and False otherwise :return: """ - if is_container_func(data_object): - if isinstance(data_object, (list, tuple, set)): - return type(data_object)(get_meta_object(x, is_container_func=is_container_func) for x in data_object) + if is_container(data_object): + if hasattr(data_object, '_fields'): + return type(data_object)(*(get_meta_object(x, is_container=is_container) for x in data_object)) + elif isinstance(data_object, (list, tuple, set)): + return type(data_object)(get_meta_object(x, is_container=is_container) for x in data_object) elif isinstance(data_object, dict): - return type(data_object)((k, get_meta_object(v, is_container_func=is_container_func)) for k, v in data_object.items()) + return type(data_object)((k, get_meta_object(v, is_container=is_container)) for k, v in data_object.items()) else: return type(data_object) +def broadcast_into_meta_object(meta_object, data_object, is_container = _is_primitive_container, check_types = True): + """ + "Broadcast" the data object into the meta object. This puts the data into the structure of the meta-object. + E.g. + + >>> broadcast_into_meta_object([int, int, int], 1) + [1, 1, 1] + >>> broadcast_into_meta_object([int, (int, int), int], (1, 2, 3)) + [1, (2, 2), 3] + + :param meta_object: A nested structure of types + :param data_object: A nested structure of data + :param is_container: A function that returns True if an object is considered to be in a container. + :return: A new data object with the structure of the meta object and the data of the data object. + """ + kwargs = dict(check_types=check_types, is_container=is_container) + if is_container(meta_object): + if isnamedtupleinstance(meta_object): + if isinstance(data_object, (list, tuple, set)): + return meta_object.__class__(*(broadcast_into_meta_object(m, d, **kwargs) for m, d in izip_equal(meta_object, data_object))) + else: + return meta_object.__class__(*(broadcast_into_meta_object(m, data_object, **kwargs) for m in meta_object)) + elif isinstance(meta_object, (list, tuple, set)): + if isinstance(data_object, (list, tuple, set)): + return meta_object.__class__(broadcast_into_meta_object(m, d, **kwargs) for m, d in izip_equal(meta_object, data_object)) + else: + return meta_object.__class__(broadcast_into_meta_object(m, data_object, **kwargs) for m in meta_object) + else: + raise NotImplementedError('Dict iteration not supported yet.') + else: + if check_types: + assert isinstance(data_object, meta_object), "Data object {} does not have type of meta-object: {}".format(data_object, meta_object) + return data_object + + class NestedType(object): """ An object which represents the type of an arbitrarily nested data structure. It can be constructed directly @@ -122,14 +159,29 @@ def __repr__(self): def __eq__(self, other): return self.meta_object == other.meta_object - def get_leaves(self, data_object, check_types = True, is_container_func = _is_primitive_container): + def broadcast(self, data_object, is_container = _is_primitive_container, check_types=True): + """ + "Broadcast" a data object to have the given structure. e.g. + + >>> structure = NestedType([int, (int, int), int]) + >>> structure.broadcast((1, 2, 3)) + [1, (2, 2), 3] + + :param data_object: A nested data object which can be broadcast onto this structure. + :return: A new data object with a structure matching this object's. + """ + return broadcast_into_meta_object(meta_object=self.meta_object, data_object=data_object, is_container=is_container, check_types=check_types) + + def get_leaves(self, data_object, check_types = True, broadcast=False, is_container = _is_primitive_container): """ :param data_object: Given a nested object, get the "leaf" values in Depth-First Order :return: A list of leaf values. """ - if check_types: + if broadcast: + data_object = self.broadcast(data_object, check_types=check_types, is_container=is_container) + elif check_types: self.check_type(data_object) - return get_leaf_values(data_object, is_container_func=is_container_func) + return get_leaf_values(data_object, is_container_func=is_container) def expand_from_leaves(self, leaves, check_types = True, assert_fully_used=True, is_container_func = _is_primitive_container): """ @@ -149,17 +201,15 @@ def from_data(data_object, is_container_func = _is_primitive_container): :param is_container_func: A callback which returns True if an object is to be considered a container and False otherwise :return: A NestedType object """ - return NestedType(get_meta_object(data_object, is_container_func=is_container_func)) + return NestedType(get_meta_object(data_object, is_container=is_container_func)) -def isnestedinstance(data, meta_obj): - """ - Check if the data is - :param data: - :param meta_obj: - :return: - """ - raise NotImplementedError() +def isnamedtuple(thing): + return hasattr(thing, '_fields') and len(thing.__bases__)==1 and thing.__bases__[0]==tuple + + +def isnamedtupleinstance(thing): + return isnamedtuple(thing.__class__) def get_leaf_values(data_object, is_container_func = _is_primitive_container): @@ -206,7 +256,9 @@ def _fill_meta_object(meta_object, data_iteratable, assert_fully_used = True, ch try: if is_container_func(meta_object): - if isinstance(meta_object, (list, tuple, set)): + if isnamedtupleinstance(meta_object): + filled_object = type(meta_object)(*(_fill_meta_object(None, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func) for x in meta_object._fields)) + elif isinstance(meta_object, (list, tuple, set)): filled_object = type(meta_object)(_fill_meta_object(x, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func) for x in meta_object) elif isinstance(meta_object, OrderedDict): filled_object = type(meta_object)((k, _fill_meta_object(val, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func)) for k, val in meta_object.items()) @@ -216,7 +268,7 @@ def _fill_meta_object(meta_object, data_iteratable, assert_fully_used = True, ch raise Exception('Cannot handle container type: "{}"'.format(type(meta_object))) else: next_data = next(data_iteratable) - if check_types and meta_object is not type(next_data): + if check_types and meta_object is not type(next_data) and meta_object is not None: raise TypeError('The type of the data object: {} did not match type from the meta object: {}'.format(type(next_data), meta_object)) filled_object = next_data except StopIteration: @@ -245,11 +297,10 @@ def nested_map(func, *nested_objs, **kwargs): is_container_func = kwargs['is_container_func'] if 'is_container_func' in kwargs else _is_primitive_container check_types = kwargs['check_types'] if 'check_types' in kwargs else False assert len(nested_objs)>0, 'nested_map requires at least 2 args' - assert callable(func), 'func must be a function with one argument.' nested_types = [NestedType.from_data(nested_obj, is_container_func=is_container_func) for nested_obj in nested_objs] assert all_equal(nested_types), "The nested objects you provided had different data structures:\n{}".format('\n'.join(str(s) for s in nested_types)) - leaf_values = zip(*[nested_type.get_leaves(nested_obj, is_container_func=is_container_func, check_types=check_types) for nested_type, nested_obj in zip(nested_types, nested_objs)]) + leaf_values = zip(*[nested_type.get_leaves(nested_obj, is_container=is_container_func, check_types=check_types) for nested_type, nested_obj in zip(nested_types, nested_objs)]) new_leaf_values = [func(*v) for v in leaf_values] new_nested_obj = nested_types[0].expand_from_leaves(new_leaf_values, check_types=check_types, is_container_func=is_container_func) return new_nested_obj diff --git a/artemis/general/pareto_efficiency.py b/artemis/general/pareto_efficiency.py index ad4a6bf3..70dac8e7 100644 --- a/artemis/general/pareto_efficiency.py +++ b/artemis/general/pareto_efficiency.py @@ -27,15 +27,30 @@ def is_pareto_efficient(costs): return is_efficient -def is_pareto_efficient_ixs(costs): +def is_pareto_efficient_indexed(costs, return_mask = True): + """ + :param costs: An (n_points, n_costs) array + :param return_mask: True to return a mask + :return: An array of indices of pareto-efficient points. + If return_mask is True, this will be an (n_points, ) boolean array + Otherwise it will be a (n_efficient_points, ) integer array of indices. + """ + is_efficient = np.arange(costs.shape[0]) + n_points = costs.shape[0] + next_point_index = 0 # Next index in the is_efficient array to search for - candidates = np.arange(costs.shape[0]) - for i, c in enumerate(costs): - if 0 < np.searchsorted(candidates, i) < len(candidates): # If this element has not yet been eliminated - candidates = candidates[np.any(costs[candidates]<=c, axis=1)] - is_efficient = np.zeros(costs.shape[0], dtype = bool) - is_efficient[candidates] = True - return is_efficient + while next_point_index0 for c in costs[ixs]: @@ -26,24 +26,31 @@ def test_is_pareto_efficient(plot=False): plt.show() -def profile_pareto_efficient(): +def profile_pareto_efficient(n_points=5000, n_costs=2, include_dumb = True): rng = np.random.RandomState(1234) - costs = rng.rand(5000, 2) + costs = rng.rand(n_points, n_costs) - with EZProfiler('dumb'): - dumb_ixs = is_pareto_efficient_dumb(costs) + if include_dumb: + with EZProfiler('is_pareto_efficient_dumb'): + base_ixs = dumb_ixs = is_pareto_efficient_dumb(costs) - with EZProfiler('smart'): + with EZProfiler('is_pareto_efficient'): less_dumb__ixs = is_pareto_efficient(costs) - assert np.array_equal(dumb_ixs, less_dumb__ixs) + if not include_dumb: + base_ixs = less_dumb__ixs + assert np.array_equal(base_ixs, less_dumb__ixs) - with EZProfiler('index-tracking'): - smart_ixs = is_pareto_efficient_ixs(costs) + with EZProfiler('is_pareto_efficient_indexed'): + smart_indexed = is_pareto_efficient_indexed(costs, return_mask=True) + assert np.array_equal(base_ixs, smart_indexed) - assert np.array_equal(dumb_ixs, smart_ixs) + with EZProfiler('is_pareto_efficient_indexed_reordered'): + smart_indexed = is_pareto_efficient_indexed(costs, return_mask=True, rank_reorder=True) + assert np.array_equal(base_ixs, smart_indexed) if __name__ == '__main__': - test_is_pareto_efficient() + # test_is_pareto_efficient() + profile_pareto_efficient(n_points=100000, n_costs=2, include_dumb=False) diff --git a/artemis/ml/predictors/predictor_tests.py b/artemis/ml/predictors/predictor_tests.py index 760084d5..1c68f500 100644 --- a/artemis/ml/predictors/predictor_tests.py +++ b/artemis/ml/predictors/predictor_tests.py @@ -1,11 +1,10 @@ import numpy as np from artemis.ml.predictors.predictor_comparison import assess_online_predictor -from artemis.ml.predictors.train_and_test import percent_argmax_correct -from plato.tools.common.bureaucracy import multichannel from artemis.ml.datasets.synthetic_clusters import get_synthetic_clusters_dataset +from artemis.ml.tools.costs import percent_argmax_correct from artemis.ml.tools.processors import OneHotEncoding - +from plato.tools.common.bureaucracy import multichannel __author__ = 'peter' From 0db3d4058e4dd59e9b9e95c512bd3b6ea37e8ea9 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Fri, 14 Sep 2018 17:18:54 +0200 Subject: [PATCH 05/41] everything --- artemis/experiments/decorators.py | 10 +- artemis/experiments/experiment_record.py | 2 +- artemis/experiments/experiment_record_view.py | 108 ++++++--- artemis/experiments/experiments.py | 50 +++-- artemis/experiments/ui.py | 96 ++++++-- artemis/general/checkpoint_counter.py | 25 +++ artemis/general/dead_easy_ui.py | 212 ++++++++++++++++++ artemis/general/display.py | 8 +- artemis/general/duck.py | 16 +- artemis/general/mymath.py | 15 ++ artemis/general/nested_structures.py | 75 +++++-- artemis/general/progress_indicator.py | 26 ++- artemis/general/should_be_builtins.py | 23 +- artemis/general/table_ui.py | 121 ++++++++++ artemis/general/tables.py | 21 +- artemis/general/test_checkpoint_counter.py | 9 +- artemis/general/test_duck.py | 58 +++-- artemis/general/test_mymath.py | 14 +- artemis/general/test_nested_structures.py | 24 +- artemis/general/test_should_be_builtins.py | 20 +- artemis/general/test_time_parser.py | 26 +++ artemis/general/time_parser.py | 15 +- artemis/plotting/data_conversion.py | 40 ++++ artemis/plotting/db_plotting.py | 124 ++++++---- artemis/plotting/expanding_subplots.py | 2 +- artemis/plotting/matplotlib_backend.py | 20 +- artemis/plotting/pyplot_plus.py | 4 +- artemis/plotting/test_db_plotting.py | 14 +- 28 files changed, 963 insertions(+), 215 deletions(-) create mode 100644 artemis/general/dead_easy_ui.py create mode 100644 artemis/general/table_ui.py create mode 100644 artemis/general/test_time_parser.py diff --git a/artemis/experiments/decorators.py b/artemis/experiments/decorators.py index 91a5fcf3..7c2bbaa5 100644 --- a/artemis/experiments/decorators.py +++ b/artemis/experiments/decorators.py @@ -44,7 +44,7 @@ class ExperimentFunction(object): This is the most general decorator. You can use this to add details on the experiment. """ - def __init__(self, show = show_record, compare = compare_experiment_records, display_function=None, comparison_function=None, one_liner_function=sensible_str, is_root=False): + def __init__(self, show = None, compare = compare_experiment_records, display_function=None, comparison_function=None, one_liner_function=None, result_parser = None, is_root=False): """ :param show: A function that is called when you "show" an experiment record in the UI. It takes an experiment record as an argument. @@ -60,11 +60,11 @@ def __init__(self, show = show_record, compare = compare_experiment_records, dis self.compare = compare if display_function is not None: - assert show is show_record, "You can't set both display function and show. (display_function is deprecated)" + assert show is None, "You can't set both display function and show. (display_function is deprecated)" show = lambda rec: display_function(rec.get_result()) if comparison_function is not None: - assert compare is compare_experiment_records, "You can't set both display function and show. (display_function is deprecated)" + assert compare is None, "You can't set both display function and show. (display_function is deprecated)" def compare(records): record_experiment_ids_uniquified = uniquify_duplicates(rec.get_experiment_id() for rec in records) @@ -74,6 +74,7 @@ def compare(records): self.compare = compare self.is_root = is_root self.one_liner_function = one_liner_function + self.result_parser = result_parser def __call__(self, f): f.is_base_experiment = True @@ -83,6 +84,7 @@ def __call__(self, f): show=self.show, compare = self.compare, one_liner_function=self.one_liner_function, - is_root=self.is_root + is_root=self.is_root, + result_parser=self.result_parser, ) return ex \ No newline at end of file diff --git a/artemis/experiments/experiment_record.py b/artemis/experiments/experiment_record.py index e40a3f26..5e67d99f 100644 --- a/artemis/experiments/experiment_record.py +++ b/artemis/experiments/experiment_record.py @@ -289,7 +289,7 @@ def get_experiment(self): """ Load the experiment associated with this record. Note that this will raise an ExperimentNotFoundError if the experiment has not been imported. - :return: An Experiment object + :return Experiment: An Experiment object """ from artemis.experiments.experiments import load_experiment return load_experiment(self.get_experiment_id()) diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 5b093fc1..ce0eedc7 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -1,21 +1,20 @@ import re from collections import OrderedDict +from six import string_types from tabulate import tabulate -from artemis.experiments.experiment_management import load_lastest_experiment_results +import numpy as np from artemis.experiments.experiment_record import NoSavedResultError, ExpInfoFields, ExperimentRecord, \ load_experiment_record, is_matplotlib_imported, UnPicklableArg -from artemis.experiments.experiments import is_experiment_loadable, get_global_experiment_library -from artemis.general.display import deepstr, truncate_string, hold_numpy_printoptions, side_by_side, CaptureStdOut, \ - surround_with_header, section_with_header +from artemis.general.display import deepstr, truncate_string, hold_numpy_printoptions, side_by_side, \ + surround_with_header, section_with_header, dict_to_str from artemis.general.nested_structures import flatten_struct, PRIMATIVE_TYPES -from artemis.general.should_be_builtins import separate_common_items, all_equal, bad_value, izip_equal, \ - remove_duplicates +from artemis.general.should_be_builtins import separate_common_items, bad_value, izip_equal, \ + remove_duplicates, get_unique_name, entries_to_table from artemis.general.tables import build_table -from six import string_types -def get_record_result_string(record, func='deep', truncate_to = None, array_print_threshold=8, array_float_format='.3g', oneline=False): +def get_record_result_string(record, func='deep', truncate_to = None, array_print_threshold=8, array_float_format='.3g', oneline=False, default_one_liner_func=str): """ Get a string representing the result of the experiment. :param record: @@ -36,7 +35,7 @@ def get_record_result_string(record, func='deep', truncate_to = None, array_prin return '' string = func(result) if not isinstance(string, string_types): - string = str(string) + string = default_one_liner_func(string) if truncate_to is not None: string = truncate_string(string, truncation=truncate_to, message = '...') @@ -117,6 +116,7 @@ def get_record_invalid_arg_string(record, recursive=False, ignore_valid_keys=(), Return a string identifying ig the arguments for this experiment are still valid. :return: """ + from artemis.experiments.experiments import is_experiment_loadable assert note_version in ('full', 'short') experiment_id = record.get_experiment_id() if is_experiment_loadable(experiment_id): @@ -149,7 +149,7 @@ def get_record_invalid_arg_string(record, recursive=False, ignore_valid_keys=(), return notes -def get_oneline_result_string(record, truncate_to=None, array_float_format='.3g', array_print_threshold=8): +def get_oneline_result_string(record, truncate_to=None, array_float_format='.3g', array_print_threshold=8, default_one_liner_func=dict_to_str): """ Get a string that describes the result of the record in one line. This can optionally be specified by experiment.one_liner_function. @@ -160,16 +160,17 @@ def get_oneline_result_string(record, truncate_to=None, array_float_format='.3g' :param array_print_threshold: :return: A string with no newlines briefly describing the result of the record. """ + from artemis.experiments.experiments import is_experiment_loadable if isinstance(record, string_types): record = load_experiment_record(record) if not is_experiment_loadable(record.get_experiment_id()): - one_liner_function=str + one_liner_function=default_one_liner_func else: one_liner_function = record.get_experiment().one_liner_function if one_liner_function is None: - one_liner_function = str + one_liner_function = default_one_liner_func return get_record_result_string(record, func=one_liner_function, truncate_to=truncate_to, array_print_threshold=array_print_threshold, - array_float_format=array_float_format, oneline=True) + array_float_format=array_float_format, oneline=True, default_one_liner_func=default_one_liner_func) def print_experiment_record_argtable(records): @@ -204,6 +205,69 @@ def lookup_fcn(record_id, column): print(tabulate(rows)) +def get_column_change_ordering(tabular_data): + """ + Get the order in which to rearrange the columns so that the fastest-changing data comes last. + + :param tabular_data: A list of equal-length lists + :return: A set of permutation indices for the columns. + """ + n_rows, n_columns = len(tabular_data), len(tabular_data[0]) + deltas = [sum(row_prev[i]!=row[i] for row_prev, row in zip(tabular_data[:-1], tabular_data[1:])) for i in range(n_columns)] + return np.argsort(deltas) + + +def get_different_args(args, no_arg_filler = 'N/A', arrange_by_deltas=False): + """ + Get a table of different args between records. + :param Sequence[List[Tuple[str, Any]]] args: A list of lists of argument (name, value) pairs. + :param no_arg_filler: The filler value to use if a record does not have a particular argument (possibly due to an argument being added to the code after the record was made) + :param arrange_by_deltas: If true, order arguments so that the fastest-changing ones are in the last column + :return Tuple[List[str], List[List[Any]]]: (arg_names, arg_values) where: + arg_names is a list of arguments that differ between records + arg_values is a len(records)-list of len(arg_names) lists of values of the arguments for each record. + """ + args = list(args) + common_args, different_args = separate_common_items(args) + all_different_args = list(remove_duplicates((k for dargs in different_args for k in dargs.keys()))) + values = [[record_args[argname] if argname in record_args else no_arg_filler for argname in all_different_args] for record_args in args] + if arrange_by_deltas: + col_shuf_ixs = get_column_change_ordering(values) + all_different_args = [all_different_args[i] for i in col_shuf_ixs] + values = [[row[i] for i in col_shuf_ixs] for row in values] + return all_different_args, values + + +def get_exportiment_record_arg_result_table(records): + record_ids = [record.get_id() for record in records] + all_different_args, arg_values = get_different_args([r.get_args() for r in records], no_arg_filler='N/A') + + parsed_results = [record.get_experiment().result_parser(record.get_result()) for record in records] + result_fields, result_data = entries_to_table(parsed_results) + result_fields = [get_unique_name(rf, all_different_args) for rf in result_fields] # Just avoid name collisions + + # result_column_name = get_unique_name('Results', taken_names=all_different_args) + + def lookup_fcn(record_id, arg_or_result_name): + row_index = record_ids.index(record_id) + if arg_or_result_name in result_fields: + return result_data[row_index][result_fields.index(arg_or_result_name)] + else: + column_index = all_different_args.index(arg_or_result_name) + return arg_values[row_index][column_index] + + rows = build_table(lookup_fcn, + row_categories=record_ids, + column_categories=all_different_args + result_fields, + prettify_labels=False, + include_row_category=False, + ) + + return rows[0], rows[1:] + + # return tabulate(rows[1:], headers=rows[0]) + + def show_record(record, show_logs=True, truncate_logs=None, truncate_result=10000, header_width=100, show_result ='deep', hang=True): """ Show the results of an experiment record. @@ -221,8 +285,6 @@ def show_record(record, show_logs=True, truncate_logs=None, truncate_result=1000 has_matplotlib_figures = any(loc.endswith('.pkl') for loc in record.get_figure_locs()) if has_matplotlib_figures: - from matplotlib import pyplot as plt - from artemis.plotting.saving_plots import interactive_matplotlib_context record.show_figures(hang=hang) print(string) @@ -280,22 +342,6 @@ def compare_experiment_records(records, parallel_text=None, show_logs=True, trun return has_matplotlib_figures -def find_experiment(*search_terms): - """ - Find an experiment. Invoke - :param search_term: A term that will be used to search for an experiment. - :return: - """ - global_lib = get_global_experiment_library() - found_experiments = OrderedDict((name, ex) for name, ex in global_lib.items() if all(re.search(term, name) for term in search_terms)) - if len(found_experiments)==0: - raise Exception("None of the {} experiments matched the search: '{}'".format(len(global_lib), search_terms)) - elif len(found_experiments)>1: - raise Exception("More than one experiment matched the search '{}', you need to be more specific. Found: {}".format(search_terms, found_experiments.keys())) - else: - return found_experiments.values()[0] - - def make_record_comparison_table(records, args_to_show=None, results_extractor = None, print_table = False, tablefmt='simple', reorder_by_args=False): """ Make a table comparing the arguments and results of different experiment records. You can use the output diff --git a/artemis/experiments/experiments.py b/artemis/experiments/experiments.py index db4a7a2f..f9776356 100644 --- a/artemis/experiments/experiments.py +++ b/artemis/experiments/experiments.py @@ -3,12 +3,13 @@ from collections import OrderedDict from contextlib import contextmanager from functools import partial - from six import string_types - from artemis.experiments.experiment_record import ExpStatusOptions, experiment_id_to_record_ids, load_experiment_record, \ get_all_record_ids, clear_experiment_records from artemis.experiments.experiment_record import run_and_record +from artemis.experiments.experiment_record_view import compare_experiment_records, show_record +from artemis.general.display import sensible_str + from artemis.general.functional import get_partial_root, partial_reparametrization, \ advanced_getargspec, PartialReparametrization @@ -19,7 +20,7 @@ class Experiment(object): create variants using decorated_function.add_variant() """ - def __init__(self, function=None, show=None, compare=None, one_liner_function=None, + def __init__(self, function=None, show=None, compare=None, one_liner_function=None, result_parser = None, name=None, is_root=False): """ :param function: The function defining the experiment @@ -31,9 +32,10 @@ def __init__(self, function=None, show=None, compare=None, one_liner_function=No """ self.name = name self.function = function - self._show = show - self._one_liner_results = one_liner_function - self._compare = compare + self._show = show_record if show is None else show + self._one_liner_results = sensible_str if one_liner_function is None else one_liner_function + self._result_parser = (lambda result: [('Result', self.one_liner_function(result))]) if result_parser is None else result_parser + self._compare = compare_experiment_records if compare is None else compare self.variants = OrderedDict() self._notes = [] self.is_root = is_root @@ -62,6 +64,10 @@ def compare(self): def compare(self, val): self._compare = val + @property + def result_parser(self): + return self._result_parser + def __call__(self, *args, **kwargs): """ Run the function as normal, without recording or anything. You can also modify with arguments. """ return self.function(*args, **kwargs) @@ -161,6 +167,7 @@ def _create_experiment_variant(self, args, kwargs, is_root): show=self._show, compare=self._compare, one_liner_function=self.one_liner_function, + result_parser=self._result_parser, is_root=is_root ) self.variants[name] = ex @@ -183,7 +190,7 @@ def add_variant(self, variant_name = None, **kwargs): :param variant_name: Optionally, the name of the experiment :param kwargs: The named arguments which will differ from the base experiment. - :return: The experiment. + :return Experiment: The experiment. """ return self._create_experiment_variant(() if variant_name is None else (variant_name, ), kwargs, is_root=False) @@ -204,7 +211,7 @@ def add_root_variant(self, variant_name=None, **kwargs): :param variant_name: Optionally, the name of the experiment :param kwargs: The named arguments which will differ from the base experiment. - :return: The experiment. + :return Experiment: The experiment. """ return self._create_experiment_variant(() if variant_name is None else (variant_name, ), kwargs, is_root=True) @@ -320,12 +327,13 @@ def browse(self, command=None, catch_errors = False, close_after = False, filter :param display_format: How experements and their records are displayed: 'nested' or 'flat'. 'nested' might be better for narrow console outputs. """ - from artemis.experiments.ui import browse_experiments - browse_experiments(command = command, root_experiment=self, catch_errors=catch_errors, close_after=close_after, - filterexp=filterexp, filterrec=filterrec, - view_mode=view_mode, raise_display_errors=raise_display_errors, run_args=run_args, keep_record=keep_record, - truncate_result_to=truncate_result_to, cache_result_string=cache_result_string, remove_prefix=remove_prefix, - display_format=display_format, **kwargs) + from artemis.experiments.ui import ExperimentBrowser + experiments = get_ordered_descendents_of_root(root_experiment=self) + browser = ExperimentBrowser(experiments=experiments, catch_errors=catch_errors, close_after=close_after, + filterexp=filterexp, filterrec=filterrec, view_mode=view_mode, raise_display_errors=raise_display_errors, + run_args=run_args, keep_record=keep_record, truncate_result_to=truncate_result_to, cache_result_string=cache_result_string, + remove_prefix=remove_prefix, display_format=display_format, **kwargs) + browser.launch(command=command) # Above this line is the core api.... # ----------------------------------- @@ -434,7 +442,6 @@ def add_two_numbers(a=1, b=2): return a+b with capture_created_experiments() as exps: - add_two_numbers.add_variant(a=2) add_two_numbers.add_variant(a=3) for ex in exps: @@ -445,7 +452,7 @@ def add_two_numbers(a=1, b=2): current_len = len(_GLOBAL_EXPERIMENT_LIBRARY) new_experiments = [] yield new_experiments - for ex in _GLOBAL_EXPERIMENT_LIBRARY.values()[current_len:]: + for ex in list(_GLOBAL_EXPERIMENT_LIBRARY.values())[current_len:]: new_experiments.append(ex) @@ -458,6 +465,16 @@ def get_nonroot_global_experiment_library(): return OrderedDict((name, exp) for name, exp in _GLOBAL_EXPERIMENT_LIBRARY.items() if not exp.is_root) +def get_ordered_descendents_of_root(root_experiment): + """ + :param Experiment root_experiment: An experiment which has variants + :return List[Experiment]: A list of the descendents (i.e. variants and subvariants) of the root experiment, in the + order in which they were created + """ + descendents_of_root = set(ex for ex in root_experiment.get_all_variants(include_self=True)) + return [ex for ex in get_nonroot_global_experiment_library().values() if ex in descendents_of_root] + + def get_experiment_info(name): experiment = load_experiment(name) return str(experiment) @@ -480,6 +497,7 @@ def _kwargs_to_experiment_name(kwargs): string = string.replace('/', '_SLASH_') return string + @contextmanager def hold_global_experiment_libary(new_lib = None): if new_lib is None: diff --git a/artemis/experiments/ui.py b/artemis/experiments/ui.py index 3dff208c..d07302dc 100644 --- a/artemis/experiments/ui.py +++ b/artemis/experiments/ui.py @@ -22,7 +22,8 @@ load_experiment_record, ExpInfoFields) from artemis.experiments.experiment_record_view import (get_record_full_string, get_record_invalid_arg_string, print_experiment_record_argtable, get_oneline_result_string, - compare_experiment_records) + compare_experiment_records, + get_exportiment_record_arg_result_table, get_different_args) from artemis.experiments.experiment_record_view import show_record, show_multiple_records from artemis.experiments.experiments import load_experiment, get_nonroot_global_experiment_library from artemis.fileman.local_dir import get_artemis_data_path @@ -30,6 +31,7 @@ from artemis.general.hashing import compute_fixed_hash from artemis.general.mymath import levenshtein_distance from artemis.general.should_be_builtins import all_equal, insert_at, izip_equal, separate_common_items, bad_value +from artemis.general.table_ui import TableExplorerUI try: import readline # Makes input() behave like interactive shell. @@ -56,12 +58,12 @@ def _warn_with_prompt(message= None, prompt = 'Press Enter to continue or q then return out -def browse_experiments(command=None, **kwargs): +def browse_experiments(experiments = None, command=None, **kwargs): """ Browse Experiments :param command: Optionally, a string command to pass directly to the UI. (e.g. "run 1") - :param root_experiment: The Experiment whose (self and) children to browse + :param experiments: If root_experiment not specified, the list of experiments to look over. :param catch_errors: Catch errors that arise while running experiments :param close_after: Close after issuing one command. :param just_last_record: Only show the most recent record for each experiment. @@ -73,7 +75,7 @@ def browse_experiments(command=None, **kwargs): :param cache_result_string: Cache the result string (useful when it takes a very long time to display the results when opening up the menu - often when results are long lists). """ - browser = ExperimentBrowser(**kwargs) + browser = ExperimentBrowser(experiments=experiments, **kwargs) browser.launch(command=command) @@ -162,12 +164,13 @@ class ExperimentBrowser(object): age<24h Select records which are less than 24h old. """ - def __init__(self, root_experiment = None, catch_errors = True, close_after = False, filterexp=None, filterrec = None, + def __init__(self, experiments=None, catch_errors = True, close_after = False, filterexp=None, filterrec = None, view_mode ='full', raise_display_errors=False, run_args=None, keep_record=True, truncate_result_to=100, ignore_valid_keys=(), cache_result_string = False, slurm_kwargs={}, remove_prefix = None, display_format='nested', - show_args=False, catch_selection_errors=True, max_width=None, table_package = 'tabulate', show_archived=True): + show_args=False, catch_selection_errors=True, max_width=None, table_package = 'tabulate', show_archived=True, + sortkey = None): """ - :param root_experiment: The Experiment whose (self and) children to browse + :param Sequence[Experiment] experiments: If root_experiment not specified: A list of experiments to include. :param catch_errors: Catch errors that arise while running experiments :param close_after: Close after issuing one command. :param filterexp: Filter the experiments with this selection (see help for how to use) @@ -184,6 +187,8 @@ def __init__(self, root_experiment = None, catch_errors = True, close_after = Fa :param remove_prefix: Remove the common prefix on the experiment ids in the display. :param display_format: How experements and their records are displayed: 'nested' or 'flat'. 'nested' might be better for narrow console outputs. + :param sortkey: A key function to use for sorting experiments. It takes experiment name and returns an object to + be used by the sorter. """ if run_args is None: @@ -192,7 +197,12 @@ def __init__(self, root_experiment = None, catch_errors = True, close_after = Fa run_args['keep_record'] = keep_record if remove_prefix is None: remove_prefix = display_format=='flat' - self.root_experiment = root_experiment + + if experiments is not None: + self._experiment_names = [ex.name for ex in experiments if not ex.is_root] + else: + self._experiment_names = get_nonroot_global_experiment_library().keys() + self.close_after = close_after self.catch_errors = catch_errors self.exp_record_dict = None @@ -212,18 +222,13 @@ def __init__(self, root_experiment = None, catch_errors = True, close_after = Fa self.max_width = max_width self.table_package = table_package self.show_archived = show_archived + self._sortkey = sortkey def _reload_record_dict(self): - names = get_nonroot_global_experiment_library().keys() - if self.root_experiment is not None: - # We could just go [ex.name for ex in self.root_experiment.get_all_variants(include_self=True)] - # but we want to preserve the order in which experiments were created - descendents_of_root = set(ex.name for ex in self.root_experiment.get_all_variants(include_self=True)) - names = [name for name in names if name in descendents_of_root] - all_experiments = get_experient_to_record_dict(names) + all_experiments = get_experient_to_record_dict(self._experiment_names) return all_experiments - def _filter_record_dict(self, all_experiments): + def _filter_record_dict(self, all_experiments, sortkey=None): # Apply filters and display Table: if self._filter is not None: try: @@ -239,6 +244,10 @@ def _filter_record_dict(self, all_experiments): old_filterrec = self._filterrec self._filterrec = None raise RecordSelectionError("Failed to apply record filter: '{}' because {}. Removing filter.".format(old_filterrec, err)) + + if sortkey is not None: + all_experiments = OrderedDict((exp_name, all_experiments[exp_name]) for exp_name in sorted(all_experiments.keys(), key=sortkey)) + return all_experiments def launch(self, command=None): @@ -249,6 +258,7 @@ def launch(self, command=None): 'show': self.show, 'call': self.call, 'kill': self.kill, + 'argsort': self.argsort, 'selectexp': self.selectexp, 'selectrec': self.selectrec, 'view': self.view, @@ -260,6 +270,7 @@ def launch(self, command=None): 'explist': self.explist, 'sidebyside': self.side_by_side, 'argtable': self.argtable, + 'argcompare': self.argcompare, 'compare': self.compare, 'delete': self.delete, 'errortrace': self.errortrace, @@ -275,7 +286,7 @@ def launch(self, command=None): if display_again: all_experiments = self._reload_record_dict() try: - self.exp_record_dict = self._filter_record_dict(all_experiments) + self.exp_record_dict = self._filter_record_dict(all_experiments, sortkey=self._sortkey) except RecordSelectionError as err: _warn_with_prompt(str(err), use_prompt=self.catch_selection_errors) if not self.catch_selection_errors: @@ -423,6 +434,28 @@ def remove_notes_if_no_notes(_record_rows, _record_headers): table = tabulate(rows, headers=full_headers) else: raise NotImplementedError(self.table_package) + + elif self.display_format == 'args': + + def arg_iterator(exp_record_dict_): + for exp_name, record_ids_ in exp_record_dict_.items(): + if len(record_ids_)==0: + yield load_experiment(exp_name).get_args() + else: + for record_id_ in record_ids_: + yield load_experiment_record(record_id_).get_args() + + arg_names, different_args = get_different_args(arg_iterator(exp_record_dict), arrange_by_deltas=True) + rows = [] + arg_iter = iter(different_args) + info_to_include = [ExpRecordDisplayFields.STATUS, ExpRecordDisplayFields.RESULT_STR] + for i, (exp_id, record_ids) in enumerate(exp_record_dict.items()): + if len(record_ids)==0: + rows.append([str(i), '']+list(next(arg_iter))) + else: + for j, record_id in enumerate(record_ids): + rows.append([str(i) if j==0 else '', j] + list(next(arg_iter)) + row_func(record_id, info_to_include, raise_display_errors=self.raise_display_errors, truncate_to=self.truncate_result_to, ignore_valid_keys=self.ignore_valid_keys)) + table = tabulate(rows, headers=['E#', 'R#']+list(arg_names)+[f.value for f in info_to_include]) else: raise NotImplementedError(self.display_format) @@ -479,6 +512,21 @@ def run(self, *args): if result=='q': quit() + def argsort(self, *args): + + # First verify that all args are included... + all_arg_names = set(a for exp_name in self.exp_record_dict.keys() for a, v in load_experiment(exp_name).get_args().items()) + if any(a not in all_arg_names for a in args): + raise RecordSelectionError('Arg(s) [{}] were not included in any experiments') + + # Define a comparison function that will always compare. + def key_sorting_function(exp_name): + exp_args = load_experiment(exp_name).get_args() + return tuple(() if name not in exp_args else (None, exp_args[name]) if isinstance(exp_args[name], (int, float)) else (str(type(exp_args[name])), exp_args[name]) for name in args) + + self._sortkey = key_sorting_function + return ExperimentBrowser.REFRESH + def archive(self, *args): parser = argparse.ArgumentParser() parser.add_argument('user_range', action='store', help='A selection of experiments to run. Examples: "3" or "3-5", or "3,4,5"') @@ -557,7 +605,7 @@ def compare(self, *args): _warn_with_prompt(use_prompt=False) def displayformat(self, new_format): - assert new_format in ('nested', 'flat'), "Display format must be 'nested' or 'flat', not '{}'".format(new_format) + assert new_format in ('nested', 'flat', 'args'), "Display format must be 'nested' or 'flat' or 'args', not '{}'".format(new_format) self.display_format = new_format return ExperimentBrowser.REFRESH @@ -609,6 +657,18 @@ def argtable(self, *args): parser.add_argument('user_range', action='store', nargs = '?', default='all', help='A selection of experiment records to run. Examples: "3" or "3-5", or "3,4,5"') args = parser.parse_args(args) records = select_experiment_records(args.user_range, self.exp_record_dict, flat=True) + + headers, rows = get_exportiment_record_arg_result_table(records) + TableExplorerUI(table_data=rows, col_headers=headers).launch() + + # print(get_exportiment_record_arg_result_table(records)) + _warn_with_prompt(use_prompt=False) + + def argcompare(self, *args): + parser = argparse.ArgumentParser() + parser.add_argument('user_range', action='store', nargs = '?', default='all', help='A selection of experiment records to run. Examples: "3" or "3-5", or "3,4,5"') + args = parser.parse_args(args) + records = select_experiment_records(args.user_range, self.exp_record_dict, flat=True) print_experiment_record_argtable(records) _warn_with_prompt(use_prompt=False) diff --git a/artemis/general/checkpoint_counter.py b/artemis/general/checkpoint_counter.py index 7c7bc745..c1f94264 100644 --- a/artemis/general/checkpoint_counter.py +++ b/artemis/general/checkpoint_counter.py @@ -56,6 +56,7 @@ def __init__(self, checkpoint_generator, default_units = None, skip_first = Fals A string e.g. '3s', indicating "plot every 3 seconds" A generator object yielding checkpoints A list/tuple/array of checkpoints + A dict mapping progress point to interval. e.g. {0: 0.25, 1: 0.5, 2: 1) means "move in steps of 0.25 at first, then 0.5 after 1 epoch, then 1 after 2 epochs" ('even', interval) ('exp', first, growth) None @@ -81,6 +82,8 @@ def __init__(self, checkpoint_generator, default_units = None, skip_first = Fals checkpoint_generator = (first*i*(1+growth)**(i-1) for i in itertools.count(0)) else: raise Exception("Can't make a checkpoint generator {}".format(checkpoint_generator)) + elif isinstance(checkpoint_generator, dict): + checkpoint_generator = dictionary_interval_generator(checkpoint_generator) elif isinstance(checkpoint_generator, (list, tuple, np.ndarray)): checkpoint_generator = iter(checkpoint_generator) elif isinstance(checkpoint_generator, (int, float)): @@ -163,3 +166,25 @@ def do_every(interval, counter_id=None, units = None): if counter_id not in _COUNTERS_DICT: _COUNTERS_DICT[counter_id] = Checkpoints(checkpoint_generator=interval, default_units=units) return _COUNTERS_DICT[counter_id]() + + +def dictionary_interval_generator(position_increment_dict): + """ + Generate checkpoints based on a dictionary of . Eg. + {0: 0.5, 2: 1, 5: 2} will generate + (0, 0.5, 1, 1.5, 2, 3, 4, 5, 7, 9, 11, ...} + + :param Dict[float, float] position_increment_dict: A dict mapping the current position to the increment for the next position. + :return Generator[float]: A series of checkpoints. + """ + change_points = sorted(position_increment_dict.keys()) + intervals = [position_increment_dict[p] for p in change_points] + current_interval_index = 0 + pt = change_points[0] + yield pt + while True: + increment = intervals[current_interval_index] + pt = pt+increment + yield pt + if current_interval_index < len(change_points)-1 and change_points[current_interval_index+1] <= pt: + current_interval_index+=1 diff --git a/artemis/general/dead_easy_ui.py b/artemis/general/dead_easy_ui.py new file mode 100644 index 00000000..4285dc95 --- /dev/null +++ b/artemis/general/dead_easy_ui.py @@ -0,0 +1,212 @@ +from __future__ import print_function +from __future__ import absolute_import +from builtins import range +from builtins import input +from builtins import zip +import inspect +import shlex +from collections import OrderedDict + +# Code taken and modified from: +# Taken from https://github.com/braincorp/dead_easy_ui/ + + +class DeadEasyUI(object): + """ + A tool for making a small program that can be run either programmatically or as a command-line user interface. Use + this class by extending it. Whatever methods you add to the extended class will become commands that you can run. + + Example (which you can also run at the bottom of this file): + + class MyUserInterface(DeadEasyUI): + + def isprime(self, number): + print '{} is {}prime'.format(number, {False: '', True: 'not '}[next(iter([True for i in range(3, number) if number%i==0]+[False]))]) + + def showargs(self, arg1, arg2): + print ' arg1: {}: {}\n arg2: {}: {}\n'.format(arg1, type(arg1), arg2, type(arg2)) + + MyUserInterface().launch(run_in_loop=True) + + This will bring up a UI + + ==== MyUserInterface console menu ==== + Command Options: {isprime, showargs} + Enter Command or "help" for help >> isprime 117 + 117 is not prime + ==== MyUserInterface console menu ==== + Command Options: {isprime, showargs} + Enter Command or "help" for help >> showargs 'abc' arg2=4.3 + arg1: abc: + arg2: 4.3: + + """ + + def _get_menu_string(self): + return '==== {} console menu ====\n'.format(self.__class__.__name__) if (self.__doc__ is None or self.__doc__ == DeadEasyUI.__doc__) else self.__doc__ if self.__doc__.endswith('\n') else self.__doc__+ '\n' + + def launch(self, prompt = 'Enter Command or "h" for help >> ', run_in_loop = True, arg_handling_mode='fallback'): + """ + Launch a command-line UI. + :param prompt: + :param run_in_loop: + :param arg_handling_mode: Can be ('str', 'guess', 'literal') + 'str': Means pass all args to the method as strings + 'literal': Means use eval to parse each arg. + 'fallback': Try literal parsing, and if it fails, fall back to string + :return: + """ + + def linenumber_of_member(k, m): + try: + return m.__func__.__code__.co_firstlineno + except AttributeError: + return -1 + + mymethods = sorted(inspect.getmembers(self, predicate=inspect.ismethod), + key = lambda pair: linenumber_of_member(*pair)) + mymethods = [(method_name, method) for method_name, method in mymethods if method_name!='launch' and not method_name.startswith('_')] + mymethods = OrderedDict(mymethods) + + options_doc = 'Command Options: {{{}}}'.format(', '.join([k for k in list(mymethods.keys())]+['quit', 'help'])) + + skip_info = False + while True: + doc = self._get_menu_string() + + if not skip_info: + print('{}{}'.format(doc, options_doc)) + user_input = input(' {}'.format(prompt)) + cmd, args, kwargs = parse_user_function_call(user_input, arg_handling_mode=arg_handling_mode) + + if cmd is None: + continue + + skip_info = False + if cmd in ('h', 'help'): + print(self._get_help_string(mymethods=mymethods, method_names_for_help=[args[0]] if len(args) > 0 else None)) + skip_info = True + continue + elif cmd in ('q', 'quit'): + print('Quitting {}. So long.'.format(self.__class__.__name__)) + break + elif cmd in mymethods: + mymethods[cmd](*args, **kwargs) + else: + print("Unknown command '{}'. Options are {}".format(cmd, list(mymethods.keys())+['help'])) + skip_info = True + continue + if not run_in_loop: + break + + def _get_help_string(self, mymethods, method_names_for_help=None): + string = '' + string += '----------------------------\n' + string += "To run a command, type the method name and then space-separated arguments. e.g.\n >> my_method 1 'string-arg' named_arg=2\n\n" + if method_names_for_help is None: + method_names_for_help = list(mymethods.keys()) + if len(method_names_for_help) == 0: + string+= "Class {} has no methods, and is therefor a useless console menu. Add methods.\n".format( + self.__class__.__name__) + for method_name in method_names_for_help: + argspec = inspect.getargspec(mymethods[method_name]) + default_start_ix = len(argspec.args) if argspec.defaults is None else len(argspec.args) - len( + argspec.defaults) + argstring = ' '.join([repr(a) for a in argspec.args[1:default_start_ix]] + + ['[{}={}]'.format(a, repr(v)) for a, v in zip(argspec.args[default_start_ix:], argspec.defaults if argspec.defaults is not None else [])]) \ + if len(argspec.args)>1 else '' + doc = mymethods[method_name].__doc__ if mymethods[method_name].__doc__ is not None else '' + string+= '- {} {}: {}\n'.format(method_name, argstring, doc) + string += '----------------------------\n' + return string + + +def parse_user_function_call(cmd_str, arg_handling_mode = 'fallback'): + """ + A simple way to parse a user call to a python function. The purpose of this is to make it easy for a user + to specify a python function and the arguments to call it with from the console. Example: + + parse_user_function_call("my_function 1 'two' a='three'") == ('my_function', (1, 'two'), {'a': 'three'}) + + Other code can use this to actually call the function + + Parse arguments to a Python function + :param str cmd_str: The command string. e.g. "my_function 1 'two' a='three'" + :param forgive_unquoted_strings: Allow for unnamed string args to be unquoted. + e.g. "my_function my_arg_string" would interpreted as "my_function 'my_arg_string' instead of throwing an error" + :return: The function name, args, kwargs + :rtype: Tuple[str, Tuple[Any]. Dict[str: Any] + """ + + assert arg_handling_mode in ('str', 'literal', 'fallback') + + # def _fake_func(*args, **kwargs): + # Just exists to help with extracting args, kwargs + # return args, kwargs + + cmd_args = shlex.split(cmd_str, posix=False) + assert len(cmd_args) == len(shlex.split(cmd_str, posix=True)), "Parse error on string '{}'. You're not allowed having spaces in the values of string keyword args:".format(cmd_str) + + if len(cmd_args)==0: + return None, None, None + + func_name = cmd_args[0] + + def parse_arg(arg_str): + if arg_handling_mode=='str': + return arg_str + elif arg_handling_mode=='literal': + return eval(arg_str, {}, {}) + else: + try: + return eval(arg_str, {}, {}) + except: + return arg_str + + args = [] + kwargs = {} + for arg in cmd_args[1:]: + if '=' not in arg: # Positional + assert len(kwargs)==0, 'You entered a positional arg after a keyword arg. Keyword args {} aleady exist.'.format(kwargs) + args.append(parse_arg(arg)) + else: + arg_name, arg_val = arg.split('=', 1) + kwargs[arg_name] = parse_arg(arg_val) + + return func_name, args, kwargs + + # if forgive_unquoted_strings: + # cmd_args = [cmd_args[0]] + [_quote_args_that_you_forgot_to_quote(arg) for arg in cmd_args[1:]] + # + # args, kwargs = eval('_fake_func(' + ','.join(cmd_args[1:]) + ')', {'_fake_func': _fake_func}, {}) + # return func_name, args, kwargs + + +def _quote_args_that_you_forgot_to_quote(arg): + """Wrap the arg in quotes if the user failed to do it.""" + if arg.startswith('"') or arg.startswith("'"): + return arg + elif '=' in arg and sum(a=='=' for a in arg)==1: # Keyword + name, val = arg.split('=') + if val[0].isalpha(): + return '{}="{}"'.format(name, val) + else: + return arg + else: + if arg[0].isalpha(): + return '"{}"'.format(arg) + else: + return arg + + +if __name__ == '__main__': + + class MyUserInterface(DeadEasyUI): + + def isprime(self, number): + print('{} is {}prime'.format(number, {False: '', True: 'not '}[next(iter([True for i in range(3, number) if number % i==0]+[False]))])) + + def showargs(self, arg1, arg2): + print(' arg1: {}: {}\n arg2: {}: {}\n'.format(arg1, type(arg1), arg2, type(arg2))) + + MyUserInterface().launch(run_in_loop=True) diff --git a/artemis/general/display.py b/artemis/general/display.py index 1d95cb16..03957f6d 100644 --- a/artemis/general/display.py +++ b/artemis/general/display.py @@ -39,7 +39,13 @@ def dict_to_str(d): :param dict d: A dict :return str: A nice, formatted version of this dict. """ - return ', '.join('{}:{}'.format(k, repr(v)) for k, v in d.items()) + if isinstance(d, (list, tuple)) and all(isinstance(el, (list, tuple)) and len(el)==2 for el in d): + items = d + elif isinstance(d, dict): + items = d.items() + else: + raise Exception("Can't interpret object {}".format(d)) + return ', '.join('{}:{:.3g}'.format(k, v) if isinstance(v, float) else '{}:{}'.format(k, repr(v)) for k, v in items) def pyfuncstring_to_tex(pyfuncstr): diff --git a/artemis/general/duck.py b/artemis/general/duck.py index bb0cd6d3..063dd45a 100644 --- a/artemis/general/duck.py +++ b/artemis/general/duck.py @@ -410,9 +410,12 @@ def __getitem__(self, indices): new_substruct = self._struct[first_selector] if isinstance(new_substruct, UniversalCollection) and not isinstance(new_substruct, Duck): # This will happen if the selector is a slice or something... new_substruct = Duck(new_substruct, recurse=False) + if len(indices)==1: # Case 1: Simple... this is the last selector, so we can just return it. return new_substruct - else: # Case 2: + else: # Case 2: There are deeper indices to get + if not isinstance(new_substruct, Duck): + raise KeyError('Leave value "{}" can not be broken into with {}'.format(new_substruct, indices[1:])) if isinstance(first_selector, (list, np.ndarray, slice)): # Sliced selection, with more sub-indices return new_substruct.map(lambda x: x.__getitem__(indices[1:])) else: # Simple selection, with more sub-indices @@ -634,8 +637,15 @@ def open(self, *ixs): ixs = tuple(-1 if ix is next else ix for ix in ixs) return self[ixs] - def has_key(self, *key_chain): - return self._struct.has_key() + def has_key(self, key, *deeper_keys): + + try: + self[(key, )+deeper_keys] + return True + except (KeyError, AttributeError): + return False + # Alternate definition that checks for exact key values but does not handle slices, negative indices, etc. + # return self._struct.has_key(key) and (len(deeper_keys)==0 or (isinstance(self._struct[key], Duck) and self._struct[key].has_key(*deeper_keys))) def keys(self, depth=None): if depth is None: diff --git a/artemis/general/mymath.py b/artemis/general/mymath.py index 4490ad5e..b02bfbed 100644 --- a/artemis/general/mymath.py +++ b/artemis/general/mymath.py @@ -289,6 +289,21 @@ def is_parallel(a, b, angular_tolerance = 1e-7): return angle < angular_tolerance +def vector_projection(v, u, axis=-1, norm_factor=0.): + """ + Project v onto u. + :param v: A vector or collection of vectors + :param u: A vector or collection of vectors which is broadcastable against v + :param axis: The axis of v along which the vector is defined. + :return: An array the same shape as v projected onto u. + """ + return u*((norm_factor+(u*v).sum(axis=axis, keepdims=True)) / (norm_factor+(u*u).sum(axis=axis, keepdims=True))) + # true_axis = v.ndim+axis if axis<0 else axis + # u_norm = u/(u*u).sum(axis=axis, keepdims=True) + # vu_dot = (u*v).sum(axis=axis, keepdims=True) / (u*u).sum(axis=axis, keepdims=True) + # return vu_dot*u + + def align_curves(xs, ys, n_bins='median', xrange = ('min', 'max'), spacing = 'lin'): """ Given multiple curves with different x-coordinates, interpolate so that each has the same x points. diff --git a/artemis/general/nested_structures.py b/artemis/general/nested_structures.py index f9c9cfe6..74e3f992 100644 --- a/artemis/general/nested_structures.py +++ b/artemis/general/nested_structures.py @@ -1,4 +1,6 @@ -from collections import OrderedDict +import inspect +from collections import OrderedDict, Iterable +from functools import partial import numpy as np from six import string_types, next @@ -64,11 +66,19 @@ def flatten_struct(struct, primatives = PRIMATIVE_TYPES, custom_handlers = {}, _primitive_containers = (list, tuple, dict, set) -def _is_primitive_container(obj): +def isgenerator(iterable): + return hasattr(iterable,'__iter__') and not hasattr(iterable,'__len__') + + +def is_primitive_container(obj): return isinstance(obj, _primitive_containers) or hasattr(obj, '_fields') -def get_meta_object(data_object, is_container = _is_primitive_container): +def is_container_or_generator(obj): + return isinstance(obj, _primitive_containers) or hasattr(obj, '_fields') or isgenerator(obj) + + +def get_meta_object(data_object, is_container = is_primitive_container, flat_list = None): """ Given an arbitrary data structure, return a "meta object" which is the same structure, except all non-container objects are replaced by their types. @@ -82,16 +92,52 @@ def get_meta_object(data_object, is_container = _is_primitive_container): """ if is_container(data_object): if hasattr(data_object, '_fields'): - return type(data_object)(*(get_meta_object(x, is_container=is_container) for x in data_object)) + return type(data_object)(*(get_meta_object(x, is_container=is_container, flat_list=flat_list) for x in data_object)) elif isinstance(data_object, (list, tuple, set)): - return type(data_object)(get_meta_object(x, is_container=is_container) for x in data_object) + return type(data_object)(get_meta_object(x, is_container=is_container, flat_list=flat_list) for x in data_object) elif isinstance(data_object, dict): - return type(data_object)((k, get_meta_object(v, is_container=is_container)) for k, v in data_object.items()) + return type(data_object)((k, get_meta_object(v, is_container=is_container, flat_list=flat_list)) for k, v in data_object.items()) + elif isgenerator(data_object): + return tuple(get_meta_object(x, is_container=is_container, flat_list=flat_list) for x in data_object) + else: + raise Exception("Don't know how to handle containier: {}".format(data_object)) else: + if flat_list is not None: + flat_list.append(data_object) return type(data_object) -def broadcast_into_meta_object(meta_object, data_object, is_container = _is_primitive_container, check_types = True): +def get_leaves_and_rebuilder(nested_object, is_container = is_container_or_generator, check_types=True, assert_fully_used=True): + """ + Given a nested structure, get the leaves in the structure, and a function to rebuild them. + + e.g. + flat_list, rebuilder = get_leaves_and_rebuilder({'a': 1, 'b': (2, 3)}) + assert flat_list == [1, 2, 3] + assert rebuilder(a*2 for a in flat_list) == {'a': 2, 'b': (4, 6)} + + :param nested_object An arbitrarily nested object + :return Tuple[List, Callable[[Sequence], Any]] : Return the flattened sequence and the function required to rebuild into the nested format. + """ + # TODO: Consider making leaves a generator so this could be used for streams. + leaves = [] + meta_obj = get_meta_object(nested_object, is_container=is_container, flat_list=leaves) + return leaves, (lambda data_iteratable: _fill_meta_object(meta_object=meta_obj, data_iteratable=iter(data_iteratable), check_types=check_types, assert_fully_used=assert_fully_used, is_container_func=is_container)) + + +def get_leaves(nested_object, is_container = is_primitive_container): + """ + + :param nested_object: + :param is_container: + :return: + """ + leaves = [] + meta_obj = get_meta_object(nested_object, is_container=is_container, flat_list=leaves) + return leaves + + +def broadcast_into_meta_object(meta_object, data_object, is_container = is_primitive_container, check_types = True): """ "Broadcast" the data object into the meta object. This puts the data into the structure of the meta-object. E.g. @@ -159,7 +205,7 @@ def __repr__(self): def __eq__(self, other): return self.meta_object == other.meta_object - def broadcast(self, data_object, is_container = _is_primitive_container, check_types=True): + def broadcast(self, data_object, is_container = is_primitive_container, check_types=True): """ "Broadcast" a data object to have the given structure. e.g. @@ -172,7 +218,7 @@ def broadcast(self, data_object, is_container = _is_primitive_container, check_t """ return broadcast_into_meta_object(meta_object=self.meta_object, data_object=data_object, is_container=is_container, check_types=check_types) - def get_leaves(self, data_object, check_types = True, broadcast=False, is_container = _is_primitive_container): + def get_leaves(self, data_object, check_types = True, broadcast=False, is_container = is_primitive_container): """ :param data_object: Given a nested object, get the "leaf" values in Depth-First Order :return: A list of leaf values. @@ -183,7 +229,7 @@ def get_leaves(self, data_object, check_types = True, broadcast=False, is_contai self.check_type(data_object) return get_leaf_values(data_object, is_container_func=is_container) - def expand_from_leaves(self, leaves, check_types = True, assert_fully_used=True, is_container_func = _is_primitive_container): + def expand_from_leaves(self, leaves, check_types = True, assert_fully_used=True, is_container_func = is_primitive_container): """ Given an iterator of leaf values, fill the meta-object represented by this type. @@ -195,7 +241,7 @@ def expand_from_leaves(self, leaves, check_types = True, assert_fully_used=True, return _fill_meta_object(self.meta_object, (x for x in leaves), check_types=check_types, assert_fully_used=assert_fully_used, is_container_func=is_container_func) @staticmethod - def from_data(data_object, is_container_func = _is_primitive_container): + def from_data(data_object, is_container_func = is_primitive_container): """ :param data_object: A nested data object :param is_container_func: A callback which returns True if an object is to be considered a container and False otherwise @@ -204,6 +250,7 @@ def from_data(data_object, is_container_func = _is_primitive_container): return NestedType(get_meta_object(data_object, is_container=is_container_func)) + def isnamedtuple(thing): return hasattr(thing, '_fields') and len(thing.__bases__)==1 and thing.__bases__[0]==tuple @@ -212,7 +259,7 @@ def isnamedtupleinstance(thing): return isnamedtuple(thing.__class__) -def get_leaf_values(data_object, is_container_func = _is_primitive_container): +def get_leaf_values(data_object, is_container_func = is_primitive_container): """ Collect leaf values of a nested data_obj in Depth-First order. @@ -244,7 +291,7 @@ def get_leaf_values(data_object, is_container_func = _is_primitive_container): return [data_object] -def _fill_meta_object(meta_object, data_iteratable, assert_fully_used = True, check_types = True, is_container_func = _is_primitive_container): +def _fill_meta_object(meta_object, data_iteratable, assert_fully_used = True, check_types = True, is_container_func = is_primitive_container): """ Fill the data from the iterable into the meta_object. :param meta_object: A nested type descripter. See NestedType init @@ -294,7 +341,7 @@ def nested_map(func, *nested_objs, **kwargs): :param is_container_func: A callback which returns True if an object is to be considered a container and False otherwise :return: A nested objectect with the same structure, but func applied to every value. """ - is_container_func = kwargs['is_container_func'] if 'is_container_func' in kwargs else _is_primitive_container + is_container_func = kwargs['is_container_func'] if 'is_container_func' in kwargs else is_primitive_container check_types = kwargs['check_types'] if 'check_types' in kwargs else False assert len(nested_objs)>0, 'nested_map requires at least 2 args' assert callable(func), 'func must be a function with one argument.' diff --git a/artemis/general/progress_indicator.py b/artemis/general/progress_indicator.py index 4e365eed..42554413 100644 --- a/artemis/general/progress_indicator.py +++ b/artemis/general/progress_indicator.py @@ -38,15 +38,16 @@ def __init__(self, expected_iterations=None, name=None, update_every = (2, 'seco def __call__(self, iteration = None): self.print_update(iteration) - def print_update(self, progress=None): + def print_update(self, progress=None, info=None): self._current_time = time.time() elapsed = self._current_time - self._start_time if self._expected_iterations is None: if self._should_update(): - print ('Progress{}: {:.1f}s Elapsed. {}. {} calls averaging {:.2g} calls/s'.format( + print ('Progress{}: {:.1f}s Elapsed{}{}. {} calls averaging {:.2g} calls/s'.format( '' if self.name is None else ' of '+self.name, elapsed, - self._post_info_callback() if self._post_info_callback is not None else '', + '. '+ self._post_info_callback() if self._post_info_callback is not None else '', + ', '+ info if info is not None else '', self._i+1, (self._i+1)/elapsed )) @@ -62,15 +63,16 @@ def print_update(self, progress=None): else: remaining = elapsed * (1 / frac - 1) if frac > 0 else float('NaN') elapsed = self._current_time - self._start_time - print('Progress{}: {}%. {:.1f}s Elapsed, {:.1f}s Remaining{}. {} {} calls averaging {:.2g} calls/s'.format( - '' if self.name is None else ' of '+self.name, - int(100*frac), - elapsed, - remaining, - ', {:.1f}s Total'.format(elapsed+remaining) if self.show_total else '', - self._post_info_callback() if self._post_info_callback is not None else '', - self._i+1, - (self._i+1)/elapsed + print('Progress{name}: {progress}%. {elapsed:.1f}s Elapsed, {remaining:.1f}s Remaining{total}. {info_cb}{info}{n_calls} calls averaging {rate:.2g} calls/s'.format( + name = '' if self.name is None else ' of '+self.name, + progress = int(100*frac), + elapsed = elapsed, + remaining = remaining, + total = ', {:.1f}s Total'.format(elapsed+remaining) if self.show_total else '', + info_cb = '. '+ self._post_info_callback() if self._post_info_callback is not None else '', + info=', '+ info if info is not None else '', + n_calls=self._i+1, + rate=(self._i+1)/elapsed )) self._last_update = progress if self._update_unit == 'iterations' else self._current_time self._i += 1 diff --git a/artemis/general/should_be_builtins.py b/artemis/general/should_be_builtins.py index 1749b266..b686c790 100644 --- a/artemis/general/should_be_builtins.py +++ b/artemis/general/should_be_builtins.py @@ -162,8 +162,9 @@ def remove_duplicates(sequence, hashable=True, key=None, keep_last=False): :param keep_last: Keep the last element, rather than the first (only makes sense if key is not None) :returns: A list that maintains the order, but with duplicates removed """ + sequence = list(sequence) is_dup = detect_duplicates(sequence, hashable=hashable, key=key, keep_last=keep_last) - return [x for x, is_duplicate in zip(sequence, is_dup) if not is_duplicate] + return (x for x, is_duplicate in zip(sequence, is_dup) if not is_duplicate) def uniquify_duplicates(sequence_of_strings): @@ -257,7 +258,7 @@ def separate_common_items(list_of_lists): if are_dicts: list_of_lists = [el.items() for el in list_of_lists] all_items = [item for list_of_items in list_of_lists for item in list_of_items] - common_items = remove_duplicates([k for k, c in count_unique_items(all_items) if c==len(list_of_lists)], hashable=False) + common_items = list(remove_duplicates([k for k, c in count_unique_items(all_items) if c==len(list_of_lists)], hashable=False)) different_items = [[item for item in list_of_items if item not in common_items] for list_of_items in list_of_lists] if are_dicts: return dict(common_items), [dict(el) for el in different_items] @@ -461,4 +462,20 @@ def unzip(iterable): :param iterable: Any iterable object yielding N-tuples :return: A N-tuple of iterables """ - return zip(*iterable) \ No newline at end of file + return zip(*iterable) + + +def entries_to_table(tuplelist, fill_value = None): + """ + Turn a bunch of entries into a table. e.g. + + >>> entries_to_table([[('a', 1), ('b', 2)], [('a', 3), ('b', 4), ('c', 5)]]) + (['a', 'b', 'c'], [[1, 2, None], [3, 4, 5]]) + + :param Sequence[Sequence[Tuple[str, Any]]] tuplelist: N_samples samples of N_observations observations, each represented by (observation_name, observation_value) + :return Tuple[Sequence[str], Sequence[Sequence[Any]]: (observation_names, data) A list of observation_names and the data suitable for tabular plotting. + """ + all_entries = list(remove_duplicates((k for sample in tuplelist for k, v in (sample.items() if isinstance(sample, dict) else sample)))) + data = [dict(sample) for sample in tuplelist] + new_data = [[d[k] if k in d else fill_value for k in all_entries] for d in data] + return all_entries, new_data diff --git a/artemis/general/table_ui.py b/artemis/general/table_ui.py new file mode 100644 index 00000000..1f6658de --- /dev/null +++ b/artemis/general/table_ui.py @@ -0,0 +1,121 @@ +import numpy as np +from tabulate import tabulate +from artemis.general.dead_easy_ui import DeadEasyUI + + +class TableExplorerUI(DeadEasyUI): + + def __init__(self, table_data, col_headers=None, row_headers=None, col_indices=None, row_indices = None): + + assert all(len(r)==len(table_data[0]) for r in table_data), "All rows of table data must have the same length. Got lengths: {}".format([len(r) for r in table_data]) + table_data = np.array(table_data, dtype=object) + assert table_data.ndim==2, "Table must consist of 2d data" + + assert col_headers is None or len(col_headers)==table_data.shape[1] + assert row_headers is None or len(row_headers)==table_data.shape[0] + + self._table_data = table_data + self._col_indices = np.array(col_indices) if col_indices is not None else None + self._row_indices = np.array(row_indices) if row_indices is not None else None + self._col_headers = np.array(col_headers) if col_headers is not None else None + self._row_headers = np.array(row_headers) if row_headers is not None else None + self._old_data_buffer = [] + + @property + def n_rows(self): + return self._table_data.shape[0] + + @property + def n_cols(self): + return self._table_data.shape[1] + + def _get_full_table(self): + n_total_rows = 1 + int(self._col_headers is not None) + self._table_data.shape[0] + n_total_cols = 1 + int(self._row_headers is not None) + self._table_data.shape[1] + table_data = np.empty((n_total_rows, n_total_cols), dtype=object) + table_data[:2, :2] = '' + table_data[0, -self.n_cols:] = self._col_indices if self._col_indices is not None else ['{}'.format(i) for i in range(1, self.n_cols+1)] + table_data[-self.n_rows:, 0] = self._row_indices if self._row_indices is not None else ['{}'.format(i) for i in range(1, self.n_rows+1)] + if self._col_headers is not None: + table_data[1, -self.n_cols:] = self._col_headers + if self._row_headers is not None: + table_data[-self.n_rows:, 1] = self._row_headers + table_data[-self.n_rows:, -self.n_cols:] = self._table_data + return table_data + + def _get_menu_string(self): + table_str = tabulate(self._get_full_table()) + return '{}\n'.format(table_str) + + def _backup(self): + self._old_data_buffer.append((self._table_data, self._row_headers, self._row_indices, self._col_headers, self._col_indices)) + + def undo(self): + if len(self._old_data_buffer)==0: + print("Can't undo, no history") + else: + self._table_data, self._row_headers, self._row_indices, self._col_headers, self._col_indices = self._old_data_buffer.pop() + + def _parse_indices(self, user_range): + if isinstance(user_range, str): + user_range = user_range.split(',') + return [int(i)-1 for i in user_range] + + def _reindex(self, row_ixs=None, col_ixs=None): + self._backup() + if row_ixs is not None: + self._table_data = self._table_data[row_ixs, :] + if self._row_headers is not None: + self._row_headers = self._row_headers[row_ixs] + if self._row_indices is not None: + self._row_indices = self._row_indices[row_ixs] + if col_ixs is not None: + self._table_data = self._table_data[:, col_ixs] + if self._col_headers is not None: + self._col_headers = self._col_headers[col_ixs] + if self._col_indices is not None: + self._col_indices = self._col_indices[col_ixs] + + def delcol(self, user_range): + self._reindex(col_ixs=[i for i in range(self.n_cols) if i not in self._parse_indices(user_range)]) + + def delrow(self, user_range): + self._reindex(row_ixs=[i for i in range(self.n_rows) if i not in self._parse_indices(user_range)]) + + def shufrows(self, user_range): + indices = self._parse_indices(user_range) + self._reindex(row_ixs=indices + [i for i in range(self.n_rows) if i not in indices]) + + def shufcols(self, user_range): + indices = self._parse_indices(user_range) + self._reindex(col_ixs=indices + [i for i in range(self.n_cols) if i not in indices]) + + def sortrows(self, by_cols=None, shuffle_cols=True): + key_order_indices = self._parse_indices(by_cols) if by_cols is not None else range(self.n_cols) + + sorting_data = self._table_data[:, key_order_indices[::-1]].copy() + for col in range(sorting_data.shape[1]): + if np.mean([np.isreal(x) for x in sorting_data[:, col]]) % 1 != 0: # Indicating not some numeric and some non-numeric data + sorting_data[:, col] = [(not np.isreal(x), x) for x in sorting_data[:, col]] + + indices = np.lexsort(sorting_data.T) + self._reindex(row_ixs=indices) + if shuffle_cols: + self.shufcols(by_cols) + + def sortcols(self, by_rows=None, shuffle_rows=True): + key_order_indices = self._parse_indices(by_rows) if by_rows is not None else range(self.n_rows) + indices = np.lexsort(self._table_data[key_order_indices[::-1], :]) + self._reindex(col_ixs=indices) + if shuffle_rows: + self.shufrows(by_rows) + + +if __name__ == '__main__': + + ui = TableExplorerUI( + col_headers=['param1', 'size', 'cost'], + row_headers=['exp1', 'exp2', 'exp3'], + table_data= [[4, 'Bella', 100], [3, 'Abe', 120], [4, 'Clarence', 117]], + ) + ui.launch() diff --git a/artemis/general/tables.py b/artemis/general/tables.py index 48cb81ed..2ea64e74 100644 --- a/artemis/general/tables.py +++ b/artemis/general/tables.py @@ -7,7 +7,7 @@ def build_table(lookup_fcn, row_categories, column_categories, clear_repeated_headers = True, prettify_labels = True, - row_header_labels = None, remove_unchanging_cols = False): + row_header_labels = None, remove_unchanging_cols = False, include_row_category=True, include_column_category = True): """ Build the rows of a table. You can feed these rows into tabulate to generate pretty things. @@ -43,11 +43,13 @@ def lookup_function(prisoner_a_choice, prisoner_b_choice): :param clear_repeated_headers: True to not repeat row headers. :param row_header_labels: Labels for the row headers. :param remove_unchanging_cols: Remove columns for which all d - :return: A list of rows. + :param include_row_category: Include the row category in the table (as the first column in each row) + :param include_column_category: Include the column category in the table (as the first row of each column) + :return Sequence[Sequence[Any]]: A list of lists containing the entries of the table. """ # Now, build that table! - single_row_category = all(isinstance(c, string_types) for c in row_categories) - single_column_category = all(isinstance(c, string_types) for c in column_categories) + single_row_category = all(not isinstance(c, (list, tuple)) for c in row_categories) + single_column_category = all(not isinstance(c, (list, tuple)) for c in column_categories) if single_row_category: row_categories = [row_categories] @@ -57,10 +59,11 @@ def lookup_function(prisoner_a_choice, prisoner_b_choice): assert len(row_header_labels) == len(row_categories) rows = [] column_headers = list(zip(*itertools.product(*column_categories))) - for i, c in enumerate(column_headers): - row_header = row_header_labels if row_header_labels is not None and i==len(column_headers)-1 else [' ']*len(row_categories) - row = row_header+(blank_out_repeats(c) if clear_repeated_headers else list(c)) - rows.append([prettify_label(el) for el in row] if prettify_labels else row) + if include_column_category: + for i, c in enumerate(column_headers): + row_header = [] if not include_row_category else row_header_labels if row_header_labels is not None and i==len(column_headers)-1 else [' ']*len(row_categories) + row = row_header+(blank_out_repeats(c) if clear_repeated_headers else list(c)) + rows.append([prettify_label(el) for el in row] if prettify_labels else row) last_row_data = [' ']*len(row_categories) for row_info in itertools.product(*row_categories): if clear_repeated_headers: @@ -70,7 +73,7 @@ def lookup_function(prisoner_a_choice, prisoner_b_choice): if prettify_labels: row_header = [prettify_label(str(el)) for el in row_header] data = [lookup_fcn(row_info[0] if single_row_category else row_info, column_info[0] if single_column_category else column_info) for column_info in itertools.product(*column_categories)] - rows.append(list(row_header) + data) + rows.append(list(row_header) + data if include_row_category else data) assert all_equal((len(r) for r in rows)), "All rows must have equal length. They now have lengths: {}".format([len(r) for r in rows]) if remove_unchanging_cols: diff --git a/artemis/general/test_checkpoint_counter.py b/artemis/general/test_checkpoint_counter.py index 2562de1c..2bfe796e 100644 --- a/artemis/general/test_checkpoint_counter.py +++ b/artemis/general/test_checkpoint_counter.py @@ -1,6 +1,6 @@ from itertools import count from artemis.general.checkpoint_counter import CheckPointCounter, Checkpoints - +import numpy as np __author__ = 'peter' @@ -27,9 +27,14 @@ def test_checkpoint_counter(): def test_checkpoints(): is_test = Checkpoints(('exp', 10, .1)) - assert [a for a in range(100) if is_test()]==[0, 10, 22, 37, 54, 74, 97] + is_test = Checkpoints({0: 0.25, 0.75: 0.5, 2.: 1}) + assert np.allclose([a for a in np.arange(0, 6, 0.1) if is_test(a)], [0, 0.3, 0.5, 0.8, 1.3, 1.8, 2.3, 3.3, 4.3, 5.3]) + + is_test = Checkpoints({1: 0.5, 2: 1, 5: 3}) + assert np.allclose([a for a in np.arange(0, 12, 0.1) if is_test(a)], [1, 1.5, 2, 3, 4, 5, 8, 11]) + if __name__ == '__main__': test_checkpoint_counter() diff --git a/artemis/general/test_duck.py b/artemis/general/test_duck.py index 9228b26f..275c3640 100644 --- a/artemis/general/test_duck.py +++ b/artemis/general/test_duck.py @@ -541,30 +541,44 @@ def test_occasional_value_filter(): assert a.filter[:, 'b'] == [2, 5] +def test_has_key(): + duck = _get_standard_test_duck() + assert duck.has_key('b') + assert duck.has_key('b', 0) + assert duck.has_key('b', 0, 'subfield1') + assert not duck.has_key('b', 0, 'subfield1', 'dadadada') + assert not duck.has_key('b', 0, 'subfield1XXX') + assert duck.has_key('b', -1, 'subfield1') + assert duck.has_key('b', slice(None), 'subfield1') + assert not duck.has_key(slice(None), slice(None), 'subfield1') + assert not duck.has_key('q') + + if __name__ == '__main__': - # test_so_demo() - # test_dict_assignment() - # test_dictarraylist() - # test_simple_growing() - # test_open_key() - # test_open_next() - # test_to_struct() - # test_next_elipsis_assignment() - # test_slice_assignment() - # test_arrayify_empty_stuct() - # test_slice_on_start() - # test_assign_tuple_keys() - # test_broadcast_bug() - # test_key_values() - # test_description() - # test_duck_array_build() - # test_split_get_assign() - # test_assign_from_struct() - # test_arrayify_axis_demo() - # test_string_slices() - # test_reasonable_errors_on_wrong_keys() - # test_reasonable_error_messages() + test_so_demo() + test_dict_assignment() + test_dictarraylist() + test_simple_growing() + test_open_key() + test_open_next() + test_to_struct() + test_next_elipsis_assignment() + test_slice_assignment() + test_arrayify_empty_stuct() + test_slice_on_start() + test_assign_tuple_keys() + test_broadcast_bug() + test_key_values() + test_description() + test_duck_array_build() + test_split_get_assign() + test_assign_from_struct() + test_arrayify_axis_demo() + test_string_slices() + test_reasonable_errors_on_wrong_keys() + test_reasonable_error_messages() test_break_in() test_copy() test_key_get_on_set_bug() test_occasional_value_filter() + test_has_key() diff --git a/artemis/general/test_mymath.py b/artemis/general/test_mymath.py index 17d8a8b3..4e1acf10 100644 --- a/artemis/general/test_mymath.py +++ b/artemis/general/test_mymath.py @@ -4,7 +4,8 @@ from artemis.general.mymath import (softmax, cummean, cumvar, sigm, expected_sigm_of_norm, mode, cummode, normalize, is_parallel, align_curves, angle_between, fixed_diff, decaying_cumsum, geosum, selective_sum, - conv_fanout, conv2_fanout_map, proportional_random_assignment, clip_to_sum) + conv_fanout, conv2_fanout_map, proportional_random_assignment, clip_to_sum, + vector_projection) import numpy as np from six.moves import xrange @@ -295,6 +296,14 @@ def test_clip_to_sum(): assert np.array_equal(clip_to_sum([1,4,8,3], 20), [1,4,8,3]) +def test_projection(): + + v = np.array([[2, 2], [2, 1]]) + u = np.array([[0, 1], [1, 1]]) + v_proj_u = vector_projection(v, u, axis=1) + assert np.allclose(v_proj_u, [[0, 2], [1.5, 1.5]]) + + if __name__ == '__main__': test_decaying_cumsum() test_fixed_diff() @@ -314,4 +323,5 @@ def test_clip_to_sum(): test_fanout_map() test_conv2_fanout_map() test_proportional_random_assignment() - test_clip_to_sum() \ No newline at end of file + test_clip_to_sum() + test_projection() \ No newline at end of file diff --git a/artemis/general/test_nested_structures.py b/artemis/general/test_nested_structures.py index 19fa7f19..7639911d 100644 --- a/artemis/general/test_nested_structures.py +++ b/artemis/general/test_nested_structures.py @@ -8,7 +8,8 @@ from artemis.general.nested_structures import (flatten_struct, get_meta_object, NestedType, seqstruct_to_structseq, structseq_to_seqstruct, nested_map, - get_leaf_values, broadcast_into_meta_object) + get_leaf_values, broadcast_into_meta_object, + get_leaves_and_rebuilder) def test_flatten_struct(): @@ -152,6 +153,24 @@ def test_namedtuple_breakin(): assert struct.broadcast([1, 2]) == [thing(1, 1), thing(2, 2)] +def test_flatten_nested_struct_and_rebuild(): + + obj = [1, 2, {'a': (3, 4.), 'b': 'c'}] + flat_list, rebuilder = get_leaves_and_rebuilder(obj) + assert flat_list==[1, 2, 3, 4., 'c'] + new_obj = rebuilder(flat_list) + assert new_obj==obj + + obj = ((j for j in range(i)) for i in range(2, 5)) + flat_list, rebuilder = get_leaves_and_rebuilder(obj) + assert flat_list == [0, 1, 0, 1, 2, 0, 1, 2, 3] + assert rebuilder(flat_list) == ((0, 1), (0, 1, 2), (0, 1, 2, 3)) + assert rebuilder((f*2 for f in flat_list)) == ((0, 2), (0, 2, 4), (0, 2, 4, 6)) + + flat_list, rebuilder = get_leaves_and_rebuilder({'a': 1, 'b': (2, 3)}) + assert flat_list == [1, 2, 3] + assert rebuilder(a*2 for a in flat_list) == {'a': 2, 'b': (4, 6)} + if __name__ == '__main__': test_flatten_struct() @@ -163,4 +182,5 @@ def test_namedtuple_breakin(): test_nested_map_with_container_func() test_none_bug() test_nested_struct_broadcast() - test_namedtuple_breakin() \ No newline at end of file + test_namedtuple_breakin() + test_flatten_nested_struct_and_rebuild() \ No newline at end of file diff --git a/artemis/general/test_should_be_builtins.py b/artemis/general/test_should_be_builtins.py index e3afbf29..b6f8aa40 100644 --- a/artemis/general/test_should_be_builtins.py +++ b/artemis/general/test_should_be_builtins.py @@ -4,7 +4,7 @@ from artemis.general.should_be_builtins import itermap, reducemap, separate_common_items, remove_duplicates, \ detect_duplicates, remove_common_prefix, all_equal, get_absolute_module, insert_at, get_shifted_key_value, \ - divide_into_subsets + divide_into_subsets, entries_to_table __author__ = 'peter' @@ -32,11 +32,11 @@ def test_separate_common_items(): def test_remove_duplicates(): - assert remove_duplicates(['a', 'b', 'a', 'c', 'c'])==['a', 'b', 'c'] - assert remove_duplicates(['a', 'b', 'a', 'c', 'c'], keep_last=True)==['b', 'a', 'c'] - assert remove_duplicates(['Alfred', 'Bob', 'Cindy', 'Alina', 'Karol', 'Betty'], key=lambda x: x[0])==['Alfred', 'Bob', 'Cindy', 'Karol'] - assert remove_duplicates(['Alfred', 'Bob', 'Cindy', 'Alina', 'Karol', 'Betty'], key=lambda x: x[0], keep_last=True)==['Cindy', 'Alina', 'Karol', 'Betty'] - assert remove_duplicates(['Alfred', 'Bob', 'Cindy', 'Alina', 'Karol', 'Betty'], key=lambda x: x[0], keep_last=True, hashable=False)==['Cindy', 'Alina', 'Karol', 'Betty'] + assert list(remove_duplicates(['a', 'b', 'a', 'c', 'c']))==['a', 'b', 'c'] + assert list(remove_duplicates(['a', 'b', 'a', 'c', 'c'], keep_last=True))==['b', 'a', 'c'] + assert list(remove_duplicates(['Alfred', 'Bob', 'Cindy', 'Alina', 'Karol', 'Betty'], key=lambda x: x[0]))==['Alfred', 'Bob', 'Cindy', 'Karol'] + assert list(remove_duplicates(['Alfred', 'Bob', 'Cindy', 'Alina', 'Karol', 'Betty'], key=lambda x: x[0], keep_last=True))==['Cindy', 'Alina', 'Karol', 'Betty'] + assert list(remove_duplicates(['Alfred', 'Bob', 'Cindy', 'Alina', 'Karol', 'Betty'], key=lambda x: x[0], keep_last=True, hashable=False))==['Cindy', 'Alina', 'Karol', 'Betty'] def test_detect_duplicates(): @@ -117,6 +117,11 @@ def test_divide_into_subsets(): assert divide_into_subsets(range(9), subset_size=3) == [[0, 1, 2], [3, 4, 5], [6, 7, 8]] +def test_entries_to_table(): + + assert entries_to_table([[('a', 1), ('b', 2)], [('a', 3), ('b', 4), ('c', 5)]]) == (['a', 'b', 'c'], [[1, 2, None], [3, 4, 5]]) + + if __name__ == '__main__': test_separate_common_items() test_reducemap() @@ -128,4 +133,5 @@ def test_divide_into_subsets(): test_get_absolute_module() test_insert_at() test_get_shifted_key_value() - test_divide_into_subsets() \ No newline at end of file + test_divide_into_subsets() + test_entries_to_table() diff --git a/artemis/general/test_time_parser.py b/artemis/general/test_time_parser.py new file mode 100644 index 00000000..f11a05b4 --- /dev/null +++ b/artemis/general/test_time_parser.py @@ -0,0 +1,26 @@ +from datetime import timedelta + +from pytest import raises + +from artemis.general.time_parser import parse_time + + +def test_time_parser(): + + assert parse_time('8h') == timedelta(hours=8) + assert parse_time('3d8h') == timedelta(days=3, hours=8) + assert parse_time('5s') == timedelta(seconds=5) + assert parse_time('.25s') == timedelta(seconds=0.25) + assert parse_time('.25d4h') == timedelta(days=0.25, hours=4) + with raises(ValueError): + assert parse_time('0.0.25d4h') == timedelta(days=0.25, hours=4) + with raises(AssertionError): + assert parse_time('5hr') + with raises(AssertionError): + assert parse_time('5q') + with raises(AssertionError): + print(parse_time('5h4q')) + + +if __name__ == '__main__': + test_time_parser() diff --git a/artemis/general/time_parser.py b/artemis/general/time_parser.py index ae97a97a..ae94effa 100644 --- a/artemis/general/time_parser.py +++ b/artemis/general/time_parser.py @@ -4,24 +4,19 @@ from datetime import timedelta -regex = re.compile(r'((?P\d+?)hr)?((?P\d+?)m)?((?P\d+?)s)?') +regex = re.compile(r'^((?P[\.\d]+?)d)?((?P[\.\d]+?)h)?((?P[\.\d]+?)m)?((?P[\.\d]+?)s)?$') def parse_time(time_str): """ - Parse a time string e.g. (13m) into a timedelta object. + Parse a time string e.g. (2h13m) into a timedelta object. - Taken from virhilo at https://stackoverflow.com/a/4628148/851699 + Modified from virhilo's answer at https://stackoverflow.com/a/4628148/851699 :param time_str: A string identifying a duration. (eg. 2h13m) :return datetime.timedelta: A datetime.timedelta object """ parts = regex.match(time_str) - if not parts: - return - parts = parts.groupdict() - time_params = {} - for (name, param) in parts.items(): - if param: - time_params[name] = int(param) + assert parts is not None, "Could not parse any time information from '{}'. Examples of valid strings: '8h', '2d8h5m20s', '2m4s'".format(time_str) + time_params = {name: float(param) for name, param in parts.groupdict().items() if param} return timedelta(**time_params) diff --git a/artemis/plotting/data_conversion.py b/artemis/plotting/data_conversion.py index 404d2674..6f2922ba 100644 --- a/artemis/plotting/data_conversion.py +++ b/artemis/plotting/data_conversion.py @@ -276,3 +276,43 @@ def insert_data(self, data): def retrieve_data(self): return self._buffer[:self._index] + + +class ResamplingRecordBuffer(DataBuffer): + """ + Keeps a buffer of incoming data. When this data reaches the buffer size, it is culled (one of every cull_factor + samples is kept and the rest thrown away). Not that this will throw away some data. + """ + # TODO: Add option for averaging, instead of throwing away culled samples. + + def __init__(self, buffer_len, cull_factor=2): + self._buffer = None + self._buffer_len = buffer_len + self._index = 0 + self._cull_factor = cull_factor + self._sample_times = np.arange(buffer_len) + self._count = 0 + self._n_culls = 0 + + def insert_data(self, data): + + if self._count % (self._n_culls+1) == 0: + + if self._buffer is None: + shape = () if np.isscalar(data) else data.shape + dtype = data.dtype if isinstance(data, np.ndarray) else type(data) if isinstance(data, (int, float, bool)) else object + self._buffer = np.empty((self._buffer_len, )+shape, dtype = dtype) + + if self._index==self._buffer_len: + self._buffer[:int(np.ceil(self._buffer_len/float(self._cull_factor)))] = self._buffer[::self._cull_factor].copy() + self._sample_times = self._sample_times*self._cull_factor + self._index //= self._cull_factor + self._n_culls += 1 + + self._buffer[self._index] = data + self._index += 1 + + self._count+=1 + + def retrieve_data(self): + return self._sample_times[:self._index], self._buffer[:self._index] diff --git a/artemis/plotting/db_plotting.py b/artemis/plotting/db_plotting.py index 5938e84a..9bce645c 100644 --- a/artemis/plotting/db_plotting.py +++ b/artemis/plotting/db_plotting.py @@ -5,7 +5,7 @@ from artemis.config import get_artemis_config_value from artemis.general.checkpoint_counter import Checkpoints -from artemis.plotting.matplotlib_backend import BarPlot, BoundingBoxPlot +from artemis.plotting.matplotlib_backend import BarPlot, BoundingBoxPlot, ResamplingLineHistory from matplotlib.axes import Axes from matplotlib.gridspec import SubplotSpec from contextlib import contextmanager @@ -48,13 +48,16 @@ def dbplot(data, name = None, plot_type = None, axis=None, plot_mode = 'live', d :param data: Any data. Hopefully, we at dbplot will be able to figure out a plot for it. :param name: A name uniquely identifying this plot. - :param plot_type: A specialized constructor to be used the first time when plotting. You can also pass - certain string to give hints as to what kind of plot you want (can resolve cases where the given data could be - plotted in multiple ways): - 'line': Plots a line plot - 'img': An image plot - 'colour': A colour image plot - 'pic': A picture (no scale bars, axis labels, etc). + :param Union[Callable[[],LinePlot],str,Tuple[Callable, Dict]] plot_type : A specialized constructor to be used the + first time when plotting. Several predefined constructors are defined in the DBPlotTypes class - you can pass + those. For back-compatibility you can also pass a string matching the name of one of the fields in the DBPlotTypes + class. + DBPlotTypes.LINE: Plots a line plot + DBPlotTypes.IMG: An image plot + DBPlotTypes.COLOUR: A colour image plot + DBPlotTypes.PIC: A picture (no scale bars, axis labels, etc) + You can also, pass a tuple of (constructor, keyword_args) where keyword args is a dict of arcuments to the plot + constructor. :param axis: A string identifying which axis to plot on. By default, it is the same as "name". Only use this argument if you indend to make multiple dbplots share the same axis. :param plot_mode: Influences how the data should be used to choose the plot type: @@ -95,11 +98,16 @@ def dbplot(data, name = None, plot_type = None, axis=None, plot_mode = 'live', d if name not in suplot_dict: # Initialize new axis if isinstance(plot_type, str): - plot = PLOT_CONSTRUCTORS[plot_type]() + plot = DBPlotTypes.from_string(plot_type)() elif isinstance(plot_type, tuple): assert len(plot_type)==2 and isinstance(plot_type[0], str) and isinstance(plot_type[1], dict), 'If you specify a tuple for plot_type, we expect (name, arg_dict). Got: {}'.format(plot_type) plot_type_name, plot_type_args = plot_type - plot = PLOT_CONSTRUCTORS[plot_type_name](**plot_type_args) + if isinstance(plot_type_name, str): + plot = DBPlotTypes.from_string(plot_type_name)(**plot_type_args) + elif callable(plot_type_name): + plot = plot_type_name(**plot_type_args) + else: + raise Exception('The first argument of the plot type tuple must be a plot type name or a callable plot type constructor.') elif plot_type is None: plot = get_plot_from_data(data, mode=plot_mode) else: @@ -154,11 +162,8 @@ def dbplot(data, name = None, plot_type = None, axis=None, plot_mode = 'live', d if draw_now and not _hold_plots and (draw_every is None or ((fig, name) not in _draw_counters) or _draw_counters[fig, name]()): plot.plot() - if hang: - plt.figure(_DBPLOT_FIGURES[fig].figure.number) - plt.show() - else: - redraw_figure(_DBPLOT_FIGURES[fig].figure) + display_figure(_DBPLOT_FIGURES[fig].figure, hang=hang) + return _DBPLOT_FIGURES[fig].subplots[name].axis @@ -179,32 +184,36 @@ def dbplot(data, name = None, plot_type = None, axis=None, plot_mode = 'live', d _default_layout = 'grid' -PLOT_CONSTRUCTORS = { - 'line': LinePlot, - 'thick-line': partial(LinePlot, plot_kwargs={'linewidth': 3}), - 'pos_line': partial(LinePlot, y_bounds=(0, None), y_bound_extend=(0, 0.05)), - 'bbox': partial(BoundingBoxPlot, linewidth=2, axes_update_mode='expand'), - 'bbox_r': partial(BoundingBoxPlot, linewidth=2, color='r', axes_update_mode='expand'), - 'bbox_b': partial(BoundingBoxPlot, linewidth=2, color='b', axes_update_mode='expand'), - 'bbox_g': partial(BoundingBoxPlot, linewidth=2, color='g', axes_update_mode='expand'), - 'bar': BarPlot, - 'img': ImagePlot, - 'cimg': partial(ImagePlot, channel_first=True), - 'line_history': MovingPointPlot, - 'img_stable': partial(ImagePlot, only_grow_clims=True), - 'colour': partial(ImagePlot, is_colour_data=True), - 'equal_aspect': partial(ImagePlot, aspect='equal'), - 'image_history': MovingImagePlot, - 'fixed_line_history': partial(MovingPointPlot, buffer_len=100), - 'pic': partial(ImagePlot, show_clims=False, aspect='equal'), - 'notice': partial(TextPlot, max_history=1, horizontal_alignment='center', vertical_alignment='center', size='x-large'), - 'cost': partial(MovingPointPlot, y_bounds=(0, None), y_bound_extend=(0, 0.05)), - 'percent': partial(MovingPointPlot, y_bounds=(0, 100)), - 'trajectory': partial(Moving2DPointPlot, axes_update_mode='expand'), - 'trajectory+': partial(Moving2DPointPlot, axes_update_mode='expand', x_bounds=(0, None), y_bounds=(0, None)), - 'histogram': partial(HistogramPlot, edges = np.linspace(-5, 5, 20)), - 'cumhist': partial(CumulativeLineHistogram, edges = np.linspace(-5, 5, 20)), - } +class DBPlotTypes: + LINE= LinePlot + THICK_LINE= partial(LinePlot, plot_kwargs={'linewidth': 3}) + POS_LINE= partial(LinePlot, y_bounds=(0, None), y_bound_extend=(0, 0.05)) + BBOX= partial(BoundingBoxPlot, linewidth=2, axes_update_mode='expand') + BBOX_R= partial(BoundingBoxPlot, linewidth=2, color='r', axes_update_mode='expand') + BBOX_B= partial(BoundingBoxPlot, linewidth=2, color='b', axes_update_mode='expand') + BBOX_G= partial(BoundingBoxPlot, linewidth=2, color='g', axes_update_mode='expand') + BAR= BarPlot + IMG= ImagePlot + CIMG= partial(ImagePlot, channel_first=True) + LINE_HISTORY= MovingPointPlot + IMG_STABLE= partial(ImagePlot, only_grow_clims=True) + COLOUR= partial(ImagePlot, is_colour_data=True) + EQUAL_ASPECT= partial(ImagePlot, aspect='equal') + IMAGE_HISTORY= MovingImagePlot + FIXED_LINE_HISTORY= partial(MovingPointPlot, buffer_len=100) + LINE_HISTORY_RESAMPLED= partial(ResamplingLineHistory, buffer_len=400) + PIC= partial(ImagePlot, show_clims=False, aspect='equal') + NOTICE= partial(TextPlot, max_history=1, horizontal_alignment='center', vertical_alignment='center', size='x-large') + COST= partial(MovingPointPlot, y_bounds=(0, None), y_bound_extend=(0, 0.05)) + PERCENT= partial(MovingPointPlot, y_bounds=(0, 100)) + TRAJECTORY= partial(Moving2DPointPlot, axes_update_mode='expand') + TRAJECTORY_PLUS= partial(Moving2DPointPlot, axes_update_mode='expand', x_bounds=(0, None), y_bounds=(0, None)) + HISTOGRAM= partial(HistogramPlot, edges = np.linspace(-5, 5, 20)) + CUMHIST= partial(CumulativeLineHistogram, edges = np.linspace(-5, 5, 20)) + + @classmethod + def from_string(cls, str): # For back-compatibility + return getattr(cls, str.upper().replace('-', '_').replace('+', '_PLUS')) def reset_dbplot(): @@ -253,17 +262,30 @@ def freeze_all_dbplots(fig = None): freeze_dbplot(name, fig=fig) -def replot_and_redraw_figure(fig): +def replot_and_redraw_figure(fig, hang): for subplot in _DBPLOT_FIGURES[fig].subplots.values(): plt.subplot(subplot.axis) subplot.plot_object.plot() - redraw_figure(_DBPLOT_FIGURES[fig].figure) + display_figure(_DBPLOT_FIGURES[fig].figure, hang) + + +def display_figure(fig, hang): + if hang is True: + plt.figure(fig.number) + plt.show() + elif hang in (None, False): + redraw_figure(fig) + elif isinstance(hang, (int, float)): + redraw_figure(fig) + plt.pause(hang) + else: + raise TypeError("Can't interpret hang argument {}".format(hang)) @contextmanager -def hold_dbplots(fig = None, draw_every = None): +def hold_dbplots(fig = None, hang=False, draw_every = None): """ Use this in a "with" statement to prevent plotting until the end. :param fig: @@ -291,7 +313,7 @@ def hold_dbplots(fig = None, draw_every = None): plot_now = True if plot_now and fig in _DBPLOT_FIGURES: - replot_and_redraw_figure(fig) + replot_and_redraw_figure(fig, hang = hang) def clear_dbplot(fig = None): @@ -309,11 +331,15 @@ def get_dbplot_axis(axis_name, fig=None): return _DBPLOT_FIGURES[fig].axes[axis_name] -def dbplot_hang(): - plt.show() +def dbplot_hang(timeout=None): + if timeout is None: + plt.show() + else: + redraw_figure() + plt.pause(timeout) -def dbplot_collection(collection, name, axis = None, draw_every=None, **kwargs): +def dbplot_collection(collection, name, hang=False, axis = None, draw_every=None, **kwargs): """ Plot a collection of items in one go. :param collection: @@ -321,7 +347,7 @@ def dbplot_collection(collection, name, axis = None, draw_every=None, **kwargs): :param kwargs: :return: """ - with hold_dbplots(draw_every=draw_every): + with hold_dbplots(draw_every=draw_every, hang=hang): if isinstance(collection, (list, tuple)): for i, el in enumerate(collection): dbplot(el, '{}[{}]'.format(name, i), axis='{}[{}]'.format(axis, i) if axis is not None else None, **kwargs) diff --git a/artemis/plotting/expanding_subplots.py b/artemis/plotting/expanding_subplots.py index 24a5dd4d..48c92a26 100644 --- a/artemis/plotting/expanding_subplots.py +++ b/artemis/plotting/expanding_subplots.py @@ -264,7 +264,7 @@ def vstack_plots(spacing=0, sharex=True, sharey = False, show_x = 'once', show_y new_subplots[-1].tick_params(axis='x', labelbottom='on') if xlabel is not None: - new_subplots[-1].set_xlabel(xlabel) + new_subplots[-1].set_xlabcel(xlabel) if remove_ticks: new_subplots[-1].get_xaxis().set_visible(True) diff --git a/artemis/plotting/matplotlib_backend.py b/artemis/plotting/matplotlib_backend.py index d9ddff34..da156fb6 100644 --- a/artemis/plotting/matplotlib_backend.py +++ b/artemis/plotting/matplotlib_backend.py @@ -7,8 +7,9 @@ from artemis.config import get_artemis_config_value from artemis.general.should_be_builtins import bad_value -from artemis.plotting.data_conversion import (put_data_in_grid, RecordBuffer, data_to_image, put_list_of_images_in_array, - UnlimitedRecordBuffer) +from artemis.plotting.data_conversion import (put_data_in_grid, RecordBuffer, data_to_image, + put_list_of_images_in_array, + UnlimitedRecordBuffer, ResamplingRecordBuffer) from matplotlib import pyplot as plt import numpy as np from six.moves import xrange @@ -368,6 +369,21 @@ def plot(self): LinePlot.plot(self) +class ResamplingLineHistory(LinePlot): + + def __init__(self, buffer_len, cull_factor=2, **kwargs): + LinePlot.__init__(self, **kwargs) + self._buffer = ResamplingRecordBuffer(buffer_len=buffer_len, cull_factor=cull_factor) + + def update(self, data): + self._buffer.insert_data(data) + + def plot(self): + x_data, y_data = self._buffer.retrieve_data() + LinePlot.update(self, (x_data, y_data)) + LinePlot.plot(self) + + class Moving2DPointPlot(LinePlot): def __init__(self, buffer_len=None, **kwargs): diff --git a/artemis/plotting/pyplot_plus.py b/artemis/plotting/pyplot_plus.py index ea7f1cad..46ccef14 100644 --- a/artemis/plotting/pyplot_plus.py +++ b/artemis/plotting/pyplot_plus.py @@ -153,7 +153,7 @@ def set_default_figure_size(width, height): def get_lines_color_cycle(): - return _lines_colour_cycle + return (p['color'] for p in plt.rcParams['axes.prop_cycle']) def get_color_cycle_map(name, length): @@ -169,7 +169,7 @@ def set_lines_color_cycle_map(name, length): def get_line_color(ix, modifier=None): - colour = _lines_colour_cycle[ix] + colour = next(c for i, c in enumerate(get_lines_color_cycle()) if i==ix) if modifier=='dark': return tuple(c/2 for c in colors.hex2color(colour)) elif modifier=='light': diff --git a/artemis/plotting/test_db_plotting.py b/artemis/plotting/test_db_plotting.py index 15388fc8..8bdd566c 100644 --- a/artemis/plotting/test_db_plotting.py +++ b/artemis/plotting/test_db_plotting.py @@ -5,7 +5,8 @@ from artemis.plotting.demo_dbplot import demo_dbplot from artemis.plotting.db_plotting import dbplot, clear_dbplot, hold_dbplots, freeze_all_dbplots, reset_dbplot, \ dbplot_hang -from artemis.plotting.matplotlib_backend import LinePlot, HistogramPlot, MovingPointPlot, is_server_plotting_on +from artemis.plotting.matplotlib_backend import LinePlot, HistogramPlot, MovingPointPlot, is_server_plotting_on, \ + ResamplingLineHistory import pytest from matplotlib import pyplot as plt from matplotlib import gridspec @@ -69,9 +70,14 @@ def test_history_plot_updating(): def test_moving_point_multiple_points(): reset_dbplot() - for i in xrange(5): - dbplot(np.sin([i/10., i/15.]), 'unlim buffer', plot_type = partial(MovingPointPlot)) - dbplot(np.sin([i/10., i/15.]), 'lim buffer', plot_type = partial(MovingPointPlot,buffer_len=20)) + p1 = 5. + p2 = 8. + for i in xrange(50): + with hold_dbplots(draw_every=5): + dbplot(np.sin([i/p1, i/p2]), 'unlim buffer', plot_type = partial(MovingPointPlot)) + dbplot(np.sin([i/p1, i/p2]), 'lim buffer', plot_type = partial(MovingPointPlot,buffer_len=20)) + dbplot(np.sin([i/p1, i/p2]), 'resampling buffer', plot_type = partial(ResamplingLineHistory, buffer_len=20)) # Only looks bad because of really small buffer length from testing. + def test_same_object(): """ From ef24bb32fe5d5cec6db93e6cd0670cd4c720eb7e Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Sat, 22 Sep 2018 16:24:11 +0200 Subject: [PATCH 06/41] ok these are kind of working --- artemis/experiments/experiment_record_view.py | 24 +++++++++++++++++++ artemis/experiments/ui.py | 7 +++--- artemis/plotting/db_plotting.py | 3 +++ artemis/plotting/expanding_subplots.py | 4 ++-- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index ce0eedc7..b1ed0e98 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -425,3 +425,27 @@ def separate_common_args(records, as_dicts=False, return_dict = False, only_shar if return_dict: argdiff = {rec.get_id(): args for rec, args in zip(records, argdiff)} return common, argdiff + + +def compare_timeseries_records(records, yfield, xfield = None): + """ + :param Sequence[ExperimentRecord] records: A list of records containing results of the form + Sequence[Dict[str, number]] + :param yfield: The name of the fields for the x-axis + :param xfield: The name of the field for the y-axis + """ + from matplotlib import pyplot as plt + results = [rec.get_result() for rec in records] + all_different_args, values = get_different_args([r.get_args() for r in records]) + + ax = plt.figure().add_subplot(1, 1, 1) + for result, argvals in izip_equal(results, values): + xvals = [r[xfield] for r in result] if xfield is not None else list(range(len(result))) + yvals = [r[yfield] for r in result] + ax.plot(xvals, yvals, label=', '.join(f'{argname}={argval}' for argname, argval in izip_equal(all_different_args, argvals))) + ax.grid(True) + if xfield is not None: + ax.set_xlabel(xfield) + ax.set_ylabel(yfield) + plt.legend() + plt.show() diff --git a/artemis/experiments/ui.py b/artemis/experiments/ui.py index d07302dc..8c0c5a0a 100644 --- a/artemis/experiments/ui.py +++ b/artemis/experiments/ui.py @@ -110,6 +110,7 @@ class ExperimentBrowser(object): > showarchived Toggle display of archived results. > view results View just the columns for experiment name and result > view full View all columns (the default view) +> kill 4.1,4.5 Kill the selected currently running records (you'll be prompted for confirmation) > show 4 Show the output from the last run of experiment 4 (if it has been run already). > show 4-6 Show the output of experiments 4,5,6 together. > records Browse through all experiment records. @@ -517,12 +518,12 @@ def argsort(self, *args): # First verify that all args are included... all_arg_names = set(a for exp_name in self.exp_record_dict.keys() for a, v in load_experiment(exp_name).get_args().items()) if any(a not in all_arg_names for a in args): - raise RecordSelectionError('Arg(s) [{}] were not included in any experiments') + raise RecordSelectionError('Arg(s) {} were not included in any experiments. Possible names: {}'.format(list(a for a in args if a not in all_arg_names), all_arg_names)) # Define a comparison function that will always compare. def key_sorting_function(exp_name): exp_args = load_experiment(exp_name).get_args() - return tuple(() if name not in exp_args else (None, exp_args[name]) if isinstance(exp_args[name], (int, float)) else (str(type(exp_args[name])), exp_args[name]) for name in args) + return tuple(() if name not in exp_args else ('!', float(exp_args[name])) if isinstance(exp_args[name], (int, float)) else (str(type(exp_args[name])), exp_args[name]) for name in args) self._sortkey = key_sorting_function return ExperimentBrowser.REFRESH @@ -691,7 +692,7 @@ def pull(self, *args): def kill(self, *args): parser = argparse.ArgumentParser() - parser.add_argument('user_range', action='store', help='A selection of experiments whose records to pull. Examples: "3" or "3-5", or "3,4,5"') + parser.add_argument('user_range', action='store', help='A selection of experiments whose records to kill. Examples: "3.2" or "3-5", or "3,4,5"') parser.add_argument('-s', '--skip', action='store_true', default=True, help='Skip the check that all selected records are currently running (just filter running ones)') args = parser.parse_args(args) diff --git a/artemis/plotting/db_plotting.py b/artemis/plotting/db_plotting.py index 9bce645c..52a5fa9e 100644 --- a/artemis/plotting/db_plotting.py +++ b/artemis/plotting/db_plotting.py @@ -81,6 +81,9 @@ def dbplot(data, name = None, plot_type = None, axis=None, plot_mode = 'live', d dbplot_remotely(arg_locals=arg_locals) return + if data.__class__.__module__ == 'torch' and data.__class__.__name__ == 'Tensor': + data = data.detach().cpu().numpy() + if isinstance(fig, plt.Figure): assert None not in _DBPLOT_FIGURES, "If you pass a figure, you can only do it on the first call to dbplot (for now)" _DBPLOT_FIGURES[None] = _PlotWindow(figure=fig, subplots=OrderedDict(), axes={}) diff --git a/artemis/plotting/expanding_subplots.py b/artemis/plotting/expanding_subplots.py index 48c92a26..d7c7afed 100644 --- a/artemis/plotting/expanding_subplots.py +++ b/artemis/plotting/expanding_subplots.py @@ -164,14 +164,14 @@ def add_subplot(layout = None, fig = None, **subplot_args): return select_subplot(name=None, fig=fig, layout=layout, **subplot_args) -def subplot_at(row, col, fig=None): +def subplot_at(row, col, fig=None, **subplot_args): """ Create or select a the subplot at position (row, col) :param row: The row :param col: The column :return: An axes object """ - return select_subplot(position=(row, col), fig=None) + return select_subplot(position=(row, col), fig=None, **subplot_args) @contextmanager From e11bceb2a25ac3492475921433d3e2960017e1d0 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Fri, 19 Oct 2018 11:59:08 +0200 Subject: [PATCH 07/41] stuuuuuf --- artemis/experiments/decorators.py | 11 ++- artemis/experiments/experiment_management.py | 11 ++- artemis/experiments/experiment_record.py | 6 +- artemis/experiments/experiment_record_view.py | 43 +++++++++--- artemis/experiments/experiments.py | 70 +++++++++++++++++++ artemis/experiments/hyperparameter_search.py | 38 ++++++++++ artemis/experiments/test_experiments.py | 23 ++++++ artemis/experiments/ui.py | 28 +++++++- artemis/general/checkpoint_counter.py | 14 ++-- artemis/general/iteratorize.py | 48 +++++++++++++ artemis/general/measuring_periods.py | 21 ++++++ artemis/general/progress_indicator.py | 31 ++++++-- artemis/general/should_be_builtins.py | 6 +- artemis/general/speedometer.py | 1 + artemis/general/test_progress_inidicator.py | 21 ++++++ artemis/ml/parameter_schedule.py | 2 +- artemis/plotting/data_conversion.py | 4 +- artemis/plotting/db_plotting.py | 4 +- artemis/plotting/expanding_subplots.py | 2 +- docs/source/plotting.rst | 4 +- 20 files changed, 354 insertions(+), 34 deletions(-) create mode 100644 artemis/experiments/hyperparameter_search.py create mode 100644 artemis/general/iteratorize.py create mode 100644 artemis/general/measuring_periods.py create mode 100644 artemis/general/test_progress_inidicator.py diff --git a/artemis/experiments/decorators.py b/artemis/experiments/decorators.py index 7c2bbaa5..413432e9 100644 --- a/artemis/experiments/decorators.py +++ b/artemis/experiments/decorators.py @@ -44,7 +44,7 @@ class ExperimentFunction(object): This is the most general decorator. You can use this to add details on the experiment. """ - def __init__(self, show = None, compare = compare_experiment_records, display_function=None, comparison_function=None, one_liner_function=None, result_parser = None, is_root=False): + def __init__(self, show = None, compare = compare_experiment_records, display_function=None, comparison_function=None, one_liner_function=None, result_parser = None, is_root=False, name=None): """ :param show: A function that is called when you "show" an experiment record in the UI. It takes an experiment record as an argument. @@ -55,6 +55,7 @@ def __init__(self, show = None, compare = compare_experiment_records, display_fu You can use call this via the UI with the compare_experiment_results command. :param one_liner_function: A function that takes your results and returns a 1 line string summarizing them. :param is_root: True to make this a root experiment - so that it is not listed to be run itself. + :param name: Custom name (if None, experiment will be named after decorated function) """ self.show = show self.compare = compare @@ -75,11 +76,17 @@ def compare(records): self.is_root = is_root self.one_liner_function = one_liner_function self.result_parser = result_parser + self.name = name def __call__(self, f): + """ + :param Callable f: The function you decorated + :return Experiment: An Experiment object (It still behaves as the original function when you call it, but now + has additional methods attached to it associated with the experiment). + """ f.is_base_experiment = True ex = Experiment( - name=f.__name__, + name=f.__name__ if self.name is None else self.name, function=f, show=self.show, compare = self.compare, diff --git a/artemis/experiments/experiment_management.py b/artemis/experiments/experiment_management.py index 46786b32..12ac236e 100644 --- a/artemis/experiments/experiment_management.py +++ b/artemis/experiments/experiment_management.py @@ -126,7 +126,11 @@ def select_experiments(user_range, exp_record_dict, return_dict=False): def _filter_experiments(user_range, exp_record_dict, return_is_in = False): - if user_range.startswith('~'): + if '|' in user_range: + is_in = [any(xs) for xs in zip(*(_filter_experiments(subrange, exp_record_dict, return_is_in=True) for subrange in user_range.split('|')))] + elif '&' in user_range: + is_in = [all(xs) for xs in zip(*(_filter_experiments(subrange, exp_record_dict, return_is_in=True) for subrange in user_range.split('&')))] + elif user_range.startswith('~'): is_in = _filter_experiments(user_range=user_range[1:], exp_record_dict=exp_record_dict, return_is_in=True) is_in = [not r for r in is_in] else: @@ -141,6 +145,9 @@ def _filter_experiments(user_range, exp_record_dict, return_is_in = False): elif user_range.startswith('has:'): phrase = user_range[len('has:'):] is_in = [phrase in exp_id for exp_id in exp_record_dict] + elif user_range.startswith('tag:'): + tag = user_range[len('tag:'):] + is_in = [tag in load_experiment(exp_id).get_tags() for exp_id in exp_record_dict] elif user_range.startswith('1diff:'): base_range = user_range[len('1diff:'):] base_range_exps = select_experiments(base_range, exp_record_dict) # list @@ -308,7 +315,7 @@ def _filter_records(user_range, exp_record_dict): try: sign = user_range[3] assert sign in ('<', '>') - filter_func = (lambda a, b: ab) + filter_func = (lambda a, b: (a is not None and b is not None) and ab) time_delta = parse_time(user_range[4:]) except: if user_range.startswith('dur'): diff --git a/artemis/experiments/experiment_record.py b/artemis/experiments/experiment_record.py index 5e67d99f..6fd30e83 100644 --- a/artemis/experiments/experiment_record.py +++ b/artemis/experiments/experiment_record.py @@ -310,7 +310,10 @@ def get_runtime(self): """ :return datetime.timedelta: A timedelta object """ - return timedelta(seconds=self.info.get_field(ExpInfoFields.RUNTIME)) + try: + return timedelta(seconds=self.info.get_field(ExpInfoFields.RUNTIME)) + except KeyError: # Which will happen if the experiment is still running or was killed without due process + return None def get_dir(self): """ @@ -506,6 +509,7 @@ def get_current_record_id(): def get_current_record_dir(default_if_none = True): """ The directory in which the results of the current experiment are recorded. + :param default_if_none: True to put records in the "default" dir if no experiment is running. """ if _CURRENT_EXPERIMENT_RECORD is None and default_if_none: return get_artemis_data_path('experiments/default/', make_local_dir=True) diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index b1ed0e98..9a408b44 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -1,6 +1,7 @@ import re from collections import OrderedDict +import itertools from six import string_types from tabulate import tabulate import numpy as np @@ -238,12 +239,25 @@ def get_different_args(args, no_arg_filler = 'N/A', arrange_by_deltas=False): return all_different_args, values -def get_exportiment_record_arg_result_table(records): +def get_exportiment_record_arg_result_table(records, result_parser = None, fill_value='N/A', arg_rename_dict = None): + """ + Given a list of ExperimentRecords, make a table containing the arguments that differ between them, and their results. + :param Sequence[ExperimentRecord] records: + :param Optional[Callable] result_parser: Takes the result and returns either: + - a List[Tuple[str, Any]], containing the (name, value) pairs of results which will form the rightmost columns of the table + - Anything else, in which case the header of the last column is taken to be "Result" and the value is put in the table + :param fill_value: Value to fill in when the experiment does not have a particular argument. + :return Tuple[List[str], List[List[Any]]]: headers, results + """ + if arg_rename_dict is not None: + arg_processor = lambda args: OrderedDict((arg_rename_dict[name] if name in arg_rename_dict else name, val) for name, val in args.items() if name not in arg_rename_dict or arg_rename_dict[name] is not None) + else: + arg_processor = lambda args: args record_ids = [record.get_id() for record in records] - all_different_args, arg_values = get_different_args([r.get_args() for r in records], no_arg_filler='N/A') + all_different_args, arg_values = get_different_args([arg_processor(r.get_args()) for r in records], no_arg_filler=fill_value) - parsed_results = [record.get_experiment().result_parser(record.get_result()) for record in records] - result_fields, result_data = entries_to_table(parsed_results) + parsed_results = [(result_parser or record.get_experiment().result_parser)(record.get_result()) if record.has_result() else [('Result', 'N/A')] for record in records] + result_fields, result_data = entries_to_table(parsed_results, fill_value = fill_value) result_fields = [get_unique_name(rf, all_different_args) for rf in result_fields] # Just avoid name collisions # result_column_name = get_unique_name('Results', taken_names=all_different_args) @@ -298,6 +312,7 @@ def show_multiple_records(records, func = None): from artemis.plotting.manage_plotting import delay_show with delay_show(): for rec in records: + func(rec) else: for rec in records: @@ -427,7 +442,7 @@ def separate_common_args(records, as_dicts=False, return_dict = False, only_shar return common, argdiff -def compare_timeseries_records(records, yfield, xfield = None): +def compare_timeseries_records(records, yfield, xfield = None, hang=True, ax=None): """ :param Sequence[ExperimentRecord] records: A list of records containing results of the form Sequence[Dict[str, number]] @@ -438,14 +453,22 @@ def compare_timeseries_records(records, yfield, xfield = None): results = [rec.get_result() for rec in records] all_different_args, values = get_different_args([r.get_args() for r in records]) - ax = plt.figure().add_subplot(1, 1, 1) + if not isinstance(yfield, (list, tuple)): + yfield = [yfield] + + ax = ax if ax is not None else plt.figure().add_subplot(1, 1, 1) for result, argvals in izip_equal(results, values): xvals = [r[xfield] for r in result] if xfield is not None else list(range(len(result))) - yvals = [r[yfield] for r in result] - ax.plot(xvals, yvals, label=', '.join(f'{argname}={argval}' for argname, argval in izip_equal(all_different_args, argvals))) + # yvals = [r[yfield[0]] for r in result] + h, = ax.plot(xvals, [r[yfield[0]] for r in result], label=(yfield[0]+': ' if len(yfield)>1 else '')+', '.join(f'{argname}={argval}' for argname, argval in izip_equal(all_different_args, argvals))) + for yf, linestyle in zip(yfield[1:], itertools.cycle(['--', ':', '-.'])): + ax.plot(xvals, [r[yf] for r in result], linestyle=linestyle, color=h.get_color(), label=yf+': '+', '.join(f'{argname}={argval}' for argname, argval in izip_equal(all_different_args, argvals))) + ax.grid(True) if xfield is not None: ax.set_xlabel(xfield) - ax.set_ylabel(yfield) + if len(yfield)==1: + ax.set_ylabel(yfield[0]) plt.legend() - plt.show() + if hang: + plt.show() diff --git a/artemis/experiments/experiments.py b/artemis/experiments/experiments.py index f9776356..b6862744 100644 --- a/artemis/experiments/experiments.py +++ b/artemis/experiments/experiments.py @@ -4,14 +4,18 @@ from contextlib import contextmanager from functools import partial from six import string_types + from artemis.experiments.experiment_record import ExpStatusOptions, experiment_id_to_record_ids, load_experiment_record, \ get_all_record_ids, clear_experiment_records from artemis.experiments.experiment_record import run_and_record from artemis.experiments.experiment_record_view import compare_experiment_records, show_record +from artemis.experiments.hyperparameter_search import parameter_search from artemis.general.display import sensible_str from artemis.general.functional import get_partial_root, partial_reparametrization, \ advanced_getargspec, PartialReparametrization +from artemis.general.should_be_builtins import izip_equal +from artemis.general.test_mode import is_test_mode class Experiment(object): @@ -39,6 +43,7 @@ def __init__(self, function=None, show=None, compare=None, one_liner_function=No self.variants = OrderedDict() self._notes = [] self.is_root = is_root + self._tags= set() if not is_root: all_args, varargs_name, kargs_name, defaults = advanced_getargspec(function) @@ -416,6 +421,71 @@ def get_variant_records(self, only_completed=False, only_last=False, flat=False) else: return exp_record_dict + def add_parameter_search(self, name='parameter_search', space = None, n_calls=100, search_params = None, scalar_func=None): + """ + :param name: Name of the Experiment to be created + :param dict[str, skopt.space.Dimension] space: A dict mapping param name to Dimension. + e.g. space=dict(a = Real(1, 100, 'log-uniform'), b = Real(1, 100, 'log-uniform')) + :param Callable[[Any], float] scalar_func: Takes the return value of the experiment and turns it into a scalar + which we aim to minimize. + :param dict[str, Any] search_params: Args passed to parameter_search + :return Experiment: A new experiment which runs the search and yields current-best parameters with every iteration. + """ + assert space is not None, "You must specify a parameter search space. See this method's documentation" + if name is None: # TODO: Set name=None in the default after deadline + name = 'parameter_search[{}]'.format(','.join(space.keys())) + + if search_params is None: + search_params = {} + + def objective(**current_params): + output = self.call(**current_params) + if scalar_func is not None: + output = scalar_func(output) + return output + + from artemis.experiments import ExperimentFunction + + @ExperimentFunction(name = self.name + '.'+ name, show = show_parameter_search_record, one_liner_function=parameter_search_one_liner) + def search_exp(): + if is_test_mode(): + nonlocal n_calls + n_calls = 3 # When just verifying that experiment runs, do the minimum + + for iter_info in parameter_search(objective, n_calls=n_calls, space=space, **search_params): + info = dict(names=list(space.keys()), x_iters =iter_info.x_iters, func_vals=iter_info.func_vals, score = iter_info.func_vals, x=iter_info.x, fun=iter_info.fun) + latest_info = {name: val for name, val in izip_equal(info['names'], iter_info.x_iters[-1])} + print(f'Latest: {latest_info}, Score: {iter_info.func_vals[-1]:.3g}') + yield info + + self.variants[name] = search_exp + search_exp.tag('psearch') # Secret feature that makes it easy to select all parameter experiments in ui with "filter tag:psearch" + return search_exp + + def tag(self, tag): + """ + Add a "tag" - a string identifying the experiment as being in some sort of group. + You can use tags in the UI with 'filter tag:my_tag' to select experiments with a given tag + :param tag: + :return: + """ + self._tags.add(tag) + return self + + def get_tags(self): + return self._tags + + +def show_parameter_search_record(record): + from tabulate import tabulate + result = record.get_result() + table = tabulate([list(xs)+[fun] for xs, fun in zip(result['x_iters'], result['func_vals'])], headers=list(result['names'])+['score']) + print(table) + + +def parameter_search_one_liner(result): + return f'{len(result["x_iters"])} Runs : ' + ', '.join(f'{k}={v:.3g}' for k, v in izip_equal(result['names'], result['x'])) + f' : Score = {result["fun"]:.3g}' + _GLOBAL_EXPERIMENT_LIBRARY = OrderedDict() diff --git a/artemis/experiments/hyperparameter_search.py b/artemis/experiments/hyperparameter_search.py new file mode 100644 index 00000000..2594ca38 --- /dev/null +++ b/artemis/experiments/hyperparameter_search.py @@ -0,0 +1,38 @@ +from skopt import gp_minimize +from skopt.utils import use_named_args +from tabulate import tabulate + +from artemis.general.iteratorize import Iteratorize + + +def parameter_search(objective, space, n_calls, n_random_starts=3, acq_optimizer="auto", n_jobs=4): + """ + :param Callable[[Any], scalar] objective: The objective function that we're trying to optimize + :param dict[str, Dimension] space: + :param n_calls: + :param n_random_starts: + :param acq_optimizer: + :return Generator[{'names': List[str], 'x_iters': List[]: + """ # TODO: Finish building this + + for k, var in space.items(): + var.name=k + space = list(space.values()) + + objective = use_named_args(space)(objective) + + iter = Iteratorize( + func = lambda callback: gp_minimize(objective, + dimensions=space, + n_calls=n_calls, + n_random_starts = n_random_starts, + random_state=1234, + n_jobs=n_jobs, + verbose=False, + callback=callback, + acq_optimizer = acq_optimizer, + ), + ) + + for i, iter_info in enumerate(iter): + yield iter_info diff --git a/artemis/experiments/test_experiments.py b/artemis/experiments/test_experiments.py index 1855f397..6208e28a 100644 --- a/artemis/experiments/test_experiments.py +++ b/artemis/experiments/test_experiments.py @@ -114,8 +114,31 @@ def my_exp(a, b, c): assert XXXX() == 1+(5*5)*5 +def test_parameter_search(): + + from skopt.space import Real + + with experiment_testing_context(new_experiment_lib=True): + + @experiment_root + def bowl(x, y): + return {'z': (x-2)**2 + (y+3)**2} + + ex_search = bowl.add_parameter_search( + space = {'x': Real(-5, 5, 'uniform'), 'y': Real(-5, 5, 'uniform')}, + scalar_func=lambda result: result['z'], + search_params=dict(n_calls=5) + ) + + record = ex_search.run() + result = record.get_result() + assert result['names']==['x', 'y'] + assert result['func_vals'][-1] < result['func_vals'][0] + + if __name__ == '__main__': test_unpicklable_args() test_config_variant() test_config_bug_catching() test_args_are_checked() + test_parameter_search() \ No newline at end of file diff --git a/artemis/experiments/ui.py b/artemis/experiments/ui.py index 8c0c5a0a..55f9335a 100644 --- a/artemis/experiments/ui.py +++ b/artemis/experiments/ui.py @@ -18,6 +18,7 @@ select_experiment_records_from_list, interpret_numbers, run_multiple_experiments) from artemis.experiments.experiment_record import ExpStatusOptions +from artemis.experiments.experiment_record import ExperimentRecord from artemis.experiments.experiment_record import (get_all_record_ids, clear_experiment_records, load_experiment_record, ExpInfoFields) from artemis.experiments.experiment_record_view import (get_record_full_string, get_record_invalid_arg_string, @@ -27,7 +28,8 @@ from artemis.experiments.experiment_record_view import show_record, show_multiple_records from artemis.experiments.experiments import load_experiment, get_nonroot_global_experiment_library from artemis.fileman.local_dir import get_artemis_data_path -from artemis.general.display import IndentPrint, side_by_side, truncate_string, surround_with_header, format_duration, format_time_stamp +from artemis.general.display import IndentPrint, side_by_side, truncate_string, surround_with_header, format_duration, \ + format_time_stamp, section_with_header from artemis.general.hashing import compute_fixed_hash from artemis.general.mymath import levenshtein_distance from artemis.general.should_be_builtins import all_equal, insert_at, izip_equal, separate_common_items, bad_value @@ -278,7 +280,9 @@ def launch(self, command=None): 'q': self.quit, 'records': self.records, 'pull': self.pull, + 'info': self.info, 'clearcache': clear_ui_cache, + 'logs': self.logs, } display_again = True @@ -572,6 +576,28 @@ def show(self, *args): show_multiple_records(records, func) _warn_with_prompt(use_prompt=False) + def info(self, *args): + parser = argparse.ArgumentParser() + parser.add_argument('user_range', action='store', help='A selection of experiment records to show. ') + args = parser.parse_args(args) + user_range = args.user_range + records = select_experiment_records(user_range, self.exp_record_dict, flat=True) + for record in records: + print('='*64) + print(record.info.get_text()) + print('='*64) + + def logs(self, *args): + parser = argparse.ArgumentParser() + parser.add_argument('user_range', action='store', help='A selection of experiment records to show. ') + args = parser.parse_args(args) + user_range = args.user_range + records = select_experiment_records(user_range, self.exp_record_dict, flat=True) # type: list[ExperimentRecord] + for record in records: + print('='*64) + print(record.get_log()) + print('='*64) + def compare(self, *args): parser = argparse.ArgumentParser() parser.add_argument('user_range', action='store', help='A selection of experiment records to compare. Examples: "3" or "3-5", or "3,4,5"') diff --git a/artemis/general/checkpoint_counter.py b/artemis/general/checkpoint_counter.py index c1f94264..4edece2e 100644 --- a/artemis/general/checkpoint_counter.py +++ b/artemis/general/checkpoint_counter.py @@ -89,14 +89,18 @@ def __init__(self, checkpoint_generator, default_units = None, skip_first = Fals elif isinstance(checkpoint_generator, (int, float)): step = checkpoint_generator checkpoint_generator = (step*i for i in itertools.count(0)) + elif checkpoint_generator is None: + checkpoint_generator = (np.inf for _ in itertools.count(0)) else: assert isinstance(checkpoint_generator, types.GeneratorType) - if skip_first: - next(checkpoint_generator) - - self.checkpoint_generator = checkpoint_generator - self._next_checkpoint = float('inf') if checkpoint_generator is None else next(checkpoint_generator) + try: + if skip_first: + next(checkpoint_generator) + self.checkpoint_generator = checkpoint_generator + self._next_checkpoint = next(checkpoint_generator) + except StopIteration: + raise Exception('Your checkpoint generator provided no checkpoints.') self._counter = 0 self._start_time = time.time() diff --git a/artemis/general/iteratorize.py b/artemis/general/iteratorize.py new file mode 100644 index 00000000..f5581887 --- /dev/null +++ b/artemis/general/iteratorize.py @@ -0,0 +1,48 @@ + + +""" +Thanks to Brice for this piece of code. Taken from https://stackoverflow.com/a/9969000/851699 + +""" + +# from thread import start_new_thread +from collections import Iterable +from queue import Queue +from threading import Thread + + +class Iteratorize(Iterable): + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func): + """ + :param Callable[Callable, Any] func: A function that takes a callback as an argument then runs. + """ + self.mfunc = func + # self.ifunc = ifunc + self.q = Queue(maxsize=1) + self.sentinel = object() + + def _callback(val): + self.q.put(val) + + def gentask(): + ret = self.mfunc(_callback) + self.q.put(self.sentinel) + + # start_new_thread(gentask, ()) + Thread(target=gentask).start() + + def __iter__(self): + return self + + def __next__(self): + + obj = self.q.get(True, None) + if obj is self.sentinel: + raise StopIteration + else: + return obj diff --git a/artemis/general/measuring_periods.py b/artemis/general/measuring_periods.py new file mode 100644 index 00000000..2ca59b1d --- /dev/null +++ b/artemis/general/measuring_periods.py @@ -0,0 +1,21 @@ + +import time + +_last_time_dict = {} + + +def measure_period(identifier): + """ + You can call this in a loop to get an easy measure of how much time has elapsed since the last call. + On the first call it will return NaN. + :param Any identifier: + :return float: Elapsed time since last measure + """ + if identifier not in _last_time_dict: + _last_time_dict[identifier] = time.time() + return float('nan') + else: + now = time.time() + elapsed = now - _last_time_dict[identifier] + _last_time_dict[identifier] = now + return elapsed diff --git a/artemis/general/progress_indicator.py b/artemis/general/progress_indicator.py index 42554413..7af0d3b4 100644 --- a/artemis/general/progress_indicator.py +++ b/artemis/general/progress_indicator.py @@ -1,5 +1,7 @@ import time +from decorator import contextmanager + class ProgressIndicator(object): @@ -34,13 +36,14 @@ def __init__(self, expected_iterations=None, name=None, update_every = (2, 'seco self._last_time = self._start_time self._last_progress = 0 self.show_total = show_total + self._pause_time = 0 def __call__(self, iteration = None): self.print_update(iteration) def print_update(self, progress=None, info=None): self._current_time = time.time() - elapsed = self._current_time - self._start_time + elapsed = self._current_time - self._start_time - self._pause_time if self._expected_iterations is None: if self._should_update(): print ('Progress{}: {:.1f}s Elapsed{}{}. {} calls averaging {:.2g} calls/s'.format( @@ -57,12 +60,10 @@ def print_update(self, progress=None, info=None): progress = self._i frac = float(progress)/(self._expected_iterations-1) if self._expected_iterations>1 else 1. if self._should_update() or progress == self._expected_iterations-1: - elapsed = self._current_time - self._start_time if self.just_use_last is True: remaining = (self._current_time - self._last_time)/(frac - self._last_progress) * (1-frac) if frac > 0 else float('NaN') else: remaining = elapsed * (1 / frac - 1) if frac > 0 else float('NaN') - elapsed = self._current_time - self._start_time print('Progress{name}: {progress}%. {elapsed:.1f}s Elapsed, {remaining:.1f}s Remaining{total}. {info_cb}{info}{n_calls} calls averaging {rate:.2g} calls/s'.format( name = '' if self.name is None else ' of '+self.name, progress = int(100*frac), @@ -82,7 +83,7 @@ def print_update(self, progress=None, info=None): self._last_progress = frac def get_elapsed(self): - return time.time() - self._start_time + return time.time() - self._start_time - self._pause_time def get_iterations(self): return self._i @@ -92,3 +93,25 @@ def _should_update_time(self): def _should_update_iter(self): return self._i - self._last_update > self._update_interval + + def pause_measurement(self): + """ + Context manager meaning "don't count this interval". + + Usage: + + n_iter = 100 + pi = ProgressInidicator(n_iter) + for i in range(n_iter): + do_something_worth_counting + with pi.pause_measurement(): + do_something_that_doesnt_count() + pi.print_update() + """ + @contextmanager + def pause_counting(): + start_pause_time = time.time() + yield + self._pause_time += time.time() - start_pause_time + + return pause_counting() diff --git a/artemis/general/should_be_builtins.py b/artemis/general/should_be_builtins.py index b686c790..b2e9894b 100644 --- a/artemis/general/should_be_builtins.py +++ b/artemis/general/should_be_builtins.py @@ -313,7 +313,6 @@ def remove_common_prefix(list_of_lists, max_elements=None, keep_base = True): count = 0 min_len = 1 if keep_base else 0 - while min(len(parts) for parts in list_of_lists)>min_len: if max_elements is not None and count >= max_elements: break @@ -479,3 +478,8 @@ def entries_to_table(tuplelist, fill_value = None): data = [dict(sample) for sample in tuplelist] new_data = [[d[k] if k in d else fill_value for k in all_entries] for d in data] return all_entries, new_data + + +def print_thru(x): + print(x) + return x \ No newline at end of file diff --git a/artemis/general/speedometer.py b/artemis/general/speedometer.py index 4a8ed09d..d3e61983 100644 --- a/artemis/general/speedometer.py +++ b/artemis/general/speedometer.py @@ -20,3 +20,4 @@ def __call__(self, progress=None): self._last_time = this_time return speed + diff --git a/artemis/general/test_progress_inidicator.py b/artemis/general/test_progress_inidicator.py new file mode 100644 index 00000000..e27d5688 --- /dev/null +++ b/artemis/general/test_progress_inidicator.py @@ -0,0 +1,21 @@ +from artemis.general.progress_indicator import ProgressIndicator +import time + +def test_progress_inidicator(): + + n_iter = 100 + + pi = ProgressIndicator(n_iter, update_every='1s') + + start=time.time() + for i in range(n_iter): + time.sleep(0.001) + if i % 10==0: + with pi.pause_measurement(): + time.sleep(0.02) + + assert pi.get_elapsed() < (time.time() - start)/2. + + +if __name__ == '__main__': + test_progress_inidicator() diff --git a/artemis/ml/parameter_schedule.py b/artemis/ml/parameter_schedule.py index f10fd1ae..0fec2b65 100644 --- a/artemis/ml/parameter_schedule.py +++ b/artemis/ml/parameter_schedule.py @@ -7,7 +7,7 @@ def __init__(self, schedule, print_variable_name = None): """ Given a schedule for a changing parameter (e.g. learning rate) get the values for this parameter at a given time. e.g.: - learning_rate_scheduler = ScheduledParameter({0: 0.1, 10: 0.01, 100: 0.001}, print_variable_name='eta') + learning_rate_scheduler = ParameterSchedule({0: 0.1, 10: 0.01, 100: 0.001}, print_variable_name='eta') new_learning_rate = learning_rate_scheduler.get_new_value(epoch=14) assert new_learning_rate == 0.01 diff --git a/artemis/plotting/data_conversion.py b/artemis/plotting/data_conversion.py index 6f2922ba..bf5e8aa4 100644 --- a/artemis/plotting/data_conversion.py +++ b/artemis/plotting/data_conversion.py @@ -24,12 +24,12 @@ def vector_length_to_tile_dims(vector_length, ): return grid_shape -def put_vector_in_grid(vec, shape = None): +def put_vector_in_grid(vec, shape = None, empty_val = 0): if shape is None: n_rows, n_cols = vector_length_to_tile_dims(len(vec)) else: n_rows, n_cols = shape - grid = np.zeros(n_rows*n_cols, dtype = vec.dtype) + grid = np.zeros(n_rows*n_cols, dtype = vec.dtype) + empty_val grid[:len(vec)]=vec grid=grid.reshape(n_rows, n_cols) return grid diff --git a/artemis/plotting/db_plotting.py b/artemis/plotting/db_plotting.py index 52a5fa9e..655ccf9d 100644 --- a/artemis/plotting/db_plotting.py +++ b/artemis/plotting/db_plotting.py @@ -265,7 +265,7 @@ def freeze_all_dbplots(fig = None): freeze_dbplot(name, fig=fig) -def replot_and_redraw_figure(fig, hang): +def dbplot_redraw_all(fig = None, hang = False): for subplot in _DBPLOT_FIGURES[fig].subplots.values(): plt.subplot(subplot.axis) @@ -316,7 +316,7 @@ def hold_dbplots(fig = None, hang=False, draw_every = None): plot_now = True if plot_now and fig in _DBPLOT_FIGURES: - replot_and_redraw_figure(fig, hang = hang) + dbplot_redraw_all(fig, hang = hang) def clear_dbplot(fig = None): diff --git a/artemis/plotting/expanding_subplots.py b/artemis/plotting/expanding_subplots.py index d7c7afed..f310a0e5 100644 --- a/artemis/plotting/expanding_subplots.py +++ b/artemis/plotting/expanding_subplots.py @@ -264,7 +264,7 @@ def vstack_plots(spacing=0, sharex=True, sharey = False, show_x = 'once', show_y new_subplots[-1].tick_params(axis='x', labelbottom='on') if xlabel is not None: - new_subplots[-1].set_xlabcel(xlabel) + new_subplots[-1].set_xlabel(xlabel) if remove_ticks: new_subplots[-1].get_xaxis().set_visible(True) diff --git a/docs/source/plotting.rst b/docs/source/plotting.rst index 95fcf69d..e791ca7b 100644 --- a/docs/source/plotting.rst +++ b/docs/source/plotting.rst @@ -71,8 +71,8 @@ Plotting Demos ###################### * `A demo of showing how to make various kinds of live updating plots. `_ -* `A demo repo showing how to use Artemis from your code `_ -* `A guide on using Artemis for remote plotting `_ +* `A demo repo showing how to use Artemis from your code `_ +* `A guide on using Artemis for remote plotting `_ ###################### From e473ff1f877f1c4c88f3c9fd902f32f01f7f0996 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Fri, 19 Oct 2018 14:19:09 +0200 Subject: [PATCH 08/41] aahhh --- artemis/experiments/hyperparameter_search.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/artemis/experiments/hyperparameter_search.py b/artemis/experiments/hyperparameter_search.py index 2594ca38..82639d7b 100644 --- a/artemis/experiments/hyperparameter_search.py +++ b/artemis/experiments/hyperparameter_search.py @@ -1,7 +1,3 @@ -from skopt import gp_minimize -from skopt.utils import use_named_args -from tabulate import tabulate - from artemis.general.iteratorize import Iteratorize @@ -14,6 +10,8 @@ def parameter_search(objective, space, n_calls, n_random_starts=3, acq_optimizer :param acq_optimizer: :return Generator[{'names': List[str], 'x_iters': List[]: """ # TODO: Finish building this + from skopt import gp_minimize # Soft requirements are imported in here. + from skopt.utils import use_named_args for k, var in space.items(): var.name=k From 623225619452207ac3c83ba27f60a35c94fb3d56 Mon Sep 17 00:00:00 2001 From: Peter Date: Mon, 12 Nov 2018 15:50:48 +0100 Subject: [PATCH 09/41] ARTEMIS CHANGES FROM DEAD MACHINE --- artemis/experiments/experiment_record.py | 45 ++++- artemis/experiments/experiment_record_view.py | 89 +++++++++- artemis/experiments/test_experiment_record.py | 20 ++- .../test_experiment_record_view_and_ui.py | 36 +++- artemis/experiments/ui.py | 12 ++ artemis/general/deferred_defaults.py | 20 +++ artemis/general/display.py | 2 +- artemis/general/ezprofile.py | 33 +++- artemis/general/functional.py | 4 +- artemis/general/global_rates.py | 15 ++ artemis/general/global_vars.py | 29 ++++ artemis/general/should_be_builtins.py | 22 ++- artemis/general/test_deferred_defaults.py | 33 ++++ artemis/general/test_should_be_builtins.py | 8 +- artemis/ml/predictors/predictor_comparison.py | 2 +- artemis/ml/tools/iteration.py | 14 +- artemis/ml/tools/processors.py | 120 -------------- artemis/ml/tools/running_averages.py | 154 ++++++++++++++++++ artemis/ml/tools/test_running_averages.py | 64 ++++++++ artemis/plotting/db_plotting.py | 1 + artemis/plotting/matplotlib_backend.py | 9 +- artemis/plotting/test_db_plotting.py | 76 ++++----- 22 files changed, 609 insertions(+), 199 deletions(-) create mode 100644 artemis/general/deferred_defaults.py create mode 100644 artemis/general/global_rates.py create mode 100644 artemis/general/global_vars.py create mode 100644 artemis/general/test_deferred_defaults.py create mode 100644 artemis/ml/tools/running_averages.py create mode 100644 artemis/ml/tools/test_running_averages.py diff --git a/artemis/experiments/experiment_record.py b/artemis/experiments/experiment_record.py index 6fd30e83..1d26f980 100644 --- a/artemis/experiments/experiment_record.py +++ b/artemis/experiments/experiment_record.py @@ -13,6 +13,7 @@ from contextlib import contextmanager from getpass import getuser from pickle import PicklingError +import itertools from datetime import datetime, timedelta from uuid import getnode @@ -23,7 +24,7 @@ from artemis.general.display import CaptureStdOut from artemis.general.functional import get_partial_chain, get_defined_and_undefined_args from artemis.general.hashing import compute_fixed_hash -from artemis.general.should_be_builtins import nested +from artemis.general.should_be_builtins import nested, natural_keys from artemis.general.test_mode import is_test_mode from artemis.general.test_mode import set_test_mode from artemis._version import __version__ as ARTEMIS_VERSION @@ -240,6 +241,7 @@ def get_figure_locs(self, include_directory=True): :return: A list of string file paths. """ locs = [f for f in os.listdir(self._experiment_directory) if f.startswith('fig-')] + locs = sorted(locs, key=natural_keys) if include_directory: return [os.path.join(self._experiment_directory, f) for f in locs] else: @@ -661,7 +663,34 @@ def clear_experiment_records(ids): ExperimentRecord(exp_path).delete() -def save_figure_in_record(name, fig=None, default_ext='.pkl'): +# +# +# def save_figure_in_current_experiment_directory(name='fig-{}.pkl', figure = None): +# +# if figure is None: +# figure = plt.gcf() +# +# current_dir = get_current_record_dir() +# start_ix = _figure_ixs[current_dir] if current_dir in _figure_ixs else 0 +# for ix in count(start_ix): +# full_path = os.path.join(current_dir, name).format(ix) +# if not os.path.exists(_figure_ixs[current_dir]): +# save_figure(figure, path = full_path) +# _figure_ixs[current_dir] = ix+1 +# return full_path +_figure_ixs = {} + + +def _get_next_figure_name(name_pattern, directory): + start_ix = _figure_ixs[directory] if directory in _figure_ixs else 0 + for ix in itertools.count(start_ix): + full_path = os.path.join(directory, name_pattern).format(ix) + if not os.path.exists(full_path): + _figure_ixs[directory] = ix+1 + return full_path + + +def save_figure_in_record(name=None, fig=None, default_ext='.pkl'): ''' Saves the given figure in the experiment directory. If no figure is passed, plt.gcf() is saved instead. :param name: The name of the figure to be saved @@ -671,11 +700,17 @@ def save_figure_in_record(name, fig=None, default_ext='.pkl'): ''' import matplotlib.pyplot as plt from artemis.plotting.saving_plots import save_figure + if fig is None: fig = plt.gcf() - save_path = os.path.join(get_current_record_dir(), name) - save_figure(fig, path=save_path, default_ext=default_ext) - return save_path + + current_dir = get_current_record_dir() + if name is None: + path = _get_next_figure_name(name_pattern='fig-{}.pkl', directory=current_dir) + else: + path = os.path.join(current_dir, name) + save_figure(fig, path=path, default_ext=default_ext) + return path def get_serialized_args(argdict): diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 9a408b44..57341e0f 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -1,6 +1,6 @@ import re from collections import OrderedDict - +from functools import partial import itertools from six import string_types from tabulate import tabulate @@ -13,6 +13,7 @@ from artemis.general.should_be_builtins import separate_common_items, bad_value, izip_equal, \ remove_duplicates, get_unique_name, entries_to_table from artemis.general.tables import build_table +import os def get_record_result_string(record, func='deep', truncate_to = None, array_print_threshold=8, array_float_format='.3g', oneline=False, default_one_liner_func=str): @@ -446,8 +447,8 @@ def compare_timeseries_records(records, yfield, xfield = None, hang=True, ax=Non """ :param Sequence[ExperimentRecord] records: A list of records containing results of the form Sequence[Dict[str, number]] - :param yfield: The name of the fields for the x-axis - :param xfield: The name of the field for the y-axis + :param yfield: The name of the field for the x-axis + :param xfield: The name of the field(s) for the y-axis """ from matplotlib import pyplot as plt results = [rec.get_result() for rec in records] @@ -472,3 +473,85 @@ def compare_timeseries_records(records, yfield, xfield = None, hang=True, ax=Non plt.legend() if hang: plt.show() + + +def get_timeseries_record_comparison_function(yfield, xfield = None, hang=True, ax=None): + """ + :param yfield: The name of the field for the x-axis + :param xfield: The name of the field(s) for the y-axis + """ + return lambda records: compare_timeseries_records(records, yfield, xfield = xfield, hang=hang, ax=ax) + + + +def timeseries_oneliner_function(result, fields, show_len, show = 'last'): + assert show=='last', 'Only support showing last element now' + return (f'{len(result)} items. ' if show_len else '')+', '.join(f'{k}: {result[-1][k]:.3g}' if isinstance(result[-1][k], float) else f'{k}: {result[-1][k]}' for k in fields) + + +def get_timeseries_oneliner_function(fields, show_len=False, show='last'): + return partial(timeseries_oneliner_function, fields=fields, show_len=show_len, show=show) + + +def browse_record_figs(record): + """ + Browse through the figures associated with an experiment record + :param ExperimentRecord record: An experiment record + """ + # TODO: Generalize this to just browse through the figures in a directory. + + from artemis.plotting.saving_plots import interactive_matplotlib_context + import pickle + from matplotlib import pyplot as plt + from artemis.plotting.drawing_plots import redraw_figure + fig_locs = record.get_figure_locs() + + class nonlocals: + this_fig = None + figno = 0 + + def show_figure(ix): + path = fig_locs[ix] + dir, name = os.path.split(path) + if nonlocals.this_fig is not None: + plt.close(nonlocals.this_fig) + # with interactive_matplotlib_context(): + plt.close(plt.gcf()) + with open(path, "rb") as f: + fig = pickle.load(f) + fig.canvas.set_window_title(record.get_id()+': ' +name+': (Figure {}/{})'.format(ix+1, len(fig_locs))) + fig.canvas.mpl_connect('key_press_event', changefig) + print('Showing {}: Figure {}/{}. Full path: {}'.format(name, ix+1, len(fig_locs), path)) + # redraw_figure() + plt.show() + nonlocals.this_fig = plt.gcf() + + def changefig(keyevent): + if keyevent.key=='right': + nonlocals.figno = (nonlocals.figno+1)%len(fig_locs) + elif keyevent.key=='left': + nonlocals.figno = (nonlocals.figno-1)%len(fig_locs) + elif keyevent.key=='up': + nonlocals.figno = (nonlocals.figno-10)%len(fig_locs) + elif keyevent.key=='down': + nonlocals.figno = (nonlocals.figno+10)%len(fig_locs) + + elif keyevent.key==' ': + nonlocals.figno = queryfig() + else: + print("No handler for key: {}. Changing Nothing".format(keyevent.key)) + show_figure(nonlocals.figno) + + def queryfig(): + user_input = input('Which Figure (of 1-{})? >>'.format(len(fig_locs))) + try: + nonlocals.figno = int(user_input)-1 + except: + if user_input=='q': + raise Exception('Quit') + else: + print("No handler for input '{}'".format(user_input)) + return nonlocals.figno + + print('Use Left/Right arrows to navigate, ') + show_figure(nonlocals.figno) diff --git a/artemis/experiments/test_experiment_record.py b/artemis/experiments/test_experiment_record.py index ed0ddfbf..102573d8 100644 --- a/artemis/experiments/test_experiment_record.py +++ b/artemis/experiments/test_experiment_record.py @@ -17,7 +17,7 @@ load_experiment_record, ExperimentRecord, record_experiment, \ delete_experiment_with_id, get_current_record_dir, open_in_record_dir, \ ExpStatusOptions, get_current_experiment_id, get_current_experiment_record, \ - get_current_record_id, has_experiment_record, experiment_id_to_record_ids + get_current_record_id, has_experiment_record, experiment_id_to_record_ids, save_figure_in_record from artemis.experiments.experiments import get_experiment_info, load_experiment, experiment_testing_context, \ clear_all_experiments from artemis.experiments.test_experiments import test_unpicklable_args @@ -459,6 +459,23 @@ def my_generator_exp(n_steps, poison_4 = False): assert rec2.get_result() == 3 +def test_figure_saving_and_loading(): + + from artemis.plotting.db_plotting import dbplot + with experiment_testing_context(new_experiment_lib=True): + @experiment_function + def my_exp(): + for t in range(4): + dbplot(np.random.randn(20, 20, 3), 'plot') + save_figure_in_record() + + rec = my_exp.run() # type: ExperimentRecord + + fig_locs = rec.get_figure_locs() + + assert set(fig_locs) == {os.path.join(rec.get_dir(), 'fig-{}.pkl'.format(i)) for i in range(4)} + + if __name__ == '__main__': set_test_mode(True) @@ -482,3 +499,4 @@ def my_generator_exp(n_steps, poison_4 = False): test_current_experiment_access_functions() test_generator_experiment() test_unpicklable_args() + test_figure_saving_and_loading() \ No newline at end of file diff --git a/artemis/experiments/test_experiment_record_view_and_ui.py b/artemis/experiments/test_experiment_record_view_and_ui.py index 50cba1a8..9b2bba66 100644 --- a/artemis/experiments/test_experiment_record_view_and_ui.py +++ b/artemis/experiments/test_experiment_record_view_and_ui.py @@ -1,25 +1,27 @@ import pytest from artemis.experiments.decorators import ExperimentFunction, experiment_function +from artemis.experiments.experiment_record import save_figure_in_record from artemis.experiments.experiment_record_view import get_oneline_result_string, print_experiment_record_argtable, \ - compare_experiment_records, get_record_invalid_arg_string + compare_experiment_records, get_record_invalid_arg_string, browse_record_figs from artemis.experiments.experiments import experiment_testing_context, clear_all_experiments from artemis.general.display import CaptureStdOut, assert_things_are_printed +import numpy as np -def display_it(result): - print(str(result) + 'aaa') +def display_it(record): + print(str(record.get_result()) + 'aaa') def one_liner(result): return str(result) + 'bbb' -def compare_them(results): - print(', '.join('{}: {}'.format(k, results[k]) for k in sorted(results.keys()))) +def compare_them(records): + print(', '.join('{}: {}'.format(record.get_experiment().name, record.get_result()) for record in records)) -@ExperimentFunction(display_function=display_it, one_liner_function=one_liner, comparison_function=compare_them) +@ExperimentFunction(show=display_it, one_liner_function=one_liner, compare=compare_them) def my_xxxyyy_test_experiment(a=1, b=2): if b==17: @@ -82,8 +84,8 @@ def test_experiment_function_ui(): import time time.sleep(0.1) - with assert_things_are_printed(min_len=1200, things=['Common Args', 'Different Args', 'Result', 'a=1, b=2', 'a=2, b=2', 'a=1, b=17']): - my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='argtable all', close_after=True) + # with assert_things_are_printed(min_len=1200, things=['Common Args', 'Different Args', 'Result', 'a=1, b=2', 'a=2, b=2', 'a=1, b=17']): + # my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='argtable all', close_after=True) with assert_things_are_printed(min_len=600, things=['my_xxxyyy_test_experiment: 3', 'my_xxxyyy_test_experiment.a2: 4']): my_xxxyyy_test_experiment.browse(raise_display_errors=True, command='compare all -r', close_after=True) @@ -219,6 +221,23 @@ def my_simdfdscds(a=1): assert string.count('Start Time') == 1 +def demo_browse_record_figs(): + + from artemis.plotting.db_plotting import dbplot + from matplotlib import pyplot as plt + with experiment_testing_context(new_experiment_lib=True): + @experiment_function + def my_exp(): + for t in range(4): + pts = np.linspace(0, 3*(t+1), 400) + dbplot((pts*np.cos(pts), pts*np.sin(pts)), 'plot', title='t={}'.format(t), plot_type='line') + save_figure_in_record() + plt.close(plt.gcf()) + rec = my_exp.run() # type: ExperimentRecord + + browse_record_figs(rec) + + if __name__ == '__main__': test_experiments_function_additions() test_experiment_function_ui() @@ -227,3 +246,4 @@ def my_simdfdscds(a=1): test_simple_experiment_show() test_view_modes() test_duplicate_headers_when_no_records_bug_is_gone() + # demo_browse_record_figs() \ No newline at end of file diff --git a/artemis/experiments/ui.py b/artemis/experiments/ui.py index 55f9335a..ad6b6b71 100644 --- a/artemis/experiments/ui.py +++ b/artemis/experiments/ui.py @@ -267,6 +267,7 @@ def launch(self, command=None): 'view': self.view, 'archive': self.archive, 'h': self.help, + 'figures': self.figures, 'filter': self.filter, 'filterrec': self.filterrec, 'displayformat': self.displayformat, @@ -587,6 +588,17 @@ def info(self, *args): print(record.info.get_text()) print('='*64) + def figures(self, *args): + parser = argparse.ArgumentParser() + parser.add_argument('user_range', action='store', help='A selection of experiment records to show. ') + args = parser.parse_args(args) + user_range = args.user_range + records = select_experiment_records(user_range, self.exp_record_dict, flat=True) + if len(records)>1: + raise RecordSelectionError('Can only show figures for one record at a time. You selected {}'.format(len(records))) + from artemis.experiments.experiment_record_view import browse_record_figs + browse_record_figs(records[0]) + def logs(self, *args): parser = argparse.ArgumentParser() parser.add_argument('user_range', action='store', help='A selection of experiment records to show. ') diff --git a/artemis/general/deferred_defaults.py b/artemis/general/deferred_defaults.py new file mode 100644 index 00000000..dd26a77a --- /dev/null +++ b/artemis/general/deferred_defaults.py @@ -0,0 +1,20 @@ +import sys +import inspect +_CACHE = {} + + +def default(function, arg): + """ + Get the default value to the argument named 'arg' for function "function" + :param Callable function: The function from which to get the default value. + :param str arg: The name of the argument + :return Any: The default value + """ + if function not in _CACHE: + if sys.version_info < (3, 4): + all_arg_names, varargs_name, kwargs_name, defaults = inspect.getargspec(function) + else: + all_arg_names, varargs_name, kwargs_name, defaults, _, _, _ = inspect.getfullargspec(function) + _CACHE[function] = dict(zip(all_arg_names[-len(defaults):], defaults)) + assert arg in _CACHE[function], 'Function {} has no default argument "{}"'.format(function, arg) + return _CACHE[function][arg] diff --git a/artemis/general/display.py b/artemis/general/display.py index 03957f6d..28093cd4 100644 --- a/artemis/general/display.py +++ b/artemis/general/display.py @@ -405,7 +405,7 @@ def format_duration(seconds): else: return res else: - days = seconds//_seconds_in_day + days = int(seconds//_seconds_in_day) return '{:d}d,{}'.format(days, format_duration(seconds % _seconds_in_day)) diff --git a/artemis/general/ezprofile.py b/artemis/general/ezprofile.py index f394aff2..cdf853be 100644 --- a/artemis/general/ezprofile.py +++ b/artemis/general/ezprofile.py @@ -1,7 +1,7 @@ from logging import Logger from time import time from collections import OrderedDict - +from contextlib import contextmanager __author__ = 'peter' @@ -34,6 +34,10 @@ def lap(self, lap_name = None): def get_current_time(self): return time() - self._lap_times['Start'] + def get_total_time(self): + assert 'Stop' in self._lap_times, "The profiler has not exited yet, so you cannot get total time." + return self._lap_times['Stop'] - self._lap_times['Start'] + def __enter__(self): start_time = time() self.start_time = start_time @@ -60,3 +64,30 @@ def get_report(self): deltas = OrderedDict((key, self._lap_times[key] - self._lap_times[last_key]) for last_key, key in zip(keys[:-1], keys[1:])) return self.profiler_name + '\n '.join(['']+['%s: Elapsed time is %.4gs' % (key, val) for key, val in deltas.items()] + (['Total: %.4gs' % (self._lap_times.values()[-1] - self._lap_times.values()[0])] if len(deltas)>1 else [])) + + +_profile_contexts = OrderedDict() + + +@contextmanager +def profile_context(name, print_result = False): + + with EZProfiler(name, print_result=print_result) as prof: + yield prof + if name in _profile_contexts: + n_calls, elapsed = _profile_contexts[name] + else: + n_calls, elapsed = 0, 0. + n_calls, elapsed = n_calls+1, elapsed + prof.get_total_time() + _profile_contexts[name] = (n_calls, elapsed) + + +def get_profile_contexts(names=None, fill_empty_with_zero = False): + + if names is None: + return _profile_contexts + else: + if fill_empty_with_zero: + return OrderedDict((k, _profile_contexts[k] if k in _profile_contexts else 0) for k in names) + else: + return OrderedDict((k, _profile_contexts[k]) for k in names) diff --git a/artemis/general/functional.py b/artemis/general/functional.py index 67efa277..a1660c99 100644 --- a/artemis/general/functional.py +++ b/artemis/general/functional.py @@ -42,7 +42,7 @@ def __init__(self, func, arg_constructors): assert arg_name in all_arg_names, "Function {} has no argument named '{}'".format(func, arg_name) assert callable(arg_constructor), "The configuration for argument '{}' must be a function which constructs the argument. Got a {}".format(arg_name, type(arg_constructor).__name__) assert not inspect.isclass(arg_constructor), "'{}' is a class object. You must instead pass a function to construct an instance of this class. You can use lambda for this.".format(arg_constructor.__name__) - assert isinstance(arg_constructor, types.FunctionType), "The constructor '{}' appeared not to be a pure function. If it is an instance of a callable class, you probably meant to give a either a constructor for that instance.".format(arg_constructor) + assert isinstance(arg_constructor, types.FunctionType), "The constructor '{}' appeared not to be a pure function. If it is an instance of a callable class, you should instead give a staticmethod constructor for that instance.".format(arg_constructor) sub_arg_names, _, _, _ = advanced_getargspec(arg_constructor) for a in sub_arg_names: if a != arg_name: # If the name of your reparemetrizing argument is not the same as the argument you are replacing.... @@ -162,7 +162,7 @@ def advanced_getargspec(f): assert k in all_arg_names, "Constructed Argument '{}' appears not to exist in function {}".format(k, chain[0]) sub_all_arg_names, sub_varargs_name, sub_kwargs_name, sub_defaults = advanced_getargspec(constructor) assert sub_varargs_name is None, "Currently can't handle unnamed arguments for argument constructor {}={}".format(k, constructor) - assert sub_kwargs_name is None, "Currently can't handle unnamed keyword arguments for argument constructor {}={}".format(k, constructor) + assert sub_kwargs_name is None, "Currently can't handle unnamed keyword arguments. Constructor {}={}".format(k, constructor) all_arg_names.remove(k) # Since the argument has been reparameterized, it is removed from the list of constructor signature current_layer_arg_names.remove(k) assert not any(a in current_layer_arg_names for a in sub_all_arg_names), "The constructor for argument '{}' has name '{}', which us already used by the function '{}'. Rename it.".format(k, next(a for a in sub_all_arg_names if a in current_layer_arg_names), chain[0]) diff --git a/artemis/general/global_rates.py b/artemis/general/global_rates.py new file mode 100644 index 00000000..67773a16 --- /dev/null +++ b/artemis/general/global_rates.py @@ -0,0 +1,15 @@ +from artemis.general.global_vars import get_global, set_global +import time + + +class _RateMeasureSingleton: + pass + + +def measure_global_rate(name): + this_time = time.time() + key = (_RateMeasureSingleton, name) + n_calls, start_time = get_global(key, constructor=lambda: (0, this_time)) + set_global(key, (n_calls+1, start_time)) + return n_calls / (this_time - start_time) if this_time!=start_time else float('inf') + diff --git a/artemis/general/global_vars.py b/artemis/general/global_vars.py new file mode 100644 index 00000000..fed2a38d --- /dev/null +++ b/artemis/general/global_vars.py @@ -0,0 +1,29 @@ +from decorator import contextmanager + +_GLOBALS = {} + + +@contextmanager +def global_context(context_dict = None): + global _GLOBALS + if context_dict is None: + context_dict = {} + old_globals = _GLOBALS + _GLOBALS = context_dict + yield context_dict + _GLOBALS = old_globals + + +def get_global(identifier, constructor=None): + + if identifier not in _GLOBALS: + if constructor is not None: + _GLOBALS[identifier] = constructor() + else: + raise KeyError('No global variable with key: {}'.format(identifier)) + return _GLOBALS[identifier] + + +def set_global(identifier, value): + + _GLOBALS[identifier] = value diff --git a/artemis/general/should_be_builtins.py b/artemis/general/should_be_builtins.py index b2e9894b..4244c34f 100644 --- a/artemis/general/should_be_builtins.py +++ b/artemis/general/should_be_builtins.py @@ -2,7 +2,7 @@ from collections import OrderedDict import itertools import os - +import re import math from six.moves import xrange, zip_longest @@ -482,4 +482,22 @@ def entries_to_table(tuplelist, fill_value = None): def print_thru(x): print(x) - return x \ No newline at end of file + return x + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + """ + A key function to use for sorting strings. This captures numbers in the strings, so for example it will sort + + sorted(['y8', 'x10', 'x2', 'y12', 'x9'], key=natural_keys) == ['x2', 'x9', 'x10', 'y8', 'y12'] + + Taken from: https://stackoverflow.com/a/5967539/851699 + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + """ + return tuple(atoi(c) for c in re.split('(\d+)', text)) diff --git a/artemis/general/test_deferred_defaults.py b/artemis/general/test_deferred_defaults.py new file mode 100644 index 00000000..6d36c461 --- /dev/null +++ b/artemis/general/test_deferred_defaults.py @@ -0,0 +1,33 @@ +from artemis.general.deferred_defaults import default +from pytest import raises + + +def test_deferred_defaults(): + + def subfunction_1(a=2, b=3): + return a+b + + def subfunction_2(c=4): + return c**2 + + def main_function(a=default(subfunction_1, 'a'), b=default(subfunction_1, 'b'), c=default(subfunction_2, 'c')): + return subfunction_1(a=a, b=b) * subfunction_2(c=c) + + assert main_function()==(2+3)*4**2 + assert main_function(b=5)==(2+5)*4**2 + assert main_function(b=5, c=1)==(2+5)*1**2 + + +def check_that_errors_caught(): + + def g(a=4): + return a*2 + + with raises(AssertionError): + def f(a = default(g, 'b')): + return a + + +if __name__ == '__main__': + test_deferred_defaults() + check_that_errors_caught() diff --git a/artemis/general/test_should_be_builtins.py b/artemis/general/test_should_be_builtins.py index b6f8aa40..d637ecb7 100644 --- a/artemis/general/test_should_be_builtins.py +++ b/artemis/general/test_should_be_builtins.py @@ -4,7 +4,7 @@ from artemis.general.should_be_builtins import itermap, reducemap, separate_common_items, remove_duplicates, \ detect_duplicates, remove_common_prefix, all_equal, get_absolute_module, insert_at, get_shifted_key_value, \ - divide_into_subsets, entries_to_table + divide_into_subsets, entries_to_table, natural_keys __author__ = 'peter' @@ -122,6 +122,11 @@ def test_entries_to_table(): assert entries_to_table([[('a', 1), ('b', 2)], [('a', 3), ('b', 4), ('c', 5)]]) == (['a', 'b', 'c'], [[1, 2, None], [3, 4, 5]]) +def test_natural_keys(): + + assert sorted(['y8', 'x10', 'x2', 'y12', 'x9'], key=natural_keys) == ['x2', 'x9', 'x10', 'y8', 'y12'] + + if __name__ == '__main__': test_separate_common_items() test_reducemap() @@ -135,3 +140,4 @@ def test_entries_to_table(): test_get_shifted_key_value() test_divide_into_subsets() test_entries_to_table() + test_natural_keys() \ No newline at end of file diff --git a/artemis/ml/predictors/predictor_comparison.py b/artemis/ml/predictors/predictor_comparison.py index 4f8b5ed8..970d5b5b 100644 --- a/artemis/ml/predictors/predictor_comparison.py +++ b/artemis/ml/predictors/predictor_comparison.py @@ -8,7 +8,7 @@ from artemis.ml.tools.costs import get_evaluation_function from artemis.ml.tools.iteration import checkpoint_minibatch_index_generator from artemis.general.mymath import sqrtspace -from artemis.ml.tools.processors import RunningAverage +from artemis.ml.tools.running_averages import RunningAverage def compare_predictors(dataset, online_predictors={}, offline_predictors={}, minibatch_size = 'full', diff --git a/artemis/ml/tools/iteration.py b/artemis/ml/tools/iteration.py index b9ae2b0f..6ba00c2b 100644 --- a/artemis/ml/tools/iteration.py +++ b/artemis/ml/tools/iteration.py @@ -281,7 +281,7 @@ def generator_pool(generator_generator): yield generator -def batchify_generator(generator_generator, batch_size, receive_input=False, out_format ='array'): +def batchify_generator(generator_generator, batch_size = None, receive_input=False, out_format ='array'): """ Best understood by example: @@ -300,7 +300,7 @@ def batchify_generator(generator_generator, batch_size, receive_input=False, out new movies to start. :param generator_generator: An generator which generates generators - :param batch_size: The size if the batch you want to yield + :param batch_size: The size if the batch you want to yield. :param receive_input: Expect a "send" to this generatoer AFTER it yields. (see Python coroutines) :param out_format: 'array' or 'tuple_of_arrays' currently supported. :yield: An array consisting of batch_size of the outputs of the subgenerator, batched together. @@ -310,7 +310,15 @@ def batchify_generator(generator_generator, batch_size, receive_input=False, out total = batch_size assert out_format in ('array', 'tuple_of_arrays') - generators = [next(generator_generator) for _ in range(batch_size)] + + if batch_size is not None: + generators = [next(generator_generator) for _ in range(batch_size)] + else: + assert isinstance(generator_generator, (list, tuple)), "If you don't specify a batch size your generator-generator must be a finite list." + batch_size = len(generator_generator) + generators = generator_generator + generator_generator = iter(generator_generator) + while True: items = [] for i in range(batch_size): diff --git a/artemis/ml/tools/processors.py b/artemis/ml/tools/processors.py index b7fea780..5c033c7e 100644 --- a/artemis/ml/tools/processors.py +++ b/artemis/ml/tools/processors.py @@ -1,6 +1,5 @@ from abc import abstractmethod import numpy as np -from artemis.general.mymath import recent_moving_average from six.moves import xrange __author__ = 'peter' @@ -31,57 +30,6 @@ def inverse(self, data): return np.argmax(data, axis = 1) -class RunningAverage(object): - - def __init__(self): - self._n_samples_seen = 0 - self._average = 0 - - def __call__(self, data): - self._n_samples_seen+=1 - frac = 1./self._n_samples_seen - self._average = (1-frac)*self._average + frac*data - return self._average - - @classmethod - def batch(cls, x): - return np.cumsum(x, axis=0)/np.arange(1, len(x)+1).astype(np.float)[(slice(None), )+(None, )*(x.ndim-1)] - - -class RecentRunningAverage(object): - - def __init__(self): - self._n_samples_seen = 0 - self._average = 0 - - def __call__(self, data): - self._n_samples_seen+=1 - frac = 1/self._n_samples_seen**.5 - self._average = (1-frac)*self._average + frac*data - return self._average - - @classmethod - def batch(cls, x): - return recent_moving_average(x, axis=0) # Works only for python 2.X, with weave - # ra = cls() - # return np.array([ra(x_) for x_ in x]) - - -class RunningAverageWithBurnin(object): - - def __init__(self, burn_in_steps): - self._burn_in_step_remaining = burn_in_steps - self.averager = RunningAverage() - - def __call__(self, x): - - if self._burn_in_step_remaining > 0: - self._burn_in_step_remaining-=1 - return x - else: - return self.averager(x) - - class IDifferentiableFunction(object): @abstractmethod @@ -108,74 +56,6 @@ def backprop_delta(self, delta_y): return delta_y -class RunningCenter(IDifferentiableFunction): - """ - Keep an exponentially decaying running mean, subtract this from the value. - """ - def __init__(self, half_life): - self.decay_constant = np.exp(-np.log(2)/half_life) - self.one_minus_decay_constant = 1-self.decay_constant - self.running_mean = None - - def __call__(self, x): - if self.running_mean is None: - self.running_mean = np.zeros_like(x) - self.running_mean[:] = self.decay_constant * self.running_mean + self.one_minus_decay_constant * x - return x - self.running_mean - - def backprop_delta(self, delta_y): - return self.decay_constant * delta_y - - -class ExponentialRunningVariance(object): - - def __init__(self, decay): - self.decay = decay - self.running_mean = 0 - self.running_mean_sq = 1 - - def __call__(self, x, decay = None): - - decay = self.decay if decay is None else decay - self.running_mean = (1-decay) * self.running_mean + decay * x - self.running_mean_sq = (1-decay) * self.running_mean_sq + decay * x**2 - var = self.running_mean_sq - self.running_mean**2 - return np.maximum(0, var) # TODO: VERIFY THIS... Due to numerical issues, small negative values are possible... - -class RunningNormalize(IDifferentiableFunction): - - def __init__(self, half_life, eps = 1e-7, initial_std=1): - self.decay_constant = np.exp(-np.log(2)/half_life) - self.one_minus_decay_constant = 1-self.decay_constant - self.running_mean = None - self.eps = eps - self.initial_std = initial_std - - def __call__(self, x): - if self.running_mean is None: - self.running_mean = np.zeros_like(x) - self.running_mean_sq = np.zeros_like(x) + self.initial_std**2 - self.running_mean[:] = self.decay_constant * self.running_mean + self.one_minus_decay_constant * x - self.running_mean_sq[:] = self.decay_constant * self.running_mean_sq + self.one_minus_decay_constant * x**2 - std = np.sqrt(self.running_mean_sq - self.running_mean**2) - return (x - self.running_mean) / (std+self.eps) - - def backprop_delta(self, delta_y): - """ - Ok, we're not doing this right at all, but lets just ignore the contribution of the current - sample to the mean/std. This makes the gradient waaaaaay simpler. If you want to see the real thing, put - - (x-(a*u+(1-a)*x))/sqrt((a*s+(1-a)*x^2 - (a*u+(1-a)*x)^2)) - into http://www.derivative-calculator.net/ - (a stands for lambda here) - - :param delta_y: The derivative of the cost wrt the output of this normalizer - :return: delta_x: The derivative of the cost wrt the input of this normalizer - """ - std = np.sqrt(self.running_mean_sq - self.running_mean**2) - return delta_y/std - - def single_to_batch(fcn, *batch_inputs, **batch_kwargs): """ :param fcn: A function diff --git a/artemis/ml/tools/running_averages.py b/artemis/ml/tools/running_averages.py new file mode 100644 index 00000000..e9ec31e3 --- /dev/null +++ b/artemis/ml/tools/running_averages.py @@ -0,0 +1,154 @@ +import numpy as np + +from artemis.general.global_vars import get_global +from artemis.general.mymath import recent_moving_average +from artemis.ml.tools.processors import IDifferentiableFunction + + +class RunningAverage(object): + + def __init__(self): + self._n_samples_seen = 0 + self._average = 0 + + def __call__(self, data): + self._n_samples_seen+=1 + frac = 1./self._n_samples_seen + self._average = (1-frac)*self._average + frac*data + return self._average + + @classmethod + def batch(cls, x): + return np.cumsum(x, axis=0)/np.arange(1, len(x)+1).astype(np.float)[(slice(None), )+(None, )*(x.ndim-1)] + + +class RecentRunningAverage(object): + + def __init__(self): + self._n_samples_seen = 0 + self._average = 0 + + def __call__(self, data): + self._n_samples_seen+=1 + frac = 1/self._n_samples_seen**.5 + self._average = (1-frac)*self._average + frac*data + return self._average + + @classmethod + def batch(cls, x): + return recent_moving_average(x, axis=0) # Works only for python 2.X, with weave + + +class OptimalStepSizeAverage(object): + + def __init__(self, error_stepsize_target=0.01, initial_stepsize = 1., epsilon=1--7): + + self.error_stepsize_target = error_stepsize_target + self.error_stepsize = initial_stepsize # (nu) + self.error_stepsize_target = 0.001 # (nu-bar) + self.step_size = 1. # (a) + self.avg = 0 # (theta) + self.beta = 0. + self.delta = 0. + self.lambdaa = 0. + self.epsilon = epsilon + self.first_iter = True + + def __call__(self, x): + error = x-self.avg + error_stepsize = self.error_stepsize / (1 + self.error_stepsize - self.error_stepsize_target) + self.beta = (1-error_stepsize) * self.beta + error_stepsize * error + self.delta = (1-error_stepsize) * self.delta + error_stepsize * error**2 + sigma_sq = (self.delta-self.beta**2)/(1+self.lambdaa) + self.step_size = np.array(1.) if self.first_iter else 1 - (sigma_sq+self.epsilon) / (self.delta+self.epsilon) + # step_size = 1 - (sigma_sq+self.epsilon) / (delta+self.epsilon) + self.lambdaa = (1-self.step_size)**2* self.lambdaa + self.step_size**2 # TODO: Test: Should it be (1-step_size**2) ?? + avg = (1-self.step_size) * self.avg + self.step_size * x + # new_obj = OptimalStepSizer(error_stepsize=error_stepsize, error_stepsize_target=self.error_stepsize_target, + # step_size=step_size, avg=avg, beta=beta, delta = delta, lambdaa=lambdaa, epsilon=self.epsilon, first_iter=False) + + if np.any(np.isnan(avg)): + raise Exception() + return avg + + +class RunningAverageWithBurnin(object): + + def __init__(self, burn_in_steps): + self._burn_in_step_remaining = burn_in_steps + self.averager = RunningAverage() + + def __call__(self, x): + + if self._burn_in_step_remaining > 0: + self._burn_in_step_remaining-=1 + return x + else: + return self.averager(x) + + +class RunningCenter(object): + """ + Keep an exponentially decaying running mean, subtract this from the value. + """ + def __init__(self, half_life): + self.decay_constant = np.exp(-np.log(2)/half_life) + self.one_minus_decay_constant = 1-self.decay_constant + self.running_mean = None + + def __call__(self, x): + if self.running_mean is None: + self.running_mean = np.zeros_like(x) + self.running_mean[:] = self.decay_constant * self.running_mean + self.one_minus_decay_constant * x + return x - self.running_mean + + +class ExponentialRunningVariance(object): + + def __init__(self, decay): + self.decay = decay + self.running_mean = 0 + self.running_mean_sq = 1 + + def __call__(self, x, decay = None): + + decay = self.decay if decay is None else decay + self.running_mean = (1-decay) * self.running_mean + decay * x + self.running_mean_sq = (1-decay) * self.running_mean_sq + decay * x**2 + var = self.running_mean_sq - self.running_mean**2 + return np.maximum(0, var) # TODO: VERIFY THIS... Due to numerical issues, small negative values are possible... + + +class RunningNormalize(IDifferentiableFunction): + + def __init__(self, half_life, eps = 1e-7, initial_std=1): + + self.decay_constant = np.exp(-np.log(2)/half_life) + self.one_minus_decay_constant = 1-self.decay_constant + self.running_mean = None + self.eps = eps + self.initial_std = initial_std + + def __call__(self, x): + if self.running_mean is None: + self.running_mean = np.zeros_like(x) + self.running_mean_sq = np.zeros_like(x) + self.initial_std**2 + self.running_mean[:] = self.decay_constant * self.running_mean + self.one_minus_decay_constant * x + self.running_mean_sq[:] = self.decay_constant * self.running_mean_sq + self.one_minus_decay_constant * x**2 + std = np.sqrt(self.running_mean_sq - self.running_mean**2) + return (x - self.running_mean) / (std+self.eps) + + +_running_averages = {} + + +def get_global_running_average(value, identifier, ra_type='simple'): + """ + Get the running average of a variable. + :param value: The latest value of the variable + :param identifier: An identifier (to store the state of the running averager) + :param ra_type: The type of running averge. Options are 'simple', 'recent', 'osa' + :return: The running average + """ + running_averager = get_global(identifier=identifier, constructor=lambda: (ra_type() if callable(ra_type) else {'simple': RunningAverage, 'recent': RecentRunningAverage, 'osa': OptimalStepSizeAverage}[ra_type]())) + return running_averager(value) diff --git a/artemis/ml/tools/test_running_averages.py b/artemis/ml/tools/test_running_averages.py new file mode 100644 index 00000000..8795bcbb --- /dev/null +++ b/artemis/ml/tools/test_running_averages.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest +from six.moves import xrange + +from artemis.general.global_vars import global_context +from artemis.ml.tools.running_averages import RunningAverage, RecentRunningAverage, get_global_running_average + +__author__ = 'peter' + + +def test_running_average(): + + inp = np.arange(5) + processor = RunningAverage() + out = [processor(el) for el in inp] + assert out == [0, 0.5, 1, 1.5, 2] + assert np.array_equal(out, RunningAverage.batch(inp)) + + inp = np.random.randn(10, 5) + processor = RunningAverage() + out = [processor(el) for el in inp] + assert all(np.allclose(out[i], np.mean(inp[:i+1], axis = 0)) for i in xrange(len(inp))) + + +@pytest.mark.skipif(True, reason='Depends on weave, which is deprecated for python 3') +def test_recent_running_average(): + + inp = np.arange(5) + processor = RecentRunningAverage() + out = [processor(el) for el in inp] + out2 = processor.batch(inp) + assert np.allclose(out, out2) + assert np.allclose(out, [0.0, 0.7071067811865475, 1.4535590291019362, 2.226779514550968, 3.019787823462811]) + + inp = np.random.randn(10, 5) + processor = RunningAverage() + out = [processor(el) for el in inp] + out2 = processor.batch(inp) + assert np.allclose(out, out2) + + +def test_get_global_running_average(): + + n_steps = 100 + + rng = np.random.RandomState(1234) + + sig = 2.5*(1-np.exp(-np.linspace(0, 10, n_steps))) + noise = rng.randn(n_steps)*0.1 + fullsig = sig + noise + with global_context(): + for x in fullsig: + ra = get_global_running_average(x, 'my_ra_simple', ra_type='simple') + assert 2.24 < ra < 2.25 + for x in fullsig: + ra = get_global_running_average(x, 'my_ra_recent', ra_type='recent') + assert 2.490 < ra < 2.491 + for x in fullsig: + ra = get_global_running_average(x, 'my_ra_osa', ra_type='osa') + assert 2.44 < ra < 2.45 + + +if __name__ == '__main__': + test_get_global_running_average() \ No newline at end of file diff --git a/artemis/plotting/db_plotting.py b/artemis/plotting/db_plotting.py index 655ccf9d..e7ee3a31 100644 --- a/artemis/plotting/db_plotting.py +++ b/artemis/plotting/db_plotting.py @@ -191,6 +191,7 @@ class DBPlotTypes: LINE= LinePlot THICK_LINE= partial(LinePlot, plot_kwargs={'linewidth': 3}) POS_LINE= partial(LinePlot, y_bounds=(0, None), y_bound_extend=(0, 0.05)) + SCATTER= partial(LinePlot, plot_kwargs=dict(marker='.', markersize=7), linestyle='') BBOX= partial(BoundingBoxPlot, linewidth=2, axes_update_mode='expand') BBOX_R= partial(BoundingBoxPlot, linewidth=2, color='r', axes_update_mode='expand') BBOX_B= partial(BoundingBoxPlot, linewidth=2, color='b', axes_update_mode='expand') diff --git a/artemis/plotting/matplotlib_backend.py b/artemis/plotting/matplotlib_backend.py index da156fb6..c623c75e 100644 --- a/artemis/plotting/matplotlib_backend.py +++ b/artemis/plotting/matplotlib_backend.py @@ -321,12 +321,15 @@ def __init__(self, axes_update_mode='expand', **kwargs): self._image_handle = None self._last_data_shape = None - def update(self, data): + def _plot_last_data(self, data): """ :param data: A (left, bottom, right, top) bounding box. """ if self._image_handle is None: - self._image_handle = next(c for c in plt.gca().get_children() if isinstance(c, AxesImage)) + try: + self._image_handle = next(c for c in plt.gca().get_children() if isinstance(c, AxesImage)) + except StopIteration: + raise Exception('Could not find any image plots in the current axis to draw bounding boxes on! Check that "axis" argument matches the name of a previous image plot') data_shape = self._image_handle.get_array().shape # Hopefully this isn't copying if data_shape != self._last_data_shape: @@ -342,7 +345,7 @@ def update(self, data): x = np.array([l, l, r, r, l]) # Note: should we be adding .5? The extend already subtracts .5 y = np.array([t, b, b, t, t]) - LinePlot.update(self, (x, y)) + LinePlot._plot_last_data(self, (x, y)) class MovingPointPlot(LinePlot): diff --git a/artemis/plotting/test_db_plotting.py b/artemis/plotting/test_db_plotting.py index 8bdd566c..b6a8178f 100644 --- a/artemis/plotting/test_db_plotting.py +++ b/artemis/plotting/test_db_plotting.py @@ -4,7 +4,7 @@ import numpy as np from artemis.plotting.demo_dbplot import demo_dbplot from artemis.plotting.db_plotting import dbplot, clear_dbplot, hold_dbplots, freeze_all_dbplots, reset_dbplot, \ - dbplot_hang + dbplot_hang, DBPlotTypes from artemis.plotting.matplotlib_backend import LinePlot, HistogramPlot, MovingPointPlot, is_server_plotting_on, \ ResamplingLineHistory import pytest @@ -204,51 +204,31 @@ def test_individual_periodic_plotting(): time.sleep(0.02) +def test_bbox_display(): + + # It once was the case that bboxes failed when in a hold block with their image. Not any more. + with hold_dbplots(): + dbplot((np.random.rand(40, 40)*255.999).astype(np.uint8), 'gfdsg') + dbplot((np.random.rand(40, 40)*255.999).astype(np.uint8), 'img') + dbplot([10, 20, 25, 30], 'bbox', axis='img', plot_type=DBPlotTypes.BBOX) + + if __name__ == '__main__': - if is_server_plotting_on(): - test_cornertext() - time.sleep(2.) - test_trajectory_plot() - time.sleep(2.) - test_demo_dbplot() - time.sleep(2.) - test_two_plots_in_the_same_axis_version_1() - time.sleep(2.) - test_two_plots_in_the_same_axis_version_2() - time.sleep(2.) - test_moving_point_multiple_points() - time.sleep(2.) - test_list_of_images() - time.sleep(2.) - test_multiple_figures() - time.sleep(2.) - test_same_object() - time.sleep(2.) - test_history_plot_updating() - time.sleep(2.) - test_particular_plot() - time.sleep(2.) - test_dbplot() - time.sleep(2.) - test_custom_axes_placement() - time.sleep(2.) - test_close_and_open() - time.sleep(2.) - else: - test_cornertext() - test_trajectory_plot() - test_demo_dbplot() - test_freeze_dbplot() - test_two_plots_in_the_same_axis_version_1() - test_two_plots_in_the_same_axis_version_2() - test_moving_point_multiple_points() - test_list_of_images() - test_multiple_figures() - test_same_object() - test_history_plot_updating() - test_particular_plot() - test_dbplot() - test_custom_axes_placement() - test_close_and_open() - test_periodic_plotting() - test_individual_periodic_plotting() + test_cornertext() + test_trajectory_plot() + test_demo_dbplot() + test_freeze_dbplot() + test_two_plots_in_the_same_axis_version_1() + test_two_plots_in_the_same_axis_version_2() + test_moving_point_multiple_points() + test_list_of_images() + test_multiple_figures() + test_same_object() + test_history_plot_updating() + test_particular_plot() + test_dbplot() + test_custom_axes_placement() + test_close_and_open() + test_periodic_plotting() + test_individual_periodic_plotting() + test_bbox_display() From 949dcd41309a01f34fc020346cba68bbdcb7ffd8 Mon Sep 17 00:00:00 2001 From: Peter Date: Tue, 20 Nov 2018 11:27:18 +0100 Subject: [PATCH 10/41] ducks now support boolean indexing --- artemis/experiments/experiment_record_view.py | 26 +++++++++++++++- artemis/general/duck.py | 31 +++++++++++++++++-- artemis/general/test_duck.py | 14 +++++++++ 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 57341e0f..f02fbf17 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -9,6 +9,7 @@ load_experiment_record, is_matplotlib_imported, UnPicklableArg from artemis.general.display import deepstr, truncate_string, hold_numpy_printoptions, side_by_side, \ surround_with_header, section_with_header, dict_to_str +from artemis.general.duck import Duck from artemis.general.nested_structures import flatten_struct, PRIMATIVE_TYPES from artemis.general.should_be_builtins import separate_common_items, bad_value, izip_equal, \ remove_duplicates, get_unique_name, entries_to_table @@ -358,6 +359,29 @@ def compare_experiment_records(records, parallel_text=None, show_logs=True, trun return has_matplotlib_figures +def make_record_comparison_duck(records, only_different_args = False, results_extractor = None): + """ + Make a data structure containing arguments and results of the experiment. + :param Sequence[ExperimentRecord] records: + :param Optional[Callable] results_extractor: + :return Duck: A Duck with one entry per record. Each entry has keys ['args', 'result'] + """ + duck = Duck() + + if only_different_args: + common, diff = separate_common_args(records) + else: + common = None + + for rec in records: + duck[next, 'args', :] = rec.get_args() if common is None else OrderedDict((k, v) for k, v in rec.get_args().items() if k not in common) + result = rec.get_result() + if results_extractor is not None: + result = results_extractor(result) + duck[-1, 'result', ...] = result + return duck + + def make_record_comparison_table(records, args_to_show=None, results_extractor = None, print_table = False, tablefmt='simple', reorder_by_args=False): """ Make a table comparing the arguments and results of different experiment records. You can use the output @@ -425,7 +449,7 @@ def separate_common_args(records, as_dicts=False, return_dict = False, only_shar :param records: A List of records :param return_dict: Return the different args as a dict - :return: (common, different) + :return Tuple[OrderedDict[str, Any], List[OrderedDict[str, Any]]: (common, different) Where common is an OrderedDict of common args different is a list (the same lengths of records) of OrderedDicts containing args that are not the same in all records. """ diff --git a/artemis/general/duck.py b/artemis/general/duck.py index 063dd45a..4c9a03da 100644 --- a/artemis/general/duck.py +++ b/artemis/general/duck.py @@ -87,6 +87,8 @@ def from_struct(cls, struct): return DynamicSequence(struct) elif isinstance(struct, dict): return UniversalOrderedStruct(struct) + elif isinstance(struct, Duck): + return UniversalCollection.from_struct(struct.to_struct()) elif struct is None or isinstance(struct, EmptyCollection): return EmptyCollection() else: @@ -167,8 +169,17 @@ def keys(self): def __getitem__(self, ix): if isinstance(ix, slice): return DynamicSequence(list.__getitem__(self, ix)) + elif isinstance(ix, UniversalCollection): + return self.__getitem__(ix.to_struct()) elif isinstance(ix, (list, tuple)): - return DynamicSequence((list.__getitem__(self, i) for i in ix)) + arrix = np.array(ix) + if arrix.dtype==np.bool: + if len(arrix) != len(self): + raise InvalidKeyError('If you use boolean indices, the length ({} here) must match the length of the collection ({} here)'.format(len(arrix), len(self))) + else: + return DynamicSequence(a for a, b in izip_equal(self, arrix) if b) + else: + return DynamicSequence((list.__getitem__(self, i) for i in ix)) else: try: return list.__getitem__(self, ix) @@ -416,7 +427,7 @@ def __getitem__(self, indices): else: # Case 2: There are deeper indices to get if not isinstance(new_substruct, Duck): raise KeyError('Leave value "{}" can not be broken into with {}'.format(new_substruct, indices[1:])) - if isinstance(first_selector, (list, np.ndarray, slice)): # Sliced selection, with more sub-indices + if isinstance(first_selector, (list, np.ndarray, slice, UniversalCollection)): # Sliced selection, with more sub-indices return new_substruct.map(lambda x: x.__getitem__(indices[1:])) else: # Simple selection, with more sub-indices return new_substruct[indices[1:]] @@ -715,3 +726,19 @@ def description(self, max_expansion=4, _skip_intro=False): if i>max_expansion: break return ('' if _skip_intro else (str(self) + '')) + indent_string(key_value_string, indent='| ', include_first=False) + + def each_eq(self, item): + """ + :param item: Any python object. + :return: A new Duck filled with boolean values indicating if each element of this Duck is equal to the given item. + (this can be used for boolean indexing) + """ + return self.map(lambda x: item==x) + + def each_in(self, item_set): + """ + :param Sequence item: A set of items + :return: A new Duck filled with boolean values indicating if each element of this Duck is in the set. + (this can be used for boolean indexing) + """ + return self.map(lambda x: (x in item_set)) diff --git a/artemis/general/test_duck.py b/artemis/general/test_duck.py index 275c3640..d7dd6f74 100644 --- a/artemis/general/test_duck.py +++ b/artemis/general/test_duck.py @@ -554,6 +554,19 @@ def test_has_key(): assert not duck.has_key('q') +def test_boolean_indexing(): + + d = Duck() + d[next, :] = {'a': 1, 'b': 2} + d[next, :] = {'a': 4, 'b': 3} + d[next, :] = {'a': 3, 'b': 6} + d[next, :] = {'a': 6, 'b': 2} + + assert d[[True, False, False, True], 'a'] == [1, 6] + assert d[d[:, 'b'].each_eq(2), 'a'] == [1, 6] + assert d[d[:, 'b'].each_in({3, 6}), 'a'] == [4, 3] + + if __name__ == '__main__': test_so_demo() test_dict_assignment() @@ -582,3 +595,4 @@ def test_has_key(): test_key_get_on_set_bug() test_occasional_value_filter() test_has_key() + test_boolean_indexing() From e1700cad8038dab271d572a5f6b4f60ba2894758 Mon Sep 17 00:00:00 2001 From: Peter Date: Fri, 7 Dec 2018 18:57:29 +0100 Subject: [PATCH 11/41] stuuuufff --- artemis/experiments/experiment_record_view.py | 2 ++ artemis/fileman/config_files.py | 2 +- artemis/fileman/file_getter.py | 8 ++++- artemis/general/duck.py | 19 +++++++++++ artemis/general/global_rates.py | 25 +++++++++++++++ artemis/general/test_duck.py | 10 +++--- artemis/ml/tools/running_averages.py | 6 +++- artemis/plotting/drawing_plots.py | 2 +- artemis/plotting/expanding_subplots.py | 2 +- artemis/plotting/matplotlib_backend.py | 2 +- artemis/plotting/pyplot_plus.py | 32 +++++++++++++++---- 11 files changed, 94 insertions(+), 16 deletions(-) diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index f02fbf17..7da1c8bb 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -378,6 +378,8 @@ def make_record_comparison_duck(records, only_different_args = False, results_ex result = rec.get_result() if results_extractor is not None: result = results_extractor(result) + duck[-1, 'exp_id'] = rec.get_experiment_id() + duck[-1, 'id'] = rec.get_id() duck[-1, 'result', ...] = result return duck diff --git a/artemis/fileman/config_files.py b/artemis/fileman/config_files.py index ea2a65a6..c03b7352 100644 --- a/artemis/fileman/config_files.py +++ b/artemis/fileman/config_files.py @@ -65,7 +65,7 @@ def get_config_value(config_filename, section, option, default_generator=None, w config = _get_config_object(config_path) if not config.has_section(section): config.add_section(section) - config.set(section, option, value) + config.set(section, option, str(value)) with open(config_path, 'w') as f: config.write(f) diff --git a/artemis/fileman/file_getter.py b/artemis/fileman/file_getter.py index 0790916e..820b41ca 100644 --- a/artemis/fileman/file_getter.py +++ b/artemis/fileman/file_getter.py @@ -1,6 +1,9 @@ import hashlib from contextlib import contextmanager +from io import BytesIO from shutil import rmtree + +import sys from six.moves import StringIO import gzip import tarfile @@ -166,7 +169,10 @@ def get_archive(url, relative_path=None, force_extract=False, archive_type = Non def unzip_gz(data): - return gzip.GzipFile(fileobj = StringIO(data)).read() + if sys.version_info[0] < 3: + return gzip.GzipFile(fileobj = StringIO(data)).read() + else: + return gzip.GzipFile(fileobj = BytesIO(data)).read() def get_file_path(relative_name = None, url=None, make_folder = False): diff --git a/artemis/general/duck.py b/artemis/general/duck.py index 4c9a03da..077d2b7c 100644 --- a/artemis/general/duck.py +++ b/artemis/general/duck.py @@ -58,6 +58,15 @@ def __setitem__(self, key, value): def __eq__(self, other): return self.to_struct() == (other.to_struct() if isinstance(other, UniversalCollection) else other) + def __and__(self, other): + return self.from_struct([a & b for a, b in izip_equal(self, other)]) + + def __or__(self, other): + return self.from_struct([a | b for a, b in izip_equal(self, other)]) + + def __invert__(self): + return self.from_struct([not a for a in self]) + @abstractmethod def to_struct(self): raise NotImplementedError() @@ -181,6 +190,7 @@ def __getitem__(self, ix): else: return DynamicSequence((list.__getitem__(self, i) for i in ix)) else: + assert not isinstance(ix, bool), 'You cannot index with a boolean.' try: return list.__getitem__(self, ix) except TypeError: @@ -694,6 +704,15 @@ def items(self, depth=None): for k in self.keys(depth=depth): yield k, self[k] + def only(self): + """ + Assert that this duck contains only one element, and return that element. + :return: The only element inside this duck. + """ + keys = list(self.keys()) + assert len(keys)==1, 'You called Duck.only() on a duck with {} elements. "only" can only be called on single-element ducks.'.format(len(self)) + return self[keys[0]] + def __str__(self, max_key_len=4): keys = list(self.keys()) if len(keys)>max_key_len: diff --git a/artemis/general/global_rates.py b/artemis/general/global_rates.py index 67773a16..040e6eb8 100644 --- a/artemis/general/global_rates.py +++ b/artemis/general/global_rates.py @@ -1,3 +1,5 @@ +from contextlib import contextmanager + from artemis.general.global_vars import get_global, set_global import time @@ -13,3 +15,26 @@ def measure_global_rate(name): set_global(key, (n_calls+1, start_time)) return n_calls / (this_time - start_time) if this_time!=start_time else float('inf') + +class _ElapsedMeasureSingleton: + pass + + +@contextmanager +def measure_rate_context(name): + start = time.time() + key = (_ElapsedMeasureSingleton, name) + n_calls, elapsed = get_global(key, constructor=lambda: (0, 0.)) + yield n_calls / elapsed if elapsed > 0 else float('nan') + end = time.time() + set_global(key, (n_calls+1, elapsed+(end-start))) + + +@contextmanager +def measure_runtime_context(name): + start = time.time() + key = (_ElapsedMeasureSingleton, name) + n_calls, elapsed = get_global(key, constructor=lambda: (0, 0.)) + yield elapsed / n_calls if n_calls > 0 else float('nan') + end = time.time() + set_global(key, (n_calls+1, elapsed+(end-start))) diff --git a/artemis/general/test_duck.py b/artemis/general/test_duck.py index d7dd6f74..dd0906e1 100644 --- a/artemis/general/test_duck.py +++ b/artemis/general/test_duck.py @@ -557,14 +557,16 @@ def test_has_key(): def test_boolean_indexing(): d = Duck() - d[next, :] = {'a': 1, 'b': 2} - d[next, :] = {'a': 4, 'b': 3} - d[next, :] = {'a': 3, 'b': 6} - d[next, :] = {'a': 6, 'b': 2} + d[next, :] = {'a': 1, 'b': 2, 'c': 7} + d[next, :] = {'a': 4, 'b': 3, 'c': 8} + d[next, :] = {'a': 3, 'b': 6, 'c': 9} + d[next, :] = {'a': 6, 'b': 2, 'c': 0} assert d[[True, False, False, True], 'a'] == [1, 6] assert d[d[:, 'b'].each_eq(2), 'a'] == [1, 6] assert d[d[:, 'b'].each_in({3, 6}), 'a'] == [4, 3] + assert d[d[:, 'b'].each_eq(3) | d[:, 'b'].each_eq(6), 'a'] == [4, 3] + assert d[d[:, 'b'].each_in({3, 6}) & ~d[:, 'a'].each_in({3, 6})].only()['c'] == 8 # "Find the 'c' value of the only item in the duck where b is in {3, 6} and 'a' is not in {3, 6} if __name__ == '__main__': diff --git a/artemis/ml/tools/running_averages.py b/artemis/ml/tools/running_averages.py index e9ec31e3..3dedaf48 100644 --- a/artemis/ml/tools/running_averages.py +++ b/artemis/ml/tools/running_averages.py @@ -36,7 +36,11 @@ def __call__(self, data): @classmethod def batch(cls, x): - return recent_moving_average(x, axis=0) # Works only for python 2.X, with weave + try: + return recent_moving_average(x, axis=0) # Works only for python 2.X, with weave + except ModuleNotFoundError: + rma = RecentRunningAverage() + return np.array([rma(xt) for xt in x]) class OptimalStepSizeAverage(object): diff --git a/artemis/plotting/drawing_plots.py b/artemis/plotting/drawing_plots.py index 890cfe49..c9636415 100644 --- a/artemis/plotting/drawing_plots.py +++ b/artemis/plotting/drawing_plots.py @@ -2,7 +2,7 @@ from matplotlib import pyplot as plt __author__ = 'peter' -_plotting_mode = get_artemis_config_value(section='plotting', option='mode') +_plotting_mode = get_artemis_config_value(section='plotting', option='mode', default_generator=lambda: 'safe') if _plotting_mode == 'safe': diff --git a/artemis/plotting/expanding_subplots.py b/artemis/plotting/expanding_subplots.py index f310a0e5..af41370d 100644 --- a/artemis/plotting/expanding_subplots.py +++ b/artemis/plotting/expanding_subplots.py @@ -226,7 +226,7 @@ def hstack_plots(spacing=0, sharex=False, sharey = True, grid=False, show_x=True with _define_plot_settings(layout='h', show_y = False if show_y=='once' else show_y, show_x = show_x, grid=grid, sharex=sharex, sharey=sharey, xlabel=xlabel, xlim=xlim, ylim=ylim): set_figure_border_size(wspace=spacing, left=left_pad, right=right_pad, top=top_pad, bottom=bottom_pad) yield - new_subplots = cap.get_new_subplots().values() + new_subplots = list(cap.get_new_subplots().values()) if clip_x: set_same_xlims(new_subplots) diff --git a/artemis/plotting/matplotlib_backend.py b/artemis/plotting/matplotlib_backend.py index c623c75e..af3e4b65 100644 --- a/artemis/plotting/matplotlib_backend.py +++ b/artemis/plotting/matplotlib_backend.py @@ -620,4 +620,4 @@ def get_plotting_server_address(): return _PLOTTING_SERVER -BACKEND = get_artemis_config_value(section='plotting', option='backend') +BACKEND = get_artemis_config_value(section='plotting', option='backend', default_generator=lambda: 'matplotlib', write_default=True) diff --git a/artemis/plotting/pyplot_plus.py b/artemis/plotting/pyplot_plus.py index 46ccef14..43a43e73 100644 --- a/artemis/plotting/pyplot_plus.py +++ b/artemis/plotting/pyplot_plus.py @@ -169,14 +169,34 @@ def set_lines_color_cycle_map(name, length): def get_line_color(ix, modifier=None): - colour = next(c for i, c in enumerate(get_lines_color_cycle()) if i==ix) - if modifier=='dark': - return tuple(c/2 for c in colors.hex2color(colour)) - elif modifier=='light': - return tuple(1-(1-c)/2 for c in colors.hex2color(colour)) + # Back compatibilituy + return modify_color(ix, modifier=modifier) + + +def modify_color(color_specifier, modifier): + rgba = get_color_from_spec(color_specifier) + if callable(modifier): + return modifier(rgba) + elif isinstance(modifier, str): + if modifier=='dark': + return tuple(c/2 for c in colors.hex2color(rgba)) + elif modifier=='light': + return tuple(1-(1-c)/2 for c in colors.hex2color(rgba)) + elif modifier.startswith('alpha:'): + alpha_val = float(modifier[len('alpha:'):]) + return rgba[:3]+(alpha_val, ) + else: + raise NotImplementedError(modifier) elif modifier is not None: raise NotImplementedError(modifier) - return colors.hex2color(colour) + + +def get_color_from_spec(spec): + if isinstance(spec, int): + colour = next(c for i, c in enumerate(get_lines_color_cycle()) if i==spec) + return colour + (1., ) + else: + return tuple(colors.to_rgba(spec)) def relabel_axis(axis, value_array, n_points = 5, format_str='{:.2g}'): From 2e412dddd0c0433c9f60f73e1462c1bc1155325d Mon Sep 17 00:00:00 2001 From: Peter Date: Fri, 14 Dec 2018 13:28:24 +0100 Subject: [PATCH 12/41] join and split --- artemis/general/test_data_splitting.py | 28 ++++++++++++++++++++++ artemis/ml/tools/data_splitting.py | 33 ++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 artemis/general/test_data_splitting.py diff --git a/artemis/general/test_data_splitting.py b/artemis/general/test_data_splitting.py new file mode 100644 index 00000000..9d735b3b --- /dev/null +++ b/artemis/general/test_data_splitting.py @@ -0,0 +1,28 @@ +import numpy as np + +from artemis.ml.tools.data_splitting import join_arrays_and_get_rebuild_func + + +def test_join_arrays_and_get_rebuild_function(): + + n_samples = 5 + randn = np.random.RandomState(1234).randn + + struct = [ + (randn(n_samples, 3), randn(n_samples)), + randn(n_samples, 4, 5) + ] + + joined, rebuild_func = join_arrays_and_get_rebuild_func(struct, axis=1) + + assert joined.shape == (n_samples, 3+1+4*5) + + new_struct = rebuild_func(joined*2) + + assert np.array_equal(struct[0][0]*2, new_struct[0][0]) + assert np.array_equal(struct[0][1]*2, new_struct[0][1]) + assert np.array_equal(struct[1]*2, new_struct[1]) + + +if __name__ == '__main__': + test_join_arrays_and_get_rebuild_function() diff --git a/artemis/ml/tools/data_splitting.py b/artemis/ml/tools/data_splitting.py index a06f3e5f..994b34ef 100644 --- a/artemis/ml/tools/data_splitting.py +++ b/artemis/ml/tools/data_splitting.py @@ -1,8 +1,12 @@ import numpy as np from six.moves import xrange +from artemis.general.nested_structures import NestedType +from artemis.general.should_be_builtins import izip_equal + __author__ = 'peter' + def split_data_by_label(data, labels, frac_training = 0.5): """ Split the data so that each label gets approximately the correct proportions between the training and test sets @@ -22,3 +26,32 @@ def split_data_by_label(data, labels, frac_training = 0.5): training_indices = np.sort(np.concatenate([ixs[:c] for ixs, c in zip(label_indices, cutoffs)])) test_indices = np.sort(np.concatenate([ixs[c:] for ixs, c in zip(label_indices, cutoffs)])) return data[training_indices], labels[training_indices], data[test_indices], labels[test_indices] + + +def join_arrays_and_get_rebuild_func(arrays, axis = 0): + """ + Given a nested structure of arrays, join them into a single array by flattening dimensions from axis on + concatenating them. Return the joined array and a function which can take the joined array and reproduce the + original structure. + + :param arrays: A possibly nested structure containing arrays which you want to join into a single array. + :param axis: Axis after which to flatten and join all arrays. The resulting array will be (dim+1) dimensional. + :return ndarray, Callable[[ndarray], [Any]]: The joined array, and the function which can be called to reconstruct + the structure from the joined array. + """ + nested_type = NestedType.from_data(arrays) + data_list = nested_type.get_leaves(arrays) + split_shapes = [x_.shape for x_ in data_list] + pre_join_shapes = [list(x_.shape[:axis]) + [np.prod(list(x_.shape[axis:]), dtype=int)] for x_ in data_list] + split_axis_ixs = np.cumsum([0]+[s_[-1] for s_ in pre_join_shapes], axis=0) + joined_arr = np.concatenate(list(x_.reshape(s_) for x_, s_ in izip_equal(data_list, pre_join_shapes)), axis=axis) + + def rebuild_function(joined_array, share_data = True): + if share_data: + x_split = [joined_array[..., start:end].reshape(shape) for (start, end, shape) in izip_equal(split_axis_ixs[:-1], split_axis_ixs[1:], split_shapes)] + else: # Note: this will raise an Error if the self.dim != 0, because the data is no longer contigious in memory. + x_split = [joined_array[..., start:end].copy().reshape(shape) for (start, end, shape) in izip_equal(split_axis_ixs[:-1], split_axis_ixs[1:], split_shapes)] + x_reassembled = nested_type.expand_from_leaves(x_split, check_types=False) + return x_reassembled + + return joined_arr, rebuild_function From b745818be92e8d5c22c9d24ec2dde0166832c8ac Mon Sep 17 00:00:00 2001 From: Peter Date: Wed, 19 Dec 2018 10:59:10 +0100 Subject: [PATCH 13/41] fixed pareto stuff --- artemis/general/global_rates.py | 24 +++++++++ artemis/general/nested_structures.py | 17 ++++--- artemis/general/pareto_efficiency.py | 38 +++++++++++--- artemis/general/test_pareto_efficiency.py | 60 +++++++++++++++-------- artemis/ml/tools/data_splitting.py | 54 ++++++++++++++------ artemis/ml/tools/running_averages.py | 7 +++ 6 files changed, 149 insertions(+), 51 deletions(-) diff --git a/artemis/general/global_rates.py b/artemis/general/global_rates.py index 040e6eb8..d8e2d484 100644 --- a/artemis/general/global_rates.py +++ b/artemis/general/global_rates.py @@ -38,3 +38,27 @@ def measure_runtime_context(name): yield elapsed / n_calls if n_calls > 0 else float('nan') end = time.time() set_global(key, (n_calls+1, elapsed+(end-start))) + + +class _LastTimeMeasureSingleton: + pass + + +def is_elapsed(identifier, period, current = None): + """ + Return True if the given span has elapsed since this function last returned True + :param identifier: A string, or anything identifier + :param period: The span which should have elapsed for this to return True again. This is measured in time in seconds + if no argument is provided for "current" or for whatever the unit of "current" is otherwise. + :param current: Optionally, the current state of progress. If ommitted, this defaults to the current time. + :return bool: True if first call or at least "span" units of time have elapsed. + """ + if current is None: + current = time.time() + key = (_LastTimeMeasureSingleton, identifier) + last = get_global(key, constructor=lambda: -float('inf')) + assert current>=last, f"Current value ({current}) must be greater or equal to the last value ({last})" + has_elapsed = current - last >= period + if has_elapsed: + set_global(key, current) + return has_elapsed diff --git a/artemis/general/nested_structures.py b/artemis/general/nested_structures.py index 74e3f992..f319009e 100644 --- a/artemis/general/nested_structures.py +++ b/artemis/general/nested_structures.py @@ -122,7 +122,7 @@ def get_leaves_and_rebuilder(nested_object, is_container = is_container_or_gener # TODO: Consider making leaves a generator so this could be used for streams. leaves = [] meta_obj = get_meta_object(nested_object, is_container=is_container, flat_list=leaves) - return leaves, (lambda data_iteratable: _fill_meta_object(meta_object=meta_obj, data_iteratable=iter(data_iteratable), check_types=check_types, assert_fully_used=assert_fully_used, is_container_func=is_container)) + return leaves, (lambda data_iteratable: fill_meta_object(meta_object=meta_obj, data_iteratable=iter(data_iteratable), check_types=check_types, assert_fully_used=assert_fully_used, is_container_func=is_container)) def get_leaves(nested_object, is_container = is_primitive_container): @@ -238,7 +238,7 @@ def expand_from_leaves(self, leaves, check_types = True, assert_fully_used=True, :param assert_fully_used: Assert that all the leaf values are used :return: A nested object, filled with the leaf data, whose structure is represented in this NestedType instance. """ - return _fill_meta_object(self.meta_object, (x for x in leaves), check_types=check_types, assert_fully_used=assert_fully_used, is_container_func=is_container_func) + return fill_meta_object(self.meta_object, (x for x in leaves), check_types=check_types, assert_fully_used=assert_fully_used, is_container_func=is_container_func) @staticmethod def from_data(data_object, is_container_func = is_primitive_container): @@ -291,7 +291,7 @@ def get_leaf_values(data_object, is_container_func = is_primitive_container): return [data_object] -def _fill_meta_object(meta_object, data_iteratable, assert_fully_used = True, check_types = True, is_container_func = is_primitive_container): +def fill_meta_object(meta_object, data_iteratable, assert_fully_used = True, check_types = True, is_container_func = is_primitive_container): """ Fill the data from the iterable into the meta_object. :param meta_object: A nested type descripter. See NestedType init @@ -304,13 +304,13 @@ def _fill_meta_object(meta_object, data_iteratable, assert_fully_used = True, ch try: if is_container_func(meta_object): if isnamedtupleinstance(meta_object): - filled_object = type(meta_object)(*(_fill_meta_object(None, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func) for x in meta_object._fields)) + filled_object = type(meta_object)(*(fill_meta_object(None, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func) for x in meta_object._fields)) elif isinstance(meta_object, (list, tuple, set)): - filled_object = type(meta_object)(_fill_meta_object(x, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func) for x in meta_object) + filled_object = type(meta_object)(fill_meta_object(x, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func) for x in meta_object) elif isinstance(meta_object, OrderedDict): - filled_object = type(meta_object)((k, _fill_meta_object(val, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func)) for k, val in meta_object.items()) + filled_object = type(meta_object)((k, fill_meta_object(val, data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func)) for k, val in meta_object.items()) elif isinstance(meta_object, dict): - filled_object = type(meta_object)((k, _fill_meta_object(meta_object[k], data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func)) for k in sorted(meta_object.keys(), key=str)) + filled_object = type(meta_object)((k, fill_meta_object(meta_object[k], data_iteratable, assert_fully_used=False, check_types=check_types, is_container_func=is_container_func)) for k in sorted(meta_object.keys(), key=str)) else: raise Exception('Cannot handle container type: "{}"'.format(type(meta_object))) else: @@ -330,6 +330,9 @@ def _fill_meta_object(meta_object, data_iteratable, assert_fully_used = True, ch return filled_object +_fill_meta_object = fill_meta_object # For backwards compatibility + + def nested_map(func, *nested_objs, **kwargs): """ An equivalent of pythons built-in map, but for nested objects. This function crawls the object and applies func diff --git a/artemis/general/pareto_efficiency.py b/artemis/general/pareto_efficiency.py index 70dac8e7..ff816959 100644 --- a/artemis/general/pareto_efficiency.py +++ b/artemis/general/pareto_efficiency.py @@ -4,31 +4,38 @@ import numpy as np +# Very slow for many datapoints. Fastest for many costs, most readable def is_pareto_efficient_dumb(costs): """ + Find the pareto-efficient points :param costs: An (n_points, n_costs) array :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient """ is_efficient = np.ones(costs.shape[0], dtype = bool) for i, c in enumerate(costs): - is_efficient[i] = np.all(np.any(costs>=c, axis=1)) + is_efficient[i] = np.all(np.any(costs[:i]>c, axis=1)) and np.all(np.any(costs[i+1:]>c, axis=1)) return is_efficient -def is_pareto_efficient(costs): +# Fairly fast for many datapoints, less fast for many costs, somewhat readable +def is_pareto_efficient_simple(costs): """ + Find the pareto-efficient points :param costs: An (n_points, n_costs) array :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient """ is_efficient = np.ones(costs.shape[0], dtype = bool) for i, c in enumerate(costs): if is_efficient[i]: - is_efficient[is_efficient] = np.any(costs[is_efficient]<=c, axis=1) # Remove dominated points + is_efficient[is_efficient] = np.any(costs[is_efficient]0 - for c in costs[ixs]: - assert np.all(np.any(c<=costs, axis=1)) + assert np.sum(ixs)>0 + for c in costs[ixs]: + assert np.all(np.any(c<=costs, axis=1)) - if plot and n_costs==2: - import matplotlib.pyplot as plt - plt.plot(costs[:, 0], costs[:, 1], '.') - plt.plot(costs[ixs, 0], costs[ixs, 1], 'ro') - plt.show() + if plot and n_costs==2: + import matplotlib.pyplot as plt + plt.plot(costs[:, 0], costs[:, 1], '.') + plt.plot(costs[ixs, 0], costs[ixs, 1], 'ro') + plt.show() + + +def test_is_pareto_efficient_integer(): + + assert np.array_equal(is_pareto_efficient_dumb(np.array([[1,2], [3,4], [2,1], [1,1]])), [False, False, False, True]) + assert np.array_equal(is_pareto_efficient_simple(np.array([[1, 2], [3, 4], [2, 1], [1, 1]])), [False, False, False, True]) + assert np.array_equal(is_pareto_efficient(np.array([[1, 2], [3, 4], [2, 1], [1, 1]])), [False, False, False, True]) def profile_pareto_efficient(n_points=5000, n_costs=2, include_dumb = True): rng = np.random.RandomState(1234) - costs = rng.rand(n_points, n_costs) + costs = rng.randn(n_points, n_costs) + + print('{} samples, {} costs'.format(n_points, n_costs)) if include_dumb: with EZProfiler('is_pareto_efficient_dumb'): base_ixs = dumb_ixs = is_pareto_efficient_dumb(costs) + else: + print('is_pareto_efficient_dumb: Really, really, slow') - with EZProfiler('is_pareto_efficient'): - less_dumb__ixs = is_pareto_efficient(costs) + with EZProfiler('is_pareto_efficient_simple'): + less_dumb__ixs = is_pareto_efficient_simple(costs) if not include_dumb: base_ixs = less_dumb__ixs assert np.array_equal(base_ixs, less_dumb__ixs) - with EZProfiler('is_pareto_efficient_indexed'): - smart_indexed = is_pareto_efficient_indexed(costs, return_mask=True) + with EZProfiler('is_pareto_efficient_reordered'): + reordered_ixs = is_pareto_efficient_reordered(costs) + assert np.array_equal(base_ixs, reordered_ixs) + + with EZProfiler('is_pareto_efficient'): + smart_indexed = is_pareto_efficient(costs, return_mask=True) assert np.array_equal(base_ixs, smart_indexed) with EZProfiler('is_pareto_efficient_indexed_reordered'): - smart_indexed = is_pareto_efficient_indexed(costs, return_mask=True, rank_reorder=True) + smart_indexed = is_pareto_efficient_indexed_reordered(costs, return_mask=True) assert np.array_equal(base_ixs, smart_indexed) if __name__ == '__main__': # test_is_pareto_efficient() - profile_pareto_efficient(n_points=100000, n_costs=2, include_dumb=False) + # test_is_pareto_efficient_integer() + profile_pareto_efficient(n_points=10000, n_costs=2, include_dumb=True) + profile_pareto_efficient(n_points=1000000, n_costs=2, include_dumb=False) + profile_pareto_efficient(n_points=10000, n_costs=15, include_dumb=True) diff --git a/artemis/ml/tools/data_splitting.py b/artemis/ml/tools/data_splitting.py index 994b34ef..a54f411c 100644 --- a/artemis/ml/tools/data_splitting.py +++ b/artemis/ml/tools/data_splitting.py @@ -1,7 +1,7 @@ import numpy as np from six.moves import xrange -from artemis.general.nested_structures import NestedType +from artemis.general.nested_structures import get_meta_object, fill_meta_object, get_leaf_values from artemis.general.should_be_builtins import izip_equal __author__ = 'peter' @@ -28,7 +28,35 @@ def split_data_by_label(data, labels, frac_training = 0.5): return data[training_indices], labels[training_indices], data[test_indices], labels[test_indices] -def join_arrays_and_get_rebuild_func(arrays, axis = 0): +class ArrayStructRebuilder(object): + """ + A parameterized function which rebuilds a data structure given a flattened array containing the values. + Suggest using it through join_arrays_and_get_rebuild_func + """ + + def __init__(self, split_shapes, meta_object): + """ + :param Sequence[Tuple[int]] split_shapes: The shapes + :param Any meta_object: A nested object defining the structure in which to rebuild (see get_meta_object) + """ + self.split_shapes = split_shapes + self.meta_object = meta_object + + def __call__(self, joined_array, share_data = True, transform_func = None, check_types=True): + axis = joined_array.ndim-1 + pre_join_shapes = [list(s[:axis]) + [np.prod(list(s[axis:]), dtype=int)] for s in self.split_shapes] + split_axis_ixs = np.cumsum([0]+[s_[-1] for s_ in pre_join_shapes], axis=0) + if share_data: + x_split = [joined_array[..., start:end].reshape(shape) for (start, end, shape) in izip_equal(split_axis_ixs[:-1], split_axis_ixs[1:], self.split_shapes)] + else: # Note: this will raise an Error if the self.dim != 0, because the data is no longer contigious in memory. + x_split = [joined_array[..., start:end].copy().reshape(shape) for (start, end, shape) in izip_equal(split_axis_ixs[:-1], split_axis_ixs[1:], self.split_shapes)] + if transform_func is not None: + x_split = [transform_func(xs) for xs in x_split] + x_reassembled = fill_meta_object(self.meta_object, (x for x in x_split), check_types=check_types) + return x_reassembled + + +def join_arrays_and_get_rebuild_func(arrays, axis = 0, transform_func = None): """ Given a nested structure of arrays, join them into a single array by flattening dimensions from axis on concatenating them. Return the joined array and a function which can take the joined array and reproduce the @@ -36,22 +64,16 @@ def join_arrays_and_get_rebuild_func(arrays, axis = 0): :param arrays: A possibly nested structure containing arrays which you want to join into a single array. :param axis: Axis after which to flatten and join all arrays. The resulting array will be (dim+1) dimensional. - :return ndarray, Callable[[ndarray], [Any]]: The joined array, and the function which can be called to reconstruct + :param transform_func: Optionally, a function which you apply to every element in the nested struture of arrays first. + :return ndarray, ArrayStructRebuilder: The joined array, and the function which can be called to reconstruct the structure from the joined array. """ - nested_type = NestedType.from_data(arrays) - data_list = nested_type.get_leaves(arrays) + meta_object = get_meta_object(arrays) + data_list = get_leaf_values(arrays) + if transform_func is not None: + data_list = [transform_func(d) for d in data_list] split_shapes = [x_.shape for x_ in data_list] - pre_join_shapes = [list(x_.shape[:axis]) + [np.prod(list(x_.shape[axis:]), dtype=int)] for x_ in data_list] - split_axis_ixs = np.cumsum([0]+[s_[-1] for s_ in pre_join_shapes], axis=0) + pre_join_shapes = [list(s[:axis]) + [np.prod(list(s[axis:]), dtype=int)] for s in split_shapes] joined_arr = np.concatenate(list(x_.reshape(s_) for x_, s_ in izip_equal(data_list, pre_join_shapes)), axis=axis) - - def rebuild_function(joined_array, share_data = True): - if share_data: - x_split = [joined_array[..., start:end].reshape(shape) for (start, end, shape) in izip_equal(split_axis_ixs[:-1], split_axis_ixs[1:], split_shapes)] - else: # Note: this will raise an Error if the self.dim != 0, because the data is no longer contigious in memory. - x_split = [joined_array[..., start:end].copy().reshape(shape) for (start, end, shape) in izip_equal(split_axis_ixs[:-1], split_axis_ixs[1:], split_shapes)] - x_reassembled = nested_type.expand_from_leaves(x_split, check_types=False) - return x_reassembled - + rebuild_function = ArrayStructRebuilder(split_shapes=split_shapes, meta_object=meta_object) return joined_arr, rebuild_function diff --git a/artemis/ml/tools/running_averages.py b/artemis/ml/tools/running_averages.py index 3dedaf48..08a890d2 100644 --- a/artemis/ml/tools/running_averages.py +++ b/artemis/ml/tools/running_averages.py @@ -1,5 +1,6 @@ import numpy as np +from artemis.general.global_rates import is_elapsed from artemis.general.global_vars import get_global from artemis.general.mymath import recent_moving_average from artemis.ml.tools.processors import IDifferentiableFunction @@ -156,3 +157,9 @@ def get_global_running_average(value, identifier, ra_type='simple'): """ running_averager = get_global(identifier=identifier, constructor=lambda: (ra_type() if callable(ra_type) else {'simple': RunningAverage, 'recent': RecentRunningAverage, 'osa': OptimalStepSizeAverage}[ra_type]())) return running_averager(value) + + +def periodically_report_running_average(identifier, time, period, value, ra_type = 'simple', format_str = '{identifier}: Average at t={time:.3g}: {avg:.3g} '): + avg = get_global_running_average(value=value, identifier=identifier, ra_type=ra_type) + if is_elapsed(identifier, period=period, current=time): + print(format_str.format(identifier=identifier, time=time, avg=avg)) From c1fa4b610427b9a0c0c96e4e718cc2ec193e2ee2 Mon Sep 17 00:00:00 2001 From: Peter Date: Wed, 19 Dec 2018 15:38:07 +0100 Subject: [PATCH 14/41] ooook --- artemis/general/pareto_efficiency.py | 2 +- artemis/general/test_pareto_efficiency.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/artemis/general/pareto_efficiency.py b/artemis/general/pareto_efficiency.py index ff816959..038acf54 100644 --- a/artemis/general/pareto_efficiency.py +++ b/artemis/general/pareto_efficiency.py @@ -32,7 +32,7 @@ def is_pareto_efficient_simple(costs): return is_efficient -# Fastest than is_pareto_efficient, but less readable. +# Faster than is_pareto_efficient_simple, but less readable. def is_pareto_efficient(costs, return_mask = True): """ Find the pareto-efficient points diff --git a/artemis/general/test_pareto_efficiency.py b/artemis/general/test_pareto_efficiency.py index d90f61ac..7f48a60a 100644 --- a/artemis/general/test_pareto_efficiency.py +++ b/artemis/general/test_pareto_efficiency.py @@ -69,8 +69,8 @@ def profile_pareto_efficient(n_points=5000, n_costs=2, include_dumb = True): if __name__ == '__main__': - # test_is_pareto_efficient() - # test_is_pareto_efficient_integer() + test_is_pareto_efficient() + test_is_pareto_efficient_integer() profile_pareto_efficient(n_points=10000, n_costs=2, include_dumb=True) profile_pareto_efficient(n_points=1000000, n_costs=2, include_dumb=False) profile_pareto_efficient(n_points=10000, n_costs=15, include_dumb=True) From f4a2da60a832c6d8b78bcf95d8a3f5496b56504b Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Fri, 21 Dec 2018 15:56:08 +0100 Subject: [PATCH 15/41] before-changing-again --- artemis/general/global_rates.py | 20 ++++--- artemis/general/global_vars.py | 4 ++ artemis/ml/tools/running_averages.py | 29 +++++++--- artemis/plotting/db_plotting.py | 67 ++++++++++++----------- artemis/plotting/point_remapping_plots.py | 39 +++++++++++++ artemis/plotting/pyplot_plus.py | 1 + artemis/plotting/test_db_plotting.py | 14 ++++- 7 files changed, 128 insertions(+), 46 deletions(-) create mode 100644 artemis/plotting/point_remapping_plots.py diff --git a/artemis/general/global_rates.py b/artemis/general/global_rates.py index d8e2d484..de4a8a50 100644 --- a/artemis/general/global_rates.py +++ b/artemis/general/global_rates.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from artemis.general.global_vars import get_global, set_global +from artemis.general.global_vars import get_global, set_global, has_global import time @@ -44,21 +44,27 @@ class _LastTimeMeasureSingleton: pass -def is_elapsed(identifier, period, current = None): +def is_elapsed(identifier, period, current = None, count_initial = True): """ Return True if the given span has elapsed since this function last returned True :param identifier: A string, or anything identifier :param period: The span which should have elapsed for this to return True again. This is measured in time in seconds if no argument is provided for "current" or for whatever the unit of "current" is otherwise. :param current: Optionally, the current state of progress. If ommitted, this defaults to the current time. + :param count_initial: Count the initial point :return bool: True if first call or at least "span" units of time have elapsed. """ if current is None: current = time.time() key = (_LastTimeMeasureSingleton, identifier) - last = get_global(key, constructor=lambda: -float('inf')) - assert current>=last, f"Current value ({current}) must be greater or equal to the last value ({last})" - has_elapsed = current - last >= period - if has_elapsed: + + if not has_global(key): set_global(key, current) - return has_elapsed + return count_initial + else: + last = get_global(key) + assert current>=last, f"Current value ({current}) must be greater or equal to the last value ({last})" + has_elapsed = current - last >= period + if has_elapsed: + set_global(key, current) + return has_elapsed diff --git a/artemis/general/global_vars.py b/artemis/general/global_vars.py index fed2a38d..6cc51940 100644 --- a/artemis/general/global_vars.py +++ b/artemis/general/global_vars.py @@ -24,6 +24,10 @@ def get_global(identifier, constructor=None): return _GLOBALS[identifier] +def has_global(identifier): + return identifier in _GLOBALS + + def set_global(identifier, value): _GLOBALS[identifier] = value diff --git a/artemis/ml/tools/running_averages.py b/artemis/ml/tools/running_averages.py index 08a890d2..aff7066a 100644 --- a/artemis/ml/tools/running_averages.py +++ b/artemis/ml/tools/running_averages.py @@ -1,7 +1,7 @@ import numpy as np from artemis.general.global_rates import is_elapsed -from artemis.general.global_vars import get_global +from artemis.general.global_vars import get_global, has_global, set_global from artemis.general.mymath import recent_moving_average from artemis.ml.tools.processors import IDifferentiableFunction @@ -147,7 +147,14 @@ def __call__(self, x): _running_averages = {} -def get_global_running_average(value, identifier, ra_type='simple'): +def construct_running_averager(ra_type): + if callable(ra_type): + return ra_type() + else: + return {'simple': RunningAverage, 'recent': RecentRunningAverage, 'osa': OptimalStepSizeAverage}[ra_type]() + + +def get_global_running_average(value, identifier, ra_type='simple', reset=False): """ Get the running average of a variable. :param value: The latest value of the variable @@ -155,11 +162,19 @@ def get_global_running_average(value, identifier, ra_type='simple'): :param ra_type: The type of running averge. Options are 'simple', 'recent', 'osa' :return: The running average """ - running_averager = get_global(identifier=identifier, constructor=lambda: (ra_type() if callable(ra_type) else {'simple': RunningAverage, 'recent': RecentRunningAverage, 'osa': OptimalStepSizeAverage}[ra_type]())) - return running_averager(value) + if not has_global(identifier): + set_global(identifier, construct_running_averager(ra_type)) + running_averager = get_global(identifier=identifier) + avg = running_averager(value) + if reset: + set_global(identifier, construct_running_averager(ra_type)) + return avg + + +def periodically_report_running_average(identifier, time, period, value, ra_type = 'simple', format_str = '{identifier}: Average at t={time:.3g}: {avg:.3g} ', reset_between = False): -def periodically_report_running_average(identifier, time, period, value, ra_type = 'simple', format_str = '{identifier}: Average at t={time:.3g}: {avg:.3g} '): - avg = get_global_running_average(value=value, identifier=identifier, ra_type=ra_type) - if is_elapsed(identifier, period=period, current=time): + report_time = is_elapsed(identifier, period=period, current=time, count_initial=False) + avg = get_global_running_average(value=value, identifier=identifier, ra_type=ra_type, reset=reset_between and report_time) + if report_time: print(format_str.format(identifier=identifier, time=time, avg=avg)) diff --git a/artemis/plotting/db_plotting.py b/artemis/plotting/db_plotting.py index e7ee3a31..06753a90 100644 --- a/artemis/plotting/db_plotting.py +++ b/artemis/plotting/db_plotting.py @@ -84,16 +84,9 @@ def dbplot(data, name = None, plot_type = None, axis=None, plot_mode = 'live', d if data.__class__.__module__ == 'torch' and data.__class__.__name__ == 'Tensor': data = data.detach().cpu().numpy() - if isinstance(fig, plt.Figure): - assert None not in _DBPLOT_FIGURES, "If you pass a figure, you can only do it on the first call to dbplot (for now)" - _DBPLOT_FIGURES[None] = _PlotWindow(figure=fig, subplots=OrderedDict(), axes={}) - fig = None - elif fig not in _DBPLOT_FIGURES or not plt.fignum_exists(_DBPLOT_FIGURES[fig].figure.number): # Second condition handles closed figures. - _DBPLOT_FIGURES[fig] = _PlotWindow(figure = _make_dbplot_figure(), subplots=OrderedDict(), axes = {}) - if fig is not None: - _DBPLOT_FIGURES[fig].figure.canvas.set_window_title(fig) + plot_object = _get_dbplot_plot_object(fig) # type: _PlotWindow - suplot_dict = _DBPLOT_FIGURES[fig].subplots + suplot_dict = plot_object.subplots if axis is None: axis=name @@ -123,51 +116,51 @@ def dbplot(data, name = None, plot_type = None, axis=None, plot_mode = 'live', d ax = axis ax_name = str(axis) elif isinstance(axis, string_types) or axis is None: - ax = select_subplot(axis, fig=_DBPLOT_FIGURES[fig].figure, layout=_default_layout if layout is None else layout) + ax = select_subplot(axis, fig=plot_object.figure, layout=_default_layout if layout is None else layout) ax_name = axis # ax.set_title(axis) else: raise Exception("Axis specifier must be a string, an Axis object, or a SubplotSpec object. Not {}".format(axis)) - if ax_name not in _DBPLOT_FIGURES[fig].axes: + if ax_name not in plot_object.axes: ax.set_title(name) - _DBPLOT_FIGURES[fig].subplots[name] = _Subplot(axis=ax, plot_object=plot) - _DBPLOT_FIGURES[fig].axes[ax_name] = ax + plot_object.subplots[name] = _Subplot(axis=ax, plot_object=plot) + plot_object.axes[ax_name] = ax - _DBPLOT_FIGURES[fig].subplots[name] = _Subplot(axis=_DBPLOT_FIGURES[fig].axes[ax_name], plot_object=plot) - plt.sca(_DBPLOT_FIGURES[fig].axes[ax_name]) + plot_object.subplots[name] = _Subplot(axis=plot_object.axes[ax_name], plot_object=plot) + plt.sca(plot_object.axes[ax_name]) if xlabel is not None: - _DBPLOT_FIGURES[fig].subplots[name].axis.set_xlabel(xlabel) + plot_object.subplots[name].axis.set_xlabel(xlabel) if ylabel is not None: - _DBPLOT_FIGURES[fig].subplots[name].axis.set_ylabel(ylabel) + plot_object.subplots[name].axis.set_ylabel(ylabel) if draw_every is not None: _draw_counters[fig, name] = Checkpoints(draw_every) if grid: plt.grid() - plot = _DBPLOT_FIGURES[fig].subplots[name].plot_object + plot = plot_object.subplots[name].plot_object if reset_color_cycle: - get_dbplot_axis(axis_name=axis, fig=fig).set_color_cycle(None) + use_dbplot_axis(axis, fig=fig, clear=False).set_color_cycle(None) plot.update(data) # Update Labels... if cornertext is not None: - if not hasattr(_DBPLOT_FIGURES[fig].figure, '__cornertext'): - _DBPLOT_FIGURES[fig].figure.__cornertext = next(iter(_DBPLOT_FIGURES[fig].subplots.values())).axis.annotate(cornertext, xy=(0, 0), xytext=(0.01, 0.98), textcoords='figure fraction') + if not hasattr(plot_object.figure, '__cornertext'): + plot_object.figure.__cornertext = next(iter(plot_object.subplots.values())).axis.annotate(cornertext, xy=(0, 0), xytext=(0.01, 0.98), textcoords='figure fraction') else: - _DBPLOT_FIGURES[fig].figure.__cornertext.set_text(cornertext) + plot_object.figure.__cornertext.set_text(cornertext) if title is not None: - _DBPLOT_FIGURES[fig].subplots[name].axis.set_title(title) + plot_object.subplots[name].axis.set_title(title) if legend is not None: - _DBPLOT_FIGURES[fig].subplots[name].axis.legend(legend, loc='best', framealpha=0.5) + plot_object.subplots[name].axis.legend(legend, loc='best', framealpha=0.5) if draw_now and not _hold_plots and (draw_every is None or ((fig, name) not in _draw_counters) or _draw_counters[fig, name]()): plot.plot() - display_figure(_DBPLOT_FIGURES[fig].figure, hang=hang) + display_figure(plot_object.figure, hang=hang) - return _DBPLOT_FIGURES[fig].subplots[name].axis + return plot_object.subplots[name].axis _PlotWindow = namedtuple('PlotWindow', ['figure', 'subplots', 'axes']) @@ -243,6 +236,18 @@ def get_dbplot_figure(name=None): return _DBPLOT_FIGURES[name].figure +def _get_dbplot_plot_object(fig): + if isinstance(fig, plt.Figure): + assert None not in _DBPLOT_FIGURES, "If you pass a figure, you can only do it on the first call to dbplot (for now)" + _DBPLOT_FIGURES[None] = _PlotWindow(figure=fig, subplots=OrderedDict(), axes={}) + fig = None + elif fig not in _DBPLOT_FIGURES or not plt.fignum_exists(_DBPLOT_FIGURES[fig].figure.number): # Second condition handles closed figures. + _DBPLOT_FIGURES[fig] = _PlotWindow(figure = _make_dbplot_figure(), subplots=OrderedDict(), axes = {}) + if fig is not None: + _DBPLOT_FIGURES[fig].figure.canvas.set_window_title(fig) + return _DBPLOT_FIGURES[fig] + + def get_dbplot_subplot(name, fig_name=None): return _DBPLOT_FIGURES[fig_name].subplots[name].axis @@ -328,11 +333,11 @@ def clear_dbplot(fig = None): _DBPLOT_FIGURES[fig].axes.clear() -def get_dbplot_axis(axis_name, fig=None): - """ - Get the named axis of a dbplot. - """ - return _DBPLOT_FIGURES[fig].axes[axis_name] +def use_dbplot_axis(name, fig=None, layout=None, clear = False, ): + ax = select_subplot(name, fig=_get_dbplot_plot_object(fig).figure, layout=_default_layout if layout is None else layout) + if clear: + ax.clear() + return ax def dbplot_hang(timeout=None): diff --git a/artemis/plotting/point_remapping_plots.py b/artemis/plotting/point_remapping_plots.py new file mode 100644 index 00000000..dc054712 --- /dev/null +++ b/artemis/plotting/point_remapping_plots.py @@ -0,0 +1,39 @@ +import numpy as np +from matplotlib import pyplot as plt + + +def get_2d_point_colours(points): + points_norm = (points - points.min(axis=0, keepdims=True)) / (points.max(axis=0, keepdims=True) - points.min(axis=0, keepdims=True)) + return [(y, x, 1-x) for x, y in points_norm] + + +def plot_2D_mapping(old_xy_points, new_xy_points, axes = None, old_title = 'x', new_title = 'f(x)'): + """ + :param old_xy_points: (N,2) array + :param new_xy_points: (Nx2) array + """ + + colours = get_2d_point_colours(old_xy_points) + + ax = plt.subplot(1, 2, 1) if axes is None else axes[0] + ax.scatter(old_xy_points[:, 0], old_xy_points[:, 1], c=colours) + ax.set_title(old_title) + + ax = plt.subplot(1, 2, 2) if axes is None else axes[1] + ax.scatter(new_xy_points[:, 0], new_xy_points[:, 1], c=colours) + ax.set_title(new_title) + + +if __name__ == '__main__': + n_x = 40 + n_y = 30 + + # Generate a grid of points + old_xy_points = np.array([v.flatten() for v in np.meshgrid(np.linspace(0, 1, n_y), np.linspace(0, 1, n_x))]).T + + # Apply some transformation + theta = 5*np.pi/6 + transform_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + new_xy_points = np.tanh(old_xy_points @ transform_matrix) + + plot_2D_mapping(old_xy_points, new_xy_points) diff --git a/artemis/plotting/pyplot_plus.py b/artemis/plotting/pyplot_plus.py index 43a43e73..3020924a 100644 --- a/artemis/plotting/pyplot_plus.py +++ b/artemis/plotting/pyplot_plus.py @@ -260,6 +260,7 @@ def _get_centered_colour_scale(cmin, cmax): })) + def center_colour_scale(h): current_min, current_max = h.get_clim() absmax = np.maximum(np.abs(current_min), np.abs(current_max)) diff --git a/artemis/plotting/test_db_plotting.py b/artemis/plotting/test_db_plotting.py index b6a8178f..41f8f4c1 100644 --- a/artemis/plotting/test_db_plotting.py +++ b/artemis/plotting/test_db_plotting.py @@ -4,7 +4,7 @@ import numpy as np from artemis.plotting.demo_dbplot import demo_dbplot from artemis.plotting.db_plotting import dbplot, clear_dbplot, hold_dbplots, freeze_all_dbplots, reset_dbplot, \ - dbplot_hang, DBPlotTypes + dbplot_hang, DBPlotTypes, use_dbplot_axis from artemis.plotting.matplotlib_backend import LinePlot, HistogramPlot, MovingPointPlot, is_server_plotting_on, \ ResamplingLineHistory import pytest @@ -213,6 +213,17 @@ def test_bbox_display(): dbplot([10, 20, 25, 30], 'bbox', axis='img', plot_type=DBPlotTypes.BBOX) +def test_inline_custom_plots(): + + for t in range(10): + with hold_dbplots(): + x = np.sin(t/10. + np.linspace(0, 10, 200)) + dbplot(x, 'x', plot_type='line') + use_dbplot_axis('custom', clear=True) + plt.plot(x, label='x', linewidth=2) + plt.plot(x**2, label='$x**2$', linewidth=2) + + if __name__ == '__main__': test_cornertext() test_trajectory_plot() @@ -232,3 +243,4 @@ def test_bbox_display(): test_periodic_plotting() test_individual_periodic_plotting() test_bbox_display() + test_inline_custom_plots() \ No newline at end of file From 1339210873156f2e49994ef7514e300ec171a82e Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Fri, 21 Dec 2018 18:04:06 +0100 Subject: [PATCH 16/41] puuuush --- artemis/ml/tools/running_averages.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/artemis/ml/tools/running_averages.py b/artemis/ml/tools/running_averages.py index aff7066a..0a09705f 100644 --- a/artemis/ml/tools/running_averages.py +++ b/artemis/ml/tools/running_averages.py @@ -172,9 +172,12 @@ def get_global_running_average(value, identifier, ra_type='simple', reset=False) return avg -def periodically_report_running_average(identifier, time, period, value, ra_type = 'simple', format_str = '{identifier}: Average at t={time:.3g}: {avg:.3g} ', reset_between = False): +def periodically_report_running_average(identifier, time, period, value, ra_type = 'simple', format_str = '{identifier}: Average at t={time:.3g}: {avg} ', reset_between = False): report_time = is_elapsed(identifier, period=period, current=time, count_initial=False) - avg = get_global_running_average(value=value, identifier=identifier, ra_type=ra_type, reset=reset_between and report_time) + if not isinstance(value, dict): + avg = get_global_running_average(value=value, identifier=identifier, ra_type=ra_type, reset=reset_between and report_time) + else: + avg = {k: f'{get_global_running_average(value=v, identifier=(identifier, k), ra_type=ra_type, reset=reset_between and report_time):.3g}' for k, v in value.items()} if report_time: print(format_str.format(identifier=identifier, time=time, avg=avg)) From 0962503879d85cf40f9e71f58de36476af507aaa Mon Sep 17 00:00:00 2001 From: Peter Date: Wed, 2 Jan 2019 20:17:09 +0100 Subject: [PATCH 17/41] added async dataloaders --- artemis/general/async.py | 66 +++++++++++++++++++++++++++++++++++ artemis/general/test_async.py | 56 +++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 artemis/general/async.py create mode 100644 artemis/general/test_async.py diff --git a/artemis/general/async.py b/artemis/general/async.py new file mode 100644 index 00000000..148fd14f --- /dev/null +++ b/artemis/general/async.py @@ -0,0 +1,66 @@ +from multiprocessing import Process, Queue, Manager, Value, Lock +import time + + +class PoisonPill: + pass + + +def _async_queue_manager(gen_func, queue): + for item in gen_func(): + queue.put(item) + queue.put(PoisonPill) + + +def iter_asynchronously(gen_func): + """ Given a generator function, make it asynchonous. """ + q = Queue() + p = Process(target=_async_queue_manager, args=(gen_func, q)) + p.start() + while True: + item = q.get() + if item is PoisonPill: + break + else: + yield item + + +def _async_value_setter(gen_func, namespace, lock): + for item in gen_func(): + with lock: + namespace.time_and_data = (time.time(), item) + with lock: + namespace.time_and_data = (time.time(), PoisonPill) + + +class Uninitialized: + pass + + +def iter_latest_asynchonously(gen_func, timeout = None, empty_value = None): + """ + Given a generator function, make an iterator that pulls the latest value yielded when running it asynchronously. + If a value has never been set, or timeout is exceeded, yield empty_value instead. + + :param gen_func: A generator function (a function returning a generator); + :return: + """ + m = Manager() + namespace = m.Namespace() + + lock = Lock() + + with lock: + namespace.time_and_data = (-float('inf'), Uninitialized) + + p = Process(target=_async_value_setter, args=(gen_func, namespace, lock)) + p.start() + while True: + with lock: + lasttime, item = namespace.time_and_data + if item is PoisonPill: # The generator has terminated + break + elif item is Uninitialized or timeout is not None and (time.time() - lasttime) > timeout: # Nothing written or nothing recent enough + yield empty_value + else: + yield item diff --git a/artemis/general/test_async.py b/artemis/general/test_async.py new file mode 100644 index 00000000..935d5277 --- /dev/null +++ b/artemis/general/test_async.py @@ -0,0 +1,56 @@ +import time +from functools import partial + +from artemis.general.async import iter_asynchronously, iter_latest_asynchonously + +LOAD_INTERVAL = 0.1 + +# SUM_INTERVAL = LOAD_INTERVAL + PROCESS_INTERVAL + + +def dataloader_example(upto): + + for i in range(upto): + time.sleep(LOAD_INTERVAL) + yield i + + +def test_async_dataloader(): + + process_interval = 0.1 + start = time.time() + for data in dataloader_example(upto=4): + time.sleep(process_interval) + elapsed = time.time()-start + print('Sync Processed Data {} at t={:.3g}: '.format(data, elapsed)) + assert (LOAD_INTERVAL + process_interval)*4 < elapsed < (LOAD_INTERVAL + process_interval)*5 + print('Sync: {:.4g}s elapsed'.format(elapsed)) + + start = time.time() + for data in iter_asynchronously(partial(dataloader_example, upto=4)): + time.sleep(process_interval) + elapsed = time.time()-start + print('Sync Processed Data {} at t={:.3g}: '.format(data, elapsed)) + print('Async: {:.4g}s elapsed'.format(elapsed)) + assert LOAD_INTERVAL + max(LOAD_INTERVAL, process_interval)*4 < elapsed < LOAD_INTERVAL + max(LOAD_INTERVAL, process_interval)*5 + + +def test_async_value_setter(): + + process_interval = 0.25 + + start = time.time() + data_points = [] + for data in iter_latest_asynchonously(gen_func = partial(dataloader_example, upto=10)): + time.sleep(process_interval) + data_points.append(data) + elapsed = time.time()-start + + assert data_points[0] is None + assert all(dn-dp > 1 for dn, dp in zip(data_points[2:], data_points[1:-1])) + print(data_points) + + +if __name__ == '__main__': + test_async_dataloader() + test_async_value_setter() From ba84ebd46f6e06ef3ac10339fcbbfe91ef764bc5 Mon Sep 17 00:00:00 2001 From: Peter Date: Wed, 2 Jan 2019 22:29:10 +0100 Subject: [PATCH 18/41] async updates --- artemis/general/async.py | 15 ++++++++++++--- artemis/general/test_async.py | 2 ++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/artemis/general/async.py b/artemis/general/async.py index 148fd14f..36b794e5 100644 --- a/artemis/general/async.py +++ b/artemis/general/async.py @@ -1,4 +1,4 @@ -from multiprocessing import Process, Queue, Manager, Value, Lock +from multiprocessing import Process, Queue, Manager, Lock, set_start_method import time @@ -37,7 +37,7 @@ class Uninitialized: pass -def iter_latest_asynchonously(gen_func, timeout = None, empty_value = None): +def iter_latest_asynchonously(gen_func, timeout = None, empty_value = None, use_forkserver = False, uninitialized_wait = None): """ Given a generator function, make an iterator that pulls the latest value yielded when running it asynchronously. If a value has never been set, or timeout is exceeded, yield empty_value instead. @@ -45,6 +45,9 @@ def iter_latest_asynchonously(gen_func, timeout = None, empty_value = None): :param gen_func: A generator function (a function returning a generator); :return: """ + if use_forkserver: + set_start_method('forkserver') # On macos this is necessary to start camera in separate thread + m = Manager() namespace = m.Namespace() @@ -60,7 +63,13 @@ def iter_latest_asynchonously(gen_func, timeout = None, empty_value = None): lasttime, item = namespace.time_and_data if item is PoisonPill: # The generator has terminated break - elif item is Uninitialized or timeout is not None and (time.time() - lasttime) > timeout: # Nothing written or nothing recent enough + elif item is Uninitialized: + if uninitialized_wait is not None: + time.sleep(uninitialized_wait) + continue + else: + yield empty_value + elif timeout is not None and (time.time() - lasttime) > timeout: # Nothing written or nothing recent enough yield empty_value else: yield item diff --git a/artemis/general/test_async.py b/artemis/general/test_async.py index 935d5277..06656517 100644 --- a/artemis/general/test_async.py +++ b/artemis/general/test_async.py @@ -51,6 +51,8 @@ def test_async_value_setter(): print(data_points) + + if __name__ == '__main__': test_async_dataloader() test_async_value_setter() From 455ebe4286dbf502c7c79a5b63e7c5fc9af0dd45 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Thu, 3 Jan 2019 12:33:32 +0100 Subject: [PATCH 19/41] rate limiter made --- artemis/general/global_rates.py | 51 ++++++++++++++++++++++++++++ artemis/general/test_global_rates.py | 34 +++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 artemis/general/test_global_rates.py diff --git a/artemis/general/global_rates.py b/artemis/general/global_rates.py index de4a8a50..97432a79 100644 --- a/artemis/general/global_rates.py +++ b/artemis/general/global_rates.py @@ -44,6 +44,25 @@ class _LastTimeMeasureSingleton: pass +def elapsed_time(identifier, current = None): + """ + Return the time that has elapsed since this function was called with the given identifier. + """ + if current is None: + current = time.time() + key = (_LastTimeMeasureSingleton, identifier) + + if not has_global(key): + set_global(key, current) + return float('inf') + else: + last = get_global(key) + assert current>=last, f"Current value ({current}) must be greater or equal to the last value ({last})" + elapsed = current - last + set_global(key, current) + return elapsed + + def is_elapsed(identifier, period, current = None, count_initial = True): """ Return True if the given span has elapsed since this function last returned True @@ -68,3 +87,35 @@ def is_elapsed(identifier, period, current = None, count_initial = True): if has_elapsed: set_global(key, current) return has_elapsed + + +def limit_rate(identifier, period): + """ + :param identifier: Any python object to uniquely identify what you're limiting. + :param period: The minimum period + :param current: The time measure (if None, system time will be used) + :return: Whether the rate was exceeded (True) or not (False) + """ + + enter_time = time.time() + key = (_LastTimeMeasureSingleton, identifier) + if not has_global(key): # First call + set_global(key, enter_time) + return False + else: + last = get_global(key) + assert enter_time>=last, f"Current value ({current}) must be greater or equal to the last value ({last})" + elapsed = enter_time - last + if elapsed < period: # Rate has been exceeded + time.sleep(period - elapsed) + set_global(key, time.time()) + return False + else: + set_global(key, enter_time) + return True + + +def limit_iteration_rate(iterable, period): + for x in iterable: + limit_rate(id(iterable), period=period) + yield x diff --git a/artemis/general/test_global_rates.py b/artemis/general/test_global_rates.py new file mode 100644 index 00000000..f4e3fcfb --- /dev/null +++ b/artemis/general/test_global_rates.py @@ -0,0 +1,34 @@ +import itertools +import time + +from artemis.general.global_rates import limit_rate, limit_iteration_rate +from artemis.general.global_vars import global_context + + +def test_limit_rate(): + + with global_context(): + start = time.time() + for t in itertools.count(0): + limit_rate('this_rate', period=0.1) + current = time.time() + if current - start > 0.5: + break + print((t, current - start)) + assert t<6 + + +def test_limit_rate_iterator(): + with global_context(): + start = time.time() + for t in limit_iteration_rate(itertools.count(0), period=0.1): + current = time.time() + if current - start > 0.5: + break + print((t, current - start)) + assert t<6 + + +if __name__ == '__main__': + # test_limit_rate() + test_limit_rate_iterator() From eb56b70f839b3ef0b59f4ac10aeffaae8097e976 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Thu, 3 Jan 2019 12:40:49 +0100 Subject: [PATCH 20/41] oook --- artemis/general/test_global_rates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/artemis/general/test_global_rates.py b/artemis/general/test_global_rates.py index f4e3fcfb..37cd368f 100644 --- a/artemis/general/test_global_rates.py +++ b/artemis/general/test_global_rates.py @@ -30,5 +30,5 @@ def test_limit_rate_iterator(): if __name__ == '__main__': - # test_limit_rate() + test_limit_rate() test_limit_rate_iterator() From 64e7c269b87f2e5b306a0b83afe968dbf3ca7274 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Fri, 4 Jan 2019 17:25:15 +0100 Subject: [PATCH 21/41] profile improvements --- artemis/general/ezprofile.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/artemis/general/ezprofile.py b/artemis/general/ezprofile.py index cdf853be..eaca4b88 100644 --- a/artemis/general/ezprofile.py +++ b/artemis/general/ezprofile.py @@ -83,11 +83,22 @@ def profile_context(name, print_result = False): def get_profile_contexts(names=None, fill_empty_with_zero = False): - + """ + :param names: Names of profiling contexts to get (from previous calls to profile_context). If None, use all. + :param fill_empty_with_zero: If names are not found, just fill with zeros. + :return: An OrderedDict + """ if names is None: return _profile_contexts else: if fill_empty_with_zero: - return OrderedDict((k, _profile_contexts[k] if k in _profile_contexts else 0) for k in names) + return OrderedDict((k, _profile_contexts[k] if k in _profile_contexts else (0, 0.)) for k in names) else: return OrderedDict((k, _profile_contexts[k]) for k in names) + + +def get_profile_contexts_string(names=None, fill_empty_with_zero = False): + + profile = get_profile_contexts(names=names, fill_empty_with_zero=fill_empty_with_zero) + string = ', '.join(f'{name}: {elapsed/n_calls:.3g}s/iter' for name, (n_calls, elapsed) in profile.items()) + return string From 0a70f5bf5770a15c9486ef8e66f0982c699e973f Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Sat, 5 Jan 2019 14:38:54 +0100 Subject: [PATCH 22/41] compatibility --- artemis/general/ezprofile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/artemis/general/ezprofile.py b/artemis/general/ezprofile.py index eaca4b88..26ea9c86 100644 --- a/artemis/general/ezprofile.py +++ b/artemis/general/ezprofile.py @@ -100,5 +100,5 @@ def get_profile_contexts(names=None, fill_empty_with_zero = False): def get_profile_contexts_string(names=None, fill_empty_with_zero = False): profile = get_profile_contexts(names=names, fill_empty_with_zero=fill_empty_with_zero) - string = ', '.join(f'{name}: {elapsed/n_calls:.3g}s/iter' for name, (n_calls, elapsed) in profile.items()) + string = ', '.join('{}: {:.3g}s/iter'.format(name, elapsed/n_calls) for name, (n_calls, elapsed) in profile.items()) return string From 8e9aa2f36be9a9d6a8cabc0638d5e5a4a8fb2906 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Thu, 17 Jan 2019 17:43:45 +0100 Subject: [PATCH 23/41] ook --- artemis/general/should_be_builtins.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/artemis/general/should_be_builtins.py b/artemis/general/should_be_builtins.py index 4244c34f..1a1dac20 100644 --- a/artemis/general/should_be_builtins.py +++ b/artemis/general/should_be_builtins.py @@ -150,6 +150,19 @@ def izip_equal(*iterables): yield combo +def adjacent_pairs(iterable): + """ + Given an iterable like ['a', 'b', 'c', 'd'], yield adjacent pairs like [('a', 'b'), ('b', 'c'), ('c', 'd')] + :param iterable: + :return: + """ + iterator = iter(iterable) + last = next(iterator) + for item in iterator: + yield (last, item) + last = item + + def remove_duplicates(sequence, hashable=True, key=None, keep_last=False): """ Remove duplicates while maintaining order. From c97b756af723acb66cd66b4044a64d57b461e7de Mon Sep 17 00:00:00 2001 From: Peter Date: Tue, 22 Jan 2019 18:54:53 +0100 Subject: [PATCH 24/41] oook --- artemis/general/global_rates.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/artemis/general/global_rates.py b/artemis/general/global_rates.py index 97432a79..4128f042 100644 --- a/artemis/general/global_rates.py +++ b/artemis/general/global_rates.py @@ -8,12 +8,14 @@ class _RateMeasureSingleton: pass -def measure_global_rate(name): +def measure_global_rate(name, n_steps = None): this_time = time.time() key = (_RateMeasureSingleton, name) n_calls, start_time = get_global(key, constructor=lambda: (0, this_time)) - set_global(key, (n_calls+1, start_time)) - return n_calls / (this_time - start_time) if this_time!=start_time else float('inf') + if n_steps is None: + n_steps = n_calls + set_global(key, (n_steps+1, start_time)) + return n_steps / (this_time - start_time) if this_time!=start_time else float('inf') class _ElapsedMeasureSingleton: From 792bd8df02c645e1f353101fd5a43fbe0447596d Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Thu, 24 Jan 2019 16:49:14 +0100 Subject: [PATCH 25/41] lilthings --- artemis/general/dead_easy_ui.py | 2 +- artemis/general/should_be_builtins.py | 32 ++++++++++++++++++++++ artemis/general/test_dead_easy_ui.py | 10 +++++++ artemis/general/test_should_be_builtins.py | 26 ++++++++++++++++-- artemis/ml/parameter_schedule.py | 2 +- artemis/ml/tools/iteration.py | 19 ++++++++----- artemis/ml/tools/test_iteration.py | 12 +++++++- 7 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 artemis/general/test_dead_easy_ui.py diff --git a/artemis/general/dead_easy_ui.py b/artemis/general/dead_easy_ui.py index 4285dc95..d415602d 100644 --- a/artemis/general/dead_easy_ui.py +++ b/artemis/general/dead_easy_ui.py @@ -173,7 +173,7 @@ def parse_arg(arg_str): arg_name, arg_val = arg.split('=', 1) kwargs[arg_name] = parse_arg(arg_val) - return func_name, args, kwargs + return func_name, tuple(args), kwargs # if forgive_unquoted_strings: # cmd_args = [cmd_args[0]] + [_quote_args_that_you_forgot_to_quote(arg) for arg in cmd_args[1:]] diff --git a/artemis/general/should_be_builtins.py b/artemis/general/should_be_builtins.py index 1a1dac20..38ecaccc 100644 --- a/artemis/general/should_be_builtins.py +++ b/artemis/general/should_be_builtins.py @@ -514,3 +514,35 @@ def natural_keys(text): (See Toothy's implementation in the comments) """ return tuple(atoi(c) for c in re.split('(\d+)', text)) + + +class switch: + """ + A switch statement, made by Ian Bell at https://stackoverflow.com/a/30012053/851699 + + Usage: + + with switch(name) as case: + if case('bob', 'nancy'): + print("Come in, you're on the guest list") + elif case('drew'): + print("Sorry, after last time we can't let you in") + else: + print("Sorry, {}, we can't let you in.".format(case.value)) + """ + + def __init__(self, value): + self._val = value + + @property + def value(self): + return self._val + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + return False # Allows traceback to occur + + def __call__(self, *mconds): + return self._val in mconds diff --git a/artemis/general/test_dead_easy_ui.py b/artemis/general/test_dead_easy_ui.py new file mode 100644 index 00000000..3eff71d2 --- /dev/null +++ b/artemis/general/test_dead_easy_ui.py @@ -0,0 +1,10 @@ +from artemis.general.dead_easy_ui import parse_user_function_call + + +def test_parse_user_function_call(): + + assert parse_user_function_call("myfunc 1 a 'a b' c=3 ddd=[2,3] ee='abc'") == ("myfunc", (1, 'a', 'a b'), dict(c=3, ddd=[2, 3], ee='abc')) + + +if __name__ == '__main__': + test_parse_user_function_call() diff --git a/artemis/general/test_should_be_builtins.py b/artemis/general/test_should_be_builtins.py index d637ecb7..83171082 100644 --- a/artemis/general/test_should_be_builtins.py +++ b/artemis/general/test_should_be_builtins.py @@ -4,7 +4,7 @@ from artemis.general.should_be_builtins import itermap, reducemap, separate_common_items, remove_duplicates, \ detect_duplicates, remove_common_prefix, all_equal, get_absolute_module, insert_at, get_shifted_key_value, \ - divide_into_subsets, entries_to_table, natural_keys + divide_into_subsets, entries_to_table, natural_keys, switch __author__ = 'peter' @@ -127,6 +127,27 @@ def test_natural_keys(): assert sorted(['y8', 'x10', 'x2', 'y12', 'x9'], key=natural_keys) == ['x2', 'x9', 'x10', 'y8', 'y12'] +def test_switch_statement(): + + responses = [] + for name in ['nancy', 'joe', 'bob', 'drew']: + with switch(name) as case: + if case('bob', 'nancy'): + response = "Come in, you're on the guest list" + elif case('drew'): + response = "Sorry, after what happened last time we can't let you in" + else: + response = "Sorry, {}, we can't let you in.".format(case.value) + responses.append(response) + + assert responses == [ + "Come in, you're on the guest list", + "Sorry, joe, we can't let you in.", + "Come in, you're on the guest list", + "Sorry, after what happened last time we can't let you in" + ] + + if __name__ == '__main__': test_separate_common_items() test_reducemap() @@ -140,4 +161,5 @@ def test_natural_keys(): test_get_shifted_key_value() test_divide_into_subsets() test_entries_to_table() - test_natural_keys() \ No newline at end of file + test_natural_keys() + test_switch_statement() diff --git a/artemis/ml/parameter_schedule.py b/artemis/ml/parameter_schedule.py index 0fec2b65..9ee9ef0a 100644 --- a/artemis/ml/parameter_schedule.py +++ b/artemis/ml/parameter_schedule.py @@ -17,7 +17,7 @@ def __init__(self, schedule, print_variable_name = None): - A function which takes the epoch and returns a parameter value. - A number or array, in which case the value remains constant """ - if isinstance(schedule, (int, float, np.ndarray)): + if isinstance(schedule, (int, float, np.ndarray, str)): schedule = {0: schedule} if isinstance(schedule, dict): assert all(isinstance(num, (int, float)) for num in schedule.keys()) diff --git a/artemis/ml/tools/iteration.py b/artemis/ml/tools/iteration.py index 6ba00c2b..72f0ff50 100644 --- a/artemis/ml/tools/iteration.py +++ b/artemis/ml/tools/iteration.py @@ -294,7 +294,7 @@ def batchify_generator(generator_generator, batch_size = None, receive_input=Fal -----------vid-2-----------|--------vid-6---------| -----vid-3-------|----------vid-4------------------ - generator_genererator yields 7 generators, corresponding to each of the movies. + generator_genererator yields 7 generators, corresponding to each of the videos. Each of those generators is a frame-generator, which produces the frames in a given video. Here, we generate frames from each movie, and start a new movies whenever an old one stops, until there are no new movies to start. @@ -307,17 +307,23 @@ def batchify_generator(generator_generator, batch_size = None, receive_input=Fal """ assert receive_input in (False, 'post'), 'pre-receive not yet implemented' - total = batch_size - assert out_format in ('array', 'tuple_of_arrays') + # if isinstance(generator_generator, (list, tuple)): + # generator_generator = iter(generator_generator) + + # generators is a list of currently active generators + # generator_generator is a generator which yields new generators to be swapped into generators when the old ones get used up. if batch_size is not None: - generators = [next(generator_generator) for _ in range(batch_size)] + if isinstance(generator_generator, (list, tuple)): + generator_generator = iter(generator_generator) + generators = [iter(next(generator_generator)) for _ in range(batch_size)] + else: assert isinstance(generator_generator, (list, tuple)), "If you don't specify a batch size your generator-generator must be a finite list." batch_size = len(generator_generator) - generators = generator_generator generator_generator = iter(generator_generator) + generators = [iter(gen) for gen in generator_generator] while True: items = [] @@ -327,8 +333,7 @@ def batchify_generator(generator_generator, batch_size = None, receive_input=Fal items.append(next(generators[i])) break except StopIteration: - total+=1 - generators[i] = next(generator_generator) # This will rais StopIteration when we're out of generators + generators[i] = iter(next(generator_generator)) # This will raise StopIteration when we're out of generators if out_format=='array': output= np.array(items) diff --git a/artemis/ml/tools/test_iteration.py b/artemis/ml/tools/test_iteration.py index 094b2f55..b096a524 100644 --- a/artemis/ml/tools/test_iteration.py +++ b/artemis/ml/tools/test_iteration.py @@ -1,5 +1,5 @@ from artemis.ml.tools.iteration import minibatch_index_generator, checkpoint_minibatch_index_generator, \ - zip_minibatch_iterate_info, minibatch_process + zip_minibatch_iterate_info, minibatch_process, batchify_generator __author__ = 'peter' import numpy as np @@ -113,9 +113,19 @@ def func(x): assert np.allclose(y1, y2) # weird numpy rounding makes it not exactly equal +def test_batchify_generator(): + + a = [x for x in batchify_generator(batch_size=2, generator_generator=[[1, 2, 3], [4, 5], [6, 7, 8], [9, 10, 11, 12]])] + assert np.array_equal(a, [[1, 4], [2, 5], [3, 6], [9, 7], [10, 8]]) + + a = [x for x in batchify_generator(batch_size=None, generator_generator=[[1, 2, 3], [4, 5], [6, 7, 8], [9, 10, 11, 12]])] + assert np.array_equal(a, [[1, 4, 6, 9], [2, 5, 7, 10]]) + + if __name__ == '__main__': test_minibatch_index_even() test_minibatch_process() test_minibatch_iterate_info() test_minibatch_index_generator() test_checkpoint_minibatch_generator() + test_batchify_generator() From 13ba79f94f48a96b9057fba8263977cc0c2dd62a Mon Sep 17 00:00:00 2001 From: O'Connor Date: Mon, 28 Jan 2019 18:38:27 +0100 Subject: [PATCH 26/41] windowsfix --- artemis/fileman/local_dir.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/artemis/fileman/local_dir.py b/artemis/fileman/local_dir.py index 2953048c..a01ee537 100644 --- a/artemis/fileman/local_dir.py +++ b/artemis/fileman/local_dir.py @@ -3,6 +3,7 @@ from artemis.config import get_artemis_config_value import os from six.moves import xrange +from os.path import expanduser __author__ = 'peter' @@ -19,7 +20,7 @@ def get_default_local_path(): - return os.path.join(os.getenv("HOME"), '.artemis') + return os.path.join(expanduser("~"), '.artemis') LOCAL_DIR = get_artemis_config_value(section='fileman', option='data_dir', default_generator = get_default_local_path, write_default = True) From 4db00489bd511810d68afadcfd254ba13484a57b Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Wed, 13 Feb 2019 18:32:04 -0900 Subject: [PATCH 27/41] parameter_search improvements --- artemis/experiments/experiments.py | 23 ++++++++++++++++++----- artemis/general/functional.py | 7 +++++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/artemis/experiments/experiments.py b/artemis/experiments/experiments.py index b6862744..c803f786 100644 --- a/artemis/experiments/experiments.py +++ b/artemis/experiments/experiments.py @@ -381,7 +381,10 @@ def get_latest_record(self, only_completed=False, if_none = 'skip'): Return the ExperimentRecord from the latest run of this Experiment. :param only_completed: Only search among records of that have run to completion. - :param err_if_none: If True, raise an error if no record exists. Otherwise, just return None in this case. + :param if_none: What to do if no record exists. Options are: + 'skip': Return None + 'err': Raise an exception + 'run': Run the experiment to get the record :return ExperimentRecord: An ExperimentRecord object """ assert if_none in ('skip', 'err', 'run') @@ -421,9 +424,10 @@ def get_variant_records(self, only_completed=False, only_last=False, flat=False) else: return exp_record_dict - def add_parameter_search(self, name='parameter_search', space = None, n_calls=100, search_params = None, scalar_func=None): + def add_parameter_search(self, name='parameter_search', fixed_args = {}, space = None, n_calls=100, search_params = None, scalar_func=None): """ :param name: Name of the Experiment to be created + :param dict[str, Any] fixed_args: Any fixed-arguments to provide to all experiments. :param dict[str, skopt.space.Dimension] space: A dict mapping param name to Dimension. e.g. space=dict(a = Real(1, 100, 'log-uniform'), b = Real(1, 100, 'log-uniform')) :param Callable[[Any], float] scalar_func: Takes the return value of the experiment and turns it into a scalar @@ -446,18 +450,27 @@ def objective(**current_params): from artemis.experiments import ExperimentFunction - @ExperimentFunction(name = self.name + '.'+ name, show = show_parameter_search_record, one_liner_function=parameter_search_one_liner) - def search_exp(): + def search_func(fixed): if is_test_mode(): nonlocal n_calls n_calls = 3 # When just verifying that experiment runs, do the minimum - for iter_info in parameter_search(objective, n_calls=n_calls, space=space, **search_params): + this_objective = partial(objective, **fixed) + + for iter_info in parameter_search(this_objective, n_calls=n_calls, space=space, **search_params): info = dict(names=list(space.keys()), x_iters =iter_info.x_iters, func_vals=iter_info.func_vals, score = iter_info.func_vals, x=iter_info.x, fun=iter_info.fun) latest_info = {name: val for name, val in izip_equal(info['names'], iter_info.x_iters[-1])} print(f'Latest: {latest_info}, Score: {iter_info.func_vals[-1]:.3g}') yield info + # The following is a hack to dynamically create a function with the given args + # arg_string = ', '.join('{}={}'.format(k, v) for k, v in fixed_args.items()) + # param_search = None + # exec('global param_search\ndef func({fixed}): search_func(fixed_args=dict({fixed})); param_search=func'.format(fixed=arg_string)) + # param_search = locals()['param_search'] + search_exp_func = partial(search_func, fixed=fixed_args) # We do this so that the fixed parameters will be recorded and we will see if they changed. + + search_exp = ExperimentFunction(name = self.name + '.'+ name, show = show_parameter_search_record, one_liner_function=parameter_search_one_liner)(search_exp_func) self.variants[name] = search_exp search_exp.tag('psearch') # Secret feature that makes it easy to select all parameter experiments in ui with "filter tag:psearch" return search_exp diff --git a/artemis/general/functional.py b/artemis/general/functional.py index a1660c99..3acc6b97 100644 --- a/artemis/general/functional.py +++ b/artemis/general/functional.py @@ -185,9 +185,12 @@ def get_defined_and_undefined_args(func): """ undefined_arg_names, varargs_name, kwargs_name, defined_args = advanced_getargspec(func) assert varargs_name is None - assert kwargs_name is None for k in defined_args.keys(): - undefined_arg_names.remove(k) + if kwargs_name is None: # If the function does not have **kwargs + undefined_arg_names.remove(k) + else: + if k in undefined_arg_names: + undefined_arg_names.remove(k) return defined_args, undefined_arg_names From fd82c625eac5751925c9349af24a91c21e52140f Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Fri, 15 Feb 2019 15:31:42 -0900 Subject: [PATCH 28/41] stuuuff --- artemis/experiments/experiment_management.py | 28 ++++++++++ artemis/plotting/range_plots.py | 57 ++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 artemis/plotting/range_plots.py diff --git a/artemis/experiments/experiment_management.py b/artemis/experiments/experiment_management.py index 12ac236e..b6e5ce90 100644 --- a/artemis/experiments/experiment_management.py +++ b/artemis/experiments/experiment_management.py @@ -545,6 +545,34 @@ def run_multiple_experiments(experiments, prefixes = None, parallel = False, dis return [ex.run(raise_exceptions=raise_exceptions, display_results=display_results, notes=notes, **run_args) for ex in experiments] +def get_multiple_records(experiment, n, only_completed=True, if_not_enough='run'): + """ + Get n records from a single experiment. + :param Experiment experiment: The experiment + :param int n: Number of records to get + :param only_completed: True if you only want completed records + :param if_not_enough: What to do if there are not enough records ready. + 'run': Run more + 'cut': Just return the number that are already calculated + 'err': Raise an excepetion + :return: + """ + if isinstance(experiment, str): + experiment = load_experiment(experiment) + assert if_not_enough in ('run', 'cut', 'err') + records = experiment.get_records(only_completed=only_completed) + if if_not_enough == 'err': + assert len(records) >= n, "You asked for {} records, but only {} were available".format(n, len(records)) + return records[-n:] + elif if_not_enough=='run': + for k in range(n-len(records)): + record = experiment.run() + records.append(record) + return records[-n:] + else: + return records + + def remove_common_results_prefix(results_dict): """ Remove the common prefix for the results you are comparing. diff --git a/artemis/plotting/range_plots.py b/artemis/plotting/range_plots.py new file mode 100644 index 00000000..6b370d45 --- /dev/null +++ b/artemis/plotting/range_plots.py @@ -0,0 +1,57 @@ +from matplotlib import pyplot as plt +import numpy as np + + +def plot_sample_mean_and_var(*x_and_ys, var_rep ='std', fill_alpha = 0.25, **plot_kwargs): + """ + Given a collection of signals, plot their mean and fill a range around the mean. Example: + x = np.arange(-5, 5) + ys = np.random.randn(20, len(x_data)) + x ** 2 - 2 + plot_sample_mean_and_var(x, ys, var_rep='std') + :param x_and_ys: You can provide either x and the y-signals or just the y-signals + :param var_rep: How to represent the variance. Options are: + 'std': Standard Deviation + 'sterr': Standard Error of the Mean + 'lim': Min/max + :param fill_alpha: + :param plot_kwargs: + :return: + """ + if len(x_and_ys)==2: + x, ys = x_and_ys + else: + assert len(x_and_ys) == 1, "You must provide unnamed arguments in order (ys) or (x, ys)" + ys, = x_and_ys + x = range(len(ys[0])) + + assert var_rep in ('std', 'sterr', 'lim') + + mean = np.mean(ys, axis=0) + + if var_rep == 'std': + std = np.std(ys, axis=0) + lower, upper = mean-std, mean+std + elif var_rep == 'sterr': + sterr = np.std(ys, axis=0)/np.sqrt(len(ys)) + lower, upper = mean-sterr, mean+sterr + elif var_rep == 'lim': + lower, upper = np.min(ys, axis=0), np.max(ys, axis=0) + else: + raise NotImplementedError(var_rep) + + mean_handel, = plt.plot(x, mean, **plot_kwargs) + fill_handle = plt.fill_between(x, lower, upper, color=mean_handel.get_color(), alpha=fill_alpha) + return mean_handel, fill_handle + + +if __name__ == '__main__': + x_data = np.arange(-5, 5) + ys1 = np.random.randn(20, len(x_data)) + x_data ** 2 - 2 + plot_sample_mean_and_var(x_data, ys1, var_rep='std') + + ys2 = np.random.randn(20, len(x_data)) + .9 * x_data ** 2 - 2 + plot_sample_mean_and_var(x_data, ys2, var_rep='std') + + ys3 = np.random.randn(20, len(x_data)) + .7 * x_data ** 2 - 2 + plot_sample_mean_and_var(x_data, ys3, var_rep='std') + plt.show() From 50c6c44178a3ac03371a437f84a75b4a770f50d1 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Fri, 15 Feb 2019 19:25:28 -0900 Subject: [PATCH 29/41] preclean --- artemis/experiments/experiment_record_view.py | 26 +++ artemis/plotting/parallel_coords_plots.py | 160 ++++++++++++++++++ artemis/plotting/pyplot_plus.py | 8 +- 3 files changed, 190 insertions(+), 4 deletions(-) create mode 100644 artemis/plotting/parallel_coords_plots.py diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 7da1c8bb..2b45153e 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -2,6 +2,7 @@ from collections import OrderedDict from functools import partial import itertools + from six import string_types from tabulate import tabulate import numpy as np @@ -16,6 +17,9 @@ from artemis.general.tables import build_table import os +from artemis.plotting.parallel_coords_plots import plot_hyperparameter_search_parallel_coords +from artemis.plotting.pyplot_plus import get_color_cycle_map + def get_record_result_string(record, func='deep', truncate_to = None, array_print_threshold=8, array_float_format='.3g', oneline=False, default_one_liner_func=str): """ @@ -581,3 +585,25 @@ def queryfig(): print('Use Left/Right arrows to navigate, ') show_figure(nonlocals.figno) + + +def plot_hyperparameter_search(record, relabel = None, show_order_first = True, show_score_last = True, score_name='score'): + """ + Create a parallel coordinates plot representing a hyperparameter search experiment record. + :param record: + :param show_order_first: + :param show_score_last: + :param score_name: + :return: + """ + result = record.get_result() + + assert {'names', 'x_iters', 'func_vals'}.issubset(result.keys()), "Record {} does not appear to be from a Parameter Search experiment!".format(record) + + names, x_iters, func_vals = result['names'], result['x_iters'], result['func_vals'] + + if relabel is not None: + assert set(relabel.keys()).issubset(names), 'Not all relabeling keys {} were found in names {}'.format(list(relabel.keys()), list(names)) + names = [relabel[n] if n in relabel else n for n in names] + + return plot_hyperparameter_search_parallel_coords(field_names=list(names), x_iters=x_iters, func_vals=func_vals, show_iter_first=show_order_first, show_score_last=show_score_last, score_name=score_name) diff --git a/artemis/plotting/parallel_coords_plots.py b/artemis/plotting/parallel_coords_plots.py new file mode 100644 index 00000000..3cfca020 --- /dev/null +++ b/artemis/plotting/parallel_coords_plots.py @@ -0,0 +1,160 @@ + +import matplotlib +from matplotlib import pyplot as plt +import numpy as np +from artemis.general.should_be_builtins import izip_equal, bad_value +# +# +# def parallel_coords_plot(field_names, values, color_field = None): +# """ +# Create a Parallel coordinates plot. +# Code lifted and modified from http://benalexkeen.com/parallel-coordinates-in-matplotlib/ +# +# :param field_names: A list of (n_fields) field names +# :param values: A (n_fields, n_samples) array of values. +# :return: +# """ +# +# n_fields, n_samples = values.shape +# df = {name: row for name, row in izip_equal(field_names, values)} +# +# from matplotlib import ticker +# +# assert len(field_names)==len(values), 'The number of field names must equal the number of rows in values.' +# +# # field_names = ['displacement', 'cylinders', 'horsepower', 'weight', 'acceleration'] +# x = [i for i, _ in enumerate(field_names)] +# # colours = ['#2e8ad8', '#cd3785', '#c64c00', '#889a00'] +# +# # create dict of categories: colours +# # colours = {df['mpg'].cat.categories[i]: colours[i] for i, _ in enumerate(df['mpg'].cat.categories)} +# +# # Create (X-1) sublots along x axis +# fig, axes = plt.subplots(1, len(x)-1, sharey=False, figsize=(15,5)) +# +# # Get min, max and range for each column +# # Normalize the data for each column +# min_max_range = {} +# for col in field_names: +# min_max_range[col] = [np.min(df[col]), np.max(df[col]), np.ptp(df[col])] +# # df[col] = np.true_divide(df[col] - np.min(df[col]), np.ptp(df[col])) +# values = (values-np.min(values, axis=1, keepdims=True)) / (np.max(values, axis=1, keepdims=True)-np.min(values, axis=1, keepdims=True)) +# +# # Plot each row +# for i, ax in enumerate(axes): +# for idx in range(n_samples): +# +# # ax.plot(df[]) +# +# # mpg_category = df.loc[idx, 'mpg'] +# +# # ax.plot(x, df.loc[idx, field_names], colours[mpg_category]) +# ax.plot(x, values[:, idx]) +# ax.set_xlim([x[i], x[i+1]]) +# +# # Set the tick positions and labels on y axis for each plot +# # Tick positions based on normalised data +# # Tick labels are based on original data +# def set_ticks_for_axis(dim, ax, ticks): +# min_val, max_val, val_range = min_max_range[field_names[dim]] +# step = val_range / float(ticks-1) +# tick_labels = [round(min_val + step * i, 2) for i in range(ticks)] +# norm_min = df[field_names[dim]].min() +# norm_range = np.ptp(df[field_names[dim]]) +# norm_step = norm_range / float(ticks-1) +# ticks = [round(norm_min + norm_step * i, 2) for i in range(ticks)] +# ax.yaxis.set_ticks(ticks) +# ax.set_yticklabels(tick_labels) +# +# for dim, ax in enumerate(axes): +# ax.xaxis.set_major_locator(ticker.FixedLocator([dim])) +# set_ticks_for_axis(dim, ax, ticks=6) +# ax.set_xticklabels([field_names[dim]]) +# +# +# # Move the final axis' ticks to the right-hand side +# ax = plt.twinx(axes[-1]) +# dim = len(axes) +# ax.xaxis.set_major_locator(ticker.FixedLocator([x[-2], x[-1]])) +# set_ticks_for_axis(dim, ax, ticks=6) +# ax.set_xticklabels([field_names[-2], field_names[-1]]) +# +# +# # Remove space between subplots +# plt.subplots_adjust(wspace=0) + + # Add legend to plot + # plt.legend( + # [plt.Line2D((0,1),(0,0), color=colours[cat]) for cat in df['mpg'].cat.categories], + # df['mpg'].cat.categories, + # bbox_to_anchor=(1.2, 1), loc=2, borderaxespad=0.) +# +# +# # plt.title("Values of car attributes by MPG category") +# + # plt.show() +from artemis.plotting.pyplot_plus import axhlines + +def draw_norm_y_axis(x_position, lims, scale='lin', axis_thickness=2, n_intermediates=3, tickwidth=0.1, axiscolor='k'): + """ + Draw a y-axis in a Parallel Coordinates plot + """ + assert scale=='lin', 'For now' + lower, upper = lims + line = plt.axvline(x=x_position, ymin=0, ymax=1, linewidth=axis_thickness, color=axiscolor) + y_axisticks = np.linspace(0, 1, n_intermediates+2) + y_labels = ['{:.2g}'.format(y*(upper-lower)+lower) for y in y_axisticks] + h_ticklabels = [plt.text(x=x_position+tickwidth/2., y=y, s=ylab, color='k', bbox=dict(boxstyle="square", fc=(1., 1., 1., 0.5), ec=(0, 0, 0, 0.))) for y, ylab in izip_equal(y_axisticks, y_labels)] + h_ticks = axhlines(ys = y_axisticks, lims=(x_position-tickwidth/2., x_position+tickwidth/2.), linewidth=axis_thickness, color=axiscolor, zorder=4) + return line, h_ticks, h_ticklabels + + +def parallel_coords_plot(field_names, values, scales = {}, ax=None): + """ + Create a Parallel coordinates plot. + + :param field_names: A list of (n_fields) field names + :param values: A (n_fields, n_samples) array of values. + :return: A list of handles to the plot objectss + """ + + assert set(scales.keys()).issubset(field_names), 'All scales must be in field names.' + assert len(field_names) == len(values) + if ax is None: + ax = plt.gca() + v_min, v_max = np.min(values, axis=1, keepdims=True), np.max(values, axis=1, keepdims=True) + + norm_lines = (values-v_min) / (v_max-v_min) + + cmap = matplotlib.cm.get_cmap('Spectral') + hs = [plt.plot(line, color=cmap(line[-1]))[0] for i, line in enumerate(norm_lines.T)] + + for i, f in enumerate(field_names): + draw_norm_y_axis(x_position=i, lims=(v_min[i, 0], v_max[i, 0]), scale = 'lin' if f not in scales else scales[f]) + + ax.set_xticks(range(len(field_names))) + ax.set_xticklabels(field_names) + + ax.tick_params(axis='y', labelleft='off') + ax.set_yticks([]) + # ax.set_yticklabels([]) + ax.set_xlim(0, len(field_names)-1) + + return hs + + +def plot_hyperparameter_search_parallel_coords(field_names, x_iters, func_vals, show_iter_first = True, show_score_last = True, iter_name='iter', score_name='score'): + """ + Visualize the result of a hyperparameter search using a Parallel Coordinates plot + :param field_names: A (n_hyperparameters) list of names of the hyperparameters + :param x_iters: A (n_steps, n_hyperparameters) list of hyperparameter values + :param func_vals: A (n_hyperparameters) list of scores computed for each value + :param show_iter_first: Insert "iter" (the interation index in the search) as a first column to the plot + :param show_score_last: Insert "score" as a last column to the plot + :param iter_name: Name of the "iter" field + :param score_name: Name of the "score" field. + :return: A list of plot handels + """ + field_names = ([iter_name] if show_iter_first else []) + list(field_names) + ([score_name] if show_score_last else []) + lines = [([i] if show_iter_first else []) + list(params) + ([val] if show_score_last else []) for i, (params, val) in enumerate(izip_equal(x_iters, func_vals))] + return parallel_coords_plot(field_names=field_names, values=np.array(lines).T) diff --git a/artemis/plotting/pyplot_plus.py b/artemis/plotting/pyplot_plus.py index 3020924a..8864dce9 100644 --- a/artemis/plotting/pyplot_plus.py +++ b/artemis/plotting/pyplot_plus.py @@ -15,7 +15,7 @@ """ -def axhlines(ys, ax=None, **plot_kwargs): +def axhlines(ys, lims=None, ax=None, **plot_kwargs): """ Draw horizontal lines across plot :param ys: A scalar, list, or 1D array of vertical offsets @@ -26,14 +26,14 @@ def axhlines(ys, ax=None, **plot_kwargs): if ax is None: ax = plt.gca() ys = np.array((ys, ) if np.isscalar(ys) else ys, copy=False) - lims = ax.get_xlim() + lims = ax.get_xlim() if lims is None else lims y_points = np.repeat(ys[:, None], repeats=3, axis=1).flatten() x_points = np.repeat(np.array(lims + (np.nan, ))[None, :], repeats=len(ys), axis=0).flatten() plot = ax.plot(x_points, y_points, scalex = False, **plot_kwargs) return plot -def axvlines(xs, ax=None, **plot_kwargs): +def axvlines(xs, lims=None, ax=None, **plot_kwargs): """ Draw vertical lines on plot :param xs: A scalar, list, or 1D array of horizontal offsets @@ -44,7 +44,7 @@ def axvlines(xs, ax=None, **plot_kwargs): if ax is None: ax = plt.gca() xs = np.array((xs, ) if np.isscalar(xs) else xs, copy=False) - lims = ax.get_ylim() + lims = ax.get_ylim() if lims is None else lims x_points = np.repeat(xs[:, None], repeats=3, axis=1).flatten() y_points = np.repeat(np.array(lims + (np.nan, ))[None, :], repeats=len(xs), axis=0).flatten() plot = ax.plot(x_points, y_points, scaley = False, **plot_kwargs) From 4bf2fe0cd390219c2399e2256c8fbc1aff16fa85 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Sat, 16 Feb 2019 16:12:50 -0900 Subject: [PATCH 30/41] cleaned up parallel coords plot --- .gitignore | 2 +- artemis/experiments/experiment_record_view.py | 23 +-- artemis/plotting/parallel_coords_plots.py | 186 +++++++----------- 3 files changed, 83 insertions(+), 128 deletions(-) diff --git a/.gitignore b/.gitignore index 6e998fcc..e4150e3e 100644 --- a/.gitignore +++ b/.gitignore @@ -71,4 +71,4 @@ venv/ # Files /Data /docs/build - +.pytest_cache diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 2b45153e..322c0882 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -587,23 +587,18 @@ def queryfig(): show_figure(nonlocals.figno) -def plot_hyperparameter_search(record, relabel = None, show_order_first = True, show_score_last = True, score_name='score'): +def plot_hyperparameter_search(record, relabel = None, assert_all_relabels_used = True, **hypersearch_parallel_kwargs): """ Create a parallel coordinates plot representing a hyperparameter search experiment record. - :param record: - :param show_order_first: - :param show_score_last: - :param score_name: - :return: + :param ExperimentRecord record: An experiment record object + :param hypersearch_parallel_kwargs: See plot_hyperparameter_search_parallel_coords + :return: A bunch of plot handels """ result = record.get_result() - - assert {'names', 'x_iters', 'func_vals'}.issubset(result.keys()), "Record {} does not appear to be from a Parameter Search experiment!".format(record) - - names, x_iters, func_vals = result['names'], result['x_iters'], result['func_vals'] - + assert {'names', 'x_iters', 'func_vals', 'x'}.issubset(result.keys()), "Record {} does not appear to be from a Parameter Search experiment!".format(record) + names, x_iters, func_vals, x = result['names'], result['x_iters'], result['func_vals'], result['x'] if relabel is not None: - assert set(relabel.keys()).issubset(names), 'Not all relabeling keys {} were found in names {}'.format(list(relabel.keys()), list(names)) + if assert_all_relabels_used: + assert set(relabel.keys()).issubset(names), 'Not all relabeling keys {} were found in names {}'.format(list(relabel.keys()), list(names)) names = [relabel[n] if n in relabel else n for n in names] - - return plot_hyperparameter_search_parallel_coords(field_names=list(names), x_iters=x_iters, func_vals=func_vals, show_iter_first=show_order_first, show_score_last=show_score_last, score_name=score_name) + return plot_hyperparameter_search_parallel_coords(field_names=list(names), param_sequence=x_iters, func_vals=func_vals, final_params=x, **hypersearch_parallel_kwargs) diff --git a/artemis/plotting/parallel_coords_plots.py b/artemis/plotting/parallel_coords_plots.py index 3cfca020..8e88e302 100644 --- a/artemis/plotting/parallel_coords_plots.py +++ b/artemis/plotting/parallel_coords_plots.py @@ -1,154 +1,107 @@ import matplotlib -from matplotlib import pyplot as plt import numpy as np -from artemis.general.should_be_builtins import izip_equal, bad_value -# -# -# def parallel_coords_plot(field_names, values, color_field = None): -# """ -# Create a Parallel coordinates plot. -# Code lifted and modified from http://benalexkeen.com/parallel-coordinates-in-matplotlib/ -# -# :param field_names: A list of (n_fields) field names -# :param values: A (n_fields, n_samples) array of values. -# :return: -# """ -# -# n_fields, n_samples = values.shape -# df = {name: row for name, row in izip_equal(field_names, values)} -# -# from matplotlib import ticker -# -# assert len(field_names)==len(values), 'The number of field names must equal the number of rows in values.' -# -# # field_names = ['displacement', 'cylinders', 'horsepower', 'weight', 'acceleration'] -# x = [i for i, _ in enumerate(field_names)] -# # colours = ['#2e8ad8', '#cd3785', '#c64c00', '#889a00'] -# -# # create dict of categories: colours -# # colours = {df['mpg'].cat.categories[i]: colours[i] for i, _ in enumerate(df['mpg'].cat.categories)} -# -# # Create (X-1) sublots along x axis -# fig, axes = plt.subplots(1, len(x)-1, sharey=False, figsize=(15,5)) -# -# # Get min, max and range for each column -# # Normalize the data for each column -# min_max_range = {} -# for col in field_names: -# min_max_range[col] = [np.min(df[col]), np.max(df[col]), np.ptp(df[col])] -# # df[col] = np.true_divide(df[col] - np.min(df[col]), np.ptp(df[col])) -# values = (values-np.min(values, axis=1, keepdims=True)) / (np.max(values, axis=1, keepdims=True)-np.min(values, axis=1, keepdims=True)) -# -# # Plot each row -# for i, ax in enumerate(axes): -# for idx in range(n_samples): -# -# # ax.plot(df[]) -# -# # mpg_category = df.loc[idx, 'mpg'] -# -# # ax.plot(x, df.loc[idx, field_names], colours[mpg_category]) -# ax.plot(x, values[:, idx]) -# ax.set_xlim([x[i], x[i+1]]) -# -# # Set the tick positions and labels on y axis for each plot -# # Tick positions based on normalised data -# # Tick labels are based on original data -# def set_ticks_for_axis(dim, ax, ticks): -# min_val, max_val, val_range = min_max_range[field_names[dim]] -# step = val_range / float(ticks-1) -# tick_labels = [round(min_val + step * i, 2) for i in range(ticks)] -# norm_min = df[field_names[dim]].min() -# norm_range = np.ptp(df[field_names[dim]]) -# norm_step = norm_range / float(ticks-1) -# ticks = [round(norm_min + norm_step * i, 2) for i in range(ticks)] -# ax.yaxis.set_ticks(ticks) -# ax.set_yticklabels(tick_labels) -# -# for dim, ax in enumerate(axes): -# ax.xaxis.set_major_locator(ticker.FixedLocator([dim])) -# set_ticks_for_axis(dim, ax, ticks=6) -# ax.set_xticklabels([field_names[dim]]) -# -# -# # Move the final axis' ticks to the right-hand side -# ax = plt.twinx(axes[-1]) -# dim = len(axes) -# ax.xaxis.set_major_locator(ticker.FixedLocator([x[-2], x[-1]])) -# set_ticks_for_axis(dim, ax, ticks=6) -# ax.set_xticklabels([field_names[-2], field_names[-1]]) -# -# -# # Remove space between subplots -# plt.subplots_adjust(wspace=0) - - # Add legend to plot - # plt.legend( - # [plt.Line2D((0,1),(0,0), color=colours[cat]) for cat in df['mpg'].cat.categories], - # df['mpg'].cat.categories, - # bbox_to_anchor=(1.2, 1), loc=2, borderaxespad=0.) -# -# -# # plt.title("Values of car attributes by MPG category") -# - # plt.show() +from matplotlib import pyplot as plt + +from artemis.general.mymath import cosine_distance +from artemis.general.should_be_builtins import izip_equal from artemis.plotting.pyplot_plus import axhlines -def draw_norm_y_axis(x_position, lims, scale='lin', axis_thickness=2, n_intermediates=3, tickwidth=0.1, axiscolor='k'): + +def draw_norm_y_axis(x_position, lims, scale='lin', axis_thickness=2, n_intermediates=3, tickwidth=0.1, axiscolor='k', ticklabel_format='{:.3g}', tick_round_grid=40): """ Draw a y-axis in a Parallel Coordinates plot + + :param x_position: Position in x to draw the axis + :param lims: The (min, max) limit of the y-axis + :param scale: Not implemented for now, just leave at 'lin'. (Todo: implement 'log') + :param axis_thickness: Thickness of the axis line + :param n_intermediates: Number of ticks to put in between ends of axis + :param tickwidth: Width of tick lines + :param axiscolor: Color of axis + :param ticklabel_format: Format for string ticklabel numbers + :param tick_round_grid: Round ticks to a grid with this number of points, or None to not do this. (Causes nicer axis labels) + :return: The handels for the (, , ) """ assert scale=='lin', 'For now' lower, upper = lims - line = plt.axvline(x=x_position, ymin=0, ymax=1, linewidth=axis_thickness, color=axiscolor) + vertical_line_handel = plt.axvline(x=x_position, ymin=0, ymax=1, linewidth=axis_thickness, color=axiscolor) y_axisticks = np.linspace(0, 1, n_intermediates+2) - y_labels = ['{:.2g}'.format(y*(upper-lower)+lower) for y in y_axisticks] - h_ticklabels = [plt.text(x=x_position+tickwidth/2., y=y, s=ylab, color='k', bbox=dict(boxstyle="square", fc=(1., 1., 1., 0.5), ec=(0, 0, 0, 0.))) for y, ylab in izip_equal(y_axisticks, y_labels)] - h_ticks = axhlines(ys = y_axisticks, lims=(x_position-tickwidth/2., x_position+tickwidth/2.), linewidth=axis_thickness, color=axiscolor, zorder=4) - return line, h_ticks, h_ticklabels + y_trueticks = y_axisticks * (upper - lower) + lower + if tick_round_grid is not None: + # spacing = (upper - lower)/tick_round_grid + spacing = 10**np.round(np.log10((upper-lower)/tick_round_grid)) + y_trueticks = np.round(y_trueticks/spacing)*spacing + y_axisticks = (y_trueticks - y_trueticks[0])/(y_trueticks[-1] - y_trueticks[0]) + y_labels = [ticklabel_format.format(y) for y in y_trueticks] + tick_label_handels = [plt.text(x=x_position+tickwidth/2., y=y, s=ylab, color='k', bbox=dict(boxstyle="square", fc=(1., 1., 1., 0.5), ec=(0, 0, 0, 0.))) for y, ylab in izip_equal(y_axisticks, y_labels)] + tick_handels = axhlines(ys = y_axisticks, lims=(x_position-tickwidth/2., x_position+tickwidth/2.), linewidth=axis_thickness, color=axiscolor, zorder=4) + return vertical_line_handel, tick_handels, tick_label_handels -def parallel_coords_plot(field_names, values, scales = {}, ax=None): +def parallel_coords_plot(field_names, values, special_formats = {}, scales = {}, color_index=-1, ax=None, alpha='auto', cmap='Spectral', **plot_kwargs): """ - Create a Parallel coordinates plot. + Create a Parallel coordinates plot. These plots are useful for visualizing high-dimensional data. :param field_names: A list of (n_fields) field names - :param values: A (n_fields, n_samples) array of values. - :return: A list of handles to the plot objectss + :param values: A (n_samples, n_fields) array of values. + :param Dict[int, Dict] special_formats: Optionally a dictionary mapping from sample index to line format. This can be used to highlight certain lines. + :param Dict[str, str] scales: (currently not implemented) Identifies the scale ('lin' or 'log') for each field name + :param color_index: Which column of values to use to colour-code the lines. Defaults to the last column. + :param ax: The plot axis (if None, use current axis (gca)) + :param alpha: The alpha (opaqueness) value to use. If 'auto', this function automatically lowers alpha in regions of dense overlap. + :param plot_kwargs: Other kwargs to pass to line plots (these will be overridden on a per-plot basis by special_formats, alpha) + :return: A list of handles to the plot objects """ - + values = np.array(values, copy=False) assert set(scales.keys()).issubset(field_names), 'All scales must be in field names.' - assert len(field_names) == len(values) + assert len(field_names) == values.shape[1] if ax is None: ax = plt.gca() - v_min, v_max = np.min(values, axis=1, keepdims=True), np.max(values, axis=1, keepdims=True) + v_min, v_max = np.min(values, axis=0), np.max(values, axis=0) norm_lines = (values-v_min) / (v_max-v_min) - cmap = matplotlib.cm.get_cmap('Spectral') - hs = [plt.plot(line, color=cmap(line[-1]))[0] for i, line in enumerate(norm_lines.T)] + cmap = matplotlib.cm.get_cmap(cmap) + formats = {i: plot_kwargs.copy() for i in range(len(norm_lines))} + for i, line in enumerate(norm_lines): # Color lines according to score + formats[i]['color']=cmap(1-line[color_index]) + if alpha=='auto': + mean_param = np.mean(norm_lines, axis=0) + for i, line in enumerate(norm_lines): + sameness = max(0, cosine_distance(mean_param, line)) # (0 to 1 where 1 means same as the mean) + alpha = sameness * (1./np.sqrt(values.shape[0])) + (1-sameness)*1. + formats[i]['alpha'] = alpha + else: + for i in range(len(norm_lines)): + formats[i]['alpha'] = alpha + for i, form in special_formats.items(): # Add special formats + formats[i].update(form) + plot_kwargs.update(dict(alpha=alpha)) + + hs = [plt.plot(line, **formats[i])[0] for i, line in enumerate(norm_lines)] for i, f in enumerate(field_names): - draw_norm_y_axis(x_position=i, lims=(v_min[i, 0], v_max[i, 0]), scale = 'lin' if f not in scales else scales[f]) + draw_norm_y_axis(x_position=i, lims=(v_min[i], v_max[i]), scale = 'lin' if f not in scales else scales[f]) ax.set_xticks(range(len(field_names))) ax.set_xticklabels(field_names) ax.tick_params(axis='y', labelleft='off') ax.set_yticks([]) - # ax.set_yticklabels([]) ax.set_xlim(0, len(field_names)-1) return hs -def plot_hyperparameter_search_parallel_coords(field_names, x_iters, func_vals, show_iter_first = True, show_score_last = True, iter_name='iter', score_name='score'): +def plot_hyperparameter_search_parallel_coords(field_names, param_sequence, func_vals, final_params = None, show_iter_first = True, show_score_last = True, iter_name='iter', score_name='score'): """ Visualize the result of a hyperparameter search using a Parallel Coordinates plot :param field_names: A (n_hyperparameters) list of names of the hyperparameters - :param x_iters: A (n_steps, n_hyperparameters) list of hyperparameter values + :param param_sequence: A (n_steps, n_hyperparameters) list of hyperparameter values :param func_vals: A (n_hyperparameters) list of scores computed for each value + :param final_params: Optionally, provide the final chosen set of hyperparameters. These will be plotted as a thick + black dotted line. :param show_iter_first: Insert "iter" (the interation index in the search) as a first column to the plot :param show_score_last: Insert "score" as a last column to the plot :param iter_name: Name of the "iter" field @@ -156,5 +109,12 @@ def plot_hyperparameter_search_parallel_coords(field_names, x_iters, func_vals, :return: A list of plot handels """ field_names = ([iter_name] if show_iter_first else []) + list(field_names) + ([score_name] if show_score_last else []) - lines = [([i] if show_iter_first else []) + list(params) + ([val] if show_score_last else []) for i, (params, val) in enumerate(izip_equal(x_iters, func_vals))] - return parallel_coords_plot(field_names=field_names, values=np.array(lines).T) + lines = [([i+1] if show_iter_first else []) + list(params) + ([val] if show_score_last else []) for i, (params, val) in enumerate(izip_equal(param_sequence, func_vals))] + + if final_params is not None: # This adds a black dotted line over the final set of hyperparameters + ix = next(i for i, v in enumerate(param_sequence) if np.array_equal(v, final_params)) + lines.append(([ix] if show_iter_first else [])+list(final_params)+([func_vals[ix] if show_score_last else []])) + special_formats = {len(lines)-1: dict(linewidth=2, color='k', linestyle='--', alpha=1)} + else: + special_formats = {} + return parallel_coords_plot(field_names=field_names, values=lines, special_formats=special_formats) From c46b2fd01d13b6bef4063afa2fa22bb2582d2dc1 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Thu, 11 Apr 2019 08:57:01 -0700 Subject: [PATCH 31/41] indice build --- artemis/experiments/experiment_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/artemis/experiments/experiment_management.py b/artemis/experiments/experiment_management.py index b6e5ce90..dea19ce9 100644 --- a/artemis/experiments/experiment_management.py +++ b/artemis/experiments/experiment_management.py @@ -175,7 +175,7 @@ def select_experiment_records(user_range, exp_record_dict, flat=True, load_recor :param user_range: :param exp_record_dict: An OrderedDict> :param flat: Return a list of experiment records, instead of an OrderedDict - :return: if not flat, an An OrderedDict> + :return: if not flat, an OrderedDict> otherwise a list """ filters = _filter_records(user_range, exp_record_dict) From 7003281655932a1abe07c642e21bc7bf0f4ee2a8 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Tue, 16 Apr 2019 11:49:22 +0900 Subject: [PATCH 32/41] addressed test fail --- artemis/experiments/experiment_record_view.py | 4 ++-- artemis/experiments/experiments.py | 11 ++++------- .../experiments/test_experiment_record_view_and_ui.py | 2 +- artemis/experiments/test_experiments.py | 2 +- artemis/general/global_rates.py | 4 ++-- artemis/ml/tools/running_averages.py | 2 +- artemis/plotting/point_remapping_plots.py | 2 +- artemis/plotting/range_plots.py | 11 +++++------ 8 files changed, 17 insertions(+), 21 deletions(-) diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 322c0882..36c19f13 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -491,9 +491,9 @@ def compare_timeseries_records(records, yfield, xfield = None, hang=True, ax=Non for result, argvals in izip_equal(results, values): xvals = [r[xfield] for r in result] if xfield is not None else list(range(len(result))) # yvals = [r[yfield[0]] for r in result] - h, = ax.plot(xvals, [r[yfield[0]] for r in result], label=(yfield[0]+': ' if len(yfield)>1 else '')+', '.join(f'{argname}={argval}' for argname, argval in izip_equal(all_different_args, argvals))) + h, = ax.plot(xvals, [r[yfield[0]] for r in result], label=(yfield[0]+': ' if len(yfield)>1 else '')+', '.join('{}={}'.format(argname, argval) for argname, argval in izip_equal(all_different_args, argvals))) for yf, linestyle in zip(yfield[1:], itertools.cycle(['--', ':', '-.'])): - ax.plot(xvals, [r[yf] for r in result], linestyle=linestyle, color=h.get_color(), label=yf+': '+', '.join(f'{argname}={argval}' for argname, argval in izip_equal(all_different_args, argvals))) + ax.plot(xvals, [r[yf] for r in result], linestyle=linestyle, color=h.get_color(), label=yf+': '+', '.join('{}={}'.format(argname, argval) for argname, argval in izip_equal(all_different_args, argvals))) ax.grid(True) if xfield is not None: diff --git a/artemis/experiments/experiments.py b/artemis/experiments/experiments.py index c803f786..f9ca8495 100644 --- a/artemis/experiments/experiments.py +++ b/artemis/experiments/experiments.py @@ -424,7 +424,7 @@ def get_variant_records(self, only_completed=False, only_last=False, flat=False) else: return exp_record_dict - def add_parameter_search(self, name='parameter_search', fixed_args = {}, space = None, n_calls=100, search_params = None, scalar_func=None): + def add_parameter_search(self, name='parameter_search', fixed_args = {}, space = None, n_calls=None, search_params = None, scalar_func=None): """ :param name: Name of the Experiment to be created :param dict[str, Any] fixed_args: Any fixed-arguments to provide to all experiments. @@ -435,6 +435,7 @@ def add_parameter_search(self, name='parameter_search', fixed_args = {}, space = :param dict[str, Any] search_params: Args passed to parameter_search :return Experiment: A new experiment which runs the search and yields current-best parameters with every iteration. """ + assert space is not None, "You must specify a parameter search space. See this method's documentation" if name is None: # TODO: Set name=None in the default after deadline name = 'parameter_search[{}]'.format(','.join(space.keys())) @@ -451,13 +452,9 @@ def objective(**current_params): from artemis.experiments import ExperimentFunction def search_func(fixed): - if is_test_mode(): - nonlocal n_calls - n_calls = 3 # When just verifying that experiment runs, do the minimum - + n_calls_to_make = n_calls if n_calls is not None else 3 if is_test_mode() else 100 this_objective = partial(objective, **fixed) - - for iter_info in parameter_search(this_objective, n_calls=n_calls, space=space, **search_params): + for iter_info in parameter_search(this_objective, n_calls=n_calls_to_make, space=space, **search_params): info = dict(names=list(space.keys()), x_iters =iter_info.x_iters, func_vals=iter_info.func_vals, score = iter_info.func_vals, x=iter_info.x, fun=iter_info.fun) latest_info = {name: val for name, val in izip_equal(info['names'], iter_info.x_iters[-1])} print(f'Latest: {latest_info}, Score: {iter_info.func_vals[-1]:.3g}') diff --git a/artemis/experiments/test_experiment_record_view_and_ui.py b/artemis/experiments/test_experiment_record_view_and_ui.py index 9b2bba66..3813c478 100644 --- a/artemis/experiments/test_experiment_record_view_and_ui.py +++ b/artemis/experiments/test_experiment_record_view_and_ui.py @@ -1,6 +1,7 @@ import pytest from artemis.experiments.decorators import ExperimentFunction, experiment_function +from artemis.experiments.experiment_record import ExperimentRecord from artemis.experiments.experiment_record import save_figure_in_record from artemis.experiments.experiment_record_view import get_oneline_result_string, print_experiment_record_argtable, \ compare_experiment_records, get_record_invalid_arg_string, browse_record_figs @@ -246,4 +247,3 @@ def my_exp(): test_simple_experiment_show() test_view_modes() test_duplicate_headers_when_no_records_bug_is_gone() - # demo_browse_record_figs() \ No newline at end of file diff --git a/artemis/experiments/test_experiments.py b/artemis/experiments/test_experiments.py index 6208e28a..7bf5bfea 100644 --- a/artemis/experiments/test_experiments.py +++ b/artemis/experiments/test_experiments.py @@ -127,7 +127,7 @@ def bowl(x, y): ex_search = bowl.add_parameter_search( space = {'x': Real(-5, 5, 'uniform'), 'y': Real(-5, 5, 'uniform')}, scalar_func=lambda result: result['z'], - search_params=dict(n_calls=5) + search_params=dict(n_calls=5), ) record = ex_search.run() diff --git a/artemis/general/global_rates.py b/artemis/general/global_rates.py index 4128f042..94d6ed34 100644 --- a/artemis/general/global_rates.py +++ b/artemis/general/global_rates.py @@ -84,7 +84,7 @@ def is_elapsed(identifier, period, current = None, count_initial = True): return count_initial else: last = get_global(key) - assert current>=last, f"Current value ({current}) must be greater or equal to the last value ({last})" + assert current>=last, "Current value ({}) must be greater or equal to the last value ({})".format(current, last) has_elapsed = current - last >= period if has_elapsed: set_global(key, current) @@ -106,7 +106,7 @@ def limit_rate(identifier, period): return False else: last = get_global(key) - assert enter_time>=last, f"Current value ({current}) must be greater or equal to the last value ({last})" + assert enter_time>=last, "Current value ({}) must be greater or equal to the last value ({})".format(enter_time, last) elapsed = enter_time - last if elapsed < period: # Rate has been exceeded time.sleep(period - elapsed) diff --git a/artemis/ml/tools/running_averages.py b/artemis/ml/tools/running_averages.py index 0a09705f..b2730473 100644 --- a/artemis/ml/tools/running_averages.py +++ b/artemis/ml/tools/running_averages.py @@ -178,6 +178,6 @@ def periodically_report_running_average(identifier, time, period, value, ra_type if not isinstance(value, dict): avg = get_global_running_average(value=value, identifier=identifier, ra_type=ra_type, reset=reset_between and report_time) else: - avg = {k: f'{get_global_running_average(value=v, identifier=(identifier, k), ra_type=ra_type, reset=reset_between and report_time):.3g}' for k, v in value.items()} + avg = {k: '{:.3g}'.format(get_global_running_average(value=v, identifier=(identifier, k), ra_type=ra_type, reset=reset_between and report_time)) for k, v in value.items()} if report_time: print(format_str.format(identifier=identifier, time=time, avg=avg)) diff --git a/artemis/plotting/point_remapping_plots.py b/artemis/plotting/point_remapping_plots.py index dc054712..bc609d11 100644 --- a/artemis/plotting/point_remapping_plots.py +++ b/artemis/plotting/point_remapping_plots.py @@ -34,6 +34,6 @@ def plot_2D_mapping(old_xy_points, new_xy_points, axes = None, old_title = 'x', # Apply some transformation theta = 5*np.pi/6 transform_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) - new_xy_points = np.tanh(old_xy_points @ transform_matrix) + new_xy_points = np.tanh(np.dot(old_xy_points, transform_matrix)) plot_2D_mapping(old_xy_points, new_xy_points) diff --git a/artemis/plotting/range_plots.py b/artemis/plotting/range_plots.py index 6b370d45..5e1fc986 100644 --- a/artemis/plotting/range_plots.py +++ b/artemis/plotting/range_plots.py @@ -2,7 +2,7 @@ import numpy as np -def plot_sample_mean_and_var(*x_and_ys, var_rep ='std', fill_alpha = 0.25, **plot_kwargs): +def plot_sample_mean_and_var(x_or_ys, ys=None, var_rep ='std', fill_alpha = 0.25, **plot_kwargs): """ Given a collection of signals, plot their mean and fill a range around the mean. Example: x = np.arange(-5, 5) @@ -17,12 +17,11 @@ def plot_sample_mean_and_var(*x_and_ys, var_rep ='std', fill_alpha = 0.25, **plo :param plot_kwargs: :return: """ - if len(x_and_ys)==2: - x, ys = x_and_ys - else: - assert len(x_and_ys) == 1, "You must provide unnamed arguments in order (ys) or (x, ys)" - ys, = x_and_ys + if ys is None: + ys = x_or_ys x = range(len(ys[0])) + else: + x = x_or_ys assert var_rep in ('std', 'sterr', 'lim') From 6c3e2142b04e92bebb9fbe09a2fd05687c2ad8d0 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Tue, 16 Apr 2019 15:39:17 +0900 Subject: [PATCH 33/41] fixed --- artemis/general/global_rates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/artemis/general/global_rates.py b/artemis/general/global_rates.py index 94d6ed34..20ed7ba6 100644 --- a/artemis/general/global_rates.py +++ b/artemis/general/global_rates.py @@ -59,7 +59,7 @@ def elapsed_time(identifier, current = None): return float('inf') else: last = get_global(key) - assert current>=last, f"Current value ({current}) must be greater or equal to the last value ({last})" + assert current>=last, "Current value ({}) must be greater or equal to the last value ({})".format(current, last) elapsed = current - last set_global(key, current) return elapsed From c8ba0d8ee72fa74112819665a44e38193936b0c5 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Tue, 16 Apr 2019 17:30:39 +0900 Subject: [PATCH 34/41] addressed errors --- artemis/experiments/experiment_record_view.py | 3 +-- artemis/experiments/experiments.py | 4 ++-- artemis/general/global_vars.py | 2 +- artemis/general/mymath.py | 5 ++++- artemis/general/progress_indicator.py | 2 +- artemis/ml/tools/running_averages.py | 2 +- artemis/plotting/test_animation.py | 4 +++- 7 files changed, 13 insertions(+), 9 deletions(-) diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 36c19f13..9f6d9229 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -513,10 +513,9 @@ def get_timeseries_record_comparison_function(yfield, xfield = None, hang=True, return lambda records: compare_timeseries_records(records, yfield, xfield = xfield, hang=hang, ax=ax) - def timeseries_oneliner_function(result, fields, show_len, show = 'last'): assert show=='last', 'Only support showing last element now' - return (f'{len(result)} items. ' if show_len else '')+', '.join(f'{k}: {result[-1][k]:.3g}' if isinstance(result[-1][k], float) else f'{k}: {result[-1][k]}' for k in fields) + return ('{} items. '.format(len(result)) if show_len else '')+', '.join('{}: {:.3g}'.format(k, result[-1][k]) if isinstance(result[-1][k], float) else '{}: {}'.format(k, result[-1][k]) for k in fields) def get_timeseries_oneliner_function(fields, show_len=False, show='last'): diff --git a/artemis/experiments/experiments.py b/artemis/experiments/experiments.py index f9ca8495..bf0aa78f 100644 --- a/artemis/experiments/experiments.py +++ b/artemis/experiments/experiments.py @@ -457,7 +457,7 @@ def search_func(fixed): for iter_info in parameter_search(this_objective, n_calls=n_calls_to_make, space=space, **search_params): info = dict(names=list(space.keys()), x_iters =iter_info.x_iters, func_vals=iter_info.func_vals, score = iter_info.func_vals, x=iter_info.x, fun=iter_info.fun) latest_info = {name: val for name, val in izip_equal(info['names'], iter_info.x_iters[-1])} - print(f'Latest: {latest_info}, Score: {iter_info.func_vals[-1]:.3g}') + print('Latest: {}, Score: {:.3g}'.format(latest_info, iter_info.func_vals[-1])) yield info # The following is a hack to dynamically create a function with the given args @@ -494,7 +494,7 @@ def show_parameter_search_record(record): def parameter_search_one_liner(result): - return f'{len(result["x_iters"])} Runs : ' + ', '.join(f'{k}={v:.3g}' for k, v in izip_equal(result['names'], result['x'])) + f' : Score = {result["fun"]:.3g}' + return '{} Runs : '.format(len(result["x_iters"])) + ', '.join('{}={:.3g}'.format(k, v) for k, v in izip_equal(result['names'], result['x'])) + ' : Score = {:.3g}'.format(result["fun"]) _GLOBAL_EXPERIMENT_LIBRARY = OrderedDict() diff --git a/artemis/general/global_vars.py b/artemis/general/global_vars.py index 6cc51940..34c69126 100644 --- a/artemis/general/global_vars.py +++ b/artemis/general/global_vars.py @@ -1,4 +1,4 @@ -from decorator import contextmanager +from contextlib import contextmanager _GLOBALS = {} diff --git a/artemis/general/mymath.py b/artemis/general/mymath.py index b02bfbed..bc9ae100 100644 --- a/artemis/general/mymath.py +++ b/artemis/general/mymath.py @@ -188,7 +188,10 @@ def recent_moving_average(x, axis = 0): a[t] = (1-frac)*a[t-1] + frac*x[t] """ - import weave # ONLY WORKS IN PYTHON 2.X !!! + try: + import weave # ONLY WORKS IN PYTHON 2.X !!! + except: + raise ImportError('Weave module could not be found. Maybe because it only works in Python 2.X') if x.ndim!=2: y = recent_moving_average(x.reshape(x.shape[0], x.size//x.shape[0]), axis=0) return y.reshape(x.shape) diff --git a/artemis/general/progress_indicator.py b/artemis/general/progress_indicator.py index 7af0d3b4..f98c9d81 100644 --- a/artemis/general/progress_indicator.py +++ b/artemis/general/progress_indicator.py @@ -1,6 +1,6 @@ import time -from decorator import contextmanager +from contextlib import contextmanager class ProgressIndicator(object): diff --git a/artemis/ml/tools/running_averages.py b/artemis/ml/tools/running_averages.py index b2730473..d0842636 100644 --- a/artemis/ml/tools/running_averages.py +++ b/artemis/ml/tools/running_averages.py @@ -39,7 +39,7 @@ def __call__(self, data): def batch(cls, x): try: return recent_moving_average(x, axis=0) # Works only for python 2.X, with weave - except ModuleNotFoundError: + except ImportError: rma = RecentRunningAverage() return np.array([rma(xt) for xt in x]) diff --git a/artemis/plotting/test_animation.py b/artemis/plotting/test_animation.py index 055b7ffa..ca525de1 100644 --- a/artemis/plotting/test_animation.py +++ b/artemis/plotting/test_animation.py @@ -1,9 +1,11 @@ import matplotlib import pytest -matplotlib.use('TkAgg') import matplotlib.pyplot as plt import numpy as np +from matplotlib.testing import is_called_from_pytest from six.moves import xrange +if not is_called_from_pytest(): + matplotlib.use('TkAgg') __author__ = 'peter' From 8d0123e336a0d246674565f4b5e22990b104ba59 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Tue, 16 Apr 2019 23:53:34 +0900 Subject: [PATCH 35/41] again --- artemis/general/async.py | 3 ++- artemis/general/dead_easy_ui.py | 11 ++--------- artemis/general/iteratorize.py | 9 +++++---- artemis/ml/tools/test_processors.py | 2 +- artemis/plotting/test_animation.py | 6 +----- 5 files changed, 11 insertions(+), 20 deletions(-) diff --git a/artemis/general/async.py b/artemis/general/async.py index 36b794e5..8d11cdc6 100644 --- a/artemis/general/async.py +++ b/artemis/general/async.py @@ -1,4 +1,4 @@ -from multiprocessing import Process, Queue, Manager, Lock, set_start_method +from multiprocessing import Process, Queue, Manager, Lock import time @@ -46,6 +46,7 @@ def iter_latest_asynchonously(gen_func, timeout = None, empty_value = None, use_ :return: """ if use_forkserver: + from multiprocessing import set_start_method # Only Python 3.X set_start_method('forkserver') # On macos this is necessary to start camera in separate thread m = Manager() diff --git a/artemis/general/dead_easy_ui.py b/artemis/general/dead_easy_ui.py index d415602d..bfe9a482 100644 --- a/artemis/general/dead_easy_ui.py +++ b/artemis/general/dead_easy_ui.py @@ -1,8 +1,6 @@ -from __future__ import print_function +from __future__ import print_function, absolute_import from __future__ import absolute_import -from builtins import range -from builtins import input -from builtins import zip +from six.moves import input import inspect import shlex from collections import OrderedDict @@ -139,11 +137,6 @@ def parse_user_function_call(cmd_str, arg_handling_mode = 'fallback'): """ assert arg_handling_mode in ('str', 'literal', 'fallback') - - # def _fake_func(*args, **kwargs): - # Just exists to help with extracting args, kwargs - # return args, kwargs - cmd_args = shlex.split(cmd_str, posix=False) assert len(cmd_args) == len(shlex.split(cmd_str, posix=True)), "Parse error on string '{}'. You're not allowed having spaces in the values of string keyword args:".format(cmd_str) diff --git a/artemis/general/iteratorize.py b/artemis/general/iteratorize.py index f5581887..bfca0a61 100644 --- a/artemis/general/iteratorize.py +++ b/artemis/general/iteratorize.py @@ -4,10 +4,12 @@ Thanks to Brice for this piece of code. Taken from https://stackoverflow.com/a/9969000/851699 """ - -# from thread import start_new_thread from collections import Iterable -from queue import Queue +import sys +if sys.version_info < (3, 0): + from Queue import Queue +else: + from queue import Queue from threading import Thread @@ -22,7 +24,6 @@ def __init__(self, func): :param Callable[Callable, Any] func: A function that takes a callback as an argument then runs. """ self.mfunc = func - # self.ifunc = ifunc self.q = Queue(maxsize=1) self.sentinel = object() diff --git a/artemis/ml/tools/test_processors.py b/artemis/ml/tools/test_processors.py index 2c94a7be..ffa44350 100644 --- a/artemis/ml/tools/test_processors.py +++ b/artemis/ml/tools/test_processors.py @@ -2,7 +2,7 @@ import pytest from six.moves import xrange -from artemis.ml.tools.processors import RunningAverage, RecentRunningAverage +from artemis.ml.tools.running_averages import RunningAverage, RecentRunningAverage __author__ = 'peter' diff --git a/artemis/plotting/test_animation.py b/artemis/plotting/test_animation.py index ca525de1..fd3d1d6f 100644 --- a/artemis/plotting/test_animation.py +++ b/artemis/plotting/test_animation.py @@ -1,11 +1,7 @@ -import matplotlib -import pytest import matplotlib.pyplot as plt import numpy as np -from matplotlib.testing import is_called_from_pytest +import pytest from six.moves import xrange -if not is_called_from_pytest(): - matplotlib.use('TkAgg') __author__ = 'peter' From 114c5464109f85fc160e47c9506ecb1010ea5d30 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Wed, 17 Apr 2019 06:25:23 +0900 Subject: [PATCH 36/41] should pass? --- artemis/experiments/test_experiments.py | 1 + artemis/general/test_scannable_functions.py | 2 +- requirements.txt | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/artemis/experiments/test_experiments.py b/artemis/experiments/test_experiments.py index 7bf5bfea..db0b57ad 100644 --- a/artemis/experiments/test_experiments.py +++ b/artemis/experiments/test_experiments.py @@ -114,6 +114,7 @@ def my_exp(a, b, c): assert XXXX() == 1+(5*5)*5 +@pytest.mark.skipif(True, reason='We dont want to make scikit-optimize a hard requirement just for this so we skip the test.') def test_parameter_search(): from skopt.space import Real diff --git a/artemis/general/test_scannable_functions.py b/artemis/general/test_scannable_functions.py index 1e30187a..29174ef7 100644 --- a/artemis/general/test_scannable_functions.py +++ b/artemis/general/test_scannable_functions.py @@ -177,4 +177,4 @@ def moving_average(x, avg=0, t=0): test_stateless_updater() test_stateless_updater_with_decorator() test_stateful_updater() - test_stateful_updater_with_decorator() \ No newline at end of file + test_stateful_updater_with_decorator() diff --git a/requirements.txt b/requirements.txt index ce1a70c1..a13d78f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ tabulate netifaces paramiko si-prefix +recordclass # weave # Only works in python 2.X # Other things we may want (uncomment to add these to requirements) # scikit-learn From 11c0a1b14e4228135ad334431ef67debd2d6f99f Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Wed, 17 Apr 2019 11:28:12 +0900 Subject: [PATCH 37/41] address tests --- artemis/general/test_scannable_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/artemis/general/test_scannable_functions.py b/artemis/general/test_scannable_functions.py index 29174ef7..a92004aa 100644 --- a/artemis/general/test_scannable_functions.py +++ b/artemis/general/test_scannable_functions.py @@ -31,7 +31,7 @@ def moving_average(x, decay, avg=0): simply_smoothed_signal = [f(x=x, decay=1./(t+1)) for t, x in enumerate(seq)] truth = np.cumsum(seq)/np.arange(1, len(seq)+1) assert np.allclose(simply_smoothed_signal, truth) - assert list(f._fields)==['avg'] + assert list(f._asdict().keys())==['avg'] assert np.allclose(f.avg, np.mean(seq)) f = moving_average.mutable_scan() From f1635f3757c556f875f9159fc72b99db0e2d5c12 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Wed, 17 Apr 2019 11:51:33 +0900 Subject: [PATCH 38/41] oops forgot to push cleanup --- artemis/experiments/experiment_record.py | 15 --------------- artemis/experiments/experiment_record_view.py | 2 -- 2 files changed, 17 deletions(-) diff --git a/artemis/experiments/experiment_record.py b/artemis/experiments/experiment_record.py index 1d26f980..7502458e 100644 --- a/artemis/experiments/experiment_record.py +++ b/artemis/experiments/experiment_record.py @@ -663,21 +663,6 @@ def clear_experiment_records(ids): ExperimentRecord(exp_path).delete() -# -# -# def save_figure_in_current_experiment_directory(name='fig-{}.pkl', figure = None): -# -# if figure is None: -# figure = plt.gcf() -# -# current_dir = get_current_record_dir() -# start_ix = _figure_ixs[current_dir] if current_dir in _figure_ixs else 0 -# for ix in count(start_ix): -# full_path = os.path.join(current_dir, name).format(ix) -# if not os.path.exists(_figure_ixs[current_dir]): -# save_figure(figure, path = full_path) -# _figure_ixs[current_dir] = ix+1 -# return full_path _figure_ixs = {} diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 9f6d9229..185c5724 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -285,8 +285,6 @@ def lookup_fcn(record_id, arg_or_result_name): return rows[0], rows[1:] - # return tabulate(rows[1:], headers=rows[0]) - def show_record(record, show_logs=True, truncate_logs=None, truncate_result=10000, header_width=100, show_result ='deep', hang=True): """ From b98f38b3b90d39f00f2df4e3c086a1a63e53bac7 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Wed, 17 Apr 2019 11:52:57 +0900 Subject: [PATCH 39/41] more clean --- artemis/experiments/experiment_record_view.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/artemis/experiments/experiment_record_view.py b/artemis/experiments/experiment_record_view.py index 185c5724..9bad1052 100644 --- a/artemis/experiments/experiment_record_view.py +++ b/artemis/experiments/experiment_record_view.py @@ -542,14 +542,12 @@ def show_figure(ix): dir, name = os.path.split(path) if nonlocals.this_fig is not None: plt.close(nonlocals.this_fig) - # with interactive_matplotlib_context(): plt.close(plt.gcf()) with open(path, "rb") as f: fig = pickle.load(f) fig.canvas.set_window_title(record.get_id()+': ' +name+': (Figure {}/{})'.format(ix+1, len(fig_locs))) fig.canvas.mpl_connect('key_press_event', changefig) print('Showing {}: Figure {}/{}. Full path: {}'.format(name, ix+1, len(fig_locs), path)) - # redraw_figure() plt.show() nonlocals.this_fig = plt.gcf() From ba685ebf72742f702fd34442b277f1ee094039d8 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 15 Sep 2022 08:36:38 -0700 Subject: [PATCH 40/41] added a bunch of typing and allowed experiment selection by user range --- artemis/experiments/experiment_management.py | 155 ++++++++++--------- artemis/experiments/experiments.py | 142 ++++++++++------- artemis/fileman/disk_memoize.py | 28 +++- artemis/fileman/test_disk_memoize.py | 59 +++++-- artemis/general/display.py | 2 +- artemis/general/functional.py | 52 +++++-- artemis/general/hashing.py | 1 + 7 files changed, 278 insertions(+), 161 deletions(-) diff --git a/artemis/experiments/experiment_management.py b/artemis/experiments/experiment_management.py index dea19ce9..0cba0b00 100644 --- a/artemis/experiments/experiment_management.py +++ b/artemis/experiments/experiment_management.py @@ -13,6 +13,7 @@ from time import time import math +from typing import Union, Sequence, Mapping from artemis.fileman.local_dir import make_dir from artemis.general.display import equalize_string_lengths @@ -20,9 +21,9 @@ from six.moves import reduce, xrange from artemis.experiments.experiment_record import (load_experiment_record, ExpInfoFields, ExpStatusOptions, ARTEMIS_LOGGER, record_id_to_experiment_id, - get_all_record_ids, get_experiment_dir, has_experiment_record) + get_all_record_ids, get_experiment_dir, has_experiment_record, ExperimentRecord) from artemis.experiments.experiments import load_experiment, get_global_experiment_library -from artemis.fileman.config_files import get_home_dir,set_non_persistent_config_value +from artemis.fileman.config_files import get_home_dir, set_non_persistent_config_value from artemis.general.hashing import compute_fixed_hash from artemis.general.time_parser import parse_time from artemis.remote.child_processes import SlurmPythonProcess @@ -31,7 +32,7 @@ divide_into_subsets -def pull_experiment_records(user, ip, experiment_names, include_variants=True, need_pass = False): +def pull_experiment_records(user, ip, experiment_names, include_variants=True, need_pass=False): """ Pull experiments from another computer matching the given experiment name. @@ -49,19 +50,19 @@ def pull_experiment_records(user, ip, experiment_names, include_variants=True, n home = get_home_dir() - file_list = ["**/*-{exp_name}{variants}/*".format(exp_name=exp_name, variants = '*' if include_variants else '') for exp_name in experiment_names] + file_list = ["**/*-{exp_name}{variants}/*".format(exp_name=exp_name, variants='*' if include_variants else '') for exp_name in experiment_names] _, experiment_directory_file = tempfile.mkstemp() with open(experiment_directory_file, 'w') as f: f.write('\n'.join(file_list)) # This one works if you have keys set up - command = ['rsync', '-a', '-m', '-i']\ - +['{user}@{ip}:~/.artemis/experiments/'.format(user=user, ip=ip)] \ - +['{home}/.artemis/experiments/'.format(home=home)]\ - +["--include-from={}".format(experiment_directory_file)]\ - +["--include='*/'", "--exclude='*'"] - # +["--include='**/*-{exp_name}{variants}/*'".format(exp_name=exp_name, variants = '*' if include_variants else '') for exp_name in experiment_names] # This was the old line, but it could be too long for many experiments. + command = ['rsync', '-a', '-m', '-i'] \ + + ['{user}@{ip}:~/.artemis/experiments/'.format(user=user, ip=ip)] \ + + ['{home}/.artemis/experiments/'.format(home=home)] \ + + ["--include-from={}".format(experiment_directory_file)] \ + + ["--include='*/'", "--exclude='*'"] + # +["--include='**/*-{exp_name}{variants}/*'".format(exp_name=exp_name, variants = '*' if include_variants else '') for exp_name in experiment_names] # This was the old line, but it could be too long for many experiments. if not need_pass: output = subprocess.check_output(' '.join(command), shell=True) @@ -81,7 +82,7 @@ def pull_experiment_records(user, ip, experiment_names, include_variants=True, n return output -def load_lastest_experiment_results(experiments, error_if_no_result = True): +def load_lastest_experiment_results(experiments, error_if_no_result=True): """ Given a list of experiments (or experiment ids), return an OrderedDict :param experiments: A list of Experiment objects (or strings identifying experiment ID is ok too) @@ -95,7 +96,7 @@ def load_lastest_experiment_results(experiments, error_if_no_result = True): return experiment_latest_results -def load_record_results(records, err_if_no_result =True, index_by_id = False): +def load_record_results(records, err_if_no_result=True, index_by_id=False): """ Given a list of experiment records, return an OrderedDict :param records: A list of ExperimentRecord objects @@ -116,7 +117,7 @@ def load_record_results(records, err_if_no_result =True, index_by_id = False): return results -def select_experiments(user_range, exp_record_dict, return_dict=False): +def select_experiments(user_range, exp_record_dict, return_dict=False) -> Union[Sequence[ExperimentRecord], Mapping[str, ExperimentRecord]]: exp_filter = _filter_experiments(user_range, exp_record_dict) if return_dict: return OrderedDict((name, exp_record_dict[name]) for name in exp_record_dict if exp_filter[name]) @@ -124,8 +125,7 @@ def select_experiments(user_range, exp_record_dict, return_dict=False): return [name for name in exp_record_dict if exp_filter[name]] -def _filter_experiments(user_range, exp_record_dict, return_is_in = False): - +def _filter_experiments(user_range, exp_record_dict, return_is_in=False): if '|' in user_range: is_in = [any(xs) for xs in zip(*(_filter_experiments(subrange, exp_record_dict, return_is_in=True) for subrange in user_range.split('|')))] elif '&' in user_range: @@ -135,13 +135,15 @@ def _filter_experiments(user_range, exp_record_dict, return_is_in = False): is_in = [not r for r in is_in] else: if user_range in exp_record_dict: - is_in = [k==user_range for k in exp_record_dict] + is_in = [k == user_range for k in exp_record_dict] + elif user_range in (k.split('.')[-1] for k in exp_record_dict): + is_in = [user_range == k.split('.')[-1] for k in exp_record_dict] else: number_range = interpret_numbers(user_range) if number_range is not None: is_in = [i in number_range for i in xrange(len(exp_record_dict))] elif user_range == 'all': - is_in = [True]*len(exp_record_dict) + is_in = [True] * len(exp_record_dict) elif user_range.startswith('has:'): phrase = user_range[len('has:'):] is_in = [phrase in exp_id for exp_id in exp_record_dict] @@ -150,9 +152,10 @@ def _filter_experiments(user_range, exp_record_dict, return_is_in = False): is_in = [tag in load_experiment(exp_id).get_tags() for exp_id in exp_record_dict] elif user_range.startswith('1diff:'): base_range = user_range[len('1diff:'):] - base_range_exps = select_experiments(base_range, exp_record_dict) # list - all_exp_args_hashes = {eid: set(compute_fixed_hash(a) for a in load_experiment(eid).get_args().items()) for eid in exp_record_dict} # dict> - is_in = [any(len(all_exp_args_hashes[eid].difference(all_exp_args_hashes[other_eid]))<=1 for other_eid in base_range_exps) for eid in exp_record_dict] + base_range_exps = select_experiments(base_range, exp_record_dict) # list + all_exp_args_hashes = {eid: set(compute_fixed_hash(a) for a in load_experiment(eid).get_args().items()) for eid in + exp_record_dict} # dict> + is_in = [any(len(all_exp_args_hashes[eid].difference(all_exp_args_hashes[other_eid])) <= 1 for other_eid in base_range_exps) for eid in exp_record_dict] elif user_range.startswith('hasnot:'): phrase = user_range[len('hasnot:'):] is_in = [phrase not in exp_id for exp_id in exp_record_dict] @@ -170,7 +173,7 @@ def _filter_experiments(user_range, exp_record_dict, return_is_in = False): return OrderedDict((exp_id, exp_is_in) for exp_id, exp_is_in in izip_equal(exp_record_dict, is_in)) -def select_experiment_records(user_range, exp_record_dict, flat=True, load_records = True): +def select_experiment_records(user_range, exp_record_dict, flat=True, load_records=True): """ :param user_range: :param exp_record_dict: An OrderedDict> @@ -218,16 +221,15 @@ def _bitwise_not(a): def _bitwise_filter_op(op, *filter_sets): - output_set = filter_sets[0].copy() - if op=='not': - assert len(filter_sets)==1 + if op == 'not': + assert len(filter_sets) == 1 for k in output_set.keys(): output_set[k] = _bitwise_not(filter_sets[0][k]) elif op in ('and', 'or'): for k in output_set.keys(): - output_set[k] = reduce(_bitwise_and if op=='and' else _bitwise_or, [fs[k] for fs in filter_sets]) - elif op=='andcascade': + output_set[k] = reduce(_bitwise_and if op == 'and' else _bitwise_or, [fs[k] for fs in filter_sets]) + elif op == 'andcascade': for k in output_set.keys(): output_set[k] = reduce(_bitwise_andcascade, [fs[k] for fs in filter_sets[::-1]]) else: @@ -236,14 +238,14 @@ def _bitwise_filter_op(op, *filter_sets): _named_record_filters = {} -_named_record_filters['old'] = lambda rec_ids: ([True]*(len(rec_ids)-1)+[False]) if len(rec_ids)>0 else [] -_named_record_filters['corrupt'] = lambda rec_ids: [load_experiment_record(rec_id).info.get_status_field()==ExpStatusOptions.CORRUPT for rec_id in rec_ids] +_named_record_filters['old'] = lambda rec_ids: ([True] * (len(rec_ids) - 1) + [False]) if len(rec_ids) > 0 else [] +_named_record_filters['corrupt'] = lambda rec_ids: [load_experiment_record(rec_id).info.get_status_field() == ExpStatusOptions.CORRUPT for rec_id in rec_ids] _named_record_filters['finished'] = lambda rec_ids: [load_experiment_record(rec_id).info.get_field(ExpInfoFields.STATUS) == ExpStatusOptions.FINISHED for rec_id in rec_ids] _named_record_filters['invalid'] = lambda rec_ids: [load_experiment_record(rec_id).args_valid() is False for rec_id in rec_ids] -_named_record_filters['all'] = lambda rec_ids: [True]*len(rec_ids) -_named_record_filters['errors'] = lambda rec_ids: [load_experiment_record(rec_id).info.get_field(ExpInfoFields.STATUS)==ExpStatusOptions.ERROR for rec_id in rec_ids] +_named_record_filters['all'] = lambda rec_ids: [True] * len(rec_ids) +_named_record_filters['errors'] = lambda rec_ids: [load_experiment_record(rec_id).info.get_field(ExpInfoFields.STATUS) == ExpStatusOptions.ERROR for rec_id in rec_ids] _named_record_filters['result'] = lambda rec_ids: [load_experiment_record(rec_id).has_result() for rec_id in rec_ids] -_named_record_filters['running'] = lambda rec_ids: [load_experiment_record(rec_id).info.get_field(ExpInfoFields.STATUS)==ExpStatusOptions.STARTED for rec_id in rec_ids] +_named_record_filters['running'] = lambda rec_ids: [load_experiment_record(rec_id).info.get_field(ExpInfoFields.STATUS) == ExpStatusOptions.STARTED for rec_id in rec_ids] def _filter_records(user_range, exp_record_dict): @@ -270,9 +272,9 @@ def _filter_records(user_range, exp_record_dict): :return: An OrderedDict list> indicating whether each record from the given experiment passed the filter """ - if user_range=='unfinished': + if user_range == 'unfinished': return _filter_records('~finished', exp_record_dict) - elif user_range=='last': + elif user_range == 'last': return _filter_records('~old', exp_record_dict) elif '|' in user_range: return _bitwise_filter_op('or', *[_filter_records(subrange, exp_record_dict) for subrange in user_range.split('|')]) @@ -280,7 +282,7 @@ def _filter_records(user_range, exp_record_dict): return _bitwise_filter_op('and', *[_filter_records(subrange, exp_record_dict) for subrange in user_range.split('&')]) elif '@' in user_range: ix = user_range.index('@') - first_part, second_part = user_range[:ix], user_range[ix+1:] + first_part, second_part = user_range[:ix], user_range[ix + 1:] _first_stage_filters = _filter_records(first_part, exp_record_dict) _new_dict = _select_record_ids_from_filters(_first_stage_filters, exp_record_dict) _second_stage_filters = _filter_records(second_part, _new_dict) @@ -289,9 +291,9 @@ def _filter_records(user_range, exp_record_dict): elif user_range.startswith('~'): return _bitwise_filter_op('not', _filter_records(user_range[1:], exp_record_dict)) - base = OrderedDict((k, [False]*len(v)) for k, v in exp_record_dict.items()) + base = OrderedDict((k, [False] * len(v)) for k, v in exp_record_dict.items()) if user_range in exp_record_dict: # User just lists an experiment - base[user_range] = [True]*len(base[user_range]) + base[user_range] = [True] * len(base[user_range]) return base number_range = interpret_numbers(user_range) @@ -302,20 +304,20 @@ def _filter_records(user_range, exp_record_dict): base[exp_id] = _named_record_filters[user_range](exp_record_dict[exp_id]) elif number_range is not None: # e.g. '6-12' for i in number_range: - if i>len(keys): - raise RecordSelectionError('Experiment {} does not exist (they go from 0 to {})'.format(i, len(keys)-1)) - base[keys[i]] = [True]*len(base[keys[i]]) + if i > len(keys): + raise RecordSelectionError('Experiment {} does not exist (they go from 0 to {})'.format(i, len(keys) - 1)) + base[keys[i]] = [True] * len(base[keys[i]]) elif '.' in user_range: # e.b. 6.3-4 exp_rec_pairs = interpret_record_identifier(user_range) for exp_number, rec_number in exp_rec_pairs: - if rec_number>=len(base[keys[exp_number]]): + if rec_number >= len(base[keys[exp_number]]): raise RecordSelectionError('Selection {}.{} does not exist.'.format(exp_number, rec_number)) base[keys[exp_number]][rec_number] = True elif user_range.startswith('dur') or user_range.startswith('age'): # Eg dur<25 Means "All records that ran less than 25s" try: sign = user_range[3] assert sign in ('<', '>') - filter_func = (lambda a, b: (a is not None and b is not None) and ab) + filter_func = (lambda a, b: (a is not None and b is not None) and a < b) if sign == '<' else (lambda a, b: (a is not None and b is not None) and a > b) time_delta = parse_time(user_range[4:]) except: if user_range.startswith('dur'): @@ -333,26 +335,25 @@ def _filter_records(user_range, exp_record_dict): elif user_range.startswith('has:'): phrase = user_range[len('has:'):] for exp_id, records in base.items(): - base[exp_id] = [True]*len(records) if phrase in exp_id else [False]*len(records) + base[exp_id] = [True] * len(records) if phrase in exp_id else [False] * len(records) else: raise RecordSelectionError("Don't know how to interpret subset '{}'. Possible subsets: {}".format(user_range, list(_named_record_filters.keys()))) return base class RecordSelectionError(Exception): - pass def _filter_experiment_record_list(user_range, experiment_record_ids): - if user_range=='all': - return [True]*len(experiment_record_ids) - elif user_range=='new': + if user_range == 'all': + return [True] * len(experiment_record_ids) + elif user_range == 'new': return detect_duplicates(experiment_record_ids, key=record_id_to_experiment_id, keep_last=True) # return [n for n, is_old in izip_equal(get_record_ids(), old) if not old] - elif user_range=='old': + elif user_range == 'old': return [not x for x in _filter_records(user_range, 'new')] - elif user_range=='orphans': + elif user_range == 'orphans': orphans = [] global_lib = get_global_experiment_library() for i, record_id in enumerate(experiment_record_ids): @@ -373,7 +374,7 @@ def _filter_experiment_record_list(user_range, experiment_record_ids): which_ones = interpret_numbers(user_range) if which_ones is None: raise Exception('Could not interpret user range: "{}"'.format(user_range)) - filters = [False]*len(experiment_record_ids) + filters = [False] * len(experiment_record_ids) for i in which_ones: filters[i] = True return filters @@ -410,12 +411,13 @@ def interpret_numbers(user_range): """ if all(d in '0123456789-,' for d in user_range): numbers_and_ranges = user_range.split(',') - numbers = [n for lst in [[int(s)] if '-' not in s else range(int(s[:s.index('-')]), int(s[s.index('-')+1:])+1) for s in numbers_and_ranges] for n in lst] + numbers = [n for lst in [[int(s)] if '-' not in s else range(int(s[:s.index('-')]), int(s[s.index('-') + 1:]) + 1) for s in numbers_and_ranges] for n in lst] return numbers else: return None -def run_experiment(experiment, slurm_job = False, experiment_path=None, **experiment_record_kwargs): + +def run_experiment(experiment, slurm_job=False, experiment_path=None, **experiment_record_kwargs): """ Run an experiment and save the results. Return a string which uniquely identifies the experiment. You can run the experiment again later by calling show_experiment(location_string): @@ -438,11 +440,12 @@ def run_experiment(experiment, slurm_job = False, experiment_path=None, **experi I am aware that we could potentially save code and make this super slick by designing a subclass of Experiment which would be a 'DistributedSlurmExperiment', but this is future work. For now, this works. """ - assert "SLURM_NODEID" in os.environ.keys(), "You indicated that the experiment '{}' is run within a SLURM call, however the environment variable 'SLURM_NODEID' could not be found".format(experiment.get_id()) + assert "SLURM_NODEID" in os.environ.keys(), "You indicated that the experiment '{}' is run within a SLURM call, however the environment variable 'SLURM_NODEID' could not be found".format( + experiment.get_id()) if int(os.environ["SLURM_NODEID"]) > 0: return if experiment_path: - 'As mentioned above, global variables are reset, so I reset the one element I actually use' #TODO: Make this more elegant + 'As mentioned above, global variables are reset, so I reset the one element I actually use' # TODO: Make this more elegant set_non_persistent_config_value(config_filename=".artemisrc", section="experiments", option="experiment_directory", value=experiment_path) return experiment.run(**experiment_record_kwargs) @@ -464,7 +467,7 @@ def run_experiment_by_name(name, exp_dict='global', slurm_job=False, experiment_ if exp_dict == 'global': exp_dict = get_global_experiment_library() experiment = exp_dict[name] - return run_experiment(experiment,slurm_job, experiment_path, **experiment_record_kwargs) + return run_experiment(experiment, slurm_job, experiment_path, **experiment_record_kwargs) def run_experiment_ignoring_errors(name, **kwargs): @@ -480,28 +483,28 @@ def run_multiple_experiments_with_slurm(experiments, n_parallel=None, max_proces ''' if n_parallel and n_parallel > 1: # raise NotImplementedError("No parallel Slurm execution at the moment. Implement it!") - print ('Warning... parallel-slurm integration is very beta. Use with caution') + print('Warning... parallel-slurm integration is very beta. Use with caution') experiment_subsets = divide_into_subsets(experiments, subset_size=n_parallel) for i, exp_subset in enumerate(experiment_subsets): nanny = Nanny() function_call = partial(run_multiple_experiments, - experiments=exp_subset, - parallel=n_parallel if max_processes_per_node is None else max_processes_per_node, - display_results=False, - run_args = run_args - ) - spp = SlurmPythonProcess(name="Group %i"%i, function=function_call,ip_address="127.0.0.1", slurm_kwargs=slurm_kwargs) + experiments=exp_subset, + parallel=n_parallel if max_processes_per_node is None else max_processes_per_node, + display_results=False, + run_args=run_args + ) + spp = SlurmPythonProcess(name="Group %i" % i, function=function_call, ip_address="127.0.0.1", slurm_kwargs=slurm_kwargs) # Using Nanny only for convenient stdout & stderr forwarding. - nanny.register_child_process(spp,monitor_for_termination=False) + nanny.register_child_process(spp, monitor_for_termination=False) nanny.execute_all_child_processes(time_out=2) else: - for i,exp in enumerate(experiments): + for i, exp in enumerate(experiments): nanny = Nanny() function_call = partial(run_experiment, experiment=exp, slurm_job=True, experiment_path=get_experiment_dir(), - raise_exceptions=raise_exceptions,display_results=False, **run_args) - spp = SlurmPythonProcess(name="Exp %i"%i, function=function_call,ip_address="127.0.0.1", slurm_kwargs=slurm_kwargs) + raise_exceptions=raise_exceptions, display_results=False, **run_args) + spp = SlurmPythonProcess(name="Exp %i" % i, function=function_call, ip_address="127.0.0.1", slurm_kwargs=slurm_kwargs) # Using Nanny only for convenient stdout & stderr forwarding. - nanny.register_child_process(spp,monitor_for_termination=False) + nanny.register_child_process(spp, monitor_for_termination=False) nanny.execute_all_child_processes(time_out=2) @@ -513,7 +516,7 @@ def _parallel_run_target(experiment_id_and_prefix, raise_exceptions, **kwargs): return run_experiment_ignoring_errors(experiment_id, prefix=prefix, **kwargs) -def run_multiple_experiments(experiments, prefixes = None, parallel = False, display_results=False, raise_exceptions=True, notes = (), run_args = {}): +def run_multiple_experiments(experiments, prefixes=None, parallel=False, display_results=False, raise_exceptions=True, notes=(), run_args={}): """ Run multiple experiments, optionally in parallel with multiprocessing. @@ -535,8 +538,8 @@ def run_multiple_experiments(experiments, prefixes = None, parallel = False, dis experiment_identifiers = [ex.get_id() for ex in experiments] if prefixes is None: prefixes = range(len(experiment_identifiers)) - prefixes = [s+': ' for s in equalize_string_lengths(prefixes, side='right')] - print ('Prefix key: \n'+'\n'.join('{}{}'.format(p, eid) for p, eid in izip_equal(prefixes, experiment_identifiers))) + prefixes = [s + ': ' for s in equalize_string_lengths(prefixes, side='right')] + print('Prefix key: \n' + '\n'.join('{}{}'.format(p, eid) for p, eid in izip_equal(prefixes, experiment_identifiers))) target_func = partial(_parallel_run_target, notes=notes, raise_exceptions=raise_exceptions, **run_args) p = multiprocessing.Pool(processes=parallel) @@ -564,8 +567,8 @@ def get_multiple_records(experiment, n, only_completed=True, if_not_enough='run' if if_not_enough == 'err': assert len(records) >= n, "You asked for {} records, but only {} were available".format(n, len(records)) return records[-n:] - elif if_not_enough=='run': - for k in range(n-len(records)): + elif if_not_enough == 'run': + for k in range(n - len(records)): record = experiment.run() records.append(record) return records[-n:] @@ -587,7 +590,7 @@ def remove_common_results_prefix(results_dict): return OrderedDict((k, v) for k, v in izip_equal(trimmed_keys, results_dict.values())) -def get_experient_to_record_dict(experiment_ids = None): +def get_experient_to_record_dict(experiment_ids=None): """ Given a list of experiment ids, return an OrderedDict whose keys are the experiment ids and whose values are lists of experiment record ids. @@ -626,15 +629,15 @@ def get_experiment_tuple(exp_id): if exp_id in exp_to_parent: parent_id = exp_to_parent[exp_id] parent_tuple = get_experiment_tuple(parent_id) - return parent_tuple + (exp_id[len(parent_id)+1:], ) + return parent_tuple + (exp_id[len(parent_id) + 1:],) else: - return (exp_id, ) + return (exp_id,) # Then for each experiment in the list, tuples = [get_experiment_tuple(eid) for eid in experiment_ids] de_prefixed_tuples = remove_common_prefix(tuples, keep_base=False) - start_with = '' if len(de_prefixed_tuples[0])==len(tuples[0]) else '.' - new_strings = [start_with+'.'.join(ex_tup) for ex_tup in de_prefixed_tuples] + start_with = '' if len(de_prefixed_tuples[0]) == len(tuples[0]) else '.' + new_strings = [start_with + '.'.join(ex_tup) for ex_tup in de_prefixed_tuples] return new_strings diff --git a/artemis/experiments/experiments.py b/artemis/experiments/experiments.py index bf0aa78f..4e559478 100644 --- a/artemis/experiments/experiments.py +++ b/artemis/experiments/experiments.py @@ -3,10 +3,14 @@ from collections import OrderedDict from contextlib import contextmanager from functools import partial +from typing import Optional, Mapping, Callable, Any, Iterator, Tuple +from typing import Sequence + from six import string_types from artemis.experiments.experiment_record import ExpStatusOptions, experiment_id_to_record_ids, load_experiment_record, \ get_all_record_ids, clear_experiment_records +from artemis.experiments.experiment_record import ExperimentRecord from artemis.experiments.experiment_record import run_and_record from artemis.experiments.experiment_record_view import compare_experiment_records, show_record from artemis.experiments.hyperparameter_search import parameter_search @@ -24,8 +28,15 @@ class Experiment(object): create variants using decorated_function.add_variant() """ - def __init__(self, function=None, show=None, compare=None, one_liner_function=None, result_parser = None, - name=None, is_root=False): + def __init__(self, + function: Optional[Callable] = None, + show: Optional[Callable[[ExperimentRecord], None]] = None, + compare: Optional[Callable[[Sequence[ExperimentRecord]], bool]] = None, + one_liner_function=Optional[Callable[[ExperimentRecord], str]], + result_parser: Optional[Callable[[ExperimentRecord], Sequence[Tuple[str, str]]]] = None, + name: Optional[str] = None, + is_root=False + ): """ :param function: The function defining the experiment :param display_function: A function that can be called to display the results returned by function. @@ -43,26 +54,31 @@ def __init__(self, function=None, show=None, compare=None, one_liner_function=No self.variants = OrderedDict() self._notes = [] self.is_root = is_root - self._tags= set() + self._tags = set() if not is_root: all_args, varargs_name, kargs_name, defaults = advanced_getargspec(function) undefined_args = [a for a in all_args if a not in defaults] - assert len(undefined_args)==0, "{} is not a root-experiment, but arguments {} are undefined. Either provide a value for these arguments or define this as a root_experiment (see {})."\ - .format(self, undefined_args, 'X.add_root_variant(...)' if isinstance(function, partial) else 'X.add_config_root_variant(...)' if isinstance(function, PartialReparametrization) else '@experiment_root') + assert len( + undefined_args) == 0, "{} is not a root-experiment, but arguments {} are undefined. Either provide a value for these arguments or define this as a root_experiment (see {})." \ + .format(self, undefined_args, 'X.add_root_variant(...)' if isinstance(function, partial) else 'X.add_config_root_variant(...)' if isinstance(function, + PartialReparametrization) else '@experiment_root') _register_experiment(self) @property - def show(self): + def show(self) -> Callable[[ExperimentRecord], None]: + """ A function that somehow displays the experiment record to the user. """ return self._show @property - def one_liner_function(self): + def one_liner_function(self) -> Callable[[ExperimentRecord], str]: + """ A function which summarizes the experiment result as a one-line string """ return self._one_liner_results @property - def compare(self): + def compare(self) -> Callable[[Sequence[ExperimentRecord]], None]: + """ Get a function that visually compares multiple records """ return self._compare @compare.setter @@ -70,7 +86,8 @@ def compare(self, val): self._compare = val @property - def result_parser(self): + def result_parser(self) -> Callable[[ExperimentRecord], Sequence[Tuple[str, str]]]: + """ Get the function that parses the experiment result into a sequence of (column, column_value) for display as a row of a table """ return self._result_parser def __call__(self, *args, **kwargs): @@ -80,20 +97,20 @@ def __call__(self, *args, **kwargs): def __str__(self): return 'Experiment {}'.format(self.name) - def get_args(self): + def get_args(self) -> Mapping[str, Any]: """ :return OrderedDict[str, Any]: An OrderedDict of arguments to the experiment """ all_arg_names, _, _, defaults = advanced_getargspec(self.function) return OrderedDict((name, defaults[name]) for name in all_arg_names) - def get_root_function(self): + def get_root_function(self) -> Callable: return get_partial_root(self.function) - def is_generator(self): + def is_generator(self) -> bool: return inspect.isgeneratorfunction(self.get_root_function()) - def call(self, *args, **kwargs): + def call(self, *args, **kwargs) -> ExperimentRecord: """ Call the experiment function without running as an experiment. If the experiment is a function, this is the same as just result = my_exp_func(). If it's defined as a generator, it loops and returns the last result. @@ -108,7 +125,8 @@ def call(self, *args, **kwargs): return result def run(self, print_to_console=True, show_figs=None, test_mode=None, keep_record=None, raise_exceptions=True, - display_results=False, notes = (), **experiment_record_kwargs): + display_results=False, notes: Optional[str] = (), **experiment_record_kwargs + ) -> ExperimentRecord: """ Run the experiment, and return the ExperimentRecord that is generated. @@ -130,20 +148,23 @@ def run(self, print_to_console=True, show_figs=None, test_mode=None, keep_record """ for exp_rec in self.iterator(print_to_console=print_to_console, show_figs=show_figs, test_mode=test_mode, keep_record=keep_record, - raise_exceptions=raise_exceptions, display_results=display_results, notes=notes, **experiment_record_kwargs): + raise_exceptions=raise_exceptions, display_results=display_results, notes=notes, **experiment_record_kwargs): pass return exp_rec def iterator(self, print_to_console=True, show_figs=None, test_mode=None, keep_record=None, raise_exceptions=True, - display_results=False, notes = (), **experiment_record_kwargs): + display_results=False, notes=(), **experiment_record_kwargs + ) -> Iterator[ExperimentRecord]: + """ Create an iteratator from an experiment defined on a generator-function. + The iterator yields an ExperimentResults wrapping the yield of the generator-function """ if keep_record is None: keep_record = keep_record_by_default if keep_record_by_default is not None else not test_mode exp_rec = None for exp_rec in run_and_record( - function = self.function, + function=self.function, experiment_id=self.name, print_to_console=print_to_console, show_figs=show_figs, @@ -152,16 +173,18 @@ def iterator(self, print_to_console=True, show_figs=None, test_mode=None, keep_r raise_exceptions=raise_exceptions, notes=notes, **experiment_record_kwargs - ): + ): yield exp_rec assert exp_rec is not None, 'Should nevah happen.' if display_results: self.show(exp_rec) return - def _create_experiment_variant(self, args, kwargs, is_root): + def _create_experiment_variant(self, args: Sequence[Any], kwargs: Mapping[str, Any], is_root: bool + ) -> 'Experiment': # TODO: For non-root variants, assert that all args are defined - assert len(args) in (0, 1), "When creating an experiment variant, you can either provide one unnamed argument (the experiment name), or zero, in which case the experiment is named after the named argumeents. See add_variant docstring" + assert len(args) in (0, + 1), "When creating an experiment variant, you can either provide one unnamed argument (the experiment name), or zero, in which case the experiment is named after the named argumeents. See add_variant docstring" name = args[0] if len(args) == 1 else _kwargs_to_experiment_name(kwargs) assert isinstance(name, str), 'Name should be a string. Not: {}'.format(name) assert name not in self.variants, 'Variant "%s" already exists.' % (name,) @@ -178,7 +201,8 @@ def _create_experiment_variant(self, args, kwargs, is_root): self.variants[name] = ex return ex - def add_variant(self, variant_name = None, **kwargs): + def add_variant(self, variant_name=None, **kwargs + ) -> 'Experiment': """ Add a variant to this experiment, and register it on the list of experiments. There are two ways you can do this: @@ -197,9 +221,10 @@ def add_variant(self, variant_name = None, **kwargs): :param kwargs: The named arguments which will differ from the base experiment. :return Experiment: The experiment. """ - return self._create_experiment_variant(() if variant_name is None else (variant_name, ), kwargs, is_root=False) + return self._create_experiment_variant(() if variant_name is None else (variant_name,), kwargs, is_root=False) - def add_root_variant(self, variant_name=None, **kwargs): + def add_root_variant(self, variant_name=None, **kwargs + ) -> 'Experiment': """ Add a variant to this experiment, but do NOT register it on the list of experiments. There are two ways you can do this: @@ -218,9 +243,9 @@ def add_root_variant(self, variant_name=None, **kwargs): :param kwargs: The named arguments which will differ from the base experiment. :return Experiment: The experiment. """ - return self._create_experiment_variant(() if variant_name is None else (variant_name, ), kwargs, is_root=True) + return self._create_experiment_variant(() if variant_name is None else (variant_name,), kwargs, is_root=True) - def copy_variants(self, other_experiment): + def copy_variants(self, other_experiment: 'Experiment') -> None: """ Copy over the variants from another experiment. @@ -230,12 +255,13 @@ def copy_variants(self, other_experiment): for variant in other_experiment.get_variants(): if variant is not self: variant_args = variant.get_args() - different_args = {k: v for k, v in variant_args.items() if base_args[k]!=v} - name_diff = variant.get_id()[len(other_experiment.get_id())+1:] + different_args = {k: v for k, v in variant_args.items() if base_args[k] != v} + name_diff = variant.get_id()[len(other_experiment.get_id()) + 1:] v = self.add_variant(name_diff, **different_args) v.copy_variants(variant) - def _add_config(self, name, arg_constructors, is_root): + def _add_config(self, name: str, arg_constructors: Mapping[str, Callable[[], Any]], is_root: bool + ) -> 'Experiment': assert isinstance(name, str), 'Name should be a string. Not: {}'.format(name) assert name not in self.variants, 'Variant "%s" already exists.' % (name,) assert '/' not in name, 'Experiment names cannot have "/" in them: {}'.format(name) @@ -251,7 +277,8 @@ def _add_config(self, name, arg_constructors, is_root): self.variants[name] = ex return ex - def add_config_variant(self, name, **arg_constructors): + def add_config_variant(self, name: str, **arg_constructors + ) -> 'Experiment': """ Add a variant where you redefine the constructor for arguments to the experiment. e.g. @@ -271,19 +298,21 @@ def demo_smooth_out_signal(smoother, signal): """ return self._add_config(name, arg_constructors=arg_constructors, is_root=False) - def add_config_root_variant(self, name, **arg_constructors): + def add_config_root_variant(self, name: str, **arg_constructors + ) -> 'Experiment': """ Add a config variant which requires additional parametrization. (See add_config_variant) """ return self._add_config(name, arg_constructors=arg_constructors, is_root=True) - def get_id(self): + def get_id(self) -> str: """ :return: A string uniquely identifying this experiment. """ return self.name - def get_variant(self, variant_name=None, **kwargs): + def get_variant(self, variant_name: Optional[str] = None, **kwargs + ) -> 'Experiment': """ Get a variant on this experiment. @@ -294,11 +323,12 @@ def get_variant(self, variant_name=None, **kwargs): if variant_name is None: variant_name = _kwargs_to_experiment_name(kwargs) else: - assert len(kwargs)==0, 'If you provide a variant name ({}), there is no need to specify the keyword arguments. ({})'.format(variant_name, kwargs) + assert len(kwargs) == 0, 'If you provide a variant name ({}), there is no need to specify the keyword arguments. ({})'.format(variant_name, kwargs) assert variant_name in self.variants, "No variant '{}' exists. Existing variants: {}".format(variant_name, list(self.variants.keys())) return self.variants[variant_name] - def get_records(self, only_completed=False): + def get_records(self, only_completed=False + ) -> 'Sequence[ExperimentRecord]': """ Get all records associated with this experiment. @@ -307,12 +337,12 @@ def get_records(self, only_completed=False): """ records = [load_experiment_record(rid) for rid in experiment_id_to_record_ids(self.name)] if only_completed: - records = [record for record in records if record.get_status()==ExpStatusOptions.FINISHED] + records = [record for record in records if record.get_status() == ExpStatusOptions.FINISHED] return records - def browse(self, command=None, catch_errors = False, close_after = False, filterexp=None, filterrec = None, - view_mode ='full', raise_display_errors=False, run_args=None, keep_record=True, truncate_result_to=100, - cache_result_string = False, remove_prefix = None, display_format='nested', **kwargs): + def browse(self, command: Optional[str] = None, catch_errors=False, close_after=False, filterexp: Optional[str] = None, filterrec=None, + view_mode='full', raise_display_errors=False, run_args=None, keep_record=True, truncate_result_to=100, + cache_result_string=False, remove_prefix=None, display_format='nested', **kwargs) -> None: """ Open up the UI, which allows you to run experiments and view their results. @@ -335,9 +365,9 @@ def browse(self, command=None, catch_errors = False, close_after = False, filter from artemis.experiments.ui import ExperimentBrowser experiments = get_ordered_descendents_of_root(root_experiment=self) browser = ExperimentBrowser(experiments=experiments, catch_errors=catch_errors, close_after=close_after, - filterexp=filterexp, filterrec=filterrec, view_mode=view_mode, raise_display_errors=raise_display_errors, - run_args=run_args, keep_record=keep_record, truncate_result_to=truncate_result_to, cache_result_string=cache_result_string, - remove_prefix=remove_prefix, display_format=display_format, **kwargs) + filterexp=filterexp, filterrec=filterrec, view_mode=view_mode, raise_display_errors=raise_display_errors, + run_args=run_args, keep_record=keep_record, truncate_result_to=truncate_result_to, cache_result_string=cache_result_string, + remove_prefix=remove_prefix, display_format=display_format, **kwargs) browser.launch(command=command) # Above this line is the core api.... @@ -354,7 +384,7 @@ def has_record(self, completed=True, valid=True): records = self.get_records(only_completed=completed) if valid: records = [record for record in records if record.args_valid()] - return len(records)>0 + return len(records) > 0 def get_variants(self): return self.variants.values() @@ -376,7 +406,7 @@ def get_all_variants(self, include_roots=False, include_self=True): def test(self, **kwargs): self.run(test_mode=True, **kwargs) - def get_latest_record(self, only_completed=False, if_none = 'skip'): + def get_latest_record(self, only_completed=False, if_none='skip'): """ Return the ExperimentRecord from the latest run of this Experiment. @@ -389,10 +419,10 @@ def get_latest_record(self, only_completed=False, if_none = 'skip'): """ assert if_none in ('skip', 'err', 'run') records = self.get_records(only_completed=only_completed) - if len(records)==0: - if if_none=='run': + if len(records) == 0: + if if_none == 'run': return self.run() - elif if_none=='err': + elif if_none == 'err': raise Exception('No{} records for experiment "{}"'.format(' completed' if only_completed else '', self.name)) else: return None @@ -424,7 +454,7 @@ def get_variant_records(self, only_completed=False, only_last=False, flat=False) else: return exp_record_dict - def add_parameter_search(self, name='parameter_search', fixed_args = {}, space = None, n_calls=None, search_params = None, scalar_func=None): + def add_parameter_search(self, name='parameter_search', fixed_args={}, space=None, n_calls=None, search_params=None, scalar_func=None): """ :param name: Name of the Experiment to be created :param dict[str, Any] fixed_args: Any fixed-arguments to provide to all experiments. @@ -455,7 +485,7 @@ def search_func(fixed): n_calls_to_make = n_calls if n_calls is not None else 3 if is_test_mode() else 100 this_objective = partial(objective, **fixed) for iter_info in parameter_search(this_objective, n_calls=n_calls_to_make, space=space, **search_params): - info = dict(names=list(space.keys()), x_iters =iter_info.x_iters, func_vals=iter_info.func_vals, score = iter_info.func_vals, x=iter_info.x, fun=iter_info.fun) + info = dict(names=list(space.keys()), x_iters=iter_info.x_iters, func_vals=iter_info.func_vals, score=iter_info.func_vals, x=iter_info.x, fun=iter_info.fun) latest_info = {name: val for name, val in izip_equal(info['names'], iter_info.x_iters[-1])} print('Latest: {}, Score: {:.3g}'.format(latest_info, iter_info.func_vals[-1])) yield info @@ -467,7 +497,7 @@ def search_func(fixed): # param_search = locals()['param_search'] search_exp_func = partial(search_func, fixed=fixed_args) # We do this so that the fixed parameters will be recorded and we will see if they changed. - search_exp = ExperimentFunction(name = self.name + '.'+ name, show = show_parameter_search_record, one_liner_function=parameter_search_one_liner)(search_exp_func) + search_exp = ExperimentFunction(name=self.name + '.' + name, show=show_parameter_search_record, one_liner_function=parameter_search_one_liner)(search_exp_func) self.variants[name] = search_exp search_exp.tag('psearch') # Secret feature that makes it easy to select all parameter experiments in ui with "filter tag:psearch" return search_exp @@ -489,12 +519,13 @@ def get_tags(self): def show_parameter_search_record(record): from tabulate import tabulate result = record.get_result() - table = tabulate([list(xs)+[fun] for xs, fun in zip(result['x_iters'], result['func_vals'])], headers=list(result['names'])+['score']) + table = tabulate([list(xs) + [fun] for xs, fun in zip(result['x_iters'], result['func_vals'])], headers=list(result['names']) + ['score']) print(table) def parameter_search_one_liner(result): - return '{} Runs : '.format(len(result["x_iters"])) + ', '.join('{}={:.3g}'.format(k, v) for k, v in izip_equal(result['names'], result['x'])) + ' : Score = {:.3g}'.format(result["fun"]) + return '{} Runs : '.format(len(result["x_iters"])) + ', '.join('{}={:.3g}'.format(k, v) for k, v in izip_equal(result['names'], result['x'])) + ' : Score = {:.3g}'.format( + result["fun"]) _GLOBAL_EXPERIMENT_LIBRARY = OrderedDict() @@ -502,7 +533,7 @@ def parameter_search_one_liner(result): class ExperimentNotFoundError(Exception): def __init__(self, experiment_id): - Exception.__init__(self,'Experiment "{}" could not be loaded, either because it has not been imported, or its definition was removed.'.format(experiment_id)) + Exception.__init__(self, 'Experiment "{}" could not be loaded, either because it has not been imported, or its definition was removed.'.format(experiment_id)) def clear_all_experiments(): @@ -537,7 +568,8 @@ def add_two_numbers(a=1, b=2): def _register_experiment(experiment): - assert experiment.name not in _GLOBAL_EXPERIMENT_LIBRARY, 'You have already registered an experiment named {} in {}'.format(experiment.name, inspect.getmodule(experiment.get_root_function()).__name__) + assert experiment.name not in _GLOBAL_EXPERIMENT_LIBRARY, 'You have already registered an experiment named {} in {}'.format(experiment.name, inspect.getmodule( + experiment.get_root_function()).__name__) _GLOBAL_EXPERIMENT_LIBRARY[experiment.name] = experiment @@ -579,7 +611,7 @@ def _kwargs_to_experiment_name(kwargs): @contextmanager -def hold_global_experiment_libary(new_lib = None): +def hold_global_experiment_libary(new_lib=None): if new_lib is None: new_lib = OrderedDict() @@ -598,7 +630,7 @@ def get_global_experiment_library(): @contextmanager -def experiment_testing_context(close_figures_at_end = True, new_experiment_lib = False): +def experiment_testing_context(close_figures_at_end=True, new_experiment_lib=False): """ Use this context when testing the experiment/experiment_record infrastructure. Should only really be used in test_experiment_record.py diff --git a/artemis/fileman/disk_memoize.py b/artemis/fileman/disk_memoize.py index 3cbef5e7..51952905 100644 --- a/artemis/fileman/disk_memoize.py +++ b/artemis/fileman/disk_memoize.py @@ -1,5 +1,8 @@ +import inspect import logging import os +import time +from contextlib import contextmanager from functools import partial from shutil import rmtree @@ -7,6 +10,7 @@ from artemis.general.functional import infer_arg_values from artemis.general.hashing import compute_fixed_hash from artemis.general.test_mode import is_test_mode +from eagle_eyes.utils.utils_for_testing import hold_tempdir logging.basicConfig() LOGGER = logging.getLogger(__name__) @@ -20,6 +24,16 @@ MEMO_DIR = get_artemis_data_path('memoize_to_disk') +@contextmanager +def hold_temp_memo_dir(): + global MEMO_DIR + oldone = MEMO_DIR + with hold_tempdir() as path: + MEMO_DIR = path + yield path + MEMO_DIR = oldone + + def memoize_to_disk(fcn, local_cache = False, disable_on_tests=False, use_cpickle = False, suppress_info = False): """ Save (memoize) computed results to disk, so that the same function, called with the @@ -83,8 +97,12 @@ def check_memos(*args, **kwargs): with open(filepath, 'rb') as f: try: if not suppress_info: - LOGGER.info('Reading memo for function {}'.format(fcn.__name__, )) + LOGGER.info('Reading memo for function {}...'.format(fcn.__name__, )) + tstart = time.monotonic() result = pickle.load(f) + if not suppress_info: + LOGGER.info(f'...Reading memo for function {fcn.__name__} took {time.monotonic()-tstart:.5f}s'.format(fcn.__name__, )) + except (ValueError, ImportError, EOFError) as err: if isinstance(err, (ValueError, EOFError)) and not suppress_info: LOGGER.warn('Memo-file "{}" was corrupt. ({}: {}). Recomputing.'.format(filepath, err.__class__.__name__, str(err))) @@ -94,7 +112,13 @@ def check_memos(*args, **kwargs): result = fcn(*args, **kwargs) else: result_computed = True - result = fcn(*args, **kwargs) + if inspect.isgeneratorfunction(fcn): + # TODO: Do this properly - caching results one at a time + LOGGER.info(f"Computing results from generator {fcn} in advance...") + result = list(fcn(*args, **kwargs)) + LOGGER.info('... Done') + else: + result = fcn(*args, **kwargs) else: result_computed = True result = fcn(*args, **kwargs) diff --git a/artemis/fileman/test_disk_memoize.py b/artemis/fileman/test_disk_memoize.py index 3915ac5e..809028b8 100644 --- a/artemis/fileman/test_disk_memoize.py +++ b/artemis/fileman/test_disk_memoize.py @@ -15,6 +15,14 @@ def compute_slow_thing(a, b, c): return (a+b)/float(c), call_time + +@memoize_to_disk_test +def compute_slow_thing_with_type_annotations(a, b, c: int =3): + call_time = time.time() + time.sleep(0.01) + return (a+b)/float(c), call_time + + def test_memoize_to_disk(): clear_memo_files_for_function(compute_slow_thing) @@ -106,7 +114,7 @@ def test_clear_error_for_missing_arg(): clear_memo_files_for_function(compute_slow_thing) - with raises(AssertionError): + with raises(TypeError): compute_slow_thing(1) @@ -114,7 +122,7 @@ def test_clear_arror_for_wrong_arg(): clear_memo_files_for_function(compute_slow_thing) - with raises(AssertionError): + with raises(TypeError): compute_slow_thing(a=1, b=2, c=3, d=4) @@ -129,7 +137,7 @@ def test_unnoticed_wrong_arg_bug_is_dead(): clear_memo_files_for_function(compute_slow_thing) compute_slow_thing(a=1, b=2, c=3) # Creates a memo - with raises(AssertionError): + with raises(TypeError): compute_slow_thing(a=1, b=2, see=3) # Previously, this was not caught, leading you not to notice your typo @@ -151,12 +159,43 @@ def test_catch_kwarg_error(): assert t3 == t1 +def test_memoize_to_disk_with_annotations(): + + clear_memo_files_for_function(compute_slow_thing_with_type_annotations) + + t = time.time() + num, t1 = compute_slow_thing_with_type_annotations(1, 3) + assert t-t1 < 0.01 + assert num == (1+3)/3. + + num, t2 = compute_slow_thing_with_type_annotations(1, 3) + assert num == (1+3)/3. + assert t2==t1 + + +@memoize_to_disk_test +def iter_slowly(): + for i in range(5): + time.sleep(0.01) + yield time.time(), i + + +def test_memoize_iter_slowly(): + + clear_memo_files_for_function(iter_slowly) + results_1 = list(iter_slowly()) + results_2 = list(iter_slowly()) + assert results_1 == results_2 + + if __name__ == '__main__': set_test_mode(True) - test_unnoticed_wrong_arg_bug_is_dead() - test_catch_kwarg_error() - test_clear_arror_for_wrong_arg() - test_clear_error_for_missing_arg() - test_memoize_to_disk_and_cache() - test_memoize_to_disk() - test_complex_args() + # test_unnoticed_wrong_arg_bug_is_dead() + # test_catch_kwarg_error() + # test_clear_arror_for_wrong_arg() + # test_clear_error_for_missing_arg() + # test_memoize_to_disk_and_cache() + # test_memoize_to_disk() + # test_complex_args() + # test_memoize_to_disk() + test_memoize_iter_slowly() diff --git a/artemis/general/display.py b/artemis/general/display.py index 28093cd4..aa25296e 100644 --- a/artemis/general/display.py +++ b/artemis/general/display.py @@ -78,7 +78,7 @@ def equalize_string_lengths(arr, side = 'left'): return strings -def sensible_str(data, size_limit=4, compact=True): +def sensible_str(data, size_limit=4, compact=True) -> str: """ Crawl through an data structure and try to make a sensible compact representation of it. :param data: Some data structure. diff --git a/artemis/general/functional.py b/artemis/general/functional.py index 3acc6b97..b8b950ac 100644 --- a/artemis/general/functional.py +++ b/artemis/general/functional.py @@ -3,6 +3,9 @@ from collections import OrderedDict from functools import partial import collections + +from cv2.gapi.ie.detail import PARAM_DESC_KIND_LOAD + from artemis.general.should_be_builtins import separate_common_items import sys import types @@ -206,23 +209,38 @@ def infer_arg_values(f, args=(), kwargs={}): :param kwargs: A dict of keyword args :return: An OrderedDict(arg_name->arg_value) """ - all_arg_names, varargs_name, kwargs_name, defaults = inspect.getargspec(f) + # all_arg_names, varargs_name, kwargs_name, defaults = inspect.getargspec(f) + + sig = inspect.signature(f) + all_arg_names = sig.parameters + + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() - assert varargs_name is None, "This function doesn't work with unnamed args" - default_args = {k: v for k, v in zip(all_arg_names[len(all_arg_names)-(len(defaults) if defaults is not None else 0):], defaults if defaults is not None else [])} - args_with_values = set(all_arg_names[:len(args)]+list(default_args.keys())+list(kwargs.keys())) - assert set(all_arg_names).issubset(args_with_values), "Arguments {} require values but are not given any. ".format(tuple(set(all_arg_names).difference(args_with_values))) + assert not any(p.kind.name=='VAR_POSITIONAL' for p in sig.parameters.values()), "This function doesn't work with unnamed args" + # assert varargs_name is None, "This function doesn't work with unnamed args" + # default_args = {k: v for k, v in zip(all_arg_names[len(all_arg_names)-(len(defaults) if defaults is not None else 0):], defaults if defaults is not None else [])} + # args_with_values = set(all_arg_names[:len(args)]+list(default_args.keys())+list(kwargs.keys())) + + + # assert set(all_arg_names).issubset(args_with_values), "Arguments {} require values but are not given any. ".format(tuple(set(all_arg_names).difference(args_with_values))) assert len(args) <= len(all_arg_names), "You provided {} arguments, but the function only takes {}".format(len(args), len(all_arg_names)) - full_args = tuple( - list(zip(all_arg_names, args)) # Handle unnamed args f(1, 2) - + [(name, kwargs[name] if name in kwargs else default_args[name]) for name in all_arg_names[len(args):]] # Handle named keyworkd args f(a=1, b=2) - + [(name, kwargs[name]) for name in kwargs if name not in all_arg_names[len(args):]] # Need to handle case if f takes **kwargs - ) - duplicates = tuple(item for item, count in collections.Counter([a for a, _ in full_args]).items() if count > 1) - assert len(duplicates)==0, 'Arguments {} have been defined multiple times: {}'.format(duplicates, full_args) - - common_args, (different_args, different_given_args) = separate_common_items([tuple(all_arg_names), tuple(n for n, _ in full_args)]) - if kwargs_name is None: # There is no **kwargs - assert len(different_given_args)==0, "Function {} was given args {} but didn't ask for them".format(f, different_given_args) - assert len(different_args)==0, "Function {} needs values for args {} but didn't get them".format(f, different_args) + + full_args = bound_args.arguments + # full_args = tuple( + # (pname, ) + # ) + # + # full_args = tuple( + # list(zip(all_arg_names, args)) # Handle unnamed args f(1, 2) + # + [(name, kwargs[name] if name in kwargs else default_args[name]) for name in all_arg_names[len(args):]] # Handle named keyworkd args f(a=1, b=2) + # + [(name, kwargs[name]) for name in kwargs if name not in all_arg_names[len(args):]] # Need to handle case if f takes **kwargs + # ) + # duplicates = tuple(item for item, count in collections.Counter([a for a, _ in full_args]).items() if count > 1) + # assert len(duplicates)==0, 'Arguments {} have been defined multiple times: {}'.format(duplicates, full_args) + + # common_args, (different_args, different_given_args) = separate_common_items([tuple(all_arg_names), tuple(n for n, _ in full_args)]) + # if kwargs_name is None: # There is no **kwargs + # assert len(different_given_args)==0, "Function {} was given args {} but didn't ask for them".format(f, different_given_args) + # assert len(different_args)==0, "Function {} needs values for args {} but didn't get them".format(f, different_args) return OrderedDict(full_args) diff --git a/artemis/general/hashing.py b/artemis/general/hashing.py index dab32412..e1c965c7 100644 --- a/artemis/general/hashing.py +++ b/artemis/general/hashing.py @@ -78,6 +78,7 @@ def compute_fixed_hash(obj, try_objects=False, _hasher = None, _memo = None, _co for k in keys: compute_fixed_hash(k, **kwargs) compute_fixed_hash(obj.__dict__[k], **kwargs) + else: # TODO: Consider whether to pickle by default. Note that pickle strings are not necessairly the same for identical objects. raise NotImplementedError("Don't have a method for hashing this %s" % (obj, )) From d2bd393f4ba385b0c9a4f5dd95472075f3a66946 Mon Sep 17 00:00:00 2001 From: Peter O'Connor Date: Sat, 7 Jan 2023 18:07:15 -0800 Subject: [PATCH 41/41] Update README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 552b80fc..14a088d8 100644 --- a/README.md +++ b/README.md @@ -63,16 +63,16 @@ For more examples of how to use artemis, read the [Artemis Documentation](http:/ To use artemis from within your project, use the following to install Artemis and its dependencies: (You probably want to do this in a virtualenv with the latest version of pip - run `virtualenv venv; source venv/bin/activate; pip install --upgrade pip;` to make one and enter it). -**Option 1: Simple install:** +**Option 1: Simple install (stable) :** ``` pip install artemis-ml ``` -**Option 2: Install as source.** +**Option 2: Install as source (master branch, do this if you want latest updates or you want to contribute to Artemis).** ``` -pip install -e git+http://github.com/QUVA-Lab/artemis.git#egg=artemis +pip install -e git+https://github.com/petered/artemis.git#egg=artemis-ml ``` This will install it in `(virtual env or system python root)/src/artemis`. You can edit the code and submit pull requests to our git repo. To install with the optional [remote plotting](https://github.com/QUVA-Lab/artemis/blob/master/artemis/remote/README.md) mode enabled, add the `[remote_plotting]` option, as in: `pip install -e git+http://github.com/QUVA-Lab/artemis.git#egg=artemis[remote_plotting]`