In this project, we first develop code for an image classifier built with PyTorch, then convert it into a command line application. GPU is necessary for training the Deep Learning model.
-
Train Image Classifier Model, which uses Deep Learning Pre-trained Neural Networks (VGG/DenseNet) to train a neural network to recognize different species of flowers (dataset of 102 flower categories);
-
Python App which allows user to input some arguments in order to train and make prediction.
- Clone the repository, and navigate to the downloaded folder. This may take a minute or two to clone due to the included image data
$ git clone https://github.com/nalbert9/AI_App_Classification_Flowers_Python.git
-
Create (and activate) a new Anaconda environment
-
Install PyTorch and torchvision; this should install the latest version of PyTorch
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
Command Line
Train a new network on a data set with train.py
-
Basic usage:
python train.py data_directory
-
Prints out training loss, validation loss, and validation accuracy as the network trains
-
Options:
- Set directory to save checkpoints:
python train.py data_dir --save_dir save_directory
- Choose architecture:
python train.py data_dir --arch "vgg16"
- Set hyperparameters:
python train.py data_dir --learning_rate 0.01 --hidden_units 512 --epochs 20
- Use GPU for training:
python train.py data_dir --gpu
- Set directory to save checkpoints:
-
Predict flower name from an image with
predict.py
along with the probability of that name. That is, you'll pass in a single image/path/to/image
and return the flower name and class probability. -
Basic usage:
python predict.py /path/to/image checkpoint
-
Options:
- Return top K most likely classes: python predict.py --input checkpoint
--top_k 3
- Use a mapping of categories to real names:
python predict.py --input checkpoint --category_names cat_to_name.json
- Use GPU for inference:
python predict.py --input checkpoint --gpu
- Return top K most likely classes: python predict.py --input checkpoint
You can imagine using something like this in a phone app that tells you the name of the flower your camera is looking at.
The contents of this repository are covered under the MIT License.