brails.processors.image_classifier.image_classifier module

Class object to train and use image classification models.

class brails.processors.image_classifier.image_classifier.ImageClassifier

Bases: object

A class to manage and train an image classification model.

This class provides methods to download train models using transfer learning, retrain existing models, and make predictions using these models. The class supports using different model architectures and performs operations on GPU if available.

Attributes__ - device (torch.device): The device (GPU or CPU) on which the model will be

trained and evaluated.

  • implemented_architectures (List[str]): List of supported model

    architectures.

predict(images: ImageSet, model_arch: str = 'efficientnetv2_s', model_path: str = 'tmp/models/trained_model.pth', classes: List[str] = ['Ants', 'Bees']) Dict[str, str]

Predict the class of each image in the provided image set.

Args__
images (ImageSet): An object containing the directory path and

image data.

model_arch (str): The architecture of the model to be used for

prediction.

model_path (str, optional): Path to the pre-trained model. Defaults

to ‘tmp/models/trained_model.pth’.

classes (List[str], optional): List of class labels. Defaults to

[‘Ants’, ‘Bees’].

Raises__
NotImplementedError: If the specified model architecture is not

found in MODEL_PROPERTIES.

NotADirectoryError: If the directory containing the images is not

valid.

Returns__
Dict[str, str]: A dictionary where the keys are image names and

the values are the predicted classes.

retrain(model_arch: str = 'efficientnetv2_s', model_path: str = 'tmp/models/trained_model.pth', train_data_dir: str = 'tmp/hymenoptera_data', batch_size: int = 32, nepochs: int | List[int] = 100, es_tolerance: int = 10, plot_accuracy: bool = True) None

Retrain existing model using training dataset and hyperparameters.

Parameters__ - model_path (str): Path to the pre-trained model to be fine-tuned.

Default is ‘tmp/models/trained_model.pth’.

  • train_data_dir (str): Directory containing the training and

    validation datasets. Default is ‘tmp/hymenoptera_data’.

  • model_inp_size (int): Input size for the model, used for resizing

    images in the dataset. Default is 384.

  • batch_size (int): Batch size for data loading. Default is 32.

  • nepochs (Union[int, List[int]]): Number of epochs for training.

    Should be an integer for retraining. Default is 100.

  • es_tolerance (int): Early stopping tolerance; the number of epochs

    without improvement before stopping. Default is 10.

  • plot_accuracy (bool): Whether to plot the validation accuracy over

    epochs. Default is True.

Returns__ - None

Raises__ - ValueError: If nepochs is not provided as an integer during

retraining.

Procedure__ 1. Loads the model from the specified path. 2. Applies data augmentation and normalization to the training dataset

and normalization to the validation dataset.

  1. If the default training data directory is used, downloads a sample

    dataset.

  2. Prepares PyTorch DataLoader objects for the training and

    validation datasets.

  3. Sends the model to the GPU and fine-tunes it using Stochastic

    Gradient Descent (SGD) optimization.

  4. After training, the model is saved at the specified path.

  5. If plot_accuracy is True, plots the validation accuracy versus

    training epochs.

Example__ >>> classifier = ImageClassifier() >>> classifier.retrain(model_path=’model.pth’,

train_data_dir=’my_data’, nepochs=50)

train(model_arch: str = 'efficientnetv2_s', train_data_dir: str = 'tmp/hymenoptera_data', batch_size: int = 32, nepochs: int | List[int] = 100, es_tolerance: int = 10, plot_accuracy: bool = True) None

Train a model using transfer learning.

Parameters__ - model_arch (str): The architecture of the model to use

(e.g., ‘efficientnetv2_s’).

  • train_data_dir (str): The directory where the training and validation

    data is located.

  • batch_size (int): The number of samples per batch.

  • nepochs (Union[int, List[int]]): Number of epochs for initial

    training and fine-tuning. If an integer, it will be split into two halves for initial training and fine-tuning. If a list of two integers, it will use the two values as epochs for initial training and fine-tuning respectively.

  • es_tolerance (int): Number of epochs with no improvement after which

    training will be stopped early.

  • plot_accuracy (bool): Whether to plot the validation accuracy

    against the number of training epochs.

Returns__ - None: This method does not return any value.

Raises__ - ValueError: If nepochs is not an integer or a list of two integers. - NotImplementedError: model_arch is not defined

Example__ >>> trainer = ImageClassifier() >>> trainer.train( … model_arch=’resnet50’, … train_data_dir=’path/to/data’, … batch_size=64, … nepochs=[10, 20], … es_tolerance=5, … plot_accuracy=True … ) New classifier head trained using transfer learning. Fine-tuning the model… Training complete.