diff --git a/RELEASE.md b/RELEASE.md index f90df85..26e8fae 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -4,6 +4,8 @@ ## Major Features and Improvements + * Add fairness indicator metrics in the third_party library. + ## Bug Fixes and Other Changes ## Breaking Changes diff --git a/fairness_indicators/example_model.py b/fairness_indicators/example_model.py index 09266b5..7bfdec2 100644 --- a/fairness_indicators/example_model.py +++ b/fairness_indicators/example_model.py @@ -19,10 +19,10 @@ results can be visualized using tools like TensorBoard. """ +from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import from tensorflow import keras import tensorflow.compat.v1 as tf import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.addons.fairness.post_export_metrics import fairness_indicators # pylint: disable=unused-import TEXT_FEATURE = 'comment_text' diff --git a/fairness_indicators/fairness_indicators_metrics.py b/fairness_indicators/fairness_indicators_metrics.py new file mode 100644 index 0000000..94b785c --- /dev/null +++ b/fairness_indicators/fairness_indicators_metrics.py @@ -0,0 +1,208 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fairness Indicators Metrics.""" + +import collections +from typing import Any, Dict, List, Optional, Sequence + +from tensorflow_model_analysis.metrics import binary_confusion_matrices +from tensorflow_model_analysis.metrics import metric_types +from tensorflow_model_analysis.metrics import metric_util +from tensorflow_model_analysis.proto import config_pb2 + +FAIRNESS_INDICATORS_METRICS_NAME = 'fairness_indicators_metrics' +FAIRNESS_INDICATORS_SUB_METRICS = ( + 'false_positive_rate', + 'false_negative_rate', + 'true_positive_rate', + 'true_negative_rate', + 'positive_rate', + 'negative_rate', + 'false_discovery_rate', + 'false_omission_rate', + 'precision', + 'recall', +) + +DEFAULT_THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + + +class FairnessIndicators(metric_types.Metric): + """Fairness indicators metrics.""" + + def computations_with_logging(self): + """Add streamz logging for fairness indicators.""" + + computations_fn = metric_util.merge_per_key_computations( + _fairness_indicators_metrics_at_thresholds + ) + + def merge_and_log_computations_fn( + eval_config: Optional[config_pb2.EvalConfig] = None, + # A tf metadata schema. + schema: Optional[Any] = None, + model_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, + sub_keys: Optional[List[Optional[metric_types.SubKey]]] = None, + aggregation_type: Optional[metric_types.AggregationType] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, + query_key: Optional[str] = None, + **kwargs + ): + return computations_fn( + eval_config, + schema, + model_names, + output_names, + sub_keys, + aggregation_type, + class_weights, + example_weighted, + query_key, + **kwargs + ) + + return merge_and_log_computations_fn + + def __init__( + self, + thresholds: Sequence[float] = DEFAULT_THRESHOLDS, + name: str = FAIRNESS_INDICATORS_METRICS_NAME, + ): + """Initializes fairness indicators metrics. + + Args: + thresholds: Thresholds to use for fairness metrics. + name: Metric name. + """ + super().__init__( + self.computations_with_logging(), thresholds=thresholds, name=name + ) + + +def calculate_digits(thresholds): + digits = [len(str(t)) - 2 for t in thresholds] + return max(max(digits), 1) + + +def _fairness_indicators_metrics_at_thresholds( + thresholds: List[float], + name: str = FAIRNESS_INDICATORS_METRICS_NAME, + eval_config: Optional[config_pb2.EvalConfig] = None, + model_name: str = '', + output_name: str = '', + aggregation_type: Optional[metric_types.AggregationType] = None, + sub_key: Optional[metric_types.SubKey] = None, + class_weights: Optional[Dict[int, float]] = None, + example_weighted: bool = False, +) -> metric_types.MetricComputations: + """Returns computations for fairness metrics at thresholds.""" + metric_key_by_name_by_threshold = collections.defaultdict(dict) + keys = [] + digits_num = calculate_digits(thresholds) + for t in thresholds: + for m in FAIRNESS_INDICATORS_SUB_METRICS: + key = metric_types.MetricKey( + name='%s/%s@%.*f' + % ( + name, + m, + digits_num, + t, + ), # e.g. "fairness_indicators_metrics/positive_rate@0.5" + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + example_weighted=example_weighted, + ) + keys.append(key) + metric_key_by_name_by_threshold[t][m] = key + + # Make sure matrices are calculated. + computations = binary_confusion_matrices.binary_confusion_matrices( + eval_config=eval_config, + model_name=model_name, + output_name=output_name, + sub_key=sub_key, + aggregation_type=aggregation_type, + class_weights=class_weights, + example_weighted=example_weighted, + thresholds=thresholds, + ) + confusion_matrices_key = computations[-1].keys[-1] + + def result( + metrics: Dict[metric_types.MetricKey, Any], + ) -> Dict[metric_types.MetricKey, Any]: + """Returns fairness metrics values.""" + metric = metrics[confusion_matrices_key] + output = {} + + for i, threshold in enumerate(thresholds): + num_positives = metric.tp[i] + metric.fn[i] + num_negatives = metric.tn[i] + metric.fp[i] + + tpr = metric.tp[i] / (num_positives or float('nan')) + tnr = metric.tn[i] / (num_negatives or float('nan')) + fpr = metric.fp[i] / (num_negatives or float('nan')) + fnr = metric.fn[i] / (num_positives or float('nan')) + pr = (metric.tp[i] + metric.fp[i]) / ( + (num_positives + num_negatives) or float('nan') + ) + nr = (metric.tn[i] + metric.fn[i]) / ( + (num_positives + num_negatives) or float('nan') + ) + precision = metric.tp[i] / ((metric.tp[i] + metric.fp[i]) or float('nan')) + recall = metric.tp[i] / ((metric.tp[i] + metric.fn[i]) or float('nan')) + + fdr = metric.fp[i] / ((metric.fp[i] + metric.tp[i]) or float('nan')) + fomr = metric.fn[i] / ((metric.fn[i] + metric.tn[i]) or float('nan')) + + output[ + metric_key_by_name_by_threshold[threshold]['false_positive_rate'] + ] = fpr + output[ + metric_key_by_name_by_threshold[threshold]['false_negative_rate'] + ] = fnr + output[ + metric_key_by_name_by_threshold[threshold]['true_positive_rate'] + ] = tpr + output[ + metric_key_by_name_by_threshold[threshold]['true_negative_rate'] + ] = tnr + output[metric_key_by_name_by_threshold[threshold]['positive_rate']] = pr + output[metric_key_by_name_by_threshold[threshold]['negative_rate']] = nr + output[ + metric_key_by_name_by_threshold[threshold]['false_discovery_rate'] + ] = fdr + output[ + metric_key_by_name_by_threshold[threshold]['false_omission_rate'] + ] = fomr + output[metric_key_by_name_by_threshold[threshold]['precision']] = ( + precision + ) + output[metric_key_by_name_by_threshold[threshold]['recall']] = recall + + return output + + derived_computation = metric_types.DerivedMetricComputation( + keys=keys, result=result + ) + + computations.append(derived_computation) + return computations + + +metric_types.register_metric(FairnessIndicators)