Model: {model['title']} (v{version})" - f"
Description: {model['description']}\n" - f"
Computation time on GPU: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n"
- f"
Computation time on CPU: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n"
- f"
Imaging modality: {model['imagingModality']}\n" - f"
Subject: {model['subject']}\n" - f"
Segments: {', '.join(segmentNames)}",
- "url": url,
- "deprecated": deprecated
- })
- # First version is not deprecated, all subsequent versions are deprecated
- deprecated = True
- return models
- except Exception as e:
- import traceback
- traceback.print_exc()
- raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}")
+ DEPENDENCY_HANDLER = SlicerPythonDependencies()
@staticmethod
- def humanReadableTimeFromSec(seconds):
- import math
- if not seconds:
- return "N/A"
- if seconds < 55:
- # if less than a minute, round up to the nearest 5 seconds
- return f"{math.ceil(seconds/5) * 5} sec"
- elif seconds < 60 * 60:
- # if less then 1 hour, round up to the nearest minute
- return f"{math.ceil(seconds/60)} min"
- # Otherwise round up to the nearest 0.1 hour
- return f"{seconds/3600:.1f} h"
-
- def modelsPath(self):
- import pathlib
- return self.fileCachePath.joinpath("models")
-
- def createModelsDir(self):
- modelsDir = self.modelsPath()
- if not os.path.exists(modelsDir):
- os.makedirs(modelsDir)
-
- def modelPath(self, modelName):
- import pathlib
- modelRoot = self.modelsPath().joinpath(modelName)
- # find labels.csv file within the modelRoot folder and subfolders
- for path in pathlib.Path(modelRoot).rglob("labels.csv"):
- return path.parent
- raise RuntimeError(f"Model {modelName} path not found")
-
- def deleteAllModels(self):
- if self.modelsPath().exists():
- import shutil
- shutil.rmtree(self.modelsPath())
-
- def downloadAllModels(self):
- for model in self.models:
- slicer.app.processEvents()
- self.downloadModel(model["id"])
-
- def downloadModel(self, modelName):
+ def assignInputNodesByName(inputs, loadedSampleNodes):
- url = self.model(modelName)["url"]
+ def findFirstNodeByNamePattern(namePattern, nodes):
+ import fnmatch
+ for node in nodes:
+ if fnmatch.fnmatchcase(node.GetName(), namePattern):
+ return node
+ return None
- import zipfile
- import requests
- import pathlib
+ inputNodes = []
+ for inputIndex, input in enumerate(inputs):
+ namePattern = input.get("namePattern")
+ if namePattern:
+ matchingNode = findFirstNodeByNamePattern(namePattern, loadedSampleNodes)
+ else:
+ matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else \
+ loadedSampleNodes[0]
+ inputNodes.append(matchingNode)
+ return inputNodes
- tempDir = pathlib.Path(slicer.util.tempDirectory())
- modelDir = self.modelsPath().joinpath(modelName)
- if not os.path.exists(modelDir):
- os.makedirs(modelDir)
+ @staticmethod
+ def getLoadedTerminologyNames():
+ import vtk
+ terminologyNames = vtk.vtkStringArray()
+ terminologiesLogic = slicer.util.getModuleLogic("Terminologies")
+ terminologiesLogic.GetLoadedTerminologyNames(terminologyNames)
- modelZipFile = tempDir.joinpath("autoseg3d_model.zip")
- self.log(f"Downloading model '{modelName}' from {url}...")
- logging.debug(f"Downloading from {url} to {modelZipFile}...")
+ return [terminologyNames.GetValue(idx) for idx in range(terminologyNames.GetNumberOfValues())]
- try:
- with open(modelZipFile, 'wb') as f:
- with requests.get(url, stream=True) as r:
- r.raise_for_status()
- total_size = int(r.headers.get('content-length', 0))
- reporting_increment_percent = 1.0
- last_reported_download_percent = -reporting_increment_percent
- downloaded_size = 0
- for chunk in r.iter_content(chunk_size=8192 * 16):
- f.write(chunk)
- downloaded_size += len(chunk)
- downloaded_percent = 100.0 * downloaded_size / total_size
- if downloaded_percent - last_reported_download_percent > reporting_increment_percent:
- self.log(f"Downloading model: {downloaded_size/1024/1024:.1f}MB / {total_size/1024/1024:.1f}MB ({downloaded_percent:.1f}%)")
- last_reported_download_percent = downloaded_percent
-
- self.log(f"Download finished. Extracting to {modelDir}...")
- with zipfile.ZipFile(modelZipFile, 'r') as zip_f:
- zip_f.extractall(modelDir)
- except Exception as e:
- raise e
- finally:
- if self.clearOutputFolder:
- self.log("Cleaning up temporary model download folder...")
- if os.path.isdir(tempDir):
- import shutil
- shutil.rmtree(tempDir)
- else:
- self.log(f"Not cleaning up temporary model download folder: {tempDir}")
+ @staticmethod
+ def getLoadedAnatomicContextNames():
+ import vtk
+ anatomicContextNames = vtk.vtkStringArray()
+ terminologiesLogic = slicer.util.getModuleLogic("Terminologies")
+ terminologiesLogic.GetLoadedAnatomicContextNames(anatomicContextNames)
+ return [anatomicContextNames.GetValue(idx) for idx in range(anatomicContextNames.GetNumberOfValues())]
- def _MONAIAuto3DSegTerminologyPropertyTypes(self):
+ @staticmethod
+ def _terminologyPropertyTypes(terminologyName):
"""Get label terminology property types defined in from MONAI Auto3DSeg terminology.
Terminology entries are either in DICOM or MONAI Auto3DSeg "Segmentation category and type".
- """
+ # List of property type codes that are specified by in the terminology.
+ #
+ # Codes are stored as a list of strings containing coding scheme designator and code value of the property type,
+ # separated by "^" character. For example "SCT^123456".
+
+ """
terminologiesLogic = slicer.util.getModuleLogic("Terminologies")
- MONAIAuto3DSegTerminologyName = slicer.modules.MONAIAuto3DSegInstance.terminologyName
+ terminologyPropertyTypes = []
# Get anatomicalStructureCategory from the MONAI Auto3DSeg terminology
anatomicalStructureCategory = slicer.vtkSlicerTerminologyCategory()
- numberOfCategories = terminologiesLogic.GetNumberOfCategoriesInTerminology(MONAIAuto3DSegTerminologyName)
- for i in range(numberOfCategories):
- terminologiesLogic.GetNthCategoryInTerminology(MONAIAuto3DSegTerminologyName, i, anatomicalStructureCategory)
- if anatomicalStructureCategory.GetCodingSchemeDesignator() == "SCT" and anatomicalStructureCategory.GetCodeValue() == "123037004":
- # Found the (123037004, SCT, "Anatomical Structure") category within DICOM master list
- break
-
- # Retrieve all anatomicalStructureCategory property type codes
- terminologyPropertyTypes = []
- terminologyType = slicer.vtkSlicerTerminologyType()
- numberOfTypes = terminologiesLogic.GetNumberOfTypesInTerminologyCategory(MONAIAuto3DSegTerminologyName, anatomicalStructureCategory)
- for i in range(numberOfTypes):
- if terminologiesLogic.GetNthTypeInTerminologyCategory(MONAIAuto3DSegTerminologyName, anatomicalStructureCategory, i, terminologyType):
- terminologyPropertyTypes.append(terminologyType.GetCodingSchemeDesignator() + "^" + terminologyType.GetCodeValue())
+ numberOfCategories = terminologiesLogic.GetNumberOfCategoriesInTerminology(terminologyName)
+ for cIdx in range(numberOfCategories):
+ terminologiesLogic.GetNthCategoryInTerminology(terminologyName, cIdx, anatomicalStructureCategory)
+
+ # Retrieve all anatomicalStructureCategory property type codes
+ terminologyType = slicer.vtkSlicerTerminologyType()
+ numberOfTypes = terminologiesLogic.GetNumberOfTypesInTerminologyCategory(terminologyName,
+ anatomicalStructureCategory)
+ for tIdx in range(numberOfTypes):
+ if terminologiesLogic.GetNthTypeInTerminologyCategory(terminologyName, anatomicalStructureCategory, tIdx,
+ terminologyType):
+ terminologyPropertyTypes.append(
+ terminologyType.GetCodingSchemeDesignator() + "^" + terminologyType.GetCodeValue())
return terminologyPropertyTypes
- def _MONAIAuto3DSegAnatomicRegions(self):
+ @staticmethod
+ def _anatomicRegions(anatomicContextName):
"""Get anatomic regions defined in from MONAI Auto3DSeg terminology.
Terminology entries are either in DICOM or MONAI Auto3DSeg "Anatomic codes".
"""
@@ -884,102 +948,24 @@ def _MONAIAuto3DSegAnatomicRegions(self):
# when editing the terminology on the GUI)
return anatomicRegions
- MONAIAuto3DSegAnatomicContextName = slicer.modules.MONAIAuto3DSegInstance.anatomicContextName
-
# Retrieve all anatomical region codes
-
regionObject = slicer.vtkSlicerTerminologyType()
- numberOfRegions = terminologiesLogic.GetNumberOfRegionsInAnatomicContext(MONAIAuto3DSegAnatomicContextName)
+ numberOfRegions = terminologiesLogic.GetNumberOfRegionsInAnatomicContext(anatomicContextName)
for i in range(numberOfRegions):
- if terminologiesLogic.GetNthRegionInAnatomicContext(MONAIAuto3DSegAnatomicContextName, i, regionObject):
+ if terminologiesLogic.GetNthRegionInAnatomicContext(anatomicContextName, i, regionObject):
anatomicRegions.append(regionObject.GetCodingSchemeDesignator() + "^" + regionObject.GetCodeValue())
return anatomicRegions
- def labelDescriptions(self, modelName):
- """Return mapping from label value to label description.
- Label description is a dict containing "name" and "terminology".
- Terminology string uses Slicer terminology entry format - see specification at
- https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag
- """
-
- # Helper function to get code string from CSV file row
- def getCodeString(field, columnNames, row):
- columnValues = []
- for fieldName in ["CodingSchemeDesignator", "CodeValue", "CodeMeaning"]:
- columnIndex = columnNames.index(f"{field}.{fieldName}")
- try:
- columnValue = row[columnIndex]
- except IndexError:
- # Probably the line in the CSV file was not terminated by multiple commas (,)
- columnValue = ""
- columnValues.append(columnValue)
- return columnValues
-
- labelDescriptions = {}
- labelsFilePath = self.modelPath(modelName).joinpath("labels.csv")
- import csv
- with open(labelsFilePath, "r") as f:
- reader = csv.reader(f)
- columnNames = next(reader)
- data = {}
- # Loop through the rows of the csv file
- for row in reader:
-
- # Determine segmentation category (DICOM or MONAIAuto3DSeg)
- terminologyPropertyTypeStr = ( # Example: SCT^23451007
- row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodingSchemeDesignator")]
- + "^" + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodeValue")])
- if terminologyPropertyTypeStr in self.MONAIAuto3DSegTerminologyPropertyTypes:
- terminologyName = slicer.modules.MONAIAuto3DSegInstance.terminologyName
- else:
- terminologyName = "Segmentation category and type - DICOM master list"
-
- # Determine the anatomic context name (DICOM or MONAIAuto3DSeg)
- anatomicRegionStr = ( # Example: SCT^279245009
- row[columnNames.index("AnatomicRegionSequence.CodingSchemeDesignator")]
- + "^" + row[columnNames.index("AnatomicRegionSequence.CodeValue")])
- if anatomicRegionStr in self.MONAIAuto3DSegAnatomicRegions:
- anatomicContextName = slicer.modules.MONAIAuto3DSegInstance.anatomicContextName
- else:
- anatomicContextName = "Anatomic codes - DICOM master list"
-
- terminologyEntryStr = (
- terminologyName
- +"~"
- # Property category: "SCT^123037004^Anatomical Structure" or "SCT^49755003^Morphologically Altered Structure"
- + "^".join(getCodeString("SegmentedPropertyCategoryCodeSequence", columnNames, row))
- + "~"
- # Property type: "SCT^23451007^Adrenal gland", "SCT^367643001^Cyst", ...
- + "^".join(getCodeString("SegmentedPropertyTypeCodeSequence", columnNames, row))
- + "~"
- # Property type modifier: "SCT^7771000^Left", ...
- + "^".join(getCodeString("SegmentedPropertyTypeModifierCodeSequence", columnNames, row))
- + "~"
- + anatomicContextName
- + "~"
- # Anatomic region (set if category is not anatomical structure): "SCT^64033007^Kidney", ...
- + "^".join(getCodeString("AnatomicRegionSequence", columnNames, row))
- + "~"
- # Anatomic region modifier: "SCT^7771000^Left", ...
- + "^".join(getCodeString("AnatomicRegionModifierSequence", columnNames, row))
- )
-
- # Store the terminology string for this structure
- labelValue = int(row[columnNames.index("LabelValue")])
- name = row[columnNames.index("Name")]
- labelDescriptions[labelValue] = { "name": name, "terminology": terminologyEntryStr }
-
- return labelDescriptions
-
- def getSegmentLabelColor(self, terminologyEntryStr):
+ @staticmethod
+ def getSegmentLabelColor(terminologyEntryStr):
"""Get segment label and color from terminology"""
def labelColorFromTypeObject(typeObject):
"""typeObject is a terminology type or type modifier"""
label = typeObject.GetSlicerLabel() if typeObject.GetSlicerLabel() else typeObject.GetCodeMeaning()
rgb = typeObject.GetRecommendedDisplayRGBValue()
- return label, (rgb[0]/255.0, rgb[1]/255.0, rgb[2]/255.0)
+ return label, (rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0)
tlogic = slicer.modules.terminologies.logic()
@@ -987,21 +973,29 @@ def labelColorFromTypeObject(typeObject):
if not tlogic.DeserializeTerminologyEntry(terminologyEntryStr, terminologyEntry):
raise RuntimeError(f"Failed to deserialize terminology string: {terminologyEntryStr}")
- numberOfTypes = tlogic.GetNumberOfTypesInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject())
+ numberOfTypes = tlogic.GetNumberOfTypesInTerminologyCategory(terminologyEntry.GetTerminologyContextName(),
+ terminologyEntry.GetCategoryObject())
foundTerminologyEntry = slicer.vtkSlicerTerminologyEntry()
for typeIndex in range(numberOfTypes):
- tlogic.GetNthTypeInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), typeIndex, foundTerminologyEntry.GetTypeObject())
+ tlogic.GetNthTypeInTerminologyCategory(terminologyEntry.GetTerminologyContextName(),
+ terminologyEntry.GetCategoryObject(), typeIndex,
+ foundTerminologyEntry.GetTypeObject())
if terminologyEntry.GetTypeObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeObject().GetCodingSchemeDesignator():
continue
if terminologyEntry.GetTypeObject().GetCodeValue() != foundTerminologyEntry.GetTypeObject().GetCodeValue():
continue
if terminologyEntry.GetTypeModifierObject() and terminologyEntry.GetTypeModifierObject().GetCodeValue():
# Type has a modifier, get the color from there
- numberOfModifiers = tlogic.GetNumberOfTypeModifiersInTerminologyType(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), terminologyEntry.GetTypeObject())
+ numberOfModifiers = tlogic.GetNumberOfTypeModifiersInTerminologyType(
+ terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(),
+ terminologyEntry.GetTypeObject())
foundMatchingModifier = False
for modifierIndex in range(numberOfModifiers):
- tlogic.GetNthTypeModifierInTerminologyType(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), terminologyEntry.GetTypeObject(),
- modifierIndex, foundTerminologyEntry.GetTypeModifierObject())
+ tlogic.GetNthTypeModifierInTerminologyType(terminologyEntry.GetTerminologyContextName(),
+ terminologyEntry.GetCategoryObject(),
+ terminologyEntry.GetTypeObject(),
+ modifierIndex,
+ foundTerminologyEntry.GetTypeModifierObject())
if terminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator():
continue
if terminologyEntry.GetTypeModifierObject().GetCodeValue() != foundTerminologyEntry.GetTypeModifierObject().GetCodeValue():
@@ -1012,77 +1006,127 @@ def labelColorFromTypeObject(typeObject):
raise RuntimeError(f"Color was not found for terminology {terminologyEntryStr}")
- @staticmethod
- def _findFirstNodeBynamePattern(namePattern, nodes):
- import fnmatch
- for node in nodes:
- if fnmatch.fnmatchcase(node.GetName(), namePattern):
- return node
- return None
+ def __init__(self):
+ """
+ Called when the logic class is instantiated. Can be used for initializing member variables.
+ """
+ ScriptedLoadableModuleLogic.__init__(self)
+ ModelDatabase.__init__(self)
- @staticmethod
- def assignInputNodesByName(inputs, loadedSampleNodes):
- inputNodes = []
- for inputIndex, input in enumerate(inputs):
- namePattern = input.get("namePattern")
- if namePattern:
- matchingNode = MONAIAuto3DSegLogic._findFirstNodeBynamePattern(namePattern, loadedSampleNodes)
- else:
- matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else loadedSampleNodes[0]
- inputNodes.append(matchingNode)
- return inputNodes
+ self.logCallback = None
+ self.startResultImportCallback = None
+ self.endResultImportCallback = None
+ self.processingCompletedCallback = None
+ self.useStandardSegmentNames = True
+
+ # process that will used to run inference either remotely or locally
+ self._bgProcess = None
+
+ # For testing the logic without actually running inference, set self.debugSkipInferenceTempDir to the location
+ # where inference result is stored and set self.debugSkipInference to True.
+ # Disabling this flag preserves input and output data after execution is completed,
+ # which can be useful for troubleshooting.
+ self.clearOutputFolder = True
+ self.debugSkipInference = False
+ self.debugSkipInferenceTempDir = r"c:\Users\andra\AppData\Local\Temp\Slicer\__SlicerTemp__2024-01-16_15+26+25.624"
def log(self, text):
logging.info(text)
if self.logCallback:
self.logCallback(text)
- def installedMONAIPythonPackageInfo(self):
- import shutil
- import subprocess
- versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode()
- return versionInfo
+ def getMONAIPythonPackageInfo(self):
+ return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo()
def setupPythonRequirements(self, upgrade=False):
- import importlib.metadata
- import importlib.util
- import packaging
+ self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade)
+ return True
- # Install PyTorch
- try:
- import PyTorchUtils
- except ModuleNotFoundError as e:
- raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.")
-
- self.log("Initializing PyTorch...")
- minimumTorchVersion = "1.12"
- torchLogic = PyTorchUtils.PyTorchUtilsLogic()
- if not torchLogic.torchInstalled():
- self.log("PyTorch Python package is required. Installing... (it may take several minutes)")
- torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement = f">={minimumTorchVersion}")
- if torch is None:
- raise ValueError("PyTorch extension needs to be installed to use this module.")
- else:
- # torch is installed, check version
- from packaging import version
- if version.parse(torchLogic.torch.__version__) < version.parse(minimumTorchVersion):
- raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module."
- + f" Minimum required version is {minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch"
- + f" with version requirement set to: >={minimumTorchVersion}")
-
- # Install MONAI with required components
- self.log("Initializing MONAI...")
- # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too).
- # Without this, for some users monai-0.9.0 got installed, which failed with this error:
- # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'"
- monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3"
- if upgrade:
- monaiInstallString += " --upgrade"
- slicer.util.pip_install(monaiInstallString)
-
- self.dependenciesInstalled = True
- self.log("Dependencies are set up successfully.")
+ def labelDescriptions(self, modelName):
+ """Return mapping from label value to label description.
+ Label description is a dict containing "name" and "terminology".
+ Terminology string uses Slicer terminology entry format - see specification at
+ https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag
+ """
+ labelsFilePath = self.modelPath(modelName).joinpath("labels.csv")
+ return self._labelDescriptions(labelsFilePath)
+ def _labelDescriptions(self, labelsFilePath):
+ # Helper function to get code string from CSV file row
+ def getCodeString(field, columnNames, row):
+ columnValues = []
+ for fieldName in ["CodingSchemeDesignator", "CodeValue", "CodeMeaning"]:
+ columnIndex = columnNames.index(f"{field}.{fieldName}")
+ try:
+ columnValue = row[columnIndex]
+ except IndexError:
+ # Probably the line in the CSV file was not terminated by multiple commas (,)
+ columnValue = ""
+ columnValues.append(columnValue)
+ return columnValues
+
+ labelDescriptions = {}
+ import csv
+ with open(labelsFilePath, "r") as f:
+ reader = csv.reader(f)
+ columnNames = next(reader)
+ # Loop through the rows of the csv file
+ for row in reader:
+ # Determine segmentation category (DICOM or MONAIAuto3DSeg)
+ terminologyPropertyTypeStr = ( # Example: SCT^23451007
+ row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodingSchemeDesignator")]
+ + "^" + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodeValue")])
+ terminologyName = None
+
+ # If property the code is found in this list then the terminology will be used,
+ for tName in self.getLoadedTerminologyNames():
+ propertyTypes = self._terminologyPropertyTypes(tName)
+ if terminologyPropertyTypeStr in propertyTypes:
+ terminologyName = tName
+ break
+
+ # NB: DICOM terminology will be used otherwise. Note: the DICOM terminology does not contain all the
+ # necessary items and some items are incomplete (e.g., don't have color or 3D Slicer label).
+ if not terminologyName:
+ terminologyName = "Segmentation category and type - DICOM master list"
+
+ # Determine the anatomic context name (DICOM or MONAIAuto3DSeg)
+ anatomicRegionStr = ( # Example: SCT^279245009
+ row[columnNames.index("AnatomicRegionSequence.CodingSchemeDesignator")]
+ + "^" + row[columnNames.index("AnatomicRegionSequence.CodeValue")])
+ anatomicContextName = None
+ for aName in self.getLoadedAnatomicContextNames():
+ if anatomicRegionStr in self._anatomicRegions(aName):
+ anatomicContextName = aName
+ if not anatomicContextName:
+ anatomicContextName = "Anatomic codes - DICOM master list"
+
+ terminologyEntryStr = (
+ terminologyName
+ + "~"
+ # Property category: "SCT^123037004^Anatomical Structure" or "SCT^49755003^Morphologically Altered Structure"
+ + "^".join(getCodeString("SegmentedPropertyCategoryCodeSequence", columnNames, row))
+ + "~"
+ # Property type: "SCT^23451007^Adrenal gland", "SCT^367643001^Cyst", ...
+ + "^".join(getCodeString("SegmentedPropertyTypeCodeSequence", columnNames, row))
+ + "~"
+ # Property type modifier: "SCT^7771000^Left", ...
+ + "^".join(getCodeString("SegmentedPropertyTypeModifierCodeSequence", columnNames, row))
+ + "~"
+ + anatomicContextName
+ + "~"
+ # Anatomic region (set if category is not anatomical structure): "SCT^64033007^Kidney", ...
+ + "^".join(getCodeString("AnatomicRegionSequence", columnNames, row))
+ + "~"
+ # Anatomic region modifier: "SCT^7771000^Left", ...
+ + "^".join(getCodeString("AnatomicRegionModifierSequence", columnNames, row))
+ )
+
+ # Store the terminology string for this structure
+ labelValue = int(row[columnNames.index("LabelValue")])
+ name = row[columnNames.index("Name")]
+ labelDescriptions[labelValue] = {"name": name, "terminology": terminologyEntryStr}
+ return labelDescriptions
def setDefaultParameters(self, parameterNode):
"""
@@ -1092,72 +1136,48 @@ def setDefaultParameters(self, parameterNode):
parameterNode.SetParameter("Model", self.defaultModel)
if not parameterNode.GetParameter("UseStandardSegmentNames"):
parameterNode.SetParameter("UseStandardSegmentNames", "true")
-
- def logProcessOutputUntilCompleted(self, segmentationProcessInfo):
- # Wait for the process to end and forward output to the log
- from subprocess import CalledProcessError
- proc = segmentationProcessInfo["proc"]
- while True:
- try:
- line = proc.stdout.readline()
- if not line:
- break
- self.log(line.rstrip())
- except UnicodeDecodeError as e:
- # Code page conversion happens because `universal_newlines=True` sets process output to text mode,
- # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string,
- # as we only guarantee correct behavior if an UTF8 locale is used.
- pass
- proc.wait()
- retcode = proc.returncode
- segmentationProcessInfo["procReturnCode"] = retcode
- if retcode != 0:
- raise CalledProcessError(retcode, proc.args, output=proc.stdout, stderr=proc.stderr)
+ if not parameterNode.GetParameter("ServerPort"):
+ parameterNode.SetParameter("ServerPort", str(8891))
def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitForCompletion=True, customData=None):
-
"""
Run the processing algorithm.
Can be used without GUI widget.
:param inputNodes: input nodes in a list
- :param outputVolume: thresholding result
+ :param outputSegmentation: output segmentation to write to
:param model: one of self.models
:param cpu: use CPU instead of GPU
:param waitForCompletion: if True then the method waits for the processing to finish
:param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False
"""
+ if not self.DEPENDENCY_HANDLER.dependenciesInstalled:
+ with slicer.util.tryWithErrorDisplay("Failed to install required dependencies.", waitCursor=True):
+ self.DEPENDENCY_HANDLER.setupPythonRequirements()
+
if not inputNodes:
raise ValueError("Input nodes are invalid")
if not outputSegmentation:
raise ValueError("Output segmentation is invalid")
- if model == None:
+ if model is None:
model = self.defaultModel
- try:
- modelPath = self.modelPath(model)
- except:
- self.downloadModel(model)
- modelPath = self.modelPath(model)
+ modelPath = self.modelPath(model)
- segmentationProcessInfo = {}
+ segmentationProcessInfo = SegmentationProcessInfo()
- import time
- startTime = time.time()
- self.log("Processing started")
+ logging.info("Processing started")
if self.debugSkipInference:
+ self.clearOutputFolder = False
# For debugging, use a fixed temporary folder
tempDir = self.debugSkipInferenceTempDir
else:
# Create new empty folder
tempDir = slicer.util.tempDirectory()
- import pathlib
- tempDirPath = pathlib.Path(tempDir)
-
# Get Python executable path
import shutil
pythonSlicerExecutablePath = shutil.which("PythonSlicer")
@@ -1169,7 +1189,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor
for inputIndex, inputNode in enumerate(inputNodes):
if inputNode.IsA('vtkMRMLScalarVolumeNode'):
inputImageFile = tempDir + f"/input-volume{inputIndex}.nrrd"
- self.log(f"Writing input file to {inputImageFile}")
+ logging.info(f"Writing input file to {inputImageFile}")
volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode")
volumeStorageNode.SetFileName(inputImageFile)
volumeStorageNode.UseCompressionOff()
@@ -1190,141 +1210,46 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor
auto3DSegCommand.append(f"--image-file-{inputIndex+1}")
auto3DSegCommand.append(inputFiles[inputIndex])
- self.log("Creating segmentations with MONAIAuto3DSeg AI...")
- self.log(f"Auto3DSeg command: {auto3DSegCommand}")
+ logging.info("Creating segmentations with MONAIAuto3DSeg AI...")
+ logging.info(f"Auto3DSeg command: {auto3DSegCommand}")
additionalEnvironmentVariables = None
if cpu:
additionalEnvironmentVariables = {"CUDA_VISIBLE_DEVICES": "-1"}
- self.log(f"Additional environment variables: {additionalEnvironmentVariables}")
+ logging.info(f"Additional environment variables: {additionalEnvironmentVariables}")
+
+ segmentationProcessInfo.tempDir = tempDir
+ segmentationProcessInfo.inputNodes = inputNodes
+ segmentationProcessInfo.outputSegmentation = outputSegmentation
+ segmentationProcessInfo.outputSegmentationFile = outputSegmentationFile
+ segmentationProcessInfo.model = model
+ segmentationProcessInfo.customData = customData
+ self._bgProcess = LocalInference(processInfo=segmentationProcessInfo, logCallback=self.log, completedCallback=self.onSegmentationProcessCompleted)
if self.debugSkipInference:
- proc = None
- else:
- proc = slicer.util.launchConsoleProcess(auto3DSegCommand, updateEnvironment=additionalEnvironmentVariables)
-
- segmentationProcessInfo["proc"] = proc
- segmentationProcessInfo["procReturnCode"] = MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN
- segmentationProcessInfo["cancelRequested"] = False
- segmentationProcessInfo["startTime"] = startTime
- segmentationProcessInfo["tempDir"] = tempDir
- segmentationProcessInfo["segmentationProcess"] = proc
- segmentationProcessInfo["inputNodes"] = inputNodes
- segmentationProcessInfo["outputSegmentation"] = outputSegmentation
- segmentationProcessInfo["outputSegmentationFile"] = outputSegmentationFile
- segmentationProcessInfo["model"] = model
- segmentationProcessInfo["customData"] = customData
-
- if proc:
- if waitForCompletion:
- # Wait for the process to end before returning
- self.logProcessOutputUntilCompleted(segmentationProcessInfo)
- self.onSegmentationProcessCompleted(segmentationProcessInfo)
- else:
- # Run the process in the background
- self.startSegmentationProcessMonitoring(segmentationProcessInfo)
- else:
- # Debugging
+ segmentationProcessInfo.procReturnCode = 0
self.onSegmentationProcessCompleted(segmentationProcessInfo)
-
- return segmentationProcessInfo
-
- def cancelProcessing(self, segmentationProcessInfo):
- self.log("Cancel is requested.")
- segmentationProcessInfo["cancelRequested"] = True
- proc = segmentationProcessInfo.get("proc")
- if proc:
- # Simple proc.kill() would not work, that would only stop the launcher
- import psutil
- psProcess = psutil.Process(proc.pid)
- for psChildProcess in psProcess.children(recursive=True):
- psChildProcess.kill()
- if psProcess.is_running():
- psProcess.kill()
else:
- self.onSegmentationProcessCompleted(segmentationProcessInfo)
+ self._bgProcess.run(auto3DSegCommand, additionalEnvironmentVariables=additionalEnvironmentVariables, waitForCompletion=waitForCompletion)
- @staticmethod
- def _handleProcessOutputThreadProcess(segmentationProcessInfo):
- # Wait for the process to end and forward output to the log
- proc = segmentationProcessInfo["proc"]
- from subprocess import CalledProcessError
- while True:
- try:
- line = proc.stdout.readline()
- if not line:
- break
- segmentationProcessInfo["procOutputQueue"].put(line.rstrip())
- except UnicodeDecodeError as e:
- # Code page conversion happens because `universal_newlines=True` sets process output to text mode,
- # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string,
- # as we only guarantee correct behavior if an UTF8 locale is used.
- pass
- proc.wait()
- retcode = proc.returncode # non-zero return code means error
- segmentationProcessInfo["procReturnCode"] = retcode
-
-
- def startSegmentationProcessMonitoring(self, segmentationProcessInfo):
- import queue
- import sys
- import threading
-
- segmentationProcessInfo["procOutputQueue"] = queue.Queue()
- segmentationProcessInfo["procThread"] = threading.Thread(target=MONAIAuto3DSegLogic._handleProcessOutputThreadProcess, args=[segmentationProcessInfo])
- segmentationProcessInfo["procThread"].start()
-
- self.checkSegmentationProcessOutput(segmentationProcessInfo)
-
-
- def checkSegmentationProcessOutput(self, segmentationProcessInfo):
-
- import queue
- outputQueue = segmentationProcessInfo["procOutputQueue"]
- while outputQueue:
- if segmentationProcessInfo.get("procReturnCode") != MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN:
- self.onSegmentationProcessCompleted(segmentationProcessInfo)
- return
- try:
- line = outputQueue.get_nowait()
- self.log(line)
- except queue.Empty:
- break
-
- # No more outputs to process now, check again later
- import qt
- qt.QTimer.singleShot(self.processOutputCheckTimerIntervalMsec, lambda segmentationProcessInfo=segmentationProcessInfo: self.checkSegmentationProcessOutput(segmentationProcessInfo))
-
-
- def onSegmentationProcessCompleted(self, segmentationProcessInfo):
-
- startTime = segmentationProcessInfo["startTime"]
- tempDir = segmentationProcessInfo["tempDir"]
- inputNodes = segmentationProcessInfo["inputNodes"]
- outputSegmentation = segmentationProcessInfo["outputSegmentation"]
- outputSegmentationFile = segmentationProcessInfo["outputSegmentationFile"]
- model = segmentationProcessInfo["model"]
- customData = segmentationProcessInfo["customData"]
- procReturnCode = segmentationProcessInfo["procReturnCode"]
- cancelRequested = segmentationProcessInfo["cancelRequested"]
+ return segmentationProcessInfo
- if cancelRequested:
- procReturnCode = MONAIAuto3DSegLogic.EXIT_CODE_USER_CANCELLED
- self.log(f"Processing was cancelled.")
- else:
+ def onSegmentationProcessCompleted(self, segmentationProcessInfo: SegmentationProcessInfo):
+ procReturnCode = segmentationProcessInfo.procReturnCode
+ customData = segmentationProcessInfo.customData
+ cancelRequested = procReturnCode == ExitCode.USER_CANCELLED
+ if not cancelRequested:
if procReturnCode == 0:
-
+ outputSegmentation = segmentationProcessInfo.outputSegmentation
if self.startResultImportCallback:
self.startResultImportCallback(customData)
- try:
-
- # Load result
- self.log("Importing segmentation results...")
- self.readSegmentation(outputSegmentation, outputSegmentationFile, model)
+ try: # Load result
+ logging.info("Importing segmentation results...")
+ self.readSegmentation(outputSegmentation, segmentationProcessInfo.outputSegmentationFile, segmentationProcessInfo.model)
# Set source volume - required for DICOM Segmentation export
- inputVolume = inputNodes[0]
+ inputVolume = segmentationProcessInfo.inputNodes[0]
if not inputVolume.IsA('vtkMRMLScalarVolumeNode'):
raise ValueError("First input node must be a scalar volume")
outputSegmentation.SetNodeReferenceID(outputSegmentation.GetReferenceImageGeometryReferenceRole(), inputVolume.GetID())
@@ -1338,40 +1263,39 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo):
shNode.SetItemParent(segmentationShItem, studyShItem)
finally:
-
if self.endResultImportCallback:
self.endResultImportCallback(customData)
-
else:
- self.log(f"Processing failed with return code {procReturnCode}")
+ logging.info(f"Processing failed with return code {procReturnCode}")
+ tempDir = segmentationProcessInfo.tempDir
if self.clearOutputFolder:
- self.log("Cleaning up temporary folder.")
+ logging.info("Cleaning up temporary folder.")
if os.path.isdir(tempDir):
import shutil
shutil.rmtree(tempDir)
else:
- self.log(f"Not cleaning up temporary folder: {tempDir}")
+ logging.info(f"Not cleaning up temporary folder: {tempDir}")
# Report total elapsed time
- import time
- stopTime = time.time()
- segmentationProcessInfo["stopTime"] = stopTime
- elapsedTime = stopTime - startTime
+ elapsedTime = time.time() - segmentationProcessInfo.startTime
if cancelRequested:
- self.log(f"Processing was cancelled after {elapsedTime:.2f} seconds.")
+ logging.info(f"Processing was cancelled after {elapsedTime:.2f} seconds.")
else:
if procReturnCode == 0:
- self.log(f"Processing was completed in {elapsedTime:.2f} seconds.")
+ logging.info(f"Processing was completed in {elapsedTime:.2f} seconds.")
else:
- self.log(f"Processing failed after {elapsedTime:.2f} seconds.")
+ logging.info(f"Processing failed after {elapsedTime:.2f} seconds.")
if self.processingCompletedCallback:
self.processingCompletedCallback(procReturnCode, customData)
+ def cancelProcessing(self):
+ if not self._bgProcess:
+ return
+ self._bgProcess.stop()
def readSegmentation(self, outputSegmentation, outputSegmentationFile, model):
-
labelValueToDescription = self.labelDescriptions(model)
# Get label descriptions
@@ -1405,12 +1329,11 @@ def readSegmentation(self, outputSegmentation, outputSegmentationFile, model):
# Set terminology and color
for labelValue in labelValueToDescription:
- segmentName = labelValueToDescription[labelValue]["name"]
terminologyEntryStr = labelValueToDescription[labelValue]["terminology"]
- segmentId = segmentName
- self.setTerminology(outputSegmentation, segmentName, segmentId, terminologyEntryStr)
+ segmentId = labelValueToDescription[labelValue]["name"]
+ self.setTerminology(outputSegmentation, segmentId, terminologyEntryStr)
- def setTerminology(self, segmentation, segmentName, segmentId, terminologyEntryStr):
+ def setTerminology(self, segmentation, segmentId, terminologyEntryStr):
segment = segmentation.GetSegmentation().GetSegment(segmentId)
if not segment:
# Segment is not present in this segmentation
@@ -1423,31 +1346,145 @@ def setTerminology(self, segmentation, segmentName, segmentId, terminologyEntryS
segment.SetName(label)
segment.SetColor(color)
except RuntimeError as e:
- self.log(str(e))
+ logging.info(str(e))
- def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath):
- import json
- modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath()
+class RemoteMONAIAuto3DSegLogic(MONAIAuto3DSegLogic):
- with open(modelsTestResultsJsonFilePath) as f:
- modelsTestResults = json.load(f)
+ DEPENDENCY_HANDLER = RemotePythonDependencies()
+
+ @property
+ def server_address(self):
+ return self._server_address
- with open(modelsDescriptionJsonFilePath) as f:
- modelsDescription = json.load(f)
+ @server_address.setter
+ def server_address(self, address):
+ self.DEPENDENCY_HANDLER.server_address = address
+ self._server_address = address
+ self._models = []
- for model in modelsDescription["models"]:
- title = model["title"]
- for modelTestResult in modelsTestResults:
- if modelTestResult["title"] == title:
- for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]:
- fieldValue = modelTestResult.get(fieldName)
- if fieldValue:
- model[fieldName] = fieldValue
- break
+ def __init__(self):
+ self._server_address = None
+ MONAIAuto3DSegLogic.__init__(self)
+ self._models = []
+
+ def getMONAIPythonPackageInfo(self):
+ return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo()
+
+ def setupPythonRequirements(self, upgrade=False):
+ self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade)
+ return False
+
+ def loadModelsDescription(self):
+ if not self._server_address:
+ return []
+ else:
+ response = requests.get(self._server_address + "/models")
+ json_data = json.loads(response.text)
+ return json_data
+
+ def labelDescriptions(self, modelName):
+ """Return mapping from label value to label description.
+ Label description is a dict containing "name" and "terminology".
+ Terminology string uses Slicer terminology entry format - see specification at
+ https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag
+ """
+ if not self._server_address:
+ return {}
+ else:
+ from pathlib import Path
+ tempDir = slicer.util.tempDirectory()
+ tempDir = Path(tempDir)
+ outfile = tempDir / "labelDescriptions.csv"
+ with requests.get(self._server_address + f"/labelDescriptions?id={modelName}", stream=True) as r:
+ r.raise_for_status()
+
+ with open(outfile, 'wb') as f:
+ for chunk in r.iter_content(chunk_size=8192):
+ f.write(chunk)
+
+ labelDescriptions = self._labelDescriptions(outfile)
+
+ import shutil
+ shutil.rmtree(tempDir)
+ return labelDescriptions
+
+ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitForCompletion=True, customData=None):
+ """
+ Run the processing algorithm.
+ Can be used without GUI widget.
+ :param inputNodes: input nodes in a list
+ :param outputVolume: thresholding result
+ :param modelId: one of self.models
+ :param cpu: use CPU instead of GPU
+ :param waitForCompletion: if True then the method waits for the processing to finish
+ :param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False
+ """
+
+ segmentationProcessInfo = SegmentationProcessInfo()
+ logging.info("Processing started")
+
+ tempDir = slicer.util.tempDirectory()
+ outputSegmentationFile = tempDir + "/output-segmentation.nrrd"
+
+ from tempfile import TemporaryDirectory
+ with TemporaryDirectory(dir=tempDir) as temp_dir:
+ # Write input volume to file
+ from pathlib import Path
+ tempDir = Path(temp_dir)
+ inputFiles = []
+ for inputIndex, inputNode in enumerate(inputNodes):
+ if inputNode.IsA('vtkMRMLScalarVolumeNode'):
+ inputImageFile = tempDir / f"input-volume{inputIndex}.nrrd"
+ logging.info(f"Writing input file to {inputImageFile}")
+ volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode")
+ volumeStorageNode.SetFileName(inputImageFile)
+ volumeStorageNode.UseCompressionOff()
+ volumeStorageNode.WriteData(inputNode)
+ slicer.mrmlScene.RemoveNode(volumeStorageNode)
+ inputFiles.append(inputImageFile)
+ else:
+ raise ValueError(f"Input node type {inputNode.GetClassName()} is not supported")
+
+ logging.info(f"Initiating Inference on {self._server_address}")
+ files = {}
+
+ for idx, inputFile in enumerate(inputFiles, start=1):
+ name = "image_file"
+ if idx > 1:
+ name = f"{name}_{idx}"
+ files[name] = open(inputFile, 'rb')
+
+ r = None
+ try:
+ with requests.post(self._server_address + f"/infer?model_name={modelId}", files=files) as r:
+ r.raise_for_status()
+
+ with open(outputSegmentationFile, "wb") as binary_file:
+ for chunk in r.iter_content(chunk_size=8192):
+ binary_file.write(chunk)
+
+ segmentationProcessInfo.procReturnCode = 0
+ except requests.exceptions.HTTPError as e:
+ from http import HTTPStatus
+ status = HTTPStatus(e.response.status_code)
+ logging.debug(f"Server response content: {r.content}")
+ raise RuntimeError(status.description)
+ finally:
+ for f in files.values():
+ f.close()
+
+ segmentationProcessInfo.tempDir = tempDir
+ segmentationProcessInfo.inputNodes = inputNodes
+ segmentationProcessInfo.outputSegmentation = outputSegmentation
+ segmentationProcessInfo.outputSegmentationFile = outputSegmentationFile
+ segmentationProcessInfo.model = modelId
+ segmentationProcessInfo.customData = customData
+
+ self.onSegmentationProcessCompleted(segmentationProcessInfo)
+
+ return segmentationProcessInfo
- with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f:
- json.dump(modelsDescription, f, indent=2)
#
# MONAIAuto3DSegTest
@@ -1556,17 +1593,14 @@ def test_MONAIAuto3DSeg1(self):
raise RuntimeError(f"Failed to load sample data set '{sampleDataName}'.")
# Set model inputs
-
- inputNodes = []
inputs = model.get("inputs")
- inputNodes = MONAIAuto3DSegLogic.assignInputNodesByName(inputs, loadedSampleNodes)
+ inputNodes = logic.assignInputNodesByName(inputs, loadedSampleNodes)
outputSegmentation = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
# Run the segmentation
self.delayDisplay(f"Running segmentation for {model['title']}...")
- import time
startTime = time.time()
logic.process(inputNodes, outputSegmentation, model["id"], forceUseCpu)
segmentationTimeSec = time.time() - startTime
@@ -1684,14 +1718,12 @@ def _writeScreenshots(self, segmentationNode, outputPath, baseName, numberOfImag
return sliceScreenshotFilename, rotate3dScreenshotFilename
def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestResultsMarkdownFilePath=None, screenshotUrlBase=None):
-
if modelsTestResultsMarkdownFilePath is None:
modelsTestResultsMarkdownFilePath = modelsTestResultsJsonFilePath.replace(".json", ".md")
if screenshotUrlBase is None:
screenshotUrlBase = "https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/ModelsTestResults/"
import json
- from MONAIAuto3DSeg import MONAIAuto3DSegLogic
with open(modelsTestResultsJsonFilePath) as f:
modelsTestResults = json.load(f)
@@ -1718,7 +1750,7 @@ def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestR
title = f"{model['title']} (v{model['version']})"
f.write(f"## {title}\n")
f.write(f"{model['description']}\n\n")
- f.write(f"Processing time: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n")
+ f.write(f"Processing time: {humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n")
f.write(f"Segment names: {', '.join(model['segmentNames'])}\n\n")
f.write(f"![2D view]({screenshotUrlBase}{model['segmentationResultsScreenshot2D']})\n")
- f.write(f"![3D view]({screenshotUrlBase}{model['segmentationResultsScreenshot3D']})\n")
\ No newline at end of file
+ f.write(f"![3D view]({screenshotUrlBase}{model['segmentationResultsScreenshot3D']})\n")
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py
new file mode 100644
index 0000000..52fda3a
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py
@@ -0,0 +1,149 @@
+import shutil
+import subprocess
+import logging
+import sys
+
+
+from abc import ABC, abstractmethod
+
+
+class DependenciesBase(ABC):
+
+ minimumTorchVersion = "1.12"
+
+ def __init__(self):
+ self.dependenciesInstalled = False # we don't know yet if dependencies have been installed
+
+ @abstractmethod
+ def installedMONAIPythonPackageInfo(self):
+ pass
+
+ @abstractmethod
+ def setupPythonRequirements(self, upgrade=False):
+ pass
+
+
+class RemotePythonDependencies(DependenciesBase):
+
+ def __init__(self):
+ super().__init__()
+ self._server_address = None
+
+ @property
+ def server_address(self):
+ return self._server_address
+
+ @server_address.setter
+ def server_address(self, address):
+ self._server_address = address
+
+ def installedMONAIPythonPackageInfo(self):
+ if not self._server_address:
+ return []
+ else:
+ import json
+ import requests
+ response = requests.get(self._server_address + "/monaiinfo")
+ json_data = json.loads(response.text)
+ return json_data
+
+ def setupPythonRequirements(self, upgrade=False):
+ logging.error("No permission to update remote python packages. Please contact developer.")
+
+
+class SlicerPythonDependencies(DependenciesBase):
+ """ Dependency handler when being used within 3D Slicer (SlicerPython) environment. """
+
+ def installedMONAIPythonPackageInfo(self):
+ versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode()
+ return versionInfo
+
+ def setupPythonRequirements(self, upgrade=False):
+ # Install PyTorch
+ try:
+ import PyTorchUtils
+ except ModuleNotFoundError as e:
+ raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.")
+
+ logging.info("Initializing PyTorch...")
+
+ torchLogic = PyTorchUtils.PyTorchUtilsLogic()
+ if not torchLogic.torchInstalled():
+ logging.info("PyTorch Python package is required. Installing... (it may take several minutes)")
+ torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement=f">={self.minimumTorchVersion}")
+ if torch is None:
+ raise ValueError("PyTorch extension needs to be installed to use this module.")
+ else: # torch is installed, check version
+ from packaging import version
+ if version.parse(torchLogic.torch.__version__) < version.parse(self.minimumTorchVersion):
+ raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module."
+ + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch"
+ + f" with version requirement set to: >={self.minimumTorchVersion}")
+
+ # Install MONAI with required components
+ logging.info("Initializing MONAI...")
+ # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too).
+ # Without this, for some users monai-0.9.0 got installed, which failed with this error:
+ # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'"
+ monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3"
+ if upgrade:
+ monaiInstallString += " --upgrade"
+ import slicer
+ slicer.util.pip_install(monaiInstallString)
+
+ self.dependenciesInstalled = True
+ logging.info("Dependencies are set up successfully.")
+
+
+class NonSlicerPythonDependencies(DependenciesBase):
+ """ Dependency handler when running locally (not within 3D Slicer)
+
+ code:
+
+ from MONAIAuto3DSegLib.dependency_handler import LocalPythonDependencies
+ dependencies = LocalPythonDependencies()
+ dependencies.installedMONAIPythonPackageInfo()
+
+ # dependencies.setupPythonRequirements(upgrade=True)
+ """
+
+ def installedMONAIPythonPackageInfo(self):
+ versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode()
+ return versionInfo
+
+ def _checkModuleInstalled(self, moduleName):
+ try:
+ import importlib
+ importlib.import_module(moduleName)
+ return True
+ except ModuleNotFoundError:
+ return False
+
+ def setupPythonRequirements(self, upgrade=False):
+ def install(package):
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
+
+ logging.debug("Initializing PyTorch...")
+
+ packageName = "torch"
+ if not self._checkModuleInstalled(packageName):
+ logging.debug("PyTorch Python package is required. Installing... (it may take several minutes)")
+ install(packageName)
+ if not self._checkModuleInstalled(packageName):
+ raise ValueError("pytorch needs to be installed to use this module.")
+ else: # torch is installed, check version
+ from packaging import version
+ import torch
+ if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion):
+ raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module."
+ + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch"
+ + f" with version requirement set to: >={self.minimumTorchVersion}")
+
+ logging.debug("Initializing MONAI...")
+ monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3"
+ if upgrade:
+ monaiInstallString += " --upgrade"
+ install(monaiInstallString)
+
+ self.dependenciesInstalled = True
+ logging.debug("Dependencies are set up successfully.")
\ No newline at end of file
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py
new file mode 100644
index 0000000..c558f0e
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py
@@ -0,0 +1,200 @@
+import json
+import logging
+import os
+from pathlib import Path
+
+from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec
+
+
+class ModelDatabase:
+ """ Retrieve model information, download and store models in model cache directory. Model information is stored
+ in local Models.json.
+ """
+
+ DEFAULT_CACHE_DIR_NAME = ".MONAIAuto3DSeg"
+
+ @property
+ def defaultModel(self):
+ return self.models[0]["id"]
+
+ @property
+ def models(self):
+ if not self._models:
+ self._models = self.loadModelsDescription()
+ return self._models
+
+ @property
+ def modelsPath(self):
+ modelsPath = self.fileCachePath.joinpath("models")
+ modelsPath.mkdir(exist_ok=True, parents=True)
+ return modelsPath
+
+ @property
+ def modelsDescriptionJsonFilePath(self):
+ return os.path.join(self.moduleDir, "Resources", "Models.json")
+
+ def __init__(self):
+ self.fileCachePath = Path.home().joinpath(f"{self.DEFAULT_CACHE_DIR_NAME}")
+ self.moduleDir = Path(__file__).parent.parent
+ self._models = []
+ self._clearTempDownloadFolder = True
+
+ def model(self, modelId):
+ for model in self.models:
+ if model["id"] == modelId:
+ return model
+ raise RuntimeError(f"Model {modelId} not found")
+
+ def loadModelsDescription(self):
+ modelsJsonFilePath = self.modelsDescriptionJsonFilePath
+ try:
+ models = []
+ import json
+ import re
+ with open(modelsJsonFilePath) as f:
+ modelsTree = json.load(f)["models"]
+ for model in modelsTree:
+ deprecated = False
+ for version in model["versions"]:
+ url = version["url"]
+ # URL format: Model: {model['title']} (v{version})"
+ f" Description: {model['description']}\n"
+ f" Computation time on GPU: {humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n"
+ f" Imaging modality: {model['imagingModality']}\n"
+ f" Subject: {model['subject']}\n"
+ f" Segments: {', '.join(segmentNames)}",
+ "url": url,
+ "deprecated": deprecated
+ })
+ # First version is not deprecated, all subsequent versions are deprecated
+ deprecated = True
+ return models
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}")
+
+ def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath):
+ modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath
+
+ with open(modelsTestResultsJsonFilePath) as f:
+ modelsTestResults = json.load(f)
+
+ with open(modelsDescriptionJsonFilePath) as f:
+ modelsDescription = json.load(f)
+
+ for model in modelsDescription["models"]:
+ title = model["title"]
+ for modelTestResult in modelsTestResults:
+ if modelTestResult["title"] == title:
+ for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]:
+ fieldValue = modelTestResult.get(fieldName)
+ if fieldValue:
+ model[fieldName] = fieldValue
+ break
+
+ with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f:
+ json.dump(modelsDescription, f, indent=2)
+
+ def createModelsDir(self):
+ modelsDir = self.modelsPath
+ if not os.path.exists(modelsDir):
+ os.makedirs(modelsDir)
+
+ def modelPath(self, modelName):
+ try:
+ return self._modelPath(modelName)
+ except RuntimeError:
+ self.downloadModel(modelName)
+ return self._modelPath(modelName)
+
+ def _modelPath(self, modelName):
+ modelRoot = self.modelsPath.joinpath(modelName)
+ # find labels.csv file within the modelRoot folder and subfolders
+ for path in Path(modelRoot).rglob("labels.csv"):
+ return path.parent
+ raise RuntimeError(f"Model {modelName} path not found")
+
+ def deleteAllModels(self):
+ if self.modelsPath.exists():
+ import shutil
+ shutil.rmtree(self.modelsPath)
+
+ def downloadAllModels(self):
+ # TODO: add some progress here since this could take a while for all models
+ for model in self.models:
+ #slicer.app.processEvents()
+ self.downloadModel(model["id"])
+
+ def downloadModel(self, modelName):
+ url = self.model(modelName)["url"]
+ import zipfile
+ import requests
+ from tempfile import TemporaryDirectory
+ with TemporaryDirectory() as td:
+ tempDir = Path(td)
+ modelDir = self.modelsPath.joinpath(modelName)
+ Path(modelDir).mkdir(exist_ok=True)
+ modelZipFile = tempDir.joinpath("autoseg3d_model.zip")
+ logging.info(f"Downloading model '{modelName}' from {url}...")
+ logging.debug(f"Downloading from {url} to {modelZipFile}...")
+ try:
+ with open(modelZipFile, 'wb') as f:
+ with requests.get(url, stream=True) as r:
+ r.raise_for_status()
+ total_size = int(r.headers.get('content-length', 0))
+ reporting_increment_percent = 1.0
+ last_reported_download_percent = -reporting_increment_percent
+ downloaded_size = 0
+ for chunk in r.iter_content(chunk_size=8192 * 16):
+ f.write(chunk)
+ downloaded_size += len(chunk)
+ downloaded_percent = 100.0 * downloaded_size / total_size
+ if downloaded_percent - last_reported_download_percent > reporting_increment_percent:
+ logging.info(
+ f"Downloading model: {downloaded_size / 1024 / 1024:.1f}MB / {total_size / 1024 / 1024:.1f}MB ({downloaded_percent:.1f}%)")
+ last_reported_download_percent = downloaded_percent
+
+ logging.info(f"Download finished. Extracting to {modelDir}...")
+ with zipfile.ZipFile(modelZipFile, 'r') as zip_f:
+ zip_f.extractall(modelDir)
+ except Exception as e:
+ raise e
+ finally:
+ if self._clearTempDownloadFolder:
+ logging.info("Cleaning up temporary model download folder...")
+ if os.path.isdir(tempDir):
+ import shutil
+ shutil.rmtree(tempDir)
+ else:
+ logging.info(f"Not cleaning up temporary model download folder: {tempDir}")
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py
new file mode 100644
index 0000000..ba330bf
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py
@@ -0,0 +1,265 @@
+import subprocess
+
+import slicer
+
+import sys
+import logging
+import queue
+import threading
+
+import qt
+from pathlib import Path
+from typing import Callable, Any
+import time
+from dataclasses import dataclass, field
+
+from enum import Enum
+
+
+class ExitCode(Enum):
+ USER_CANCELLED = 1001
+ DID_NOT_RUN = 1002
+
+
+@dataclass
+class ProcessInfo:
+ proc: subprocess.Popen = None
+ startTime: float = field(default_factory=time.time)
+ procReturnCode: ExitCode = ExitCode.DID_NOT_RUN
+ procOutputQueue: queue.Queue = queue.Queue()
+ procThread: threading.Thread = None
+
+
+class SegmentationProcessInfo(ProcessInfo):
+ tempDir: str = ""
+ inputNodes: list = None
+ outputSegmentation: slicer.vtkMRMLSegmentationNode = None
+ outputSegmentationFile: str = ""
+ model: str = ""
+ customData: Any = None
+
+
+class BackgroundProcess:
+ """ Any kind of process with threads and continuous checking until stopped"""
+
+ # Timer for checking the output of the process that is running in the background
+ CHECK_TIMER_INTERVAL = 1000
+
+ @property
+ def proc(self):
+ return self.processInfo.proc
+
+ @proc.setter
+ def proc(self, proc):
+ self.processInfo.proc = proc
+
+ @property
+ def procThread(self):
+ return self.processInfo.procThread
+
+ @procThread.setter
+ def procThread(self, procThread):
+ self.processInfo.procThread = procThread
+
+ @property
+ def procOutputQueue(self):
+ return self.processInfo.procOutputQueue
+
+ @procOutputQueue.setter
+ def procOutputQueue(self, procOutputQueue):
+ self.processInfo.procOutputQueue = procOutputQueue
+
+ @staticmethod
+ def getPSProcess(pid):
+ import psutil
+ try:
+ return psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ return None
+
+ def __init__(self, processInfo: ProcessInfo = None, logCallback: Callable = None, completedCallback: Callable = None):
+ self.processInfo = processInfo if processInfo else ProcessInfo()
+ self.logCallback = logCallback
+ self.completedCallback = completedCallback
+
+ # NB: making sure that the following values were not set previously
+ self.proc = None
+ self.procThread = None
+ self.procOutputQueue = None
+
+ def __del__(self):
+ self._killProcess()
+
+ def handleSubProcessLogging(self, text):
+ logging.info(text)
+
+ def cleanup(self):
+ if self.procThread:
+ self.procThread.join()
+ if self.completedCallback:
+ self.completedCallback(self.processInfo)
+ self.proc = None
+ self.procThread = None
+ self.procOutputQueue = None
+
+ def isRunning(self):
+ if self.proc is not None:
+ psProcess = self.getPSProcess(self.proc.pid)
+ if psProcess:
+ return psProcess.is_running()
+ return False
+
+ def _killProcess(self):
+ if not self.proc:
+ return
+ psProcess = self.getPSProcess(self.proc.pid) # proc.kill() does not work, that would only stop the launcher
+ if not psProcess:
+ return
+ for psChildProcess in psProcess.children(recursive=True):
+ psChildProcess.kill()
+ if psProcess.is_running():
+ psProcess.kill()
+
+ def stop(self):
+ self._killProcess()
+ self._setProcReturnCode(ExitCode.USER_CANCELLED)
+
+ def _startHandleProcessOutputThread(self):
+ self.procOutputQueue = queue.Queue()
+ self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess)
+ self.procThread.start()
+ self.checkProcessOutput()
+
+ def _handleProcessOutputThreadProcess(self):
+ while True:
+ try:
+ line = self.proc.stdout.readline()
+ if not line:
+ break
+ text = line.rstrip()
+ self.procOutputQueue.put(text)
+ except UnicodeDecodeError as e:
+ pass
+ self.proc.wait()
+ self._setProcReturnCode(self.proc.returncode) # non-zero return code means error
+
+ def checkProcessOutput(self):
+ outputQueue = self.procOutputQueue
+ while outputQueue:
+ try:
+ line = outputQueue.get_nowait()
+ self.handleSubProcessLogging(line)
+ except queue.Empty:
+ break
+ if self.processInfo.procReturnCode != ExitCode.DID_NOT_RUN:
+ self.completedCallback(self.processInfo)
+ return
+
+ psProcess = self.getPSProcess(self.proc.pid)
+ if psProcess and psProcess.is_running(): # No more outputs to process now, check again later
+ qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput)
+ else:
+ self.cleanup()
+
+ def addLog(self, text):
+ if self.logCallback:
+ self.logCallback(text)
+
+ def _setProcReturnCode(self, rcode):
+ # if user cancelled, leave it at that and don't change it
+ if self.processInfo.procReturnCode == ExitCode.USER_CANCELLED:
+ return
+ self.processInfo.procReturnCode = rcode
+
+
+
+class InferenceServer(BackgroundProcess):
+ """ Web server class to be used from 3D Slicer. Upon starting the server with the given cmd command, it will be
+ checked every {CHECK_TIMER_INTERVAL} milliseconds for process status/outputs.
+
+ code:
+
+ cmd = [sys.executable, "main.py", "--host", hostName, "--port", port]
+
+ from MONAIAuto3DSegLib.process import InferenceServer
+ server = InferenceServer(completedCallback=...)
+ server.start()
+
+ ...
+
+ server.stop()
+
+ """
+
+ def __init__(self, processInfo: ProcessInfo = None, logCallback: Callable = None,
+ completedCallback: Callable = None):
+ super().__init__(processInfo, logCallback, completedCallback)
+ self.hostName = "127.0.0.1"
+ self.port = 8891
+
+ def getAddressUrl(self):
+ return f"http://{self.hostName}:{self.port}"
+
+ def start(self):
+ cmd = [
+ sys.executable,
+ str(Path(__file__).parent.parent / "MONAIAuto3DSegServer" / "main.py"),
+ "--host", self.hostName,
+ "--port", self.port
+ ]
+
+ logging.debug(f"Launching process: {cmd}")
+ self.proc = slicer.util.launchConsoleProcess(cmd, useStartupEnvironment=False)
+ self._startHandleProcessOutputThread()
+
+ def handleSubProcessLogging(self, text):
+ # NB: let upper level handle if it should be logged to console or UI
+ self.addLog(text)
+
+
+class LocalInference(BackgroundProcess):
+ """ Running local inference until finished or cancelled. """
+
+ def __init__(self, processInfo: SegmentationProcessInfo, logCallback: Callable = None,
+ completedCallback: Callable = None, waitForCompletion: bool = True):
+ super().__init__(processInfo, logCallback, completedCallback)
+ self.waitForCompletion = waitForCompletion
+
+ def run(self, cmd, additionalEnvironmentVariables=None, waitForCompletion=True):
+ logging.debug(f"Launching process: {cmd}")
+ self.proc = slicer.util.launchConsoleProcess(cmd, updateEnvironment=additionalEnvironmentVariables)
+
+ if self.isRunning():
+ self.addLog("Process Started")
+
+ if waitForCompletion:
+ self.logProcessOutputUntilCompleted()
+ self.completedCallback(self.processInfo)
+ else:
+ self._startHandleProcessOutputThread()
+
+ def handleSubProcessLogging(self, text):
+ self.addLog(text)
+ logging.info(text)
+
+ def logProcessOutputUntilCompleted(self):
+ # Wait for the process to end and forward output to the log
+ proc = self.proc
+ while True:
+ try:
+ line = proc.stdout.readline()
+ if not line:
+ break
+ logging.info(line.rstrip())
+ except UnicodeDecodeError as e:
+ # Code page conversion happens because `universal_newlines=True` sets process output to text mode,
+ # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string,
+ # as we only guarantee correct behavior if an UTF8 locale is used.
+ pass
+ proc.wait()
+ retcode = proc.returncode
+ self._setProcReturnCode(retcode)
+
+ if retcode != 0:
+ from subprocess import CalledProcessError
+ raise CalledProcessError(proc.returncode, proc.args, output=proc.stdout, stderr=proc.stderr)
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py
new file mode 100644
index 0000000..4575b22
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py
@@ -0,0 +1,14 @@
+
+
+def humanReadableTimeFromSec(seconds):
+ import math
+ if not seconds:
+ return "N/A"
+ if seconds < 55:
+ # if less than a minute, round up to the nearest 5 seconds
+ return f"{math.ceil(seconds / 5) * 5} sec"
+ elif seconds < 60 * 60:
+ # if less than 1 hour, round up to the nearest minute
+ return f"{math.ceil(seconds / 60)} min"
+ # Otherwise round up to the nearest 0.1 hour
+ return f"{seconds / 3600:.1f} h"
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegServer/__init__.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py
new file mode 100644
index 0000000..7e3d309
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py
@@ -0,0 +1,156 @@
+# pip install python-multipart fastapi slowapi uvicorn[standard]
+
+# usage: uvicorn main:app --host example.com --port 8891
+# usage: uvicorn main:app --host localhost --port 8891
+
+
+import os
+import logging
+import sys
+from pathlib import Path
+
+
+
+paths = [str(Path(__file__).parent.parent)]
+for path in paths:
+ if not path in sys.path:
+ sys.path.insert(0, path)
+
+from MONAIAuto3DSegLib.model_database import ModelDatabase
+
+import shutil
+import asyncio
+import subprocess
+from fastapi import FastAPI, UploadFile, Request
+from fastapi.responses import FileResponse, JSONResponse
+from fastapi.background import BackgroundTasks
+
+from slowapi import Limiter, _rate_limit_exceeded_handler
+from slowapi.errors import RateLimitExceeded
+
+limiter = Limiter(key_func=lambda request: "request_per_route_per_minute")
+app = FastAPI()
+app.state.limiter = limiter
+app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+modelDB = ModelDatabase()
+
+# deciding which dependencies to choose
+if "python-real" in Path(sys.executable).name:
+ from MONAIAuto3DSegLib.dependency_handler import SlicerPythonDependencies
+ dependencyHandler = SlicerPythonDependencies()
+else:
+ from MONAIAuto3DSegLib.dependency_handler import NonSlicerPythonDependencies
+ dependencyHandler = NonSlicerPythonDependencies()
+ dependencyHandler.setupPythonRequirements()
+
+logging.debug(f"Using {dependencyHandler.__class__.__name__} as dependency handler")
+
+
+def upload(file, session_dir, identifier):
+ extension = "".join(Path(file.filename).suffixes)
+ file_location = str(Path(session_dir) / f"{identifier}{extension}")
+ with open(file_location, "wb+") as file_object:
+ file_object.write(file.file.read())
+ return file_location
+
+
+@app.get("/monaiinfo")
+def monaiInfo():
+ return dependencyHandler.installedMONAIPythonPackageInfo()
+
+
+@app.get("/models")
+async def models():
+ return modelDB.models
+
+
+@app.get("/modelinfo")
+async def getModelInfo(id: str):
+ return modelDB.model(id)
+
+
+@app.get("/labelDescriptions")
+def getLabelsFile(id: str):
+ return FileResponse(modelDB.modelPath(id).joinpath("labels.csv"), media_type = 'application/octet-stream', filename="labels.csv")
+
+
+@app.post("/infer")
+@limiter.limit("5/minute")
+async def infer(
+ request: Request,
+ background_tasks: BackgroundTasks,
+ image_file: UploadFile,
+ model_name: str,
+ image_file_2: UploadFile = None,
+ image_file_3: UploadFile = None,
+ image_file_4: UploadFile = None
+):
+ import tempfile
+ session_dir = tempfile.mkdtemp(dir=tempfile.gettempdir())
+ background_tasks.add_task(shutil.rmtree, session_dir, ignore_errors=False)
+
+ logging.debug(session_dir)
+ inputFiles = list()
+ inputFiles.append(upload(image_file, session_dir, "image_file"))
+ if image_file_2:
+ inputFiles.append(upload(image_file_2, session_dir, "image_file_2"))
+ if image_file_3:
+ inputFiles.append(upload(image_file_3, session_dir, "image_file_3"))
+ if image_file_4:
+ inputFiles.append(upload(image_file_4, session_dir, "image_file_4"))
+
+ # logging.info("Input Files: ", inputFiles)
+
+ outputSegmentationFile = str(Path(session_dir) / "output-segmentation.nrrd")
+
+ modelPath = modelDB.modelPath(model_name)
+ modelPtFile = modelPath.joinpath("model.pt")
+
+ assert os.path.exists(modelPtFile)
+
+ moduleDir = Path(__file__).parent.parent
+ inferenceScriptPyFile = os.path.join(moduleDir, "Scripts", "auto3dseg_segresnet_inference.py")
+ auto3DSegCommand = [sys.executable, str(inferenceScriptPyFile),
+ "--model-file", str(modelPtFile),
+ "--image-file", inputFiles[0],
+ "--result-file", outputSegmentationFile]
+ for inputIndex in range(1, len(inputFiles)):
+ auto3DSegCommand.append(f"--image-file-{inputIndex + 1}")
+ auto3DSegCommand.append(inputFiles[inputIndex])
+
+ try:
+ logging.debug(auto3DSegCommand)
+ proc = await asyncio.create_subprocess_exec(*auto3DSegCommand)
+ await proc.wait()
+ if proc.returncode != 0:
+ raise subprocess.CalledProcessError(proc.returncode, auto3DSegCommand)
+ return FileResponse(outputSegmentationFile, media_type='application/octet-stream', background=background_tasks)
+ except Exception as err:
+ logging.info(err)
+ shutil.rmtree(session_dir)
+ import traceback
+ return JSONResponse(
+ content={
+ "error": "An unexpected error occurred",
+ "message": str(err),
+ "traceback": traceback.format_exc()
+ },
+ status_code=500
+ )
+
+
+def main(argv):
+ import argparse
+ parser = argparse.ArgumentParser(description="MONAIAuto3DSeg server")
+ parser.add_argument("-ip", "--host", type=str, metavar="PATH", required=False, default="localhost", help="host name")
+ parser.add_argument("-p", "--port", type=int, metavar="PATH", required=True, help="port")
+
+ args = parser.parse_args(argv)
+
+ import uvicorn
+ # NB: reload=True causing issues on Windows (https://stackoverflow.com/a/70570250)
+ uvicorn.run("main:app", host=args.host, port=args.port, log_level="debug", reload=False)
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui
index d57b851..a4f7e61 100644
--- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui
+++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui
@@ -6,8 +6,8 @@
Computation time on CPU: {humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n"
+ f"