brails.processors.vlm_segmenter.segment_anything.modeling.sam module

class brails.processors.vlm_segmenter.segment_anything.modeling.sam.Sam(image_encoder, prompt_encoder, mask_decoder, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375])

Bases: Module

property device
forward(batched_input, multimask_output)

Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly.

Parameters:
  • batched_input (list(dict)) –

    A list over input images, each a dictionary with the following keys. A prompt key can be excluded if it is not present.

    ’image’: The image as a torch tensor in 3xHxW format,

    already transformed for input to the model.

    ’original_size’: (tuple(int, int)) The original size of

    the image before transformation, as (H, W).

    ’point_coords’: (torch.Tensor) Batched point prompts for

    this image, with shape BxNx2. Already transformed to the input frame of the model.

    ’point_labels’: (torch.Tensor) Batched labels for point prompts,

    with shape BxN.

    ’boxes’: (torch.Tensor) Batched box inputs, with shape Bx4.

    Already transformed to the input frame of the model.

    ’mask_inputs’: (torch.Tensor) Batched mask inputs to the model,

    in the form Bx1xHxW.

  • multimask_output (bool) – Whether the model should predict multiple disambiguating masks, or return a single mask.

Returns:

A list over input images, where each element is
as dictionary with the following keys.
’masks’: (torch.Tensor) Batched binary mask predictions,

with shape BxCxHxW, where B is the number of input prompts, C is determiend by multimask_output, and (H, W) is the original size of the image.

’iou_predictions’: (torch.Tensor) The model’s predictions

of mask quality, in shape BxC.

’low_res_logits’: (torch.Tensor) Low resolution logits with

shape BxCxHxW, where H=W=256. Can be passed as mask input to subsequent iterations of prediction.

Return type:

(list(dict))

image_format = 'RGB'
mask_threshold = 0.0
postprocess_masks(masks, input_size, original_size)

Remove padding and upscale masks to the original image size.

Parameters:
  • masks (torch.Tensor) – Batched masks from the mask_decoder, in BxCxHxW format.

  • input_size (tuple(int, int)) – The size of the image input to the model, in (H, W) format. Used to remove padding.

  • original_size (tuple(int, int)) – The original size of the image before resizing for input to the model, in (H, W) format.

Returns:

Batched masks in BxCxHxW format, where (H, W)

is given by original_size.

Return type:

(torch.Tensor)

preprocess(x)

Normalize pixel values and pad to a square input.