Source code for dcnum.segm.segm_torch.torch_model

import errno
import functools
import hashlib
import json
import logging
import os
import pathlib

import numpy as np

from ...meta import paths

from .torch_setup import torch


logger = logging.getLogger(__name__)


[docs] def check_md5sum(path): """Verify the last five characters of the file stem with its MD5 hash""" md5 = hashlib.md5(path.read_bytes()).hexdigest() if md5[:5] != path.stem.split("_")[-1]: raise ValueError(f"MD5 mismatch for {path} ({md5})! Expected the " f"input file to end with '{md5[:5]}{path.suffix}'.")
[docs] @functools.cache def load_model(path_or_name, device): """Load a PyTorch model + metadata from a TorchScript jit checkpoint Parameters ---------- path_or_name: str or pathlib.Path jit checkpoint file; For dcnum, these files have the suffix .dcnm and contain a special `_extra_files["dcnum_meta.json"]` extra file that can be loaded via `torch.jit.load` (see below). device: str or torch.device device on which to run the model Returns ------- model_jit: torch.jit.ScriptModule loaded PyTorch model stored as a TorchScript module model_meta: dict metadata associated with the loaded model """ device = torch.device(device) model_path = retrieve_model_file(path_or_name) # define an extra files mapping dictionary that loads the model's metadata extra_files = {"dcnum_meta.json": ""} # load model model_jit = torch.jit.load(model_path, _extra_files=extra_files, map_location=device) # load model metadata model_meta = json.loads(extra_files["dcnum_meta.json"]) # set model to evaluation mode model_jit.eval() # optimize for inference on device model_jit = torch.jit.optimize_for_inference(model_jit) if device.type == "cuda": # Estimate the batch size for the current device. # In principle, we would be fine with a batch size of 50, but # there is a slight improvement in performance when going to # higher batch sizes and users will also see the GPU usage # in the task manager (to perform a sanity check). sy, sx = model_meta["preprocessing"]["image_shape"] # We estimate the batch size by determining the memory usage. size = 100 for _ in range(50): data = torch.tensor( np.zeros((size, 1, sy, sx), dtype=np.float32), device=device) data_seg = model_jit(data) data_seg_bin = data_seg > 0.5 # noqa: F841 torch.cuda.synchronize() free, total = torch.cuda.mem_get_info(device) if free / total < 0.1: # leave a bit of space for other things size -= 100 break size += 100 del data, data_seg, data_seg_bin import gc gc.collect() torch.cuda.empty_cache() # 50 images should fit in any GPU size = max(size, 50) model_meta["estimated_batch_size_cuda"] = size return model_jit, model_meta
[docs] @functools.cache def retrieve_model_file(path_or_name): """Retrieve a dcnum torch model file If a path to a model is given, then this path is returned directly. If a file name is given, then look for the file with :func:`dcnum.meta.paths.find_file` using the "torch_model_file" topic. """ # Did the user already pass a path? if isinstance(path_or_name, pathlib.Path): if path_or_name.exists(): path = path_or_name else: try: return retrieve_model_file(path_or_name.name) except BaseException: raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), str(path_or_name)) elif isinstance(path_or_name, str): name = path_or_name.strip() # We now have a string for a filename, and we have to figure out what # the path is. There are several options, including cached files. if pathlib.Path(name).exists(): path = pathlib.Path(name) else: path = paths.find_file("torch_model_files", name) else: raise ValueError( f"Please pass a string or a path, got {type(path_or_name)}!") logger.info(f"Found dcnum model file {path}") check_md5sum(path) return path