Model: {model['title']} (v{version})" - f"
Description: {model['description']}\n" - f"
Computation time on GPU: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n"
- f"
Computation time on CPU: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n"
- f"
Imaging modality: {model['imagingModality']}\n" - f"
Subject: {model['subject']}\n" - f"
Segments: {', '.join(segmentNames)}",
- "url": url,
- "deprecated": deprecated
- })
- # First version is not deprecated, all subsequent versions are deprecated
- deprecated = True
- return models
- except Exception as e:
- import traceback
- traceback.print_exc()
- raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}")
+ return [terminologyNames.GetValue(idx) for idx in range(terminologyNames.GetNumberOfValues())]
@staticmethod
- def humanReadableTimeFromSec(seconds):
- import math
- if not seconds:
- return "N/A"
- if seconds < 55:
- # if less than a minute, round up to the nearest 5 seconds
- return f"{math.ceil(seconds/5) * 5} sec"
- elif seconds < 60 * 60:
- # if less then 1 hour, round up to the nearest minute
- return f"{math.ceil(seconds/60)} min"
- # Otherwise round up to the nearest 0.1 hour
- return f"{seconds/3600:.1f} h"
-
- def modelsPath(self):
- import pathlib
- return self.fileCachePath.joinpath("models")
-
- def createModelsDir(self):
- modelsDir = self.modelsPath()
- if not os.path.exists(modelsDir):
- os.makedirs(modelsDir)
-
- def modelPath(self, modelName):
- import pathlib
- modelRoot = self.modelsPath().joinpath(modelName)
- # find labels.csv file within the modelRoot folder and subfolders
- for path in pathlib.Path(modelRoot).rglob("labels.csv"):
- return path.parent
- raise RuntimeError(f"Model {modelName} path not found")
-
- def deleteAllModels(self):
- if self.modelsPath().exists():
- import shutil
- shutil.rmtree(self.modelsPath())
+ def getLoadedAnatomicContextNames():
+ import vtk
+ anatomicContextNames = vtk.vtkStringArray()
+ terminologiesLogic = slicer.util.getModuleLogic("Terminologies")
+ terminologiesLogic.GetLoadedAnatomicContextNames(anatomicContextNames)
+
+ return [anatomicContextNames.GetValue(idx) for idx in range(anatomicContextNames.GetNumberOfValues())]
def downloadAllModels(self):
for model in self.models:
slicer.app.processEvents()
self.downloadModel(model["id"])
- def downloadModel(self, modelName):
-
- url = self.model(modelName)["url"]
-
- import zipfile
- import requests
- import pathlib
-
- tempDir = pathlib.Path(slicer.util.tempDirectory())
- modelDir = self.modelsPath().joinpath(modelName)
- if not os.path.exists(modelDir):
- os.makedirs(modelDir)
-
- modelZipFile = tempDir.joinpath("autoseg3d_model.zip")
- self.log(f"Downloading model '{modelName}' from {url}...")
- logging.debug(f"Downloading from {url} to {modelZipFile}...")
-
- try:
- with open(modelZipFile, 'wb') as f:
- with requests.get(url, stream=True) as r:
- r.raise_for_status()
- total_size = int(r.headers.get('content-length', 0))
- reporting_increment_percent = 1.0
- last_reported_download_percent = -reporting_increment_percent
- downloaded_size = 0
- for chunk in r.iter_content(chunk_size=8192 * 16):
- f.write(chunk)
- downloaded_size += len(chunk)
- downloaded_percent = 100.0 * downloaded_size / total_size
- if downloaded_percent - last_reported_download_percent > reporting_increment_percent:
- self.log(f"Downloading model: {downloaded_size/1024/1024:.1f}MB / {total_size/1024/1024:.1f}MB ({downloaded_percent:.1f}%)")
- last_reported_download_percent = downloaded_percent
-
- self.log(f"Download finished. Extracting to {modelDir}...")
- with zipfile.ZipFile(modelZipFile, 'r') as zip_f:
- zip_f.extractall(modelDir)
- except Exception as e:
- raise e
- finally:
- if self.clearOutputFolder:
- self.log("Cleaning up temporary model download folder...")
- if os.path.isdir(tempDir):
- import shutil
- shutil.rmtree(tempDir)
- else:
- self.log(f"Not cleaning up temporary model download folder: {tempDir}")
-
-
- def _MONAIAuto3DSegTerminologyPropertyTypes(self):
+ @staticmethod
+ def _terminologyPropertyTypes(terminologyName):
"""Get label terminology property types defined in from MONAI Auto3DSeg terminology.
Terminology entries are either in DICOM or MONAI Auto3DSeg "Segmentation category and type".
- """
+ # List of property type codes that are specified by in the terminology.
+ #
+ # Codes are stored as a list of strings containing coding scheme designator and code value of the property type,
+ # separated by "^" character. For example "SCT^123456".
+
+ """
terminologiesLogic = slicer.util.getModuleLogic("Terminologies")
- MONAIAuto3DSegTerminologyName = slicer.modules.MONAIAuto3DSegInstance.terminologyName
+ terminologyPropertyTypes = []
# Get anatomicalStructureCategory from the MONAI Auto3DSeg terminology
anatomicalStructureCategory = slicer.vtkSlicerTerminologyCategory()
- numberOfCategories = terminologiesLogic.GetNumberOfCategoriesInTerminology(MONAIAuto3DSegTerminologyName)
- for i in range(numberOfCategories):
- terminologiesLogic.GetNthCategoryInTerminology(MONAIAuto3DSegTerminologyName, i, anatomicalStructureCategory)
- if anatomicalStructureCategory.GetCodingSchemeDesignator() == "SCT" and anatomicalStructureCategory.GetCodeValue() == "123037004":
- # Found the (123037004, SCT, "Anatomical Structure") category within DICOM master list
- break
-
- # Retrieve all anatomicalStructureCategory property type codes
- terminologyPropertyTypes = []
- terminologyType = slicer.vtkSlicerTerminologyType()
- numberOfTypes = terminologiesLogic.GetNumberOfTypesInTerminologyCategory(MONAIAuto3DSegTerminologyName, anatomicalStructureCategory)
- for i in range(numberOfTypes):
- if terminologiesLogic.GetNthTypeInTerminologyCategory(MONAIAuto3DSegTerminologyName, anatomicalStructureCategory, i, terminologyType):
- terminologyPropertyTypes.append(terminologyType.GetCodingSchemeDesignator() + "^" + terminologyType.GetCodeValue())
+ numberOfCategories = terminologiesLogic.GetNumberOfCategoriesInTerminology(terminologyName)
+ for cIdx in range(numberOfCategories):
+ terminologiesLogic.GetNthCategoryInTerminology(terminologyName, cIdx, anatomicalStructureCategory)
+
+ # Retrieve all anatomicalStructureCategory property type codes
+ terminologyType = slicer.vtkSlicerTerminologyType()
+ numberOfTypes = terminologiesLogic.GetNumberOfTypesInTerminologyCategory(terminologyName,
+ anatomicalStructureCategory)
+ for tIdx in range(numberOfTypes):
+ if terminologiesLogic.GetNthTypeInTerminologyCategory(terminologyName, anatomicalStructureCategory, tIdx,
+ terminologyType):
+ terminologyPropertyTypes.append(
+ terminologyType.GetCodingSchemeDesignator() + "^" + terminologyType.GetCodeValue())
return terminologyPropertyTypes
- def _MONAIAuto3DSegAnatomicRegions(self):
+ @staticmethod
+ def _anatomicRegions(anatomicContextName):
"""Get anatomic regions defined in from MONAI Auto3DSeg terminology.
Terminology entries are either in DICOM or MONAI Auto3DSeg "Anatomic codes".
"""
@@ -884,25 +849,101 @@ def _MONAIAuto3DSegAnatomicRegions(self):
# when editing the terminology on the GUI)
return anatomicRegions
- MONAIAuto3DSegAnatomicContextName = slicer.modules.MONAIAuto3DSegInstance.anatomicContextName
-
# Retrieve all anatomical region codes
-
regionObject = slicer.vtkSlicerTerminologyType()
- numberOfRegions = terminologiesLogic.GetNumberOfRegionsInAnatomicContext(MONAIAuto3DSegAnatomicContextName)
+ numberOfRegions = terminologiesLogic.GetNumberOfRegionsInAnatomicContext(anatomicContextName)
for i in range(numberOfRegions):
- if terminologiesLogic.GetNthRegionInAnatomicContext(MONAIAuto3DSegAnatomicContextName, i, regionObject):
+ if terminologiesLogic.GetNthRegionInAnatomicContext(anatomicContextName, i, regionObject):
anatomicRegions.append(regionObject.GetCodingSchemeDesignator() + "^" + regionObject.GetCodeValue())
return anatomicRegions
+ @staticmethod
+ def getSegmentLabelColor(terminologyEntryStr):
+ """Get segment label and color from terminology"""
+
+ def labelColorFromTypeObject(typeObject):
+ """typeObject is a terminology type or type modifier"""
+ label = typeObject.GetSlicerLabel() if typeObject.GetSlicerLabel() else typeObject.GetCodeMeaning()
+ rgb = typeObject.GetRecommendedDisplayRGBValue()
+ return label, (rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0)
+
+ tlogic = slicer.modules.terminologies.logic()
+
+ terminologyEntry = slicer.vtkSlicerTerminologyEntry()
+ if not tlogic.DeserializeTerminologyEntry(terminologyEntryStr, terminologyEntry):
+ raise RuntimeError(f"Failed to deserialize terminology string: {terminologyEntryStr}")
+
+ numberOfTypes = tlogic.GetNumberOfTypesInTerminologyCategory(terminologyEntry.GetTerminologyContextName(),
+ terminologyEntry.GetCategoryObject())
+ foundTerminologyEntry = slicer.vtkSlicerTerminologyEntry()
+ for typeIndex in range(numberOfTypes):
+ tlogic.GetNthTypeInTerminologyCategory(terminologyEntry.GetTerminologyContextName(),
+ terminologyEntry.GetCategoryObject(), typeIndex,
+ foundTerminologyEntry.GetTypeObject())
+ if terminologyEntry.GetTypeObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeObject().GetCodingSchemeDesignator():
+ continue
+ if terminologyEntry.GetTypeObject().GetCodeValue() != foundTerminologyEntry.GetTypeObject().GetCodeValue():
+ continue
+ if terminologyEntry.GetTypeModifierObject() and terminologyEntry.GetTypeModifierObject().GetCodeValue():
+ # Type has a modifier, get the color from there
+ numberOfModifiers = tlogic.GetNumberOfTypeModifiersInTerminologyType(
+ terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(),
+ terminologyEntry.GetTypeObject())
+ foundMatchingModifier = False
+ for modifierIndex in range(numberOfModifiers):
+ tlogic.GetNthTypeModifierInTerminologyType(terminologyEntry.GetTerminologyContextName(),
+ terminologyEntry.GetCategoryObject(),
+ terminologyEntry.GetTypeObject(),
+ modifierIndex,
+ foundTerminologyEntry.GetTypeModifierObject())
+ if terminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator():
+ continue
+ if terminologyEntry.GetTypeModifierObject().GetCodeValue() != foundTerminologyEntry.GetTypeModifierObject().GetCodeValue():
+ continue
+ return labelColorFromTypeObject(foundTerminologyEntry.GetTypeModifierObject())
+ continue
+ return labelColorFromTypeObject(foundTerminologyEntry.GetTypeObject())
+
+ raise RuntimeError(f"Color was not found for terminology {terminologyEntryStr}")
+
+ def __init__(self):
+ """
+ Called when the logic class is instantiated. Can be used for initializing member variables.
+ """
+ ScriptedLoadableModuleLogic.__init__(self)
+ ModelDatabase.__init__(self)
+
+ self.processingCompletedCallback = None
+ self.startResultImportCallback = None
+ self.endResultImportCallback = None
+ self.useStandardSegmentNames = True
+
+ # Timer for checking the output of the segmentation process that is running in the background
+ self.processOutputCheckTimerIntervalMsec = 1000
+
+ # For testing the logic without actually running inference, set self.debugSkipInferenceTempDir to the location
+ # where inference result is stored and set self.debugSkipInference to True.
+ self.debugSkipInference = False
+ self.debugSkipInferenceTempDir = r"c:\Users\andra\AppData\Local\Temp\Slicer\__SlicerTemp__2024-01-16_15+26+25.624"
+
+ def getMONAIPythonPackageInfo(self):
+ return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo()
+
+ def setupPythonRequirements(self, upgrade=False):
+ self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade)
+ return True
+
def labelDescriptions(self, modelName):
"""Return mapping from label value to label description.
Label description is a dict containing "name" and "terminology".
Terminology string uses Slicer terminology entry format - see specification at
https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag
"""
+ labelsFilePath = self.modelPath(modelName).joinpath("labels.csv")
+ return self._labelDescriptions(labelsFilePath)
+ def _labelDescriptions(self, labelsFilePath):
# Helper function to get code string from CSV file row
def getCodeString(field, columnNames, row):
columnValues = []
@@ -917,173 +958,68 @@ def getCodeString(field, columnNames, row):
return columnValues
labelDescriptions = {}
- labelsFilePath = self.modelPath(modelName).joinpath("labels.csv")
import csv
with open(labelsFilePath, "r") as f:
reader = csv.reader(f)
columnNames = next(reader)
- data = {}
# Loop through the rows of the csv file
for row in reader:
-
# Determine segmentation category (DICOM or MONAIAuto3DSeg)
terminologyPropertyTypeStr = ( # Example: SCT^23451007
- row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodingSchemeDesignator")]
- + "^" + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodeValue")])
- if terminologyPropertyTypeStr in self.MONAIAuto3DSegTerminologyPropertyTypes:
- terminologyName = slicer.modules.MONAIAuto3DSegInstance.terminologyName
- else:
+ row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodingSchemeDesignator")]
+ + "^" + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodeValue")])
+ terminologyName = None
+
+ # If property the code is found in this list then the terminology will be used,
+ for tName in self.getLoadedTerminologyNames():
+ propertyTypes = self._terminologyPropertyTypes(tName)
+ if terminologyPropertyTypeStr in propertyTypes:
+ terminologyName = tName
+ break
+
+ # NB: DICOM terminology will be used otherwise. Note: the DICOM terminology does not contain all the
+ # necessary items and some items are incomplete (e.g., don't have color or 3D Slicer label).
+ if not terminologyName:
terminologyName = "Segmentation category and type - DICOM master list"
# Determine the anatomic context name (DICOM or MONAIAuto3DSeg)
anatomicRegionStr = ( # Example: SCT^279245009
- row[columnNames.index("AnatomicRegionSequence.CodingSchemeDesignator")]
- + "^" + row[columnNames.index("AnatomicRegionSequence.CodeValue")])
- if anatomicRegionStr in self.MONAIAuto3DSegAnatomicRegions:
- anatomicContextName = slicer.modules.MONAIAuto3DSegInstance.anatomicContextName
- else:
+ row[columnNames.index("AnatomicRegionSequence.CodingSchemeDesignator")]
+ + "^" + row[columnNames.index("AnatomicRegionSequence.CodeValue")])
+ anatomicContextName = None
+ for aName in self.getLoadedAnatomicContextNames():
+ if anatomicRegionStr in self._anatomicRegions(aName):
+ anatomicContextName = aName
+ if not anatomicContextName:
anatomicContextName = "Anatomic codes - DICOM master list"
terminologyEntryStr = (
- terminologyName
- +"~"
- # Property category: "SCT^123037004^Anatomical Structure" or "SCT^49755003^Morphologically Altered Structure"
- + "^".join(getCodeString("SegmentedPropertyCategoryCodeSequence", columnNames, row))
- + "~"
- # Property type: "SCT^23451007^Adrenal gland", "SCT^367643001^Cyst", ...
- + "^".join(getCodeString("SegmentedPropertyTypeCodeSequence", columnNames, row))
- + "~"
- # Property type modifier: "SCT^7771000^Left", ...
- + "^".join(getCodeString("SegmentedPropertyTypeModifierCodeSequence", columnNames, row))
- + "~"
- + anatomicContextName
- + "~"
- # Anatomic region (set if category is not anatomical structure): "SCT^64033007^Kidney", ...
- + "^".join(getCodeString("AnatomicRegionSequence", columnNames, row))
- + "~"
- # Anatomic region modifier: "SCT^7771000^Left", ...
- + "^".join(getCodeString("AnatomicRegionModifierSequence", columnNames, row))
- )
+ terminologyName
+ + "~"
+ # Property category: "SCT^123037004^Anatomical Structure" or "SCT^49755003^Morphologically Altered Structure"
+ + "^".join(getCodeString("SegmentedPropertyCategoryCodeSequence", columnNames, row))
+ + "~"
+ # Property type: "SCT^23451007^Adrenal gland", "SCT^367643001^Cyst", ...
+ + "^".join(getCodeString("SegmentedPropertyTypeCodeSequence", columnNames, row))
+ + "~"
+ # Property type modifier: "SCT^7771000^Left", ...
+ + "^".join(getCodeString("SegmentedPropertyTypeModifierCodeSequence", columnNames, row))
+ + "~"
+ + anatomicContextName
+ + "~"
+ # Anatomic region (set if category is not anatomical structure): "SCT^64033007^Kidney", ...
+ + "^".join(getCodeString("AnatomicRegionSequence", columnNames, row))
+ + "~"
+ # Anatomic region modifier: "SCT^7771000^Left", ...
+ + "^".join(getCodeString("AnatomicRegionModifierSequence", columnNames, row))
+ )
# Store the terminology string for this structure
labelValue = int(row[columnNames.index("LabelValue")])
name = row[columnNames.index("Name")]
- labelDescriptions[labelValue] = { "name": name, "terminology": terminologyEntryStr }
-
+ labelDescriptions[labelValue] = {"name": name, "terminology": terminologyEntryStr}
return labelDescriptions
- def getSegmentLabelColor(self, terminologyEntryStr):
- """Get segment label and color from terminology"""
-
- def labelColorFromTypeObject(typeObject):
- """typeObject is a terminology type or type modifier"""
- label = typeObject.GetSlicerLabel() if typeObject.GetSlicerLabel() else typeObject.GetCodeMeaning()
- rgb = typeObject.GetRecommendedDisplayRGBValue()
- return label, (rgb[0]/255.0, rgb[1]/255.0, rgb[2]/255.0)
-
- tlogic = slicer.modules.terminologies.logic()
-
- terminologyEntry = slicer.vtkSlicerTerminologyEntry()
- if not tlogic.DeserializeTerminologyEntry(terminologyEntryStr, terminologyEntry):
- raise RuntimeError(f"Failed to deserialize terminology string: {terminologyEntryStr}")
-
- numberOfTypes = tlogic.GetNumberOfTypesInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject())
- foundTerminologyEntry = slicer.vtkSlicerTerminologyEntry()
- for typeIndex in range(numberOfTypes):
- tlogic.GetNthTypeInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), typeIndex, foundTerminologyEntry.GetTypeObject())
- if terminologyEntry.GetTypeObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeObject().GetCodingSchemeDesignator():
- continue
- if terminologyEntry.GetTypeObject().GetCodeValue() != foundTerminologyEntry.GetTypeObject().GetCodeValue():
- continue
- if terminologyEntry.GetTypeModifierObject() and terminologyEntry.GetTypeModifierObject().GetCodeValue():
- # Type has a modifier, get the color from there
- numberOfModifiers = tlogic.GetNumberOfTypeModifiersInTerminologyType(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), terminologyEntry.GetTypeObject())
- foundMatchingModifier = False
- for modifierIndex in range(numberOfModifiers):
- tlogic.GetNthTypeModifierInTerminologyType(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), terminologyEntry.GetTypeObject(),
- modifierIndex, foundTerminologyEntry.GetTypeModifierObject())
- if terminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator():
- continue
- if terminologyEntry.GetTypeModifierObject().GetCodeValue() != foundTerminologyEntry.GetTypeModifierObject().GetCodeValue():
- continue
- return labelColorFromTypeObject(foundTerminologyEntry.GetTypeModifierObject())
- continue
- return labelColorFromTypeObject(foundTerminologyEntry.GetTypeObject())
-
- raise RuntimeError(f"Color was not found for terminology {terminologyEntryStr}")
-
- @staticmethod
- def _findFirstNodeBynamePattern(namePattern, nodes):
- import fnmatch
- for node in nodes:
- if fnmatch.fnmatchcase(node.GetName(), namePattern):
- return node
- return None
-
- @staticmethod
- def assignInputNodesByName(inputs, loadedSampleNodes):
- inputNodes = []
- for inputIndex, input in enumerate(inputs):
- namePattern = input.get("namePattern")
- if namePattern:
- matchingNode = MONAIAuto3DSegLogic._findFirstNodeBynamePattern(namePattern, loadedSampleNodes)
- else:
- matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else loadedSampleNodes[0]
- inputNodes.append(matchingNode)
- return inputNodes
-
- def log(self, text):
- logging.info(text)
- if self.logCallback:
- self.logCallback(text)
-
- def installedMONAIPythonPackageInfo(self):
- import shutil
- import subprocess
- versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode()
- return versionInfo
-
- def setupPythonRequirements(self, upgrade=False):
- import importlib.metadata
- import importlib.util
- import packaging
-
- # Install PyTorch
- try:
- import PyTorchUtils
- except ModuleNotFoundError as e:
- raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.")
-
- self.log("Initializing PyTorch...")
- minimumTorchVersion = "1.12"
- torchLogic = PyTorchUtils.PyTorchUtilsLogic()
- if not torchLogic.torchInstalled():
- self.log("PyTorch Python package is required. Installing... (it may take several minutes)")
- torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement = f">={minimumTorchVersion}")
- if torch is None:
- raise ValueError("PyTorch extension needs to be installed to use this module.")
- else:
- # torch is installed, check version
- from packaging import version
- if version.parse(torchLogic.torch.__version__) < version.parse(minimumTorchVersion):
- raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module."
- + f" Minimum required version is {minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch"
- + f" with version requirement set to: >={minimumTorchVersion}")
-
- # Install MONAI with required components
- self.log("Initializing MONAI...")
- # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too).
- # Without this, for some users monai-0.9.0 got installed, which failed with this error:
- # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'"
- monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3"
- if upgrade:
- monaiInstallString += " --upgrade"
- slicer.util.pip_install(monaiInstallString)
-
- self.dependenciesInstalled = True
- self.log("Dependencies are set up successfully.")
-
-
def setDefaultParameters(self, parameterNode):
"""
Initialize parameter node with default settings.
@@ -1092,6 +1028,8 @@ def setDefaultParameters(self, parameterNode):
parameterNode.SetParameter("Model", self.defaultModel)
if not parameterNode.GetParameter("UseStandardSegmentNames"):
parameterNode.SetParameter("UseStandardSegmentNames", "true")
+ if not parameterNode.GetParameter("ServerPort"):
+ parameterNode.SetParameter("ServerPort", str(8891))
def logProcessOutputUntilCompleted(self, segmentationProcessInfo):
# Wait for the process to end and forward output to the log
@@ -1102,7 +1040,7 @@ def logProcessOutputUntilCompleted(self, segmentationProcessInfo):
line = proc.stdout.readline()
if not line:
break
- self.log(line.rstrip())
+ logger.info(line.rstrip())
except UnicodeDecodeError as e:
# Code page conversion happens because `universal_newlines=True` sets process output to text mode,
# and it fails because probably system locale is not UTF8. We just ignore the error and discard the string,
@@ -1115,7 +1053,6 @@ def logProcessOutputUntilCompleted(self, segmentationProcessInfo):
raise CalledProcessError(retcode, proc.args, output=proc.stdout, stderr=proc.stderr)
def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitForCompletion=True, customData=None):
-
"""
Run the processing algorithm.
Can be used without GUI widget.
@@ -1127,26 +1064,26 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor
:param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False
"""
+ if not self.DEPENDENCY_HANDLER.dependenciesInstalled:
+ with slicer.util.tryWithErrorDisplay("Failed to install required dependencies.", waitCursor=True):
+ self.DEPENDENCY_HANDLER.setupPythonRequirements()
+
if not inputNodes:
raise ValueError("Input nodes are invalid")
if not outputSegmentation:
raise ValueError("Output segmentation is invalid")
- if model == None:
+ if model is None:
model = self.defaultModel
- try:
- modelPath = self.modelPath(model)
- except:
- self.downloadModel(model)
- modelPath = self.modelPath(model)
+ modelPath = self.modelPath(model)
segmentationProcessInfo = {}
import time
startTime = time.time()
- self.log("Processing started")
+ logger.info("Processing started")
if self.debugSkipInference:
# For debugging, use a fixed temporary folder
@@ -1155,9 +1092,6 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor
# Create new empty folder
tempDir = slicer.util.tempDirectory()
- import pathlib
- tempDirPath = pathlib.Path(tempDir)
-
# Get Python executable path
import shutil
pythonSlicerExecutablePath = shutil.which("PythonSlicer")
@@ -1169,7 +1103,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor
for inputIndex, inputNode in enumerate(inputNodes):
if inputNode.IsA('vtkMRMLScalarVolumeNode'):
inputImageFile = tempDir + f"/input-volume{inputIndex}.nrrd"
- self.log(f"Writing input file to {inputImageFile}")
+ logger.info(f"Writing input file to {inputImageFile}")
volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode")
volumeStorageNode.SetFileName(inputImageFile)
volumeStorageNode.UseCompressionOff()
@@ -1190,13 +1124,13 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor
auto3DSegCommand.append(f"--image-file-{inputIndex+1}")
auto3DSegCommand.append(inputFiles[inputIndex])
- self.log("Creating segmentations with MONAIAuto3DSeg AI...")
- self.log(f"Auto3DSeg command: {auto3DSegCommand}")
+ logger.info("Creating segmentations with MONAIAuto3DSeg AI...")
+ logger.info(f"Auto3DSeg command: {auto3DSegCommand}")
additionalEnvironmentVariables = None
if cpu:
additionalEnvironmentVariables = {"CUDA_VISIBLE_DEVICES": "-1"}
- self.log(f"Additional environment variables: {additionalEnvironmentVariables}")
+ logger.info(f"Additional environment variables: {additionalEnvironmentVariables}")
if self.debugSkipInference:
proc = None
@@ -1230,7 +1164,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor
return segmentationProcessInfo
def cancelProcessing(self, segmentationProcessInfo):
- self.log("Cancel is requested.")
+ logger.info("Cancel is requested.")
segmentationProcessInfo["cancelRequested"] = True
proc = segmentationProcessInfo.get("proc")
if proc:
@@ -1248,7 +1182,6 @@ def cancelProcessing(self, segmentationProcessInfo):
def _handleProcessOutputThreadProcess(segmentationProcessInfo):
# Wait for the process to end and forward output to the log
proc = segmentationProcessInfo["proc"]
- from subprocess import CalledProcessError
while True:
try:
line = proc.stdout.readline()
@@ -1264,10 +1197,8 @@ def _handleProcessOutputThreadProcess(segmentationProcessInfo):
retcode = proc.returncode # non-zero return code means error
segmentationProcessInfo["procReturnCode"] = retcode
-
def startSegmentationProcessMonitoring(self, segmentationProcessInfo):
import queue
- import sys
import threading
segmentationProcessInfo["procOutputQueue"] = queue.Queue()
@@ -1276,9 +1207,7 @@ def startSegmentationProcessMonitoring(self, segmentationProcessInfo):
self.checkSegmentationProcessOutput(segmentationProcessInfo)
-
def checkSegmentationProcessOutput(self, segmentationProcessInfo):
-
import queue
outputQueue = segmentationProcessInfo["procOutputQueue"]
while outputQueue:
@@ -1287,17 +1216,14 @@ def checkSegmentationProcessOutput(self, segmentationProcessInfo):
return
try:
line = outputQueue.get_nowait()
- self.log(line)
+ logger.info(line)
except queue.Empty:
break
# No more outputs to process now, check again later
- import qt
qt.QTimer.singleShot(self.processOutputCheckTimerIntervalMsec, lambda segmentationProcessInfo=segmentationProcessInfo: self.checkSegmentationProcessOutput(segmentationProcessInfo))
-
def onSegmentationProcessCompleted(self, segmentationProcessInfo):
-
startTime = segmentationProcessInfo["startTime"]
tempDir = segmentationProcessInfo["tempDir"]
inputNodes = segmentationProcessInfo["inputNodes"]
@@ -1310,17 +1236,14 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo):
if cancelRequested:
procReturnCode = MONAIAuto3DSegLogic.EXIT_CODE_USER_CANCELLED
- self.log(f"Processing was cancelled.")
+ logger.info(f"Processing was cancelled.")
else:
if procReturnCode == 0:
-
if self.startResultImportCallback:
self.startResultImportCallback(customData)
- try:
-
- # Load result
- self.log("Importing segmentation results...")
+ try: # Load result
+ logger.info("Importing segmentation results...")
self.readSegmentation(outputSegmentation, outputSegmentationFile, model)
# Set source volume - required for DICOM Segmentation export
@@ -1338,20 +1261,18 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo):
shNode.SetItemParent(segmentationShItem, studyShItem)
finally:
-
if self.endResultImportCallback:
self.endResultImportCallback(customData)
-
else:
- self.log(f"Processing failed with return code {procReturnCode}")
+ logger.info(f"Processing failed with return code {procReturnCode}")
if self.clearOutputFolder:
- self.log("Cleaning up temporary folder.")
+ logger.info("Cleaning up temporary folder.")
if os.path.isdir(tempDir):
import shutil
shutil.rmtree(tempDir)
else:
- self.log(f"Not cleaning up temporary folder: {tempDir}")
+ logger.info(f"Not cleaning up temporary folder: {tempDir}")
# Report total elapsed time
import time
@@ -1359,19 +1280,17 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo):
segmentationProcessInfo["stopTime"] = stopTime
elapsedTime = stopTime - startTime
if cancelRequested:
- self.log(f"Processing was cancelled after {elapsedTime:.2f} seconds.")
+ logger.info(f"Processing was cancelled after {elapsedTime:.2f} seconds.")
else:
if procReturnCode == 0:
- self.log(f"Processing was completed in {elapsedTime:.2f} seconds.")
+ logger.info(f"Processing was completed in {elapsedTime:.2f} seconds.")
else:
- self.log(f"Processing failed after {elapsedTime:.2f} seconds.")
+ logger.info(f"Processing failed after {elapsedTime:.2f} seconds.")
if self.processingCompletedCallback:
self.processingCompletedCallback(procReturnCode, customData)
-
def readSegmentation(self, outputSegmentation, outputSegmentationFile, model):
-
labelValueToDescription = self.labelDescriptions(model)
# Get label descriptions
@@ -1405,12 +1324,11 @@ def readSegmentation(self, outputSegmentation, outputSegmentationFile, model):
# Set terminology and color
for labelValue in labelValueToDescription:
- segmentName = labelValueToDescription[labelValue]["name"]
terminologyEntryStr = labelValueToDescription[labelValue]["terminology"]
- segmentId = segmentName
- self.setTerminology(outputSegmentation, segmentName, segmentId, terminologyEntryStr)
+ segmentId = labelValueToDescription[labelValue]["name"]
+ self.setTerminology(outputSegmentation, segmentId, terminologyEntryStr)
- def setTerminology(self, segmentation, segmentName, segmentId, terminologyEntryStr):
+ def setTerminology(self, segmentation, segmentId, terminologyEntryStr):
segment = segmentation.GetSegmentation().GetSegment(segmentId)
if not segment:
# Segment is not present in this segmentation
@@ -1423,31 +1341,137 @@ def setTerminology(self, segmentation, segmentName, segmentId, terminologyEntryS
segment.SetName(label)
segment.SetColor(color)
except RuntimeError as e:
- self.log(str(e))
+ logger.info(str(e))
- def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath):
- import json
- modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath()
+class RemoteMONAIAuto3DSegLogic(MONAIAuto3DSegLogic):
- with open(modelsTestResultsJsonFilePath) as f:
- modelsTestResults = json.load(f)
+ DEPENDENCY_HANDLER = RemotePythonDependencies()
- with open(modelsDescriptionJsonFilePath) as f:
- modelsDescription = json.load(f)
-
- for model in modelsDescription["models"]:
- title = model["title"]
- for modelTestResult in modelsTestResults:
- if modelTestResult["title"] == title:
- for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]:
- fieldValue = modelTestResult.get(fieldName)
- if fieldValue:
- model[fieldName] = fieldValue
- break
+ @property
+ def server_address(self):
+ return self._server_address
+
+ @server_address.setter
+ def server_address(self, address):
+ self._server_address = address
+ self._models = []
+
+ def __init__(self):
+ self._server_address = None
+ MONAIAuto3DSegLogic.__init__(self)
+ self._models = []
+
+ def getMONAIPythonPackageInfo(self):
+ return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo(self._server_address)
- with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f:
- json.dump(modelsDescription, f, indent=2)
+ def setupPythonRequirements(self, upgrade=False):
+ self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade)
+ return False
+
+ def loadModelsDescription(self):
+ if not self._server_address:
+ return []
+ else:
+ response = requests.get(self._server_address + "/models")
+ json_data = json.loads(response.text)
+ return json_data
+
+ def labelDescriptions(self, modelId):
+ """Return mapping from label value to label description.
+ Label description is a dict containing "name" and "terminology".
+ Terminology string uses Slicer terminology entry format - see specification at
+ https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag
+ """
+ if not self._server_address:
+ return []
+ else:
+ import tempfile
+ with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile:
+ with requests.get(self._server_address + f"/labelDescriptions?id={modelId}", stream=True) as r:
+ r.raise_for_status()
+
+ with open(tmpfile.name, 'wb') as f:
+ for chunk in r.iter_content(chunk_size=8192):
+ f.write(chunk)
+
+ return self._labelDescriptions(tmpfile.name)
+
+ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitForCompletion=True, customData=None):
+ """
+ Run the processing algorithm.
+ Can be used without GUI widget.
+ :param inputNodes: input nodes in a list
+ :param outputVolume: thresholding result
+ :param modelId: one of self.models
+ :param cpu: use CPU instead of GPU
+ :param waitForCompletion: if True then the method waits for the processing to finish
+ :param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False
+ """
+
+ import time
+ startTime = time.time()
+ logger.info("Processing started")
+
+ tempDir = slicer.util.tempDirectory()
+ outputSegmentationFile = tempDir + "/output-segmentation.nrrd"
+
+ from tempfile import TemporaryDirectory
+ with TemporaryDirectory(dir=tempDir) as temp_dir:
+ # Write input volume to file
+ from pathlib import Path
+ tempDir = Path(temp_dir)
+ inputFiles = []
+ for inputIndex, inputNode in enumerate(inputNodes):
+ if inputNode.IsA('vtkMRMLScalarVolumeNode'):
+ inputImageFile = tempDir / f"input-volume{inputIndex}.nrrd"
+ logger.info(f"Writing input file to {inputImageFile}")
+ volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode")
+ volumeStorageNode.SetFileName(inputImageFile)
+ volumeStorageNode.UseCompressionOff()
+ volumeStorageNode.WriteData(inputNode)
+ slicer.mrmlScene.RemoveNode(volumeStorageNode)
+ inputFiles.append(inputImageFile)
+ else:
+ raise ValueError(f"Input node type {inputNode.GetClassName()} is not supported")
+
+ segmentationProcessInfo = {}
+ segmentationProcessInfo["procReturnCode"] = MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN
+
+ logger.info(f"Initiating Inference on {self._server_address}")
+ files = {}
+
+ try:
+ for idx, inputFile in enumerate(inputFiles, start=1):
+ name = "image_file"
+ if idx > 1:
+ name = f"{name}_{idx}"
+ files[name] = open(inputFile, 'rb')
+
+ with requests.post(self._server_address + f"/infer?model_name={modelId}", files=files) as r:
+ r.raise_for_status()
+
+ with open(outputSegmentationFile, "wb") as binary_file:
+ for chunk in r.iter_content(chunk_size=8192):
+ binary_file.write(chunk)
+
+ segmentationProcessInfo["procReturnCode"] = 0
+ finally:
+ for f in files.values():
+ f.close()
+
+ segmentationProcessInfo["cancelRequested"] = False
+ segmentationProcessInfo["startTime"] = startTime
+ segmentationProcessInfo["tempDir"] = tempDir
+ segmentationProcessInfo["inputNodes"] = inputNodes
+ segmentationProcessInfo["outputSegmentation"] = outputSegmentation
+ segmentationProcessInfo["outputSegmentationFile"] = outputSegmentationFile
+ segmentationProcessInfo["model"] = modelId
+ segmentationProcessInfo["customData"] = customData
+
+ self.onSegmentationProcessCompleted(segmentationProcessInfo)
+
+ return segmentationProcessInfo
#
# MONAIAuto3DSegTest
@@ -1556,10 +1580,8 @@ def test_MONAIAuto3DSeg1(self):
raise RuntimeError(f"Failed to load sample data set '{sampleDataName}'.")
# Set model inputs
-
- inputNodes = []
inputs = model.get("inputs")
- inputNodes = MONAIAuto3DSegLogic.assignInputNodesByName(inputs, loadedSampleNodes)
+ inputNodes = assignInputNodesByName(inputs, loadedSampleNodes)
outputSegmentation = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
@@ -1684,14 +1706,12 @@ def _writeScreenshots(self, segmentationNode, outputPath, baseName, numberOfImag
return sliceScreenshotFilename, rotate3dScreenshotFilename
def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestResultsMarkdownFilePath=None, screenshotUrlBase=None):
-
if modelsTestResultsMarkdownFilePath is None:
modelsTestResultsMarkdownFilePath = modelsTestResultsJsonFilePath.replace(".json", ".md")
if screenshotUrlBase is None:
screenshotUrlBase = "https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/ModelsTestResults/"
import json
- from MONAIAuto3DSeg import MONAIAuto3DSegLogic
with open(modelsTestResultsJsonFilePath) as f:
modelsTestResults = json.load(f)
@@ -1718,7 +1738,7 @@ def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestR
title = f"{model['title']} (v{model['version']})"
f.write(f"## {title}\n")
f.write(f"{model['description']}\n\n")
- f.write(f"Processing time: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n")
+ f.write(f"Processing time: {humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n")
f.write(f"Segment names: {', '.join(model['segmentNames'])}\n\n")
f.write(f"![2D view]({screenshotUrlBase}{model['segmentationResultsScreenshot2D']})\n")
f.write(f"![3D view]({screenshotUrlBase}{model['segmentationResultsScreenshot3D']})\n")
\ No newline at end of file
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py
new file mode 100644
index 0000000..63c5e7d
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py
@@ -0,0 +1 @@
+APPLICATION_NAME = "MONAIAuto3DSeg"
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py
new file mode 100644
index 0000000..438f120
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py
@@ -0,0 +1,130 @@
+import sys
+
+import shutil
+import subprocess
+import logging
+
+from MONAIAuto3DSegLib.constants import APPLICATION_NAME
+logger = logging.getLogger(APPLICATION_NAME)
+
+
+from abc import ABC, abstractmethod
+
+
+class DependenciesBase(ABC):
+
+ minimumTorchVersion = "1.12"
+
+ def __init__(self):
+ self.dependenciesInstalled = False # we don't know yet if dependencies have been installed
+
+ @abstractmethod
+ def installedMONAIPythonPackageInfo(self):
+ pass
+
+ @abstractmethod
+ def setupPythonRequirements(self, upgrade=False):
+ pass
+
+
+class LocalPythonDependencies(DependenciesBase):
+
+ def installedMONAIPythonPackageInfo(self):
+ versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode()
+ return versionInfo
+
+ def _checkModuleInstalled(self, moduleName):
+ try:
+ import importlib
+ importlib.import_module(moduleName)
+ return True
+ except ModuleNotFoundError:
+ return False
+
+ def setupPythonRequirements(self, upgrade=False):
+ def install(package):
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
+
+ logger.info("Initializing PyTorch...")
+
+ packageName = "torch"
+ if not self._checkModuleInstalled(packageName):
+ logger.info("PyTorch Python package is required. Installing... (it may take several minutes)")
+ install(packageName)
+ if not self._checkModuleInstalled(packageName):
+ raise ValueError("pytorch needs to be installed to use this module.")
+ else: # torch is installed, check version
+ from packaging import version
+ import torch
+ if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion):
+ raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module."
+ + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch"
+ + f" with version requirement set to: >={self.minimumTorchVersion}")
+
+ logger.info("Initializing MONAI...")
+ monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3"
+ if upgrade:
+ monaiInstallString += " --upgrade"
+ install(monaiInstallString)
+
+ self.dependenciesInstalled = True
+ logger.info("Dependencies are set up successfully.")
+
+
+class RemotePythonDependencies(DependenciesBase):
+
+ def installedMONAIPythonPackageInfo(self, server_address):
+ if not server_address:
+ return []
+ else:
+ import json
+ import requests
+ response = requests.get(server_address + "/monaiinfo")
+ json_data = json.loads(response.text)
+ return json_data
+
+ def setupPythonRequirements(self, upgrade=False):
+ logger.error("No permission to update remote python packages. Please contact developer.")
+
+
+class SlicerPythonDependencies(DependenciesBase):
+
+ def installedMONAIPythonPackageInfo(self):
+ versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode()
+ return versionInfo
+
+ def setupPythonRequirements(self, upgrade=False):
+ # Install PyTorch
+ try:
+ import PyTorchUtils
+ except ModuleNotFoundError as e:
+ raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.")
+
+ logger.info("Initializing PyTorch...")
+
+ torchLogic = PyTorchUtils.PyTorchUtilsLogic()
+ if not torchLogic.torchInstalled():
+ logger.info("PyTorch Python package is required. Installing... (it may take several minutes)")
+ torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement=f">={self.minimumTorchVersion}")
+ if torch is None:
+ raise ValueError("PyTorch extension needs to be installed to use this module.")
+ else: # torch is installed, check version
+ from packaging import version
+ if version.parse(torchLogic.torch.__version__) < version.parse(self.minimumTorchVersion):
+ raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module."
+ + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch"
+ + f" with version requirement set to: >={self.minimumTorchVersion}")
+
+ # Install MONAI with required components
+ logger.info("Initializing MONAI...")
+ # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too).
+ # Without this, for some users monai-0.9.0 got installed, which failed with this error:
+ # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'"
+ monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3"
+ if upgrade:
+ monaiInstallString += " --upgrade"
+ import slicer
+ slicer.util.pip_install(monaiInstallString)
+
+ self.dependenciesInstalled = True
+ logger.info("Dependencies are set up successfully.")
\ No newline at end of file
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py
new file mode 100644
index 0000000..4e3540f
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py
@@ -0,0 +1,27 @@
+import logging
+from typing import Callable
+
+
+class LogHandler(logging.Handler):
+ """
+
+ code:
+
+ logger = logging.getLogger("XYZ")
+
+ callback = ... # any callable
+ # NB: only catching info level messages and forwarding it to callback
+ handler = LogHandler(callback, logging.INFO)
+ # can format log messages
+ formatter = logging.Formatter('%(levelname)s - %(message)s')
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+
+ """
+ def __init__(self, callback: Callable, level=logging.NOTSET):
+ self._callback = callback
+ super().__init__(level)
+
+ def emit(self, record):
+ msg = self.format(record)
+ self._callback(msg)
\ No newline at end of file
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py
new file mode 100644
index 0000000..d9ef612
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py
@@ -0,0 +1,196 @@
+import json
+import logging
+import os
+from pathlib import Path
+
+from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec
+from MONAIAuto3DSegLib.constants import APPLICATION_NAME
+
+
+logger = logging.getLogger(APPLICATION_NAME)
+
+
+class ModelDatabase:
+
+ @property
+ def defaultModel(self):
+ return self.models[0]["id"]
+
+ @property
+ def models(self):
+ if not self._models:
+ self._models = self.loadModelsDescription()
+ return self._models
+
+ @property
+ def modelsPath(self):
+ modelsPath = self.fileCachePath.joinpath("models")
+ modelsPath.mkdir(exist_ok=True, parents=True)
+ return modelsPath
+
+ @property
+ def modelsDescriptionJsonFilePath(self):
+ return os.path.join(self.moduleDir, "Resources", "Models.json")
+
+ def __init__(self):
+ self.fileCachePath = Path.home().joinpath(f".{APPLICATION_NAME}")
+ self.moduleDir = Path(__file__).parent.parent
+
+ # Disabling this flag preserves input and output data after execution is completed,
+ # which can be useful for troubleshooting.
+ self.clearOutputFolder = True
+ self._models = []
+
+ def model(self, modelId):
+ for model in self.models:
+ if model["id"] == modelId:
+ return model
+ raise RuntimeError(f"Model {modelId} not found")
+
+ def loadModelsDescription(self):
+ modelsJsonFilePath = self.modelsDescriptionJsonFilePath
+ try:
+ models = []
+ import json
+ import re
+ with open(modelsJsonFilePath) as f:
+ modelsTree = json.load(f)["models"]
+ for model in modelsTree:
+ deprecated = False
+ for version in model["versions"]:
+ url = version["url"]
+ # URL format: Model: {model['title']} (v{version})"
+ f" Description: {model['description']}\n"
+ f" Computation time on GPU: {humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n"
+ f" Imaging modality: {model['imagingModality']}\n"
+ f" Subject: {model['subject']}\n"
+ f" Segments: {', '.join(segmentNames)}",
+ "url": url,
+ "deprecated": deprecated
+ })
+ # First version is not deprecated, all subsequent versions are deprecated
+ deprecated = True
+ return models
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}")
+
+ def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath):
+ modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath
+
+ with open(modelsTestResultsJsonFilePath) as f:
+ modelsTestResults = json.load(f)
+
+ with open(modelsDescriptionJsonFilePath) as f:
+ modelsDescription = json.load(f)
+
+ for model in modelsDescription["models"]:
+ title = model["title"]
+ for modelTestResult in modelsTestResults:
+ if modelTestResult["title"] == title:
+ for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]:
+ fieldValue = modelTestResult.get(fieldName)
+ if fieldValue:
+ model[fieldName] = fieldValue
+ break
+
+ with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f:
+ json.dump(modelsDescription, f, indent=2)
+
+ def createModelsDir(self):
+ modelsDir = self.modelsPath
+ if not os.path.exists(modelsDir):
+ os.makedirs(modelsDir)
+
+ def modelPath(self, modelName):
+ try:
+ return self._modelPath(modelName)
+ except RuntimeError:
+ self.downloadModel(modelName)
+ return self._modelPath(modelName)
+
+ def _modelPath(self, modelName):
+ modelRoot = self.modelsPath.joinpath(modelName)
+ # find labels.csv file within the modelRoot folder and subfolders
+ for path in Path(modelRoot).rglob("labels.csv"):
+ return path.parent
+ raise RuntimeError(f"Model {modelName} path not found")
+
+ def deleteAllModels(self):
+ if self.modelsPath.exists():
+ import shutil
+ shutil.rmtree(self.modelsPath)
+
+ def downloadModel(self, modelName):
+ url = self.model(modelName)["url"]
+ import zipfile
+ import requests
+ from tempfile import TemporaryDirectory
+ with TemporaryDirectory() as td:
+ tempDir = Path(td)
+ modelDir = self.modelsPath.joinpath(modelName)
+ Path(modelDir).mkdir(exist_ok=True)
+ modelZipFile = tempDir.joinpath("autoseg3d_model.zip")
+ logger.info(f"Downloading model '{modelName}' from {url}...")
+ logger.debug(f"Downloading from {url} to {modelZipFile}...")
+ try:
+ with open(modelZipFile, 'wb') as f:
+ with requests.get(url, stream=True) as r:
+ r.raise_for_status()
+ total_size = int(r.headers.get('content-length', 0))
+ reporting_increment_percent = 1.0
+ last_reported_download_percent = -reporting_increment_percent
+ downloaded_size = 0
+ for chunk in r.iter_content(chunk_size=8192 * 16):
+ f.write(chunk)
+ downloaded_size += len(chunk)
+ downloaded_percent = 100.0 * downloaded_size / total_size
+ if downloaded_percent - last_reported_download_percent > reporting_increment_percent:
+ logger.info(
+ f"Downloading model: {downloaded_size / 1024 / 1024:.1f}MB / {total_size / 1024 / 1024:.1f}MB ({downloaded_percent:.1f}%)")
+ last_reported_download_percent = downloaded_percent
+
+ logger.info(f"Download finished. Extracting to {modelDir}...")
+ with zipfile.ZipFile(modelZipFile, 'r') as zip_f:
+ zip_f.extractall(modelDir)
+ except Exception as e:
+ raise e
+ finally:
+ if self.clearOutputFolder:
+ logger.info("Cleaning up temporary model download folder...")
+ if os.path.isdir(tempDir):
+ import shutil
+ shutil.rmtree(tempDir)
+ else:
+ logger.info(f"Not cleaning up temporary model download folder: {tempDir}")
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py
new file mode 100644
index 0000000..ec5bdfc
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py
@@ -0,0 +1,95 @@
+import slicer
+import psutil
+
+import logging
+import queue
+import threading
+
+import qt
+from pathlib import Path
+from typing import Callable
+
+from MONAIAuto3DSegLib.constants import APPLICATION_NAME
+logger = logging.getLogger(APPLICATION_NAME)
+
+
+class WebServer:
+
+ CHECK_TIMER_INTERVAL = 1000
+
+ @staticmethod
+ def getPSProcess(pid):
+ try:
+ return psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ return None
+
+ def __init__(self, completedCallback: Callable):
+ self.completedCallback = completedCallback
+ self.procThread = None
+ self.serverProc = None
+ self.queue = None
+
+ def isRunning(self):
+ if self.serverProc is not None:
+ psProcess = self.getPSProcess(self.serverProc.pid)
+ if psProcess:
+ return psProcess.is_running()
+ return False
+
+ def __del__(self):
+ self.killProcess()
+
+ def killProcess(self):
+ if not self.serverProc:
+ return
+ psProcess = self.getPSProcess(self.serverProc.pid)
+ if not psProcess:
+ return
+ for psChildProcess in psProcess.children(recursive=True):
+ psChildProcess.kill()
+ if psProcess.is_running():
+ psProcess.kill()
+
+ def launchConsoleProcess(self, cmd):
+ self.serverProc = \
+ slicer.util.launchConsoleProcess(cmd, cwd=Path(__file__).parent.parent / "auto3dseg", useStartupEnvironment=False)
+
+ self.queue = queue.Queue()
+ self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess)
+ self.procThread.start()
+ self.checkProcessOutput()
+
+ def cleanup(self):
+ if self.procThread:
+ self.procThread.join()
+ self.completedCallback()
+ self.serverProc = None
+ self.procThread = None
+ self.queue = None
+
+ def _handleProcessOutputThreadProcess(self):
+ while True:
+ try:
+ line = self.serverProc.stdout.readline()
+ if not line:
+ break
+ self.queue.put(line.rstrip())
+ except UnicodeDecodeError as e:
+ pass
+ self.serverProc.wait()
+
+ def checkProcessOutput(self):
+ outputQueue = self.queue
+ while outputQueue:
+ try:
+ line = outputQueue.get_nowait()
+ logger.info(line)
+ except queue.Empty:
+ break
+
+ psProcess = self.getPSProcess(self.serverProc.pid)
+ if psProcess and psProcess.is_running(): # No more outputs to process now, check again later
+ qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput)
+ else:
+ self.cleanup()
\ No newline at end of file
diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py
new file mode 100644
index 0000000..186c9cc
--- /dev/null
+++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py
@@ -0,0 +1,35 @@
+
+
+def humanReadableTimeFromSec(seconds):
+ import math
+ if not seconds:
+ return "N/A"
+ if seconds < 55:
+ # if less than a minute, round up to the nearest 5 seconds
+ return f"{math.ceil(seconds / 5) * 5} sec"
+ elif seconds < 60 * 60:
+ # if less than 1 hour, round up to the nearest minute
+ return f"{math.ceil(seconds / 60)} min"
+ # Otherwise round up to the nearest 0.1 hour
+ return f"{seconds / 3600:.1f} h"
+
+
+def assignInputNodesByName(inputs, loadedSampleNodes):
+ inputNodes = []
+ for inputIndex, input in enumerate(inputs):
+ namePattern = input.get("namePattern")
+ if namePattern:
+ matchingNode = findFirstNodeByNamePattern(namePattern, loadedSampleNodes)
+ else:
+ matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else \
+ loadedSampleNodes[0]
+ inputNodes.append(matchingNode)
+ return inputNodes
+
+
+def findFirstNodeByNamePattern(namePattern, nodes):
+ import fnmatch
+ for node in nodes:
+ if fnmatch.fnmatchcase(node.GetName(), namePattern):
+ return node
+ return None
\ No newline at end of file
diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui
index d57b851..1c4fd3c 100644
--- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui
+++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui
@@ -6,8 +6,8 @@
Computation time on CPU: {humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n"
+ f"