2.2.2. PyTorch Image Classifier

The PyTorch Generic Image Classifier is a class that can be used for creating user defined classifier. It can be used for training and evaluating the model. The users can directly use the pre-trained model for making predictions or the users can fine-tune the pre-trained model with additional training data. The pre-trained model provides a good test accuracy (97.31% for rooftpye classification) if the test data come from the same distribution as the training data. However, it is recommended that the users should fine-tune the pre-trained model due to the out-of-distribution generalization issue of neural networks (https://arxiv.org/pdf/1903.12261.pdf).


  1. init

    1. modelName Name of the model default = ‘rooftype_resnet18_v1’

    2. imgDir Directories for training data

    3. valimgDir Directories for validation data

    4. download Downlaod the pre-trained model default = False

    5. random_split Ratio to split the data into a training set and a validation set if validation data is not provided.

    6. resultFile Name of the result file for predicting multple images. default = preds.csv

    7. workDir The working directory default = tmp

    8. printRes Show the probability and prediction default=True

    9. printConfusionMatrix Whether to print the confusion matrix or not default=False

  2. train

    1. lr1: default=0.01

    2. epochs: default==10

    3. batch_size: default==64

    4. plot: default==False

  3. **fine_tuning*

    1. lr1: default=0.001

    2. epochs: default==10

    3. batch_size: default==32

    4. plot: default==False

  4. predictOneImage

    1. imagePath: Path to a single image

  5. predictMultipleImages

    1. imagePathList: A list of image paths

    2. resultFile: The name of the result filename default=None

  6. predictOneDirectory

    1. directory_name: Directory for saving all the images

    2. resultFile: The name of the result filename default=None Description

This class implements the abstraction of an image classifier implemented based on PyTorch, it can be first used to train the classifier. Once trained, the classifier can be used to predict the class of each image given a set of images. This class can also be used to load a pre-trained model to make predictions. The loaded pre-trained model can be fine-tuned with provided images to overcome the out-of-distribution generalization issue of neural networks. Example

The following is an example, in which a classifier is created and trained.

The image dataset for this example contains satellite images categorized according to roof type.

The dataset can be downloaded here.

When unzipped, the file gives the ‘roofType’. You need to set ‘’imgDir’’ to the corresponding directory. The roofType directory contains the images for training:

│── class_1
│       └── *.png
│── class_2
|      └── *.png
│── ...
└── class_n
       └── *.png Construct the image classifier

# import the module
from brails.modules import PytorchImageClassifier

# initialize the classifier, give it a name and a directory
roofClassifier = PytorchImageClassifier(modelName='rooftype_resnet18_v1', imgDir='./roofType/') Fine-tune the model

# Fine the base model for 5 epochs with an initial learning rate of 0.01.

roofClassifier.fine_tuning(lr=0.001, batch_size=64, epochs=5)

It is recommended to run the above example on a GPU machine.

Please refer to https://github.com/rwightman/pytorch-image-models for supported models. You may need to first install timm via pip: pip install timm. Classify Images Based on the Model

Now you can use the trained model to predict the (roofType) class for a given image.

# If you are running the inference from another place, you need to initialize the classifier firstly:

from brails.PytorchGenericModelClassifier import PytorchImageClassifier

roofClassifier = PytorchImageClassifier(modelName='rooftype_resnet18_v1')

# define the paths of images in a list
imgs = ["./roofType/flat/TopViewx-76.84779286x38.81642318.png",

# use the model to predict
predictions_dataframe = roofClassifier.predictMultipleImages(imgs)

The predictions will be written in preds.csv under the working directory.


The generic image classifier is intended to illustrate the overall process of model training and prediction. The classifier takes an image as the input and will always produce a prediction. Since the classifier is trained to classify only a specific category of images, its prediction is meaningful only if the input image belongs to the categories the model is trained for.