Skip to content

Commit

Permalink
fix handling of nan value in averaging
Browse files Browse the repository at this point in the history
  • Loading branch information
nkanazawa1989 committed Nov 18, 2023
1 parent 61ebcd9 commit 1f10cc2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion qiskit_experiments/curve_analysis/curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _format_data(
average = averaging_methods[self.options.average_method]
model_names = self.model_names()
formatted = []
for (class_id, xv), g in groupby(sorted(curve_data.values, key=sort_by), key=sort_by):
for (_, xv), g in groupby(sorted(curve_data.values, key=sort_by), key=sort_by):
g_values = np.array(list(g))
g_dict = dict(zip(columns, g_values.T))
avg_yval, avg_yerr, shots = average(g_dict["yval"], g_dict["yerr"], g_dict["shots"])
Expand Down
9 changes: 5 additions & 4 deletions qiskit_experiments/curve_analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asteval
import lmfit
import numpy as np
import pandas as pd
from qiskit.utils.deprecation import deprecate_func
from qiskit.utils import detach_prefix
from uncertainties import UFloat, wrap as wrap_function
Expand Down Expand Up @@ -243,9 +244,9 @@ def shot_weighted_average(
if len(yvals) == 1:
return yvals[0], yerrs[0], shots[0]

if np.any(shots < -1):
if any(s is pd.NA for s in shots):
# Shot number is unknown
return np.mean(yvals), np.nan, -1
return np.mean(yvals), np.nan, pd.NA

total_shots = np.sum(shots)
weights = shots / total_shots
Expand Down Expand Up @@ -276,7 +277,7 @@ def inverse_weighted_variance(
if len(yvals) == 1:
return yvals[0], yerrs[0], shots[0]

total_shots = np.sum(shots) if all(shots > 0) else -1
total_shots = np.sum(shots)
weights = 1 / yerrs**2
yvar = 1 / np.sum(weights)

Expand Down Expand Up @@ -307,7 +308,7 @@ def sample_average(
if len(yvals) == 1:
return yvals[0], 0.0, shots[0]

total_shots = np.sum(shots) if all(shots > 0) else -1
total_shots = np.sum(shots)

avg_yval = np.mean(yvals)
avg_yerr = np.sqrt(np.mean((avg_yval - yvals) ** 2) / len(yvals))
Expand Down

0 comments on commit 1f10cc2

Please sign in to comment.