brails.processors.vlm_image_classifier.CLIPClassifier module

Class object to call CLIP for image classification.

class brails.processors.vlm_image_classifier.CLIPClassifier.CLIPClassifier(input_dict: dict | None = None)

Bases: object

A classifier that utilizes CLIP model to predict attributes from images.

This class is designed to load a CLIP model, process input images, and make predictions for the entered textual prompts. It supports customizable classes and prompts to enhance prediction accuracy.

Attributes:
model_arch (str):

The architecture of the model to be used. Available model architectures are ‘ViT-B/32’ (default), RN50’, ‘RN101’, ‘RN50x4’,’RN50x16’, ‘RN50x64’, ‘ViT-B/16’, ‘ViT-L/14’, and ‘ViT-L/14@336px’.

device (torch.device):

The device (CPU or GPU) used for computations.

batch_size (int):

The number of images processed in a single batch.

template (str):

A template for formatting text prompts.

Args:
input_dict (Optional[dict]): A dictionary containing model architecture

and other configuration parameters.

predict(images: ImageSet, classes: list[str], text_prompts: list[str]) dict[str, str]

Predicts classes for the given images using the CLIP model.

Args:
images (ImageSet):

An object containing the images to classify.

classes (list[str]):

A list of class names.

text_prompts (list[str]):

A list of text prompts corresponding to the classes.

Returns:
dict[str, str]:

A dictionary mapping image keys to their predicted classes.

Raises:
TypeError or ValueError:

If the conditions on classes and text_prompts are not met.