-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: fastapi implementation for running inference remotely
- model database class that handles all local models including their download and/or deletion - extracted functions into utils module - fastapi server can be run from Slicer directly or from commandline - make sure loaded terminologies are searched
- Loading branch information
Showing
12 changed files
with
1,291 additions
and
495 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
APPLICATION_NAME = "MONAIAuto3DSeg" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import sys | ||
|
||
import shutil | ||
import subprocess | ||
import logging | ||
|
||
from MONAIAuto3DSegLib.constants import APPLICATION_NAME | ||
logger = logging.getLogger(APPLICATION_NAME) | ||
|
||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
class DependenciesBase(ABC): | ||
|
||
minimumTorchVersion = "1.12" | ||
|
||
def __init__(self): | ||
self.dependenciesInstalled = False # we don't know yet if dependencies have been installed | ||
|
||
@abstractmethod | ||
def installedMONAIPythonPackageInfo(self): | ||
pass | ||
|
||
@abstractmethod | ||
def setupPythonRequirements(self, upgrade=False): | ||
pass | ||
|
||
|
||
class LocalPythonDependencies(DependenciesBase): | ||
|
||
def installedMONAIPythonPackageInfo(self): | ||
versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode() | ||
return versionInfo | ||
|
||
def _checkModuleInstalled(self, moduleName): | ||
try: | ||
import importlib | ||
importlib.import_module(moduleName) | ||
return True | ||
except ModuleNotFoundError: | ||
return False | ||
|
||
def setupPythonRequirements(self, upgrade=False): | ||
def install(package): | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | ||
|
||
logger.info("Initializing PyTorch...") | ||
|
||
packageName = "torch" | ||
if not self._checkModuleInstalled(packageName): | ||
logger.info("PyTorch Python package is required. Installing... (it may take several minutes)") | ||
install(packageName) | ||
if not self._checkModuleInstalled(packageName): | ||
raise ValueError("pytorch needs to be installed to use this module.") | ||
else: # torch is installed, check version | ||
from packaging import version | ||
import torch | ||
if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion): | ||
raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module." | ||
+ f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" | ||
+ f" with version requirement set to: >={self.minimumTorchVersion}") | ||
|
||
logger.info("Initializing MONAI...") | ||
monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" | ||
if upgrade: | ||
monaiInstallString += " --upgrade" | ||
install(monaiInstallString) | ||
|
||
self.dependenciesInstalled = True | ||
logger.info("Dependencies are set up successfully.") | ||
|
||
|
||
class RemotePythonDependencies(DependenciesBase): | ||
|
||
def installedMONAIPythonPackageInfo(self, server_address): | ||
if not server_address: | ||
return [] | ||
else: | ||
import json | ||
import requests | ||
response = requests.get(server_address + "/monaiinfo") | ||
json_data = json.loads(response.text) | ||
return json_data | ||
|
||
def setupPythonRequirements(self, upgrade=False): | ||
logger.error("No permission to update remote python packages. Please contact developer.") | ||
|
||
|
||
class SlicerPythonDependencies(DependenciesBase): | ||
|
||
def installedMONAIPythonPackageInfo(self): | ||
versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode() | ||
return versionInfo | ||
|
||
def setupPythonRequirements(self, upgrade=False): | ||
# Install PyTorch | ||
try: | ||
import PyTorchUtils | ||
except ModuleNotFoundError as e: | ||
raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.") | ||
|
||
logger.info("Initializing PyTorch...") | ||
|
||
torchLogic = PyTorchUtils.PyTorchUtilsLogic() | ||
if not torchLogic.torchInstalled(): | ||
logger.info("PyTorch Python package is required. Installing... (it may take several minutes)") | ||
torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement=f">={self.minimumTorchVersion}") | ||
if torch is None: | ||
raise ValueError("PyTorch extension needs to be installed to use this module.") | ||
else: # torch is installed, check version | ||
from packaging import version | ||
if version.parse(torchLogic.torch.__version__) < version.parse(self.minimumTorchVersion): | ||
raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module." | ||
+ f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" | ||
+ f" with version requirement set to: >={self.minimumTorchVersion}") | ||
|
||
# Install MONAI with required components | ||
logger.info("Initializing MONAI...") | ||
# Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too). | ||
# Without this, for some users monai-0.9.0 got installed, which failed with this error: | ||
# "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'" | ||
monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" | ||
if upgrade: | ||
monaiInstallString += " --upgrade" | ||
import slicer | ||
slicer.util.pip_install(monaiInstallString) | ||
|
||
self.dependenciesInstalled = True | ||
logger.info("Dependencies are set up successfully.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import logging | ||
from typing import Callable | ||
|
||
|
||
class LogHandler(logging.Handler): | ||
""" | ||
code: | ||
logger = logging.getLogger("XYZ") | ||
callback = ... # any callable | ||
# NB: only catching info level messages and forwarding it to callback | ||
handler = LogHandler(callback, logging.INFO) | ||
# can format log messages | ||
formatter = logging.Formatter('%(levelname)s - %(message)s') | ||
handler.setFormatter(formatter) | ||
logger.addHandler(handler) | ||
""" | ||
def __init__(self, callback: Callable, level=logging.NOTSET): | ||
self._callback = callback | ||
super().__init__(level) | ||
|
||
def emit(self, record): | ||
msg = self.format(record) | ||
self._callback(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import json | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec | ||
from MONAIAuto3DSegLib.constants import APPLICATION_NAME | ||
|
||
|
||
logger = logging.getLogger(APPLICATION_NAME) | ||
|
||
|
||
class ModelDatabase: | ||
|
||
@property | ||
def defaultModel(self): | ||
return self.models[0]["id"] | ||
|
||
@property | ||
def models(self): | ||
if not self._models: | ||
self._models = self.loadModelsDescription() | ||
return self._models | ||
|
||
@property | ||
def modelsPath(self): | ||
modelsPath = self.fileCachePath.joinpath("models") | ||
modelsPath.mkdir(exist_ok=True, parents=True) | ||
return modelsPath | ||
|
||
@property | ||
def modelsDescriptionJsonFilePath(self): | ||
return os.path.join(self.moduleDir, "Resources", "Models.json") | ||
|
||
def __init__(self): | ||
self.fileCachePath = Path.home().joinpath(f".{APPLICATION_NAME}") | ||
self.moduleDir = Path(__file__).parent.parent | ||
|
||
# Disabling this flag preserves input and output data after execution is completed, | ||
# which can be useful for troubleshooting. | ||
self.clearOutputFolder = True | ||
self._models = [] | ||
|
||
def model(self, modelId): | ||
for model in self.models: | ||
if model["id"] == modelId: | ||
return model | ||
raise RuntimeError(f"Model {modelId} not found") | ||
|
||
def loadModelsDescription(self): | ||
modelsJsonFilePath = self.modelsDescriptionJsonFilePath | ||
try: | ||
models = [] | ||
import json | ||
import re | ||
with open(modelsJsonFilePath) as f: | ||
modelsTree = json.load(f)["models"] | ||
for model in modelsTree: | ||
deprecated = False | ||
for version in model["versions"]: | ||
url = version["url"] | ||
# URL format: <path>/<filename>-v<version>.zip | ||
# Example URL: https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/Models/17-segments-TotalSegmentator-v1.0.3.zip | ||
match = re.search(r"(?P<filename>[^/]+)-v(?P<version>\d+\.\d+\.\d+)", url) | ||
if match: | ||
filename = match.group("filename") | ||
version = match.group("version") | ||
else: | ||
logger.error(f"Failed to extract model id and version from url: {url}") | ||
if "inputs" in model: | ||
# Contains a list of dict. One dict for each input. | ||
# Currently, only "title" (user-displayable name) and "namePattern" of the input are specified. | ||
# In the future, inputs could have additional properties, such as name, type, optional, ... | ||
inputs = model["inputs"] | ||
else: | ||
# Inputs are not defined, use default (single input volume) | ||
inputs = [{"title": "Input volume"}] | ||
segmentNames = model.get('segmentNames') | ||
if not segmentNames: | ||
segmentNames = "N/A" | ||
models.append({ | ||
"id": f"{filename}-v{version}", | ||
"title": model['title'], | ||
"version": version, | ||
"inputs": inputs, | ||
"imagingModality": model["imagingModality"], | ||
"description": model["description"], | ||
"sampleData": model.get("sampleData"), | ||
"segmentNames": model.get("segmentNames"), | ||
"details": | ||
f"<p><b>Model:</b> {model['title']} (v{version})" | ||
f"<p><b>Description:</b> {model['description']}\n" | ||
f"<p><b>Computation time on GPU:</b> {humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n" | ||
f"<br><b>Computation time on CPU:</b> {humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n" | ||
f"<p><b>Imaging modality:</b> {model['imagingModality']}\n" | ||
f"<p><b>Subject:</b> {model['subject']}\n" | ||
f"<p><b>Segments:</b> {', '.join(segmentNames)}", | ||
"url": url, | ||
"deprecated": deprecated | ||
}) | ||
# First version is not deprecated, all subsequent versions are deprecated | ||
deprecated = True | ||
return models | ||
except Exception as e: | ||
import traceback | ||
traceback.print_exc() | ||
raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}") | ||
|
||
def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath): | ||
modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath | ||
|
||
with open(modelsTestResultsJsonFilePath) as f: | ||
modelsTestResults = json.load(f) | ||
|
||
with open(modelsDescriptionJsonFilePath) as f: | ||
modelsDescription = json.load(f) | ||
|
||
for model in modelsDescription["models"]: | ||
title = model["title"] | ||
for modelTestResult in modelsTestResults: | ||
if modelTestResult["title"] == title: | ||
for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]: | ||
fieldValue = modelTestResult.get(fieldName) | ||
if fieldValue: | ||
model[fieldName] = fieldValue | ||
break | ||
|
||
with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f: | ||
json.dump(modelsDescription, f, indent=2) | ||
|
||
def createModelsDir(self): | ||
modelsDir = self.modelsPath | ||
if not os.path.exists(modelsDir): | ||
os.makedirs(modelsDir) | ||
|
||
def modelPath(self, modelName): | ||
try: | ||
return self._modelPath(modelName) | ||
except RuntimeError: | ||
self.downloadModel(modelName) | ||
return self._modelPath(modelName) | ||
|
||
def _modelPath(self, modelName): | ||
modelRoot = self.modelsPath.joinpath(modelName) | ||
# find labels.csv file within the modelRoot folder and subfolders | ||
for path in Path(modelRoot).rglob("labels.csv"): | ||
return path.parent | ||
raise RuntimeError(f"Model {modelName} path not found") | ||
|
||
def deleteAllModels(self): | ||
if self.modelsPath.exists(): | ||
import shutil | ||
shutil.rmtree(self.modelsPath) | ||
|
||
def downloadModel(self, modelName): | ||
url = self.model(modelName)["url"] | ||
import zipfile | ||
import requests | ||
from tempfile import TemporaryDirectory | ||
with TemporaryDirectory() as td: | ||
tempDir = Path(td) | ||
modelDir = self.modelsPath.joinpath(modelName) | ||
Path(modelDir).mkdir(exist_ok=True) | ||
modelZipFile = tempDir.joinpath("autoseg3d_model.zip") | ||
logger.info(f"Downloading model '{modelName}' from {url}...") | ||
logger.debug(f"Downloading from {url} to {modelZipFile}...") | ||
try: | ||
with open(modelZipFile, 'wb') as f: | ||
with requests.get(url, stream=True) as r: | ||
r.raise_for_status() | ||
total_size = int(r.headers.get('content-length', 0)) | ||
reporting_increment_percent = 1.0 | ||
last_reported_download_percent = -reporting_increment_percent | ||
downloaded_size = 0 | ||
for chunk in r.iter_content(chunk_size=8192 * 16): | ||
f.write(chunk) | ||
downloaded_size += len(chunk) | ||
downloaded_percent = 100.0 * downloaded_size / total_size | ||
if downloaded_percent - last_reported_download_percent > reporting_increment_percent: | ||
logger.info( | ||
f"Downloading model: {downloaded_size / 1024 / 1024:.1f}MB / {total_size / 1024 / 1024:.1f}MB ({downloaded_percent:.1f}%)") | ||
last_reported_download_percent = downloaded_percent | ||
|
||
logger.info(f"Download finished. Extracting to {modelDir}...") | ||
with zipfile.ZipFile(modelZipFile, 'r') as zip_f: | ||
zip_f.extractall(modelDir) | ||
except Exception as e: | ||
raise e | ||
finally: | ||
if self.clearOutputFolder: | ||
logger.info("Cleaning up temporary model download folder...") | ||
if os.path.isdir(tempDir): | ||
import shutil | ||
shutil.rmtree(tempDir) | ||
else: | ||
logger.info(f"Not cleaning up temporary model download folder: {tempDir}") |
Oops, something went wrong.