brails.processors.vlm_image_classifier.CLIPClassifier module
Class object to call CLIP for image classification.
- class brails.processors.vlm_image_classifier.CLIPClassifier.CLIPClassifier(task: str, input_dict: dict | None = None)
Bases:
object
A classifier that utilizes CLIP model to predict attributes from images.
This class is designed to load a CLIP model, process input images, and make predictions for the entered textual prompts. It supports customizable classes and prompts to enhance prediction accuracy.
- Attributes__
- model_arch (str): The architecture of the model to be used. Available
model architectures are ‘ViT-B/32’ (default), RN50’, ‘RN101’, ‘RN50x4’,’RN50x16’, ‘RN50x64’, ‘ViT-B/16’, ‘ViT-L/14’, and ‘ViT-L/14@336px’.
device (torch.device): The device (CPU or GPU) used for computations. preds (Optional[Dict[str, str]]): A dictionary to hold predictions. batch_size (int): The number of images processed in a single batch. template (str): A template for formatting text prompts.
- Args__
task (str): The task for which the classifier is being used. input_dict (Optional[dict]): A dictionary containing model architecture
and other configuration parameters.
- predict(images: ImageSet, model_path: str = '', classes: List[str] | None = None, text_prompts: List[str] | None = None) Dict[str, str]
Predicts classes for the given images using the CLIP model.
- Args__
images (ImageSet): An object containing the images to classify. model_path (str): The path to the pre-trained model. classes (Optional[List[str]]): A list of class names. text_prompts (Optional[List[str]]): A list of text prompts
corresponding to the classes.
- Returns__
- Dict[str, str]: A dictionary mapping image keys to their predicted
classes.
- Raises__
- AssertionError: If the conditions on classes and text_prompts are
not met.