Source code for dcnum.segm.segmenter_mpo

import abc
import multiprocessing as mp
import time
import threading

import numpy as np

from ..common import cpu_count, start_workers_threaded
from ..os_env_st import RequestSingleThreaded, confirm_single_threaded

from .segmenter import Segmenter


# All subprocesses should use 'spawn' to avoid issues with threads
# and 'fork' on POSIX systems.
mp_spawn = mp.get_context('spawn')


[docs] class MPOSegmenter(Segmenter, abc.ABC): hardware_processor = "cpu" def __init__(self, *, num_workers: int | None = None, kwargs_mask: dict | None = None, debug: bool = False, **kwargs): """Segmenter with multiprocessing operation Parameters ---------- num_workers Number of workers (processes) to spawn kwargs_mask: dict Keyword arguments for mask post-processing (see `process_labels`) debug: bool Debugging parameters kwargs: Additional, optional keyword arguments for ``segment_algorithm`` defined in the subclass. """ super(MPOSegmenter, self).__init__(kwargs_mask=kwargs_mask, debug=debug, **kwargs) self.num_workers = num_workers or cpu_count() self.slot_list = None """List of ChunkSlot instances""" self.mp_slot_index = mp_spawn.Value("I", 0) """The slot that is currently being worked on""" self.mp_active = mp_spawn.Event() """Event that defines whether the workers are allowed to do work""" self.mp_num_workers_done = mp_spawn.Value("I", 0) """Number of workers that are done processing the slot""" self.mp_shutdown = mp_spawn.Event() """Shutdown event tells workers to stop when set to != 0""" # workers self._worker_starter = None self._workers = []
[docs] def __enter__(self): return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb): self.join_workers()
[docs] def __getstate__(self): # Copy the object's state from self.__dict__ which contains # all our instance attributes. Always use the dict.copy() # method to avoid modifying the original state. # This is important when using "spawn" to create new processes, # because then the entire object gets pickled and some things # cannot be pickled! state = self.__dict__.copy() # Remove the unpicklable entries. del state["logger"] del state["_workers"] del state["_worker_starter"] return state
[docs] def __setstate__(self, state): # Restore instance attributes self.__dict__.update(state)
[docs] def join_workers(self): """Ask all workers to stop and join them""" if self._worker_starter is not None: self._worker_starter.join() if self._workers: self.mp_shutdown.set() for w in self._workers: w.join() if hasattr(w, "close"): w.close() self._workers.clear()
[docs] def reinitialize_workers(self, slot_list): self.join_workers() self.slot_list = slot_list self.mp_shutdown.clear() if self.debug: worker_cls = MPOSegmenterWorkerThread num_workers = 1 self.logger.debug("Running with one worker in main thread") else: worker_cls = MPOSegmenterWorkerProcess num_workers = self.num_workers self.logger.debug(f"Running with {num_workers} workers") chunk_size = self.slot_list[0].length self.num_workers = min(num_workers, chunk_size) step_size = chunk_size // self.num_workers rest = chunk_size % self.num_workers w_start = 0 for ii in range(self.num_workers): # Give every worker the same-sized workload and add one # from the rest until there is no more. w_stop = w_start + step_size if rest: w_stop += 1 rest -= 1 args = [self, w_start, w_stop] w = worker_cls(*args) self._workers.append(w) w_start = w_stop self._worker_starter = start_workers_threaded( worker_list=self._workers, logger=self.logger, name="SegmenterWorker")
[docs] def segment_batch(self, images: np.ndarray, bg_off: np.ndarray | None = None, ): """Perform batch segmentation of `images` Before segmentation, an optional background offset correction with ``bg_off`` is performed. After segmentation, mask postprocessing is performed according to the segmenter class definition. Parameters ---------- images: 3d np.ndarray of shape (N, Y, X) The time-series image data. First axis is time. bg_off: 1D np.ndarray of length N Optional 1D numpy array with background offset Notes ----- - If the segmentation algorithm only accepts background-corrected images, then `images` must already be background-corrected, except for the optional `bg_off`. """ from ..logic.chunk_slot_data import ChunkSlotData available_features = ["image_bg"] if bg_off is not None: if not self.requires_background_correction: raise ValueError(f"The segmenter {self.__class__.__name__} " f"does not employ background correction, " f"but the `bg_off` keyword argument was " f"passed to `segment_batch`. Please check " f"your analysis pipeline.") if bg_off is not None: available_features.append("bg_off") cs = ChunkSlotData(shape=images.shape, available_features=available_features) if self.requires_background_correction: cs.image_corr[:] = images if bg_off is not None: cs.bg_off[:] = bg_off else: cs.image[:] = images cs.chunk = 0 self.segment_chunk(0, [cs]) return cs.mask[:]
[docs] def segment_chunk(self, chunk: int, # noqa: F821 slot_list: list, ): """Segment the image data of one `ChunkSlot` Parameters ---------- chunk: The data chunk index to perform segmentation on slot_list: List of `ChunkSlotData` instances (e.g. `SlotRegister.slots`) Returns ------- mask: np.array The `chunk_slot.mask` numpy view on the shared boolean mask array. """ self.mp_active.clear() # Find the slot that we are supposed to be working on. for cs in slot_list: if cs.chunk == chunk: break else: raise ValueError(f"Could not find slot for {chunk=}") # Prepare everything for the workers, so they can already start # segmenting when they are created. slot_index = slot_list.index(cs) self.mp_slot_index.value = slot_index self.mp_num_workers_done.value = 0 self.mp_active.set() if self.slot_list is not None: for cs1, cs2 in zip(slot_list, self.slot_list): if cs1 is not cs2: # Something changed. We have to respawn the workers. self.slot_list = slot_list self.reinitialize_workers(slot_list) break else: self.slot_list = slot_list self.reinitialize_workers(slot_list) while self.mp_num_workers_done.value != self.num_workers: time.sleep(.01) self.mp_active.clear() return cs.mask
[docs] def segment_single(self, image, bg_off: float | None = None): """Return the boolean mask image for an input image Before segmentation, an optional background offset correction with ``bg_off`` is performed. After segmentation, mask postprocessing is performed according to the class definition. """ segm_wrap = self.segment_algorithm_wrapper() # optional subtraction of background offset if bg_off is not None: image = image - bg_off # obtain mask mask = segm_wrap(image) return mask
[docs] def close(self): self.join_workers()
[docs] class MPOSegmenterWorker: def __init__(self, segmenter, sl_start: int, sl_stop: int, ): """Worker process for CPU-based segmentation Parameters ---------- segmenter: .segmenter_mpo.MPOSegmenter The segmentation instance sl_start: int Start of slice of input array to process sl_stop: int Stop of slice of input array to process """ # Must call super init, otherwise Thread or Process are not initialized super(MPOSegmenterWorker, self).__init__() self.segmenter = segmenter self.slot_list = segmenter.slot_list """List of ChunkSlot instances""" self.mp_slot_index = segmenter.mp_slot_index """The slot that is currently being worked on""" self.mp_active = segmenter.mp_active """Whether the workers are allowed to do work""" self.mp_num_workers_done = segmenter.mp_num_workers_done """Number of workers that are done processing the slot""" self.mp_shutdown = segmenter.mp_shutdown """Shutdown bit tells workers to stop when set to != 0""" self.sl_start = sl_start self.sl_stop = sl_stop
[docs] def run(self): # confirm single-threadedness (prints to log) confirm_single_threaded() last_chunk = -1 while True: if self.mp_shutdown.is_set(): break if self.mp_active.wait(timeout=1): # Get the current slot cs = self.slot_list[self.mp_slot_index.value] if cs.chunk == last_chunk: # We processed this chunk already time.sleep(.01) continue elif self.sl_start >= cs.length: # We have no data to process pass else: if self.segmenter.requires_background_correction: images = cs.image_corr bg_offs = cs.bg_off else: images = cs.image bg_offs = None # Iterate over the chunks in that slot for idx in range(self.sl_start, min(self.sl_stop, cs.length)): cs.mask[idx] = self.segmenter.segment_single( image=images[idx], bg_off=None if bg_offs is None else bg_offs[idx], ) with self.mp_num_workers_done: self.mp_num_workers_done.value += 1 last_chunk = cs.chunk
[docs] class MPOSegmenterWorkerProcess(MPOSegmenterWorker, mp_spawn.Process): def __init__(self, *args, **kwargs): super(MPOSegmenterWorkerProcess, self).__init__(*args, **kwargs)
[docs] def start(self): # Set all relevant os environment variables such libraries in the # new process only use single-threaded computation. with RequestSingleThreaded(): mp_spawn.Process.start(self)
[docs] class MPOSegmenterWorkerThread(MPOSegmenterWorker, threading.Thread): def __init__(self, *args, **kwargs): super(MPOSegmenterWorkerThread, self).__init__(*args, **kwargs)