Skip to content

Commit

Permalink
[MaxVar split, Part 2] Added the visualisation improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
perdaug committed Oct 4, 2017
1 parent 077308b commit 6402a49
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 145 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
=========

0.x
---
- Improved the interactive plotting (customised for the MaxVar-based acquisition methods)
- Added a pair-wise plotting to plot_state() (a way to visualise n-dimensional parameters)

0.6.3 (2017-09-28)
------------------

Expand Down
2 changes: 2 additions & 0 deletions elfi/methods/bo/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def __init__(self, *args, delta=None, **kwargs):
kwargs['exploration_rate'] = 1 / delta

super(LCBSC, self).__init__(*args, **kwargs)
self.name = 'lcbsc'
self.label_fn = 'The Lower Confidence Bound Selection Criterion'

@property
def delta(self):
Expand Down
169 changes: 86 additions & 83 deletions elfi/methods/parameter_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
__all__ = ['Rejection', 'SMC', 'BayesianOptimization', 'BOLFI']

import logging
from collections import OrderedDict
from math import ceil

import matplotlib.pyplot as plt
import numpy as np

import elfi.client
Expand Down Expand Up @@ -89,7 +89,6 @@ def __init__(self,
model = model.model if isinstance(model, NodeReference) else model
if not model.parameter_names:
raise ValueError('Model {} defines no parameters'.format(model))

self.model = model.copy()
self.output_names = self._check_outputs(output_names)

Expand Down Expand Up @@ -161,7 +160,7 @@ def extract_result(self):
"""
raise NotImplementedError

def update(self, batch, batch_index):
def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
ELFI calls this method when a new batch has been computed and the state of
Expand All @@ -174,10 +173,8 @@ def update(self, batch, batch_index):
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
batch_index : int
Returns
-------
None
vis : bool, optional
Interactive visualisation of the iterations.
"""
self.state['n_batches'] += 1
Expand Down Expand Up @@ -231,7 +228,7 @@ def plot_state(self, **kwargs):
"""
raise NotImplementedError

def infer(self, *args, vis=None, **kwargs):
def infer(self, *args, **options):
"""Set the objective and start the iterate loop until the inference is finished.
See the other arguments from the `set_objective` method.
Expand All @@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs):
result : Sample
"""
vis_opt = vis if isinstance(vis, dict) else {}

self.set_objective(*args, **kwargs)

vis = options.pop('vis', None)
self.set_objective(*args, **options)
while not self.finished:
self.iterate()
if vis:
self.plot_state(interactive=True, **vis_opt)

self.iterate(vis=vis)
self.batches.cancel_pending()
if vis:
self.plot_state(close=True, **vis_opt)

return self.extract_result()

def iterate(self):
"""Advance the inference by one iteration.
def iterate(self, vis=None):
"""Forward the inference one iteration.
This is a way to manually progress the inference. One iteration consists of
waiting and processing the result of the next batch in succession and possibly
Expand All @@ -272,6 +262,11 @@ def iterate(self):
will never be more batches submitted in parallel than the `max_parallel_batches`
setting allows.
Parameters
----------
vis : bool, optional
Interactive visualisation of the iterations.
Returns
-------
None
Expand All @@ -286,7 +281,7 @@ def iterate(self):
# Handle the next ready batch in succession
batch, batch_index = self.batches.wait_next()
logger.debug('Received batch %d' % batch_index)
self.update(batch, batch_index)
self.update(batch, batch_index, vis=vis)

@property
def finished(self):
Expand Down Expand Up @@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
# Reset the inference
self.batches.reset()

def update(self, batch, batch_index):
def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
dict with `self.outputs` as keys and the corresponding outputs for the batch as values
vis : bool, optional
Interactive visualisation of the iterations.
batch_index : int
"""
if vis and self.state['samples'] is not None:
self.plot_state(interactive=True, **vis)

super(Rejection, self).update(batch, batch_index)
if self.state['samples'] is None:
# Lazy initialization of the outputs dict
Expand Down Expand Up @@ -584,8 +583,8 @@ def plot_state(self, **options):
displays = []
if options.get('interactive'):
from IPython import display
displays.append(
display.HTML('<span>Threshold: {}</span>'.format(self.state['threshold'])))
html_display = '<span>Threshold: {}</span>'.format(self.state['threshold'])
displays.append(display.HTML(html_display))

visin.plot_sample(
self.state['samples'],
Expand Down Expand Up @@ -651,14 +650,15 @@ def extract_result(self):
threshold=pop.threshold,
**self._extract_result_kwargs())

def update(self, batch, batch_index):
def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.
Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
dict with `self.outputs` as keys and the corresponding outputs for the batch as values
vis : bool, optional
Interactive visualisation of the iterations.
batch_index : int
"""
Expand Down Expand Up @@ -942,14 +942,16 @@ def extract_result(self):
return OptimizationResult(
x_min=batch_min, outputs=outputs, **self._extract_result_kwargs())

def update(self, batch, batch_index):
def update(self, batch, batch_index, vis=None):
"""Update the GP regression model of the target node with a new batch.
Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
vis : bool, optional
Interactive visualisation of the iterations.
batch_index : int
"""
Expand All @@ -959,11 +961,22 @@ def update(self, batch, batch_index):
params = batch_to_arr2d(batch, self.parameter_names)
self._report_batch(batch_index, params, batch[self.target_name])

# Adding the acquisition plots.
if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
options = {}
options['point_acq'] = {'x': params, 'd': batch[self.target_name]}
options['method_acq'] = self.acquisition_method.label_fn
arr_ax = self.plot_state(interactive=True, **options)

optimize = self._should_optimize()
self.target_model.update(params, batch[self.target_name], optimize)
if optimize:
self.state['last_GP_update'] = self.target_model.n_evidence

# Adding the updated gp plots.
if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
self.plot_state(interactive=True, arr_ax=arr_ax, **options)

def prepare_new_batch(self, batch_index):
"""Prepare values for a new batch.
Expand Down Expand Up @@ -1040,60 +1053,51 @@ def _report_batch(self, batch_index, params, distances):
str += "{}{} at {}\n".format(fill, distances[i].item(), params[i])
logger.debug(str)

def plot_state(self, **options):
"""Plot the GP surface.
def plot_state(self, plot_acq_pairwise=False, arr_ax=None, **options):
"""Plot the GP surface and the acquisition space.
This feature is still experimental and currently supports only 2D cases.
"""
f = plt.gcf()
if len(f.axes) < 2:
f, _ = plt.subplots(1, 2, figsize=(13, 6), sharex='row', sharey='row')

gp = self.target_model

# Draw the GP surface
visin.draw_contour(
gp.predict_mean,
gp.bounds,
self.parameter_names,
title='GP target surface',
points=gp.X,
axes=f.axes[0],
**options)
Notes
-----
- The plots of the GP surface and the acquisition space work for the
cases when dim < 3;
- The method is experimental.
# Draw the latest acquisitions
if options.get('interactive'):
point = gp.X[-1, :]
if len(gp.X) > 1:
f.axes[1].scatter(*point, color='red')
Parameters
----------
plot_acq_pairwise : bool, optional
The option to plot the pair-wise acquisition point relationships.
arr_ax : array_like, optional
Handled implicitly upon interactive visualisation.
displays = [gp._gp]
Returns
-------
array_like
Axes for interactive visualisation.
if options.get('interactive'):
from IPython import display
displays.insert(
0,
display.HTML('<span><b>Iteration {}:</b> Acquired {} at {}</span>'.format(
len(gp.Y), gp.Y[-1][0], point)))

# Update
visin._update_interactive(displays, options)

def acq(x):
return self.acquisition_method.evaluate(x, len(gp.X))

# Draw the acquisition surface
visin.draw_contour(
acq,
gp.bounds,
self.parameter_names,
title='Acquisition surface',
points=None,
axes=f.axes[1],
**options)
Raises
------
ValueError
Unsupported dimension.
if options.get('close'):
plt.close()
"""
if plot_acq_pairwise:
if len(self.parameter_names) == 1:
raise ValueError('Can not plot the pair-wise comparison for 1 parameter.')

# Transform the acquisition points in the accepted format.
dict_pts_acq = OrderedDict()
for idx_param, name_param in enumerate(self.parameter_names):
dict_pts_acq[name_param] = self.target_model.X[:, idx_param]
vis.plot_pairs(dict_pts_acq, **options)
else:
if len(self.parameter_names) == 1:
arr_ax = vis.plot_state_1d(self, arr_ax, **options)
return arr_ax
elif len(self.parameter_names) == 2:
arr_ax = vis.plot_state_2d(self, arr_ax, **options)
return arr_ax
else:
raise ValueError('The method is supported only for 1- or 2-dimensions.')

def plot_discrepancy(self, axes=None, **kwargs):
"""Plot acquired parameters vs. resulting discrepancy.
Expand Down Expand Up @@ -1133,7 +1137,7 @@ class BOLFI(BayesianOptimization):
"""

def fit(self, n_evidence, threshold=None):
def fit(self, n_evidence, threshold=None, **options):
"""Fit the surrogate model.
Generates a regression model for the discrepancy given the parameters.
Expand All @@ -1150,9 +1154,8 @@ def fit(self, n_evidence, threshold=None):

if n_evidence is None:
raise ValueError(
'You must specify the number of evidence (n_evidence) for the fitting')

self.infer(n_evidence)
'You must specify the number of evidence( n_evidence) for the fitting')
self.infer(n_evidence, **options)
return self.extract_posterior(threshold)

def extract_posterior(self, threshold=None):
Expand Down
Loading

0 comments on commit 6402a49

Please sign in to comment.