-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow seed fix in Sampler
#669
base: master
Are you sure you want to change the base?
Changes from all commits
9e3f4ed
0323e71
c150739
b7503b7
aeb5dc6
9aad366
222db30
39f6316
81d70bd
1821ce4
ff2b716
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,23 +30,32 @@ | |
'c': {'np': 'choice'} | ||
} | ||
|
||
def _get_method_by_alias(alias, module, tf_distributions=None): | ||
def _get_method_by_alias(name, module, tf_distributions=None): | ||
""" Fetch fullname of a randomizer from ``scipy.stats``, ``tensorflow`` or | ||
``numpy`` by its alias or fullname. | ||
``numpy`` by its alias or name. | ||
""" | ||
rnd_submodules = {'np': np.random, | ||
'tf': tf_distributions, | ||
'ss': ss} | ||
# fetch fullname | ||
fullname = ALIASES.get(alias, {module: alias for module in ['np', 'tf', 'ss']}).get(module, None) | ||
if fullname is None: | ||
raise ValueError(f"Distribution {alias} has no implementaion in module {module}") | ||
rnd_modules = {'np': np.random, | ||
'tf': tf_distributions, | ||
'ss': ss} | ||
|
||
if isinstance(module, str): | ||
if name in ALIASES: # pylint: disable=consider-using-get | ||
aliases = ALIASES[name] | ||
|
||
if module in aliases: | ||
name = aliases[module] | ||
else: | ||
msg = f"The mapping of distribution alias '{name}' to its full name "\ | ||
f"for module '{module}' is not defined." | ||
raise ValueError(msg) | ||
|
||
module = rnd_modules[module] | ||
|
||
# check that the randomizer is implemented in corresponding module | ||
if not hasattr(rnd_submodules[module], fullname): | ||
raise ValueError(f"Distribution {fullname} has no implementaion in module {module}") | ||
if not hasattr(module, name): | ||
raise ValueError(f"Distribution {name} has no implementaion in module {module}") | ||
|
||
return fullname | ||
return name | ||
|
||
|
||
def arithmetize(cls): | ||
|
@@ -84,10 +93,13 @@ class Sampler(): | |
weight : float | ||
weight of Sampler self in mixtures. | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
def __init__(self, *args, seed=None, **kwargs): | ||
self.__array_priority__ = 100 | ||
self.weight = 1.0 | ||
|
||
self.seed = seed | ||
self.rng = make_rng(seed) | ||
|
||
# if dim is supplied, redefine sampling method | ||
if 'dim' in kwargs: | ||
# assemble stacked sampler | ||
|
@@ -216,8 +228,9 @@ def truncate(self, high=None, low=None, expr=None, prob=0.5, max_iters=None, sam | |
class OrSampler(Sampler): | ||
""" Class for implementing `|` (mixture) operation on `Sampler`-instances. | ||
""" | ||
def __init__(self, left, right, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
def __init__(self, left, right, *args, seed=None, **kwargs): | ||
seed = mix_samplers_seeds(left, right) if seed is None else seed | ||
super().__init__(*args, seed=seed, **kwargs) | ||
self.bases = [left, right] | ||
|
||
# calculate probs of samplers in mixture | ||
|
@@ -230,21 +243,22 @@ def sample(self, size): | |
defined by weights (`self.weight`-attr) from two samplers invoked (`self.bases`-attr) and | ||
mixes them in one sample of needed size. | ||
""" | ||
up_size = np.random.binomial(size, self.normed[0]) | ||
up_size = self.rng.binomial(size, self.normed[0]) | ||
low_size = size - up_size | ||
|
||
up_sample = self.bases[0].sample(size=up_size) | ||
low_sample = self.bases[1].sample(size=low_size) | ||
sample_points = np.concatenate([up_sample, low_sample]) | ||
sample_points = sample_points[np.random.permutation(size)] | ||
sample_points = sample_points[self.rng.permutation(size)] | ||
|
||
return sample_points | ||
|
||
class AndSampler(Sampler): | ||
""" Class for implementing `&` (coordinates stacking) operation on `Sampler`-instances. | ||
""" | ||
def __init__(self, left, right, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
def __init__(self, left, right, *args, seed=None, **kwargs): | ||
seed = mix_samplers_seeds(left, right) if seed is None else seed | ||
super().__init__(*args, seed=seed, **kwargs) | ||
self.bases = [left, right] | ||
|
||
def sample(self, size): | ||
|
@@ -258,8 +272,9 @@ def sample(self, size): | |
class ApplySampler(Sampler): | ||
""" Class for implementing `apply` (adding transform) operation on `Sampler`-instances. | ||
""" | ||
def __init__(self, sampler, transform, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
def __init__(self, sampler, transform, *args, seed=None, **kwargs): | ||
seed = sampler.seed if seed is None else seed | ||
super().__init__(*args, seed=seed, **kwargs) | ||
self.bases = [sampler] | ||
self.transform = transform | ||
|
||
|
@@ -276,10 +291,12 @@ class TruncateSampler(Sampler): | |
# from the region of interest using this number of iterations, we throw a Warning or ValueError | ||
max_iters = 1e7 | ||
|
||
def __init__(self, sampler, high=None, low=None, expr=None, prob=0.5, max_iters=None, | ||
sample_anyways=False, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
def __init__(self, sampler, *args, high=None, low=None, expr=None, prob=0.5, | ||
max_iters=None, sample_anyways=False, seed=None, **kwargs): | ||
seed = sampler.seed if seed is None else seed | ||
super().__init__(*args, seed=seed, **kwargs) | ||
self.bases = [sampler] | ||
|
||
self.high = high | ||
self.low = low | ||
self.expr = expr | ||
|
@@ -346,8 +363,9 @@ class BaseOperationSampler(Sampler): | |
""" Base class for implementing all arithmetic operations on `Sampler`-instances. | ||
""" | ||
operation = None | ||
def __init__(self, left, right, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
def __init__(self, left, right, *args, seed=None, **kwargs): | ||
seed = mix_samplers_seeds(left, right) if seed is None else seed | ||
super().__init__(*args, seed=seed, **kwargs) | ||
self.bases = [left, right] | ||
|
||
def sample(self, size): | ||
|
@@ -448,7 +466,7 @@ class NumpySampler(Sampler): | |
name : str | ||
a distribution name (a method from `numpy random`) or its alias. | ||
seed : int | ||
random seed for setting up sampler's state (see :func:`~.make_rng`). | ||
random seed for setting up sampler's random numbers generator (see :func:`~.make_rng`). | ||
**kwargs | ||
additional keyword-arguments defining properties of specific | ||
distribution (e.g. ``loc`` for 'normal'). | ||
|
@@ -457,17 +475,16 @@ class NumpySampler(Sampler): | |
---------- | ||
name : str | ||
a distribution name (a method from `numpy random`). | ||
state : numpy.random.Generator | ||
rng : numpy.random.Generator | ||
a random number generator | ||
_params : dict | ||
dict of args for Sampler's distribution. | ||
""" | ||
def __init__(self, name, seed=None, **kwargs): | ||
super().__init__(name, seed, **kwargs) | ||
name = _get_method_by_alias(name, 'np') | ||
super().__init__(name, seed=seed, **kwargs) | ||
name = _get_method_by_alias(name=name, module=self.rng) | ||
self.name = name | ||
self._params = copy(kwargs) | ||
self.state = make_rng(seed) | ||
|
||
|
||
def sample(self, size): | ||
|
@@ -483,7 +500,7 @@ def sample(self, size): | |
np.ndarray | ||
array of shape (size, Sampler's dimension). | ||
""" | ||
sampler = getattr(self.state, self.name) | ||
sampler = getattr(self.rng, self.name) | ||
sample = sampler(size=size, **self._params) | ||
if len(sample.shape) == 1: | ||
sample = sample.reshape(-1, 1) | ||
|
@@ -498,7 +515,7 @@ class ScipySampler(Sampler): | |
name : str | ||
a distribution name, a class from `scipy.stats`, or its alias. | ||
seed : int | ||
random seed for setting up sampler's state (see :func:`~.make_rng`). | ||
random seed for setting up sampler's random number generator (see :func:`~.make_rng`). | ||
**kwargs | ||
additional parameters for specification of the distribution. | ||
For instance, `scale` for name='gamma'. | ||
|
@@ -507,16 +524,15 @@ class ScipySampler(Sampler): | |
---------- | ||
name : str | ||
a distribution name (a class from `scipy.stats`). | ||
state : numpy.random.Generator | ||
rng : numpy.random.Generator | ||
a random number generator | ||
distr | ||
a distribution class | ||
""" | ||
def __init__(self, name, seed=None, **kwargs): | ||
super().__init__(name, seed, **kwargs) | ||
super().__init__(name, seed=seed, **kwargs) | ||
name = _get_method_by_alias(name, 'ss') | ||
self.name = name | ||
self.state = make_rng(seed) | ||
self.distr = getattr(ss, self.name)(**kwargs) | ||
|
||
def sample(self, size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
@@ -534,7 +550,7 @@ def sample(self, size): | |
array of shape (size, Sampler's dimension). | ||
""" | ||
sampler = self.distr.rvs | ||
sample = sampler(size=size, random_state=self.state) | ||
sample = sampler(size=size, random_state=self.rng) | ||
if len(sample.shape) == 1: | ||
sample = sample.reshape(-1, 1) | ||
return sample | ||
|
@@ -551,7 +567,7 @@ class HistoSampler(Sampler): | |
edges : list | ||
list of len=histo_dimension, contains edges of bins along axes. | ||
seed : int | ||
random seed for setting up sampler's state (see :func:`~.make_rng`). | ||
random seed for setting up sampler's random numbers generator (see :func:`~.make_rng`). | ||
|
||
Attributes | ||
---------- | ||
|
@@ -567,7 +583,7 @@ class HistoSampler(Sampler): | |
Otherwise, edges should be supplied. In this case all bins are empty. | ||
""" | ||
def __init__(self, histo=None, edges=None, seed=None, **kwargs): | ||
super().__init__(histo, edges, seed, **kwargs) | ||
super().__init__(histo, edges, seed=seed, **kwargs) | ||
if histo is not None: | ||
self.bins = histo[0] | ||
self.edges = histo[1] | ||
|
@@ -585,8 +601,7 @@ def __init__(self, histo=None, edges=None, seed=None, **kwargs): | |
self.nonzero_probs_idx = np.asarray(self.probs != 0.0).nonzero()[0] | ||
self.nonzero_probs = self.probs[self.nonzero_probs_idx] | ||
|
||
self.state = make_rng(seed) | ||
self.state_sampler = self.state.uniform | ||
self.rng_sampler = self.rng.uniform | ||
|
||
def sample(self, size): | ||
""" Sampling method of ``HistoSampler``. | ||
|
@@ -604,11 +619,11 @@ def sample(self, size): | |
array of shape (size, histo dimension). | ||
""" | ||
# Choose bins to use according to non-zero probabilities | ||
bin_nums = self.state.choice(self.nonzero_probs_idx, p=self.nonzero_probs, size=size) | ||
bin_nums = self.rng.choice(self.nonzero_probs_idx, p=self.nonzero_probs, size=size) | ||
|
||
# uniformly generate samples from selected boxes | ||
low, high = self.l_all[bin_nums], self.h_all[bin_nums] | ||
return self.state_sampler(low=low, high=high) | ||
return self.rng_sampler(low=low, high=high) | ||
|
||
def update(self, points): | ||
""" Update bins of sampler's histogram by throwing in additional points. | ||
|
@@ -638,3 +653,32 @@ def cart_prod(*arrs): | |
""" | ||
grids = np.meshgrid(*arrs, indexing='ij') | ||
return np.stack(grids, axis=-1).reshape(-1, len(arrs)) | ||
|
||
def mix_samplers_seeds(left, right): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be a staticmethod of base There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are mixing seeds, which is wrong. If I would mix two samplers in the proposed way and generate one random number, it would be the same, irrespectable of how many times each of those two samplers was called before creating the mixture. You need to mix entropies, and there is a well-established (and used in other places in rng1, rng2
state1 = rng1.bit_generator.state['state']['state']
state2 = rng2.bit_generator.state['state']['state']
seed = np.random.SeedSequence([state1, state2])
rng = np.random.default_rng(seed) While the difference between these two approaches is hard to come by in any realistic example, the latter is the official way to do so. In either case, the current proposed way to seed the RNG in sampler would not work with |
||
""" Mix seeds of provided samplers with minimal possible collisions. | ||
|
||
- If both seed are None, returns None. | ||
- If both seeds are fixed numbers of the same type, returns a new number which binary form contains | ||
odd bits of left sampler's seed at odd positions and even bits of right sampler's seed at even position. | ||
- If one of seeds is None and the other one is not, raises ValueError. | ||
|
||
Defines a mapping from R^2 to R^1 such as its image uniformly covers the whole values range of seeds data type. | ||
Never causes data type overflow on chain call (contrary to addition, multiplication, cantor or szudzik pairing). | ||
Contrary to exclusive disjunction (aka xor), does not produce a zero when two identical seeds are mixed. | ||
""" | ||
if left.seed is None and right.seed is None: | ||
return None | ||
|
||
if left.seed is not None and right.seed is not None: | ||
left_nbits = 8 * np.nbytes[type(left.seed)] | ||
left_mask = (2 ** left_nbits - 1) // 3 # create mask of form '01010101' * nbytes | ||
|
||
right_nbits = 8 * np.nbytes[type(right.seed)] | ||
right_mask = (2 ** right_nbits - 1) // 3 * 2 # create mask of form '10101010' * nbytes | ||
|
||
return left.seed & left_mask | right.seed & right_mask # seed type changes to np.int64 | ||
|
||
msg = "Samplers seeds can be combined only if both of them are either numbers or None. "\ | ||
f"Left sampler {left} has seed {left.seed} and right {right} has seed {right.seed}. "\ | ||
"Either fix both seeds or none at all." | ||
raise ValueError(msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a chain of
random_state
variables and their SeedSequences inbatchflow
:And
Research
is even one level above that.Have you read this entire chain of properties before changing it?