Skip to content

Latest commit

 

History

History
72 lines (47 loc) · 1.64 KB

README.md

File metadata and controls

72 lines (47 loc) · 1.64 KB

PyTorch-CAM

This project provide a script of class activation map (CAM) visualizations, which can be used for explaining predictions and model interpretability, etc.

Installation

  • Install via pip
$ pip install torchcam
  • Install from source
$ pip install --upgrade git+https://github.com/Tramac/pytorch-cam.git

Usage

from torchcam import open_image, image2batch, int2tensor, getCAM
from torchvision.models import resnet18

img = open_image('./data/cat.jpg', (224, 224), convert_mode='RGB')
input = image2batch(img)
image_class = 284 # cat class in imagenet
target = int2tensor(image_class)

model = resnet18(pretrained=True)
  • Basic
# gradcam
cam = getCAM(model, img, input, target, display=True, save=False)
  • Adavanced

Besides the default gradcam method, these following additional methods are alse available: vanilla_grad, grad_x_input, saliency, integrate_grad, deconv, smooth_grad.

from torchcam import saliency

results = saliency.get_image_saliency_result(model, img, input, target, methods=['smooth_grad', 'vanilla_grad', 'grad_x_input', 'saliency'])
figure = saliency.get_image_saliency_plot(results, display=True, save=False)
  • Yours own model
model = YourModel()

cam = getCAM(model, img, input, target, layer_path=['xxx']) # The end backprop layer key in your model

Result

TODO

  • support more explainer
  • optim code
  • test your own model

Reference