This repository contains an implementation of a Vision Transformer (ViT) model designed to classify brain tumor images into four categories (Meningioma, Pituitary, Glioma and No tumor). The model can be trained on any dataset of brain tumor MRI scans.
├── data/ # Dataset directory (Not included)
├── best_model.pth # Saved model with the best validation accuracy
├── cleanup.py # Script for cleaning up the dataset
├── requirements.txt # List of required Python packages
├── test.py # Script for testing the model on new images
├── train.py # Script for training the model
└── transformer.py # Definition of the Vision Transformer model
-
Clone the repository:
git clone https://github.com/marvelefe/VIT-POC.git cd VIT-POC
-
Create a virtual environment and install dependencies:
python3 -m venv venv source venv/bin/activate # On Windows, use `venv\Scripts\activate` pip install -r requirements.txt
-
Prepare the dataset:
- Place your dataset under the
./data
directory and split your training and validation so you have two sub-directories:./data/Training
and./data/Testing
directories.
- Place your dataset under the
A sample dataset can be downloaded here on Kaggle https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset
A snapshot of images from the dataset:
To train the model, run:
python train.py
Training progress, including loss and accuracy for both training and validation sets, are displayed and saved as plots (accuracy.png
, loss.png
). The best-performing model is saved as best_model.pth
.
Training Accuracy and Validation Loss History:
Combined Training and Validation History (Epochs 1 to 20):
During training, the model is evaluated against the validation dataset. Post-training, a confusion matrix is generated and saved as confusion_matrix.png
, and provides insights into the model's performance.
To classify a new image, use the test.py
script:
python test.py
Replace the image path in the script with your target image. The model predicts the tumor class, and the confidence level is displayed along with a bar chart saved as prediction_result.png
.
The model achieves a 95.58% accuracy in classifying tumor images across four distinct classes. Performance metrics, including accuracy and loss plots, are available in the repository.