Skip to content

Commit

Permalink
added cache for figure_options and plotter options.
Browse files Browse the repository at this point in the history
  • Loading branch information
ItamarGoldman committed Feb 18, 2024
1 parent cc397c0 commit 33473b7
Showing 1 changed file with 76 additions and 60 deletions.
136 changes: 76 additions & 60 deletions qiskit_experiments/curve_analysis/curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(

self._models = models or []
self._name = name or self.__class__.__name__
self._cache = {}

@property
def name(self) -> str:
Expand Down Expand Up @@ -169,70 +170,85 @@ def set_options(self, **fields):
)
fields["plot_residuals"] = False
else:
# check we have model to fit into
if self.models:
self.plotter.set_figure_options(
ylabel=[
self.plotter.figure_options.get("ylabel", ""),
"Residuals",
],
)
model_names = self.model_names()
for model_name in model_names:
self.plotter.set_figure_options(
sharey=False,
series_params={
**self.plotter.figure_options["series_params"],
**{
model_name: {
"canvas": 0,
},
model_name
+ "_residuals": {
"canvas": 1,
},
},
},
)
self._add_residuals_plot_config()
if not fields.get("plot_residuals", True) and self.options.get("plot_residuals"):
self._remove_residuals_plot_config()

# Here add the configuration for the residuals plot:
self.plotter.set_options(
subplots=(2, 1),
style=PlotStyle(
{
"figsize": (8, 8),
"textbox_rel_pos": (0.28, -0.10),
"sub_plot_heights_list": [7 / 10, 3 / 10],
"sub_plot_widths_list": [1],
"style_name": "residuals",
}
),
)
super().set_options(**fields)

if not fields.get("plot_residuals", True) and self.options.get("plot_residuals"):
# set options for single plot and cancel residuals plotting.
if self.models:
def _add_residuals_plot_config(self):
"""Configure plotter options for residuals plot."""
# check we have model to fit into
if self.models:
# Cache figure options.
self._cache["figure_options"] = {}
self._cache["figure_options"]["ylabel"] = self.plotter.figure_options.get("ylabel")
self._cache["figure_options"]["series_params"] = self.plotter.figure_options[
"series_params"
]
self._cache["figure_options"]["sharey"] = self.plotter.figure_options["sharey"]

self.plotter.set_figure_options(
ylabel=[
self.plotter.figure_options.get("ylabel", ""),
"Residuals",
],
)
model_names = self.model_names()
for model_name in model_names:
# Cache figure options.
self._cache[model_name] = {}
self.plotter.set_figure_options(
ylabel=[self.plotter.figure_options.get("ylabel", "")[0]],
)
model_names = self.model_names()
for model_name in model_names:
self.plotter.figure_options["series_params"][model_name].pop("canvas", None)

# Here add the configuration for the residuals plot:
self.plotter.set_options(
subplots=(1, 1),
style=PlotStyle(
{
"figsize": (7, 5),
"textbox_rel_pos": (0, 0),
"sub_plot_heights_list": [1],
"sub_plot_widths_list": [1],
"style_name": "canceled_residuals",
}
),
sharey=False,
series_params={
**self.plotter.figure_options["series_params"],
**{
model_name: {
"canvas": 0,
},
model_name
+ "_residuals": {
"canvas": 1,
},
},
},
)
super().set_options(**fields)

# Cache plotter options.
self._cache["plotter"] = {}
self._cache["plotter"]["subplots"] = self.plotter.options.get("subplots")
self._cache["plotter"]["style"] = self.plotter.options.get("style", PlotStyle({}))

# Here add the configuration for the residuals plot:
self.plotter.set_options(
subplots=(2, 1),
style=PlotStyle(
{
"figsize": (8, 8),
"textbox_rel_pos": (0.28, -0.10),
"sub_plot_heights_list": [7 / 10, 3 / 10],
"sub_plot_widths_list": [1],
"style_name": "residuals",
}
),
)

def _remove_residuals_plot_config(self):
"""set options for a single plot to its cached values."""
if self.models:
self.plotter.set_figure_options(
ylabel=self._cache["figure_options"]["ylabel"],
sharey=self._cache["figure_options"]["sharey"],
series_params=self._cache["figure_options"]["series_params"],
)

# Here add the style_name so the plotter will know not to print the residual data.
self.plotter.set_options(
subplots=self._cache["plotter"]["subplots"],
style=PlotStyle.merge(
self._cache["plotter"]["style"], PlotStyle({"style_name": "canceled_residuals"})
),
)

def _run_data_processing(
self,
Expand Down

0 comments on commit 33473b7

Please sign in to comment.