from __future__ import annotations
import functools
import logging
import multiprocessing as mp
import time
import traceback
import numpy as np
from ..common import LazyLoader
from ..feat import Gate, QueueEventExtractor
from ..read import HDF5Data
from ..segm.segmenter import STRUCTURING_ELEMENT
from .chunk_slot import ChunkSlot, ChunkSlotData
from .job import DCNumPipelineJob
ndi = LazyLoader("scipy.ndimage")
mp_spawn = mp.get_context("spawn")
[docs]
class count_time:
"""Decorator for counting execution time"""
[docs]
def __call__(self, func):
@functools.wraps(func)
def method(inst, *args, **kwargs):
t0 = time.perf_counter()
retval = func(inst, *args, **kwargs)
t1 = time.perf_counter()
# update the time counter for this method
fn = func.__name__
if fn in inst.timers:
with inst.timers[fn].get_lock():
inst.timers[fn].value += t1 - t0
return retval
return method
[docs]
class SlotRegister:
def __init__(self,
job: DCNumPipelineJob,
data: HDF5Data,
event_queue: mp.Queue | None = None,
num_slots: int = 3):
"""A register for `ChunkSlot`s for shared memory access
The `SlotRegister` manages all `ChunkSlot` instances and
implements methods to interact with individual `ChunkSlot`s.
"""
self.job = job
self.data = data
self.event_queue = event_queue or mp_spawn.Queue()
self.chunk_size = data.image_chunk_size
self.num_chunks = data.image_num_chunks
self._slots = []
self.timers = {
"task_load_all": mp_spawn.Value("d", 0.0),
"task_label_masks": mp_spawn.Value("d", 0.0),
"task_process_labels": mp_spawn.Value("d", 0.0),
"task_extract_features": mp_spawn.Value("d", 0.0),
}
# Counters are created with recursive locks, which means that the
# same process may acquire multiple locks on the object, and only
# after releasing all of them, may the lock be acquired by another
# process.
self.counters = {
"chunks_loaded": mp_spawn.Value("Q", 0),
"masks_dropped": mp_spawn.Value("Q", 0),
"write_queue_size": mp_spawn.Value("Q", 0),
}
self._state = mp_spawn.Value("u", "w")
self.num_frames = len(self.data)
"""Total number of frames to process"""
self.feat_nevents = mp_spawn.RawArray("l", self.num_frames)
"""Number of events per frame
Shared RawArray of length `len(data)` into which the number of
events per frame is written.
"""
# Initialize feat_nevents with -1
self.feat_nevents[:] = np.full(self.num_frames, -1)
# Generate all requested slots.
for ii in range(max(1, num_slots)):
self._slots.append(ChunkSlot(job=job, data=data))
# Add a slot for the remainder. The size of the remainder chunks
# slot is smaller or equal to the others.
chunk_slot_remainder = ChunkSlot(job=job, data=data, is_remainder=True)
self._slots.append(chunk_slot_remainder)
[docs]
def __getitem__(self, idx):
return self.slots[idx]
[docs]
def __iter__(self):
"""Iterate over slots, sorted by current chunk number"""
slots_indices = np.argsort([sc.chunk for sc in self.slots])
for idx in slots_indices:
yield self.slots[idx]
[docs]
def __len__(self):
return len(self.slots)
@property
def chunks_loaded(self):
"""A process-safe counter for the number of chunks loaded
This number increments as `SlotRegister.task_load_all` is called.
"""
return self.counters["chunks_loaded"].value
@chunks_loaded.setter
def chunks_loaded(self, value):
self.counters["chunks_loaded"].value = value
@property
def masks_dropped(self):
"""A process-safe counter for the number of masks dropped
Segmentation may drop invalid masks/events.
"""
return self.counters["masks_dropped"].value
@masks_dropped.setter
def masks_dropped(self, value):
self.counters["masks_dropped"].value = value
@property
def write_queue_size(self):
"""A process-safe counter for the number of chunks in the writer queue
A large number indicates a slow writer which can be
a result of a slow hard disk or a slow CPU (since
is used compression). Used for preventing
OOM events by stalling data processing when the writer is slow
"""
return self.counters["write_queue_size"].value
@property
def slots(self):
"""A list of all `ChunkSlots`"""
return [s for s in self._slots]
@property
def state(self):
"""State of the `SlotRegister`, used for communication with workers
- "w": initialized (workers work)
- "p": paused (all workers pause)
- "q": quit (all workers stop)
"""
return self._state.value
@state.setter
def state(self, value):
self._state.value = value
[docs]
def close(self):
# Let everyone know we are closing
self._state.value = "q"
[docs]
def find_slot(self,
state: str,
chunk: int | None = None
) -> ChunkSlot | None:
"""Return the first `ChunkSlot` that has the given state
We sort the slots according to the slot chunks so that we
always process the slot with the smallest slot chunk number
first. Initially, the slot_chunks array is filled with
zeros, but we populate it here.
Return None if no matching slot exists
"""
for sc in self:
if sc.state == state:
if chunk is None:
return sc
elif sc.chunk == chunk:
return sc
# fallback to nothing found
return None
[docs]
def get_counter_lock(self, name):
if name in self.counters:
return self.counters[name].get_lock()
else:
raise KeyError(f"No counter lock defined for {name}")
[docs]
def get_time(self, method_name):
"""Return accumulative time for the given method"""
return self.timers[method_name].value
[docs]
def reserve_slot_for_task(self,
current_state: str,
next_state: str,
chunk_slot: ChunkSlot | None = None,
batch_size: int | None = None,
) -> "StateWarden | None":
"""Return slot with the specified state and lowest chunk index
Parameters
----------
current_state:
State requried for the task to start
next_state:
State that will be set after the task is done
chunk_slot:
Optional `ChunkSlot` to operate on; if set to None, search
for a matching one, and if none can be found, return None
batch_size:
Number of frames to reserve for performing the task. Defaults
to the entire chunk.
Returns
-------
state_warden
Context manager that enforces setting the next state or
None if no `ChunkSlot` could be reserved.
Usage:
if state_warden is not None:
with state_warden as (chunk_slot, batch_range):
perform_task(chunk_slot,
start_index=batch_range[0],
stop_index=batch_range[1]
)
The `batch_range` indices are defined by the `batch_size`
parameter.
This context manager will automatically set the slot
state to `next_state` when the context is exits
without exceptions.
"""
if chunk_slot is None:
for sc in self:
if sc.state == current_state:
sw = StateWarden(sc,
current_state=current_state,
next_state=next_state,
batch_size=batch_size)
if sw.batch_size:
return sw
else:
# nothing could be reserved
return None
# fallback to nothing found
return None
else:
return StateWarden(chunk_slot,
current_state=current_state,
next_state=next_state,
batch_size=batch_size)
##############################################################
# Tasks (ordered according to their sequence of application) #
##############################################################
[docs]
@count_time()
def task_load_all(self,
logger: logging.Logger | None = None
) -> bool:
"""Load chunk data into memory for as many slots as possible
Returns
-------
did_something : bool
Whether data were loaded into memory
"""
did_something = False
lock = self.get_counter_lock("chunks_loaded")
has_lock = lock.acquire(block=False)
if has_lock and self.chunks_loaded < self.num_chunks:
try:
for cs in self:
# The number of self.chunks_loaded is identical to the
# chunk index we want to load next.
# The cs.chunk number will be set to the new chunk. All
# ChunkSlots have a cs.chunk number that is smaller than
# `self.chunks_loaded`.
# We are interested in chunks with the state "i" and
# will transform them into chunks in state "s".
if cs.state == "i":
# We have at least two chunk slots: One
# or more that handle the majority of the frames,
# and one that handles the final (remainder) chunk.
if ((self.chunks_loaded < self.num_chunks - 1
and not cs.is_remainder)
or (self.chunks_loaded == self.num_chunks - 1
and cs.is_remainder)):
state_warden = self.reserve_slot_for_task(
current_state="i",
next_state="s",
chunk_slot=cs)
if (state_warden is not None
and state_warden.batch_size):
if state_warden.batch_size != cs.length:
raise ValueError(
f"Batch size must match chunk size "
f"({state_warden.batch_range=} vs. "
f"{cs.length=}) for {cs=}")
with state_warden:
cs.load(self.chunks_loaded)
self.chunks_loaded += 1
did_something = True
except BaseException:
if logger is not None:
logger.error(traceback.format_exc())
finally:
lock.release()
return did_something
[docs]
@count_time()
def task_label_masks(self,
logger: logging.Logger | None = None
) -> bool:
"""Perform labeling of mask images for a `ChunkSlot`
This method is process-safe. Multiple processes may call it
concurrently, working on the same `ChunkSlot`.
Returns
-------
did_something : bool
Whether masks where converted to labels
"""
did_something = False
cs = self.find_slot(state="m")
logger = logger or logging.getLogger(__name__)
# Comput labels from this ChunkSlot
batch_size = 100
if cs is not None and cs.state == "m":
while True:
state_warden = self.reserve_slot_for_task(
current_state="m",
next_state="l",
chunk_slot=cs,
batch_size=batch_size)
if state_warden is not None and state_warden.batch_size:
with state_warden as (_, batch_range):
for idx in range(*batch_range):
cs.labels[idx], _ = ndi.label(
input=cs.mask[idx],
structure=STRUCTURING_ELEMENT)
did_something = True
else:
break
return did_something
[docs]
@count_time()
def task_process_labels(self,
logger: logging.Logger | None = None
) -> bool:
"""Perform label processing (e.g. binary closing) for a `ChunkSlot`
This method is process-safe. Multiple processes may call it
concurrently, working on the same `ChunkSlot`.
Returns
-------
did_something : bool
Whether labels were processed
"""
did_something = False
cs = self.find_slot(state="l")
logger = logger or logging.getLogger(__name__)
# Comput labels from this ChunkSlot
batch_size = 100
if cs is not None and cs.state == "l":
while True:
state_warden = self.reserve_slot_for_task(
current_state="l",
next_state="e",
chunk_slot=cs,
batch_size=batch_size)
if state_warden is not None and state_warden.batch_size:
with state_warden as (_, batch_range):
# TODO: This segmentation juggling looks clunky.
km = self.job.kwargs["segmenter_kwargs"]["kwargs_mask"]
segm = self.job.get_segmenter_class()(kwargs_mask=km)
if segm.mask_postprocessing:
for idx in range(*batch_range):
cs.labels[idx] = \
segm.process_labels(cs.labels[idx],
**segm.kwargs_mask)
segm.close()
did_something = True
else:
break
return did_something
[docs]
class StateWarden:
"""Context manager for changing the state of a `ChunkSlot`"""
def __init__(self,
chunk_slot: ChunkSlot | ChunkSlotData,
current_state: str,
next_state: str,
batch_size: int | None = None,
):
# Make sure the task lock is acquired.
self.batch_range = chunk_slot.acquire_task_lock(
req_state=current_state,
batch_size=batch_size)
self.batch_size = self.batch_range[1] - self.batch_range[0]
self.chunk_slot = chunk_slot
self.current_state = current_state
self.next_state = next_state
[docs]
def __enter__(self):
# Make sure the state is still correct
# release the lock, because somebody else might need it
if self.chunk_slot.state != self.current_state:
self.chunk_slot.release_task_lock(*self.batch_range,
task_done=False)
raise ValueError(
f"Current state of slot {self.chunk_slot} "
f"({self.chunk_slot.state}) does not match "
f"expected state {self.current_state}.")
return self.chunk_slot, self.batch_range
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
self.chunk_slot.release_task_lock(
*self.batch_range,
# only set batch to done if no exception occurred
task_done=exc_type is None)
if self.chunk_slot.get_progress() == 1:
self.chunk_slot.state = self.next_state
[docs]
def __repr__(self):
return (f"<StateWarden {self.current_state}->{self.next_state} "
f"at {hex(id(self))}>")