This is a wrapper I built for Resnet18 fine-tuned on the Flowers102 dataset.
Flowers102 is a dataset of flowers provided by the Oxford University. The dataset has varying sized images belonging to one of 102 classes. The training and validation split have 1020 images, while the test set has 6149.
The repository is comprised of:
- Training.ipynb: A notebook that shows the transfer learning process.
- FlowersResnetWrapper.py: The wrapper class for resnet18.
- resnet.pth: The pythorch checkpoint to usethe wrapper.
- labels.txt: The list of human-readable classes needed for the wrapper to work.
To initialize the wrapper you just need to provide the checkpoint's path, either via string or Path object.
The wrapper accepts either PIL images or tensors, if a tensor is provided, the model expects a tensor of shapes either
The predictions are provided through a dictionary with 3 keys:
- logits: The logits of the model
-
class_ids: The class id of each prediction, computed as
$argmax(softmax(predictions))$ - class_names: The human-readable name of the class id of each prediction
- torch==2.4.1+cu124
- torchvision==0.19.1+cu124
- pillow==10.4.0