This repository contains training, evaluation and testing of classification models implementated from stratch using PyTorch.
Ramen 🍜 , Sashimi 🐟 , Sushi 🍣 , Takoyaki 🍡
Model | Variants |
---|---|
VGG | 11, 13, 16, 19 |
Resnet | 18, 34, 50, 101, 152 |
Mobilenet | v2, v3-small, v3-large |
Shufflenet | v2 |
ConvNeXt | tiny, small, base, large |
Please install essential dependencies (see requirements.txt
)
conda create -n jpfood python==3.8
conda activate jpfood
pip install -r requirements.txt
The dataset used in this repository is a subset of Food-101 dataset.
The original dataset contains 1000 images for each class.
In this repository, train - 200, validation - 100 and test - 100 images are used.
Japanese food dataset can be downloaded from link
Model implementations are available in models/{model}.py
folder.
To initialize PyTorch's pretrained model weights to our model architecture, run:
python convert.py --model resnet --variant 50
options:
--model : Model name {vgg, resnet, mobilenet}
--variant : Model variant
--width_mulit : channel width ratio for mobilenet
--num_class : 1000 default for ImageNet classes
Pretrained models will be saved as pretrained/{model}.pt
Change model and variant of desired model before training and testing in config.yaml
.
model:
name: 'resnet'
variant: '50'
Hyperparameters recommendation:
Model | Optimizer | Learning_rate |
---|---|---|
VGG | sgd | 0.001 |
Resnet | sgd | 0.001 |
Mobilenet | sgd | 0.001 |
Shufflenet | sgd | 0.005 |
ConvNeXt | adamw | 0.00005 |
For training, run:
python train.py
options:
--eval : To evaluate the model without training again
--eval_model : model path for evaluation
The result for each model training is saved in experiments/{model_name}{datetime}/
folder.
For inference, run:
python inference.py --image {image_path} --viz
options:
--image : Folder path to inference images
--viz : Whether to visualize inferece results
for C++ inference, change required inputs and run:
python trace_model.py
to convert pytorch model to script model. After tracing the model, use it in libtorch-inference repo.