Source code for dcnum.segm.segm_torch.torch_postproc

from ...common import LazyLoader
from ..segmenter import Segmenter, STRUCTURING_ELEMENT

import numpy as np


ndi = LazyLoader("scipy.ndimage")


[docs] def postprocess_masks(masks, original_image_shape: tuple[int, int]): """Postprocess mask images from ML segmenters The transformation includes: - Revert the cropping and padding operations done in :func:`.preprocess_images` by padding with zeros and cropping. - If the original image shape is larger than the mask image shape, also clear borders in an intermediate step (maks postprocessing using :func:`Segmenter.process_labels`). Parameters ---------- masks: 3d or 4d ndarray Mask data in shape (batch_size, 1, imagex_size, imagey_size) or (batch_size, imagex_size, imagey_size). original_image_shape: tuple of (int, int) The required output mask shape for one event. This required for doing the inverse of what is done in :func:`.preprocess_images`. Returns ------- labels_proc: np.ndarray An integer array with the same dimensions as the original image data passed to :func:`.preprocess_images`. The shape of this array is (batch_size, original_image_shape[0], original_image_shape[1]). Notes ----- This method is only called by the overarching logic when the preprocessing/model output produces images of different shape. It causes an obvious overhead that we want to avoid. """ # If output of model is 4d, remove channel axis if len(masks.shape) == 4: masks = masks[:, 0, :, :] # Label the mask image labels = np.empty(masks.shape, dtype=np.uint16) for ii in range(masks.shape[0]): ndi.label( input=masks[ii], output=labels[ii], structure=STRUCTURING_ELEMENT) batch_size = labels.shape[0] # Revert padding and cropping from preprocessing mask_shape_ret = labels.shape[1:] # height s0diff = original_image_shape[0] - mask_shape_ret[0] s0t = abs(s0diff) // 2 s0b = abs(s0diff) - s0t # width s1diff = original_image_shape[1] - mask_shape_ret[1] s1l = abs(s1diff) // 2 s1r = abs(s1diff) - s1l if s0diff > 0 or s1diff > 0: # The masks that we have must be padded. Before we do that, we have # to remove events on the edges, otherwise we will have half-segmented # cell events in the output array. for ii in range(batch_size): labels[ii] = Segmenter.process_labels(labels[ii], clear_border=True, fill_holes=False, closing_disk=0) # Crop first, only then pad. if s1diff > 0: labels_pad = np.zeros((batch_size, labels.shape[1], original_image_shape[1]), dtype=np.uint16) labels_pad[:, :, s1l:-s1r] = labels labels = labels_pad elif s1diff < 0: labels = labels[:, :, s1l:-s1r] if s0diff > 0: labels_pad = np.zeros((batch_size, original_image_shape[0], original_image_shape[1]), dtype=np.uint16) labels_pad[:, s0t:-s0b, :] = labels labels = labels_pad elif s0diff < 0: labels = labels[:, s0t:-s0b, :] return labels