diff --git a/nannyml/performance_estimation/confidence_based/metrics.py b/nannyml/performance_estimation/confidence_based/metrics.py index 1f0f60ff..ad29f953 100644 --- a/nannyml/performance_estimation/confidence_based/metrics.py +++ b/nannyml/performance_estimation/confidence_based/metrics.py @@ -782,11 +782,9 @@ def estimate_f1(y_pred: Union[pd.Series, np.ndarray], y_pred_proba: Union[pd.Ser fp = np.where(y_pred == 1, 1 - y_pred_proba, 0) fn = np.where(y_pred == 0, y_pred_proba, 0) TP, FP, FN = np.sum(tp), np.sum(fp), np.sum(fn) - if TP + 0.5 * (FP + FN) == 0: - metric = 0 - else: - metric = TP / (TP + 0.5 * (FP + FN)) - return metric + + denominator = TP + 0.5 * (FP + FN) + return TP / denominator if denominator != 0 else 0 @MetricFactory.register('precision', ProblemType.CLASSIFICATION_BINARY) @@ -929,11 +927,9 @@ def estimate_precision(y_pred: Union[pd.Series, np.ndarray], y_pred_proba: Union tp = np.where(y_pred == 1, y_pred_proba, 0) fp = np.where(y_pred == 1, 1 - y_pred_proba, 0) TP, FP = np.sum(tp), np.sum(fp) - if TP + FP == 0: - metric = 0 - else: - metric = TP / (TP + FP) - return metric + + denominator = TP + FP + return TP / denominator if denominator != 0 else 0 @MetricFactory.register('recall', ProblemType.CLASSIFICATION_BINARY) @@ -1076,11 +1072,9 @@ def estimate_recall(y_pred: Union[pd.Series, np.ndarray], y_pred_proba: Union[pd tp = np.where(y_pred == 1, y_pred_proba, 0) fn = np.where(y_pred == 0, y_pred_proba, 0) TP, FN = np.sum(tp), np.sum(fn) - if TP + FN == 0: - metric = 0 - else: - metric = TP / (TP + FN) - return metric + + denominator = TP + FN + return TP / denominator if denominator != 0 else 0 @MetricFactory.register('specificity', ProblemType.CLASSIFICATION_BINARY) @@ -1215,11 +1209,9 @@ def estimate_specificity(y_pred: Union[pd.Series, np.ndarray], y_pred_proba: Uni tn = np.where(y_pred == 0, 1 - y_pred_proba, 0) fp = np.where(y_pred == 1, 1 - y_pred_proba, 0) TN, FP = np.sum(tn), np.sum(fp) - if TN + FP == 0: - metric = 0 - else: - metric = TN / (TN + FP) - return metric + + denominator = TN + FP + return TN / denominator if denominator != 0 else 0 @MetricFactory.register('accuracy', ProblemType.CLASSIFICATION_BINARY)