brails.processors.image_segmenter.image_segmenter module
Class objects to train and use image segmentation models.
- class brails.processors.image_segmenter.image_segmenter.DatasetBinary(root, imageFolder, maskFolder, transforms=None)
Bases:
VisionDataset
A PyTorch dataset class for loading binary masks paired with images.
- Parameters:
root (str) – The root directory containing the dataset.
imageFolder (str) – The name of the folder containing the images.
maskFolder (str) – The name of the folder containing the binary masks.
transforms (transforms.Compose or None, optional) – A composition of transforms to apply to both the images and masks. If None, no transform will be applied. Default is None.
- image_names
A sorted list of file paths to the images in the dataset.
- Type:
list[Path]
- mask_names
A sorted list of file paths to the binary masks in the dataset.
- Type:
list[Path]
- __len__()
Returns the number of samples in the dataset.
- __getitem__(index
int): Retrieves a sample (image and corresponding binary mask) at the specified index.
- class brails.processors.image_segmenter.image_segmenter.DatasetRGB(root, imageFolder, maskFolder, transforms=None)
Bases:
VisionDataset
A PyTorch dataset class for loading binary masks paired with images.
- Parameters:
root (str) – The root directory containing the dataset.
imageFolder (str) – The name of the folder containing the images.
maskFolder (str) – The name of the folder containing the binary masks.
transforms (transforms.Compose or None, optional) – A function or transform to apply to both the images and masks. If None, no transformation will be applied. Default is None.
- image_names
A sorted list of file paths to the images in the dataset.
- Type:
list[Path]
- mask_names
A sorted list of file paths to the binary masks in the dataset.
- Type:
list[Path]
- class brails.processors.image_segmenter.image_segmenter.ImageSegmenter(model_arch='deeplabv3_resnet101')
Bases:
object
A class to manage and train an image segmentation model.
This class provides methods to download train models using transfer learning, retrain existing models, and make predictions using these models. The class supports using different model architectures and performs operations on GPU if available.
- model_arch
The model architecture (e.g., “deeplabv3_resnet101”).
- Type:
str
- device
The device for computation, either “cuda:0” for GPU or “cpu”.
- Type:
torch.device
- batch_size
The batch size for training, initialized to None.
- Type:
int, optional
- nepochs
The number of epochs for training, initialized to None.
- Type:
int, optional
- train_data_dir
The directory containing training data, initialized to None.
- Type:
str, optional
- classes
List of class names for classification, initialized to None.
- Type:
list[str], optional
- loss_history
History of loss values during training, initialized to None.
- Type:
list[float], optional
- predict(imdir, classes, model_path='tmp/models/trained_seg_model.pth')
Segment images in the specified directory using a pre-trained model.
- Parameters:
imdir (str) – The directory containing the images to be predicted.
classes (list[str]) – List of class names corresponding to the model’s output classes.
model_path (str, optional) – Path to the trained model file. Defaults to ‘tmp/models/trained_seg_model.pth’.
- Returns:
A dictionary where keys are image filenames and values are predicted class labels.
- Return type:
dict[str, str]
- Raises:
FileNotFoundError – If the specified model file does not exist.
NotADirectoryError – If the specified image directory is not a directory.
ValueError – If the image directory is empty or no valid images are found.
- train(train_data_dir, classes, batch_size=2, nepochs=100, es_tolerance=10, plot_loss=True)
Train a segmentation model using the specified training input.
- Parameters:
train_data_dir (str) – Directory containing the training data. It should have subdirectories for each class.
classes (list[str]) – List of class names. The classes should correspond to the subdirectories in train_data_dir.
batch_size (int, optional) – Number of samples per batch. Defaults to 2.
nepochs (int, optional) – Number of epochs for training. Defaults to 100.
es_tolerance (int, optional) – Number of epochs to wait for improvement before early stopping. Defaults to 10.
plot_loss (bool, optional) – Whether to plot the training loss curve. Defaults to True.
- Raises:
FileNotFoundError – If the train_data_dir is not found.
ValueError – If classes is empty or if the number of epochs is invalid.
- Returns:
None