Skip to content

Commit

Permalink
ENH: fastapi implementation for running inference remotely
Browse files Browse the repository at this point in the history
- model database class that handles all local models including their download and/or deletion
- extracted functions into utils module
- fastapi server can be run from Slicer directly or from commandline
- make sure loaded terminologies are searched
  • Loading branch information
che85 committed Sep 18, 2024
1 parent e37d719 commit a088ebb
Show file tree
Hide file tree
Showing 12 changed files with 1,291 additions and 495 deletions.
9 changes: 9 additions & 0 deletions MONAIAuto3DSeg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ set(MODULE_NAME MONAIAuto3DSeg)
#-----------------------------------------------------------------------------
set(MODULE_PYTHON_SCRIPTS
${MODULE_NAME}.py
${MODULE_NAME}Lib/__init__.py
${MODULE_NAME}Lib/constants.py
${MODULE_NAME}Lib/dependency_handler.py
${MODULE_NAME}Lib/log_handler.py
${MODULE_NAME}Lib/model_database.py
${MODULE_NAME}Lib/server.py
${MODULE_NAME}Lib/utils.py
auto3dseg/__init__py
auto3dseg/main.py
)

set(MODULE_PYTHON_RESOURCES
Expand Down
966 changes: 493 additions & 473 deletions MONAIAuto3DSeg/MONAIAuto3DSeg.py

Large diffs are not rendered by default.

Empty file.
1 change: 1 addition & 0 deletions MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
APPLICATION_NAME = "MONAIAuto3DSeg"
130 changes: 130 additions & 0 deletions MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import sys

import shutil
import subprocess
import logging

from MONAIAuto3DSegLib.constants import APPLICATION_NAME
logger = logging.getLogger(APPLICATION_NAME)


from abc import ABC, abstractmethod


class DependenciesBase(ABC):

minimumTorchVersion = "1.12"

def __init__(self):
self.dependenciesInstalled = False # we don't know yet if dependencies have been installed

@abstractmethod
def installedMONAIPythonPackageInfo(self):
pass

@abstractmethod
def setupPythonRequirements(self, upgrade=False):
pass


class LocalPythonDependencies(DependenciesBase):

def installedMONAIPythonPackageInfo(self):
versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode()
return versionInfo

def _checkModuleInstalled(self, moduleName):
try:
import importlib
importlib.import_module(moduleName)
return True
except ModuleNotFoundError:
return False

def setupPythonRequirements(self, upgrade=False):
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])

logger.info("Initializing PyTorch...")

packageName = "torch"
if not self._checkModuleInstalled(packageName):
logger.info("PyTorch Python package is required. Installing... (it may take several minutes)")
install(packageName)
if not self._checkModuleInstalled(packageName):
raise ValueError("pytorch needs to be installed to use this module.")
else: # torch is installed, check version
from packaging import version
import torch
if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion):
raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module."
+ f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch"
+ f" with version requirement set to: >={self.minimumTorchVersion}")

logger.info("Initializing MONAI...")
monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3"
if upgrade:
monaiInstallString += " --upgrade"
install(monaiInstallString)

self.dependenciesInstalled = True
logger.info("Dependencies are set up successfully.")


class RemotePythonDependencies(DependenciesBase):

def installedMONAIPythonPackageInfo(self, server_address):
if not server_address:
return []
else:
import json
import requests
response = requests.get(server_address + "/monaiinfo")
json_data = json.loads(response.text)
return json_data

def setupPythonRequirements(self, upgrade=False):
logger.error("No permission to update remote python packages. Please contact developer.")


class SlicerPythonDependencies(DependenciesBase):

def installedMONAIPythonPackageInfo(self):
versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode()
return versionInfo

def setupPythonRequirements(self, upgrade=False):
# Install PyTorch
try:
import PyTorchUtils
except ModuleNotFoundError as e:
raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.")

logger.info("Initializing PyTorch...")

torchLogic = PyTorchUtils.PyTorchUtilsLogic()
if not torchLogic.torchInstalled():
logger.info("PyTorch Python package is required. Installing... (it may take several minutes)")
torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement=f">={self.minimumTorchVersion}")
if torch is None:
raise ValueError("PyTorch extension needs to be installed to use this module.")
else: # torch is installed, check version
from packaging import version
if version.parse(torchLogic.torch.__version__) < version.parse(self.minimumTorchVersion):
raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module."
+ f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch"
+ f" with version requirement set to: >={self.minimumTorchVersion}")

# Install MONAI with required components
logger.info("Initializing MONAI...")
# Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too).
# Without this, for some users monai-0.9.0 got installed, which failed with this error:
# "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'"
monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3"
if upgrade:
monaiInstallString += " --upgrade"
import slicer
slicer.util.pip_install(monaiInstallString)

self.dependenciesInstalled = True
logger.info("Dependencies are set up successfully.")
27 changes: 27 additions & 0 deletions MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import logging
from typing import Callable


class LogHandler(logging.Handler):
"""
code:
logger = logging.getLogger("XYZ")
callback = ... # any callable
# NB: only catching info level messages and forwarding it to callback
handler = LogHandler(callback, logging.INFO)
# can format log messages
formatter = logging.Formatter('%(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
"""
def __init__(self, callback: Callable, level=logging.NOTSET):
self._callback = callback
super().__init__(level)

def emit(self, record):
msg = self.format(record)
self._callback(msg)
196 changes: 196 additions & 0 deletions MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import json
import logging
import os
from pathlib import Path

from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec
from MONAIAuto3DSegLib.constants import APPLICATION_NAME


logger = logging.getLogger(APPLICATION_NAME)


class ModelDatabase:

@property
def defaultModel(self):
return self.models[0]["id"]

@property
def models(self):
if not self._models:
self._models = self.loadModelsDescription()
return self._models

@property
def modelsPath(self):
modelsPath = self.fileCachePath.joinpath("models")
modelsPath.mkdir(exist_ok=True, parents=True)
return modelsPath

@property
def modelsDescriptionJsonFilePath(self):
return os.path.join(self.moduleDir, "Resources", "Models.json")

def __init__(self):
self.fileCachePath = Path.home().joinpath(f".{APPLICATION_NAME}")
self.moduleDir = Path(__file__).parent.parent

# Disabling this flag preserves input and output data after execution is completed,
# which can be useful for troubleshooting.
self.clearOutputFolder = True
self._models = []

def model(self, modelId):
for model in self.models:
if model["id"] == modelId:
return model
raise RuntimeError(f"Model {modelId} not found")

def loadModelsDescription(self):
modelsJsonFilePath = self.modelsDescriptionJsonFilePath
try:
models = []
import json
import re
with open(modelsJsonFilePath) as f:
modelsTree = json.load(f)["models"]
for model in modelsTree:
deprecated = False
for version in model["versions"]:
url = version["url"]
# URL format: <path>/<filename>-v<version>.zip
# Example URL: https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/Models/17-segments-TotalSegmentator-v1.0.3.zip
match = re.search(r"(?P<filename>[^/]+)-v(?P<version>\d+\.\d+\.\d+)", url)
if match:
filename = match.group("filename")
version = match.group("version")
else:
logger.error(f"Failed to extract model id and version from url: {url}")
if "inputs" in model:
# Contains a list of dict. One dict for each input.
# Currently, only "title" (user-displayable name) and "namePattern" of the input are specified.
# In the future, inputs could have additional properties, such as name, type, optional, ...
inputs = model["inputs"]
else:
# Inputs are not defined, use default (single input volume)
inputs = [{"title": "Input volume"}]
segmentNames = model.get('segmentNames')
if not segmentNames:
segmentNames = "N/A"
models.append({
"id": f"{filename}-v{version}",
"title": model['title'],
"version": version,
"inputs": inputs,
"imagingModality": model["imagingModality"],
"description": model["description"],
"sampleData": model.get("sampleData"),
"segmentNames": model.get("segmentNames"),
"details":
f"<p><b>Model:</b> {model['title']} (v{version})"
f"<p><b>Description:</b> {model['description']}\n"
f"<p><b>Computation time on GPU:</b> {humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n"
f"<br><b>Computation time on CPU:</b> {humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n"
f"<p><b>Imaging modality:</b> {model['imagingModality']}\n"
f"<p><b>Subject:</b> {model['subject']}\n"
f"<p><b>Segments:</b> {', '.join(segmentNames)}",
"url": url,
"deprecated": deprecated
})
# First version is not deprecated, all subsequent versions are deprecated
deprecated = True
return models
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}")

def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath):
modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath

with open(modelsTestResultsJsonFilePath) as f:
modelsTestResults = json.load(f)

with open(modelsDescriptionJsonFilePath) as f:
modelsDescription = json.load(f)

for model in modelsDescription["models"]:
title = model["title"]
for modelTestResult in modelsTestResults:
if modelTestResult["title"] == title:
for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]:
fieldValue = modelTestResult.get(fieldName)
if fieldValue:
model[fieldName] = fieldValue
break

with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f:
json.dump(modelsDescription, f, indent=2)

def createModelsDir(self):
modelsDir = self.modelsPath
if not os.path.exists(modelsDir):
os.makedirs(modelsDir)

def modelPath(self, modelName):
try:
return self._modelPath(modelName)
except RuntimeError:
self.downloadModel(modelName)
return self._modelPath(modelName)

def _modelPath(self, modelName):
modelRoot = self.modelsPath.joinpath(modelName)
# find labels.csv file within the modelRoot folder and subfolders
for path in Path(modelRoot).rglob("labels.csv"):
return path.parent
raise RuntimeError(f"Model {modelName} path not found")

def deleteAllModels(self):
if self.modelsPath.exists():
import shutil
shutil.rmtree(self.modelsPath)

def downloadModel(self, modelName):
url = self.model(modelName)["url"]
import zipfile
import requests
from tempfile import TemporaryDirectory
with TemporaryDirectory() as td:
tempDir = Path(td)
modelDir = self.modelsPath.joinpath(modelName)
Path(modelDir).mkdir(exist_ok=True)
modelZipFile = tempDir.joinpath("autoseg3d_model.zip")
logger.info(f"Downloading model '{modelName}' from {url}...")
logger.debug(f"Downloading from {url} to {modelZipFile}...")
try:
with open(modelZipFile, 'wb') as f:
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get('content-length', 0))
reporting_increment_percent = 1.0
last_reported_download_percent = -reporting_increment_percent
downloaded_size = 0
for chunk in r.iter_content(chunk_size=8192 * 16):
f.write(chunk)
downloaded_size += len(chunk)
downloaded_percent = 100.0 * downloaded_size / total_size
if downloaded_percent - last_reported_download_percent > reporting_increment_percent:
logger.info(
f"Downloading model: {downloaded_size / 1024 / 1024:.1f}MB / {total_size / 1024 / 1024:.1f}MB ({downloaded_percent:.1f}%)")
last_reported_download_percent = downloaded_percent

logger.info(f"Download finished. Extracting to {modelDir}...")
with zipfile.ZipFile(modelZipFile, 'r') as zip_f:
zip_f.extractall(modelDir)
except Exception as e:
raise e
finally:
if self.clearOutputFolder:
logger.info("Cleaning up temporary model download folder...")
if os.path.isdir(tempDir):
import shutil
shutil.rmtree(tempDir)
else:
logger.info(f"Not cleaning up temporary model download folder: {tempDir}")
Loading

0 comments on commit a088ebb

Please sign in to comment.