From 268035f432bb8e23a78c4f72eaafa660e2584df1 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 6 Jul 2021 08:08:00 -0700 Subject: [PATCH] Type Fixes in Robust Metrics (#707) Summary: Fixes Mypy type checking issues with attack metrics to resolve CircleCI issues Pull Request resolved: https://github.com/pytorch/captum/pull/707 Reviewed By: NarineK Differential Revision: D29552315 Pulled By: vivekmig fbshipit-source-id: ba44d7e4121df30d26ac9e0bc796614ac726a9ed --- captum/_utils/common.py | 13 ++++ captum/attr/_utils/summarizer.py | 6 +- .../robust/_core/metrics/attack_comparator.py | 78 +++++++++++-------- .../_core/metrics/min_param_perturbation.py | 10 ++- tests/robust/test_min_param_perturbation.py | 6 +- 5 files changed, 72 insertions(+), 41 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 3e0ae30ae2..6c67009135 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -159,6 +159,19 @@ def _format_input(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> Tuple[Tensor, .. return _format_tensor_into_tuples(inputs) +def _format_float_or_tensor_into_tuples( + inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]] +) -> Tuple[Union[float, Tensor], ...]: + if not isinstance(inputs, tuple): + assert isinstance( + inputs, (torch.Tensor, float) + ), "`inputs` must have type float or torch.Tensor but {} found: ".format( + type(inputs) + ) + inputs = (inputs,) + return inputs + + @overload def _format_additional_forward_args(additional_forward_args: None) -> None: ... diff --git a/captum/attr/_utils/summarizer.py b/captum/attr/_utils/summarizer.py index 62a386843c..e99d74fbd6 100644 --- a/captum/attr/_utils/summarizer.py +++ b/captum/attr/_utils/summarizer.py @@ -43,7 +43,7 @@ def _copy_stats(self): return copy.deepcopy(self._stats) - def update(self, x: Union[Tensor, Tuple[Tensor, ...]]): + def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]): r""" Calls `update` on each `Stat` object within the summarizer @@ -57,9 +57,9 @@ def update(self, x: Union[Tensor, Tuple[Tensor, ...]]): # we want input to be consistently a single input or a tuple assert not (self._is_inputs_tuple ^ isinstance(x, tuple)) - from captum._utils.common import _format_tensor_into_tuples + from captum._utils.common import _format_float_or_tensor_into_tuples - x = _format_tensor_into_tuples(x) + x = _format_float_or_tensor_into_tuples(x) for i, inp in enumerate(x): if i >= len(self._summarizers): diff --git a/captum/robust/_core/metrics/attack_comparator.py b/captum/robust/_core/metrics/attack_comparator.py index 7555301f18..330a8c83aa 100644 --- a/captum/robust/_core/metrics/attack_comparator.py +++ b/captum/robust/_core/metrics/attack_comparator.py @@ -1,7 +1,19 @@ #!/usr/bin/env python3 import warnings from collections import namedtuple -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + NamedTuple, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from torch import Tensor @@ -15,6 +27,10 @@ ORIGINAL_KEY = "Original" +MetricResultType = TypeVar( + "MetricResultType", float, Tensor, Tuple[Union[float, Tensor], ...] +) + class AttackInfo(NamedTuple): attack_fn: Union[Perturbation, Callable] @@ -33,7 +49,7 @@ def agg_metric(inp): return inp -class AttackComparator: +class AttackComparator(Generic[MetricResultType]): r""" Allows measuring model robustness for a given attack or set of attacks. This class can be used with any metric(s) as well as any set of attacks, either based on @@ -44,7 +60,7 @@ class AttackComparator: def __init__( self, forward_func: Callable, - metric: Callable[..., Union[float, Tensor, Tuple[Union[float, Tensor], ...]]], + metric: Callable[..., MetricResultType], preproc_fn: Callable = None, ) -> None: r""" @@ -74,10 +90,10 @@ def model_metric(model_out: Tensor, **kwargs: Any) additional_forward_args provided to evaluate. """ self.forward_func = forward_func - self.metric = metric + self.metric: Callable = metric self.preproc_fn = preproc_fn - self.attacks = {} - self.summary_results = {} + self.attacks: Dict[str, AttackInfo] = {} + self.summary_results: Dict[str, Summarizer] = {} self.metric_aggregator = agg_metric self.batch_stats = [Mean, Min, Max] self.aggregate_stats = [Mean] @@ -148,7 +164,7 @@ def add_attack( def _format_summary( self, summary: Union[Dict, List[Dict]] - ) -> Dict[str, Union[float, Tuple[float, ...]]]: + ) -> Dict[str, MetricResultType]: r""" This method reformats a given summary; particularly for tuples, the Summarizer's summary format is a list of dictionaries, @@ -159,12 +175,12 @@ def _format_summary( if isinstance(summary, dict): return summary else: - summary_dict = {} + summary_dict: Dict[str, Tuple] = {} for key in summary[0]: summary_dict[key] = tuple(s[key] for s in summary) if self.out_format: summary_dict[key] = self.out_format(*summary_dict[key]) - return summary_dict + return summary_dict # type: ignore def _update_out_format( self, out_metric: Union[float, Tensor, Tuple[Union[float, Tensor], ...]] @@ -174,7 +190,9 @@ def _update_out_format( and isinstance(out_metric, tuple) and hasattr(out_metric, "_fields") ): - self.out_format = namedtuple(type(out_metric).__name__, out_metric._fields) + self.out_format = namedtuple( # type: ignore + type(out_metric).__name__, cast(NamedTuple, out_metric)._fields + ) def _evaluate_batch( self, @@ -212,13 +230,10 @@ def _evaluate_batch( def evaluate( self, inputs: Any, - additional_forward_args: Optional[Tuple] = None, + additional_forward_args: Any = None, perturbations_per_eval: int = 1, **kwargs, - ) -> Dict[ - str, - Union[Tensor, Tuple[Tensor, ...], Dict[str, Union[Tensor, Tuple[Tensor, ...]]]], - ]: + ) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]: r""" Evaluate model and attack performance on provided inputs @@ -385,45 +400,44 @@ def _check_and_evaluate(input_list, key_list): def _parse_and_update_results( self, batch_summarizers: Dict[str, Summarizer] - ) -> Dict[ - str, Union[float, Tuple[float, ...], Dict[str, Union[float, Tuple[float, ...]]]] - ]: - results = { - ORIGINAL_KEY: self._format_summary(batch_summarizers[ORIGINAL_KEY].summary)[ - "mean" - ] + ) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]: + results: Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]] = { + ORIGINAL_KEY: self._format_summary( + cast(Union[Dict, List], batch_summarizers[ORIGINAL_KEY].summary) + )["mean"] } self.summary_results[ORIGINAL_KEY].update( self.metric_aggregator(results[ORIGINAL_KEY]) ) for attack_key in self.attacks: attack = self.attacks[attack_key] - results[attack.name] = self._format_summary( - batch_summarizers[attack.name].summary + attack_results = self._format_summary( + cast(Union[Dict, List], batch_summarizers[attack.name].summary) ) + results[attack.name] = attack_results - if len(results[attack.name]) == 1: - key = next(iter(results[attack.name])) + if len(attack_results) == 1: + key = next(iter(attack_results)) if attack.name not in self.summary_results: self.summary_results[attack.name] = Summarizer( [stat() for stat in self.aggregate_stats] ) self.summary_results[attack.name].update( - self.metric_aggregator(results[attack.name][key]) + self.metric_aggregator(attack_results[key]) ) else: - for key in results[attack.name]: + for key in attack_results: summary_key = f"{attack.name} {key.title()} Attempt" if summary_key not in self.summary_results: self.summary_results[summary_key] = Summarizer( [stat() for stat in self.aggregate_stats] ) self.summary_results[summary_key].update( - self.metric_aggregator(results[attack.name][key]) + self.metric_aggregator(attack_results[key]) ) return results - def summary(self) -> Dict[str, Dict[str, Union[Tensor, Tuple[Tensor, ...]]]]: + def summary(self) -> Dict[str, Dict[str, MetricResultType]]: r""" Returns average results over all previous batches evaluated. @@ -440,7 +454,9 @@ def summary(self) -> Dict[str, Dict[str, Union[Tensor, Tuple[Tensor, ...]]]]: per batch. """ return { - key: self._format_summary(self.summary_results[key].summary) + key: self._format_summary( + cast(Union[Dict, List], self.summary_results[key].summary) + ) for key in self.summary_results } diff --git a/captum/robust/_core/metrics/min_param_perturbation.py b/captum/robust/_core/metrics/min_param_perturbation.py index 74ec92d010..3c8a9268a8 100644 --- a/captum/robust/_core/metrics/min_param_perturbation.py +++ b/captum/robust/_core/metrics/min_param_perturbation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import math from enum import Enum -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast import torch from torch import Tensor @@ -136,7 +136,9 @@ def correct_fn(model_out: Tensor, **kwargs: Any) -> bool self.num_attempts = num_attempts self.preproc_fn = preproc_fn self.apply_before_preproc = apply_before_preproc - self.correct_fn = correct_fn if correct_fn is not None else default_correct_fn + self.correct_fn = cast( + Callable, correct_fn if correct_fn is not None else default_correct_fn + ) assert ( mode.upper() in MinParamPerturbationMode.__members__ @@ -147,9 +149,9 @@ def _evaluate_batch( self, input_list: List, additional_forward_args: Any, - correct_fn_kwargs: Dict[str, Any], + correct_fn_kwargs: Optional[Dict[str, Any]], target: TargetType, - ) -> None: + ) -> Optional[int]: if additional_forward_args is None: additional_forward_args = () diff --git a/tests/robust/test_min_param_perturbation.py b/tests/robust/test_min_param_perturbation.py index aeed86f73b..a9b2cea842 100644 --- a/tests/robust/test_min_param_perturbation.py +++ b/tests/robust/test_min_param_perturbation.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import List +from typing import List, cast import torch from torch import Tensor @@ -55,7 +55,7 @@ def test_minimal_pert_basic_linear(self) -> None: target_inp, pert = minimal_pert.evaluate( inp, target=0, attack_kwargs={"ind": 0} ) - self.assertAlmostEqual(pert, 2.0) + self.assertAlmostEqual(cast(float, pert), 2.0) assertTensorAlmostEqual( self, target_inp, torch.tensor([[0.0, -9.0, 9.0, 1.0, -3.0]]) ) @@ -79,7 +79,7 @@ def test_minimal_pert_basic_binary(self) -> None: attack_kwargs={"ind": 0}, perturbations_per_eval=10, ) - self.assertAlmostEqual(pert, 2.0) + self.assertAlmostEqual(cast(float, pert), 2.0) assertTensorAlmostEqual( self, target_inp, torch.tensor([[0.0, -9.0, 9.0, 1.0, -3.0]]) )