Source code for dcnum.segm.segm_torch.segm_torch_sto

from dcnum.segm import STOSegmenter
import numpy as np

from .segm_torch_base import TorchSegmenterBase
from .torch_model import load_model
from .torch_preproc import preprocess_images
from .torch_postproc import postprocess_masks
from .torch_setup import torch


[docs] class SegmentTorchSTO(TorchSegmenterBase, STOSegmenter): """PyTorch segmentation (GPU version)"""
[docs] def log_info(self, logger, gpu_id=None): model_file = self.kwargs["model_file"] device = torch.device(gpu_id if gpu_id is not None else "cuda") logger.info(f"CUDA version: {torch.version.cuda}") logger.info(f"GPU name: {torch.cuda.get_device_name()}") compute_capability = ".".join( str(cc) for cc in torch.cuda.get_device_capability(device)) logger.info(f"GPU compute capability: {compute_capability}") _, total = torch.cuda.mem_get_info(device) logger.info(f"Available GPU memory: {total/1024**3:.1f}GB") model, model_meta = load_model(model_file, device) batch_size = model_meta["estimated_batch_size_cuda"] logger.info(f"GPU segmentation batch size: {batch_size}")
[docs] @staticmethod def is_available(): available = False if TorchSegmenterBase.is_available(): try: available = torch.cuda.is_available() except BaseException: available = False return available
[docs] @staticmethod def _segment_in_batches(images, model, model_meta, device): """Segment image data in batches Return mask array with same shape as `images`. """ size = len(images) # In dcnum <= 0.27.0, we had a fixed batch size of 50 which # resulted in a small speed penalty. Here, we use a batch size # that is tailored to the GPU memory. # Note that a batch size for segmentation larger than the chunk size # will result in an effective batch size that is identical to the # chunk size. The for-loop below will only have one iteration. batch_size = model_meta["estimated_batch_size_cuda"] # Preprocess the first image chunk batch_next = preprocess_images(images[0:batch_size], **model_meta["preprocessing"]) # Create empty array to fill up with segmented batches mask = np.empty((size, *batch_next.shape[-2:]), dtype=bool) for start_idx in range(0, size, batch_size): # Move image tensors to cuda batch = torch.tensor(batch_next, device=device) # Model inference batch_seg = model(batch) # perform thresholding on GPU batch_seg_bool = batch_seg >= 0.5 # For debugging and profiling, uncomment the next line. # torch.cuda.synchronize() # While we are waiting for the GPU, we can load the # next batch into memory (model(batch) runs async). im_next = images[start_idx + batch_size:start_idx + 2 * batch_size] if im_next.size: batch_next = preprocess_images(im_next, **model_meta["preprocessing"]) # Remove extra dim [B, C, H, W] --> [B, H, W] batch_seg_bool = batch_seg_bool.squeeze(1) # Convert cuda-tensor to numpy array and fill masks array # (This will lock until the GPU computation is complete). mask[start_idx:start_idx + batch_size] \ = batch_seg_bool.detach().cpu().numpy() # Perform postprocessing in cases where the image shapes don't match if mask.shape[1:] != images.shape[1:]: # This is inefficient, because `postprocess_masks` requires # us to convert mask to labels. labels = postprocess_masks( masks=mask, original_image_shape=images.shape[1:]) return labels > 0 else: return mask
[docs] @staticmethod def segment_algorithm(images, gpu_id: str | None = None, *, model_file: str | None = None): """ Parameters ---------- images: 3d ndarray array of N event images of shape (N, H, W) gpu_id: str optional argument specifying the GPU to use model_file: str path to or name of a dcnum model file (.dcnm); if only a name is provided, then the "torch_model_files" directory paths are searched for the file name Returns ------- mask: 2d boolean or integer ndarray mask images of shape (N, H, W) """ if model_file is None: raise ValueError("Please specify a .dcnm model file!") # Determine device to use device = torch.device(gpu_id if gpu_id is not None else "cuda") # Load model and metadata model, model_meta = load_model(model_file, device) # Model inference # The `masks` array has the shape (len(images), H, W), where # H and W may be different from the corresponding axes in `images`. mask = SegmentTorchSTO._segment_in_batches( images=images, model=model, model_meta=model_meta, device=device, ) return mask