brails.processors.image_classifier.image_classifier module
Class object to train and use image classification models.
- class brails.processors.image_classifier.image_classifier.ImageClassifier
Bases:
object
A class to manage and train an image classification model.
This class provides methods to download train models using transfer learning, retrain existing models, and make predictions using these models. The class supports using different model architectures and performs operations on GPU if available.
Attributes__ - device (torch.device): The device (GPU or CPU) on which the model will be
trained and evaluated.
- implemented_architectures (List[str]): List of supported model
architectures.
- predict(images: ImageSet, model_arch: str = 'efficientnetv2_s', model_path: str = 'tmp/models/trained_model.pth', classes: List[str] = ['Ants', 'Bees']) Dict[str, str]
Predict the class of each image in the provided image set.
- Args__
- images (ImageSet): An object containing the directory path and
image data.
- model_arch (str): The architecture of the model to be used for
prediction.
- model_path (str, optional): Path to the pre-trained model. Defaults
to ‘tmp/models/trained_model.pth’.
- classes (List[str], optional): List of class labels. Defaults to
[‘Ants’, ‘Bees’].
- Raises__
- NotImplementedError: If the specified model architecture is not
found in MODEL_PROPERTIES.
- NotADirectoryError: If the directory containing the images is not
valid.
- Returns__
- Dict[str, str]: A dictionary where the keys are image names and
the values are the predicted classes.
- retrain(model_arch: str = 'efficientnetv2_s', model_path: str = 'tmp/models/trained_model.pth', train_data_dir: str = 'tmp/hymenoptera_data', batch_size: int = 32, nepochs: int | List[int] = 100, es_tolerance: int = 10, plot_accuracy: bool = True) None
Retrain existing model using training dataset and hyperparameters.
Parameters__ - model_path (str): Path to the pre-trained model to be fine-tuned.
Default is ‘tmp/models/trained_model.pth’.
- train_data_dir (str): Directory containing the training and
validation datasets. Default is ‘tmp/hymenoptera_data’.
- model_inp_size (int): Input size for the model, used for resizing
images in the dataset. Default is 384.
batch_size (int): Batch size for data loading. Default is 32.
- nepochs (Union[int, List[int]]): Number of epochs for training.
Should be an integer for retraining. Default is 100.
- es_tolerance (int): Early stopping tolerance; the number of epochs
without improvement before stopping. Default is 10.
- plot_accuracy (bool): Whether to plot the validation accuracy over
epochs. Default is True.
Returns__ - None
Raises__ - ValueError: If nepochs is not provided as an integer during
retraining.
Procedure__ 1. Loads the model from the specified path. 2. Applies data augmentation and normalization to the training dataset
and normalization to the validation dataset.
- If the default training data directory is used, downloads a sample
dataset.
- Prepares PyTorch DataLoader objects for the training and
validation datasets.
- Sends the model to the GPU and fine-tunes it using Stochastic
Gradient Descent (SGD) optimization.
After training, the model is saved at the specified path.
- If plot_accuracy is True, plots the validation accuracy versus
training epochs.
Example__ >>> classifier = ImageClassifier() >>> classifier.retrain(model_path=’model.pth’,
train_data_dir=’my_data’, nepochs=50)
- train(model_arch: str = 'efficientnetv2_s', train_data_dir: str = 'tmp/hymenoptera_data', batch_size: int = 32, nepochs: int | List[int] = 100, es_tolerance: int = 10, plot_accuracy: bool = True) None
Train a model using transfer learning.
Parameters__ - model_arch (str): The architecture of the model to use
(e.g., ‘efficientnetv2_s’).
- train_data_dir (str): The directory where the training and validation
data is located.
batch_size (int): The number of samples per batch.
- nepochs (Union[int, List[int]]): Number of epochs for initial
training and fine-tuning. If an integer, it will be split into two halves for initial training and fine-tuning. If a list of two integers, it will use the two values as epochs for initial training and fine-tuning respectively.
- es_tolerance (int): Number of epochs with no improvement after which
training will be stopped early.
- plot_accuracy (bool): Whether to plot the validation accuracy
against the number of training epochs.
Returns__ - None: This method does not return any value.
Raises__ - ValueError: If nepochs is not an integer or a list of two integers. - NotImplementedError: model_arch is not defined
Example__ >>> trainer = ImageClassifier() >>> trainer.train( … model_arch=’resnet50’, … train_data_dir=’path/to/data’, … batch_size=64, … nepochs=[10, 20], … es_tolerance=5, … plot_accuracy=True … ) New classifier head trained using transfer learning. Fine-tuning the model… Training complete.