Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow seed fix in Sampler #669

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
17 changes: 13 additions & 4 deletions batchflow/plotter/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def image(self, data):

image_keys = ['alpha', 'vmin', 'vmax', 'extent']
image_config = self.config.filter(keys=image_keys, prefix='image_')

image = self.ax.imshow(data, cmap=cmap, **image_config)

return [image]
Expand Down Expand Up @@ -653,7 +652,8 @@ def add_colorbar(self, image, width=0.2, pad=None, color='black', position='righ

return colorbar

def add_legend(self, mode='image', label=None, color='none', alpha=1, size=15, **kwargs):
def add_legend(self, mode='image', label=None, color='none', alpha=1,
size=15, family='sans-serif', properties=None, **kwargs):
""" Add patches to subplot legend.

Parameters
Expand All @@ -671,10 +671,18 @@ def add_legend(self, mode='image', label=None, color='none', alpha=1, size=15, *
alpha : number or list of numbers from 0 to 1
Legend handles opacity.
size : int
Legend size.
Legend text font size.
family : str
Legent text font family.
properties : dict
Legend font parameters, must be valid for `matplotlib.font_manager.FontProperties`.
kwargs : misc
For `matplotlib.legend`.
"""
if properties is None:
properties = {}
properties = {'size': size, 'family': family, **properties}

# get legend that already exists
legend = self.ax.get_legend()
old_handles = getattr(legend, 'legendHandles', [])
Expand All @@ -689,6 +697,7 @@ def add_legend(self, mode='image', label=None, color='none', alpha=1, size=15, *
for label_item, label_color, label_alpha in zip(labels, colors, alphas):
if label_item is None:
continue

if isinstance(label_item, str):
if mode in ('image', 'histogram'):
if is_color_like(label_color):
Expand All @@ -705,7 +714,7 @@ def add_legend(self, mode='image', label=None, color='none', alpha=1, size=15, *
if len(new_handles) > 0:
# extend existing handles and labels with new ones
handles = old_handles + new_handles
legend = self.ax.legend(prop={'size': size}, handles=handles, handler_map=handler_map, **kwargs)
legend = self.ax.legend(prop=properties, handles=handles, handler_map=handler_map, **kwargs)

return legend

Expand Down
12 changes: 6 additions & 6 deletions batchflow/research/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(self, domain=None, **kwargs):
self.n_updates = 0
self.additional = True
self.create_id_prefix = False
self.random_state = None
self.rng = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a chain of random_state variables and their SeedSequences in batchflow:

---Pipeline
---Dataset
------Batch
---------inbatch_parallel workers (threads / processes / for-items)

And Research is even one level above that.

Have you read this entire chain of properties before changing it?


self.values_indices = dict()

Expand Down Expand Up @@ -318,7 +318,7 @@ def set_iter_params(self, n_items=None, n_reps=1, repeat_each=None, produced=0,
self.n_produced = produced
self.additional = additional
self.create_id_prefix = create_id_prefix
self.random_state = make_rng(seed)
self.rng = make_rng(seed)
self.reset_iter()

def set_update(self, function, when, **kwargs):
Expand All @@ -344,7 +344,7 @@ def update(self, generated, research):
domain.updates = self.updates
domain.n_updates = self.n_updates + 1
domain.values_indices = self.values_indices
domain.set_iter_params(produced=generated, additional=self.additional, seed=self.random_state,
domain.set_iter_params(produced=generated, additional=self.additional, seed=self.rng,
create_id_prefix=self.create_id_prefix, **update['iter_kwargs'])
return domain
return None
Expand Down Expand Up @@ -453,7 +453,7 @@ def reset_iter(self):
for cube in self.cubes:
for _, values in cube:
if isinstance(values, Sampler):
values.state = make_rng(self.random_state)
values.rng = make_rng(self.rng)
self._iterator = None

def create_iter(self):
Expand All @@ -468,7 +468,7 @@ def _iterator():
weights[np.isnan(weights)] = 1
iterators = [self._cube_iterator(cube) for cube in np.array(self.cubes, dtype=object)[block]]
while len(iterators) > 0:
index = self.random_state.choice(len(block), p=weights/weights.sum())
index = self.rng.choice(len(block), p=weights/weights.sum())
try:
yield next(iterators[index])
except StopIteration:
Expand Down Expand Up @@ -519,7 +519,7 @@ def iterator(self):
""" Get domain iterator. """
if self._iterator is None:
self.set_iter_params(self.n_items, self.n_reps, self.repeat_each, self.n_produced,
self.additional, self.create_id_prefix, self.random_state)
self.additional, self.create_id_prefix, self.rng)
self.create_iter()
return self._iterator

Expand Down
130 changes: 87 additions & 43 deletions batchflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,32 @@
'c': {'np': 'choice'}
}

def _get_method_by_alias(alias, module, tf_distributions=None):
def _get_method_by_alias(name, module, tf_distributions=None):
""" Fetch fullname of a randomizer from ``scipy.stats``, ``tensorflow`` or
``numpy`` by its alias or fullname.
``numpy`` by its alias or name.
"""
rnd_submodules = {'np': np.random,
'tf': tf_distributions,
'ss': ss}
# fetch fullname
fullname = ALIASES.get(alias, {module: alias for module in ['np', 'tf', 'ss']}).get(module, None)
if fullname is None:
raise ValueError(f"Distribution {alias} has no implementaion in module {module}")
rnd_modules = {'np': np.random,
'tf': tf_distributions,
'ss': ss}

if isinstance(module, str):
if name in ALIASES: # pylint: disable=consider-using-get
aliases = ALIASES[name]

if module in aliases:
name = aliases[module]
else:
msg = f"The mapping of distribution alias '{name}' to its full name "\
f"for module '{module}' is not defined."
raise ValueError(msg)

module = rnd_modules[module]

# check that the randomizer is implemented in corresponding module
if not hasattr(rnd_submodules[module], fullname):
raise ValueError(f"Distribution {fullname} has no implementaion in module {module}")
if not hasattr(module, name):
raise ValueError(f"Distribution {name} has no implementaion in module {module}")

return fullname
return name


def arithmetize(cls):
Expand Down Expand Up @@ -84,10 +93,13 @@ class Sampler():
weight : float
weight of Sampler self in mixtures.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, seed=None, **kwargs):
self.__array_priority__ = 100
self.weight = 1.0

self.seed = seed
self.rng = make_rng(seed)

# if dim is supplied, redefine sampling method
if 'dim' in kwargs:
# assemble stacked sampler
Expand Down Expand Up @@ -216,8 +228,9 @@ def truncate(self, high=None, low=None, expr=None, prob=0.5, max_iters=None, sam
class OrSampler(Sampler):
""" Class for implementing `|` (mixture) operation on `Sampler`-instances.
"""
def __init__(self, left, right, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, left, right, *args, seed=None, **kwargs):
seed = mix_samplers_seeds(left, right) if seed is None else seed
super().__init__(*args, seed=seed, **kwargs)
self.bases = [left, right]

# calculate probs of samplers in mixture
Expand All @@ -230,21 +243,22 @@ def sample(self, size):
defined by weights (`self.weight`-attr) from two samplers invoked (`self.bases`-attr) and
mixes them in one sample of needed size.
"""
up_size = np.random.binomial(size, self.normed[0])
up_size = self.rng.binomial(size, self.normed[0])
low_size = size - up_size

up_sample = self.bases[0].sample(size=up_size)
low_sample = self.bases[1].sample(size=low_size)
sample_points = np.concatenate([up_sample, low_sample])
sample_points = sample_points[np.random.permutation(size)]
sample_points = sample_points[self.rng.permutation(size)]

return sample_points

class AndSampler(Sampler):
""" Class for implementing `&` (coordinates stacking) operation on `Sampler`-instances.
"""
def __init__(self, left, right, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, left, right, *args, seed=None, **kwargs):
seed = mix_samplers_seeds(left, right) if seed is None else seed
super().__init__(*args, seed=seed, **kwargs)
self.bases = [left, right]

def sample(self, size):
Expand All @@ -258,8 +272,9 @@ def sample(self, size):
class ApplySampler(Sampler):
""" Class for implementing `apply` (adding transform) operation on `Sampler`-instances.
"""
def __init__(self, sampler, transform, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, sampler, transform, *args, seed=None, **kwargs):
seed = sampler.seed if seed is None else seed
super().__init__(*args, seed=seed, **kwargs)
self.bases = [sampler]
self.transform = transform

Expand All @@ -276,10 +291,12 @@ class TruncateSampler(Sampler):
# from the region of interest using this number of iterations, we throw a Warning or ValueError
max_iters = 1e7

def __init__(self, sampler, high=None, low=None, expr=None, prob=0.5, max_iters=None,
sample_anyways=False, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, sampler, *args, high=None, low=None, expr=None, prob=0.5,
max_iters=None, sample_anyways=False, seed=None, **kwargs):
seed = sampler.seed if seed is None else seed
super().__init__(*args, seed=seed, **kwargs)
self.bases = [sampler]

self.high = high
self.low = low
self.expr = expr
Expand Down Expand Up @@ -346,8 +363,9 @@ class BaseOperationSampler(Sampler):
""" Base class for implementing all arithmetic operations on `Sampler`-instances.
"""
operation = None
def __init__(self, left, right, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, left, right, *args, seed=None, **kwargs):
seed = mix_samplers_seeds(left, right) if seed is None else seed
super().__init__(*args, seed=seed, **kwargs)
self.bases = [left, right]

def sample(self, size):
Expand Down Expand Up @@ -448,7 +466,7 @@ class NumpySampler(Sampler):
name : str
a distribution name (a method from `numpy random`) or its alias.
seed : int
random seed for setting up sampler's state (see :func:`~.make_rng`).
random seed for setting up sampler's random numbers generator (see :func:`~.make_rng`).
**kwargs
additional keyword-arguments defining properties of specific
distribution (e.g. ``loc`` for 'normal').
Expand All @@ -457,17 +475,16 @@ class NumpySampler(Sampler):
----------
name : str
a distribution name (a method from `numpy random`).
state : numpy.random.Generator
rng : numpy.random.Generator
a random number generator
_params : dict
dict of args for Sampler's distribution.
"""
def __init__(self, name, seed=None, **kwargs):
super().__init__(name, seed, **kwargs)
name = _get_method_by_alias(name, 'np')
super().__init__(name, seed=seed, **kwargs)
name = _get_method_by_alias(name=name, module=self.rng)
self.name = name
self._params = copy(kwargs)
self.state = make_rng(seed)


def sample(self, size):
Expand All @@ -483,7 +500,7 @@ def sample(self, size):
np.ndarray
array of shape (size, Sampler's dimension).
"""
sampler = getattr(self.state, self.name)
sampler = getattr(self.rng, self.name)
sample = sampler(size=size, **self._params)
if len(sample.shape) == 1:
sample = sample.reshape(-1, 1)
Expand All @@ -498,7 +515,7 @@ class ScipySampler(Sampler):
name : str
a distribution name, a class from `scipy.stats`, or its alias.
seed : int
random seed for setting up sampler's state (see :func:`~.make_rng`).
random seed for setting up sampler's random number generator (see :func:`~.make_rng`).
**kwargs
additional parameters for specification of the distribution.
For instance, `scale` for name='gamma'.
Expand All @@ -507,16 +524,15 @@ class ScipySampler(Sampler):
----------
name : str
a distribution name (a class from `scipy.stats`).
state : numpy.random.Generator
rng : numpy.random.Generator
a random number generator
distr
a distribution class
"""
def __init__(self, name, seed=None, **kwargs):
super().__init__(name, seed, **kwargs)
super().__init__(name, seed=seed, **kwargs)
name = _get_method_by_alias(name, 'ss')
self.name = name
self.state = make_rng(seed)
self.distr = getattr(ss, self.name)(**kwargs)

def sample(self, size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rng should be allowed to be passed as optional argument

Expand All @@ -534,7 +550,7 @@ def sample(self, size):
array of shape (size, Sampler's dimension).
"""
sampler = self.distr.rvs
sample = sampler(size=size, random_state=self.state)
sample = sampler(size=size, random_state=self.rng)
if len(sample.shape) == 1:
sample = sample.reshape(-1, 1)
return sample
Expand All @@ -551,7 +567,7 @@ class HistoSampler(Sampler):
edges : list
list of len=histo_dimension, contains edges of bins along axes.
seed : int
random seed for setting up sampler's state (see :func:`~.make_rng`).
random seed for setting up sampler's random numbers generator (see :func:`~.make_rng`).

Attributes
----------
Expand All @@ -567,7 +583,7 @@ class HistoSampler(Sampler):
Otherwise, edges should be supplied. In this case all bins are empty.
"""
def __init__(self, histo=None, edges=None, seed=None, **kwargs):
super().__init__(histo, edges, seed, **kwargs)
super().__init__(histo, edges, seed=seed, **kwargs)
if histo is not None:
self.bins = histo[0]
self.edges = histo[1]
Expand All @@ -585,8 +601,7 @@ def __init__(self, histo=None, edges=None, seed=None, **kwargs):
self.nonzero_probs_idx = np.asarray(self.probs != 0.0).nonzero()[0]
self.nonzero_probs = self.probs[self.nonzero_probs_idx]

self.state = make_rng(seed)
self.state_sampler = self.state.uniform
self.rng_sampler = self.rng.uniform

def sample(self, size):
""" Sampling method of ``HistoSampler``.
Expand All @@ -604,11 +619,11 @@ def sample(self, size):
array of shape (size, histo dimension).
"""
# Choose bins to use according to non-zero probabilities
bin_nums = self.state.choice(self.nonzero_probs_idx, p=self.nonzero_probs, size=size)
bin_nums = self.rng.choice(self.nonzero_probs_idx, p=self.nonzero_probs, size=size)

# uniformly generate samples from selected boxes
low, high = self.l_all[bin_nums], self.h_all[bin_nums]
return self.state_sampler(low=low, high=high)
return self.rng_sampler(low=low, high=high)

def update(self, points):
""" Update bins of sampler's histogram by throwing in additional points.
Expand Down Expand Up @@ -638,3 +653,32 @@ def cart_prod(*arrs):
"""
grids = np.meshgrid(*arrs, indexing='ij')
return np.stack(grids, axis=-1).reshape(-1, len(arrs))

def mix_samplers_seeds(left, right):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a staticmethod of base Sampler class

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are mixing seeds, which is wrong. If I would mix two samplers in the proposed way and generate one random number, it would be the same, irrespectable of how many times each of those two samplers was called before creating the mixture.

You need to mix entropies, and there is a well-established (and used in other places in batchflow) way to do so: np.random.SeedSequence

rng1, rng2
state1 = rng1.bit_generator.state['state']['state']
state2 = rng2.bit_generator.state['state']['state']
seed = np.random.SeedSequence([state1, state2])
rng = np.random.default_rng(seed)

While the difference between these two approaches is hard to come by in any realistic example, the latter is the official way to do so.

In either case, the current proposed way to seed the RNG in sampler would not work with batchflow+seismiQB ways to fix the randomization, and the only thing you need to actually fix the seed for make_locations(sampler) in seismiQB is the ability to pass custom rng into Sampler.sample call

""" Mix seeds of provided samplers with minimal possible collisions.

- If both seed are None, returns None.
- If both seeds are fixed numbers of the same type, returns a new number which binary form contains
odd bits of left sampler's seed at odd positions and even bits of right sampler's seed at even position.
- If one of seeds is None and the other one is not, raises ValueError.

Defines a mapping from R^2 to R^1 such as its image uniformly covers the whole values range of seeds data type.
Never causes data type overflow on chain call (contrary to addition, multiplication, cantor or szudzik pairing).
Contrary to exclusive disjunction (aka xor), does not produce a zero when two identical seeds are mixed.
"""
if left.seed is None and right.seed is None:
return None

if left.seed is not None and right.seed is not None:
left_nbits = 8 * np.nbytes[type(left.seed)]
left_mask = (2 ** left_nbits - 1) // 3 # create mask of form '01010101' * nbytes

right_nbits = 8 * np.nbytes[type(right.seed)]
right_mask = (2 ** right_nbits - 1) // 3 * 2 # create mask of form '10101010' * nbytes

return left.seed & left_mask | right.seed & right_mask # seed type changes to np.int64

msg = "Samplers seeds can be combined only if both of them are either numbers or None. "\
f"Left sampler {left} has seed {left.seed} and right {right} has seed {right.seed}. "\
"Either fix both seeds or none at all."
raise ValueError(msg)
Loading