diff --git a/MONAIAuto3DSeg/CMakeLists.txt b/MONAIAuto3DSeg/CMakeLists.txt index 18fc468..3d8a9c9 100644 --- a/MONAIAuto3DSeg/CMakeLists.txt +++ b/MONAIAuto3DSeg/CMakeLists.txt @@ -4,6 +4,15 @@ set(MODULE_NAME MONAIAuto3DSeg) #----------------------------------------------------------------------------- set(MODULE_PYTHON_SCRIPTS ${MODULE_NAME}.py + ${MODULE_NAME}Lib/__init__.py + ${MODULE_NAME}Lib/constants.py + ${MODULE_NAME}Lib/dependency_handler.py + ${MODULE_NAME}Lib/log_handler.py + ${MODULE_NAME}Lib/model_database.py + ${MODULE_NAME}Lib/server.py + ${MODULE_NAME}Lib/utils.py + auto3dseg/__init__py + auto3dseg/main.py ) set(MODULE_PYTHON_RESOURCES diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 04eb64d..17533aa 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -1,12 +1,22 @@ import logging import os -import re +import json +import sys import vtk +import qt import slicer +import requests from slicer.ScriptedLoadableModule import * from slicer.util import VTKObservationMixin +from MONAIAuto3DSegLib.model_database import ModelDatabase +from MONAIAuto3DSegLib.constants import APPLICATION_NAME +from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec, assignInputNodesByName +from MONAIAuto3DSegLib.dependency_handler import SlicerPythonDependencies, RemotePythonDependencies + + +logger = logging.getLogger(APPLICATION_NAME) # @@ -24,7 +34,7 @@ def __init__(self, parent): self.parent.title = "MONAI Auto3DSeg" self.parent.categories = ["Segmentation"] self.parent.dependencies = [] - self.parent.contributors = ["Andras Lasso (PerkLab, Queen's University)", "Andres Diaz-Pinto (NVIDIA & KCL)", "Rudolf Bumm (KSGR Switzerland)"] + self.parent.contributors = ["Andras Lasso (PerkLab, Queen's University)", "Andres Diaz-Pinto (NVIDIA & KCL)", "Rudolf Bumm (KSGR Switzerland), Christian Herz (CHOP)"] self.parent.helpText = """ 3D Slicer extension for segmentation using MONAI Auto3DSeg AI model. See more information in the extension documentation. @@ -181,6 +191,18 @@ def __init__(self, parent=None): self._updatingGUIFromParameterNode = False self._processingState = MONAIAuto3DSegWidget.PROCESSING_IDLE self._segmentationProcessInfo = None + self._webServer = None + + from MONAIAuto3DSegLib.log_handler import LogHandler + handler = LogHandler(self.addLog, logging.INFO) + formatter = logging.Formatter('%(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + def onReload(self): + if self._webServer: + self._webServer.killProcess() + super().onReload() def setup(self): """ @@ -194,7 +216,6 @@ def setup(self): self.layout.addWidget(uiWidget) self.ui = slicer.util.childWidgetVariables(uiWidget) - import qt self.ui.downloadSampleDataToolButton.setIcon(qt.QIcon(self.resourcePath("Icons/radiology.svg"))) self.inputNodeSelectors = [self.ui.inputNodeSelector0, self.ui.inputNodeSelector1, self.ui.inputNodeSelector2, self.ui.inputNodeSelector3] @@ -208,7 +229,6 @@ def setup(self): # Create logic class. Logic implements all computations that should be possible to run # in batch mode, without a graphical user interface. self.logic = MONAIAuto3DSegLogic() - self.logic.logCallback = self.addLog self.logic.processingCompletedCallback = self.onProcessingCompleted self.logic.startResultImportCallback = self.onProcessImportStarted self.logic.endResultImportCallback = self.onProcessImportEnded @@ -241,8 +261,15 @@ def setup(self): self.ui.browseToModelsFolderButton.connect("clicked(bool)", self.onBrowseModelsFolder) self.ui.deleteAllModelsButton.connect("clicked(bool)", self.onClearModelsFolder) + self.ui.serverComboBox.lineEdit().setPlaceholderText("enter server address or leave empty to use default") + self.ui.serverComboBox.currentIndexChanged.connect(self.onRemoteServerButtonToggled) + self.ui.remoteServerButton.toggled.connect(self.onRemoteServerButtonToggled) + + self.ui.serverButton.toggled.connect(self.onServerButtonToggled) + # Make sure parameter node is initialized (needed for module reload) self.initializeParameterNode() + self.updateServerUrlGUIFromSettings() self.updateGUIFromParameterNode() @@ -325,15 +352,12 @@ def updateGUIFromParameterNode(self, caller=None, event=None): This method is called whenever parameter node is changed. The module GUI is updated to show the current state of the parameter node. """ - import qt - if self._parameterNode is None or self._updatingGUIFromParameterNode: return # Make sure GUI changes do not call updateParameterNodeFromGUI (it could cause infinite loop) self._updatingGUIFromParameterNode = True try: - self.ui.modelSearchBox.text = self._parameterNode.GetParameter("ModelSearchText") searchWords = self._parameterNode.GetParameter("ModelSearchText").lower().split() @@ -440,6 +464,23 @@ def updateGUIFromParameterNode(self, caller=None, event=None): self.ui.applyButton.toolTip = "Please wait for the segmentation to be cancelled" self.ui.applyButton.enabled = False + remoteConnection = self.ui.remoteServerButton.checked + + # if remoteConnection: + # self.ui.serverCollapsibleButton.collapsed = True + + self.ui.portSpinBox.value = int(self._parameterNode.GetParameter("ServerPort")) + + self.ui.browseToModelsFolderButton.enabled = not remoteConnection + self.ui.useStandardSegmentNamesCheckBox.enabled = not remoteConnection + self.ui.cpuCheckBox.enabled = not remoteConnection + self.ui.showAllModelsCheckBox.enabled = not remoteConnection + self.ui.deleteAllModelsButton.enabled = not remoteConnection + self.ui.packageUpgradeButton.enabled = not remoteConnection + + serverRunning = self._webServer is not None and self._webServer.isRunning() + self.ui.serverButton.checked = serverRunning + self.ui.serverButton.text = "Running ..." if serverRunning else "Start server" finally: # All the GUI updates are done self._updatingGUIFromParameterNode = False @@ -470,6 +511,7 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): self._parameterNode.SetParameter("ShowAllModels", "true" if self.ui.showAllModelsCheckBox.checked else "false") self._parameterNode.SetParameter("UseStandardSegmentNames", "true" if self.ui.useStandardSegmentNamesCheckBox.checked else "false") self._parameterNode.SetNodeReferenceID("OutputSegmentation", self.ui.outputSegmentationSelector.currentNodeID) + self._parameterNode.SetParameter("ServerPort", str(self.ui.portSpinBox.value)) finally: self._parameterNode.EndModify(wasModified) @@ -477,8 +519,23 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): def addLog(self, text): """Append text to log window """ - self.ui.statusLabel.appendPlainText(text) - slicer.app.processEvents() # force update + if len(self.ui.statusLabel.html) > 1024 * 256: + self.ui.statusLabel.clear() + self.ui.statusLabel.insertHtml("Log cleared\n") + self.ui.statusLabel.insertHtml(text) + self.ui.statusLabel.insertPlainText("\n") + self.ui.statusLabel.ensureCursorVisible() + self.ui.statusLabel.repaint() + + # self.ui.statusLabel.appendPlainText(text) + # slicer.app.processEvents() # force update + + def addServerLog(self, *args): + for arg in args: + if self.ui.logConsoleCheckBox.checked: + print(arg) + if self.ui.logGuiCheckBox.checked: + self.addLog(arg) def setProcessingState(self, state): self._processingState = state @@ -500,10 +557,6 @@ def onApply(self): self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_STARTING) - if not self.logic.dependenciesInstalled: - with slicer.util.tryWithErrorDisplay("Failed to install required dependencies.", waitCursor=True): - self.logic.setupPythonRequirements() - try: with slicer.util.tryWithErrorDisplay("Failed to start processing.", waitCursor=True): @@ -518,10 +571,10 @@ def onApply(self): for inputNodeSelector in self.inputNodeSelectors: if inputNodeSelector.visible: inputNodes.append(inputNodeSelector.currentNode()) - self._segmentationProcessInfo = self.logic.process(inputNodes, self.ui.outputSegmentationSelector.currentNode(), - self._currentModelId(), self.ui.cpuCheckBox.checked, waitForCompletion=False) self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IN_PROGRESS) + self._segmentationProcessInfo = self.logic.process(inputNodes, self.ui.outputSegmentationSelector.currentNode(), + self._currentModelId(), self.ui.cpuCheckBox.checked, waitForCompletion=False) except Exception as e: self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) @@ -533,22 +586,19 @@ def onCancel(self): def onProcessImportStarted(self, customData): self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IMPORT_RESULTS) - import qt qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor) slicer.app.processEvents() def onProcessImportEnded(self, customData): - import qt qt.QApplication.restoreOverrideCursor() slicer.app.processEvents() def onProcessingCompleted(self, returnCode, customData): - self.ui.statusLabel.appendPlainText("\nProcessing finished.") + self.addLog("\nProcessing finished.") self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) self._segmentationProcessInfo = None def _currentModelId(self): - import qt itemIndex = self.ui.modelComboBox.currentRow item = self.ui.modelComboBox.item(itemIndex) if not item: @@ -556,7 +606,6 @@ def _currentModelId(self): return item.data(qt.Qt.UserRole) def _setCurrentModelId(self, modelId): - import qt for itemIndex in range(self.ui.modelComboBox.count): item = self.ui.modelComboBox.item(itemIndex) if item.data(qt.Qt.UserRole) == modelId: @@ -565,7 +614,9 @@ def _setCurrentModelId(self, modelId): return False def onDownloadSampleData(self): - model = self.logic.model(self._currentModelId()) + with slicer.util.tryWithErrorDisplay("Failed to retrieve model information", waitCursor=True): + model = self.logic.model(self._currentModelId()) + sampleDataName = model.get("sampleData") if not sampleDataName: slicer.util.messageBox("No sample data is available for this model.") @@ -584,7 +635,7 @@ def onDownloadSampleData(self): slicer.util.messageBox(f"Failed to load sample data set '{sampleDataName}'.") return - inputNodes = MONAIAuto3DSegLogic.assignInputNodesByName(inputs, loadedSampleNodes) + inputNodes = assignInputNodesByName(inputs, loadedSampleNodes) for inputIndex, inputNode in enumerate(inputNodes): if inputNode: self.inputNodeSelectors[inputIndex].setCurrentNode(inputNode) @@ -592,24 +643,25 @@ def onDownloadSampleData(self): def onPackageInfoUpdate(self): self.ui.packageInfoTextBrowser.plainText = "" with slicer.util.tryWithErrorDisplay("Failed to get MONAI package version information", waitCursor=True): - self.ui.packageInfoTextBrowser.plainText = self.logic.installedMONAIPythonPackageInfo().rstrip() + self.ui.packageInfoTextBrowser.plainText = self.logic.getMONAIPythonPackageInfo().rstrip() def onPackageUpgrade(self): + restartRequired = True with slicer.util.tryWithErrorDisplay("Failed to upgrade MONAI", waitCursor=True): - self.logic.setupPythonRequirements(upgrade=True) + restartRequired = self.logic.setupPythonRequirements(upgrade=True) self.onPackageInfoUpdate() - if not slicer.util.confirmOkCancelDisplay(f"This MONAI update requires a 3D Slicer restart.","Press OK to restart."): - raise ValueError("Restart was cancelled.") - else: - slicer.util.restart() + if restartRequired: + if not slicer.util.confirmOkCancelDisplay(f"This MONAI update requires a 3D Slicer restart.","Press OK to restart."): + raise ValueError("Restart was cancelled.") + else: + slicer.util.restart() def onBrowseModelsFolder(self): - import qt self.logic.createModelsDir() - qt.QDesktopServices().openUrl(qt.QUrl.fromLocalFile(self.logic.modelsPath())) + qt.QDesktopServices().openUrl(qt.QUrl.fromLocalFile(self.logic.modelsPath)) def onClearModelsFolder(self): - if not os.path.exists(self.logic.modelsPath()): + if not os.path.exists(self.logic.modelsPath): slicer.util.messageBox("There are no downloaded models.") return if not slicer.util.confirmOkCancelDisplay("All downloaded model files will be deleted. The files will be automatically downloaded again as needed."): @@ -617,11 +669,103 @@ def onClearModelsFolder(self): self.logic.deleteAllModels() slicer.util.messageBox("Downloaded models are deleted.") + def onRemoteServerButtonToggled(self): + if self.ui.remoteServerButton.checked: + self.ui.remoteServerButton.text = "Connected" + self.logic = RemoteMONAIAuto3DSegLogic() + self.logic.server_address = self.ui.serverComboBox.currentText + try: + _ = self.logic.models + self.addLog(f"Remote Server Connected {self.logic.server_address}") + except: + slicer.util.warningDisplay( + f"Connection to remote server '{self.logic.server_address}' failed. " + f"Please check address, port, and connection" + ) + self.ui.remoteServerButton.checked = False + return + self.saveServerUrl() + else: + self.ui.remoteServerButton.text = "Connect" + self.logic = MONAIAuto3DSegLogic() + + self.logic.processingCompletedCallback = self.onProcessingCompleted + self.logic.startResultImportCallback = self.onProcessImportStarted + self.logic.endResultImportCallback = self.onProcessImportEnded + self.updateGUIFromParameterNode() + + def onServerButtonToggled(self, toggled): + if toggled: + if not self._webServer or not self._webServer.isRunning() : + import platform + from pathlib import Path + slicer.util.pip_install("python-multipart fastapi uvicorn[standard]") + + hostName = platform.node() + port = str(self.ui.portSpinBox.value) + cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] + + self.ui.serverAddressLineEdit.text = f"http://{hostName}:{port}" + + from MONAIAuto3DSegLib.server import WebServer + self._webServer = WebServer(completedCallback=lambda: self.ui.serverButton.setChecked(False)) + self._webServer.launchConsoleProcess(cmd) + else: + if self._webServer is not None and self._webServer.isRunning(): + logger.info("Server stop requested.") + self._webServer.killProcess() + if not self._webServer.isRunning(): + logger.info("Server stopped.") + + self._webServer = None + self.updateGUIFromParameterNode() + + def serverUrl(self): + serverUrl = self.ui.serverComboBox.currentText.strip() + if not serverUrl: + serverUrl = "http://127.0.0.1:8000" + return serverUrl.rstrip("/") + + def saveServerUrl(self): + settings = qt.QSettings() + serverUrl = self.ui.serverComboBox.currentText + settings.setValue(f"{self.moduleName}/serverUrl", serverUrl) + serverUrlHistory = self._getServerUrlHistory(serverUrl, settings) + serverUrlHistory.insert(0, serverUrl) # Save current server URL to the top of history + serverUrlHistory = serverUrlHistory[:10] # keep up to first 10 elements + settings.setValue(f"{self.moduleName}/serverUrlHistory", ";".join(serverUrlHistory)) + self.updateServerUrlGUIFromSettings() + + def _getServerUrlHistory(self, serverUrl, settings): + serverUrlHistory = settings.value(f"{self.moduleName}/serverUrlHistory") + if serverUrlHistory: + serverUrlHistory = serverUrlHistory.split(";") + else: + serverUrlHistory = [] + try: + serverUrlHistory.remove(serverUrl) + except ValueError: + pass + return serverUrlHistory + + def updateServerUrlGUIFromSettings(self): + settings = qt.QSettings() + serverUrlHistory = settings.value(f"{self.moduleName}/serverUrlHistory") + + wasBlocked = self.ui.serverComboBox.blockSignals(True) + self.ui.serverComboBox.clear() + if serverUrlHistory: + self.ui.serverComboBox.addItems(serverUrlHistory.split(";")) + self.ui.serverComboBox.setCurrentText(settings.value(f"{self.moduleName}/serverUrl")) + self.ui.serverComboBox.blockSignals(wasBlocked) + + # # MONAIAuto3DSegLogic # -class MONAIAuto3DSegLogic(ScriptedLoadableModuleLogic): + +class MONAIAuto3DSegLogic(ScriptedLoadableModuleLogic, ModelDatabase): """This class should implement all the actual computation done by your module. The interface should be such that other python code can import @@ -634,244 +778,65 @@ class MONAIAuto3DSegLogic(ScriptedLoadableModuleLogic): EXIT_CODE_USER_CANCELLED = 1001 EXIT_CODE_DID_NOT_RUN = 1002 - def __init__(self): - """ - Called when the logic class is instantiated. Can be used for initializing member variables. - """ - from collections import OrderedDict - - ScriptedLoadableModuleLogic.__init__(self) - - import pathlib - self.fileCachePath = pathlib.Path.home().joinpath(".MONAIAuto3DSeg") - - self.dependenciesInstalled = False # we don't know yet if dependencies have been installed - - self.moduleDir = os.path.dirname(slicer.util.getModule('MONAIAuto3DSeg').path) - - self.logCallback = None - self.processingCompletedCallback = None - self.startResultImportCallback = None - self.endResultImportCallback = None - self.useStandardSegmentNames = True - - # List of property type codes that are specified by in the MONAIAuto3DSeg terminology. - # - # Codes are stored as a list of strings containing coding scheme designator and code value of the property type, - # separated by "^" character. For example "SCT^123456". - # - # If property the code is found in this list then the MONAIAuto3DSeg terminology will be used, - # otherwise the DICOM terminology will be used. This is necessary because the DICOM terminology - # does not contain all the necessary items and some items are incomplete (e.g., don't have color or 3D Slicer label). - # - self.MONAIAuto3DSegTerminologyPropertyTypes = self._MONAIAuto3DSegTerminologyPropertyTypes() - - # List of anatomic regions that are specified by MONAIAuto3DSeg. - self.MONAIAuto3DSegAnatomicRegions = self._MONAIAuto3DSegAnatomicRegions() - - # Segmentation models specified by in models.json file - self.models = self.loadModelsDescription() - self.defaultModel = self.models[0]["id"] - - # Timer for checking the output of the segmentation process that is running in the background - self.processOutputCheckTimerIntervalMsec = 1000 - - # Disabling this flag preserves input and output data after execution is completed, - # which can be useful for troubleshooting. - self.clearOutputFolder = True - - # For testing the logic without actually running inference, set self.debugSkipInferenceTempDir to the location - # where inference result is stored and set self.debugSkipInference to True. - self.debugSkipInference = False - self.debugSkipInferenceTempDir = r"c:\Users\andra\AppData\Local\Temp\Slicer\__SlicerTemp__2024-01-16_15+26+25.624" - - - def model(self, modelId): - for model in self.models: - if model["id"] == modelId: - return model - raise RuntimeError(f"Model {modelId} not found") + DEPENDENCY_HANDLER = SlicerPythonDependencies() + @staticmethod + def getLoadedTerminologyNames(): + import vtk + terminologyNames = vtk.vtkStringArray() + terminologiesLogic = slicer.util.getModuleLogic("Terminologies") + terminologiesLogic.GetLoadedTerminologyNames(terminologyNames) - def modelsDescriptionJsonFilePath(self): - return os.path.join(self.moduleDir, "Resources", "Models.json") - - def loadModelsDescription(self): - modelsJsonFilePath = self.modelsDescriptionJsonFilePath() - try: - models = [] - import json - import re - with open(modelsJsonFilePath) as f: - modelsTree = json.load(f)["models"] - for model in modelsTree: - deprecated = False - for version in model["versions"]: - url = version["url"] - # URL format: /-v.zip - # Example URL: https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/Models/17-segments-TotalSegmentator-v1.0.3.zip - match = re.search(r"(?P[^/]+)-v(?P\d+\.\d+\.\d+)", url) - if match: - filename = match.group("filename") - version = match.group("version") - else: - logging.error(f"Failed to extract model id and version from url: {url}") - if "inputs" in model: - # Contains a list of dict. One dict for each input. - # Currently, only "title" (user-displayable name) and "namePattern" of the input are specified. - # In the future, inputs could have additional properties, such as name, type, optional, ... - inputs = model["inputs"] - else: - # Inputs are not defined, use default (single input volume) - inputs = [{"title": "Input volume"}] - segmentNames = model.get('segmentNames') - if not segmentNames: - segmentNames = "N/A" - models.append({ - "id": f"{filename}-v{version}", - "title": model['title'], - "version": version, - "inputs": inputs, - "imagingModality": model["imagingModality"], - "description": model["description"], - "sampleData": model.get("sampleData"), - "segmentNames": model.get("segmentNames"), - "details": - f"

Model: {model['title']} (v{version})" - f"

Description: {model['description']}\n" - f"

Computation time on GPU: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n" - f"
Computation time on CPU: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n" - f"

Imaging modality: {model['imagingModality']}\n" - f"

Subject: {model['subject']}\n" - f"

Segments: {', '.join(segmentNames)}", - "url": url, - "deprecated": deprecated - }) - # First version is not deprecated, all subsequent versions are deprecated - deprecated = True - return models - except Exception as e: - import traceback - traceback.print_exc() - raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}") + return [terminologyNames.GetValue(idx) for idx in range(terminologyNames.GetNumberOfValues())] @staticmethod - def humanReadableTimeFromSec(seconds): - import math - if not seconds: - return "N/A" - if seconds < 55: - # if less than a minute, round up to the nearest 5 seconds - return f"{math.ceil(seconds/5) * 5} sec" - elif seconds < 60 * 60: - # if less then 1 hour, round up to the nearest minute - return f"{math.ceil(seconds/60)} min" - # Otherwise round up to the nearest 0.1 hour - return f"{seconds/3600:.1f} h" - - def modelsPath(self): - import pathlib - return self.fileCachePath.joinpath("models") - - def createModelsDir(self): - modelsDir = self.modelsPath() - if not os.path.exists(modelsDir): - os.makedirs(modelsDir) - - def modelPath(self, modelName): - import pathlib - modelRoot = self.modelsPath().joinpath(modelName) - # find labels.csv file within the modelRoot folder and subfolders - for path in pathlib.Path(modelRoot).rglob("labels.csv"): - return path.parent - raise RuntimeError(f"Model {modelName} path not found") - - def deleteAllModels(self): - if self.modelsPath().exists(): - import shutil - shutil.rmtree(self.modelsPath()) + def getLoadedAnatomicContextNames(): + import vtk + anatomicContextNames = vtk.vtkStringArray() + terminologiesLogic = slicer.util.getModuleLogic("Terminologies") + terminologiesLogic.GetLoadedAnatomicContextNames(anatomicContextNames) + + return [anatomicContextNames.GetValue(idx) for idx in range(anatomicContextNames.GetNumberOfValues())] def downloadAllModels(self): for model in self.models: slicer.app.processEvents() self.downloadModel(model["id"]) - def downloadModel(self, modelName): - - url = self.model(modelName)["url"] - - import zipfile - import requests - import pathlib - - tempDir = pathlib.Path(slicer.util.tempDirectory()) - modelDir = self.modelsPath().joinpath(modelName) - if not os.path.exists(modelDir): - os.makedirs(modelDir) - - modelZipFile = tempDir.joinpath("autoseg3d_model.zip") - self.log(f"Downloading model '{modelName}' from {url}...") - logging.debug(f"Downloading from {url} to {modelZipFile}...") - - try: - with open(modelZipFile, 'wb') as f: - with requests.get(url, stream=True) as r: - r.raise_for_status() - total_size = int(r.headers.get('content-length', 0)) - reporting_increment_percent = 1.0 - last_reported_download_percent = -reporting_increment_percent - downloaded_size = 0 - for chunk in r.iter_content(chunk_size=8192 * 16): - f.write(chunk) - downloaded_size += len(chunk) - downloaded_percent = 100.0 * downloaded_size / total_size - if downloaded_percent - last_reported_download_percent > reporting_increment_percent: - self.log(f"Downloading model: {downloaded_size/1024/1024:.1f}MB / {total_size/1024/1024:.1f}MB ({downloaded_percent:.1f}%)") - last_reported_download_percent = downloaded_percent - - self.log(f"Download finished. Extracting to {modelDir}...") - with zipfile.ZipFile(modelZipFile, 'r') as zip_f: - zip_f.extractall(modelDir) - except Exception as e: - raise e - finally: - if self.clearOutputFolder: - self.log("Cleaning up temporary model download folder...") - if os.path.isdir(tempDir): - import shutil - shutil.rmtree(tempDir) - else: - self.log(f"Not cleaning up temporary model download folder: {tempDir}") - - - def _MONAIAuto3DSegTerminologyPropertyTypes(self): + @staticmethod + def _terminologyPropertyTypes(terminologyName): """Get label terminology property types defined in from MONAI Auto3DSeg terminology. Terminology entries are either in DICOM or MONAI Auto3DSeg "Segmentation category and type". - """ + # List of property type codes that are specified by in the terminology. + # + # Codes are stored as a list of strings containing coding scheme designator and code value of the property type, + # separated by "^" character. For example "SCT^123456". + + """ terminologiesLogic = slicer.util.getModuleLogic("Terminologies") - MONAIAuto3DSegTerminologyName = slicer.modules.MONAIAuto3DSegInstance.terminologyName + terminologyPropertyTypes = [] # Get anatomicalStructureCategory from the MONAI Auto3DSeg terminology anatomicalStructureCategory = slicer.vtkSlicerTerminologyCategory() - numberOfCategories = terminologiesLogic.GetNumberOfCategoriesInTerminology(MONAIAuto3DSegTerminologyName) - for i in range(numberOfCategories): - terminologiesLogic.GetNthCategoryInTerminology(MONAIAuto3DSegTerminologyName, i, anatomicalStructureCategory) - if anatomicalStructureCategory.GetCodingSchemeDesignator() == "SCT" and anatomicalStructureCategory.GetCodeValue() == "123037004": - # Found the (123037004, SCT, "Anatomical Structure") category within DICOM master list - break - - # Retrieve all anatomicalStructureCategory property type codes - terminologyPropertyTypes = [] - terminologyType = slicer.vtkSlicerTerminologyType() - numberOfTypes = terminologiesLogic.GetNumberOfTypesInTerminologyCategory(MONAIAuto3DSegTerminologyName, anatomicalStructureCategory) - for i in range(numberOfTypes): - if terminologiesLogic.GetNthTypeInTerminologyCategory(MONAIAuto3DSegTerminologyName, anatomicalStructureCategory, i, terminologyType): - terminologyPropertyTypes.append(terminologyType.GetCodingSchemeDesignator() + "^" + terminologyType.GetCodeValue()) + numberOfCategories = terminologiesLogic.GetNumberOfCategoriesInTerminology(terminologyName) + for cIdx in range(numberOfCategories): + terminologiesLogic.GetNthCategoryInTerminology(terminologyName, cIdx, anatomicalStructureCategory) + + # Retrieve all anatomicalStructureCategory property type codes + terminologyType = slicer.vtkSlicerTerminologyType() + numberOfTypes = terminologiesLogic.GetNumberOfTypesInTerminologyCategory(terminologyName, + anatomicalStructureCategory) + for tIdx in range(numberOfTypes): + if terminologiesLogic.GetNthTypeInTerminologyCategory(terminologyName, anatomicalStructureCategory, tIdx, + terminologyType): + terminologyPropertyTypes.append( + terminologyType.GetCodingSchemeDesignator() + "^" + terminologyType.GetCodeValue()) return terminologyPropertyTypes - def _MONAIAuto3DSegAnatomicRegions(self): + @staticmethod + def _anatomicRegions(anatomicContextName): """Get anatomic regions defined in from MONAI Auto3DSeg terminology. Terminology entries are either in DICOM or MONAI Auto3DSeg "Anatomic codes". """ @@ -884,25 +849,101 @@ def _MONAIAuto3DSegAnatomicRegions(self): # when editing the terminology on the GUI) return anatomicRegions - MONAIAuto3DSegAnatomicContextName = slicer.modules.MONAIAuto3DSegInstance.anatomicContextName - # Retrieve all anatomical region codes - regionObject = slicer.vtkSlicerTerminologyType() - numberOfRegions = terminologiesLogic.GetNumberOfRegionsInAnatomicContext(MONAIAuto3DSegAnatomicContextName) + numberOfRegions = terminologiesLogic.GetNumberOfRegionsInAnatomicContext(anatomicContextName) for i in range(numberOfRegions): - if terminologiesLogic.GetNthRegionInAnatomicContext(MONAIAuto3DSegAnatomicContextName, i, regionObject): + if terminologiesLogic.GetNthRegionInAnatomicContext(anatomicContextName, i, regionObject): anatomicRegions.append(regionObject.GetCodingSchemeDesignator() + "^" + regionObject.GetCodeValue()) return anatomicRegions + @staticmethod + def getSegmentLabelColor(terminologyEntryStr): + """Get segment label and color from terminology""" + + def labelColorFromTypeObject(typeObject): + """typeObject is a terminology type or type modifier""" + label = typeObject.GetSlicerLabel() if typeObject.GetSlicerLabel() else typeObject.GetCodeMeaning() + rgb = typeObject.GetRecommendedDisplayRGBValue() + return label, (rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0) + + tlogic = slicer.modules.terminologies.logic() + + terminologyEntry = slicer.vtkSlicerTerminologyEntry() + if not tlogic.DeserializeTerminologyEntry(terminologyEntryStr, terminologyEntry): + raise RuntimeError(f"Failed to deserialize terminology string: {terminologyEntryStr}") + + numberOfTypes = tlogic.GetNumberOfTypesInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), + terminologyEntry.GetCategoryObject()) + foundTerminologyEntry = slicer.vtkSlicerTerminologyEntry() + for typeIndex in range(numberOfTypes): + tlogic.GetNthTypeInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), + terminologyEntry.GetCategoryObject(), typeIndex, + foundTerminologyEntry.GetTypeObject()) + if terminologyEntry.GetTypeObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeObject().GetCodingSchemeDesignator(): + continue + if terminologyEntry.GetTypeObject().GetCodeValue() != foundTerminologyEntry.GetTypeObject().GetCodeValue(): + continue + if terminologyEntry.GetTypeModifierObject() and terminologyEntry.GetTypeModifierObject().GetCodeValue(): + # Type has a modifier, get the color from there + numberOfModifiers = tlogic.GetNumberOfTypeModifiersInTerminologyType( + terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), + terminologyEntry.GetTypeObject()) + foundMatchingModifier = False + for modifierIndex in range(numberOfModifiers): + tlogic.GetNthTypeModifierInTerminologyType(terminologyEntry.GetTerminologyContextName(), + terminologyEntry.GetCategoryObject(), + terminologyEntry.GetTypeObject(), + modifierIndex, + foundTerminologyEntry.GetTypeModifierObject()) + if terminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator(): + continue + if terminologyEntry.GetTypeModifierObject().GetCodeValue() != foundTerminologyEntry.GetTypeModifierObject().GetCodeValue(): + continue + return labelColorFromTypeObject(foundTerminologyEntry.GetTypeModifierObject()) + continue + return labelColorFromTypeObject(foundTerminologyEntry.GetTypeObject()) + + raise RuntimeError(f"Color was not found for terminology {terminologyEntryStr}") + + def __init__(self): + """ + Called when the logic class is instantiated. Can be used for initializing member variables. + """ + ScriptedLoadableModuleLogic.__init__(self) + ModelDatabase.__init__(self) + + self.processingCompletedCallback = None + self.startResultImportCallback = None + self.endResultImportCallback = None + self.useStandardSegmentNames = True + + # Timer for checking the output of the segmentation process that is running in the background + self.processOutputCheckTimerIntervalMsec = 1000 + + # For testing the logic without actually running inference, set self.debugSkipInferenceTempDir to the location + # where inference result is stored and set self.debugSkipInference to True. + self.debugSkipInference = False + self.debugSkipInferenceTempDir = r"c:\Users\andra\AppData\Local\Temp\Slicer\__SlicerTemp__2024-01-16_15+26+25.624" + + def getMONAIPythonPackageInfo(self): + return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo() + + def setupPythonRequirements(self, upgrade=False): + self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade) + return True + def labelDescriptions(self, modelName): """Return mapping from label value to label description. Label description is a dict containing "name" and "terminology". Terminology string uses Slicer terminology entry format - see specification at https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag """ + labelsFilePath = self.modelPath(modelName).joinpath("labels.csv") + return self._labelDescriptions(labelsFilePath) + def _labelDescriptions(self, labelsFilePath): # Helper function to get code string from CSV file row def getCodeString(field, columnNames, row): columnValues = [] @@ -917,173 +958,68 @@ def getCodeString(field, columnNames, row): return columnValues labelDescriptions = {} - labelsFilePath = self.modelPath(modelName).joinpath("labels.csv") import csv with open(labelsFilePath, "r") as f: reader = csv.reader(f) columnNames = next(reader) - data = {} # Loop through the rows of the csv file for row in reader: - # Determine segmentation category (DICOM or MONAIAuto3DSeg) terminologyPropertyTypeStr = ( # Example: SCT^23451007 - row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodingSchemeDesignator")] - + "^" + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodeValue")]) - if terminologyPropertyTypeStr in self.MONAIAuto3DSegTerminologyPropertyTypes: - terminologyName = slicer.modules.MONAIAuto3DSegInstance.terminologyName - else: + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodingSchemeDesignator")] + + "^" + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodeValue")]) + terminologyName = None + + # If property the code is found in this list then the terminology will be used, + for tName in self.getLoadedTerminologyNames(): + propertyTypes = self._terminologyPropertyTypes(tName) + if terminologyPropertyTypeStr in propertyTypes: + terminologyName = tName + break + + # NB: DICOM terminology will be used otherwise. Note: the DICOM terminology does not contain all the + # necessary items and some items are incomplete (e.g., don't have color or 3D Slicer label). + if not terminologyName: terminologyName = "Segmentation category and type - DICOM master list" # Determine the anatomic context name (DICOM or MONAIAuto3DSeg) anatomicRegionStr = ( # Example: SCT^279245009 - row[columnNames.index("AnatomicRegionSequence.CodingSchemeDesignator")] - + "^" + row[columnNames.index("AnatomicRegionSequence.CodeValue")]) - if anatomicRegionStr in self.MONAIAuto3DSegAnatomicRegions: - anatomicContextName = slicer.modules.MONAIAuto3DSegInstance.anatomicContextName - else: + row[columnNames.index("AnatomicRegionSequence.CodingSchemeDesignator")] + + "^" + row[columnNames.index("AnatomicRegionSequence.CodeValue")]) + anatomicContextName = None + for aName in self.getLoadedAnatomicContextNames(): + if anatomicRegionStr in self._anatomicRegions(aName): + anatomicContextName = aName + if not anatomicContextName: anatomicContextName = "Anatomic codes - DICOM master list" terminologyEntryStr = ( - terminologyName - +"~" - # Property category: "SCT^123037004^Anatomical Structure" or "SCT^49755003^Morphologically Altered Structure" - + "^".join(getCodeString("SegmentedPropertyCategoryCodeSequence", columnNames, row)) - + "~" - # Property type: "SCT^23451007^Adrenal gland", "SCT^367643001^Cyst", ... - + "^".join(getCodeString("SegmentedPropertyTypeCodeSequence", columnNames, row)) - + "~" - # Property type modifier: "SCT^7771000^Left", ... - + "^".join(getCodeString("SegmentedPropertyTypeModifierCodeSequence", columnNames, row)) - + "~" - + anatomicContextName - + "~" - # Anatomic region (set if category is not anatomical structure): "SCT^64033007^Kidney", ... - + "^".join(getCodeString("AnatomicRegionSequence", columnNames, row)) - + "~" - # Anatomic region modifier: "SCT^7771000^Left", ... - + "^".join(getCodeString("AnatomicRegionModifierSequence", columnNames, row)) - ) + terminologyName + + "~" + # Property category: "SCT^123037004^Anatomical Structure" or "SCT^49755003^Morphologically Altered Structure" + + "^".join(getCodeString("SegmentedPropertyCategoryCodeSequence", columnNames, row)) + + "~" + # Property type: "SCT^23451007^Adrenal gland", "SCT^367643001^Cyst", ... + + "^".join(getCodeString("SegmentedPropertyTypeCodeSequence", columnNames, row)) + + "~" + # Property type modifier: "SCT^7771000^Left", ... + + "^".join(getCodeString("SegmentedPropertyTypeModifierCodeSequence", columnNames, row)) + + "~" + + anatomicContextName + + "~" + # Anatomic region (set if category is not anatomical structure): "SCT^64033007^Kidney", ... + + "^".join(getCodeString("AnatomicRegionSequence", columnNames, row)) + + "~" + # Anatomic region modifier: "SCT^7771000^Left", ... + + "^".join(getCodeString("AnatomicRegionModifierSequence", columnNames, row)) + ) # Store the terminology string for this structure labelValue = int(row[columnNames.index("LabelValue")]) name = row[columnNames.index("Name")] - labelDescriptions[labelValue] = { "name": name, "terminology": terminologyEntryStr } - + labelDescriptions[labelValue] = {"name": name, "terminology": terminologyEntryStr} return labelDescriptions - def getSegmentLabelColor(self, terminologyEntryStr): - """Get segment label and color from terminology""" - - def labelColorFromTypeObject(typeObject): - """typeObject is a terminology type or type modifier""" - label = typeObject.GetSlicerLabel() if typeObject.GetSlicerLabel() else typeObject.GetCodeMeaning() - rgb = typeObject.GetRecommendedDisplayRGBValue() - return label, (rgb[0]/255.0, rgb[1]/255.0, rgb[2]/255.0) - - tlogic = slicer.modules.terminologies.logic() - - terminologyEntry = slicer.vtkSlicerTerminologyEntry() - if not tlogic.DeserializeTerminologyEntry(terminologyEntryStr, terminologyEntry): - raise RuntimeError(f"Failed to deserialize terminology string: {terminologyEntryStr}") - - numberOfTypes = tlogic.GetNumberOfTypesInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject()) - foundTerminologyEntry = slicer.vtkSlicerTerminologyEntry() - for typeIndex in range(numberOfTypes): - tlogic.GetNthTypeInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), typeIndex, foundTerminologyEntry.GetTypeObject()) - if terminologyEntry.GetTypeObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeObject().GetCodingSchemeDesignator(): - continue - if terminologyEntry.GetTypeObject().GetCodeValue() != foundTerminologyEntry.GetTypeObject().GetCodeValue(): - continue - if terminologyEntry.GetTypeModifierObject() and terminologyEntry.GetTypeModifierObject().GetCodeValue(): - # Type has a modifier, get the color from there - numberOfModifiers = tlogic.GetNumberOfTypeModifiersInTerminologyType(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), terminologyEntry.GetTypeObject()) - foundMatchingModifier = False - for modifierIndex in range(numberOfModifiers): - tlogic.GetNthTypeModifierInTerminologyType(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), terminologyEntry.GetTypeObject(), - modifierIndex, foundTerminologyEntry.GetTypeModifierObject()) - if terminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator(): - continue - if terminologyEntry.GetTypeModifierObject().GetCodeValue() != foundTerminologyEntry.GetTypeModifierObject().GetCodeValue(): - continue - return labelColorFromTypeObject(foundTerminologyEntry.GetTypeModifierObject()) - continue - return labelColorFromTypeObject(foundTerminologyEntry.GetTypeObject()) - - raise RuntimeError(f"Color was not found for terminology {terminologyEntryStr}") - - @staticmethod - def _findFirstNodeBynamePattern(namePattern, nodes): - import fnmatch - for node in nodes: - if fnmatch.fnmatchcase(node.GetName(), namePattern): - return node - return None - - @staticmethod - def assignInputNodesByName(inputs, loadedSampleNodes): - inputNodes = [] - for inputIndex, input in enumerate(inputs): - namePattern = input.get("namePattern") - if namePattern: - matchingNode = MONAIAuto3DSegLogic._findFirstNodeBynamePattern(namePattern, loadedSampleNodes) - else: - matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else loadedSampleNodes[0] - inputNodes.append(matchingNode) - return inputNodes - - def log(self, text): - logging.info(text) - if self.logCallback: - self.logCallback(text) - - def installedMONAIPythonPackageInfo(self): - import shutil - import subprocess - versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode() - return versionInfo - - def setupPythonRequirements(self, upgrade=False): - import importlib.metadata - import importlib.util - import packaging - - # Install PyTorch - try: - import PyTorchUtils - except ModuleNotFoundError as e: - raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.") - - self.log("Initializing PyTorch...") - minimumTorchVersion = "1.12" - torchLogic = PyTorchUtils.PyTorchUtilsLogic() - if not torchLogic.torchInstalled(): - self.log("PyTorch Python package is required. Installing... (it may take several minutes)") - torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement = f">={minimumTorchVersion}") - if torch is None: - raise ValueError("PyTorch extension needs to be installed to use this module.") - else: - # torch is installed, check version - from packaging import version - if version.parse(torchLogic.torch.__version__) < version.parse(minimumTorchVersion): - raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module." - + f" Minimum required version is {minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" - + f" with version requirement set to: >={minimumTorchVersion}") - - # Install MONAI with required components - self.log("Initializing MONAI...") - # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too). - # Without this, for some users monai-0.9.0 got installed, which failed with this error: - # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'" - monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" - if upgrade: - monaiInstallString += " --upgrade" - slicer.util.pip_install(monaiInstallString) - - self.dependenciesInstalled = True - self.log("Dependencies are set up successfully.") - - def setDefaultParameters(self, parameterNode): """ Initialize parameter node with default settings. @@ -1092,6 +1028,8 @@ def setDefaultParameters(self, parameterNode): parameterNode.SetParameter("Model", self.defaultModel) if not parameterNode.GetParameter("UseStandardSegmentNames"): parameterNode.SetParameter("UseStandardSegmentNames", "true") + if not parameterNode.GetParameter("ServerPort"): + parameterNode.SetParameter("ServerPort", str(8891)) def logProcessOutputUntilCompleted(self, segmentationProcessInfo): # Wait for the process to end and forward output to the log @@ -1102,7 +1040,7 @@ def logProcessOutputUntilCompleted(self, segmentationProcessInfo): line = proc.stdout.readline() if not line: break - self.log(line.rstrip()) + logger.info(line.rstrip()) except UnicodeDecodeError as e: # Code page conversion happens because `universal_newlines=True` sets process output to text mode, # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string, @@ -1115,7 +1053,6 @@ def logProcessOutputUntilCompleted(self, segmentationProcessInfo): raise CalledProcessError(retcode, proc.args, output=proc.stdout, stderr=proc.stderr) def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitForCompletion=True, customData=None): - """ Run the processing algorithm. Can be used without GUI widget. @@ -1127,26 +1064,26 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor :param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False """ + if not self.DEPENDENCY_HANDLER.dependenciesInstalled: + with slicer.util.tryWithErrorDisplay("Failed to install required dependencies.", waitCursor=True): + self.DEPENDENCY_HANDLER.setupPythonRequirements() + if not inputNodes: raise ValueError("Input nodes are invalid") if not outputSegmentation: raise ValueError("Output segmentation is invalid") - if model == None: + if model is None: model = self.defaultModel - try: - modelPath = self.modelPath(model) - except: - self.downloadModel(model) - modelPath = self.modelPath(model) + modelPath = self.modelPath(model) segmentationProcessInfo = {} import time startTime = time.time() - self.log("Processing started") + logger.info("Processing started") if self.debugSkipInference: # For debugging, use a fixed temporary folder @@ -1155,9 +1092,6 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor # Create new empty folder tempDir = slicer.util.tempDirectory() - import pathlib - tempDirPath = pathlib.Path(tempDir) - # Get Python executable path import shutil pythonSlicerExecutablePath = shutil.which("PythonSlicer") @@ -1169,7 +1103,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor for inputIndex, inputNode in enumerate(inputNodes): if inputNode.IsA('vtkMRMLScalarVolumeNode'): inputImageFile = tempDir + f"/input-volume{inputIndex}.nrrd" - self.log(f"Writing input file to {inputImageFile}") + logger.info(f"Writing input file to {inputImageFile}") volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode") volumeStorageNode.SetFileName(inputImageFile) volumeStorageNode.UseCompressionOff() @@ -1190,13 +1124,13 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor auto3DSegCommand.append(f"--image-file-{inputIndex+1}") auto3DSegCommand.append(inputFiles[inputIndex]) - self.log("Creating segmentations with MONAIAuto3DSeg AI...") - self.log(f"Auto3DSeg command: {auto3DSegCommand}") + logger.info("Creating segmentations with MONAIAuto3DSeg AI...") + logger.info(f"Auto3DSeg command: {auto3DSegCommand}") additionalEnvironmentVariables = None if cpu: additionalEnvironmentVariables = {"CUDA_VISIBLE_DEVICES": "-1"} - self.log(f"Additional environment variables: {additionalEnvironmentVariables}") + logger.info(f"Additional environment variables: {additionalEnvironmentVariables}") if self.debugSkipInference: proc = None @@ -1230,7 +1164,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor return segmentationProcessInfo def cancelProcessing(self, segmentationProcessInfo): - self.log("Cancel is requested.") + logger.info("Cancel is requested.") segmentationProcessInfo["cancelRequested"] = True proc = segmentationProcessInfo.get("proc") if proc: @@ -1248,7 +1182,6 @@ def cancelProcessing(self, segmentationProcessInfo): def _handleProcessOutputThreadProcess(segmentationProcessInfo): # Wait for the process to end and forward output to the log proc = segmentationProcessInfo["proc"] - from subprocess import CalledProcessError while True: try: line = proc.stdout.readline() @@ -1264,10 +1197,8 @@ def _handleProcessOutputThreadProcess(segmentationProcessInfo): retcode = proc.returncode # non-zero return code means error segmentationProcessInfo["procReturnCode"] = retcode - def startSegmentationProcessMonitoring(self, segmentationProcessInfo): import queue - import sys import threading segmentationProcessInfo["procOutputQueue"] = queue.Queue() @@ -1276,9 +1207,7 @@ def startSegmentationProcessMonitoring(self, segmentationProcessInfo): self.checkSegmentationProcessOutput(segmentationProcessInfo) - def checkSegmentationProcessOutput(self, segmentationProcessInfo): - import queue outputQueue = segmentationProcessInfo["procOutputQueue"] while outputQueue: @@ -1287,17 +1216,14 @@ def checkSegmentationProcessOutput(self, segmentationProcessInfo): return try: line = outputQueue.get_nowait() - self.log(line) + logger.info(line) except queue.Empty: break # No more outputs to process now, check again later - import qt qt.QTimer.singleShot(self.processOutputCheckTimerIntervalMsec, lambda segmentationProcessInfo=segmentationProcessInfo: self.checkSegmentationProcessOutput(segmentationProcessInfo)) - def onSegmentationProcessCompleted(self, segmentationProcessInfo): - startTime = segmentationProcessInfo["startTime"] tempDir = segmentationProcessInfo["tempDir"] inputNodes = segmentationProcessInfo["inputNodes"] @@ -1310,17 +1236,14 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): if cancelRequested: procReturnCode = MONAIAuto3DSegLogic.EXIT_CODE_USER_CANCELLED - self.log(f"Processing was cancelled.") + logger.info(f"Processing was cancelled.") else: if procReturnCode == 0: - if self.startResultImportCallback: self.startResultImportCallback(customData) - try: - - # Load result - self.log("Importing segmentation results...") + try: # Load result + logger.info("Importing segmentation results...") self.readSegmentation(outputSegmentation, outputSegmentationFile, model) # Set source volume - required for DICOM Segmentation export @@ -1338,20 +1261,18 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): shNode.SetItemParent(segmentationShItem, studyShItem) finally: - if self.endResultImportCallback: self.endResultImportCallback(customData) - else: - self.log(f"Processing failed with return code {procReturnCode}") + logger.info(f"Processing failed with return code {procReturnCode}") if self.clearOutputFolder: - self.log("Cleaning up temporary folder.") + logger.info("Cleaning up temporary folder.") if os.path.isdir(tempDir): import shutil shutil.rmtree(tempDir) else: - self.log(f"Not cleaning up temporary folder: {tempDir}") + logger.info(f"Not cleaning up temporary folder: {tempDir}") # Report total elapsed time import time @@ -1359,19 +1280,17 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): segmentationProcessInfo["stopTime"] = stopTime elapsedTime = stopTime - startTime if cancelRequested: - self.log(f"Processing was cancelled after {elapsedTime:.2f} seconds.") + logger.info(f"Processing was cancelled after {elapsedTime:.2f} seconds.") else: if procReturnCode == 0: - self.log(f"Processing was completed in {elapsedTime:.2f} seconds.") + logger.info(f"Processing was completed in {elapsedTime:.2f} seconds.") else: - self.log(f"Processing failed after {elapsedTime:.2f} seconds.") + logger.info(f"Processing failed after {elapsedTime:.2f} seconds.") if self.processingCompletedCallback: self.processingCompletedCallback(procReturnCode, customData) - def readSegmentation(self, outputSegmentation, outputSegmentationFile, model): - labelValueToDescription = self.labelDescriptions(model) # Get label descriptions @@ -1405,12 +1324,11 @@ def readSegmentation(self, outputSegmentation, outputSegmentationFile, model): # Set terminology and color for labelValue in labelValueToDescription: - segmentName = labelValueToDescription[labelValue]["name"] terminologyEntryStr = labelValueToDescription[labelValue]["terminology"] - segmentId = segmentName - self.setTerminology(outputSegmentation, segmentName, segmentId, terminologyEntryStr) + segmentId = labelValueToDescription[labelValue]["name"] + self.setTerminology(outputSegmentation, segmentId, terminologyEntryStr) - def setTerminology(self, segmentation, segmentName, segmentId, terminologyEntryStr): + def setTerminology(self, segmentation, segmentId, terminologyEntryStr): segment = segmentation.GetSegmentation().GetSegment(segmentId) if not segment: # Segment is not present in this segmentation @@ -1423,31 +1341,137 @@ def setTerminology(self, segmentation, segmentName, segmentId, terminologyEntryS segment.SetName(label) segment.SetColor(color) except RuntimeError as e: - self.log(str(e)) + logger.info(str(e)) - def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath): - import json - modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath() +class RemoteMONAIAuto3DSegLogic(MONAIAuto3DSegLogic): - with open(modelsTestResultsJsonFilePath) as f: - modelsTestResults = json.load(f) + DEPENDENCY_HANDLER = RemotePythonDependencies() - with open(modelsDescriptionJsonFilePath) as f: - modelsDescription = json.load(f) - - for model in modelsDescription["models"]: - title = model["title"] - for modelTestResult in modelsTestResults: - if modelTestResult["title"] == title: - for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]: - fieldValue = modelTestResult.get(fieldName) - if fieldValue: - model[fieldName] = fieldValue - break + @property + def server_address(self): + return self._server_address + + @server_address.setter + def server_address(self, address): + self._server_address = address + self._models = [] + + def __init__(self): + self._server_address = None + MONAIAuto3DSegLogic.__init__(self) + self._models = [] + + def getMONAIPythonPackageInfo(self): + return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo(self._server_address) - with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f: - json.dump(modelsDescription, f, indent=2) + def setupPythonRequirements(self, upgrade=False): + self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade) + return False + + def loadModelsDescription(self): + if not self._server_address: + return [] + else: + response = requests.get(self._server_address + "/models") + json_data = json.loads(response.text) + return json_data + + def labelDescriptions(self, modelId): + """Return mapping from label value to label description. + Label description is a dict containing "name" and "terminology". + Terminology string uses Slicer terminology entry format - see specification at + https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag + """ + if not self._server_address: + return [] + else: + import tempfile + with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile: + with requests.get(self._server_address + f"/labelDescriptions?id={modelId}", stream=True) as r: + r.raise_for_status() + + with open(tmpfile.name, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + return self._labelDescriptions(tmpfile.name) + + def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitForCompletion=True, customData=None): + """ + Run the processing algorithm. + Can be used without GUI widget. + :param inputNodes: input nodes in a list + :param outputVolume: thresholding result + :param modelId: one of self.models + :param cpu: use CPU instead of GPU + :param waitForCompletion: if True then the method waits for the processing to finish + :param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False + """ + + import time + startTime = time.time() + logger.info("Processing started") + + tempDir = slicer.util.tempDirectory() + outputSegmentationFile = tempDir + "/output-segmentation.nrrd" + + from tempfile import TemporaryDirectory + with TemporaryDirectory(dir=tempDir) as temp_dir: + # Write input volume to file + from pathlib import Path + tempDir = Path(temp_dir) + inputFiles = [] + for inputIndex, inputNode in enumerate(inputNodes): + if inputNode.IsA('vtkMRMLScalarVolumeNode'): + inputImageFile = tempDir / f"input-volume{inputIndex}.nrrd" + logger.info(f"Writing input file to {inputImageFile}") + volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode") + volumeStorageNode.SetFileName(inputImageFile) + volumeStorageNode.UseCompressionOff() + volumeStorageNode.WriteData(inputNode) + slicer.mrmlScene.RemoveNode(volumeStorageNode) + inputFiles.append(inputImageFile) + else: + raise ValueError(f"Input node type {inputNode.GetClassName()} is not supported") + + segmentationProcessInfo = {} + segmentationProcessInfo["procReturnCode"] = MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN + + logger.info(f"Initiating Inference on {self._server_address}") + files = {} + + try: + for idx, inputFile in enumerate(inputFiles, start=1): + name = "image_file" + if idx > 1: + name = f"{name}_{idx}" + files[name] = open(inputFile, 'rb') + + with requests.post(self._server_address + f"/infer?model_name={modelId}", files=files) as r: + r.raise_for_status() + + with open(outputSegmentationFile, "wb") as binary_file: + for chunk in r.iter_content(chunk_size=8192): + binary_file.write(chunk) + + segmentationProcessInfo["procReturnCode"] = 0 + finally: + for f in files.values(): + f.close() + + segmentationProcessInfo["cancelRequested"] = False + segmentationProcessInfo["startTime"] = startTime + segmentationProcessInfo["tempDir"] = tempDir + segmentationProcessInfo["inputNodes"] = inputNodes + segmentationProcessInfo["outputSegmentation"] = outputSegmentation + segmentationProcessInfo["outputSegmentationFile"] = outputSegmentationFile + segmentationProcessInfo["model"] = modelId + segmentationProcessInfo["customData"] = customData + + self.onSegmentationProcessCompleted(segmentationProcessInfo) + + return segmentationProcessInfo # # MONAIAuto3DSegTest @@ -1556,10 +1580,8 @@ def test_MONAIAuto3DSeg1(self): raise RuntimeError(f"Failed to load sample data set '{sampleDataName}'.") # Set model inputs - - inputNodes = [] inputs = model.get("inputs") - inputNodes = MONAIAuto3DSegLogic.assignInputNodesByName(inputs, loadedSampleNodes) + inputNodes = assignInputNodesByName(inputs, loadedSampleNodes) outputSegmentation = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") @@ -1684,14 +1706,12 @@ def _writeScreenshots(self, segmentationNode, outputPath, baseName, numberOfImag return sliceScreenshotFilename, rotate3dScreenshotFilename def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestResultsMarkdownFilePath=None, screenshotUrlBase=None): - if modelsTestResultsMarkdownFilePath is None: modelsTestResultsMarkdownFilePath = modelsTestResultsJsonFilePath.replace(".json", ".md") if screenshotUrlBase is None: screenshotUrlBase = "https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/ModelsTestResults/" import json - from MONAIAuto3DSeg import MONAIAuto3DSegLogic with open(modelsTestResultsJsonFilePath) as f: modelsTestResults = json.load(f) @@ -1718,7 +1738,7 @@ def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestR title = f"{model['title']} (v{model['version']})" f.write(f"## {title}\n") f.write(f"{model['description']}\n\n") - f.write(f"Processing time: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n") + f.write(f"Processing time: {humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n") f.write(f"Segment names: {', '.join(model['segmentNames'])}\n\n") f.write(f"![2D view]({screenshotUrlBase}{model['segmentationResultsScreenshot2D']})\n") f.write(f"![3D view]({screenshotUrlBase}{model['segmentationResultsScreenshot3D']})\n") \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py new file mode 100644 index 0000000..63c5e7d --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py @@ -0,0 +1 @@ +APPLICATION_NAME = "MONAIAuto3DSeg" diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py new file mode 100644 index 0000000..438f120 --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py @@ -0,0 +1,130 @@ +import sys + +import shutil +import subprocess +import logging + +from MONAIAuto3DSegLib.constants import APPLICATION_NAME +logger = logging.getLogger(APPLICATION_NAME) + + +from abc import ABC, abstractmethod + + +class DependenciesBase(ABC): + + minimumTorchVersion = "1.12" + + def __init__(self): + self.dependenciesInstalled = False # we don't know yet if dependencies have been installed + + @abstractmethod + def installedMONAIPythonPackageInfo(self): + pass + + @abstractmethod + def setupPythonRequirements(self, upgrade=False): + pass + + +class LocalPythonDependencies(DependenciesBase): + + def installedMONAIPythonPackageInfo(self): + versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode() + return versionInfo + + def _checkModuleInstalled(self, moduleName): + try: + import importlib + importlib.import_module(moduleName) + return True + except ModuleNotFoundError: + return False + + def setupPythonRequirements(self, upgrade=False): + def install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + logger.info("Initializing PyTorch...") + + packageName = "torch" + if not self._checkModuleInstalled(packageName): + logger.info("PyTorch Python package is required. Installing... (it may take several minutes)") + install(packageName) + if not self._checkModuleInstalled(packageName): + raise ValueError("pytorch needs to be installed to use this module.") + else: # torch is installed, check version + from packaging import version + import torch + if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion): + raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module." + + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" + + f" with version requirement set to: >={self.minimumTorchVersion}") + + logger.info("Initializing MONAI...") + monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" + if upgrade: + monaiInstallString += " --upgrade" + install(monaiInstallString) + + self.dependenciesInstalled = True + logger.info("Dependencies are set up successfully.") + + +class RemotePythonDependencies(DependenciesBase): + + def installedMONAIPythonPackageInfo(self, server_address): + if not server_address: + return [] + else: + import json + import requests + response = requests.get(server_address + "/monaiinfo") + json_data = json.loads(response.text) + return json_data + + def setupPythonRequirements(self, upgrade=False): + logger.error("No permission to update remote python packages. Please contact developer.") + + +class SlicerPythonDependencies(DependenciesBase): + + def installedMONAIPythonPackageInfo(self): + versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode() + return versionInfo + + def setupPythonRequirements(self, upgrade=False): + # Install PyTorch + try: + import PyTorchUtils + except ModuleNotFoundError as e: + raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.") + + logger.info("Initializing PyTorch...") + + torchLogic = PyTorchUtils.PyTorchUtilsLogic() + if not torchLogic.torchInstalled(): + logger.info("PyTorch Python package is required. Installing... (it may take several minutes)") + torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement=f">={self.minimumTorchVersion}") + if torch is None: + raise ValueError("PyTorch extension needs to be installed to use this module.") + else: # torch is installed, check version + from packaging import version + if version.parse(torchLogic.torch.__version__) < version.parse(self.minimumTorchVersion): + raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module." + + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" + + f" with version requirement set to: >={self.minimumTorchVersion}") + + # Install MONAI with required components + logger.info("Initializing MONAI...") + # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too). + # Without this, for some users monai-0.9.0 got installed, which failed with this error: + # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'" + monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" + if upgrade: + monaiInstallString += " --upgrade" + import slicer + slicer.util.pip_install(monaiInstallString) + + self.dependenciesInstalled = True + logger.info("Dependencies are set up successfully.") \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py new file mode 100644 index 0000000..4e3540f --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py @@ -0,0 +1,27 @@ +import logging +from typing import Callable + + +class LogHandler(logging.Handler): + """ + + code: + + logger = logging.getLogger("XYZ") + + callback = ... # any callable + # NB: only catching info level messages and forwarding it to callback + handler = LogHandler(callback, logging.INFO) + # can format log messages + formatter = logging.Formatter('%(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + """ + def __init__(self, callback: Callable, level=logging.NOTSET): + self._callback = callback + super().__init__(level) + + def emit(self, record): + msg = self.format(record) + self._callback(msg) \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py new file mode 100644 index 0000000..d9ef612 --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py @@ -0,0 +1,196 @@ +import json +import logging +import os +from pathlib import Path + +from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec +from MONAIAuto3DSegLib.constants import APPLICATION_NAME + + +logger = logging.getLogger(APPLICATION_NAME) + + +class ModelDatabase: + + @property + def defaultModel(self): + return self.models[0]["id"] + + @property + def models(self): + if not self._models: + self._models = self.loadModelsDescription() + return self._models + + @property + def modelsPath(self): + modelsPath = self.fileCachePath.joinpath("models") + modelsPath.mkdir(exist_ok=True, parents=True) + return modelsPath + + @property + def modelsDescriptionJsonFilePath(self): + return os.path.join(self.moduleDir, "Resources", "Models.json") + + def __init__(self): + self.fileCachePath = Path.home().joinpath(f".{APPLICATION_NAME}") + self.moduleDir = Path(__file__).parent.parent + + # Disabling this flag preserves input and output data after execution is completed, + # which can be useful for troubleshooting. + self.clearOutputFolder = True + self._models = [] + + def model(self, modelId): + for model in self.models: + if model["id"] == modelId: + return model + raise RuntimeError(f"Model {modelId} not found") + + def loadModelsDescription(self): + modelsJsonFilePath = self.modelsDescriptionJsonFilePath + try: + models = [] + import json + import re + with open(modelsJsonFilePath) as f: + modelsTree = json.load(f)["models"] + for model in modelsTree: + deprecated = False + for version in model["versions"]: + url = version["url"] + # URL format: /-v.zip + # Example URL: https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/Models/17-segments-TotalSegmentator-v1.0.3.zip + match = re.search(r"(?P[^/]+)-v(?P\d+\.\d+\.\d+)", url) + if match: + filename = match.group("filename") + version = match.group("version") + else: + logger.error(f"Failed to extract model id and version from url: {url}") + if "inputs" in model: + # Contains a list of dict. One dict for each input. + # Currently, only "title" (user-displayable name) and "namePattern" of the input are specified. + # In the future, inputs could have additional properties, such as name, type, optional, ... + inputs = model["inputs"] + else: + # Inputs are not defined, use default (single input volume) + inputs = [{"title": "Input volume"}] + segmentNames = model.get('segmentNames') + if not segmentNames: + segmentNames = "N/A" + models.append({ + "id": f"{filename}-v{version}", + "title": model['title'], + "version": version, + "inputs": inputs, + "imagingModality": model["imagingModality"], + "description": model["description"], + "sampleData": model.get("sampleData"), + "segmentNames": model.get("segmentNames"), + "details": + f"

Model: {model['title']} (v{version})" + f"

Description: {model['description']}\n" + f"

Computation time on GPU: {humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n" + f"
Computation time on CPU: {humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n" + f"

Imaging modality: {model['imagingModality']}\n" + f"

Subject: {model['subject']}\n" + f"

Segments: {', '.join(segmentNames)}", + "url": url, + "deprecated": deprecated + }) + # First version is not deprecated, all subsequent versions are deprecated + deprecated = True + return models + except Exception as e: + import traceback + traceback.print_exc() + raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}") + + def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath): + modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath + + with open(modelsTestResultsJsonFilePath) as f: + modelsTestResults = json.load(f) + + with open(modelsDescriptionJsonFilePath) as f: + modelsDescription = json.load(f) + + for model in modelsDescription["models"]: + title = model["title"] + for modelTestResult in modelsTestResults: + if modelTestResult["title"] == title: + for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]: + fieldValue = modelTestResult.get(fieldName) + if fieldValue: + model[fieldName] = fieldValue + break + + with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f: + json.dump(modelsDescription, f, indent=2) + + def createModelsDir(self): + modelsDir = self.modelsPath + if not os.path.exists(modelsDir): + os.makedirs(modelsDir) + + def modelPath(self, modelName): + try: + return self._modelPath(modelName) + except RuntimeError: + self.downloadModel(modelName) + return self._modelPath(modelName) + + def _modelPath(self, modelName): + modelRoot = self.modelsPath.joinpath(modelName) + # find labels.csv file within the modelRoot folder and subfolders + for path in Path(modelRoot).rglob("labels.csv"): + return path.parent + raise RuntimeError(f"Model {modelName} path not found") + + def deleteAllModels(self): + if self.modelsPath.exists(): + import shutil + shutil.rmtree(self.modelsPath) + + def downloadModel(self, modelName): + url = self.model(modelName)["url"] + import zipfile + import requests + from tempfile import TemporaryDirectory + with TemporaryDirectory() as td: + tempDir = Path(td) + modelDir = self.modelsPath.joinpath(modelName) + Path(modelDir).mkdir(exist_ok=True) + modelZipFile = tempDir.joinpath("autoseg3d_model.zip") + logger.info(f"Downloading model '{modelName}' from {url}...") + logger.debug(f"Downloading from {url} to {modelZipFile}...") + try: + with open(modelZipFile, 'wb') as f: + with requests.get(url, stream=True) as r: + r.raise_for_status() + total_size = int(r.headers.get('content-length', 0)) + reporting_increment_percent = 1.0 + last_reported_download_percent = -reporting_increment_percent + downloaded_size = 0 + for chunk in r.iter_content(chunk_size=8192 * 16): + f.write(chunk) + downloaded_size += len(chunk) + downloaded_percent = 100.0 * downloaded_size / total_size + if downloaded_percent - last_reported_download_percent > reporting_increment_percent: + logger.info( + f"Downloading model: {downloaded_size / 1024 / 1024:.1f}MB / {total_size / 1024 / 1024:.1f}MB ({downloaded_percent:.1f}%)") + last_reported_download_percent = downloaded_percent + + logger.info(f"Download finished. Extracting to {modelDir}...") + with zipfile.ZipFile(modelZipFile, 'r') as zip_f: + zip_f.extractall(modelDir) + except Exception as e: + raise e + finally: + if self.clearOutputFolder: + logger.info("Cleaning up temporary model download folder...") + if os.path.isdir(tempDir): + import shutil + shutil.rmtree(tempDir) + else: + logger.info(f"Not cleaning up temporary model download folder: {tempDir}") diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py new file mode 100644 index 0000000..ec5bdfc --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py @@ -0,0 +1,95 @@ +import slicer +import psutil + +import logging +import queue +import threading + +import qt +from pathlib import Path +from typing import Callable + +from MONAIAuto3DSegLib.constants import APPLICATION_NAME +logger = logging.getLogger(APPLICATION_NAME) + + +class WebServer: + + CHECK_TIMER_INTERVAL = 1000 + + @staticmethod + def getPSProcess(pid): + try: + return psutil.Process(pid) + except psutil.NoSuchProcess: + return None + + def __init__(self, completedCallback: Callable): + self.completedCallback = completedCallback + self.procThread = None + self.serverProc = None + self.queue = None + + def isRunning(self): + if self.serverProc is not None: + psProcess = self.getPSProcess(self.serverProc.pid) + if psProcess: + return psProcess.is_running() + return False + + def __del__(self): + self.killProcess() + + def killProcess(self): + if not self.serverProc: + return + psProcess = self.getPSProcess(self.serverProc.pid) + if not psProcess: + return + for psChildProcess in psProcess.children(recursive=True): + psChildProcess.kill() + if psProcess.is_running(): + psProcess.kill() + + def launchConsoleProcess(self, cmd): + self.serverProc = \ + slicer.util.launchConsoleProcess(cmd, cwd=Path(__file__).parent.parent / "auto3dseg", useStartupEnvironment=False) + + self.queue = queue.Queue() + self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess) + self.procThread.start() + self.checkProcessOutput() + + def cleanup(self): + if self.procThread: + self.procThread.join() + self.completedCallback() + self.serverProc = None + self.procThread = None + self.queue = None + + def _handleProcessOutputThreadProcess(self): + while True: + try: + line = self.serverProc.stdout.readline() + if not line: + break + self.queue.put(line.rstrip()) + except UnicodeDecodeError as e: + pass + self.serverProc.wait() + + def checkProcessOutput(self): + outputQueue = self.queue + while outputQueue: + try: + line = outputQueue.get_nowait() + logger.info(line) + except queue.Empty: + break + + psProcess = self.getPSProcess(self.serverProc.pid) + if psProcess and psProcess.is_running(): # No more outputs to process now, check again later + qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput) + else: + self.cleanup() \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py new file mode 100644 index 0000000..186c9cc --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py @@ -0,0 +1,35 @@ + + +def humanReadableTimeFromSec(seconds): + import math + if not seconds: + return "N/A" + if seconds < 55: + # if less than a minute, round up to the nearest 5 seconds + return f"{math.ceil(seconds / 5) * 5} sec" + elif seconds < 60 * 60: + # if less than 1 hour, round up to the nearest minute + return f"{math.ceil(seconds / 60)} min" + # Otherwise round up to the nearest 0.1 hour + return f"{seconds / 3600:.1f} h" + + +def assignInputNodesByName(inputs, loadedSampleNodes): + inputNodes = [] + for inputIndex, input in enumerate(inputs): + namePattern = input.get("namePattern") + if namePattern: + matchingNode = findFirstNodeByNamePattern(namePattern, loadedSampleNodes) + else: + matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else \ + loadedSampleNodes[0] + inputNodes.append(matchingNode) + return inputNodes + + +def findFirstNodeByNamePattern(namePattern, nodes): + import fnmatch + for node in nodes: + if fnmatch.fnmatchcase(node.GetName(), namePattern): + return node + return None \ No newline at end of file diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index d57b851..1c4fd3c 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -6,8 +6,8 @@ 0 0 - 402 - 653 + 408 + 674 @@ -17,7 +17,7 @@ Inputs - + @@ -39,14 +39,14 @@ - + Input volume 1: - + @@ -67,14 +67,14 @@ - + Input volume 2: - + @@ -95,14 +95,14 @@ - + Input volume 3: - + @@ -123,14 +123,14 @@ - + Input volume 4: - + @@ -151,7 +151,7 @@ - + Download sample data set for the current segmentation model @@ -161,7 +161,7 @@ - + @@ -182,13 +182,47 @@ - + Segmentation model: + + + + + + + 0 + 0 + + + + true + + + + + + + QPushButton:checked { + /* Checked style */ + background-color: green; /* Green background when checked */ + border-color: darkgreen; /* Darker border color when checked */ +} + + + Connect + + + true + + + + + @@ -401,18 +435,106 @@ - - - QPlainTextEdit::NoWrap - - - true + + + Server Address: - - Qt::TextSelectableByMouse + + false + + + + + Port: + + + + + + + 0 + + + 65535 + + + 1 + + + + + + + Log to Console: + + + + + + + + + + + + + + Log to GUI: + + + + + + + + + + + + + + + + + + + + + QPushButton:checked { + /* Checked style */ + background-color: green; /* Green background when checked */ + border-color: darkgreen; /* Darker border color when checked */ +} + + + Start server + + + true + + + + + + + true + + + + + + + Server address: + + + + + + + @@ -437,6 +559,12 @@ QLineEdit

ctkSearchBox.h
+ + qMRMLCollapsibleButton + ctkCollapsibleButton +
qMRMLCollapsibleButton.h
+ 1 +
qMRMLNodeComboBox QWidget @@ -536,5 +664,21 @@ + + remoteServerButton + toggled(bool) + serverCollapsibleButton + setDisabled(bool) + + + 356 + 49 + + + 203 + 533 + + + diff --git a/MONAIAuto3DSeg/auto3dseg/__init__.py b/MONAIAuto3DSeg/auto3dseg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MONAIAuto3DSeg/auto3dseg/main.py b/MONAIAuto3DSeg/auto3dseg/main.py new file mode 100644 index 0000000..9be35ab --- /dev/null +++ b/MONAIAuto3DSeg/auto3dseg/main.py @@ -0,0 +1,139 @@ +# pip install python-multipart +# pip install fastapi +# pip install "uvicorn[standard]" + +# usage: uvicorn main:app --reload --host reslnjolleyws03.research.chop.edu --port 8891 +# usage: uvicorn main:app --reload --host localhost --port 8891 + + +import os +import logging +import sys +from pathlib import Path + +paths = [str(Path(__file__).parent.parent)] +for path in paths: + if not path in sys.path: + sys.path.insert(0, path) + +from MONAIAuto3DSegLib.model_database import ModelDatabase +from MONAIAuto3DSegLib.dependency_handler import LocalPythonDependencies +from MONAIAuto3DSegLib.constants import APPLICATION_NAME + +import shutil +import asyncio +import subprocess +from fastapi import FastAPI, UploadFile +from fastapi.responses import FileResponse +from fastapi import HTTPException +from fastapi.background import BackgroundTasks + + +app = FastAPI() +modelDB = ModelDatabase() +dependencyHandler = LocalPythonDependencies() + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(APPLICATION_NAME) + + +def upload(file, session_dir, identifier): + extension = "".join(Path(file.filename).suffixes) + file_location = f"{session_dir}/{identifier}{extension}" + with open(file_location, "wb+") as file_object: + file_object.write(file.file.read()) + return file_location + + +@app.get("/monaiinfo") +def monaiInfo(): + return dependencyHandler.installedMONAIPythonPackageInfo() + + +@app.get("/models") +async def models(): + return modelDB.models + + +@app.get("/modelinfo") +async def getModelInfo(id: str): + return modelDB.model(id) + + +@app.get("/labelDescriptions") +def getLabelsFile(id: str): + return FileResponse(modelDB.modelPath(id).joinpath("labels.csv"), media_type = 'application/octet-stream', filename="labels.csv") + + +@app.post("/infer") +async def infer( + background_tasks: BackgroundTasks, + image_file: UploadFile, + model_name: str, + image_file_2: UploadFile = None, + image_file_3: UploadFile = None, + image_file_4: UploadFile = None +): + import tempfile + session_dir = tempfile.mkdtemp(dir=tempfile.gettempdir()) + background_tasks.add_task(shutil.rmtree, session_dir) + + logging.debug(session_dir) + inputFiles = list() + inputFiles.append(upload(image_file, session_dir, "image_file")) + if image_file_2: + inputFiles.append(upload(image_file_2, session_dir, "image_file_2")) + if image_file_3: + inputFiles.append(upload(image_file_3, session_dir, "image_file_3")) + if image_file_4: + inputFiles.append(upload(image_file_4, session_dir, "image_file_4")) + + # logging.info("Input Files: ", inputFiles) + + outputSegmentationFile = f"{session_dir}/output-segmentation.nrrd" + + modelPath = modelDB.modelPath(model_name) + modelPtFile = modelPath.joinpath("model.pt") + + assert os.path.exists(modelPtFile) + + dependencyHandler.setupPythonRequirements() + + moduleDir = Path(__file__).parent.parent + inferenceScriptPyFile = os.path.join(moduleDir, "Scripts", "auto3dseg_segresnet_inference.py") + auto3DSegCommand = [sys.executable, str(inferenceScriptPyFile), + "--model-file", str(modelPtFile), + "--image-file", inputFiles[0], + "--result-file", str(outputSegmentationFile)] + for inputIndex in range(1, len(inputFiles)): + auto3DSegCommand.append(f"--image-file-{inputIndex + 1}") + auto3DSegCommand.append(inputFiles[inputIndex]) + + try: + # logger.info(auto3DSegCommand) + proc = await asyncio.create_subprocess_shell(" ".join(auto3DSegCommand)) + await proc.wait() + if proc.returncode != 0: + raise subprocess.CalledProcessError(proc.returncode, " ".join(auto3DSegCommand)) + return FileResponse(outputSegmentationFile, media_type='application/octet-stream', background=background_tasks) + except Exception as e: + shutil.rmtree(session_dir) + raise HTTPException(status_code=500, detail=f"Failed to run CMD command: {str(e)}") + + +def main(argv): + import argparse + parser = argparse.ArgumentParser(description="MONAIAuto3DSeg server") + parser.add_argument("-ip", "--host", type=str, metavar="PATH", required=False, default="localhost", help="host name") + parser.add_argument("-p", "--port", type=int, metavar="PATH", required=True, help="port") + + args = parser.parse_args(argv) + + import uvicorn + + uvicorn.run("main:app", host=args.host, port=args.port, reload=True, log_level="debug") + + +if __name__ == "__main__": + main(sys.argv[1:])