import functools
import pathlib
import re
import numpy as np
from ...meta import paths
from ..segmenter import Segmenter, SegmenterNotApplicableError
from .torch_model import load_model
from .torch_setup import torch
[docs]
class TorchSegmenterBase(Segmenter):
"""Torch segmenters that use a pretrained model for segmentation"""
requires_background_correction = False
mask_postprocessing = True
mask_default_kwargs = {
"clear_border": True,
"fill_holes": True,
"closing_disk": 0,
}
[docs]
@classmethod
def get_ppid_from_ppkw(cls, kwargs, kwargs_mask=None):
kwargs_new = kwargs.copy()
# Make sure that the `model_file` kwarg is actually just a filename
# so that the pipeline identifier only contains the name, but not
# the full path.
if "model_file" in kwargs:
model_file = kwargs["model_file"]
mpath = pathlib.Path(model_file)
if mpath.exists():
# register the location of the file in the search path
# registry so other threads/processes will find it.
paths.register_search_path("torch_model_files", mpath.parent)
kwargs_new["model_file"] = mpath.name
return super(TorchSegmenterBase, cls).get_ppid_from_ppkw(kwargs_new,
kwargs_mask)
[docs]
@classmethod
def validate_applicability(cls,
segmenter_kwargs: dict,
meta: dict | None = None,
logs: dict | None = None):
"""Validate the applicability of this segmenter for a dataset
The applicability is defined by the metadata in the segmentation
model.
Parameters
----------
segmenter_kwargs: dict
Keyword arguments for the segmenter
meta: dict
Dictionary of metadata from an :class:`.hdf5_data.HDF5Data`
instance
logs: dict
Dictionary of logs from an :class:`.hdf5_data.HDF5Data` instance
Returns
-------
applicable: bool
True if the segmenter is applicable to the dataset
Raises
------
SegmenterNotApplicable
If the segmenter is not applicable to the dataset
"""
if "model_file" not in segmenter_kwargs:
raise ValueError("A `model_file` must be provided in the "
"`segmenter_kwargs` to validate applicability")
logs = logs or {}
model_file = segmenter_kwargs["model_file"]
_, model_meta = load_model(model_file, device="cpu")
reasons_list = []
validators = {
"meta": functools.partial(
cls._validate_applicability_item,
data_dict=meta,
reasons_list=reasons_list),
"logs": functools.partial(
cls._validate_applicability_item,
# convert logs to strings
data_dict={key: "\n".join(val) for key, val in logs.items()},
reasons_list=reasons_list)
}
for item in model_meta.get("validation", []):
it = item["type"]
if it in validators:
validators[it](item)
else:
reasons_list.append(
f"invalid validation type {it} in {model_file}")
if reasons_list:
raise SegmenterNotApplicableError(segmenter_class=cls,
reasons_list=reasons_list)
return True
[docs]
@staticmethod
def _validate_applicability_item(item, data_dict, reasons_list):
"""Populate `reasons_list` with invalid entries
Example `data_dict`::
{"type": "meta",
"key": "setup:region",
"allow-missing-key": False,
"regexp": "^channel$",
"regexp-negate": False,
"reason": "only channel region supported",
}
"""
key = item["key"]
if key in data_dict:
valid = True
if "regexp" in item:
re_match = bool(re.search(item["regexp"], data_dict[key],
re.MULTILINE))
negate = item.get("regexp-negate", False)
valid = valid and (re_match if not negate else not re_match)
if "value" in item:
valid = valid and np.allclose(item["value"], data_dict[key],
atol=0, rtol=0.01)
if not valid:
reasons_list.append(item.get("reason", "unknown reason"))
elif not item.get("allow-missing-key", False):
reasons_list.append(f"Key '{key}' missing in {item['type']}")
[docs]
@staticmethod
def is_available():
try:
torch.__version__
except BaseException:
available = False
else:
available = True
return available