Image classification with Pytorch and model interpretability.
Before running the code, you must setup the python environment that's specified within the environment.yml
file. To do so, run the following commands within your terminal:
$ conda env create -f environment.yml
$ conda activate pytorch_img_classif
$ conda remove --name pytorch_img_classif --all
The code for model training is contained in the transfer_learning.ipynb
notebook which must be run as a standalone script. This code also performs model evaluation and the computation of GradCAM heatmaps.
In order to run the script, make sure you have updated the data_dir
python variable which must contain the path to the folder containing your dataset.
This repository contains the results obtained when running this transfer learning code on the VGG16 using the casting defect detection dataset from Kaggle. This dataset contains images of casting manufacturing products and the underlying problem consists in classifying them based on whether the products are defective or not.
The first attempts gave an accuracy that greater that 99%. Besides that, GradCAM heatmaps show meaningful results on test images insofar as most of the defects are highlighted by the method (see figure below).