brails.processors.vlm_image_classifier.CLIPClassifier module
Class object to call CLIP for image classification.
- class brails.processors.vlm_image_classifier.CLIPClassifier.CLIPClassifier(input_dict: dict | None = None)
Bases:
object
A classifier that utilizes CLIP model to predict attributes from images.
This class is designed to load a CLIP model, process input images, and make predictions for the entered textual prompts. It supports customizable classes and prompts to enhance prediction accuracy.
- Attributes:
- model_arch (str):
The architecture of the model to be used. Available model architectures are ‘ViT-B/32’ (default), RN50’, ‘RN101’, ‘RN50x4’,’RN50x16’, ‘RN50x64’, ‘ViT-B/16’, ‘ViT-L/14’, and ‘ViT-L/14@336px’.
- device (torch.device):
The device (CPU or GPU) used for computations.
- batch_size (int):
The number of images processed in a single batch.
- template (str):
A template for formatting text prompts.
- Args:
- input_dict (Optional[dict]): A dictionary containing model architecture
and other configuration parameters.
- predict(images: ImageSet, classes: list[str], text_prompts: list[str]) dict[str, str]
Predicts classes for the given images using the CLIP model.
- Args:
- images (ImageSet):
An object containing the images to classify.
- classes (list[str]):
A list of class names.
- text_prompts (list[str]):
A list of text prompts corresponding to the classes.
- Returns:
- dict[str, str]:
A dictionary mapping image keys to their predicted classes.
- Raises:
- TypeError or ValueError:
If the conditions on classes and text_prompts are not met.