Skip to content

Commit

Permalink
ENH add add_marginal_subplot (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
lorentzenchr authored Nov 24, 2024
1 parent 0752de0 commit 360328d
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 61 deletions.
97 changes: 42 additions & 55 deletions docs/examples/regression_on_workers_compensation.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ theme:
name: Switch to system preference

features:
- content.code.copy
- navigation.sections
- navigation.expand
- navigation.tabs
Expand Down Expand Up @@ -113,6 +114,9 @@ markdown_extensions:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
use_pygments: true
- pymdownx.inlinehilite
- pymdownx.keys
Expand Down
6 changes: 3 additions & 3 deletions src/model_diagnostics/_utils/isotonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ def gpava(
----------
fun : callable
Function that calculates the functional at interest, e.g.
```
```py
def median(x, w=None):
return np.quantile(x, 0.5, method="inverted_cdf")
````
```
or, to get the upper bound
````
```py
def median(x, w=None):
return -np.quantile(-x, 0.5, method="inverted_cdf")
```
Expand Down
8 changes: 6 additions & 2 deletions src/model_diagnostics/_utils/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@ def get_plotly_color(i):
return colors[i % len(colors)]


def get_xlabel(ax):
def get_xlabel(ax, xaxis=1):
if isinstance(ax, mpl.axes.Axes):
return ax.get_xlabel()
else:
elif xaxis == 1:
# ax = plotly figure
return ax.layout.xaxis.title.text
elif xaxis >= 1:
axis = getattr(ax.layout, f"xaxis{xaxis}")
return axis.title.text


def get_ylabel(ax, yaxis=1):
if isinstance(ax, mpl.axes.Axes):
return ax.get_ylabel()
elif yaxis == 1:
# ax = plotly figure
return ax.layout.yaxis.title.text
elif yaxis >= 2:
axis = getattr(ax.layout, f"yaxis{yaxis}")
Expand Down
8 changes: 7 additions & 1 deletion src/model_diagnostics/calibration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from .identification import compute_bias, compute_marginal, identification_function
from .plots import plot_bias, plot_marginal, plot_reliability_diagram
from .plots import (
add_marginal_subplot,
plot_bias,
plot_marginal,
plot_reliability_diagram,
)

__all__ = [
"add_marginal_subplot",
"compute_bias",
"compute_marginal",
"identification_function",
Expand Down
161 changes: 161 additions & 0 deletions src/model_diagnostics/calibration/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,91 @@ def plot_marginal(
setting the `plot_backend` via
[`model_diagnostics.set_config`][model_diagnostics.set_config] or
[`model_diagnostics.config_context`][model_diagnostics.config_context].
Examples
-----
If you wish to plot multiple features at once with subfigures, here is how to do it
with matplotlib:
```py
from math import ceil
import matplotlib.pyplot as plt
import numpy as np
from model_diagnostics.calibration import plot_marginal
# Replace by your own data and model.
n_obs = 100
y_obs = np.arange(n_obs)
X = np.ones((n_obs, 2))
X[:, 0] = np.sin(np.arange(n_obs))
X[:, 1] = y_obs ** 2
def model_predict(X):
s = 0.5 * n_obs * np.sin(X)
return s.sum(axis=1) + np.sqrt(X[:, 1])
# Now the plotting.
feature_list = [0, 1]
n_rows, n_cols = ceil(len(feature_list) / 2), 2
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, sharey=True)
for i, ax in enumerate(axs):
plot_marginal(
y_obs=y_obs,
y_pred=model_predict(X),
X=X,
feature_name=feature_list[i],
predict_function=model_predict,
ax=ax,
)
fig.tight_layout()
```
For plotly, use the helper function
[`add_marginal_subplot`][model_diagnostics.calibration.plots.add_marginal_subplot]:
```py
from math import ceil
import numpy as np
from model_diagnostics import config_context
from plotly.subplots import make_subplots
from model_diagnostics.calibration import add_marginal_subplot, plot_marginal
# Replace by your own data and model.
n_obs = 100
y_obs = np.arange(n_obs)
X = np.ones((n_obs, 2))
X[:, 0] = np.sin(np.arange(n_obs))
X[:, 1] = y_obs ** 2
def model_predict(X):
s = 0.5 * n_obs * np.sin(X)
return s.sum(axis=1) + np.sqrt(X[:, 1])
# Now the plotting.
feature_list = [0, 1]
n_rows, n_cols = ceil(len(feature_list) / 2), 2
fig = make_subplots(
rows=n_rows,
cols=n_cols,
vertical_spacing=0.3 / n_rows, # equals default
# subplot_titles=feature_list, # maybe
specs=[[{"secondary_y": True}] * n_cols] * n_rows, # This is important!
)
for row in range(n_rows):
for col in range(n_cols):
i = n_cols * row + col
with config_context(plot_backend="plotly"):
subfig = plot_marginal(
y_obs=y_obs,
y_pred=model_predict(X),
X=X,
feature_name=feature_list[i],
predict_function=model_predict,
)
add_marginal_subplot(subfig, fig, row, col)
fig.show()
```
"""
if ax is None:
plot_backend = get_config()["plot_backend"]
Expand Down Expand Up @@ -1130,3 +1215,79 @@ def plot_marginal(
)

return ax


def add_marginal_subplot(subfig, fig, row: int, col: int):
"""Add a plotly subplot from plot_marginal to a multi-plot figure.
This is a helper function is accompanying
[`plot_marginal`][model_diagnostics.calibration.plot_marginal] in order to ease
plotting with subfigures with the plotly backend.
For it to work, you must call `make_subplots` with the `specs` argument and set
the appropriate number of `{"secondary_y": True}` in a list of lists.
```py hl_lines="7"
from plotly.subplots import make_subplots
n_rows, n_cols = ...
fig = make_subplots(
rows=n_rows,
cols=n_cols,
specs=[[{"secondary_y": True}] * n_cols] * n_rows, # This is important!
)
```
The reason is that `plot_marginal` uses a secondary yaxis (and swapped sides with
the primary yaxis).
Parameters
----------
subfig : plotly Figure
The subfigure which is added to `fig`.
fig : plotly Figure
The multi-plot figure to which `subfig` is added at positions `row` and `col`.
row : int
The (0-based) row index of `fig` at which `subfig` is added.
col : int
The (0-based) column index of `fig` at which `subfig` is added.
Returns
-------
fig
"""
# It returns a tuple of `range`s starting at 1.
plotly_rows, plotly_cols = fig._get_subplot_rows_columns() # noqa: SLF001
n_rows = len(plotly_rows)
n_cols = len(plotly_cols)
if row >= n_rows or col >= n_cols:
msg = (
f"The `fig` only has {n_rows} rows and {n_cols} columns. You specified "
f"(0-based) {row=} and {col=}."
)
raise ValueError(msg)
i = n_cols * row + col
# Plotly uses 1-based indices:
row += 1
col += 1
# Transfer the x-axis titles of the subfig to fig.
xaxis = "xaxis" if i == 0 else f"xaxis{i + 1}"
fig["layout"][xaxis]["title"] = subfig["layout"]["xaxis"]["title"]
# Change sides of y-axis.
yaxis = "yaxis" if i == 0 else f"yaxis{2 * i + 1}"
yaxis2 = f"yaxis{2 * (i + 1)}"
fig.update_layout(
**{
yaxis: {"side": "right", "showgrid": False},
yaxis2: {"side": "left", "title": "y"},
}
)
# Only the last added subfig should show the legends, but all the ones before
# should not.
# So don't show legends for row-1 and col-1.
if row > 1:
fig.update_traces(patch={"showlegend": False}, row=row - 1, col=col)
if col > 1:
fig.update_traces(patch={"showlegend": False}, row=row, col=col - 1)
for d in subfig.data:
fig.add_trace(d, row=row, col=col, secondary_y=d["yaxis"] == "y2")

fig.update_layout(title=subfig.layout.title.text)
73 changes: 73 additions & 0 deletions src/model_diagnostics/calibration/tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from math import ceil

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -26,6 +28,7 @@
pd_Series,
)
from model_diagnostics.calibration import (
add_marginal_subplot,
plot_bias,
plot_marginal,
plot_reliability_diagram,
Expand Down Expand Up @@ -683,3 +686,73 @@ def test_plot_marginal_show_lines(show_lines, feature_type, plot_backend):
for i in range(1, 1 + 3):
assert isinstance(ax.data[i], Scatter)
assert ax.data[i].mode == mode


def test_add_marginal_subplot_raises():
"""Test that add_marginal_subplot raises errors."""
pytest.importorskip("plotly")
import plotly.graph_objects as go
from plotly.subplots import make_subplots

n_rows, n_cols = 4, 3
fig = make_subplots(
rows=n_rows,
cols=n_cols,
)
msg = f"The `fig` only has {n_rows} rows and {n_cols} columns"
with pytest.raises(ValueError, match=msg):
add_marginal_subplot(go.Figure(), fig, 5, 4)


def test_add_marginal_subplot():
"""Test that add_marginal_subplot works."""
pytest.importorskip("plotly")
from plotly.subplots import make_subplots

n_features = 12
n_obs = 10
y_obs = np.arange(n_obs)
X = np.ones((n_obs, n_features))
X[:n_obs, 0] = np.sin(np.arange(n_obs))
X[:, 1] = y_obs**2

def model_predict(X):
s = 0.5 * n_obs * np.sin(X)
return s.sum(axis=1) + np.sqrt(X[:, 1])

# Now the plotting.
feature_list = list(range(n_features))
n_cols = 3
n_rows = ceil(len(feature_list) / n_cols)
fig = make_subplots(
rows=n_rows,
cols=n_cols,
specs=[[{"secondary_y": True}] * n_cols] * n_rows,
)
for row in range(n_rows):
for col in range(n_cols):
i = n_cols * row + col
with config_context(plot_backend="plotly"):
subfig = plot_marginal(
y_obs=y_obs,
y_pred=model_predict(X),
X=X,
feature_name=feature_list[i],
predict_function=model_predict,
)
add_marginal_subplot(subfig, fig, row, col)

assert get_xlabel(fig, xaxis=1) == "binned feature 0"
assert get_xlabel(fig, xaxis=3) == "binned feature 2"
assert get_xlabel(fig, xaxis=4) == "binned feature 3"
assert get_ylabel(fig, yaxis=2) == "y"
assert get_title(fig) == "Marginal Plot"

legend_text = get_legend_list(fig)
# TODO: It is not 100% clear why legend_text has most often more entries than 3 or
# 4. We therefor test >= instead of ==.
# It is also unclear why for matplotlib the order varies.
assert len(legend_text) >= 3
assert legend_text[0] == "mean y_obs"
assert legend_text[1] == "mean y_pred"
assert legend_text[2] == "partial dependence"

0 comments on commit 360328d

Please sign in to comment.