brails.utils.model_utils module
This module provides a utility class for computer vision models in BRAILS.
Utility class for computer vision models in BRAILS. |
- class brails.utils.model_utils.ModelUtils
Bases:
object
Utility class for computer vision models in BRAILS.
This class provides static methods to ensure necessary model files are available locally, downloading them if needed. Intended for use in applications that rely on pre-trained model weights.
To use
ModelUtils
, include thisimport
statement in your code:from brails.utils import ModelUtils
- static get_model_path(model_path='', default_filename='', download_url='', model_description='model', overwrite=False)
Check if a model file is available locally, download it if necessary.
- Parameters:
model_path (str, optional) – Custom path to the model file. If provided, no download occurs.
default_filename (str, optional) – Filename to use if downloading the model.
download_url (str, optional) – URL to download the model if model does not exist locally.
model_description (str, optional) – Human-readable description of the model (default:
'model'
).overwrite (bool, optional) – If
True
, re-download and overwrite the model file even if it exists. Defaults toFalse
.
- Returns:
Absolute path to the model file.
- Return type:
str
- Raises:
ValueError – If
model_path
is not provided and eitherdefault_filename
ordownload_url
is missing.
Examples
Use a custom model path:
>>> from brails.utils import ModelUtils >>> path = ModelUtils.get_model_path( ... model_path='my_models/custom.pth' ... ) Inferences will be performed using the custom model in my_models/custom.pth
Download a default model if not already available:
>>> path = ModelUtils.get_model_path( ... default_filename='default_model.pth', ... download_url=( ... 'https://zenodo.org/record/7271554/files/' ... 'trained_model_rooftype.pth' ... ), ... model_description='roof classification model' ... ) Downloading default roof classification model to tmp/models/default_model.pth... 100%|██████████| 77.9M/77.9M [04:11<00:00, 325kB/s] Default roof classification model successfully downloaded.
Force overwrite an existing model file:
>>> path = ModelUtils.get_model_path( ... default_filename='default_model.pth', ... download_url=( ... 'https://zenodo.org/record/7271554/files/' ... 'trained_model_rooftype.pth' ... ), ... model_description='roof classification model', ... overwrite=True ... ) Re-downloading default roof classification model to tmp/models/default_model.pth... 100%|██████████| 77.9M/77.9M [04:12<00:00, 324kB/s] Default roof classification model successfully downloaded.