diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 4d65f86c0..3f1e72305 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -5,6 +5,8 @@ Changelog
---
- Added the general Gaussian noise example model (fixed covariance)
+- Improved the interactive plotting (customised for the MaxVar-based acquisition methods)
+- Added a pair-wise plotting to plot_state() (a way to visualise n-dimensional parameters)
0.6.2 (2017-09-06)
------------------
diff --git a/elfi/methods/bo/acquisition.py b/elfi/methods/bo/acquisition.py
index 3a29db620..23c7ae987 100644
--- a/elfi/methods/bo/acquisition.py
+++ b/elfi/methods/bo/acquisition.py
@@ -201,6 +201,8 @@ def __init__(self, *args, delta=None, **kwargs):
kwargs['exploration_rate'] = 1 / delta
super(LCBSC, self).__init__(*args, **kwargs)
+ self.name = 'lcbsc'
+ self.label_fn = 'The Lower Confidence Bound Selection Criterion'
@property
def delta(self):
diff --git a/elfi/methods/bo/gpy_regression.py b/elfi/methods/bo/gpy_regression.py
index 19b0a5fe8..8999a45f0 100644
--- a/elfi/methods/bo/gpy_regression.py
+++ b/elfi/methods/bo/gpy_regression.py
@@ -338,6 +338,16 @@ def Y(self):
"""Return output evidence."""
return self._gp.Y
+ @property
+ def noise(self):
+ """Return the noise."""
+ return self._gp.Gaussian_noise.variance[0]
+
+ @property
+ def instance(self):
+ """Return the gp instance."""
+ return self._gp
+
def copy(self):
"""Return a copy of current instance."""
kopy = copy.copy(self)
diff --git a/elfi/methods/parameter_inference.py b/elfi/methods/parameter_inference.py
index 8c70a233d..e168f4540 100644
--- a/elfi/methods/parameter_inference.py
+++ b/elfi/methods/parameter_inference.py
@@ -3,9 +3,9 @@
__all__ = ['Rejection', 'SMC', 'BayesianOptimization', 'BOLFI']
import logging
+from collections import OrderedDict
from math import ceil
-import matplotlib.pyplot as plt
import numpy as np
import elfi.client
@@ -89,7 +89,6 @@ def __init__(self,
model = model.model if isinstance(model, NodeReference) else model
if not model.parameter_names:
raise ValueError('Model {} defines no parameters'.format(model))
-
self.model = model.copy()
self.output_names = self._check_outputs(output_names)
@@ -161,7 +160,7 @@ def extract_result(self):
"""
raise NotImplementedError
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
ELFI calls this method when a new batch has been computed and the state of
@@ -174,10 +173,8 @@ def update(self, batch, batch_index):
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
batch_index : int
-
- Returns
- -------
- None
+ vis : bool, optional
+ Interactive visualisation of the iterations.
"""
self.state['n_batches'] += 1
@@ -231,7 +228,7 @@ def plot_state(self, **kwargs):
"""
raise NotImplementedError
- def infer(self, *args, vis=None, **kwargs):
+ def infer(self, *args, **opts):
"""Set the objective and start the iterate loop until the inference is finished.
See the other arguments from the `set_objective` method.
@@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs):
result : Sample
"""
- vis_opt = vis if isinstance(vis, dict) else {}
-
- self.set_objective(*args, **kwargs)
-
+ vis = opts.pop('vis', None)
+ self.set_objective(*args, **opts)
while not self.finished:
- self.iterate()
- if vis:
- self.plot_state(interactive=True, **vis_opt)
-
+ self.iterate(vis=vis)
self.batches.cancel_pending()
- if vis:
- self.plot_state(close=True, **vis_opt)
return self.extract_result()
- def iterate(self):
- """Advance the inference by one iteration.
+ def iterate(self, vis=None):
+ """Forward the inference one iteration.
This is a way to manually progress the inference. One iteration consists of
waiting and processing the result of the next batch in succession and possibly
@@ -272,6 +262,11 @@ def iterate(self):
will never be more batches submitted in parallel than the `max_parallel_batches`
setting allows.
+ Parameters
+ ----------
+ vis : bool, optional
+ Interactive visualisation of the iterations.
+
Returns
-------
None
@@ -286,7 +281,7 @@ def iterate(self):
# Handle the next ready batch in succession
batch, batch_index = self.batches.wait_next()
logger.debug('Received batch %d' % batch_index)
- self.update(batch, batch_index)
+ self.update(batch, batch_index, vis=vis)
@property
def finished(self):
@@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
# Reset the inference
self.batches.reset()
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
- dict with `self.outputs` as keys and the corresponding outputs for the batch
- as values
+ dict with `self.outputs` as keys and the corresponding outputs for the batch as values
+ vis : bool, optional
+ Interactive visualisation of the iterations.
batch_index : int
"""
+ if vis and self.state['samples'] is not None:
+ self.plot_state(interactive=True, **vis)
+
super(Rejection, self).update(batch, batch_index)
if self.state['samples'] is None:
# Lazy initialization of the outputs dict
@@ -584,8 +583,8 @@ def plot_state(self, **options):
displays = []
if options.get('interactive'):
from IPython import display
- displays.append(
- display.HTML('Threshold: {}'.format(self.state['threshold'])))
+ html_display = 'Threshold: {}'.format(self.state['threshold'])
+ displays.append(display.HTML(html_display))
visin.plot_sample(
self.state['samples'],
@@ -651,14 +650,15 @@ def extract_result(self):
threshold=pop.threshold,
**self._extract_result_kwargs())
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
- dict with `self.outputs` as keys and the corresponding outputs for the batch
- as values
+ dict with `self.outputs` as keys and the corresponding outputs for the batch as values
+ vis : bool, optional
+ Interactive visualisation of the iterations.
batch_index : int
"""
@@ -833,7 +833,6 @@ def __init__(self,
output_names = [target_name] + model.parameter_names
super(BayesianOptimization, self).__init__(
model, output_names, batch_size=batch_size, **kwargs)
-
target_model = target_model or \
GPyRegression(self.model.parameter_names, bounds=bounds)
@@ -942,7 +941,7 @@ def extract_result(self):
return OptimizationResult(
x_min=batch_min, outputs=outputs, **self._extract_result_kwargs())
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
"""Update the GP regression model of the target node with a new batch.
Parameters
@@ -950,6 +949,8 @@ def update(self, batch, batch_index):
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
+ vis : bool, optional
+ Interactive visualisation of the iterations.
batch_index : int
"""
@@ -958,12 +959,22 @@ def update(self, batch, batch_index):
params = batch_to_arr2d(batch, self.parameter_names)
self._report_batch(batch_index, params, batch[self.target_name])
+ # Adding the acquisition plots.
+ if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
+ opts = {}
+ opts['point_acq'] = {'x': params, 'd': batch[self.target_name]}
+ opts['method_acq'] = self.acquisition_method.label_fn
+ arr_ax = self.plot_state(interactive=True, **opts)
optimize = self._should_optimize()
self.target_model.update(params, batch[self.target_name], optimize)
if optimize:
self.state['last_GP_update'] = self.target_model.n_evidence
+ # Adding the updated gp plots.
+ if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
+ self.plot_state(interactive=True, arr_ax=arr_ax, **opts)
+
def prepare_new_batch(self, batch_index):
"""Prepare values for a new batch.
@@ -980,7 +991,6 @@ def prepare_new_batch(self, batch_index):
"""
t = self._get_acquisition_index(batch_index)
-
# Check if we still should take initial points from the prior
if t < 0:
return
@@ -1040,60 +1050,40 @@ def _report_batch(self, batch_index, params, distances):
str += "{}{} at {}\n".format(fill, distances[i].item(), params[i])
logger.debug(str)
- def plot_state(self, **options):
- """Plot the GP surface.
-
- This feature is still experimental and currently supports only 2D cases.
- """
- f = plt.gcf()
- if len(f.axes) < 2:
- f, _ = plt.subplots(1, 2, figsize=(13, 6), sharex='row', sharey='row')
-
- gp = self.target_model
-
- # Draw the GP surface
- visin.draw_contour(
- gp.predict_mean,
- gp.bounds,
- self.parameter_names,
- title='GP target surface',
- points=gp.X,
- axes=f.axes[0],
- **options)
-
- # Draw the latest acquisitions
- if options.get('interactive'):
- point = gp.X[-1, :]
- if len(gp.X) > 1:
- f.axes[1].scatter(*point, color='red')
+ def plot_state(self, plot_acq_pairwise=False, arr_ax=None, **opts):
+ """Plot the GP surface and the acquisition space.
- displays = [gp._gp]
+ Notes
+ -----
+ - The plots of the GP surface and the acquisition space work for the
+ cases when dim < 3;
+ - The method is experimental.
- if options.get('interactive'):
- from IPython import display
- displays.insert(
- 0,
- display.HTML('Iteration {}: Acquired {} at {}'.format(
- len(gp.Y), gp.Y[-1][0], point)))
-
- # Update
- visin._update_interactive(displays, options)
-
- def acq(x):
- return self.acquisition_method.evaluate(x, len(gp.X))
-
- # Draw the acquisition surface
- visin.draw_contour(
- acq,
- gp.bounds,
- self.parameter_names,
- title='Acquisition surface',
- points=None,
- axes=f.axes[1],
- **options)
+ Parameters
+ ----------
+ plot_acq_pairwise : bool, optional
+ The option to plot the pair-wise acquisition point relationships.
- if options.get('close'):
- plt.close()
+ """
+ if plot_acq_pairwise:
+ if len(self.parameter_names) == 1:
+ logger.info('Can not plot the pair-wise comparison for 1 parameter.')
+ return
+ # Transform the acquisition points in the acceptable format.
+ dict_pts_acq = OrderedDict()
+ for idx_param, name_param in enumerate(self.parameter_names):
+ dict_pts_acq[name_param] = self.target_model.X[:, idx_param]
+ vis.plot_pairs(dict_pts_acq, **opts)
+ else:
+ if len(self.parameter_names) == 1:
+ arr_ax = vis.plot_state_1d(self, arr_ax, **opts)
+ return arr_ax
+ elif len(self.parameter_names) == 2:
+ arr_ax = vis.plot_state_2d(self, arr_ax, **opts)
+ return arr_ax
+ else:
+ logger.info('The method is supported for 1- or 2-dimensions.')
+ return
def plot_discrepancy(self, axes=None, **kwargs):
"""Plot acquired parameters vs. resulting discrepancy.
@@ -1133,7 +1123,7 @@ class BOLFI(BayesianOptimization):
"""
- def fit(self, n_evidence, threshold=None):
+ def fit(self, n_evidence, threshold=None, **opts):
"""Fit the surrogate model.
Generates a regression model for the discrepancy given the parameters.
@@ -1150,9 +1140,8 @@ def fit(self, n_evidence, threshold=None):
if n_evidence is None:
raise ValueError(
- 'You must specify the number of evidence (n_evidence) for the fitting')
-
- self.infer(n_evidence)
+ 'You must specify the number of evidence( n_evidence) for the fitting')
+ self.infer(n_evidence, **opts)
return self.extract_posterior(threshold)
def extract_posterior(self, threshold=None):
@@ -1235,12 +1224,10 @@ def sample(self,
else:
inds = np.argsort(self.target_model.Y[:, 0])
initials = np.asarray(self.target_model.X[inds])
-
self.target_model.is_sampling = True # enables caching for default RBF kernel
tasks_ids = []
ii_initial = 0
-
# sampling is embarrassingly parallel, so depending on self.client this may parallelize
for ii in range(n_chains):
seed = get_sub_seed(self.seed, ii)
@@ -1270,12 +1257,12 @@ def sample(self,
chains = np.asarray(chains)
- print(
- "{} chains of {} iterations acquired. Effective sample size and Rhat for each "
- "parameter:".format(n_chains, n_samples))
+ logger.info(
+ "%d chains of %d iterations acquired. Effective sample size and Rhat for each "
+ "parameter:" % (n_chains, n_samples))
for ii, node in enumerate(self.parameter_names):
- print(node, mcmc.eff_sample_size(chains[:, :, ii]),
- mcmc.gelman_rubin(chains[:, :, ii]))
+ chain = chains[:, :, ii]
+ logger.info("%s %d %d" % (node, mcmc.eff_sample_size(chain), mcmc.gelman_rubin(chain)))
self.target_model.is_sampling = False
diff --git a/elfi/visualization/interactive.py b/elfi/visualization/interactive.py
index 2b28df105..ceb1af51e 100644
--- a/elfi/visualization/interactive.py
+++ b/elfi/visualization/interactive.py
@@ -4,14 +4,17 @@
import matplotlib.pyplot as plt
import numpy as np
+import scipy.interpolate
logger = logging.getLogger(__name__)
def plot_sample(samples, nodes=None, n=-1, displays=None, **options):
- """Plot a scatterplot of samples.
+ """Plot a scatter-plot of samples.
- Experimental, only dims 1-2 supported.
+ Notes
+ -----
+ - Experimental, only dims 1-2 supported.
Parameters
----------
@@ -23,7 +26,8 @@ def plot_sample(samples, nodes=None, n=-1, displays=None, **options):
"""
axes = _prepare_axes(options)
-
+ if samples is None:
+ return
nodes = nodes or sorted(samples.keys())[:2]
if isinstance(nodes, str):
nodes = [nodes]
@@ -39,9 +43,8 @@ def plot_sample(samples, nodes=None, n=-1, displays=None, **options):
axes.set_ylabel(nodes[1])
axes.scatter(samples[nodes[0]][:n], samples[nodes[1]][:n])
- _update_interactive(displays, options)
-
- if options.get('close'):
+ if options.get('interactive'):
+ update_interactive(displays, options)
plt.close()
@@ -52,7 +55,8 @@ def get_axes(**options):
return plt.gca()
-def _update_interactive(displays, options):
+def update_interactive(displays, options):
+ """Update the interactive plot."""
displays = displays or []
if options.get('interactive'):
from IPython import display
@@ -67,7 +71,6 @@ def _prepare_axes(options):
if ion:
axes.clear()
-
if options.get('xlim'):
axes.set_xlim(options.get('xlim'))
if options.get('ylim'):
@@ -76,7 +79,7 @@ def _prepare_axes(options):
return axes
-def draw_contour(fn, bounds, nodes=None, points=None, title=None, **options):
+def draw_contour(fn, bounds, params=None, points=None, title=None, label=None, **options):
"""Plot a contour of a function.
Experimental, only 2D supported.
@@ -92,29 +95,33 @@ def draw_contour(fn, bounds, nodes=None, points=None, title=None, **options):
title : str, optional
"""
- ax = get_axes(**options)
-
+ # Preparing the contour plot settings.
+ if options.get('axes'):
+ axes = options['axes']
+ plt.sca(axes)
x, y = np.meshgrid(np.linspace(*bounds[0]), np.linspace(*bounds[1]))
z = fn(np.c_[x.reshape(-1), y.reshape(-1)])
+ # Plotting the contour.
+ CS = plt.contourf(x, y, z.reshape(x.shape), 25)
+ CB = plt.colorbar(CS, orientation='horizontal', format='%.1e')
+ CB.set_label(label)
+ rbf = scipy.interpolate.Rbf(x, y, z, function='linear')
+ zi = rbf(x, y)
+ plt.imshow(zi,
+ vmin=z.min(),
+ vmax=z.max(),
+ origin='lower',
+ extent=[x.min(), x.max(), y.min(), y.max()])
+
+ # Adding the acquisition points.
+ if points is not None:
+ plt.scatter(points[:, 0], points[:, 1], color='k')
- if ax:
- plt.sca(ax)
- plt.cla()
-
+ # Adding the labels.
if title:
plt.title(title)
- try:
- plt.contour(x, y, z.reshape(x.shape))
- except ValueError:
- logger.warning('Could not draw a contour plot')
- if points is not None:
- plt.scatter(points[:-1, 0], points[:-1, 1])
- if options.get('interactive'):
- plt.scatter(points[-1, 0], points[-1, 1], color='r')
-
plt.xlim(bounds[0])
plt.ylim(bounds[1])
-
- if nodes:
- plt.xlabel(nodes[0])
- plt.ylabel(nodes[1])
+ if params:
+ plt.xlabel(params[0])
+ plt.ylabel(params[1])
diff --git a/elfi/visualization/visualization.py b/elfi/visualization/visualization.py
index d3736e5a4..9ba7d3a3e 100644
--- a/elfi/visualization/visualization.py
+++ b/elfi/visualization/visualization.py
@@ -5,6 +5,7 @@
import matplotlib.pyplot as plt
import numpy as np
+import elfi.visualization.interactive as visin
from elfi.model.elfi_model import Constant, ElfiModel, NodeReference
@@ -99,6 +100,7 @@ def _create_axes(axes, shape, **kwargs):
else:
fig, axes = plt.subplots(ncols=shape[1], nrows=shape[0], **fig_kwargs)
axes = np.atleast_1d(axes)
+ fig.tight_layout(pad=2.0)
return axes, kwargs
@@ -156,12 +158,157 @@ def plot_marginals(samples, selector=None, bins=20, axes=None, **kwargs):
return axes
-def plot_pairs(samples, selector=None, bins=20, axes=None, **kwargs):
- """Plot pairwise relationships as a matrix with marginals on the diagonal.
+def plot_state_1d(model_bo, arr_ax=None, **options):
+ """Plot the GP surface and the acquisition function in 1D.
- The y-axis of marginal histograms are scaled.
+ Notes
+ -----
+ The method is experimental.
+
+ Parameters
+ ----------
+ model_bo : elfi.methods.parameter_inference.BOLFI
+
+ """
+ gp = model_bo.target_model
+ pts_eval = np.linspace(*gp.bounds[0])
+
+ if arr_ax is None:
+ fig, arr_ax = plt.subplots(nrows=1,
+ ncols=2,
+ figsize=(12, 4),
+ sharex=True)
+ plt.ticklabel_format(style='sci', axis='y', scilimits=(-3, 4))
+ fig.tight_layout(pad=2.0)
+
+ # Plotting the acquisition space and the recent acquisition.
+ arr_ax[1].set_title('Acquisition surface')
+ arr_ax[1].set_xlabel(model_bo.parameter_names[0])
+ arr_ax[1].set_ylabel(options.pop('method_acq'))
+ score_acq = model_bo.acquisition_method.evaluate(pts_eval)
+ arr_ax[1].plot(pts_eval,
+ score_acq,
+ color='k',
+ label='acquisition function')
+ # Plotting the confidence interval and the mean.
+ mean, var = gp.predict(pts_eval, noiseless=False)
+ sigma = np.sqrt(var)
+ z_95 = 1.96
+ lb_ci = mean - z_95 * (sigma)
+ ub_ci = mean + z_95 * (sigma)
+ arr_ax[0].fill(np.concatenate([pts_eval, pts_eval[::-1]]),
+ np.concatenate([lb_ci, ub_ci[::-1]]),
+ alpha=.1,
+ fc='k',
+ ec='None',
+ label='95% confidence interval')
+ arr_ax[0].plot(pts_eval, mean, color='k', label='mean')
+ # Plotting the acquisition threshold.
+ if model_bo.acquisition_method.name in ['max_var', 'rand_max_var', 'exp_int_var']:
+ thresh_acq = np.repeat(model_bo.acquisition_method.eps,
+ len(pts_eval))
+ arr_ax[0].plot(pts_eval,
+ thresh_acq,
+ color='g',
+ label='acquisition threshold')
+ # Plotting the acquired points.
+ arr_ax[0].scatter(gp.X, gp.Y, color='k')
+
+ arr_ax[0].legend(loc='upper right')
+ arr_ax[0].set_title('GP target surface')
+ arr_ax[0].set_xlabel(model_bo.parameter_names[0])
+ arr_ax[0].set_ylabel('Discrepancy')
+
+ return arr_ax
+ else:
+ if options.get('interactive'):
+ from IPython import display
+ pt_last = options.pop('point_acq')
+ arr_ax[0].scatter(pt_last['x'], pt_last['d'], color='r')
+ ymin, ymax = arr_ax[1].get_ylim()
+ arr_ax[1].vlines(x=pt_last['x'], ymin=ymin, ymax=ymax,
+ color='r', linestyle='--',
+ label='latest acquisition')
+
+ arr_ax[1].legend(loc='upper right')
+ displays = []
+ displays.append(gp.instance)
+ n_it = int(len(gp.Y) / model_bo.batch_size)
+ html_disp = 'Iteration {}: Acquired {} at {}' \
+ .format(n_it, pt_last['d'], pt_last['x'])
+ displays.append(display.HTML(html_disp))
+ visin.update_interactive(displays, options=options)
+
+ plt.close()
+
+
+def plot_state_2d(model_bo, arr_ax=None, pre=False, post=False, **options):
+ """Plot the GP surface and the acquisition function in 2D.
+
+ Notes
+ -----
+ The method is experimental.
+
+ Parameters
+ ----------
+ model_bo : elfi.methods.parameter_inference.BOLFI
- Parameters
+ """
+ gp = model_bo.target_model
+
+ if arr_ax is None:
+ # Defining the plotting settings.
+ _, arr_ax = plt.subplots(nrows=1,
+ ncols=2,
+ figsize=(16, 10),
+ sharex='row',
+ sharey='row')
+
+ # Plotting the acquisition space and the recent acquisition.
+ def fn_acq(x):
+ return model_bo.acquisition_method.evaluate(x, len(gp.X))
+ visin.draw_contour(fn_acq,
+ gp.bounds,
+ model_bo.parameter_names,
+ title='Acquisition surface',
+ axes=arr_ax[1],
+ label=options.pop('method_acq'),
+ **options)
+ # Plotting the GP target surface and the acquired points.
+ visin.draw_contour(gp.predict_mean,
+ gp.bounds,
+ model_bo.parameter_names,
+ title='GP target surface',
+ points=gp.X,
+ axes=arr_ax[0],
+ label='Discrepancy',
+ **options)
+ return arr_ax
+ else:
+ if options.get('interactive'):
+ from IPython import display
+ pt_last = options.pop('point_acq')
+ arr_ax[0].scatter(pt_last['x'][:, 0], pt_last['x'][:, 1], color='r')
+ arr_ax[1].scatter(pt_last['x'][:, 0], pt_last['x'][:, 1], color='r')
+
+ displays = []
+ displays.append(gp.instance)
+ n_it = int(len(gp.Y) / model_bo.batch_size)
+ html_disp = 'Iteration {}: Acquired {} at {}' \
+ .format(n_it, pt_last['d'], pt_last['x'])
+ displays.append(display.HTML(html_disp))
+ visin.update_interactive(displays, options=options)
+ plt.close()
+
+
+def plot_pairs(data, selector=None, bins=20, axes=None, **kwargs):
+ """Plot pair-wise relationships in a grid with marginals on the diagonal.
+
+ Notes
+ -----
+ Removed: The y-axis of marginal histograms are scaled.
+
+ Parameters
----------
samples : OrderedDict of np.arrays
selector : iterable of ints or strings, optional
@@ -175,31 +322,38 @@ def plot_pairs(samples, selector=None, bins=20, axes=None, **kwargs):
axes : np.array of plt.Axes
"""
- samples = _limit_params(samples, selector)
- shape = (len(samples), len(samples))
+ # Pop the target kwargs.
edgecolor = kwargs.pop('edgecolor', 'none')
- dot_size = kwargs.pop('s', 2)
- kwargs['sharex'] = kwargs.get('sharex', 'col')
- kwargs['sharey'] = kwargs.get('sharey', 'row')
- axes, kwargs = _create_axes(axes, shape, **kwargs)
-
- for i1, k1 in enumerate(samples):
- min_samples = samples[k1].min()
- max_samples = samples[k1].max()
- for i2, k2 in enumerate(samples):
- if i1 == i2:
- # create a histogram with scaled y-axis
- hist, bin_edges = np.histogram(samples[k1], bins=bins)
- bar_width = bin_edges[1] - bin_edges[0]
- hist = (hist - hist.min()) * (max_samples - min_samples) / (
- hist.max() - hist.min())
- axes[i1, i2].bar(bin_edges[:-1], hist, bar_width, bottom=min_samples, **kwargs)
+ dot_size = kwargs.pop('s', 25)
+
+ # Filter the data.
+ data_selected = _limit_params(data, selector)
+
+ # Initialise the figure.
+ shape_fig = (len(data_selected), len(data_selected))
+ axes, kwargs = _create_axes(axes, shape_fig, **kwargs)
+
+ # Populate the grid of figures.
+ for idx_row, key_row in enumerate(data_selected):
+ for idx_col, key_col in enumerate(data_selected):
+ if idx_row == idx_col:
+ # Plot the marginals.
+ axes[idx_row, idx_col].hist(data_selected[key_row], bins=bins, **kwargs)
+ axes[idx_row, idx_col].set_xlabel(key_row)
+ # Experimental: Calculate the bin length.
+ x_min, x_max = axes[idx_row, idx_col].get_xlim()
+ length_bin = (x_max - x_min) / bins
+ axes[idx_row, idx_col].set_ylabel(
+ 'Count (bin length: {0:.2f})'.format(length_bin))
else:
- axes[i1, i2].scatter(
- samples[k2], samples[k1], s=dot_size, edgecolor=edgecolor, **kwargs)
-
- axes[i1, 0].set_ylabel(k1)
- axes[-1, i1].set_xlabel(k1)
+ # Plot the pairs.
+ axes[idx_row, idx_col].scatter(data_selected[key_row],
+ data_selected[key_col],
+ edgecolor=edgecolor,
+ s=dot_size,
+ **kwargs)
+ axes[idx_row, idx_col].set_xlabel(key_row)
+ axes[idx_row, idx_col].set_ylabel(key_col)
return axes
diff --git a/tests/unit/test_document_examples.py b/tests/unit/test_document_examples.py
index f873bb04b..7fabef3af 100644
--- a/tests/unit/test_document_examples.py
+++ b/tests/unit/test_document_examples.py
@@ -31,7 +31,7 @@ def __init__(self, model, discrepancy_name, threshold, **kwargs):
def set_objective(self, n_sim):
self.objective['n_sim'] = n_sim
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
super(CustomMethod, self).update(batch, batch_index)
# Make a filter mask (logical numpy array) from the distance array