Source code for dcnum.segm.segm_torch.segm_torch_mpo

import numpy as np

from ..segmenter_mpo import MPOSegmenter

from .segm_torch_base import TorchSegmenterBase
from .torch_model import load_model
from .torch_preproc import preprocess_images
from .torch_postproc import postprocess_masks
from .torch_setup import torch


[docs] class SegmentTorchMPO(TorchSegmenterBase, MPOSegmenter): """PyTorch segmentation (multiprocessing version)"""
[docs] @staticmethod def segment_algorithm(image, *, model_file: str | None = None): """ Parameters ---------- image: 2d ndarray event image model_file: str path to or name of a dcnum model file (.dcnm); if only a name is provided, then the "torch_model_files" directory paths are searched for the file name Returns ------- mask: 2d boolean or integer ndarray mask or labeling image for the give index """ if model_file is None: raise ValueError("Please specify a .dcnm model file!") # Set number of pytorch threads to 1, because dcnum is doing # all the multiprocessing. # https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads if torch.get_num_threads() != 1: torch.set_num_threads(1) if torch.get_num_interop_threads() != 1: torch.set_num_interop_threads(1) device = torch.device("cpu") # Load model and metadata model, model_meta = load_model(model_file, device) image_preproc = preprocess_images(image[np.newaxis, :, :], **model_meta["preprocessing"]) image_ten = torch.from_numpy(image_preproc) # Move image tensors to device image_ten_on_device = image_ten.to(device) # Model inference pred_tensor = model(image_ten_on_device) # Convert cuda-tensor into numpy mask array. The `pred_tensor` # array is still of the shape (1, 1, H, W). The `masks` # array is of shape (1, H, W). mask = pred_tensor.detach().cpu().numpy()[0] >= 0.5 # Perform postprocessing in cases where the image shapes don't match assert len(mask[0].shape) == len(image.shape), "sanity check" if mask[0].shape != image.shape: # This is inefficient, because `postprocess_masks` requires # us to convert mask to labels. labels = postprocess_masks( masks=mask, original_image_shape=image.shape, ) return labels[0] > 0 else: return mask[0]