brails.processors.image_segmenter.image_segmenter module

Class objects to train and use image segmentation models.

class brails.processors.image_segmenter.image_segmenter.DatasetBinary(root: str, imageFolder: str, maskFolder: str, transforms: Callable | None = None)

Bases: VisionDataset

Create a PyTorch dataset for binary masks.

Args__

root (str): Root directory of the dataset. imageFolder (str): Folder name containing the images. maskFolder (str): Folder name containing the binary masks. transforms (Optional[Callable], optional): A function/transform to

apply to both the images and masks. Defaults to None.

Attributes__

image_names (List[Path]): List of paths to the image files. mask_names (List[Path]): List of paths to the mask files.

class brails.processors.image_segmenter.image_segmenter.DatasetRGB(root: str, imageFolder: str, maskFolder: str, transforms: Callable | None = None)

Bases: VisionDataset

Create a PyTorch dataset for RGB masks.

Args__

root (str): Root directory of the dataset. imageFolder (str): Folder name containing the images. maskFolder (str): Folder name containing the RGB masks. transforms (Optional[Callable], optional): A function/transform to

apply to the images and masks. Defaults to None.

Attributes__

image_names (List[Path]): List of paths to the image files. mask_names (List[Path]): List of paths to the mask files.

class brails.processors.image_segmenter.image_segmenter.ImageSegmenter(modelArch='deeplabv3_resnet101')

Bases: object

A class to manage and train an image segmentation model.

This class provides methods to download train models using transfer learning, retrain existing models, and make predictions using these models. The class supports using different model architectures and performs operations on GPU if available.

predict(imdir: str, classes: List[str], model_path: str = 'tmp/models/trained_seg_model.pth') Dict[str, str]

Segment images in the specified directory using a pre-trained model.

Args__

imdir (str): The directory containing the images to be predicted. classes (List[str]): List of class names corresponding to the

model’s output classes.

model_path (str, optional): Path to the trained model file.

Defaults to ‘tmp/models/trained_seg_model.pth’.

Returns__
Dict[str, str]: A dictionary where keys are image filenames and

values are predicted class labels.

Raises__

FileNotFoundError: If the specified model file does not exist. NotADirectoryError: If the specified image directory is not a

directory.

ValueError: If the image directory is empty or no valid images are

found.

train(train_data_dir: str, classes: List[str], batch_size: int = 2, nepochs: int = 100, es_tolerance: int = 10, plot_loss: bool = True) None

Train a segmentation model using the specified training input.

Args__
train_data_dir (str): Directory containing the training data.

It should have subdirectories for each class.

classes (List[str]): List of class names. The classes should

correspond to the subdirectories in train_data_dir.

batch_size (int, optional): Number of samples per batch.

Defaults to 2.

nepochs (int, optional): Number of epochs for training.

Defaults to 100.

es_tolerance (int, optional): Number of epochs to wait for

improvement before early stopping. Defaults to 10.

plot_loss (bool, optional): Whether to plot the training loss

curve. Defaults to True.

Raises__

FileNotFoundError: If the train_data_dir is not found. ValueError: If classes is empty or if the number of epochs is

invalid.

Returns__

None