brails.processors.vlm_segmenter.segment_anything.modeling.sam module
- class brails.processors.vlm_segmenter.segment_anything.modeling.sam.Sam(image_encoder: ImageEncoderViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_std: List[float] = [58.395, 57.12, 57.375])
Bases:
Module
- property device: Any
- forward(batched_input: List[Dict[str, Any]], multimask_output: bool) List[Dict[str, Tensor]]
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly.
- Arguments:
- batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be excluded if it is not present.
- ‘image’: The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
- ‘original_size’: (tuple(int, int)) The original size of
the image before transformation, as (H, W).
- ‘point_coords’: (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the input frame of the model.
- ‘point_labels’: (torch.Tensor) Batched labels for point prompts,
with shape BxN.
- ‘boxes’: (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
- ‘mask_inputs’: (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
- multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
- Returns:
- (list(dict)): A list over input images, where each element is
- as dictionary with the following keys.
- ‘masks’: (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts, C is determiend by multimask_output, and (H, W) is the original size of the image.
- ‘iou_predictions’: (torch.Tensor) The model’s predictions
of mask quality, in shape BxC.
- ‘low_res_logits’: (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input to subsequent iterations of prediction.
- image_format: str = 'RGB'
- mask_threshold: float = 0.0
- postprocess_masks(masks: Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...]) Tensor
Remove padding and upscale masks to the original image size.
- Arguments:
- masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
- input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
- original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
- Returns:
- (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
- preprocess(x: Tensor) Tensor
Normalize pixel values and pad to a square input.
- training: bool