brails.processors.vlm_image_classifier.CLIPClassifier module
Class object to call CLIP for image classification.
- class brails.processors.vlm_image_classifier.CLIPClassifier.CLIPClassifier(input_dict=None)
Bases:
object
A classifier that utilizes CLIP model to predict attributes from images.
This class is designed to load a CLIP model, process input images, and make predictions for the entered textual prompts. It supports customizable classes and prompts to enhance prediction accuracy.
- model_arch
The architecture of the model to be used. Available model architectures are ‘ViT-B/32’ (default), RN50’, ‘RN101’, ‘RN50x4’,’RN50x16’, ‘RN50x64’, ‘ViT-B/16’, ‘ViT-L/14’, and ‘ViT-L/14@336px’.
- Type:
str
- device
The device (CPU or GPU) used for computations.
- Type:
torch.device
- batch_size
The number of images processed in a single batch.
- Type:
int
- template
A template for formatting text prompts.
- Type:
str
- Parameters:
input_dict (Optional[dict]) – A dictionary containing model architecture and other configuration parameters.
- predict(images, classes, text_prompts)
Predicts classes for the given images using the CLIP model.
- Parameters:
images (ImageSet) – An object containing the images to classify.
classes (list[str]) – A list of class names.
text_prompts (list[str]) – A list of text prompts corresponding to the classes.
- Returns:
A dictionary mapping image keys to their predicted classes.
- Return type:
dict[str, str]
- Raises:
TypeError or ValueError – If the conditions on classes and text_prompts are not met.