brails.processors.image_segmenter.image_segmenter module
Class objects to train and use image segmentation models.
- class brails.processors.image_segmenter.image_segmenter.DatasetBinary(root: str, imageFolder: str, maskFolder: str, transforms: Callable | None = None)
Bases:
VisionDataset
Create a PyTorch dataset for binary masks.
- Args__
root (str): Root directory of the dataset. imageFolder (str): Folder name containing the images. maskFolder (str): Folder name containing the binary masks. transforms (Optional[Callable], optional): A function/transform to
apply to both the images and masks. Defaults to None.
- Attributes__
image_names (List[Path]): List of paths to the image files. mask_names (List[Path]): List of paths to the mask files.
- class brails.processors.image_segmenter.image_segmenter.DatasetRGB(root: str, imageFolder: str, maskFolder: str, transforms: Callable | None = None)
Bases:
VisionDataset
Create a PyTorch dataset for RGB masks.
- Args__
root (str): Root directory of the dataset. imageFolder (str): Folder name containing the images. maskFolder (str): Folder name containing the RGB masks. transforms (Optional[Callable], optional): A function/transform to
apply to the images and masks. Defaults to None.
- Attributes__
image_names (List[Path]): List of paths to the image files. mask_names (List[Path]): List of paths to the mask files.
- class brails.processors.image_segmenter.image_segmenter.ImageSegmenter(modelArch='deeplabv3_resnet101')
Bases:
object
A class to manage and train an image segmentation model.
This class provides methods to download train models using transfer learning, retrain existing models, and make predictions using these models. The class supports using different model architectures and performs operations on GPU if available.
- predict(imdir: str, classes: List[str], model_path: str = 'tmp/models/trained_seg_model.pth') Dict[str, str]
Segment images in the specified directory using a pre-trained model.
- Args__
imdir (str): The directory containing the images to be predicted. classes (List[str]): List of class names corresponding to the
model’s output classes.
- model_path (str, optional): Path to the trained model file.
Defaults to ‘tmp/models/trained_seg_model.pth’.
- Returns__
- Dict[str, str]: A dictionary where keys are image filenames and
values are predicted class labels.
- Raises__
FileNotFoundError: If the specified model file does not exist. NotADirectoryError: If the specified image directory is not a
directory.
- ValueError: If the image directory is empty or no valid images are
found.
- train(train_data_dir: str, classes: List[str], batch_size: int = 2, nepochs: int = 100, es_tolerance: int = 10, plot_loss: bool = True) None
Train a segmentation model using the specified training input.
- Args__
- train_data_dir (str): Directory containing the training data.
It should have subdirectories for each class.
- classes (List[str]): List of class names. The classes should
correspond to the subdirectories in train_data_dir.
- batch_size (int, optional): Number of samples per batch.
Defaults to 2.
- nepochs (int, optional): Number of epochs for training.
Defaults to 100.
- es_tolerance (int, optional): Number of epochs to wait for
improvement before early stopping. Defaults to 10.
- plot_loss (bool, optional): Whether to plot the training loss
curve. Defaults to True.
- Raises__
FileNotFoundError: If the train_data_dir is not found. ValueError: If classes is empty or if the number of epochs is
invalid.
- Returns__
None