From 6402a49be571fe0c1112a36377a810f3dea0ec08 Mon Sep 17 00:00:00 2001
From: perdaug <2019828p@student.gla.ac.uk>
Date: Thu, 7 Sep 2017 13:24:22 +0300
Subject: [PATCH] [MaxVar split, Part 2] Added the visualisation improvements.
---
CHANGELOG.rst | 5 +
elfi/methods/bo/acquisition.py | 2 +
elfi/methods/parameter_inference.py | 169 +++++++++----------
elfi/visualization/interactive.py | 76 +++++----
elfi/visualization/visualization.py | 234 +++++++++++++++++++++++----
tests/unit/test_document_examples.py | 2 +-
6 files changed, 343 insertions(+), 145 deletions(-)
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index e3fc4d70..2a596fed 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -1,6 +1,11 @@
Changelog
=========
+0.x
+---
+- Improved the interactive plotting (customised for the MaxVar-based acquisition methods)
+- Added a pair-wise plotting to plot_state() (a way to visualise n-dimensional parameters)
+
0.6.3 (2017-09-28)
------------------
diff --git a/elfi/methods/bo/acquisition.py b/elfi/methods/bo/acquisition.py
index 3a29db62..23c7ae98 100644
--- a/elfi/methods/bo/acquisition.py
+++ b/elfi/methods/bo/acquisition.py
@@ -201,6 +201,8 @@ def __init__(self, *args, delta=None, **kwargs):
kwargs['exploration_rate'] = 1 / delta
super(LCBSC, self).__init__(*args, **kwargs)
+ self.name = 'lcbsc'
+ self.label_fn = 'The Lower Confidence Bound Selection Criterion'
@property
def delta(self):
diff --git a/elfi/methods/parameter_inference.py b/elfi/methods/parameter_inference.py
index 8c70a233..5845d5b3 100644
--- a/elfi/methods/parameter_inference.py
+++ b/elfi/methods/parameter_inference.py
@@ -3,9 +3,9 @@
__all__ = ['Rejection', 'SMC', 'BayesianOptimization', 'BOLFI']
import logging
+from collections import OrderedDict
from math import ceil
-import matplotlib.pyplot as plt
import numpy as np
import elfi.client
@@ -89,7 +89,6 @@ def __init__(self,
model = model.model if isinstance(model, NodeReference) else model
if not model.parameter_names:
raise ValueError('Model {} defines no parameters'.format(model))
-
self.model = model.copy()
self.output_names = self._check_outputs(output_names)
@@ -161,7 +160,7 @@ def extract_result(self):
"""
raise NotImplementedError
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
ELFI calls this method when a new batch has been computed and the state of
@@ -174,10 +173,8 @@ def update(self, batch, batch_index):
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
batch_index : int
-
- Returns
- -------
- None
+ vis : bool, optional
+ Interactive visualisation of the iterations.
"""
self.state['n_batches'] += 1
@@ -231,7 +228,7 @@ def plot_state(self, **kwargs):
"""
raise NotImplementedError
- def infer(self, *args, vis=None, **kwargs):
+ def infer(self, *args, **options):
"""Set the objective and start the iterate loop until the inference is finished.
See the other arguments from the `set_objective` method.
@@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs):
result : Sample
"""
- vis_opt = vis if isinstance(vis, dict) else {}
-
- self.set_objective(*args, **kwargs)
-
+ vis = options.pop('vis', None)
+ self.set_objective(*args, **options)
while not self.finished:
- self.iterate()
- if vis:
- self.plot_state(interactive=True, **vis_opt)
-
+ self.iterate(vis=vis)
self.batches.cancel_pending()
- if vis:
- self.plot_state(close=True, **vis_opt)
return self.extract_result()
- def iterate(self):
- """Advance the inference by one iteration.
+ def iterate(self, vis=None):
+ """Forward the inference one iteration.
This is a way to manually progress the inference. One iteration consists of
waiting and processing the result of the next batch in succession and possibly
@@ -272,6 +262,11 @@ def iterate(self):
will never be more batches submitted in parallel than the `max_parallel_batches`
setting allows.
+ Parameters
+ ----------
+ vis : bool, optional
+ Interactive visualisation of the iterations.
+
Returns
-------
None
@@ -286,7 +281,7 @@ def iterate(self):
# Handle the next ready batch in succession
batch, batch_index = self.batches.wait_next()
logger.debug('Received batch %d' % batch_index)
- self.update(batch, batch_index)
+ self.update(batch, batch_index, vis=vis)
@property
def finished(self):
@@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
# Reset the inference
self.batches.reset()
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
- dict with `self.outputs` as keys and the corresponding outputs for the batch
- as values
+ dict with `self.outputs` as keys and the corresponding outputs for the batch as values
+ vis : bool, optional
+ Interactive visualisation of the iterations.
batch_index : int
"""
+ if vis and self.state['samples'] is not None:
+ self.plot_state(interactive=True, **vis)
+
super(Rejection, self).update(batch, batch_index)
if self.state['samples'] is None:
# Lazy initialization of the outputs dict
@@ -584,8 +583,8 @@ def plot_state(self, **options):
displays = []
if options.get('interactive'):
from IPython import display
- displays.append(
- display.HTML('Threshold: {}'.format(self.state['threshold'])))
+ html_display = 'Threshold: {}'.format(self.state['threshold'])
+ displays.append(display.HTML(html_display))
visin.plot_sample(
self.state['samples'],
@@ -651,14 +650,15 @@ def extract_result(self):
threshold=pop.threshold,
**self._extract_result_kwargs())
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
- dict with `self.outputs` as keys and the corresponding outputs for the batch
- as values
+ dict with `self.outputs` as keys and the corresponding outputs for the batch as values
+ vis : bool, optional
+ Interactive visualisation of the iterations.
batch_index : int
"""
@@ -942,7 +942,7 @@ def extract_result(self):
return OptimizationResult(
x_min=batch_min, outputs=outputs, **self._extract_result_kwargs())
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
"""Update the GP regression model of the target node with a new batch.
Parameters
@@ -950,6 +950,8 @@ def update(self, batch, batch_index):
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
+ vis : bool, optional
+ Interactive visualisation of the iterations.
batch_index : int
"""
@@ -959,11 +961,22 @@ def update(self, batch, batch_index):
params = batch_to_arr2d(batch, self.parameter_names)
self._report_batch(batch_index, params, batch[self.target_name])
+ # Adding the acquisition plots.
+ if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
+ options = {}
+ options['point_acq'] = {'x': params, 'd': batch[self.target_name]}
+ options['method_acq'] = self.acquisition_method.label_fn
+ arr_ax = self.plot_state(interactive=True, **options)
+
optimize = self._should_optimize()
self.target_model.update(params, batch[self.target_name], optimize)
if optimize:
self.state['last_GP_update'] = self.target_model.n_evidence
+ # Adding the updated gp plots.
+ if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
+ self.plot_state(interactive=True, arr_ax=arr_ax, **options)
+
def prepare_new_batch(self, batch_index):
"""Prepare values for a new batch.
@@ -1040,60 +1053,51 @@ def _report_batch(self, batch_index, params, distances):
str += "{}{} at {}\n".format(fill, distances[i].item(), params[i])
logger.debug(str)
- def plot_state(self, **options):
- """Plot the GP surface.
+ def plot_state(self, plot_acq_pairwise=False, arr_ax=None, **options):
+ """Plot the GP surface and the acquisition space.
- This feature is still experimental and currently supports only 2D cases.
- """
- f = plt.gcf()
- if len(f.axes) < 2:
- f, _ = plt.subplots(1, 2, figsize=(13, 6), sharex='row', sharey='row')
-
- gp = self.target_model
-
- # Draw the GP surface
- visin.draw_contour(
- gp.predict_mean,
- gp.bounds,
- self.parameter_names,
- title='GP target surface',
- points=gp.X,
- axes=f.axes[0],
- **options)
+ Notes
+ -----
+ - The plots of the GP surface and the acquisition space work for the
+ cases when dim < 3;
+ - The method is experimental.
- # Draw the latest acquisitions
- if options.get('interactive'):
- point = gp.X[-1, :]
- if len(gp.X) > 1:
- f.axes[1].scatter(*point, color='red')
+ Parameters
+ ----------
+ plot_acq_pairwise : bool, optional
+ The option to plot the pair-wise acquisition point relationships.
+ arr_ax : array_like, optional
+ Handled implicitly upon interactive visualisation.
- displays = [gp._gp]
+ Returns
+ -------
+ array_like
+ Axes for interactive visualisation.
- if options.get('interactive'):
- from IPython import display
- displays.insert(
- 0,
- display.HTML('Iteration {}: Acquired {} at {}'.format(
- len(gp.Y), gp.Y[-1][0], point)))
-
- # Update
- visin._update_interactive(displays, options)
-
- def acq(x):
- return self.acquisition_method.evaluate(x, len(gp.X))
-
- # Draw the acquisition surface
- visin.draw_contour(
- acq,
- gp.bounds,
- self.parameter_names,
- title='Acquisition surface',
- points=None,
- axes=f.axes[1],
- **options)
+ Raises
+ ------
+ ValueError
+ Unsupported dimension.
- if options.get('close'):
- plt.close()
+ """
+ if plot_acq_pairwise:
+ if len(self.parameter_names) == 1:
+ raise ValueError('Can not plot the pair-wise comparison for 1 parameter.')
+
+ # Transform the acquisition points in the accepted format.
+ dict_pts_acq = OrderedDict()
+ for idx_param, name_param in enumerate(self.parameter_names):
+ dict_pts_acq[name_param] = self.target_model.X[:, idx_param]
+ vis.plot_pairs(dict_pts_acq, **options)
+ else:
+ if len(self.parameter_names) == 1:
+ arr_ax = vis.plot_state_1d(self, arr_ax, **options)
+ return arr_ax
+ elif len(self.parameter_names) == 2:
+ arr_ax = vis.plot_state_2d(self, arr_ax, **options)
+ return arr_ax
+ else:
+ raise ValueError('The method is supported only for 1- or 2-dimensions.')
def plot_discrepancy(self, axes=None, **kwargs):
"""Plot acquired parameters vs. resulting discrepancy.
@@ -1133,7 +1137,7 @@ class BOLFI(BayesianOptimization):
"""
- def fit(self, n_evidence, threshold=None):
+ def fit(self, n_evidence, threshold=None, **options):
"""Fit the surrogate model.
Generates a regression model for the discrepancy given the parameters.
@@ -1150,9 +1154,8 @@ def fit(self, n_evidence, threshold=None):
if n_evidence is None:
raise ValueError(
- 'You must specify the number of evidence (n_evidence) for the fitting')
-
- self.infer(n_evidence)
+ 'You must specify the number of evidence( n_evidence) for the fitting')
+ self.infer(n_evidence, **options)
return self.extract_posterior(threshold)
def extract_posterior(self, threshold=None):
diff --git a/elfi/visualization/interactive.py b/elfi/visualization/interactive.py
index 2b28df10..f1b57c40 100644
--- a/elfi/visualization/interactive.py
+++ b/elfi/visualization/interactive.py
@@ -4,14 +4,17 @@
import matplotlib.pyplot as plt
import numpy as np
+import scipy.interpolate
logger = logging.getLogger(__name__)
def plot_sample(samples, nodes=None, n=-1, displays=None, **options):
- """Plot a scatterplot of samples.
+ """Plot a scatter-plot of samples.
- Experimental, only dims 1-2 supported.
+ Notes
+ -----
+ - Experimental, only dims 1-2 supported.
Parameters
----------
@@ -23,7 +26,8 @@ def plot_sample(samples, nodes=None, n=-1, displays=None, **options):
"""
axes = _prepare_axes(options)
-
+ if samples is None:
+ return
nodes = nodes or sorted(samples.keys())[:2]
if isinstance(nodes, str):
nodes = [nodes]
@@ -39,9 +43,8 @@ def plot_sample(samples, nodes=None, n=-1, displays=None, **options):
axes.set_ylabel(nodes[1])
axes.scatter(samples[nodes[0]][:n], samples[nodes[1]][:n])
- _update_interactive(displays, options)
-
- if options.get('close'):
+ if options.get('interactive'):
+ update_interactive(displays, options)
plt.close()
@@ -52,7 +55,8 @@ def get_axes(**options):
return plt.gca()
-def _update_interactive(displays, options):
+def update_interactive(displays, options):
+ """Update the interactive plot."""
displays = displays or []
if options.get('interactive'):
from IPython import display
@@ -67,7 +71,6 @@ def _prepare_axes(options):
if ion:
axes.clear()
-
if options.get('xlim'):
axes.set_xlim(options.get('xlim'))
if options.get('ylim'):
@@ -76,45 +79,52 @@ def _prepare_axes(options):
return axes
-def draw_contour(fn, bounds, nodes=None, points=None, title=None, **options):
+def draw_contour(fn, bounds, params=None, points=None, title=None, label=None, **options):
"""Plot a contour of a function.
+ Notes
+ -----
Experimental, only 2D supported.
Parameters
----------
- fn : callable
- bounds : list[arraylike]
+ fn : Callable
+ Description
+ bounds : list[array_like]
Bounds for the plot, e.g. [(0, 1), (0,1)].
- nodes : list[str], optional
- points : arraylike, optional
- Additional points to plot.
- title : str, optional
+ params : list[String], optional
+ Parameter names.
+ title : String, optional
+ label : String, optional
"""
- ax = get_axes(**options)
-
+ if options.get('axes'):
+ axes = options['axes']
+ plt.sca(axes)
x, y = np.meshgrid(np.linspace(*bounds[0]), np.linspace(*bounds[1]))
z = fn(np.c_[x.reshape(-1), y.reshape(-1)])
- if ax:
- plt.sca(ax)
- plt.cla()
+ # Plotting the contour.
+ CS = plt.contourf(x, y, z.reshape(x.shape), 25)
+ CB = plt.colorbar(CS, orientation='horizontal', format='%.1e')
+ CB.set_label(label)
+ rbf = scipy.interpolate.Rbf(x, y, z, function='linear')
+ zi = rbf(x, y)
+ plt.imshow(zi,
+ vmin=z.min(),
+ vmax=z.max(),
+ origin='lower',
+ extent=[x.min(), x.max(), y.min(), y.max()])
+
+ # Adding the points.
+ if points is not None:
+ plt.scatter(points[:, 0], points[:, 1], color='k')
+ # Adding the labels.
if title:
plt.title(title)
- try:
- plt.contour(x, y, z.reshape(x.shape))
- except ValueError:
- logger.warning('Could not draw a contour plot')
- if points is not None:
- plt.scatter(points[:-1, 0], points[:-1, 1])
- if options.get('interactive'):
- plt.scatter(points[-1, 0], points[-1, 1], color='r')
-
plt.xlim(bounds[0])
plt.ylim(bounds[1])
-
- if nodes:
- plt.xlabel(nodes[0])
- plt.ylabel(nodes[1])
+ if params:
+ plt.xlabel(params[0])
+ plt.ylabel(params[1])
diff --git a/elfi/visualization/visualization.py b/elfi/visualization/visualization.py
index d3736e5a..e04add7b 100644
--- a/elfi/visualization/visualization.py
+++ b/elfi/visualization/visualization.py
@@ -5,6 +5,7 @@
import matplotlib.pyplot as plt
import numpy as np
+import elfi.visualization.interactive as visin
from elfi.model.elfi_model import Constant, ElfiModel, NodeReference
@@ -99,6 +100,7 @@ def _create_axes(axes, shape, **kwargs):
else:
fig, axes = plt.subplots(ncols=shape[1], nrows=shape[0], **fig_kwargs)
axes = np.atleast_1d(axes)
+ fig.tight_layout(pad=2.0)
return axes, kwargs
@@ -156,14 +158,183 @@ def plot_marginals(samples, selector=None, bins=20, axes=None, **kwargs):
return axes
-def plot_pairs(samples, selector=None, bins=20, axes=None, **kwargs):
- """Plot pairwise relationships as a matrix with marginals on the diagonal.
+def plot_state_1d(model_bo, arr_ax=None, **options):
+ """Plot the GP surface and the acquisition function in 1-D.
- The y-axis of marginal histograms are scaled.
+ Notes
+ -----
+ The method is experimental.
- Parameters
+ Parameters
----------
- samples : OrderedDict of np.arrays
+ model_bo : elfi.methods.parameter_inference.BOLFI
+ arr_ax : array_like, optional
+
+ Returns
+ -------
+ array_like
+ Axes for interactive visualisation
+
+ """
+ gp = model_bo.target_model
+ pts_eval = np.linspace(*gp.bounds[0])
+
+ if arr_ax is None:
+ fig, arr_ax = plt.subplots(nrows=1,
+ ncols=2,
+ figsize=(12, 4),
+ sharex=True)
+ plt.ticklabel_format(style='sci', axis='y', scilimits=(-3, 4))
+ fig.tight_layout(pad=2.0)
+
+ # Plotting the acquisition space.
+ arr_ax[1].set_title('Acquisition surface')
+ arr_ax[1].set_xlabel(model_bo.parameter_names[0])
+ arr_ax[1].set_ylabel(options.pop('method_acq'))
+ score_acq = model_bo.acquisition_method.evaluate(pts_eval)
+ arr_ax[1].legend(loc='upper right')
+ arr_ax[1].plot(pts_eval,
+ score_acq,
+ color='k',
+ label='acquisition function')
+
+ # Plotting the confidence interval and the mean.
+ mean, var = gp.predict(pts_eval, noiseless=False)
+ sigma = np.sqrt(var)
+ z_95 = 1.96
+ lb_ci = mean - z_95 * (sigma)
+ ub_ci = mean + z_95 * (sigma)
+ arr_ax[0].fill(np.concatenate([pts_eval, pts_eval[::-1]]),
+ np.concatenate([lb_ci, ub_ci[::-1]]),
+ alpha=.1,
+ fc='k',
+ ec='None',
+ label='95% confidence interval')
+ arr_ax[0].plot(pts_eval, mean, color='k', label='mean')
+
+ # Plotting the acquisition threshold, epsilon.
+ if model_bo.acquisition_method.name in ['max_var', 'rand_max_var', 'exp_int_var']:
+ thresh_acq = np.repeat(model_bo.acquisition_method.eps,
+ len(pts_eval))
+ arr_ax[0].plot(pts_eval,
+ thresh_acq,
+ color='g',
+ label='acquisition threshold')
+
+ # Plotting the acquired points.
+ arr_ax[0].set_title('GP target surface')
+ arr_ax[0].set_xlabel(model_bo.parameter_names[0])
+ arr_ax[0].set_ylabel('Discrepancy')
+ arr_ax[0].legend(loc='upper right')
+ arr_ax[0].scatter(gp.X, gp.Y, color='k')
+
+ return arr_ax
+ else:
+ if options.get('interactive'):
+ from IPython import display
+ pt_last = options.pop('point_acq')
+
+ # Plotting the last acquired point on the GP target surface.
+ arr_ax[0].scatter(pt_last['x'], pt_last['d'], color='r')
+
+ # Plotting the lines indicating the acquisition's location on the acquisition space.
+ ymin, ymax = arr_ax[1].get_ylim()
+ arr_ax[1].vlines(x=pt_last['x'], ymin=ymin, ymax=ymax,
+ color='r', linestyle='--',
+ label='latest acquisition')
+
+ # Handling the interactive display.
+ displays = []
+ displays.append(gp.instance)
+ n_it = int(len(gp.Y) / model_bo.batch_size)
+ html_disp = 'Iteration {}: Acquired {} at {}' \
+ .format(n_it, pt_last['d'], pt_last['x'])
+ displays.append(display.HTML(html_disp))
+ visin.update_interactive(displays, options=options)
+
+ plt.close()
+
+
+def plot_state_2d(model_bo, arr_ax=None, **options):
+ """Plot the GP surface and the acquisition function in 2-D.
+
+ Notes
+ -----
+ The method is experimental.
+
+ Parameters
+ ----------
+ model_bo : elfi.methods.parameter_inference.BOLFI
+ arr_ax : array_like, optional
+
+ Returns
+ -------
+ array_like
+ Axes for interactive visualisation
+
+ """
+ gp = model_bo.target_model
+
+ if arr_ax is None:
+ _, arr_ax = plt.subplots(nrows=1,
+ ncols=2,
+ figsize=(16, 10),
+ sharex='row',
+ sharey='row')
+
+ # Plotting the acquisition space.
+ def fn_acq(x):
+ return model_bo.acquisition_method.evaluate(x, len(gp.X))
+ visin.draw_contour(fn_acq,
+ gp.bounds,
+ model_bo.parameter_names,
+ title='Acquisition surface',
+ axes=arr_ax[1],
+ label=options.pop('method_acq'),
+ **options)
+
+ # Plotting the GP target surface and the acquired points.
+ visin.draw_contour(gp.predict_mean,
+ gp.bounds,
+ model_bo.parameter_names,
+ title='GP target surface',
+ points=gp.X,
+ axes=arr_ax[0],
+ label='Discrepancy',
+ **options)
+
+ return arr_ax
+ else:
+ if options.get('interactive'):
+ from IPython import display
+ pt_last = options.pop('point_acq')
+
+ # Plotting the last acquired point on the GP target surface and the acquisition space.
+ arr_ax[0].scatter(pt_last['x'][:, 0], pt_last['x'][:, 1], color='r')
+ arr_ax[1].scatter(pt_last['x'][:, 0], pt_last['x'][:, 1], color='r')
+
+ # Handling the interactive display.
+ displays = []
+ displays.append(gp.instance)
+ n_it = int(len(gp.Y) / model_bo.batch_size)
+ html_disp = 'Iteration {}: Acquired {} at {}' \
+ .format(n_it, pt_last['d'], pt_last['x'])
+ displays.append(display.HTML(html_disp))
+ visin.update_interactive(displays, options=options)
+
+ plt.close()
+
+
+def plot_pairs(data, selector=None, bins=20, axes=None, **kwargs):
+ """Plot the pair-wise relationships in a grid with marginals on the diagonal.
+
+ Notes
+ -----
+ Removed: The y-axis of marginal histograms are scaled.
+
+ Parameters
+ ----------
+ data : OrderedDict of np.arrays
selector : iterable of ints or strings, optional
Indices or keys to use from samples. Default to all.
bins : int, optional
@@ -175,31 +346,38 @@ def plot_pairs(samples, selector=None, bins=20, axes=None, **kwargs):
axes : np.array of plt.Axes
"""
- samples = _limit_params(samples, selector)
- shape = (len(samples), len(samples))
edgecolor = kwargs.pop('edgecolor', 'none')
- dot_size = kwargs.pop('s', 2)
- kwargs['sharex'] = kwargs.get('sharex', 'col')
- kwargs['sharey'] = kwargs.get('sharey', 'row')
- axes, kwargs = _create_axes(axes, shape, **kwargs)
-
- for i1, k1 in enumerate(samples):
- min_samples = samples[k1].min()
- max_samples = samples[k1].max()
- for i2, k2 in enumerate(samples):
- if i1 == i2:
- # create a histogram with scaled y-axis
- hist, bin_edges = np.histogram(samples[k1], bins=bins)
- bar_width = bin_edges[1] - bin_edges[0]
- hist = (hist - hist.min()) * (max_samples - min_samples) / (
- hist.max() - hist.min())
- axes[i1, i2].bar(bin_edges[:-1], hist, bar_width, bottom=min_samples, **kwargs)
+ dot_size = kwargs.pop('s', 25)
+
+ # Filter the data.
+ data_selected = _limit_params(data, selector)
+
+ # Initialise the figure.
+ shape_fig = (len(data_selected), len(data_selected))
+ axes, kwargs = _create_axes(axes, shape_fig, **kwargs)
+
+ # Populate the grid of figures.
+ for idx_row, key_row in enumerate(data_selected):
+ for idx_col, key_col in enumerate(data_selected):
+ if idx_row == idx_col:
+ # Plot the marginals.
+ axes[idx_row, idx_col].hist(data_selected[key_row], bins=bins, **kwargs)
+ axes[idx_row, idx_col].set_xlabel(key_row)
+
+ # Experimental: Calculate the bin length.
+ x_min, x_max = axes[idx_row, idx_col].get_xlim()
+ length_bin = (x_max - x_min) / bins
+ axes[idx_row, idx_col].set_ylabel(
+ 'Count (bin length: {0:.2f})'.format(length_bin))
else:
- axes[i1, i2].scatter(
- samples[k2], samples[k1], s=dot_size, edgecolor=edgecolor, **kwargs)
-
- axes[i1, 0].set_ylabel(k1)
- axes[-1, i1].set_xlabel(k1)
+ # Plot the pairs.
+ axes[idx_row, idx_col].scatter(data_selected[key_row],
+ data_selected[key_col],
+ edgecolor=edgecolor,
+ s=dot_size,
+ **kwargs)
+ axes[idx_row, idx_col].set_xlabel(key_row)
+ axes[idx_row, idx_col].set_ylabel(key_col)
return axes
diff --git a/tests/unit/test_document_examples.py b/tests/unit/test_document_examples.py
index f873bb04..7fabef3a 100644
--- a/tests/unit/test_document_examples.py
+++ b/tests/unit/test_document_examples.py
@@ -31,7 +31,7 @@ def __init__(self, model, discrepancy_name, threshold, **kwargs):
def set_objective(self, n_sim):
self.objective['n_sim'] = n_sim
- def update(self, batch, batch_index):
+ def update(self, batch, batch_index, vis=None):
super(CustomMethod, self).update(batch, batch_index)
# Make a filter mask (logical numpy array) from the distance array