brails.processors.vlm_image_classifier.CLIPClassifier module

Class object to call CLIP for image classification.

class brails.processors.vlm_image_classifier.CLIPClassifier.CLIPClassifier(input_dict=None)

Bases: object

A classifier that utilizes CLIP model to predict attributes from images.

This class is designed to load a CLIP model, process input images, and make predictions for the entered textual prompts. It supports customizable classes and prompts to enhance prediction accuracy.

model_arch

The architecture of the model to be used. Available model architectures are ‘ViT-B/32’ (default), RN50’, ‘RN101’, ‘RN50x4’,’RN50x16’, ‘RN50x64’, ‘ViT-B/16’, ‘ViT-L/14’, and ‘ViT-L/14@336px’.

Type:

str

device

The device (CPU or GPU) used for computations.

Type:

torch.device

batch_size

The number of images processed in a single batch.

Type:

int

template

A template for formatting text prompts.

Type:

str

Parameters:

input_dict (Optional[dict]) – A dictionary containing model architecture and other configuration parameters.

predict(images, classes, text_prompts)

Predicts classes for the given images using the CLIP model.

Parameters:
  • images (ImageSet) – An object containing the images to classify.

  • classes (list[str]) – A list of class names.

  • text_prompts (list[str]) – A list of text prompts corresponding to the classes.

Returns:

A dictionary mapping image keys to their predicted classes.

Return type:

dict[str, str]

Raises:

TypeError or ValueError – If the conditions on classes and text_prompts are not met.