brails.processors.FoundationClassifier.csail_segmentation_tool.csail_seg.lib.nn.modules.replicate module

class brails.processors.FoundationClassifier.csail_segmentation_tool.csail_seg.lib.nn.modules.replicate.CallbackContext

Bases: object

class brails.processors.FoundationClassifier.csail_segmentation_tool.csail_seg.lib.nn.modules.replicate.DataParallelWithCallback(module: T, device_ids: Sequence[int | device] | None = None, output_device: int | device | None = None, dim: int = 0)

Bases: DataParallel

Data Parallel with a replication callback.

An replication callback __data_parallel_replicate__ of each module will be invoked after being created by original replicate function. The callback will be invoked with arguments __data_parallel_replicate__(ctx, copy_id)

Examples:

> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked.

replicate(module, device_ids)
brails.processors.FoundationClassifier.csail_segmentation_tool.csail_seg.lib.nn.modules.replicate.execute_replication_callbacks(modules)

Execute an replication callback __data_parallel_replicate__ on each module created by original replication.

The callback will be invoked with arguments __data_parallel_replicate__(ctx, copy_id)

Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information.

We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies.

brails.processors.FoundationClassifier.csail_segmentation_tool.csail_seg.lib.nn.modules.replicate.patch_replication_callback(data_parallel)

Monkey-patch an existing DataParallel object. Add the replication callback. Useful when you have customized DataParallel implementation.

Examples:

> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])