Source code for dcnum.segm.segm_torch.torch_preproc

import numpy as np


[docs] def preprocess_images(images: np.ndarray, norm_mean: float | None, norm_std: float | None, image_shape: tuple[int, int] | None = None, ): """Transform image data to something torch models expect The transformation includes: - normalization (division by 255, subtraction of mean, division by std) - cropping and padding of the input images to `image_shape`. For padding, the median of each *individual* image is used. - casting the input images to four dimensions (batch_size, 1, height, width) where the second axis is "channels" Parameters ---------- images: Input image array (batch_size, height_in, width_in). If this is a 2D image, it will be reshaped to a 3D image with a batch_size of 1. norm_mean: Mean value used for standard score data normalization, i.e. `normalized = `(images / 255 - norm_mean) / norm_std`; Set to None to disable normalization. norm_std: Standard deviation used for standard score data normalization; Set to None to disable normalization (see above). image_shape Image shape for which the model was created (height, width). If the image shape does not match the input image shape, then the input images are padded/cropped to fit the image shape of the model. Returns ------- image_proc: 3D array with preprocessed image data of shape (batch_size, 1, height, width) """ if len(images.shape) == 2: # Insert indexing axis (batch dimension) images = images[np.newaxis, :, :] batch_size = images.shape[0] # crop and pad the images based on what the model expects image_shape_act: tuple[int, int] = images.shape[1:] if image_shape is None: # model fits perfectly to input data image_shape = image_shape_act # height hdiff = image_shape_act[0] - image_shape[0] ht = abs(hdiff) // 2 hb = abs(hdiff) - ht # width wdiff = image_shape_act[1] - image_shape[1] wl = abs(wdiff) // 2 wr = abs(wdiff) - wl # helper variables wpad = wdiff < 0 wcrp = wdiff > 0 hpad = hdiff < 0 hcrp = hdiff > 0 # The easy part is the cropping if hcrp or wcrp: # define slices for width and height slice_hc = slice(ht, -hb) if hcrp else slice(None, None) slice_wc = slice(wl, -wr) if wcrp else slice(None, None) img_proc = images[:, slice_hc, slice_wc] else: img_proc = images # The hard part is the padding if hpad or wpad: # compute median for each original input image img_med = np.median(images, axis=(1, 2)) # broadcast the median array from 1D to 3D img_med = img_med[:, None, None] # define slices for width and height slice_hp = slice(ht, -hb) if hpad else slice(None, None) slice_wp = slice(wl, -wr) if wpad else slice(None, None) # empty padding image stack with the shape required for the model img_pad = np.empty(shape=(batch_size, image_shape[0], image_shape[1]), dtype=np.float32) # fill in original data img_pad[:, slice_hp, slice_wp] = img_proc # fill in background data for height if hpad: img_pad[:, :ht, :] = img_med img_pad[:, -hb:, :] = img_med # fill in background data for width if wpad: img_pad[:, :, :wl] = img_med img_pad[:, :, -wr:] = img_med # Replace img_norm img_proc = img_pad if norm_mean is None or norm_std is None: # convert to float32 img_norm = img_proc.astype(np.float32) else: # normalize images img_norm = (img_proc.astype(np.float32) / 255 - norm_mean) / norm_std # Add a "channels" axis for the ML models. return img_norm[:, np.newaxis, :, :]