From a088ebb324be7dc7f27e36050bf9fb54e4f21a17 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Fri, 5 Jul 2024 15:18:20 -0400 Subject: [PATCH 01/27] ENH: fastapi implementation for running inference remotely - model database class that handles all local models including their download and/or deletion - extracted functions into utils module - fastapi server can be run from Slicer directly or from commandline - make sure loaded terminologies are searched --- MONAIAuto3DSeg/CMakeLists.txt | 9 + MONAIAuto3DSeg/MONAIAuto3DSeg.py | 966 +++++++++--------- MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py | 0 MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py | 1 + .../MONAIAuto3DSegLib/dependency_handler.py | 130 +++ .../MONAIAuto3DSegLib/log_handler.py | 27 + .../MONAIAuto3DSegLib/model_database.py | 196 ++++ MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py | 95 ++ MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py | 35 + MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui | 188 +++- MONAIAuto3DSeg/auto3dseg/__init__.py | 0 MONAIAuto3DSeg/auto3dseg/main.py | 139 +++ 12 files changed, 1291 insertions(+), 495 deletions(-) create mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py create mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py create mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py create mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py create mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py create mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py create mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py create mode 100644 MONAIAuto3DSeg/auto3dseg/__init__.py create mode 100644 MONAIAuto3DSeg/auto3dseg/main.py diff --git a/MONAIAuto3DSeg/CMakeLists.txt b/MONAIAuto3DSeg/CMakeLists.txt index 18fc468..3d8a9c9 100644 --- a/MONAIAuto3DSeg/CMakeLists.txt +++ b/MONAIAuto3DSeg/CMakeLists.txt @@ -4,6 +4,15 @@ set(MODULE_NAME MONAIAuto3DSeg) #----------------------------------------------------------------------------- set(MODULE_PYTHON_SCRIPTS ${MODULE_NAME}.py + ${MODULE_NAME}Lib/__init__.py + ${MODULE_NAME}Lib/constants.py + ${MODULE_NAME}Lib/dependency_handler.py + ${MODULE_NAME}Lib/log_handler.py + ${MODULE_NAME}Lib/model_database.py + ${MODULE_NAME}Lib/server.py + ${MODULE_NAME}Lib/utils.py + auto3dseg/__init__py + auto3dseg/main.py ) set(MODULE_PYTHON_RESOURCES diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 04eb64d..17533aa 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -1,12 +1,22 @@ import logging import os -import re +import json +import sys import vtk +import qt import slicer +import requests from slicer.ScriptedLoadableModule import * from slicer.util import VTKObservationMixin +from MONAIAuto3DSegLib.model_database import ModelDatabase +from MONAIAuto3DSegLib.constants import APPLICATION_NAME +from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec, assignInputNodesByName +from MONAIAuto3DSegLib.dependency_handler import SlicerPythonDependencies, RemotePythonDependencies + + +logger = logging.getLogger(APPLICATION_NAME) # @@ -24,7 +34,7 @@ def __init__(self, parent): self.parent.title = "MONAI Auto3DSeg" self.parent.categories = ["Segmentation"] self.parent.dependencies = [] - self.parent.contributors = ["Andras Lasso (PerkLab, Queen's University)", "Andres Diaz-Pinto (NVIDIA & KCL)", "Rudolf Bumm (KSGR Switzerland)"] + self.parent.contributors = ["Andras Lasso (PerkLab, Queen's University)", "Andres Diaz-Pinto (NVIDIA & KCL)", "Rudolf Bumm (KSGR Switzerland), Christian Herz (CHOP)"] self.parent.helpText = """ 3D Slicer extension for segmentation using MONAI Auto3DSeg AI model. See more information in the extension documentation. @@ -181,6 +191,18 @@ def __init__(self, parent=None): self._updatingGUIFromParameterNode = False self._processingState = MONAIAuto3DSegWidget.PROCESSING_IDLE self._segmentationProcessInfo = None + self._webServer = None + + from MONAIAuto3DSegLib.log_handler import LogHandler + handler = LogHandler(self.addLog, logging.INFO) + formatter = logging.Formatter('%(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + def onReload(self): + if self._webServer: + self._webServer.killProcess() + super().onReload() def setup(self): """ @@ -194,7 +216,6 @@ def setup(self): self.layout.addWidget(uiWidget) self.ui = slicer.util.childWidgetVariables(uiWidget) - import qt self.ui.downloadSampleDataToolButton.setIcon(qt.QIcon(self.resourcePath("Icons/radiology.svg"))) self.inputNodeSelectors = [self.ui.inputNodeSelector0, self.ui.inputNodeSelector1, self.ui.inputNodeSelector2, self.ui.inputNodeSelector3] @@ -208,7 +229,6 @@ def setup(self): # Create logic class. Logic implements all computations that should be possible to run # in batch mode, without a graphical user interface. self.logic = MONAIAuto3DSegLogic() - self.logic.logCallback = self.addLog self.logic.processingCompletedCallback = self.onProcessingCompleted self.logic.startResultImportCallback = self.onProcessImportStarted self.logic.endResultImportCallback = self.onProcessImportEnded @@ -241,8 +261,15 @@ def setup(self): self.ui.browseToModelsFolderButton.connect("clicked(bool)", self.onBrowseModelsFolder) self.ui.deleteAllModelsButton.connect("clicked(bool)", self.onClearModelsFolder) + self.ui.serverComboBox.lineEdit().setPlaceholderText("enter server address or leave empty to use default") + self.ui.serverComboBox.currentIndexChanged.connect(self.onRemoteServerButtonToggled) + self.ui.remoteServerButton.toggled.connect(self.onRemoteServerButtonToggled) + + self.ui.serverButton.toggled.connect(self.onServerButtonToggled) + # Make sure parameter node is initialized (needed for module reload) self.initializeParameterNode() + self.updateServerUrlGUIFromSettings() self.updateGUIFromParameterNode() @@ -325,15 +352,12 @@ def updateGUIFromParameterNode(self, caller=None, event=None): This method is called whenever parameter node is changed. The module GUI is updated to show the current state of the parameter node. """ - import qt - if self._parameterNode is None or self._updatingGUIFromParameterNode: return # Make sure GUI changes do not call updateParameterNodeFromGUI (it could cause infinite loop) self._updatingGUIFromParameterNode = True try: - self.ui.modelSearchBox.text = self._parameterNode.GetParameter("ModelSearchText") searchWords = self._parameterNode.GetParameter("ModelSearchText").lower().split() @@ -440,6 +464,23 @@ def updateGUIFromParameterNode(self, caller=None, event=None): self.ui.applyButton.toolTip = "Please wait for the segmentation to be cancelled" self.ui.applyButton.enabled = False + remoteConnection = self.ui.remoteServerButton.checked + + # if remoteConnection: + # self.ui.serverCollapsibleButton.collapsed = True + + self.ui.portSpinBox.value = int(self._parameterNode.GetParameter("ServerPort")) + + self.ui.browseToModelsFolderButton.enabled = not remoteConnection + self.ui.useStandardSegmentNamesCheckBox.enabled = not remoteConnection + self.ui.cpuCheckBox.enabled = not remoteConnection + self.ui.showAllModelsCheckBox.enabled = not remoteConnection + self.ui.deleteAllModelsButton.enabled = not remoteConnection + self.ui.packageUpgradeButton.enabled = not remoteConnection + + serverRunning = self._webServer is not None and self._webServer.isRunning() + self.ui.serverButton.checked = serverRunning + self.ui.serverButton.text = "Running ..." if serverRunning else "Start server" finally: # All the GUI updates are done self._updatingGUIFromParameterNode = False @@ -470,6 +511,7 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): self._parameterNode.SetParameter("ShowAllModels", "true" if self.ui.showAllModelsCheckBox.checked else "false") self._parameterNode.SetParameter("UseStandardSegmentNames", "true" if self.ui.useStandardSegmentNamesCheckBox.checked else "false") self._parameterNode.SetNodeReferenceID("OutputSegmentation", self.ui.outputSegmentationSelector.currentNodeID) + self._parameterNode.SetParameter("ServerPort", str(self.ui.portSpinBox.value)) finally: self._parameterNode.EndModify(wasModified) @@ -477,8 +519,23 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): def addLog(self, text): """Append text to log window """ - self.ui.statusLabel.appendPlainText(text) - slicer.app.processEvents() # force update + if len(self.ui.statusLabel.html) > 1024 * 256: + self.ui.statusLabel.clear() + self.ui.statusLabel.insertHtml("Log cleared\n") + self.ui.statusLabel.insertHtml(text) + self.ui.statusLabel.insertPlainText("\n") + self.ui.statusLabel.ensureCursorVisible() + self.ui.statusLabel.repaint() + + # self.ui.statusLabel.appendPlainText(text) + # slicer.app.processEvents() # force update + + def addServerLog(self, *args): + for arg in args: + if self.ui.logConsoleCheckBox.checked: + print(arg) + if self.ui.logGuiCheckBox.checked: + self.addLog(arg) def setProcessingState(self, state): self._processingState = state @@ -500,10 +557,6 @@ def onApply(self): self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_STARTING) - if not self.logic.dependenciesInstalled: - with slicer.util.tryWithErrorDisplay("Failed to install required dependencies.", waitCursor=True): - self.logic.setupPythonRequirements() - try: with slicer.util.tryWithErrorDisplay("Failed to start processing.", waitCursor=True): @@ -518,10 +571,10 @@ def onApply(self): for inputNodeSelector in self.inputNodeSelectors: if inputNodeSelector.visible: inputNodes.append(inputNodeSelector.currentNode()) - self._segmentationProcessInfo = self.logic.process(inputNodes, self.ui.outputSegmentationSelector.currentNode(), - self._currentModelId(), self.ui.cpuCheckBox.checked, waitForCompletion=False) self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IN_PROGRESS) + self._segmentationProcessInfo = self.logic.process(inputNodes, self.ui.outputSegmentationSelector.currentNode(), + self._currentModelId(), self.ui.cpuCheckBox.checked, waitForCompletion=False) except Exception as e: self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) @@ -533,22 +586,19 @@ def onCancel(self): def onProcessImportStarted(self, customData): self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IMPORT_RESULTS) - import qt qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor) slicer.app.processEvents() def onProcessImportEnded(self, customData): - import qt qt.QApplication.restoreOverrideCursor() slicer.app.processEvents() def onProcessingCompleted(self, returnCode, customData): - self.ui.statusLabel.appendPlainText("\nProcessing finished.") + self.addLog("\nProcessing finished.") self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) self._segmentationProcessInfo = None def _currentModelId(self): - import qt itemIndex = self.ui.modelComboBox.currentRow item = self.ui.modelComboBox.item(itemIndex) if not item: @@ -556,7 +606,6 @@ def _currentModelId(self): return item.data(qt.Qt.UserRole) def _setCurrentModelId(self, modelId): - import qt for itemIndex in range(self.ui.modelComboBox.count): item = self.ui.modelComboBox.item(itemIndex) if item.data(qt.Qt.UserRole) == modelId: @@ -565,7 +614,9 @@ def _setCurrentModelId(self, modelId): return False def onDownloadSampleData(self): - model = self.logic.model(self._currentModelId()) + with slicer.util.tryWithErrorDisplay("Failed to retrieve model information", waitCursor=True): + model = self.logic.model(self._currentModelId()) + sampleDataName = model.get("sampleData") if not sampleDataName: slicer.util.messageBox("No sample data is available for this model.") @@ -584,7 +635,7 @@ def onDownloadSampleData(self): slicer.util.messageBox(f"Failed to load sample data set '{sampleDataName}'.") return - inputNodes = MONAIAuto3DSegLogic.assignInputNodesByName(inputs, loadedSampleNodes) + inputNodes = assignInputNodesByName(inputs, loadedSampleNodes) for inputIndex, inputNode in enumerate(inputNodes): if inputNode: self.inputNodeSelectors[inputIndex].setCurrentNode(inputNode) @@ -592,24 +643,25 @@ def onDownloadSampleData(self): def onPackageInfoUpdate(self): self.ui.packageInfoTextBrowser.plainText = "" with slicer.util.tryWithErrorDisplay("Failed to get MONAI package version information", waitCursor=True): - self.ui.packageInfoTextBrowser.plainText = self.logic.installedMONAIPythonPackageInfo().rstrip() + self.ui.packageInfoTextBrowser.plainText = self.logic.getMONAIPythonPackageInfo().rstrip() def onPackageUpgrade(self): + restartRequired = True with slicer.util.tryWithErrorDisplay("Failed to upgrade MONAI", waitCursor=True): - self.logic.setupPythonRequirements(upgrade=True) + restartRequired = self.logic.setupPythonRequirements(upgrade=True) self.onPackageInfoUpdate() - if not slicer.util.confirmOkCancelDisplay(f"This MONAI update requires a 3D Slicer restart.","Press OK to restart."): - raise ValueError("Restart was cancelled.") - else: - slicer.util.restart() + if restartRequired: + if not slicer.util.confirmOkCancelDisplay(f"This MONAI update requires a 3D Slicer restart.","Press OK to restart."): + raise ValueError("Restart was cancelled.") + else: + slicer.util.restart() def onBrowseModelsFolder(self): - import qt self.logic.createModelsDir() - qt.QDesktopServices().openUrl(qt.QUrl.fromLocalFile(self.logic.modelsPath())) + qt.QDesktopServices().openUrl(qt.QUrl.fromLocalFile(self.logic.modelsPath)) def onClearModelsFolder(self): - if not os.path.exists(self.logic.modelsPath()): + if not os.path.exists(self.logic.modelsPath): slicer.util.messageBox("There are no downloaded models.") return if not slicer.util.confirmOkCancelDisplay("All downloaded model files will be deleted. The files will be automatically downloaded again as needed."): @@ -617,11 +669,103 @@ def onClearModelsFolder(self): self.logic.deleteAllModels() slicer.util.messageBox("Downloaded models are deleted.") + def onRemoteServerButtonToggled(self): + if self.ui.remoteServerButton.checked: + self.ui.remoteServerButton.text = "Connected" + self.logic = RemoteMONAIAuto3DSegLogic() + self.logic.server_address = self.ui.serverComboBox.currentText + try: + _ = self.logic.models + self.addLog(f"Remote Server Connected {self.logic.server_address}") + except: + slicer.util.warningDisplay( + f"Connection to remote server '{self.logic.server_address}' failed. " + f"Please check address, port, and connection" + ) + self.ui.remoteServerButton.checked = False + return + self.saveServerUrl() + else: + self.ui.remoteServerButton.text = "Connect" + self.logic = MONAIAuto3DSegLogic() + + self.logic.processingCompletedCallback = self.onProcessingCompleted + self.logic.startResultImportCallback = self.onProcessImportStarted + self.logic.endResultImportCallback = self.onProcessImportEnded + self.updateGUIFromParameterNode() + + def onServerButtonToggled(self, toggled): + if toggled: + if not self._webServer or not self._webServer.isRunning() : + import platform + from pathlib import Path + slicer.util.pip_install("python-multipart fastapi uvicorn[standard]") + + hostName = platform.node() + port = str(self.ui.portSpinBox.value) + cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] + + self.ui.serverAddressLineEdit.text = f"http://{hostName}:{port}" + + from MONAIAuto3DSegLib.server import WebServer + self._webServer = WebServer(completedCallback=lambda: self.ui.serverButton.setChecked(False)) + self._webServer.launchConsoleProcess(cmd) + else: + if self._webServer is not None and self._webServer.isRunning(): + logger.info("Server stop requested.") + self._webServer.killProcess() + if not self._webServer.isRunning(): + logger.info("Server stopped.") + + self._webServer = None + self.updateGUIFromParameterNode() + + def serverUrl(self): + serverUrl = self.ui.serverComboBox.currentText.strip() + if not serverUrl: + serverUrl = "http://127.0.0.1:8000" + return serverUrl.rstrip("/") + + def saveServerUrl(self): + settings = qt.QSettings() + serverUrl = self.ui.serverComboBox.currentText + settings.setValue(f"{self.moduleName}/serverUrl", serverUrl) + serverUrlHistory = self._getServerUrlHistory(serverUrl, settings) + serverUrlHistory.insert(0, serverUrl) # Save current server URL to the top of history + serverUrlHistory = serverUrlHistory[:10] # keep up to first 10 elements + settings.setValue(f"{self.moduleName}/serverUrlHistory", ";".join(serverUrlHistory)) + self.updateServerUrlGUIFromSettings() + + def _getServerUrlHistory(self, serverUrl, settings): + serverUrlHistory = settings.value(f"{self.moduleName}/serverUrlHistory") + if serverUrlHistory: + serverUrlHistory = serverUrlHistory.split(";") + else: + serverUrlHistory = [] + try: + serverUrlHistory.remove(serverUrl) + except ValueError: + pass + return serverUrlHistory + + def updateServerUrlGUIFromSettings(self): + settings = qt.QSettings() + serverUrlHistory = settings.value(f"{self.moduleName}/serverUrlHistory") + + wasBlocked = self.ui.serverComboBox.blockSignals(True) + self.ui.serverComboBox.clear() + if serverUrlHistory: + self.ui.serverComboBox.addItems(serverUrlHistory.split(";")) + self.ui.serverComboBox.setCurrentText(settings.value(f"{self.moduleName}/serverUrl")) + self.ui.serverComboBox.blockSignals(wasBlocked) + + # # MONAIAuto3DSegLogic # -class MONAIAuto3DSegLogic(ScriptedLoadableModuleLogic): + +class MONAIAuto3DSegLogic(ScriptedLoadableModuleLogic, ModelDatabase): """This class should implement all the actual computation done by your module. The interface should be such that other python code can import @@ -634,244 +778,65 @@ class MONAIAuto3DSegLogic(ScriptedLoadableModuleLogic): EXIT_CODE_USER_CANCELLED = 1001 EXIT_CODE_DID_NOT_RUN = 1002 - def __init__(self): - """ - Called when the logic class is instantiated. Can be used for initializing member variables. - """ - from collections import OrderedDict - - ScriptedLoadableModuleLogic.__init__(self) - - import pathlib - self.fileCachePath = pathlib.Path.home().joinpath(".MONAIAuto3DSeg") - - self.dependenciesInstalled = False # we don't know yet if dependencies have been installed - - self.moduleDir = os.path.dirname(slicer.util.getModule('MONAIAuto3DSeg').path) - - self.logCallback = None - self.processingCompletedCallback = None - self.startResultImportCallback = None - self.endResultImportCallback = None - self.useStandardSegmentNames = True - - # List of property type codes that are specified by in the MONAIAuto3DSeg terminology. - # - # Codes are stored as a list of strings containing coding scheme designator and code value of the property type, - # separated by "^" character. For example "SCT^123456". - # - # If property the code is found in this list then the MONAIAuto3DSeg terminology will be used, - # otherwise the DICOM terminology will be used. This is necessary because the DICOM terminology - # does not contain all the necessary items and some items are incomplete (e.g., don't have color or 3D Slicer label). - # - self.MONAIAuto3DSegTerminologyPropertyTypes = self._MONAIAuto3DSegTerminologyPropertyTypes() - - # List of anatomic regions that are specified by MONAIAuto3DSeg. - self.MONAIAuto3DSegAnatomicRegions = self._MONAIAuto3DSegAnatomicRegions() - - # Segmentation models specified by in models.json file - self.models = self.loadModelsDescription() - self.defaultModel = self.models[0]["id"] - - # Timer for checking the output of the segmentation process that is running in the background - self.processOutputCheckTimerIntervalMsec = 1000 - - # Disabling this flag preserves input and output data after execution is completed, - # which can be useful for troubleshooting. - self.clearOutputFolder = True - - # For testing the logic without actually running inference, set self.debugSkipInferenceTempDir to the location - # where inference result is stored and set self.debugSkipInference to True. - self.debugSkipInference = False - self.debugSkipInferenceTempDir = r"c:\Users\andra\AppData\Local\Temp\Slicer\__SlicerTemp__2024-01-16_15+26+25.624" - - - def model(self, modelId): - for model in self.models: - if model["id"] == modelId: - return model - raise RuntimeError(f"Model {modelId} not found") + DEPENDENCY_HANDLER = SlicerPythonDependencies() + @staticmethod + def getLoadedTerminologyNames(): + import vtk + terminologyNames = vtk.vtkStringArray() + terminologiesLogic = slicer.util.getModuleLogic("Terminologies") + terminologiesLogic.GetLoadedTerminologyNames(terminologyNames) - def modelsDescriptionJsonFilePath(self): - return os.path.join(self.moduleDir, "Resources", "Models.json") - - def loadModelsDescription(self): - modelsJsonFilePath = self.modelsDescriptionJsonFilePath() - try: - models = [] - import json - import re - with open(modelsJsonFilePath) as f: - modelsTree = json.load(f)["models"] - for model in modelsTree: - deprecated = False - for version in model["versions"]: - url = version["url"] - # URL format: /-v.zip - # Example URL: https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/Models/17-segments-TotalSegmentator-v1.0.3.zip - match = re.search(r"(?P[^/]+)-v(?P\d+\.\d+\.\d+)", url) - if match: - filename = match.group("filename") - version = match.group("version") - else: - logging.error(f"Failed to extract model id and version from url: {url}") - if "inputs" in model: - # Contains a list of dict. One dict for each input. - # Currently, only "title" (user-displayable name) and "namePattern" of the input are specified. - # In the future, inputs could have additional properties, such as name, type, optional, ... - inputs = model["inputs"] - else: - # Inputs are not defined, use default (single input volume) - inputs = [{"title": "Input volume"}] - segmentNames = model.get('segmentNames') - if not segmentNames: - segmentNames = "N/A" - models.append({ - "id": f"{filename}-v{version}", - "title": model['title'], - "version": version, - "inputs": inputs, - "imagingModality": model["imagingModality"], - "description": model["description"], - "sampleData": model.get("sampleData"), - "segmentNames": model.get("segmentNames"), - "details": - f"

Model: {model['title']} (v{version})" - f"

Description: {model['description']}\n" - f"

Computation time on GPU: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n" - f"
Computation time on CPU: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n" - f"

Imaging modality: {model['imagingModality']}\n" - f"

Subject: {model['subject']}\n" - f"

Segments: {', '.join(segmentNames)}", - "url": url, - "deprecated": deprecated - }) - # First version is not deprecated, all subsequent versions are deprecated - deprecated = True - return models - except Exception as e: - import traceback - traceback.print_exc() - raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}") + return [terminologyNames.GetValue(idx) for idx in range(terminologyNames.GetNumberOfValues())] @staticmethod - def humanReadableTimeFromSec(seconds): - import math - if not seconds: - return "N/A" - if seconds < 55: - # if less than a minute, round up to the nearest 5 seconds - return f"{math.ceil(seconds/5) * 5} sec" - elif seconds < 60 * 60: - # if less then 1 hour, round up to the nearest minute - return f"{math.ceil(seconds/60)} min" - # Otherwise round up to the nearest 0.1 hour - return f"{seconds/3600:.1f} h" - - def modelsPath(self): - import pathlib - return self.fileCachePath.joinpath("models") - - def createModelsDir(self): - modelsDir = self.modelsPath() - if not os.path.exists(modelsDir): - os.makedirs(modelsDir) - - def modelPath(self, modelName): - import pathlib - modelRoot = self.modelsPath().joinpath(modelName) - # find labels.csv file within the modelRoot folder and subfolders - for path in pathlib.Path(modelRoot).rglob("labels.csv"): - return path.parent - raise RuntimeError(f"Model {modelName} path not found") - - def deleteAllModels(self): - if self.modelsPath().exists(): - import shutil - shutil.rmtree(self.modelsPath()) + def getLoadedAnatomicContextNames(): + import vtk + anatomicContextNames = vtk.vtkStringArray() + terminologiesLogic = slicer.util.getModuleLogic("Terminologies") + terminologiesLogic.GetLoadedAnatomicContextNames(anatomicContextNames) + + return [anatomicContextNames.GetValue(idx) for idx in range(anatomicContextNames.GetNumberOfValues())] def downloadAllModels(self): for model in self.models: slicer.app.processEvents() self.downloadModel(model["id"]) - def downloadModel(self, modelName): - - url = self.model(modelName)["url"] - - import zipfile - import requests - import pathlib - - tempDir = pathlib.Path(slicer.util.tempDirectory()) - modelDir = self.modelsPath().joinpath(modelName) - if not os.path.exists(modelDir): - os.makedirs(modelDir) - - modelZipFile = tempDir.joinpath("autoseg3d_model.zip") - self.log(f"Downloading model '{modelName}' from {url}...") - logging.debug(f"Downloading from {url} to {modelZipFile}...") - - try: - with open(modelZipFile, 'wb') as f: - with requests.get(url, stream=True) as r: - r.raise_for_status() - total_size = int(r.headers.get('content-length', 0)) - reporting_increment_percent = 1.0 - last_reported_download_percent = -reporting_increment_percent - downloaded_size = 0 - for chunk in r.iter_content(chunk_size=8192 * 16): - f.write(chunk) - downloaded_size += len(chunk) - downloaded_percent = 100.0 * downloaded_size / total_size - if downloaded_percent - last_reported_download_percent > reporting_increment_percent: - self.log(f"Downloading model: {downloaded_size/1024/1024:.1f}MB / {total_size/1024/1024:.1f}MB ({downloaded_percent:.1f}%)") - last_reported_download_percent = downloaded_percent - - self.log(f"Download finished. Extracting to {modelDir}...") - with zipfile.ZipFile(modelZipFile, 'r') as zip_f: - zip_f.extractall(modelDir) - except Exception as e: - raise e - finally: - if self.clearOutputFolder: - self.log("Cleaning up temporary model download folder...") - if os.path.isdir(tempDir): - import shutil - shutil.rmtree(tempDir) - else: - self.log(f"Not cleaning up temporary model download folder: {tempDir}") - - - def _MONAIAuto3DSegTerminologyPropertyTypes(self): + @staticmethod + def _terminologyPropertyTypes(terminologyName): """Get label terminology property types defined in from MONAI Auto3DSeg terminology. Terminology entries are either in DICOM or MONAI Auto3DSeg "Segmentation category and type". - """ + # List of property type codes that are specified by in the terminology. + # + # Codes are stored as a list of strings containing coding scheme designator and code value of the property type, + # separated by "^" character. For example "SCT^123456". + + """ terminologiesLogic = slicer.util.getModuleLogic("Terminologies") - MONAIAuto3DSegTerminologyName = slicer.modules.MONAIAuto3DSegInstance.terminologyName + terminologyPropertyTypes = [] # Get anatomicalStructureCategory from the MONAI Auto3DSeg terminology anatomicalStructureCategory = slicer.vtkSlicerTerminologyCategory() - numberOfCategories = terminologiesLogic.GetNumberOfCategoriesInTerminology(MONAIAuto3DSegTerminologyName) - for i in range(numberOfCategories): - terminologiesLogic.GetNthCategoryInTerminology(MONAIAuto3DSegTerminologyName, i, anatomicalStructureCategory) - if anatomicalStructureCategory.GetCodingSchemeDesignator() == "SCT" and anatomicalStructureCategory.GetCodeValue() == "123037004": - # Found the (123037004, SCT, "Anatomical Structure") category within DICOM master list - break - - # Retrieve all anatomicalStructureCategory property type codes - terminologyPropertyTypes = [] - terminologyType = slicer.vtkSlicerTerminologyType() - numberOfTypes = terminologiesLogic.GetNumberOfTypesInTerminologyCategory(MONAIAuto3DSegTerminologyName, anatomicalStructureCategory) - for i in range(numberOfTypes): - if terminologiesLogic.GetNthTypeInTerminologyCategory(MONAIAuto3DSegTerminologyName, anatomicalStructureCategory, i, terminologyType): - terminologyPropertyTypes.append(terminologyType.GetCodingSchemeDesignator() + "^" + terminologyType.GetCodeValue()) + numberOfCategories = terminologiesLogic.GetNumberOfCategoriesInTerminology(terminologyName) + for cIdx in range(numberOfCategories): + terminologiesLogic.GetNthCategoryInTerminology(terminologyName, cIdx, anatomicalStructureCategory) + + # Retrieve all anatomicalStructureCategory property type codes + terminologyType = slicer.vtkSlicerTerminologyType() + numberOfTypes = terminologiesLogic.GetNumberOfTypesInTerminologyCategory(terminologyName, + anatomicalStructureCategory) + for tIdx in range(numberOfTypes): + if terminologiesLogic.GetNthTypeInTerminologyCategory(terminologyName, anatomicalStructureCategory, tIdx, + terminologyType): + terminologyPropertyTypes.append( + terminologyType.GetCodingSchemeDesignator() + "^" + terminologyType.GetCodeValue()) return terminologyPropertyTypes - def _MONAIAuto3DSegAnatomicRegions(self): + @staticmethod + def _anatomicRegions(anatomicContextName): """Get anatomic regions defined in from MONAI Auto3DSeg terminology. Terminology entries are either in DICOM or MONAI Auto3DSeg "Anatomic codes". """ @@ -884,25 +849,101 @@ def _MONAIAuto3DSegAnatomicRegions(self): # when editing the terminology on the GUI) return anatomicRegions - MONAIAuto3DSegAnatomicContextName = slicer.modules.MONAIAuto3DSegInstance.anatomicContextName - # Retrieve all anatomical region codes - regionObject = slicer.vtkSlicerTerminologyType() - numberOfRegions = terminologiesLogic.GetNumberOfRegionsInAnatomicContext(MONAIAuto3DSegAnatomicContextName) + numberOfRegions = terminologiesLogic.GetNumberOfRegionsInAnatomicContext(anatomicContextName) for i in range(numberOfRegions): - if terminologiesLogic.GetNthRegionInAnatomicContext(MONAIAuto3DSegAnatomicContextName, i, regionObject): + if terminologiesLogic.GetNthRegionInAnatomicContext(anatomicContextName, i, regionObject): anatomicRegions.append(regionObject.GetCodingSchemeDesignator() + "^" + regionObject.GetCodeValue()) return anatomicRegions + @staticmethod + def getSegmentLabelColor(terminologyEntryStr): + """Get segment label and color from terminology""" + + def labelColorFromTypeObject(typeObject): + """typeObject is a terminology type or type modifier""" + label = typeObject.GetSlicerLabel() if typeObject.GetSlicerLabel() else typeObject.GetCodeMeaning() + rgb = typeObject.GetRecommendedDisplayRGBValue() + return label, (rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0) + + tlogic = slicer.modules.terminologies.logic() + + terminologyEntry = slicer.vtkSlicerTerminologyEntry() + if not tlogic.DeserializeTerminologyEntry(terminologyEntryStr, terminologyEntry): + raise RuntimeError(f"Failed to deserialize terminology string: {terminologyEntryStr}") + + numberOfTypes = tlogic.GetNumberOfTypesInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), + terminologyEntry.GetCategoryObject()) + foundTerminologyEntry = slicer.vtkSlicerTerminologyEntry() + for typeIndex in range(numberOfTypes): + tlogic.GetNthTypeInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), + terminologyEntry.GetCategoryObject(), typeIndex, + foundTerminologyEntry.GetTypeObject()) + if terminologyEntry.GetTypeObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeObject().GetCodingSchemeDesignator(): + continue + if terminologyEntry.GetTypeObject().GetCodeValue() != foundTerminologyEntry.GetTypeObject().GetCodeValue(): + continue + if terminologyEntry.GetTypeModifierObject() and terminologyEntry.GetTypeModifierObject().GetCodeValue(): + # Type has a modifier, get the color from there + numberOfModifiers = tlogic.GetNumberOfTypeModifiersInTerminologyType( + terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), + terminologyEntry.GetTypeObject()) + foundMatchingModifier = False + for modifierIndex in range(numberOfModifiers): + tlogic.GetNthTypeModifierInTerminologyType(terminologyEntry.GetTerminologyContextName(), + terminologyEntry.GetCategoryObject(), + terminologyEntry.GetTypeObject(), + modifierIndex, + foundTerminologyEntry.GetTypeModifierObject()) + if terminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator(): + continue + if terminologyEntry.GetTypeModifierObject().GetCodeValue() != foundTerminologyEntry.GetTypeModifierObject().GetCodeValue(): + continue + return labelColorFromTypeObject(foundTerminologyEntry.GetTypeModifierObject()) + continue + return labelColorFromTypeObject(foundTerminologyEntry.GetTypeObject()) + + raise RuntimeError(f"Color was not found for terminology {terminologyEntryStr}") + + def __init__(self): + """ + Called when the logic class is instantiated. Can be used for initializing member variables. + """ + ScriptedLoadableModuleLogic.__init__(self) + ModelDatabase.__init__(self) + + self.processingCompletedCallback = None + self.startResultImportCallback = None + self.endResultImportCallback = None + self.useStandardSegmentNames = True + + # Timer for checking the output of the segmentation process that is running in the background + self.processOutputCheckTimerIntervalMsec = 1000 + + # For testing the logic without actually running inference, set self.debugSkipInferenceTempDir to the location + # where inference result is stored and set self.debugSkipInference to True. + self.debugSkipInference = False + self.debugSkipInferenceTempDir = r"c:\Users\andra\AppData\Local\Temp\Slicer\__SlicerTemp__2024-01-16_15+26+25.624" + + def getMONAIPythonPackageInfo(self): + return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo() + + def setupPythonRequirements(self, upgrade=False): + self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade) + return True + def labelDescriptions(self, modelName): """Return mapping from label value to label description. Label description is a dict containing "name" and "terminology". Terminology string uses Slicer terminology entry format - see specification at https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag """ + labelsFilePath = self.modelPath(modelName).joinpath("labels.csv") + return self._labelDescriptions(labelsFilePath) + def _labelDescriptions(self, labelsFilePath): # Helper function to get code string from CSV file row def getCodeString(field, columnNames, row): columnValues = [] @@ -917,173 +958,68 @@ def getCodeString(field, columnNames, row): return columnValues labelDescriptions = {} - labelsFilePath = self.modelPath(modelName).joinpath("labels.csv") import csv with open(labelsFilePath, "r") as f: reader = csv.reader(f) columnNames = next(reader) - data = {} # Loop through the rows of the csv file for row in reader: - # Determine segmentation category (DICOM or MONAIAuto3DSeg) terminologyPropertyTypeStr = ( # Example: SCT^23451007 - row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodingSchemeDesignator")] - + "^" + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodeValue")]) - if terminologyPropertyTypeStr in self.MONAIAuto3DSegTerminologyPropertyTypes: - terminologyName = slicer.modules.MONAIAuto3DSegInstance.terminologyName - else: + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodingSchemeDesignator")] + + "^" + row[columnNames.index("SegmentedPropertyTypeCodeSequence.CodeValue")]) + terminologyName = None + + # If property the code is found in this list then the terminology will be used, + for tName in self.getLoadedTerminologyNames(): + propertyTypes = self._terminologyPropertyTypes(tName) + if terminologyPropertyTypeStr in propertyTypes: + terminologyName = tName + break + + # NB: DICOM terminology will be used otherwise. Note: the DICOM terminology does not contain all the + # necessary items and some items are incomplete (e.g., don't have color or 3D Slicer label). + if not terminologyName: terminologyName = "Segmentation category and type - DICOM master list" # Determine the anatomic context name (DICOM or MONAIAuto3DSeg) anatomicRegionStr = ( # Example: SCT^279245009 - row[columnNames.index("AnatomicRegionSequence.CodingSchemeDesignator")] - + "^" + row[columnNames.index("AnatomicRegionSequence.CodeValue")]) - if anatomicRegionStr in self.MONAIAuto3DSegAnatomicRegions: - anatomicContextName = slicer.modules.MONAIAuto3DSegInstance.anatomicContextName - else: + row[columnNames.index("AnatomicRegionSequence.CodingSchemeDesignator")] + + "^" + row[columnNames.index("AnatomicRegionSequence.CodeValue")]) + anatomicContextName = None + for aName in self.getLoadedAnatomicContextNames(): + if anatomicRegionStr in self._anatomicRegions(aName): + anatomicContextName = aName + if not anatomicContextName: anatomicContextName = "Anatomic codes - DICOM master list" terminologyEntryStr = ( - terminologyName - +"~" - # Property category: "SCT^123037004^Anatomical Structure" or "SCT^49755003^Morphologically Altered Structure" - + "^".join(getCodeString("SegmentedPropertyCategoryCodeSequence", columnNames, row)) - + "~" - # Property type: "SCT^23451007^Adrenal gland", "SCT^367643001^Cyst", ... - + "^".join(getCodeString("SegmentedPropertyTypeCodeSequence", columnNames, row)) - + "~" - # Property type modifier: "SCT^7771000^Left", ... - + "^".join(getCodeString("SegmentedPropertyTypeModifierCodeSequence", columnNames, row)) - + "~" - + anatomicContextName - + "~" - # Anatomic region (set if category is not anatomical structure): "SCT^64033007^Kidney", ... - + "^".join(getCodeString("AnatomicRegionSequence", columnNames, row)) - + "~" - # Anatomic region modifier: "SCT^7771000^Left", ... - + "^".join(getCodeString("AnatomicRegionModifierSequence", columnNames, row)) - ) + terminologyName + + "~" + # Property category: "SCT^123037004^Anatomical Structure" or "SCT^49755003^Morphologically Altered Structure" + + "^".join(getCodeString("SegmentedPropertyCategoryCodeSequence", columnNames, row)) + + "~" + # Property type: "SCT^23451007^Adrenal gland", "SCT^367643001^Cyst", ... + + "^".join(getCodeString("SegmentedPropertyTypeCodeSequence", columnNames, row)) + + "~" + # Property type modifier: "SCT^7771000^Left", ... + + "^".join(getCodeString("SegmentedPropertyTypeModifierCodeSequence", columnNames, row)) + + "~" + + anatomicContextName + + "~" + # Anatomic region (set if category is not anatomical structure): "SCT^64033007^Kidney", ... + + "^".join(getCodeString("AnatomicRegionSequence", columnNames, row)) + + "~" + # Anatomic region modifier: "SCT^7771000^Left", ... + + "^".join(getCodeString("AnatomicRegionModifierSequence", columnNames, row)) + ) # Store the terminology string for this structure labelValue = int(row[columnNames.index("LabelValue")]) name = row[columnNames.index("Name")] - labelDescriptions[labelValue] = { "name": name, "terminology": terminologyEntryStr } - + labelDescriptions[labelValue] = {"name": name, "terminology": terminologyEntryStr} return labelDescriptions - def getSegmentLabelColor(self, terminologyEntryStr): - """Get segment label and color from terminology""" - - def labelColorFromTypeObject(typeObject): - """typeObject is a terminology type or type modifier""" - label = typeObject.GetSlicerLabel() if typeObject.GetSlicerLabel() else typeObject.GetCodeMeaning() - rgb = typeObject.GetRecommendedDisplayRGBValue() - return label, (rgb[0]/255.0, rgb[1]/255.0, rgb[2]/255.0) - - tlogic = slicer.modules.terminologies.logic() - - terminologyEntry = slicer.vtkSlicerTerminologyEntry() - if not tlogic.DeserializeTerminologyEntry(terminologyEntryStr, terminologyEntry): - raise RuntimeError(f"Failed to deserialize terminology string: {terminologyEntryStr}") - - numberOfTypes = tlogic.GetNumberOfTypesInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject()) - foundTerminologyEntry = slicer.vtkSlicerTerminologyEntry() - for typeIndex in range(numberOfTypes): - tlogic.GetNthTypeInTerminologyCategory(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), typeIndex, foundTerminologyEntry.GetTypeObject()) - if terminologyEntry.GetTypeObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeObject().GetCodingSchemeDesignator(): - continue - if terminologyEntry.GetTypeObject().GetCodeValue() != foundTerminologyEntry.GetTypeObject().GetCodeValue(): - continue - if terminologyEntry.GetTypeModifierObject() and terminologyEntry.GetTypeModifierObject().GetCodeValue(): - # Type has a modifier, get the color from there - numberOfModifiers = tlogic.GetNumberOfTypeModifiersInTerminologyType(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), terminologyEntry.GetTypeObject()) - foundMatchingModifier = False - for modifierIndex in range(numberOfModifiers): - tlogic.GetNthTypeModifierInTerminologyType(terminologyEntry.GetTerminologyContextName(), terminologyEntry.GetCategoryObject(), terminologyEntry.GetTypeObject(), - modifierIndex, foundTerminologyEntry.GetTypeModifierObject()) - if terminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator() != foundTerminologyEntry.GetTypeModifierObject().GetCodingSchemeDesignator(): - continue - if terminologyEntry.GetTypeModifierObject().GetCodeValue() != foundTerminologyEntry.GetTypeModifierObject().GetCodeValue(): - continue - return labelColorFromTypeObject(foundTerminologyEntry.GetTypeModifierObject()) - continue - return labelColorFromTypeObject(foundTerminologyEntry.GetTypeObject()) - - raise RuntimeError(f"Color was not found for terminology {terminologyEntryStr}") - - @staticmethod - def _findFirstNodeBynamePattern(namePattern, nodes): - import fnmatch - for node in nodes: - if fnmatch.fnmatchcase(node.GetName(), namePattern): - return node - return None - - @staticmethod - def assignInputNodesByName(inputs, loadedSampleNodes): - inputNodes = [] - for inputIndex, input in enumerate(inputs): - namePattern = input.get("namePattern") - if namePattern: - matchingNode = MONAIAuto3DSegLogic._findFirstNodeBynamePattern(namePattern, loadedSampleNodes) - else: - matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else loadedSampleNodes[0] - inputNodes.append(matchingNode) - return inputNodes - - def log(self, text): - logging.info(text) - if self.logCallback: - self.logCallback(text) - - def installedMONAIPythonPackageInfo(self): - import shutil - import subprocess - versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode() - return versionInfo - - def setupPythonRequirements(self, upgrade=False): - import importlib.metadata - import importlib.util - import packaging - - # Install PyTorch - try: - import PyTorchUtils - except ModuleNotFoundError as e: - raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.") - - self.log("Initializing PyTorch...") - minimumTorchVersion = "1.12" - torchLogic = PyTorchUtils.PyTorchUtilsLogic() - if not torchLogic.torchInstalled(): - self.log("PyTorch Python package is required. Installing... (it may take several minutes)") - torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement = f">={minimumTorchVersion}") - if torch is None: - raise ValueError("PyTorch extension needs to be installed to use this module.") - else: - # torch is installed, check version - from packaging import version - if version.parse(torchLogic.torch.__version__) < version.parse(minimumTorchVersion): - raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module." - + f" Minimum required version is {minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" - + f" with version requirement set to: >={minimumTorchVersion}") - - # Install MONAI with required components - self.log("Initializing MONAI...") - # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too). - # Without this, for some users monai-0.9.0 got installed, which failed with this error: - # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'" - monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" - if upgrade: - monaiInstallString += " --upgrade" - slicer.util.pip_install(monaiInstallString) - - self.dependenciesInstalled = True - self.log("Dependencies are set up successfully.") - - def setDefaultParameters(self, parameterNode): """ Initialize parameter node with default settings. @@ -1092,6 +1028,8 @@ def setDefaultParameters(self, parameterNode): parameterNode.SetParameter("Model", self.defaultModel) if not parameterNode.GetParameter("UseStandardSegmentNames"): parameterNode.SetParameter("UseStandardSegmentNames", "true") + if not parameterNode.GetParameter("ServerPort"): + parameterNode.SetParameter("ServerPort", str(8891)) def logProcessOutputUntilCompleted(self, segmentationProcessInfo): # Wait for the process to end and forward output to the log @@ -1102,7 +1040,7 @@ def logProcessOutputUntilCompleted(self, segmentationProcessInfo): line = proc.stdout.readline() if not line: break - self.log(line.rstrip()) + logger.info(line.rstrip()) except UnicodeDecodeError as e: # Code page conversion happens because `universal_newlines=True` sets process output to text mode, # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string, @@ -1115,7 +1053,6 @@ def logProcessOutputUntilCompleted(self, segmentationProcessInfo): raise CalledProcessError(retcode, proc.args, output=proc.stdout, stderr=proc.stderr) def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitForCompletion=True, customData=None): - """ Run the processing algorithm. Can be used without GUI widget. @@ -1127,26 +1064,26 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor :param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False """ + if not self.DEPENDENCY_HANDLER.dependenciesInstalled: + with slicer.util.tryWithErrorDisplay("Failed to install required dependencies.", waitCursor=True): + self.DEPENDENCY_HANDLER.setupPythonRequirements() + if not inputNodes: raise ValueError("Input nodes are invalid") if not outputSegmentation: raise ValueError("Output segmentation is invalid") - if model == None: + if model is None: model = self.defaultModel - try: - modelPath = self.modelPath(model) - except: - self.downloadModel(model) - modelPath = self.modelPath(model) + modelPath = self.modelPath(model) segmentationProcessInfo = {} import time startTime = time.time() - self.log("Processing started") + logger.info("Processing started") if self.debugSkipInference: # For debugging, use a fixed temporary folder @@ -1155,9 +1092,6 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor # Create new empty folder tempDir = slicer.util.tempDirectory() - import pathlib - tempDirPath = pathlib.Path(tempDir) - # Get Python executable path import shutil pythonSlicerExecutablePath = shutil.which("PythonSlicer") @@ -1169,7 +1103,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor for inputIndex, inputNode in enumerate(inputNodes): if inputNode.IsA('vtkMRMLScalarVolumeNode'): inputImageFile = tempDir + f"/input-volume{inputIndex}.nrrd" - self.log(f"Writing input file to {inputImageFile}") + logger.info(f"Writing input file to {inputImageFile}") volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode") volumeStorageNode.SetFileName(inputImageFile) volumeStorageNode.UseCompressionOff() @@ -1190,13 +1124,13 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor auto3DSegCommand.append(f"--image-file-{inputIndex+1}") auto3DSegCommand.append(inputFiles[inputIndex]) - self.log("Creating segmentations with MONAIAuto3DSeg AI...") - self.log(f"Auto3DSeg command: {auto3DSegCommand}") + logger.info("Creating segmentations with MONAIAuto3DSeg AI...") + logger.info(f"Auto3DSeg command: {auto3DSegCommand}") additionalEnvironmentVariables = None if cpu: additionalEnvironmentVariables = {"CUDA_VISIBLE_DEVICES": "-1"} - self.log(f"Additional environment variables: {additionalEnvironmentVariables}") + logger.info(f"Additional environment variables: {additionalEnvironmentVariables}") if self.debugSkipInference: proc = None @@ -1230,7 +1164,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor return segmentationProcessInfo def cancelProcessing(self, segmentationProcessInfo): - self.log("Cancel is requested.") + logger.info("Cancel is requested.") segmentationProcessInfo["cancelRequested"] = True proc = segmentationProcessInfo.get("proc") if proc: @@ -1248,7 +1182,6 @@ def cancelProcessing(self, segmentationProcessInfo): def _handleProcessOutputThreadProcess(segmentationProcessInfo): # Wait for the process to end and forward output to the log proc = segmentationProcessInfo["proc"] - from subprocess import CalledProcessError while True: try: line = proc.stdout.readline() @@ -1264,10 +1197,8 @@ def _handleProcessOutputThreadProcess(segmentationProcessInfo): retcode = proc.returncode # non-zero return code means error segmentationProcessInfo["procReturnCode"] = retcode - def startSegmentationProcessMonitoring(self, segmentationProcessInfo): import queue - import sys import threading segmentationProcessInfo["procOutputQueue"] = queue.Queue() @@ -1276,9 +1207,7 @@ def startSegmentationProcessMonitoring(self, segmentationProcessInfo): self.checkSegmentationProcessOutput(segmentationProcessInfo) - def checkSegmentationProcessOutput(self, segmentationProcessInfo): - import queue outputQueue = segmentationProcessInfo["procOutputQueue"] while outputQueue: @@ -1287,17 +1216,14 @@ def checkSegmentationProcessOutput(self, segmentationProcessInfo): return try: line = outputQueue.get_nowait() - self.log(line) + logger.info(line) except queue.Empty: break # No more outputs to process now, check again later - import qt qt.QTimer.singleShot(self.processOutputCheckTimerIntervalMsec, lambda segmentationProcessInfo=segmentationProcessInfo: self.checkSegmentationProcessOutput(segmentationProcessInfo)) - def onSegmentationProcessCompleted(self, segmentationProcessInfo): - startTime = segmentationProcessInfo["startTime"] tempDir = segmentationProcessInfo["tempDir"] inputNodes = segmentationProcessInfo["inputNodes"] @@ -1310,17 +1236,14 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): if cancelRequested: procReturnCode = MONAIAuto3DSegLogic.EXIT_CODE_USER_CANCELLED - self.log(f"Processing was cancelled.") + logger.info(f"Processing was cancelled.") else: if procReturnCode == 0: - if self.startResultImportCallback: self.startResultImportCallback(customData) - try: - - # Load result - self.log("Importing segmentation results...") + try: # Load result + logger.info("Importing segmentation results...") self.readSegmentation(outputSegmentation, outputSegmentationFile, model) # Set source volume - required for DICOM Segmentation export @@ -1338,20 +1261,18 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): shNode.SetItemParent(segmentationShItem, studyShItem) finally: - if self.endResultImportCallback: self.endResultImportCallback(customData) - else: - self.log(f"Processing failed with return code {procReturnCode}") + logger.info(f"Processing failed with return code {procReturnCode}") if self.clearOutputFolder: - self.log("Cleaning up temporary folder.") + logger.info("Cleaning up temporary folder.") if os.path.isdir(tempDir): import shutil shutil.rmtree(tempDir) else: - self.log(f"Not cleaning up temporary folder: {tempDir}") + logger.info(f"Not cleaning up temporary folder: {tempDir}") # Report total elapsed time import time @@ -1359,19 +1280,17 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): segmentationProcessInfo["stopTime"] = stopTime elapsedTime = stopTime - startTime if cancelRequested: - self.log(f"Processing was cancelled after {elapsedTime:.2f} seconds.") + logger.info(f"Processing was cancelled after {elapsedTime:.2f} seconds.") else: if procReturnCode == 0: - self.log(f"Processing was completed in {elapsedTime:.2f} seconds.") + logger.info(f"Processing was completed in {elapsedTime:.2f} seconds.") else: - self.log(f"Processing failed after {elapsedTime:.2f} seconds.") + logger.info(f"Processing failed after {elapsedTime:.2f} seconds.") if self.processingCompletedCallback: self.processingCompletedCallback(procReturnCode, customData) - def readSegmentation(self, outputSegmentation, outputSegmentationFile, model): - labelValueToDescription = self.labelDescriptions(model) # Get label descriptions @@ -1405,12 +1324,11 @@ def readSegmentation(self, outputSegmentation, outputSegmentationFile, model): # Set terminology and color for labelValue in labelValueToDescription: - segmentName = labelValueToDescription[labelValue]["name"] terminologyEntryStr = labelValueToDescription[labelValue]["terminology"] - segmentId = segmentName - self.setTerminology(outputSegmentation, segmentName, segmentId, terminologyEntryStr) + segmentId = labelValueToDescription[labelValue]["name"] + self.setTerminology(outputSegmentation, segmentId, terminologyEntryStr) - def setTerminology(self, segmentation, segmentName, segmentId, terminologyEntryStr): + def setTerminology(self, segmentation, segmentId, terminologyEntryStr): segment = segmentation.GetSegmentation().GetSegment(segmentId) if not segment: # Segment is not present in this segmentation @@ -1423,31 +1341,137 @@ def setTerminology(self, segmentation, segmentName, segmentId, terminologyEntryS segment.SetName(label) segment.SetColor(color) except RuntimeError as e: - self.log(str(e)) + logger.info(str(e)) - def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath): - import json - modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath() +class RemoteMONAIAuto3DSegLogic(MONAIAuto3DSegLogic): - with open(modelsTestResultsJsonFilePath) as f: - modelsTestResults = json.load(f) + DEPENDENCY_HANDLER = RemotePythonDependencies() - with open(modelsDescriptionJsonFilePath) as f: - modelsDescription = json.load(f) - - for model in modelsDescription["models"]: - title = model["title"] - for modelTestResult in modelsTestResults: - if modelTestResult["title"] == title: - for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]: - fieldValue = modelTestResult.get(fieldName) - if fieldValue: - model[fieldName] = fieldValue - break + @property + def server_address(self): + return self._server_address + + @server_address.setter + def server_address(self, address): + self._server_address = address + self._models = [] + + def __init__(self): + self._server_address = None + MONAIAuto3DSegLogic.__init__(self) + self._models = [] + + def getMONAIPythonPackageInfo(self): + return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo(self._server_address) - with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f: - json.dump(modelsDescription, f, indent=2) + def setupPythonRequirements(self, upgrade=False): + self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade) + return False + + def loadModelsDescription(self): + if not self._server_address: + return [] + else: + response = requests.get(self._server_address + "/models") + json_data = json.loads(response.text) + return json_data + + def labelDescriptions(self, modelId): + """Return mapping from label value to label description. + Label description is a dict containing "name" and "terminology". + Terminology string uses Slicer terminology entry format - see specification at + https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag + """ + if not self._server_address: + return [] + else: + import tempfile + with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile: + with requests.get(self._server_address + f"/labelDescriptions?id={modelId}", stream=True) as r: + r.raise_for_status() + + with open(tmpfile.name, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + return self._labelDescriptions(tmpfile.name) + + def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitForCompletion=True, customData=None): + """ + Run the processing algorithm. + Can be used without GUI widget. + :param inputNodes: input nodes in a list + :param outputVolume: thresholding result + :param modelId: one of self.models + :param cpu: use CPU instead of GPU + :param waitForCompletion: if True then the method waits for the processing to finish + :param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False + """ + + import time + startTime = time.time() + logger.info("Processing started") + + tempDir = slicer.util.tempDirectory() + outputSegmentationFile = tempDir + "/output-segmentation.nrrd" + + from tempfile import TemporaryDirectory + with TemporaryDirectory(dir=tempDir) as temp_dir: + # Write input volume to file + from pathlib import Path + tempDir = Path(temp_dir) + inputFiles = [] + for inputIndex, inputNode in enumerate(inputNodes): + if inputNode.IsA('vtkMRMLScalarVolumeNode'): + inputImageFile = tempDir / f"input-volume{inputIndex}.nrrd" + logger.info(f"Writing input file to {inputImageFile}") + volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode") + volumeStorageNode.SetFileName(inputImageFile) + volumeStorageNode.UseCompressionOff() + volumeStorageNode.WriteData(inputNode) + slicer.mrmlScene.RemoveNode(volumeStorageNode) + inputFiles.append(inputImageFile) + else: + raise ValueError(f"Input node type {inputNode.GetClassName()} is not supported") + + segmentationProcessInfo = {} + segmentationProcessInfo["procReturnCode"] = MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN + + logger.info(f"Initiating Inference on {self._server_address}") + files = {} + + try: + for idx, inputFile in enumerate(inputFiles, start=1): + name = "image_file" + if idx > 1: + name = f"{name}_{idx}" + files[name] = open(inputFile, 'rb') + + with requests.post(self._server_address + f"/infer?model_name={modelId}", files=files) as r: + r.raise_for_status() + + with open(outputSegmentationFile, "wb") as binary_file: + for chunk in r.iter_content(chunk_size=8192): + binary_file.write(chunk) + + segmentationProcessInfo["procReturnCode"] = 0 + finally: + for f in files.values(): + f.close() + + segmentationProcessInfo["cancelRequested"] = False + segmentationProcessInfo["startTime"] = startTime + segmentationProcessInfo["tempDir"] = tempDir + segmentationProcessInfo["inputNodes"] = inputNodes + segmentationProcessInfo["outputSegmentation"] = outputSegmentation + segmentationProcessInfo["outputSegmentationFile"] = outputSegmentationFile + segmentationProcessInfo["model"] = modelId + segmentationProcessInfo["customData"] = customData + + self.onSegmentationProcessCompleted(segmentationProcessInfo) + + return segmentationProcessInfo # # MONAIAuto3DSegTest @@ -1556,10 +1580,8 @@ def test_MONAIAuto3DSeg1(self): raise RuntimeError(f"Failed to load sample data set '{sampleDataName}'.") # Set model inputs - - inputNodes = [] inputs = model.get("inputs") - inputNodes = MONAIAuto3DSegLogic.assignInputNodesByName(inputs, loadedSampleNodes) + inputNodes = assignInputNodesByName(inputs, loadedSampleNodes) outputSegmentation = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") @@ -1684,14 +1706,12 @@ def _writeScreenshots(self, segmentationNode, outputPath, baseName, numberOfImag return sliceScreenshotFilename, rotate3dScreenshotFilename def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestResultsMarkdownFilePath=None, screenshotUrlBase=None): - if modelsTestResultsMarkdownFilePath is None: modelsTestResultsMarkdownFilePath = modelsTestResultsJsonFilePath.replace(".json", ".md") if screenshotUrlBase is None: screenshotUrlBase = "https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/ModelsTestResults/" import json - from MONAIAuto3DSeg import MONAIAuto3DSegLogic with open(modelsTestResultsJsonFilePath) as f: modelsTestResults = json.load(f) @@ -1718,7 +1738,7 @@ def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestR title = f"{model['title']} (v{model['version']})" f.write(f"## {title}\n") f.write(f"{model['description']}\n\n") - f.write(f"Processing time: {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {MONAIAuto3DSegLogic.humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n") + f.write(f"Processing time: {humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n") f.write(f"Segment names: {', '.join(model['segmentNames'])}\n\n") f.write(f"![2D view]({screenshotUrlBase}{model['segmentationResultsScreenshot2D']})\n") f.write(f"![3D view]({screenshotUrlBase}{model['segmentationResultsScreenshot3D']})\n") \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py new file mode 100644 index 0000000..63c5e7d --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py @@ -0,0 +1 @@ +APPLICATION_NAME = "MONAIAuto3DSeg" diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py new file mode 100644 index 0000000..438f120 --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py @@ -0,0 +1,130 @@ +import sys + +import shutil +import subprocess +import logging + +from MONAIAuto3DSegLib.constants import APPLICATION_NAME +logger = logging.getLogger(APPLICATION_NAME) + + +from abc import ABC, abstractmethod + + +class DependenciesBase(ABC): + + minimumTorchVersion = "1.12" + + def __init__(self): + self.dependenciesInstalled = False # we don't know yet if dependencies have been installed + + @abstractmethod + def installedMONAIPythonPackageInfo(self): + pass + + @abstractmethod + def setupPythonRequirements(self, upgrade=False): + pass + + +class LocalPythonDependencies(DependenciesBase): + + def installedMONAIPythonPackageInfo(self): + versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode() + return versionInfo + + def _checkModuleInstalled(self, moduleName): + try: + import importlib + importlib.import_module(moduleName) + return True + except ModuleNotFoundError: + return False + + def setupPythonRequirements(self, upgrade=False): + def install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + logger.info("Initializing PyTorch...") + + packageName = "torch" + if not self._checkModuleInstalled(packageName): + logger.info("PyTorch Python package is required. Installing... (it may take several minutes)") + install(packageName) + if not self._checkModuleInstalled(packageName): + raise ValueError("pytorch needs to be installed to use this module.") + else: # torch is installed, check version + from packaging import version + import torch + if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion): + raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module." + + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" + + f" with version requirement set to: >={self.minimumTorchVersion}") + + logger.info("Initializing MONAI...") + monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" + if upgrade: + monaiInstallString += " --upgrade" + install(monaiInstallString) + + self.dependenciesInstalled = True + logger.info("Dependencies are set up successfully.") + + +class RemotePythonDependencies(DependenciesBase): + + def installedMONAIPythonPackageInfo(self, server_address): + if not server_address: + return [] + else: + import json + import requests + response = requests.get(server_address + "/monaiinfo") + json_data = json.loads(response.text) + return json_data + + def setupPythonRequirements(self, upgrade=False): + logger.error("No permission to update remote python packages. Please contact developer.") + + +class SlicerPythonDependencies(DependenciesBase): + + def installedMONAIPythonPackageInfo(self): + versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode() + return versionInfo + + def setupPythonRequirements(self, upgrade=False): + # Install PyTorch + try: + import PyTorchUtils + except ModuleNotFoundError as e: + raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.") + + logger.info("Initializing PyTorch...") + + torchLogic = PyTorchUtils.PyTorchUtilsLogic() + if not torchLogic.torchInstalled(): + logger.info("PyTorch Python package is required. Installing... (it may take several minutes)") + torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement=f">={self.minimumTorchVersion}") + if torch is None: + raise ValueError("PyTorch extension needs to be installed to use this module.") + else: # torch is installed, check version + from packaging import version + if version.parse(torchLogic.torch.__version__) < version.parse(self.minimumTorchVersion): + raise ValueError(f"PyTorch version {torchLogic.torch.__version__} is not compatible with this module." + + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" + + f" with version requirement set to: >={self.minimumTorchVersion}") + + # Install MONAI with required components + logger.info("Initializing MONAI...") + # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too). + # Without this, for some users monai-0.9.0 got installed, which failed with this error: + # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'" + monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" + if upgrade: + monaiInstallString += " --upgrade" + import slicer + slicer.util.pip_install(monaiInstallString) + + self.dependenciesInstalled = True + logger.info("Dependencies are set up successfully.") \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py new file mode 100644 index 0000000..4e3540f --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py @@ -0,0 +1,27 @@ +import logging +from typing import Callable + + +class LogHandler(logging.Handler): + """ + + code: + + logger = logging.getLogger("XYZ") + + callback = ... # any callable + # NB: only catching info level messages and forwarding it to callback + handler = LogHandler(callback, logging.INFO) + # can format log messages + formatter = logging.Formatter('%(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + """ + def __init__(self, callback: Callable, level=logging.NOTSET): + self._callback = callback + super().__init__(level) + + def emit(self, record): + msg = self.format(record) + self._callback(msg) \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py new file mode 100644 index 0000000..d9ef612 --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py @@ -0,0 +1,196 @@ +import json +import logging +import os +from pathlib import Path + +from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec +from MONAIAuto3DSegLib.constants import APPLICATION_NAME + + +logger = logging.getLogger(APPLICATION_NAME) + + +class ModelDatabase: + + @property + def defaultModel(self): + return self.models[0]["id"] + + @property + def models(self): + if not self._models: + self._models = self.loadModelsDescription() + return self._models + + @property + def modelsPath(self): + modelsPath = self.fileCachePath.joinpath("models") + modelsPath.mkdir(exist_ok=True, parents=True) + return modelsPath + + @property + def modelsDescriptionJsonFilePath(self): + return os.path.join(self.moduleDir, "Resources", "Models.json") + + def __init__(self): + self.fileCachePath = Path.home().joinpath(f".{APPLICATION_NAME}") + self.moduleDir = Path(__file__).parent.parent + + # Disabling this flag preserves input and output data after execution is completed, + # which can be useful for troubleshooting. + self.clearOutputFolder = True + self._models = [] + + def model(self, modelId): + for model in self.models: + if model["id"] == modelId: + return model + raise RuntimeError(f"Model {modelId} not found") + + def loadModelsDescription(self): + modelsJsonFilePath = self.modelsDescriptionJsonFilePath + try: + models = [] + import json + import re + with open(modelsJsonFilePath) as f: + modelsTree = json.load(f)["models"] + for model in modelsTree: + deprecated = False + for version in model["versions"]: + url = version["url"] + # URL format: /-v.zip + # Example URL: https://github.com/lassoan/SlicerMONAIAuto3DSeg/releases/download/Models/17-segments-TotalSegmentator-v1.0.3.zip + match = re.search(r"(?P[^/]+)-v(?P\d+\.\d+\.\d+)", url) + if match: + filename = match.group("filename") + version = match.group("version") + else: + logger.error(f"Failed to extract model id and version from url: {url}") + if "inputs" in model: + # Contains a list of dict. One dict for each input. + # Currently, only "title" (user-displayable name) and "namePattern" of the input are specified. + # In the future, inputs could have additional properties, such as name, type, optional, ... + inputs = model["inputs"] + else: + # Inputs are not defined, use default (single input volume) + inputs = [{"title": "Input volume"}] + segmentNames = model.get('segmentNames') + if not segmentNames: + segmentNames = "N/A" + models.append({ + "id": f"{filename}-v{version}", + "title": model['title'], + "version": version, + "inputs": inputs, + "imagingModality": model["imagingModality"], + "description": model["description"], + "sampleData": model.get("sampleData"), + "segmentNames": model.get("segmentNames"), + "details": + f"

Model: {model['title']} (v{version})" + f"

Description: {model['description']}\n" + f"

Computation time on GPU: {humanReadableTimeFromSec(model.get('segmentationTimeSecGPU'))}\n" + f"
Computation time on CPU: {humanReadableTimeFromSec(model.get('segmentationTimeSecCPU'))}\n" + f"

Imaging modality: {model['imagingModality']}\n" + f"

Subject: {model['subject']}\n" + f"

Segments: {', '.join(segmentNames)}", + "url": url, + "deprecated": deprecated + }) + # First version is not deprecated, all subsequent versions are deprecated + deprecated = True + return models + except Exception as e: + import traceback + traceback.print_exc() + raise RuntimeError(f"Failed to load models description from {modelsJsonFilePath}") + + def updateModelsDescriptionJsonFilePathFromTestResults(self, modelsTestResultsJsonFilePath): + modelsDescriptionJsonFilePath = self.modelsDescriptionJsonFilePath + + with open(modelsTestResultsJsonFilePath) as f: + modelsTestResults = json.load(f) + + with open(modelsDescriptionJsonFilePath) as f: + modelsDescription = json.load(f) + + for model in modelsDescription["models"]: + title = model["title"] + for modelTestResult in modelsTestResults: + if modelTestResult["title"] == title: + for fieldName in ["segmentationTimeSecGPU", "segmentationTimeSecCPU", "segmentNames"]: + fieldValue = modelTestResult.get(fieldName) + if fieldValue: + model[fieldName] = fieldValue + break + + with open(modelsDescriptionJsonFilePath, 'w', newline="\n") as f: + json.dump(modelsDescription, f, indent=2) + + def createModelsDir(self): + modelsDir = self.modelsPath + if not os.path.exists(modelsDir): + os.makedirs(modelsDir) + + def modelPath(self, modelName): + try: + return self._modelPath(modelName) + except RuntimeError: + self.downloadModel(modelName) + return self._modelPath(modelName) + + def _modelPath(self, modelName): + modelRoot = self.modelsPath.joinpath(modelName) + # find labels.csv file within the modelRoot folder and subfolders + for path in Path(modelRoot).rglob("labels.csv"): + return path.parent + raise RuntimeError(f"Model {modelName} path not found") + + def deleteAllModels(self): + if self.modelsPath.exists(): + import shutil + shutil.rmtree(self.modelsPath) + + def downloadModel(self, modelName): + url = self.model(modelName)["url"] + import zipfile + import requests + from tempfile import TemporaryDirectory + with TemporaryDirectory() as td: + tempDir = Path(td) + modelDir = self.modelsPath.joinpath(modelName) + Path(modelDir).mkdir(exist_ok=True) + modelZipFile = tempDir.joinpath("autoseg3d_model.zip") + logger.info(f"Downloading model '{modelName}' from {url}...") + logger.debug(f"Downloading from {url} to {modelZipFile}...") + try: + with open(modelZipFile, 'wb') as f: + with requests.get(url, stream=True) as r: + r.raise_for_status() + total_size = int(r.headers.get('content-length', 0)) + reporting_increment_percent = 1.0 + last_reported_download_percent = -reporting_increment_percent + downloaded_size = 0 + for chunk in r.iter_content(chunk_size=8192 * 16): + f.write(chunk) + downloaded_size += len(chunk) + downloaded_percent = 100.0 * downloaded_size / total_size + if downloaded_percent - last_reported_download_percent > reporting_increment_percent: + logger.info( + f"Downloading model: {downloaded_size / 1024 / 1024:.1f}MB / {total_size / 1024 / 1024:.1f}MB ({downloaded_percent:.1f}%)") + last_reported_download_percent = downloaded_percent + + logger.info(f"Download finished. Extracting to {modelDir}...") + with zipfile.ZipFile(modelZipFile, 'r') as zip_f: + zip_f.extractall(modelDir) + except Exception as e: + raise e + finally: + if self.clearOutputFolder: + logger.info("Cleaning up temporary model download folder...") + if os.path.isdir(tempDir): + import shutil + shutil.rmtree(tempDir) + else: + logger.info(f"Not cleaning up temporary model download folder: {tempDir}") diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py new file mode 100644 index 0000000..ec5bdfc --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py @@ -0,0 +1,95 @@ +import slicer +import psutil + +import logging +import queue +import threading + +import qt +from pathlib import Path +from typing import Callable + +from MONAIAuto3DSegLib.constants import APPLICATION_NAME +logger = logging.getLogger(APPLICATION_NAME) + + +class WebServer: + + CHECK_TIMER_INTERVAL = 1000 + + @staticmethod + def getPSProcess(pid): + try: + return psutil.Process(pid) + except psutil.NoSuchProcess: + return None + + def __init__(self, completedCallback: Callable): + self.completedCallback = completedCallback + self.procThread = None + self.serverProc = None + self.queue = None + + def isRunning(self): + if self.serverProc is not None: + psProcess = self.getPSProcess(self.serverProc.pid) + if psProcess: + return psProcess.is_running() + return False + + def __del__(self): + self.killProcess() + + def killProcess(self): + if not self.serverProc: + return + psProcess = self.getPSProcess(self.serverProc.pid) + if not psProcess: + return + for psChildProcess in psProcess.children(recursive=True): + psChildProcess.kill() + if psProcess.is_running(): + psProcess.kill() + + def launchConsoleProcess(self, cmd): + self.serverProc = \ + slicer.util.launchConsoleProcess(cmd, cwd=Path(__file__).parent.parent / "auto3dseg", useStartupEnvironment=False) + + self.queue = queue.Queue() + self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess) + self.procThread.start() + self.checkProcessOutput() + + def cleanup(self): + if self.procThread: + self.procThread.join() + self.completedCallback() + self.serverProc = None + self.procThread = None + self.queue = None + + def _handleProcessOutputThreadProcess(self): + while True: + try: + line = self.serverProc.stdout.readline() + if not line: + break + self.queue.put(line.rstrip()) + except UnicodeDecodeError as e: + pass + self.serverProc.wait() + + def checkProcessOutput(self): + outputQueue = self.queue + while outputQueue: + try: + line = outputQueue.get_nowait() + logger.info(line) + except queue.Empty: + break + + psProcess = self.getPSProcess(self.serverProc.pid) + if psProcess and psProcess.is_running(): # No more outputs to process now, check again later + qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput) + else: + self.cleanup() \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py new file mode 100644 index 0000000..186c9cc --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py @@ -0,0 +1,35 @@ + + +def humanReadableTimeFromSec(seconds): + import math + if not seconds: + return "N/A" + if seconds < 55: + # if less than a minute, round up to the nearest 5 seconds + return f"{math.ceil(seconds / 5) * 5} sec" + elif seconds < 60 * 60: + # if less than 1 hour, round up to the nearest minute + return f"{math.ceil(seconds / 60)} min" + # Otherwise round up to the nearest 0.1 hour + return f"{seconds / 3600:.1f} h" + + +def assignInputNodesByName(inputs, loadedSampleNodes): + inputNodes = [] + for inputIndex, input in enumerate(inputs): + namePattern = input.get("namePattern") + if namePattern: + matchingNode = findFirstNodeByNamePattern(namePattern, loadedSampleNodes) + else: + matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else \ + loadedSampleNodes[0] + inputNodes.append(matchingNode) + return inputNodes + + +def findFirstNodeByNamePattern(namePattern, nodes): + import fnmatch + for node in nodes: + if fnmatch.fnmatchcase(node.GetName(), namePattern): + return node + return None \ No newline at end of file diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index d57b851..1c4fd3c 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -6,8 +6,8 @@ 0 0 - 402 - 653 + 408 + 674 @@ -17,7 +17,7 @@ Inputs - + @@ -39,14 +39,14 @@ - + Input volume 1: - + @@ -67,14 +67,14 @@ - + Input volume 2: - + @@ -95,14 +95,14 @@ - + Input volume 3: - + @@ -123,14 +123,14 @@ - + Input volume 4: - + @@ -151,7 +151,7 @@ - + Download sample data set for the current segmentation model @@ -161,7 +161,7 @@ - + @@ -182,13 +182,47 @@ - + Segmentation model: + + + + + + + 0 + 0 + + + + true + + + + + + + QPushButton:checked { + /* Checked style */ + background-color: green; /* Green background when checked */ + border-color: darkgreen; /* Darker border color when checked */ +} + + + Connect + + + true + + + + + @@ -401,18 +435,106 @@ - - - QPlainTextEdit::NoWrap - - - true + + + Server Address: - - Qt::TextSelectableByMouse + + false + + + + + Port: + + + + + + + 0 + + + 65535 + + + 1 + + + + + + + Log to Console: + + + + + + + + + + + + + + Log to GUI: + + + + + + + + + + + + + + + + + + + + + QPushButton:checked { + /* Checked style */ + background-color: green; /* Green background when checked */ + border-color: darkgreen; /* Darker border color when checked */ +} + + + Start server + + + true + + + + + + + true + + + + + + + Server address: + + + + + + + @@ -437,6 +559,12 @@ QLineEdit

ctkSearchBox.h
+ + qMRMLCollapsibleButton + ctkCollapsibleButton +
qMRMLCollapsibleButton.h
+ 1 +
qMRMLNodeComboBox QWidget @@ -536,5 +664,21 @@ + + remoteServerButton + toggled(bool) + serverCollapsibleButton + setDisabled(bool) + + + 356 + 49 + + + 203 + 533 + + + diff --git a/MONAIAuto3DSeg/auto3dseg/__init__.py b/MONAIAuto3DSeg/auto3dseg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MONAIAuto3DSeg/auto3dseg/main.py b/MONAIAuto3DSeg/auto3dseg/main.py new file mode 100644 index 0000000..9be35ab --- /dev/null +++ b/MONAIAuto3DSeg/auto3dseg/main.py @@ -0,0 +1,139 @@ +# pip install python-multipart +# pip install fastapi +# pip install "uvicorn[standard]" + +# usage: uvicorn main:app --reload --host reslnjolleyws03.research.chop.edu --port 8891 +# usage: uvicorn main:app --reload --host localhost --port 8891 + + +import os +import logging +import sys +from pathlib import Path + +paths = [str(Path(__file__).parent.parent)] +for path in paths: + if not path in sys.path: + sys.path.insert(0, path) + +from MONAIAuto3DSegLib.model_database import ModelDatabase +from MONAIAuto3DSegLib.dependency_handler import LocalPythonDependencies +from MONAIAuto3DSegLib.constants import APPLICATION_NAME + +import shutil +import asyncio +import subprocess +from fastapi import FastAPI, UploadFile +from fastapi.responses import FileResponse +from fastapi import HTTPException +from fastapi.background import BackgroundTasks + + +app = FastAPI() +modelDB = ModelDatabase() +dependencyHandler = LocalPythonDependencies() + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(APPLICATION_NAME) + + +def upload(file, session_dir, identifier): + extension = "".join(Path(file.filename).suffixes) + file_location = f"{session_dir}/{identifier}{extension}" + with open(file_location, "wb+") as file_object: + file_object.write(file.file.read()) + return file_location + + +@app.get("/monaiinfo") +def monaiInfo(): + return dependencyHandler.installedMONAIPythonPackageInfo() + + +@app.get("/models") +async def models(): + return modelDB.models + + +@app.get("/modelinfo") +async def getModelInfo(id: str): + return modelDB.model(id) + + +@app.get("/labelDescriptions") +def getLabelsFile(id: str): + return FileResponse(modelDB.modelPath(id).joinpath("labels.csv"), media_type = 'application/octet-stream', filename="labels.csv") + + +@app.post("/infer") +async def infer( + background_tasks: BackgroundTasks, + image_file: UploadFile, + model_name: str, + image_file_2: UploadFile = None, + image_file_3: UploadFile = None, + image_file_4: UploadFile = None +): + import tempfile + session_dir = tempfile.mkdtemp(dir=tempfile.gettempdir()) + background_tasks.add_task(shutil.rmtree, session_dir) + + logging.debug(session_dir) + inputFiles = list() + inputFiles.append(upload(image_file, session_dir, "image_file")) + if image_file_2: + inputFiles.append(upload(image_file_2, session_dir, "image_file_2")) + if image_file_3: + inputFiles.append(upload(image_file_3, session_dir, "image_file_3")) + if image_file_4: + inputFiles.append(upload(image_file_4, session_dir, "image_file_4")) + + # logging.info("Input Files: ", inputFiles) + + outputSegmentationFile = f"{session_dir}/output-segmentation.nrrd" + + modelPath = modelDB.modelPath(model_name) + modelPtFile = modelPath.joinpath("model.pt") + + assert os.path.exists(modelPtFile) + + dependencyHandler.setupPythonRequirements() + + moduleDir = Path(__file__).parent.parent + inferenceScriptPyFile = os.path.join(moduleDir, "Scripts", "auto3dseg_segresnet_inference.py") + auto3DSegCommand = [sys.executable, str(inferenceScriptPyFile), + "--model-file", str(modelPtFile), + "--image-file", inputFiles[0], + "--result-file", str(outputSegmentationFile)] + for inputIndex in range(1, len(inputFiles)): + auto3DSegCommand.append(f"--image-file-{inputIndex + 1}") + auto3DSegCommand.append(inputFiles[inputIndex]) + + try: + # logger.info(auto3DSegCommand) + proc = await asyncio.create_subprocess_shell(" ".join(auto3DSegCommand)) + await proc.wait() + if proc.returncode != 0: + raise subprocess.CalledProcessError(proc.returncode, " ".join(auto3DSegCommand)) + return FileResponse(outputSegmentationFile, media_type='application/octet-stream', background=background_tasks) + except Exception as e: + shutil.rmtree(session_dir) + raise HTTPException(status_code=500, detail=f"Failed to run CMD command: {str(e)}") + + +def main(argv): + import argparse + parser = argparse.ArgumentParser(description="MONAIAuto3DSeg server") + parser.add_argument("-ip", "--host", type=str, metavar="PATH", required=False, default="localhost", help="host name") + parser.add_argument("-p", "--port", type=int, metavar="PATH", required=True, help="port") + + args = parser.parse_args(argv) + + import uvicorn + + uvicorn.run("main:app", host=args.host, port=args.port, reload=True, log_level="debug") + + +if __name__ == "__main__": + main(sys.argv[1:]) From 24c200f6dc904c25d0254513d2dc901829655e0e Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Thu, 19 Sep 2024 12:15:01 -0400 Subject: [PATCH 02/27] BUG: kill web server when Slicer is closed Signed-off-by: Christian Herz --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 17533aa..70f0e56 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -280,6 +280,8 @@ def cleanup(self): """ Called when the application closes and the module widget is destroyed. """ + if self._webServer: + self._webServer.killProcess() self.removeObservers() def enter(self): From 7e3d099bbd5912b390ef247b80368d8f7301d4b5 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Thu, 19 Sep 2024 15:45:15 -0400 Subject: [PATCH 03/27] BUG: fix server address connection --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 70f0e56..6320b04 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -261,7 +261,7 @@ def setup(self): self.ui.browseToModelsFolderButton.connect("clicked(bool)", self.onBrowseModelsFolder) self.ui.deleteAllModelsButton.connect("clicked(bool)", self.onClearModelsFolder) - self.ui.serverComboBox.lineEdit().setPlaceholderText("enter server address or leave empty to use default") + self.ui.serverComboBox.lineEdit().setPlaceholderText("Enter server address") self.ui.serverComboBox.currentIndexChanged.connect(self.onRemoteServerButtonToggled) self.ui.remoteServerButton.toggled.connect(self.onRemoteServerButtonToggled) @@ -672,7 +672,7 @@ def onClearModelsFolder(self): slicer.util.messageBox("Downloaded models are deleted.") def onRemoteServerButtonToggled(self): - if self.ui.remoteServerButton.checked: + if self.ui.remoteServerButton.checked and self.ui.serverComboBox.currentText != '': self.ui.remoteServerButton.text = "Connected" self.logic = RemoteMONAIAuto3DSegLogic() self.logic.server_address = self.ui.serverComboBox.currentText @@ -688,6 +688,7 @@ def onRemoteServerButtonToggled(self): return self.saveServerUrl() else: + self.ui.remoteServerButton.checked = False self.ui.remoteServerButton.text = "Connect" self.logic = MONAIAuto3DSegLogic() From b85eea58f1065f1a9ab72cbc1bd9a544a3ad3852 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Fri, 4 Oct 2024 15:35:25 -0400 Subject: [PATCH 04/27] STYLE: preventing to show all logging to user and addressed other comments --- MONAIAuto3DSeg/CMakeLists.txt | 6 +- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 96 +++++---- MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py | 1 - .../MONAIAuto3DSegLib/dependency_handler.py | 74 ++----- .../MONAIAuto3DSegLib/log_handler.py | 27 --- .../MONAIAuto3DSegLib/model_database.py | 25 +-- MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py | 184 ++++++++++-------- MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py | 21 -- .../__init__.py | 0 .../main.py | 65 ++++++- 10 files changed, 248 insertions(+), 251 deletions(-) delete mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py delete mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py rename MONAIAuto3DSeg/{auto3dseg => MONAIAuto3DSegServer}/__init__.py (100%) rename MONAIAuto3DSeg/{auto3dseg => MONAIAuto3DSegServer}/main.py (62%) diff --git a/MONAIAuto3DSeg/CMakeLists.txt b/MONAIAuto3DSeg/CMakeLists.txt index 3d8a9c9..322c24e 100644 --- a/MONAIAuto3DSeg/CMakeLists.txt +++ b/MONAIAuto3DSeg/CMakeLists.txt @@ -5,14 +5,12 @@ set(MODULE_NAME MONAIAuto3DSeg) set(MODULE_PYTHON_SCRIPTS ${MODULE_NAME}.py ${MODULE_NAME}Lib/__init__.py - ${MODULE_NAME}Lib/constants.py ${MODULE_NAME}Lib/dependency_handler.py - ${MODULE_NAME}Lib/log_handler.py ${MODULE_NAME}Lib/model_database.py ${MODULE_NAME}Lib/server.py ${MODULE_NAME}Lib/utils.py - auto3dseg/__init__py - auto3dseg/main.py + ${MODULE_NAME}Server/__init__py + ${MODULE_NAME}Server/main.py ) set(MODULE_PYTHON_RESOURCES diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 70f0e56..84ef7c4 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -8,17 +8,14 @@ import qt import slicer import requests +from typing import Callable from slicer.ScriptedLoadableModule import * from slicer.util import VTKObservationMixin from MONAIAuto3DSegLib.model_database import ModelDatabase -from MONAIAuto3DSegLib.constants import APPLICATION_NAME -from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec, assignInputNodesByName +from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec from MONAIAuto3DSegLib.dependency_handler import SlicerPythonDependencies, RemotePythonDependencies -logger = logging.getLogger(APPLICATION_NAME) - - # # MONAIAuto3DSeg # @@ -193,12 +190,6 @@ def __init__(self, parent=None): self._segmentationProcessInfo = None self._webServer = None - from MONAIAuto3DSegLib.log_handler import LogHandler - handler = LogHandler(self.addLog, logging.INFO) - formatter = logging.Formatter('%(levelname)s - %(message)s') - handler.setFormatter(formatter) - logger.addHandler(handler) - def onReload(self): if self._webServer: self._webServer.killProcess() @@ -261,7 +252,7 @@ def setup(self): self.ui.browseToModelsFolderButton.connect("clicked(bool)", self.onBrowseModelsFolder) self.ui.deleteAllModelsButton.connect("clicked(bool)", self.onClearModelsFolder) - self.ui.serverComboBox.lineEdit().setPlaceholderText("enter server address or leave empty to use default") + self.ui.serverComboBox.lineEdit().setPlaceholderText("Enter server address") self.ui.serverComboBox.currentIndexChanged.connect(self.onRemoteServerButtonToggled) self.ui.remoteServerButton.toggled.connect(self.onRemoteServerButtonToggled) @@ -637,7 +628,7 @@ def onDownloadSampleData(self): slicer.util.messageBox(f"Failed to load sample data set '{sampleDataName}'.") return - inputNodes = assignInputNodesByName(inputs, loadedSampleNodes) + inputNodes = self.logic.assignInputNodesByName(inputs, loadedSampleNodes) for inputIndex, inputNode in enumerate(inputNodes): if inputNode: self.inputNodeSelectors[inputIndex].setCurrentNode(inputNode) @@ -672,7 +663,7 @@ def onClearModelsFolder(self): slicer.util.messageBox("Downloaded models are deleted.") def onRemoteServerButtonToggled(self): - if self.ui.remoteServerButton.checked: + if self.ui.remoteServerButton.checked and self.ui.serverComboBox.currentText != '': self.ui.remoteServerButton.text = "Connected" self.logic = RemoteMONAIAuto3DSegLogic() self.logic.server_address = self.ui.serverComboBox.currentText @@ -688,6 +679,7 @@ def onRemoteServerButtonToggled(self): return self.saveServerUrl() else: + self.ui.remoteServerButton.checked = False self.ui.remoteServerButton.text = "Connect" self.logic = MONAIAuto3DSegLogic() @@ -710,15 +702,14 @@ def onServerButtonToggled(self, toggled): self.ui.serverAddressLineEdit.text = f"http://{hostName}:{port}" from MONAIAuto3DSegLib.server import WebServer - self._webServer = WebServer(completedCallback=lambda: self.ui.serverButton.setChecked(False)) + self._webServer = WebServer( + logCallback=self.addLog, + completedCallback=lambda: self.ui.serverButton.setChecked(False) + ) self._webServer.launchConsoleProcess(cmd) else: if self._webServer is not None and self._webServer.isRunning(): - logger.info("Server stop requested.") self._webServer.killProcess() - if not self._webServer.isRunning(): - logger.info("Server stopped.") - self._webServer = None self.updateGUIFromParameterNode() @@ -782,6 +773,27 @@ class MONAIAuto3DSegLogic(ScriptedLoadableModuleLogic, ModelDatabase): DEPENDENCY_HANDLER = SlicerPythonDependencies() + @staticmethod + def assignInputNodesByName(inputs, loadedSampleNodes): + + def findFirstNodeByNamePattern(namePattern, nodes): + import fnmatch + for node in nodes: + if fnmatch.fnmatchcase(node.GetName(), namePattern): + return node + return None + + inputNodes = [] + for inputIndex, input in enumerate(inputs): + namePattern = input.get("namePattern") + if namePattern: + matchingNode = findFirstNodeByNamePattern(namePattern, loadedSampleNodes) + else: + matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else \ + loadedSampleNodes[0] + inputNodes.append(matchingNode) + return inputNodes + @staticmethod def getLoadedTerminologyNames(): import vtk @@ -1042,7 +1054,7 @@ def logProcessOutputUntilCompleted(self, segmentationProcessInfo): line = proc.stdout.readline() if not line: break - logger.info(line.rstrip()) + logging.info(line.rstrip()) except UnicodeDecodeError as e: # Code page conversion happens because `universal_newlines=True` sets process output to text mode, # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string, @@ -1085,7 +1097,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor import time startTime = time.time() - logger.info("Processing started") + logging.info("Processing started") if self.debugSkipInference: # For debugging, use a fixed temporary folder @@ -1105,7 +1117,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor for inputIndex, inputNode in enumerate(inputNodes): if inputNode.IsA('vtkMRMLScalarVolumeNode'): inputImageFile = tempDir + f"/input-volume{inputIndex}.nrrd" - logger.info(f"Writing input file to {inputImageFile}") + logging.info(f"Writing input file to {inputImageFile}") volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode") volumeStorageNode.SetFileName(inputImageFile) volumeStorageNode.UseCompressionOff() @@ -1126,13 +1138,13 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor auto3DSegCommand.append(f"--image-file-{inputIndex+1}") auto3DSegCommand.append(inputFiles[inputIndex]) - logger.info("Creating segmentations with MONAIAuto3DSeg AI...") - logger.info(f"Auto3DSeg command: {auto3DSegCommand}") + logging.info("Creating segmentations with MONAIAuto3DSeg AI...") + logging.info(f"Auto3DSeg command: {auto3DSegCommand}") additionalEnvironmentVariables = None if cpu: additionalEnvironmentVariables = {"CUDA_VISIBLE_DEVICES": "-1"} - logger.info(f"Additional environment variables: {additionalEnvironmentVariables}") + logging.info(f"Additional environment variables: {additionalEnvironmentVariables}") if self.debugSkipInference: proc = None @@ -1166,7 +1178,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor return segmentationProcessInfo def cancelProcessing(self, segmentationProcessInfo): - logger.info("Cancel is requested.") + logging.info("Cancel is requested.") segmentationProcessInfo["cancelRequested"] = True proc = segmentationProcessInfo.get("proc") if proc: @@ -1218,7 +1230,7 @@ def checkSegmentationProcessOutput(self, segmentationProcessInfo): return try: line = outputQueue.get_nowait() - logger.info(line) + logging.info(line) except queue.Empty: break @@ -1238,14 +1250,14 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): if cancelRequested: procReturnCode = MONAIAuto3DSegLogic.EXIT_CODE_USER_CANCELLED - logger.info(f"Processing was cancelled.") + logging.info(f"Processing was cancelled.") else: if procReturnCode == 0: if self.startResultImportCallback: self.startResultImportCallback(customData) try: # Load result - logger.info("Importing segmentation results...") + logging.info("Importing segmentation results...") self.readSegmentation(outputSegmentation, outputSegmentationFile, model) # Set source volume - required for DICOM Segmentation export @@ -1266,15 +1278,15 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): if self.endResultImportCallback: self.endResultImportCallback(customData) else: - logger.info(f"Processing failed with return code {procReturnCode}") + logging.info(f"Processing failed with return code {procReturnCode}") if self.clearOutputFolder: - logger.info("Cleaning up temporary folder.") + logging.info("Cleaning up temporary folder.") if os.path.isdir(tempDir): import shutil shutil.rmtree(tempDir) else: - logger.info(f"Not cleaning up temporary folder: {tempDir}") + logging.info(f"Not cleaning up temporary folder: {tempDir}") # Report total elapsed time import time @@ -1282,12 +1294,12 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): segmentationProcessInfo["stopTime"] = stopTime elapsedTime = stopTime - startTime if cancelRequested: - logger.info(f"Processing was cancelled after {elapsedTime:.2f} seconds.") + logging.info(f"Processing was cancelled after {elapsedTime:.2f} seconds.") else: if procReturnCode == 0: - logger.info(f"Processing was completed in {elapsedTime:.2f} seconds.") + logging.info(f"Processing was completed in {elapsedTime:.2f} seconds.") else: - logger.info(f"Processing failed after {elapsedTime:.2f} seconds.") + logging.info(f"Processing failed after {elapsedTime:.2f} seconds.") if self.processingCompletedCallback: self.processingCompletedCallback(procReturnCode, customData) @@ -1343,7 +1355,7 @@ def setTerminology(self, segmentation, segmentId, terminologyEntryStr): segment.SetName(label) segment.SetColor(color) except RuntimeError as e: - logger.info(str(e)) + logging.info(str(e)) class RemoteMONAIAuto3DSegLogic(MONAIAuto3DSegLogic): @@ -1356,6 +1368,7 @@ def server_address(self): @server_address.setter def server_address(self, address): + self.DEPENDENCY_HANDLER.server_address = address self._server_address = address self._models = [] @@ -1413,7 +1426,7 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF import time startTime = time.time() - logger.info("Processing started") + logging.info("Processing started") tempDir = slicer.util.tempDirectory() outputSegmentationFile = tempDir + "/output-segmentation.nrrd" @@ -1427,7 +1440,7 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF for inputIndex, inputNode in enumerate(inputNodes): if inputNode.IsA('vtkMRMLScalarVolumeNode'): inputImageFile = tempDir / f"input-volume{inputIndex}.nrrd" - logger.info(f"Writing input file to {inputImageFile}") + logging.info(f"Writing input file to {inputImageFile}") volumeStorageNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLVolumeArchetypeStorageNode") volumeStorageNode.SetFileName(inputImageFile) volumeStorageNode.UseCompressionOff() @@ -1440,7 +1453,7 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF segmentationProcessInfo = {} segmentationProcessInfo["procReturnCode"] = MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN - logger.info(f"Initiating Inference on {self._server_address}") + logging.info(f"Initiating Inference on {self._server_address}") files = {} try: @@ -1475,6 +1488,7 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF return segmentationProcessInfo + # # MONAIAuto3DSegTest # @@ -1583,7 +1597,7 @@ def test_MONAIAuto3DSeg1(self): # Set model inputs inputs = model.get("inputs") - inputNodes = assignInputNodesByName(inputs, loadedSampleNodes) + inputNodes = logic.assignInputNodesByName(inputs, loadedSampleNodes) outputSegmentation = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") @@ -1743,4 +1757,4 @@ def _writeTestResultsToMarkdown(self, modelsTestResultsJsonFilePath, modelsTestR f.write(f"Processing time: {humanReadableTimeFromSec(model['segmentationTimeSecGPU'])} on GPU, {humanReadableTimeFromSec(model['segmentationTimeSecCPU'])} on CPU\n\n") f.write(f"Segment names: {', '.join(model['segmentNames'])}\n\n") f.write(f"![2D view]({screenshotUrlBase}{model['segmentationResultsScreenshot2D']})\n") - f.write(f"![3D view]({screenshotUrlBase}{model['segmentationResultsScreenshot3D']})\n") \ No newline at end of file + f.write(f"![3D view]({screenshotUrlBase}{model['segmentationResultsScreenshot3D']})\n") diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py deleted file mode 100644 index 63c5e7d..0000000 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/constants.py +++ /dev/null @@ -1 +0,0 @@ -APPLICATION_NAME = "MONAIAuto3DSeg" diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py index 438f120..9c75887 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py @@ -1,12 +1,7 @@ -import sys - import shutil import subprocess import logging -from MONAIAuto3DSegLib.constants import APPLICATION_NAME -logger = logging.getLogger(APPLICATION_NAME) - from abc import ABC, abstractmethod @@ -27,67 +22,36 @@ def setupPythonRequirements(self, upgrade=False): pass -class LocalPythonDependencies(DependenciesBase): - - def installedMONAIPythonPackageInfo(self): - versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode() - return versionInfo - - def _checkModuleInstalled(self, moduleName): - try: - import importlib - importlib.import_module(moduleName) - return True - except ModuleNotFoundError: - return False - - def setupPythonRequirements(self, upgrade=False): - def install(package): - subprocess.check_call([sys.executable, "-m", "pip", "install", package]) - - logger.info("Initializing PyTorch...") - - packageName = "torch" - if not self._checkModuleInstalled(packageName): - logger.info("PyTorch Python package is required. Installing... (it may take several minutes)") - install(packageName) - if not self._checkModuleInstalled(packageName): - raise ValueError("pytorch needs to be installed to use this module.") - else: # torch is installed, check version - from packaging import version - import torch - if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion): - raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module." - + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" - + f" with version requirement set to: >={self.minimumTorchVersion}") - - logger.info("Initializing MONAI...") - monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" - if upgrade: - monaiInstallString += " --upgrade" - install(monaiInstallString) +class RemotePythonDependencies(DependenciesBase): - self.dependenciesInstalled = True - logger.info("Dependencies are set up successfully.") + def __init__(self): + super().__init__() + self._server_address = None + @property + def server_address(self): + return self._server_address -class RemotePythonDependencies(DependenciesBase): + @server_address.setter + def server_address(self, address): + self._server_address = address - def installedMONAIPythonPackageInfo(self, server_address): - if not server_address: + def installedMONAIPythonPackageInfo(self): + if not self._server_address: return [] else: import json import requests - response = requests.get(server_address + "/monaiinfo") + response = requests.get(self._server_address + "/monaiinfo") json_data = json.loads(response.text) return json_data def setupPythonRequirements(self, upgrade=False): - logger.error("No permission to update remote python packages. Please contact developer.") + logging.error("No permission to update remote python packages. Please contact developer.") class SlicerPythonDependencies(DependenciesBase): + """ Dependency handler when being used within 3D Slicer (SlicerPython) environment. """ def installedMONAIPythonPackageInfo(self): versionInfo = subprocess.check_output([shutil.which("PythonSlicer"), "-m", "pip", "show", "MONAI"]).decode() @@ -100,11 +64,11 @@ def setupPythonRequirements(self, upgrade=False): except ModuleNotFoundError as e: raise RuntimeError("This module requires PyTorch extension. Install it from the Extensions Manager.") - logger.info("Initializing PyTorch...") + logging.info("Initializing PyTorch...") torchLogic = PyTorchUtils.PyTorchUtilsLogic() if not torchLogic.torchInstalled(): - logger.info("PyTorch Python package is required. Installing... (it may take several minutes)") + logging.info("PyTorch Python package is required. Installing... (it may take several minutes)") torch = torchLogic.installTorch(askConfirmation=True, torchVersionRequirement=f">={self.minimumTorchVersion}") if torch is None: raise ValueError("PyTorch extension needs to be installed to use this module.") @@ -116,7 +80,7 @@ def setupPythonRequirements(self, upgrade=False): + f" with version requirement set to: >={self.minimumTorchVersion}") # Install MONAI with required components - logger.info("Initializing MONAI...") + logging.info("Initializing MONAI...") # Specify minimum version 1.3, as this is a known working version (it is possible that an earlier version works, too). # Without this, for some users monai-0.9.0 got installed, which failed with this error: # "ImportError: cannot import name ‘MetaKeys’ from 'monai.utils'" @@ -127,4 +91,4 @@ def setupPythonRequirements(self, upgrade=False): slicer.util.pip_install(monaiInstallString) self.dependenciesInstalled = True - logger.info("Dependencies are set up successfully.") \ No newline at end of file + logging.info("Dependencies are set up successfully.") \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py deleted file mode 100644 index 4e3540f..0000000 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/log_handler.py +++ /dev/null @@ -1,27 +0,0 @@ -import logging -from typing import Callable - - -class LogHandler(logging.Handler): - """ - - code: - - logger = logging.getLogger("XYZ") - - callback = ... # any callable - # NB: only catching info level messages and forwarding it to callback - handler = LogHandler(callback, logging.INFO) - # can format log messages - formatter = logging.Formatter('%(levelname)s - %(message)s') - handler.setFormatter(formatter) - logger.addHandler(handler) - - """ - def __init__(self, callback: Callable, level=logging.NOTSET): - self._callback = callback - super().__init__(level) - - def emit(self, record): - msg = self.format(record) - self._callback(msg) \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py index d9ef612..c9c599c 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py @@ -4,13 +4,14 @@ from pathlib import Path from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec -from MONAIAuto3DSegLib.constants import APPLICATION_NAME - - -logger = logging.getLogger(APPLICATION_NAME) class ModelDatabase: + """ Retrieve model information, download and store models in model cache directory. Model information is stored + in local Models.json. + """ + + DEFAULT_CACHE_DIR_NAME = ".MONAIAuto3DSeg" @property def defaultModel(self): @@ -33,7 +34,7 @@ def modelsDescriptionJsonFilePath(self): return os.path.join(self.moduleDir, "Resources", "Models.json") def __init__(self): - self.fileCachePath = Path.home().joinpath(f".{APPLICATION_NAME}") + self.fileCachePath = Path.home().joinpath(f"{self.DEFAULT_CACHE_DIR_NAME}") self.moduleDir = Path(__file__).parent.parent # Disabling this flag preserves input and output data after execution is completed, @@ -66,7 +67,7 @@ def loadModelsDescription(self): filename = match.group("filename") version = match.group("version") else: - logger.error(f"Failed to extract model id and version from url: {url}") + logging.error(f"Failed to extract model id and version from url: {url}") if "inputs" in model: # Contains a list of dict. One dict for each input. # Currently, only "title" (user-displayable name) and "namePattern" of the input are specified. @@ -162,8 +163,8 @@ def downloadModel(self, modelName): modelDir = self.modelsPath.joinpath(modelName) Path(modelDir).mkdir(exist_ok=True) modelZipFile = tempDir.joinpath("autoseg3d_model.zip") - logger.info(f"Downloading model '{modelName}' from {url}...") - logger.debug(f"Downloading from {url} to {modelZipFile}...") + logging.info(f"Downloading model '{modelName}' from {url}...") + logging.debug(f"Downloading from {url} to {modelZipFile}...") try: with open(modelZipFile, 'wb') as f: with requests.get(url, stream=True) as r: @@ -177,20 +178,20 @@ def downloadModel(self, modelName): downloaded_size += len(chunk) downloaded_percent = 100.0 * downloaded_size / total_size if downloaded_percent - last_reported_download_percent > reporting_increment_percent: - logger.info( + logging.info( f"Downloading model: {downloaded_size / 1024 / 1024:.1f}MB / {total_size / 1024 / 1024:.1f}MB ({downloaded_percent:.1f}%)") last_reported_download_percent = downloaded_percent - logger.info(f"Download finished. Extracting to {modelDir}...") + logging.info(f"Download finished. Extracting to {modelDir}...") with zipfile.ZipFile(modelZipFile, 'r') as zip_f: zip_f.extractall(modelDir) except Exception as e: raise e finally: if self.clearOutputFolder: - logger.info("Cleaning up temporary model download folder...") + logging.info("Cleaning up temporary model download folder...") if os.path.isdir(tempDir): import shutil shutil.rmtree(tempDir) else: - logger.info(f"Not cleaning up temporary model download folder: {tempDir}") + logging.info(f"Not cleaning up temporary model download folder: {tempDir}") diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py index ec5bdfc..55b2cd2 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py @@ -9,87 +9,107 @@ from pathlib import Path from typing import Callable -from MONAIAuto3DSegLib.constants import APPLICATION_NAME -logger = logging.getLogger(APPLICATION_NAME) - class WebServer: - - CHECK_TIMER_INTERVAL = 1000 - - @staticmethod - def getPSProcess(pid): - try: - return psutil.Process(pid) - except psutil.NoSuchProcess: - return None - - def __init__(self, completedCallback: Callable): - self.completedCallback = completedCallback - self.procThread = None - self.serverProc = None - self.queue = None - - def isRunning(self): - if self.serverProc is not None: - psProcess = self.getPSProcess(self.serverProc.pid) - if psProcess: - return psProcess.is_running() - return False - - def __del__(self): - self.killProcess() - - def killProcess(self): - if not self.serverProc: - return - psProcess = self.getPSProcess(self.serverProc.pid) - if not psProcess: - return - for psChildProcess in psProcess.children(recursive=True): - psChildProcess.kill() - if psProcess.is_running(): - psProcess.kill() - - def launchConsoleProcess(self, cmd): - self.serverProc = \ - slicer.util.launchConsoleProcess(cmd, cwd=Path(__file__).parent.parent / "auto3dseg", useStartupEnvironment=False) - - self.queue = queue.Queue() - self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess) - self.procThread.start() - self.checkProcessOutput() - - def cleanup(self): - if self.procThread: - self.procThread.join() - self.completedCallback() - self.serverProc = None - self.procThread = None - self.queue = None - - def _handleProcessOutputThreadProcess(self): - while True: - try: - line = self.serverProc.stdout.readline() - if not line: - break - self.queue.put(line.rstrip()) - except UnicodeDecodeError as e: - pass - self.serverProc.wait() - - def checkProcessOutput(self): - outputQueue = self.queue - while outputQueue: - try: - line = outputQueue.get_nowait() - logger.info(line) - except queue.Empty: - break - - psProcess = self.getPSProcess(self.serverProc.pid) - if psProcess and psProcess.is_running(): # No more outputs to process now, check again later - qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput) - else: - self.cleanup() \ No newline at end of file + """ Web server class to be used from 3D Slicer. Upon starting the server with the given cmd command, it will be + checked every {CHECK_TIMER_INTERVAL} milliseconds for process status/outputs. + + code: + + cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] + + from MONAIAuto3DSegLib.server import WebServer + server = WebServer(completedCallback=...) + server.launchConsoleProcess(cmd) + + ... + + server.killProcess() + + """ + + CHECK_TIMER_INTERVAL = 1000 + + @staticmethod + def getPSProcess(pid): + try: + return psutil.Process(pid) + except psutil.NoSuchProcess: + return None + + def __init__(self, logCallback: Callable=None, completedCallback: Callable=None): + self.logCallback = logCallback + self.completedCallback = completedCallback + self.procThread = None + self.serverProc = None + self.queue = None + + def isRunning(self): + if self.serverProc is not None: + psProcess = self.getPSProcess(self.serverProc.pid) + if psProcess: + return psProcess.is_running() + return False + + def __del__(self): + self.killProcess() + + def killProcess(self): + if not self.serverProc: + return + psProcess = self.getPSProcess(self.serverProc.pid) # proc.kill() does not work, that would only stop the launcher + if not psProcess: + return + for psChildProcess in psProcess.children(recursive=True): + psChildProcess.kill() + if psProcess.is_running(): + psProcess.kill() + + def launchConsoleProcess(self, cmd): + self.serverProc = \ + slicer.util.launchConsoleProcess(cmd, cwd=Path(__file__).parent.parent / "MONAIAuto3DSegServer", useStartupEnvironment=False) + + if self.logCallback and self.isRunning(): + self.logCallback("Server Started") + + self.queue = queue.Queue() + self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess) + self.procThread.start() + self.checkProcessOutput() + + def cleanup(self): + if self.procThread: + self.procThread.join() + if self.completedCallback: + self.completedCallback() + if self.logCallback: + self.logCallback("Server Stopped") + self.serverProc = None + self.procThread = None + self.queue = None + + def _handleProcessOutputThreadProcess(self): + while True: + try: + line = self.serverProc.stdout.readline() + if not line: + break + self.queue.put(line.rstrip()) + except UnicodeDecodeError as e: + pass + self.serverProc.wait() + + def checkProcessOutput(self): + outputQueue = self.queue + while outputQueue: + try: + line = outputQueue.get_nowait() + logging.info(line) + except queue.Empty: + break + + psProcess = self.getPSProcess(self.serverProc.pid) + if psProcess and psProcess.is_running(): # No more outputs to process now, check again later + qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput) + else: + self.cleanup() \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py index 186c9cc..4575b22 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/utils.py @@ -12,24 +12,3 @@ def humanReadableTimeFromSec(seconds): return f"{math.ceil(seconds / 60)} min" # Otherwise round up to the nearest 0.1 hour return f"{seconds / 3600:.1f} h" - - -def assignInputNodesByName(inputs, loadedSampleNodes): - inputNodes = [] - for inputIndex, input in enumerate(inputs): - namePattern = input.get("namePattern") - if namePattern: - matchingNode = findFirstNodeByNamePattern(namePattern, loadedSampleNodes) - else: - matchingNode = loadedSampleNodes[inputIndex] if inputIndex < len(loadedSampleNodes) else \ - loadedSampleNodes[0] - inputNodes.append(matchingNode) - return inputNodes - - -def findFirstNodeByNamePattern(namePattern, nodes): - import fnmatch - for node in nodes: - if fnmatch.fnmatchcase(node.GetName(), namePattern): - return node - return None \ No newline at end of file diff --git a/MONAIAuto3DSeg/auto3dseg/__init__.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/__init__.py similarity index 100% rename from MONAIAuto3DSeg/auto3dseg/__init__.py rename to MONAIAuto3DSeg/MONAIAuto3DSegServer/__init__.py diff --git a/MONAIAuto3DSeg/auto3dseg/main.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py similarity index 62% rename from MONAIAuto3DSeg/auto3dseg/main.py rename to MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py index 9be35ab..ea747ea 100644 --- a/MONAIAuto3DSeg/auto3dseg/main.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py @@ -2,7 +2,7 @@ # pip install fastapi # pip install "uvicorn[standard]" -# usage: uvicorn main:app --reload --host reslnjolleyws03.research.chop.edu --port 8891 +# usage: uvicorn main:app --reload --host example.com --port 8891 # usage: uvicorn main:app --reload --host localhost --port 8891 @@ -17,8 +17,7 @@ sys.path.insert(0, path) from MONAIAuto3DSegLib.model_database import ModelDatabase -from MONAIAuto3DSegLib.dependency_handler import LocalPythonDependencies -from MONAIAuto3DSegLib.constants import APPLICATION_NAME +from MONAIAuto3DSegLib.dependency_handler import DependenciesBase import shutil import asyncio @@ -29,15 +28,65 @@ from fastapi.background import BackgroundTasks +class LocalPythonDependencies(DependenciesBase): + """ Dependency handler when running locally (not within 3D Slicer) + + code: + + from MONAIAuto3DSegLib.dependency_handler import LocalPythonDependencies + dependencies = LocalPythonDependencies() + dependencies.installedMONAIPythonPackageInfo() + + # dependencies.setupPythonRequirements(upgrade=True) + """ + + def installedMONAIPythonPackageInfo(self): + versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode() + return versionInfo + + def _checkModuleInstalled(self, moduleName): + try: + import importlib + importlib.import_module(moduleName) + return True + except ModuleNotFoundError: + return False + + def setupPythonRequirements(self, upgrade=False): + def install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + logging.debug("Initializing PyTorch...") + + packageName = "torch" + if not self._checkModuleInstalled(packageName): + logging.debug("PyTorch Python package is required. Installing... (it may take several minutes)") + install(packageName) + if not self._checkModuleInstalled(packageName): + raise ValueError("pytorch needs to be installed to use this module.") + else: # torch is installed, check version + from packaging import version + import torch + if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion): + raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module." + + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" + + f" with version requirement set to: >={self.minimumTorchVersion}") + + logging.debug("Initializing MONAI...") + monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" + if upgrade: + monaiInstallString += " --upgrade" + install(monaiInstallString) + + self.dependenciesInstalled = True + logging.debug("Dependencies are set up successfully.") + + app = FastAPI() modelDB = ModelDatabase() dependencyHandler = LocalPythonDependencies() -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(APPLICATION_NAME) - - def upload(file, session_dir, identifier): extension = "".join(Path(file.filename).suffixes) file_location = f"{session_dir}/{identifier}{extension}" @@ -111,7 +160,7 @@ async def infer( auto3DSegCommand.append(inputFiles[inputIndex]) try: - # logger.info(auto3DSegCommand) + # logging.debug(auto3DSegCommand) proc = await asyncio.create_subprocess_shell(" ".join(auto3DSegCommand)) await proc.wait() if proc.returncode != 0: From bb0574ff86232f0d7980666d4da8d4bd67889e21 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Mon, 7 Oct 2024 11:50:45 -0400 Subject: [PATCH 05/27] ENH: requirements installation and reloading submodules if requested - requesting server from Slicer will use SlicerPythonDependencies and check for PyTorch extension prior to server start. Other modules (e.g. monai) will also be installed prior to server start. - using fastapi from commandline will use NonSlicerPythonDependencies. --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 71 +++++++++++++------ .../MONAIAuto3DSegLib/dependency_handler.py | 57 ++++++++++++++- MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py | 23 ++++-- MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py | 70 ++++-------------- 4 files changed, 132 insertions(+), 89 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 84ef7c4..c19fabd 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -8,7 +8,6 @@ import qt import slicer import requests -from typing import Callable from slicer.ScriptedLoadableModule import * from slicer.util import VTKObservationMixin from MONAIAuto3DSegLib.model_database import ModelDatabase @@ -191,8 +190,21 @@ def __init__(self, parent=None): self._webServer = None def onReload(self): + logging.debug(f"Reloading {self.moduleName}") if self._webServer: self._webServer.killProcess() + + packageName ="MONAIAuto3DSegLib" + submoduleNames = ['dependency_handler', 'model_database', 'server', 'utils'] + import imp + f, filename, description = imp.find_module(packageName) + package = imp.load_module(packageName, f, filename, description) + for submoduleName in submoduleNames: + f, filename, description = imp.find_module(submoduleName, package.__path__) + try: + imp.load_module(packageName + '.' + submoduleName, f, filename, description) + finally: + f.close() super().onReload() def setup(self): @@ -689,30 +701,43 @@ def onRemoteServerButtonToggled(self): self.updateGUIFromParameterNode() def onServerButtonToggled(self, toggled): - if toggled: - if not self._webServer or not self._webServer.isRunning() : - import platform - from pathlib import Path - slicer.util.pip_install("python-multipart fastapi uvicorn[standard]") - - hostName = platform.node() - port = str(self.ui.portSpinBox.value) - cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] - - self.ui.serverAddressLineEdit.text = f"http://{hostName}:{port}" - - from MONAIAuto3DSegLib.server import WebServer - self._webServer = WebServer( - logCallback=self.addLog, - completedCallback=lambda: self.ui.serverButton.setChecked(False) - ) - self._webServer.launchConsoleProcess(cmd) - else: - if self._webServer is not None and self._webServer.isRunning(): - self._webServer.killProcess() - self._webServer = None + with slicer.util.tryWithErrorDisplay("Failed to start server.", waitCursor=True): + if toggled: + if not hasattr(slicer.modules, 'pytorchutils'): + raise ModuleNotFoundError ("This modules requires the PyTorch extension. Install PyTorch and restart Slicer.") + + # TODO: improve error reporting if installation of requirements fails + self.logic.setupPythonRequirements() + + if not self._webServer or not self._webServer.isRunning() : + import platform + from pathlib import Path + slicer.util.pip_install("python-multipart fastapi uvicorn[standard]") + + hostName = platform.node() + port = str(self.ui.portSpinBox.value) + + cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] + + self.ui.serverAddressLineEdit.text = f"http://{hostName}:{port}" + + from MONAIAuto3DSegLib.server import WebServer + self._webServer = WebServer( + logCallback=self.addLog, + completedCallback=self.onServerCompleted + ) + self._webServer.launchConsoleProcess(cmd) + else: + if self._webServer is not None and self._webServer.isRunning(): + self._webServer.killProcess() + self._webServer = None self.updateGUIFromParameterNode() + def onServerCompleted(self, text=None): + if text: + self.addLog(text) + self.ui.serverButton.setChecked(False) + def serverUrl(self): serverUrl = self.ui.serverComboBox.currentText.strip() if not serverUrl: diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py index 9c75887..52fda3a 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/dependency_handler.py @@ -1,6 +1,7 @@ import shutil import subprocess import logging +import sys from abc import ABC, abstractmethod @@ -91,4 +92,58 @@ def setupPythonRequirements(self, upgrade=False): slicer.util.pip_install(monaiInstallString) self.dependenciesInstalled = True - logging.info("Dependencies are set up successfully.") \ No newline at end of file + logging.info("Dependencies are set up successfully.") + + +class NonSlicerPythonDependencies(DependenciesBase): + """ Dependency handler when running locally (not within 3D Slicer) + + code: + + from MONAIAuto3DSegLib.dependency_handler import LocalPythonDependencies + dependencies = LocalPythonDependencies() + dependencies.installedMONAIPythonPackageInfo() + + # dependencies.setupPythonRequirements(upgrade=True) + """ + + def installedMONAIPythonPackageInfo(self): + versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode() + return versionInfo + + def _checkModuleInstalled(self, moduleName): + try: + import importlib + importlib.import_module(moduleName) + return True + except ModuleNotFoundError: + return False + + def setupPythonRequirements(self, upgrade=False): + def install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + logging.debug("Initializing PyTorch...") + + packageName = "torch" + if not self._checkModuleInstalled(packageName): + logging.debug("PyTorch Python package is required. Installing... (it may take several minutes)") + install(packageName) + if not self._checkModuleInstalled(packageName): + raise ValueError("pytorch needs to be installed to use this module.") + else: # torch is installed, check version + from packaging import version + import torch + if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion): + raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module." + + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" + + f" with version requirement set to: >={self.minimumTorchVersion}") + + logging.debug("Initializing MONAI...") + monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" + if upgrade: + monaiInstallString += " --upgrade" + install(monaiInstallString) + + self.dependenciesInstalled = True + logging.debug("Dependencies are set up successfully.") \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py index 55b2cd2..2aea1f6 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py @@ -66,11 +66,12 @@ def killProcess(self): psProcess.kill() def launchConsoleProcess(self, cmd): + logging.debug(f"Launching process: {cmd}") self.serverProc = \ slicer.util.launchConsoleProcess(cmd, cwd=Path(__file__).parent.parent / "MONAIAuto3DSegServer", useStartupEnvironment=False) - if self.logCallback and self.isRunning(): - self.logCallback("Server Started") + if self.isRunning(): + self.addLog("Server Started") self.queue = queue.Queue() self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess) @@ -81,20 +82,24 @@ def cleanup(self): if self.procThread: self.procThread.join() if self.completedCallback: - self.completedCallback() - if self.logCallback: - self.logCallback("Server Stopped") + if self.serverProc.returncode not in [-9, 0]: # killed or stopped cleanly + self.addLog(self._err) + self.completedCallback("Server Stopped") self.serverProc = None self.procThread = None self.queue = None def _handleProcessOutputThreadProcess(self): + self._err = None while True: try: line = self.serverProc.stdout.readline() if not line: break - self.queue.put(line.rstrip()) + text = line.rstrip() + self.queue.put(text) + if "ERROR" in text: + self._err = text except UnicodeDecodeError as e: pass self.serverProc.wait() @@ -112,4 +117,8 @@ def checkProcessOutput(self): if psProcess and psProcess.is_running(): # No more outputs to process now, check again later qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput) else: - self.cleanup() \ No newline at end of file + self.cleanup() + + def addLog(self, text): + if self.logCallback: + self.logCallback(text) \ No newline at end of file diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py index ea747ea..490a50d 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py @@ -11,13 +11,13 @@ import sys from pathlib import Path + paths = [str(Path(__file__).parent.parent)] for path in paths: if not path in sys.path: sys.path.insert(0, path) from MONAIAuto3DSegLib.model_database import ModelDatabase -from MONAIAuto3DSegLib.dependency_handler import DependenciesBase import shutil import asyncio @@ -28,63 +28,19 @@ from fastapi.background import BackgroundTasks -class LocalPythonDependencies(DependenciesBase): - """ Dependency handler when running locally (not within 3D Slicer) - - code: - - from MONAIAuto3DSegLib.dependency_handler import LocalPythonDependencies - dependencies = LocalPythonDependencies() - dependencies.installedMONAIPythonPackageInfo() - - # dependencies.setupPythonRequirements(upgrade=True) - """ - - def installedMONAIPythonPackageInfo(self): - versionInfo = subprocess.check_output([sys.executable, "-m", "pip", "show", "MONAI"]).decode() - return versionInfo - - def _checkModuleInstalled(self, moduleName): - try: - import importlib - importlib.import_module(moduleName) - return True - except ModuleNotFoundError: - return False - - def setupPythonRequirements(self, upgrade=False): - def install(package): - subprocess.check_call([sys.executable, "-m", "pip", "install", package]) - - logging.debug("Initializing PyTorch...") - - packageName = "torch" - if not self._checkModuleInstalled(packageName): - logging.debug("PyTorch Python package is required. Installing... (it may take several minutes)") - install(packageName) - if not self._checkModuleInstalled(packageName): - raise ValueError("pytorch needs to be installed to use this module.") - else: # torch is installed, check version - from packaging import version - import torch - if version.parse(torch.__version__) < version.parse(self.minimumTorchVersion): - raise ValueError(f"PyTorch version {torch.__version__} is not compatible with this module." - + f" Minimum required version is {self.minimumTorchVersion}. You can use 'PyTorch Util' module to install PyTorch" - + f" with version requirement set to: >={self.minimumTorchVersion}") - - logging.debug("Initializing MONAI...") - monaiInstallString = "monai[fire,pyyaml,nibabel,pynrrd,psutil,tensorboard,skimage,itk,tqdm]>=1.3" - if upgrade: - monaiInstallString += " --upgrade" - install(monaiInstallString) - - self.dependenciesInstalled = True - logging.debug("Dependencies are set up successfully.") - - app = FastAPI() modelDB = ModelDatabase() -dependencyHandler = LocalPythonDependencies() + +# deciding which dependencies to choose +if "python-real" in Path(sys.executable).name: + from MONAIAuto3DSegLib.dependency_handler import SlicerPythonDependencies + dependencyHandler = SlicerPythonDependencies() +else: + from MONAIAuto3DSegLib.dependency_handler import NonSlicerPythonDependencies + dependencyHandler = NonSlicerPythonDependencies() + dependencyHandler.setupPythonRequirements() + +logging.debug(f"Using {dependencyHandler.__class__.__name__} as dependency handler") def upload(file, session_dir, identifier): @@ -147,8 +103,6 @@ async def infer( assert os.path.exists(modelPtFile) - dependencyHandler.setupPythonRequirements() - moduleDir = Path(__file__).parent.parent inferenceScriptPyFile = os.path.join(moduleDir, "Scripts", "auto3dseg_segresnet_inference.py") auto3DSegCommand = [sys.executable, str(inferenceScriptPyFile), From 9f41e2638099d55205a33cf479cc9b4f87bb8a7f Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Mon, 7 Oct 2024 11:56:14 -0400 Subject: [PATCH 06/27] ENH: if attempting server start from Slicer, install requirements and check for PyTorch extension --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index c19fabd..68f03b6 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -703,9 +703,6 @@ def onRemoteServerButtonToggled(self): def onServerButtonToggled(self, toggled): with slicer.util.tryWithErrorDisplay("Failed to start server.", waitCursor=True): if toggled: - if not hasattr(slicer.modules, 'pytorchutils'): - raise ModuleNotFoundError ("This modules requires the PyTorch extension. Install PyTorch and restart Slicer.") - # TODO: improve error reporting if installation of requirements fails self.logic.setupPythonRequirements() From 8d2efb1b854d3bca9fd22baf8514040ae990f6ed Mon Sep 17 00:00:00 2001 From: Andras Lasso Date: Mon, 7 Oct 2024 13:39:26 -0400 Subject: [PATCH 07/27] Minor tweaks --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 21 +++++---- MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py | 23 +++++++--- MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui | 45 ++++++++++--------- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 68f03b6..a27777a 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -476,6 +476,11 @@ def updateGUIFromParameterNode(self, caller=None, event=None): self.ui.portSpinBox.value = int(self._parameterNode.GetParameter("ServerPort")) + self.ui.serverAddressTitleLabel.visible = self._webServer is not None + self.ui.serverAddressLabel.visible = self._webServer is not None + if self._webServer: + self.ui.serverAddressLabel.text = self._webServer.getAddressUrl() + self.ui.browseToModelsFolderButton.enabled = not remoteConnection self.ui.useStandardSegmentNamesCheckBox.enabled = not remoteConnection self.ui.cpuCheckBox.enabled = not remoteConnection @@ -680,12 +685,12 @@ def onRemoteServerButtonToggled(self): self.logic = RemoteMONAIAuto3DSegLogic() self.logic.server_address = self.ui.serverComboBox.currentText try: - _ = self.logic.models - self.addLog(f"Remote Server Connected {self.logic.server_address}") + models = self.logic.models + self.addLog(f"Remote Server Connected {self.logic.server_address}. {len(models)} models are available.") except: slicer.util.warningDisplay( f"Connection to remote server '{self.logic.server_address}' failed. " - f"Please check address, port, and connection" + f"Please check address, port, and connection." ) self.ui.remoteServerButton.checked = False return @@ -714,19 +719,17 @@ def onServerButtonToggled(self, toggled): hostName = platform.node() port = str(self.ui.portSpinBox.value) - cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] - - self.ui.serverAddressLineEdit.text = f"http://{hostName}:{port}" - from MONAIAuto3DSegLib.server import WebServer self._webServer = WebServer( logCallback=self.addLog, completedCallback=self.onServerCompleted ) - self._webServer.launchConsoleProcess(cmd) + self._webServer.hostName = hostName + self._webServer.port = port + self._webServer.start() else: if self._webServer is not None and self._webServer.isRunning(): - self._webServer.killProcess() + self._webServer.stop() self._webServer = None self.updateGUIFromParameterNode() diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py index 2aea1f6..d90a7eb 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py @@ -20,11 +20,11 @@ class WebServer: from MONAIAuto3DSegLib.server import WebServer server = WebServer(completedCallback=...) - server.launchConsoleProcess(cmd) + server.start() ... - server.killProcess() + server.stop() """ @@ -43,6 +43,8 @@ def __init__(self, logCallback: Callable=None, completedCallback: Callable=None) self.procThread = None self.serverProc = None self.queue = None + self.hostName = "127.0.0.1" + self.port = 8891 def isRunning(self): if self.serverProc is not None: @@ -52,9 +54,9 @@ def isRunning(self): return False def __del__(self): - self.killProcess() + self._killProcess() - def killProcess(self): + def _killProcess(self): if not self.serverProc: return psProcess = self.getPSProcess(self.serverProc.pid) # proc.kill() does not work, that would only stop the launcher @@ -65,7 +67,18 @@ def killProcess(self): if psProcess.is_running(): psProcess.kill() - def launchConsoleProcess(self, cmd): + def start(self): + import sys + cmd = [sys.executable, "main.py", "--host", self.hostName, "--port", self.port] + self._launchConsoleProcess(cmd) + + def stop(self): + self._killProcess() + + def getAddressUrl(self): + return f"http://{self.hostName}:{self.port}" + + def _launchConsoleProcess(self, cmd): logging.debug(f"Launching process: {cmd}") self.serverProc = \ slicer.util.launchConsoleProcess(cmd, cwd=Path(__file__).parent.parent / "MONAIAuto3DSegServer", useStartupEnvironment=False) diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index 1c4fd3c..1b0e290 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -7,7 +7,7 @@ 0 0 408 - 674 + 841 @@ -437,7 +437,7 @@ - Server Address: + Local segmentation server false @@ -463,42 +463,35 @@ - + Log to Console: - + - + Log to GUI: - + - - - - - - - - + QPushButton:checked { @@ -515,17 +508,27 @@ - - - - true + + + + Server address: - - + + - Server address: + + + + Qt::LinksAccessibleByMouse|Qt::TextSelectableByKeyboard|Qt::TextSelectableByMouse + + + + + + + From d3a0bc82a228e9a072c02c37d92c404039df01d5 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Mon, 7 Oct 2024 15:35:09 -0400 Subject: [PATCH 08/27] BUG: minor issues --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index a27777a..63ddb8f 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -1181,7 +1181,6 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor segmentationProcessInfo["cancelRequested"] = False segmentationProcessInfo["startTime"] = startTime segmentationProcessInfo["tempDir"] = tempDir - segmentationProcessInfo["segmentationProcess"] = proc segmentationProcessInfo["inputNodes"] = inputNodes segmentationProcessInfo["outputSegmentation"] = outputSegmentation segmentationProcessInfo["outputSegmentationFile"] = outputSegmentationFile @@ -1260,7 +1259,7 @@ def checkSegmentationProcessOutput(self, segmentationProcessInfo): break # No more outputs to process now, check again later - qt.QTimer.singleShot(self.processOutputCheckTimerIntervalMsec, lambda segmentationProcessInfo=segmentationProcessInfo: self.checkSegmentationProcessOutput(segmentationProcessInfo)) + qt.QTimer.singleShot(self.processOutputCheckTimerIntervalMsec, lambda: self.checkSegmentationProcessOutput(segmentationProcessInfo=segmentationProcessInfo)) def onSegmentationProcessCompleted(self, segmentationProcessInfo): startTime = segmentationProcessInfo["startTime"] @@ -1403,7 +1402,7 @@ def __init__(self): self._models = [] def getMONAIPythonPackageInfo(self): - return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo(self._server_address) + return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo() def setupPythonRequirements(self, upgrade=False): self.DEPENDENCY_HANDLER.setupPythonRequirements(upgrade) @@ -1417,7 +1416,7 @@ def loadModelsDescription(self): json_data = json.loads(response.text) return json_data - def labelDescriptions(self, modelId): + def labelDescriptions(self, modelName): """Return mapping from label value to label description. Label description is a dict containing "name" and "terminology". Terminology string uses Slicer terminology entry format - see specification at @@ -1428,7 +1427,7 @@ def labelDescriptions(self, modelId): else: import tempfile with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile: - with requests.get(self._server_address + f"/labelDescriptions?id={modelId}", stream=True) as r: + with requests.get(self._server_address + f"/labelDescriptions?id={modelName}", stream=True) as r: r.raise_for_status() with open(tmpfile.name, 'wb') as f: From 77e11da17463f7a9988c8c19b99dd08fe5b290af Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Tue, 8 Oct 2024 15:47:27 -0400 Subject: [PATCH 09/27] STYLE: moved download models into ModelDatabase - extracted SegmentationProcessInfo dataclass instead of using dict --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 145 +++++++++--------- .../MONAIAuto3DSegLib/model_database.py | 5 + 2 files changed, 79 insertions(+), 71 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 63ddb8f..5aa6049 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -2,12 +2,14 @@ import os import json import sys +import time import vtk import qt import slicer import requests + from slicer.ScriptedLoadableModule import * from slicer.util import VTKObservationMixin from MONAIAuto3DSegLib.model_database import ModelDatabase @@ -15,6 +17,36 @@ from MONAIAuto3DSegLib.dependency_handler import SlicerPythonDependencies, RemotePythonDependencies +import subprocess +import queue +import threading +from dataclasses import dataclass +from typing import Any + +from enum import Enum + + +class ExitCode(Enum): + USER_CANCELLED = 1001 + DID_NOT_RUN = 1002 + + +@dataclass +class SegmentationProcessInfo: + proc: subprocess.Popen = None + startTime: float = time.time() + cancelRequested: bool = False + tempDir: str = "" + inputNodes: list = None + outputSegmentation: slicer.vtkMRMLSegmentationNode = None + outputSegmentationFile: str = "" + model: str = "" + procReturnCode: ExitCode = ExitCode.DID_NOT_RUN + procOutputQueue: queue.Queue = queue.Queue() + procThread: threading.Thread = None + customData: Any = None + + # # MONAIAuto3DSeg # @@ -793,9 +825,6 @@ class MONAIAuto3DSegLogic(ScriptedLoadableModuleLogic, ModelDatabase): https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py """ - EXIT_CODE_USER_CANCELLED = 1001 - EXIT_CODE_DID_NOT_RUN = 1002 - DEPENDENCY_HANDLER = SlicerPythonDependencies() @staticmethod @@ -837,11 +866,6 @@ def getLoadedAnatomicContextNames(): return [anatomicContextNames.GetValue(idx) for idx in range(anatomicContextNames.GetNumberOfValues())] - def downloadAllModels(self): - for model in self.models: - slicer.app.processEvents() - self.downloadModel(model["id"]) - @staticmethod def _terminologyPropertyTypes(terminologyName): """Get label terminology property types defined in from MONAI Auto3DSeg terminology. @@ -1070,10 +1094,11 @@ def setDefaultParameters(self, parameterNode): if not parameterNode.GetParameter("ServerPort"): parameterNode.SetParameter("ServerPort", str(8891)) - def logProcessOutputUntilCompleted(self, segmentationProcessInfo): + @staticmethod + def logProcessOutputUntilCompleted(segmentationProcessInfo: SegmentationProcessInfo): # Wait for the process to end and forward output to the log from subprocess import CalledProcessError - proc = segmentationProcessInfo["proc"] + proc = segmentationProcessInfo.proc while True: try: line = proc.stdout.readline() @@ -1087,7 +1112,7 @@ def logProcessOutputUntilCompleted(self, segmentationProcessInfo): pass proc.wait() retcode = proc.returncode - segmentationProcessInfo["procReturnCode"] = retcode + segmentationProcessInfo.procReturnCode = retcode if retcode != 0: raise CalledProcessError(retcode, proc.args, output=proc.stdout, stderr=proc.stderr) @@ -1096,7 +1121,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor Run the processing algorithm. Can be used without GUI widget. :param inputNodes: input nodes in a list - :param outputVolume: thresholding result + :param outputSegmentation: output segmentation to write to :param model: one of self.models :param cpu: use CPU instead of GPU :param waitForCompletion: if True then the method waits for the processing to finish @@ -1118,10 +1143,8 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor modelPath = self.modelPath(model) - segmentationProcessInfo = {} + segmentationProcessInfo = SegmentationProcessInfo() - import time - startTime = time.time() logging.info("Processing started") if self.debugSkipInference: @@ -1176,16 +1199,13 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor else: proc = slicer.util.launchConsoleProcess(auto3DSegCommand, updateEnvironment=additionalEnvironmentVariables) - segmentationProcessInfo["proc"] = proc - segmentationProcessInfo["procReturnCode"] = MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN - segmentationProcessInfo["cancelRequested"] = False - segmentationProcessInfo["startTime"] = startTime - segmentationProcessInfo["tempDir"] = tempDir - segmentationProcessInfo["inputNodes"] = inputNodes - segmentationProcessInfo["outputSegmentation"] = outputSegmentation - segmentationProcessInfo["outputSegmentationFile"] = outputSegmentationFile - segmentationProcessInfo["model"] = model - segmentationProcessInfo["customData"] = customData + segmentationProcessInfo.proc = proc + segmentationProcessInfo.tempDir = tempDir + segmentationProcessInfo.inputNodes = inputNodes + segmentationProcessInfo.outputSegmentation = outputSegmentation + segmentationProcessInfo.outputSegmentationFile = outputSegmentationFile + segmentationProcessInfo.model = model + segmentationProcessInfo.customData = customData if proc: if waitForCompletion: @@ -1201,10 +1221,10 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor return segmentationProcessInfo - def cancelProcessing(self, segmentationProcessInfo): + def cancelProcessing(self, segmentationProcessInfo: SegmentationProcessInfo): logging.info("Cancel is requested.") - segmentationProcessInfo["cancelRequested"] = True - proc = segmentationProcessInfo.get("proc") + segmentationProcessInfo.cancelRequested = True + proc = segmentationProcessInfo.proc if proc: # Simple proc.kill() would not work, that would only stop the launcher import psutil @@ -1217,15 +1237,15 @@ def cancelProcessing(self, segmentationProcessInfo): self.onSegmentationProcessCompleted(segmentationProcessInfo) @staticmethod - def _handleProcessOutputThreadProcess(segmentationProcessInfo): + def _handleProcessOutputThreadProcess(segmentationProcessInfo: SegmentationProcessInfo): # Wait for the process to end and forward output to the log - proc = segmentationProcessInfo["proc"] + proc = segmentationProcessInfo.proc while True: try: line = proc.stdout.readline() if not line: break - segmentationProcessInfo["procOutputQueue"].put(line.rstrip()) + segmentationProcessInfo.procOutputQueue.put(line.rstrip()) except UnicodeDecodeError as e: # Code page conversion happens because `universal_newlines=True` sets process output to text mode, # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string, @@ -1233,23 +1253,22 @@ def _handleProcessOutputThreadProcess(segmentationProcessInfo): pass proc.wait() retcode = proc.returncode # non-zero return code means error - segmentationProcessInfo["procReturnCode"] = retcode + segmentationProcessInfo.procReturnCode = retcode def startSegmentationProcessMonitoring(self, segmentationProcessInfo): import queue import threading - segmentationProcessInfo["procOutputQueue"] = queue.Queue() - segmentationProcessInfo["procThread"] = threading.Thread(target=MONAIAuto3DSegLogic._handleProcessOutputThreadProcess, args=[segmentationProcessInfo]) - segmentationProcessInfo["procThread"].start() + segmentationProcessInfo.procThread = threading.Thread(target=MONAIAuto3DSegLogic._handleProcessOutputThreadProcess, args=[segmentationProcessInfo]) + segmentationProcessInfo.procThread.start() self.checkSegmentationProcessOutput(segmentationProcessInfo) def checkSegmentationProcessOutput(self, segmentationProcessInfo): import queue - outputQueue = segmentationProcessInfo["procOutputQueue"] + outputQueue = segmentationProcessInfo.procOutputQueue while outputQueue: - if segmentationProcessInfo.get("procReturnCode") != MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN: + if segmentationProcessInfo.procReturnCode != ExitCode.DID_NOT_RUN: self.onSegmentationProcessCompleted(segmentationProcessInfo) return try: @@ -1261,31 +1280,24 @@ def checkSegmentationProcessOutput(self, segmentationProcessInfo): # No more outputs to process now, check again later qt.QTimer.singleShot(self.processOutputCheckTimerIntervalMsec, lambda: self.checkSegmentationProcessOutput(segmentationProcessInfo=segmentationProcessInfo)) - def onSegmentationProcessCompleted(self, segmentationProcessInfo): - startTime = segmentationProcessInfo["startTime"] - tempDir = segmentationProcessInfo["tempDir"] - inputNodes = segmentationProcessInfo["inputNodes"] - outputSegmentation = segmentationProcessInfo["outputSegmentation"] - outputSegmentationFile = segmentationProcessInfo["outputSegmentationFile"] - model = segmentationProcessInfo["model"] - customData = segmentationProcessInfo["customData"] - procReturnCode = segmentationProcessInfo["procReturnCode"] - cancelRequested = segmentationProcessInfo["cancelRequested"] - - if cancelRequested: - procReturnCode = MONAIAuto3DSegLogic.EXIT_CODE_USER_CANCELLED + def onSegmentationProcessCompleted(self, segmentationProcessInfo: SegmentationProcessInfo): + procReturnCode = segmentationProcessInfo.procReturnCode + customData = segmentationProcessInfo.customData + if cancelRequested := segmentationProcessInfo.cancelRequested: + procReturnCode = ExitCode.USER_CANCELLED logging.info(f"Processing was cancelled.") else: if procReturnCode == 0: + outputSegmentation = segmentationProcessInfo.outputSegmentation if self.startResultImportCallback: self.startResultImportCallback(customData) try: # Load result logging.info("Importing segmentation results...") - self.readSegmentation(outputSegmentation, outputSegmentationFile, model) + self.readSegmentation(outputSegmentation, segmentationProcessInfo.outputSegmentationFile, segmentationProcessInfo.model) # Set source volume - required for DICOM Segmentation export - inputVolume = inputNodes[0] + inputVolume = segmentationProcessInfo.inputNodes[0] if not inputVolume.IsA('vtkMRMLScalarVolumeNode'): raise ValueError("First input node must be a scalar volume") outputSegmentation.SetNodeReferenceID(outputSegmentation.GetReferenceImageGeometryReferenceRole(), inputVolume.GetID()) @@ -1304,6 +1316,7 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): else: logging.info(f"Processing failed with return code {procReturnCode}") + tempDir = segmentationProcessInfo.tempDir if self.clearOutputFolder: logging.info("Cleaning up temporary folder.") if os.path.isdir(tempDir): @@ -1313,10 +1326,7 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo): logging.info(f"Not cleaning up temporary folder: {tempDir}") # Report total elapsed time - import time - stopTime = time.time() - segmentationProcessInfo["stopTime"] = stopTime - elapsedTime = stopTime - startTime + elapsedTime = time.time() - segmentationProcessInfo.startTime if cancelRequested: logging.info(f"Processing was cancelled after {elapsedTime:.2f} seconds.") else: @@ -1448,8 +1458,7 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF :param customData: any custom data to identify or describe this processing request, it will be returned in the process completed callback when waitForCompletion is False """ - import time - startTime = time.time() + segmentationProcessInfo = SegmentationProcessInfo() logging.info("Processing started") tempDir = slicer.util.tempDirectory() @@ -1474,9 +1483,6 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF else: raise ValueError(f"Input node type {inputNode.GetClassName()} is not supported") - segmentationProcessInfo = {} - segmentationProcessInfo["procReturnCode"] = MONAIAuto3DSegLogic.EXIT_CODE_DID_NOT_RUN - logging.info(f"Initiating Inference on {self._server_address}") files = {} @@ -1494,19 +1500,17 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF for chunk in r.iter_content(chunk_size=8192): binary_file.write(chunk) - segmentationProcessInfo["procReturnCode"] = 0 + segmentationProcessInfo.procReturnCode = 0 finally: for f in files.values(): f.close() - segmentationProcessInfo["cancelRequested"] = False - segmentationProcessInfo["startTime"] = startTime - segmentationProcessInfo["tempDir"] = tempDir - segmentationProcessInfo["inputNodes"] = inputNodes - segmentationProcessInfo["outputSegmentation"] = outputSegmentation - segmentationProcessInfo["outputSegmentationFile"] = outputSegmentationFile - segmentationProcessInfo["model"] = modelId - segmentationProcessInfo["customData"] = customData + segmentationProcessInfo.tempDir = tempDir + segmentationProcessInfo.inputNodes = inputNodes + segmentationProcessInfo.outputSegmentation = outputSegmentation + segmentationProcessInfo.outputSegmentationFile = outputSegmentationFile + segmentationProcessInfo.model = modelId + segmentationProcessInfo.customData = customData self.onSegmentationProcessCompleted(segmentationProcessInfo) @@ -1628,7 +1632,6 @@ def test_MONAIAuto3DSeg1(self): # Run the segmentation self.delayDisplay(f"Running segmentation for {model['title']}...") - import time startTime = time.time() logic.process(inputNodes, outputSegmentation, model["id"], forceUseCpu) segmentationTimeSec = time.time() - startTime diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py index c9c599c..ef1e865 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py @@ -153,6 +153,11 @@ def deleteAllModels(self): import shutil shutil.rmtree(self.modelsPath) + def downloadAllModels(self): + for model in self.models: + slicer.app.processEvents() + self.downloadModel(model["id"]) + def downloadModel(self, modelName): url = self.model(modelName)["url"] import zipfile From 0954af9f4a83de51114a9706c11808db062a80a4 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Thu, 10 Oct 2024 15:41:56 -0400 Subject: [PATCH 10/27] ENH: introducing LocalInference, InferenceServer and their parent class BackgroundProcess --- MONAIAuto3DSeg/CMakeLists.txt | 2 +- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 174 +++----------- MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py | 254 ++++++++++++++++++++ MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py | 137 ----------- 4 files changed, 289 insertions(+), 278 deletions(-) create mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py delete mode 100644 MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py diff --git a/MONAIAuto3DSeg/CMakeLists.txt b/MONAIAuto3DSeg/CMakeLists.txt index 322c24e..40f5d0c 100644 --- a/MONAIAuto3DSeg/CMakeLists.txt +++ b/MONAIAuto3DSeg/CMakeLists.txt @@ -7,7 +7,7 @@ set(MODULE_PYTHON_SCRIPTS ${MODULE_NAME}Lib/__init__.py ${MODULE_NAME}Lib/dependency_handler.py ${MODULE_NAME}Lib/model_database.py - ${MODULE_NAME}Lib/server.py + ${MODULE_NAME}Lib/process.py ${MODULE_NAME}Lib/utils.py ${MODULE_NAME}Server/__init__py ${MODULE_NAME}Server/main.py diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 5aa6049..83bfceb 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -15,36 +15,9 @@ from MONAIAuto3DSegLib.model_database import ModelDatabase from MONAIAuto3DSegLib.utils import humanReadableTimeFromSec from MONAIAuto3DSegLib.dependency_handler import SlicerPythonDependencies, RemotePythonDependencies +from MONAIAuto3DSegLib.process import InferenceServer, LocalInference, ExitCode, SegmentationProcessInfo -import subprocess -import queue -import threading -from dataclasses import dataclass -from typing import Any - -from enum import Enum - - -class ExitCode(Enum): - USER_CANCELLED = 1001 - DID_NOT_RUN = 1002 - - -@dataclass -class SegmentationProcessInfo: - proc: subprocess.Popen = None - startTime: float = time.time() - cancelRequested: bool = False - tempDir: str = "" - inputNodes: list = None - outputSegmentation: slicer.vtkMRMLSegmentationNode = None - outputSegmentationFile: str = "" - model: str = "" - procReturnCode: ExitCode = ExitCode.DID_NOT_RUN - procOutputQueue: queue.Queue = queue.Queue() - procThread: threading.Thread = None - customData: Any = None # @@ -227,7 +200,7 @@ def onReload(self): self._webServer.killProcess() packageName ="MONAIAuto3DSegLib" - submoduleNames = ['dependency_handler', 'model_database', 'server', 'utils'] + submoduleNames = ['dependency_handler', 'model_database', 'process', 'utils'] import imp f, filename, description = imp.find_module(packageName) package = imp.load_module(packageName, f, filename, description) @@ -623,7 +596,8 @@ def onApply(self): def onCancel(self): with slicer.util.tryWithErrorDisplay("Failed to cancel processing.", waitCursor=True): - self.logic.cancelProcessing(self._segmentationProcessInfo) + # TODO: needed here?? self._segmentationProcessInfo + self.logic.cancelProcessing() self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_CANCEL_REQUESTED) def onProcessImportStarted(self, customData): @@ -636,7 +610,13 @@ def onProcessImportEnded(self, customData): slicer.app.processEvents() def onProcessingCompleted(self, returnCode, customData): - self.addLog("\nProcessing finished.") + if returnCode == 0: + m = "\nProcessing finished." + elif returnCode == ExitCode.USER_CANCELLED: + m = "\nProcessing was cancelled." + else: + m = f"\nProcessing failed with error code {returnCode}." + self.addLog(m) self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) self._segmentationProcessInfo = None @@ -741,6 +721,8 @@ def onServerButtonToggled(self, toggled): with slicer.util.tryWithErrorDisplay("Failed to start server.", waitCursor=True): if toggled: # TODO: improve error reporting if installation of requirements fails + + self.ui.statusLabel.plainText = "" self.logic.setupPythonRequirements() if not self._webServer or not self._webServer.isRunning() : @@ -751,9 +733,8 @@ def onServerButtonToggled(self, toggled): hostName = platform.node() port = str(self.ui.portSpinBox.value) - from MONAIAuto3DSegLib.server import WebServer - self._webServer = WebServer( - logCallback=self.addLog, + self._webServer = InferenceServer( + logCallback=self.addServerLog, completedCallback=self.onServerCompleted ) self._webServer.hostName = hostName @@ -765,9 +746,13 @@ def onServerButtonToggled(self, toggled): self._webServer = None self.updateGUIFromParameterNode() - def onServerCompleted(self, text=None): - if text: - self.addLog(text) + def onServerCompleted(self, processInfo=None): + returnCode = processInfo.procReturnCode + if returnCode == ExitCode.USER_CANCELLED: + m = "\nServer was stopped." + else: + m = f"\nProcessing failed with error code {returnCode}." + self.addLog(m) self.ui.serverButton.setChecked(False) def serverUrl(self): @@ -982,8 +967,7 @@ def __init__(self): self.endResultImportCallback = None self.useStandardSegmentNames = True - # Timer for checking the output of the segmentation process that is running in the background - self.processOutputCheckTimerIntervalMsec = 1000 + self._bgProcess = None # For testing the logic without actually running inference, set self.debugSkipInferenceTempDir to the location # where inference result is stored and set self.debugSkipInference to True. @@ -1094,28 +1078,6 @@ def setDefaultParameters(self, parameterNode): if not parameterNode.GetParameter("ServerPort"): parameterNode.SetParameter("ServerPort", str(8891)) - @staticmethod - def logProcessOutputUntilCompleted(segmentationProcessInfo: SegmentationProcessInfo): - # Wait for the process to end and forward output to the log - from subprocess import CalledProcessError - proc = segmentationProcessInfo.proc - while True: - try: - line = proc.stdout.readline() - if not line: - break - logging.info(line.rstrip()) - except UnicodeDecodeError as e: - # Code page conversion happens because `universal_newlines=True` sets process output to text mode, - # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string, - # as we only guarantee correct behavior if an UTF8 locale is used. - pass - proc.wait() - retcode = proc.returncode - segmentationProcessInfo.procReturnCode = retcode - if retcode != 0: - raise CalledProcessError(retcode, proc.args, output=proc.stdout, stderr=proc.stderr) - def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitForCompletion=True, customData=None): """ Run the processing algorithm. @@ -1194,12 +1156,6 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor additionalEnvironmentVariables = {"CUDA_VISIBLE_DEVICES": "-1"} logging.info(f"Additional environment variables: {additionalEnvironmentVariables}") - if self.debugSkipInference: - proc = None - else: - proc = slicer.util.launchConsoleProcess(auto3DSegCommand, updateEnvironment=additionalEnvironmentVariables) - - segmentationProcessInfo.proc = proc segmentationProcessInfo.tempDir = tempDir segmentationProcessInfo.inputNodes = inputNodes segmentationProcessInfo.outputSegmentation = outputSegmentation @@ -1207,86 +1163,19 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor segmentationProcessInfo.model = model segmentationProcessInfo.customData = customData - if proc: - if waitForCompletion: - # Wait for the process to end before returning - self.logProcessOutputUntilCompleted(segmentationProcessInfo) - self.onSegmentationProcessCompleted(segmentationProcessInfo) - else: - # Run the process in the background - self.startSegmentationProcessMonitoring(segmentationProcessInfo) - else: - # Debugging + self._bgProcess = LocalInference(processInfo=segmentationProcessInfo, completedCallback=self.onSegmentationProcessCompleted) + if self.debugSkipInference: self.onSegmentationProcessCompleted(segmentationProcessInfo) - - return segmentationProcessInfo - - def cancelProcessing(self, segmentationProcessInfo: SegmentationProcessInfo): - logging.info("Cancel is requested.") - segmentationProcessInfo.cancelRequested = True - proc = segmentationProcessInfo.proc - if proc: - # Simple proc.kill() would not work, that would only stop the launcher - import psutil - psProcess = psutil.Process(proc.pid) - for psChildProcess in psProcess.children(recursive=True): - psChildProcess.kill() - if psProcess.is_running(): - psProcess.kill() else: - self.onSegmentationProcessCompleted(segmentationProcessInfo) + self._bgProcess.run(auto3DSegCommand, additionalEnvironmentVariables=additionalEnvironmentVariables, waitForCompletion=waitForCompletion) - @staticmethod - def _handleProcessOutputThreadProcess(segmentationProcessInfo: SegmentationProcessInfo): - # Wait for the process to end and forward output to the log - proc = segmentationProcessInfo.proc - while True: - try: - line = proc.stdout.readline() - if not line: - break - segmentationProcessInfo.procOutputQueue.put(line.rstrip()) - except UnicodeDecodeError as e: - # Code page conversion happens because `universal_newlines=True` sets process output to text mode, - # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string, - # as we only guarantee correct behavior if an UTF8 locale is used. - pass - proc.wait() - retcode = proc.returncode # non-zero return code means error - segmentationProcessInfo.procReturnCode = retcode - - def startSegmentationProcessMonitoring(self, segmentationProcessInfo): - import queue - import threading - - segmentationProcessInfo.procThread = threading.Thread(target=MONAIAuto3DSegLogic._handleProcessOutputThreadProcess, args=[segmentationProcessInfo]) - segmentationProcessInfo.procThread.start() - - self.checkSegmentationProcessOutput(segmentationProcessInfo) - - def checkSegmentationProcessOutput(self, segmentationProcessInfo): - import queue - outputQueue = segmentationProcessInfo.procOutputQueue - while outputQueue: - if segmentationProcessInfo.procReturnCode != ExitCode.DID_NOT_RUN: - self.onSegmentationProcessCompleted(segmentationProcessInfo) - return - try: - line = outputQueue.get_nowait() - logging.info(line) - except queue.Empty: - break - - # No more outputs to process now, check again later - qt.QTimer.singleShot(self.processOutputCheckTimerIntervalMsec, lambda: self.checkSegmentationProcessOutput(segmentationProcessInfo=segmentationProcessInfo)) + return segmentationProcessInfo def onSegmentationProcessCompleted(self, segmentationProcessInfo: SegmentationProcessInfo): procReturnCode = segmentationProcessInfo.procReturnCode customData = segmentationProcessInfo.customData - if cancelRequested := segmentationProcessInfo.cancelRequested: - procReturnCode = ExitCode.USER_CANCELLED - logging.info(f"Processing was cancelled.") - else: + cancelRequested = procReturnCode == ExitCode.USER_CANCELLED + if not cancelRequested: if procReturnCode == 0: outputSegmentation = segmentationProcessInfo.outputSegmentation if self.startResultImportCallback: @@ -1338,6 +1227,11 @@ def onSegmentationProcessCompleted(self, segmentationProcessInfo: SegmentationPr if self.processingCompletedCallback: self.processingCompletedCallback(procReturnCode, customData) + def cancelProcessing(self): + if not self._bgProcess: + return + self._bgProcess.stop() + def readSegmentation(self, outputSegmentation, outputSegmentationFile, model): labelValueToDescription = self.labelDescriptions(model) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py new file mode 100644 index 0000000..8fa52c5 --- /dev/null +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py @@ -0,0 +1,254 @@ +import subprocess + +import slicer +import psutil + +import sys +import logging +import queue +import threading + +import qt +from pathlib import Path +from typing import Callable, Any +import time +from dataclasses import dataclass, field + +from enum import Enum + + +class ExitCode(Enum): + USER_CANCELLED = 1001 + DID_NOT_RUN = 1002 + + +@dataclass +class ProcessInfo: + proc: subprocess.Popen = None + startTime: float = field(default_factory=time.time) + procReturnCode: ExitCode = ExitCode.DID_NOT_RUN + procOutputQueue: queue.Queue = queue.Queue() + procThread: threading.Thread = None + + +class SegmentationProcessInfo(ProcessInfo): + tempDir: str = "" + inputNodes: list = None + outputSegmentation: slicer.vtkMRMLSegmentationNode = None + outputSegmentationFile: str = "" + model: str = "" + customData: Any = None + + +class BackgroundProcess: + """ Any kind of process with threads and continuous checking until stopped""" + + # Timer for checking the output of the process that is running in the background + CHECK_TIMER_INTERVAL = 1000 + + @property + def proc(self): + return self.processInfo.proc + + @proc.setter + def proc(self, proc): + self.processInfo.proc = proc + + @property + def procThread(self): + return self.processInfo.procThread + + @procThread.setter + def procThread(self, procThread): + self.processInfo.procThread = procThread + + @property + def procOutputQueue(self): + return self.processInfo.procOutputQueue + + @procOutputQueue.setter + def procOutputQueue(self, procOutputQueue): + self.processInfo.procOutputQueue = procOutputQueue + + @staticmethod + def getPSProcess(pid): + try: + return psutil.Process(pid) + except psutil.NoSuchProcess: + return None + + def __init__(self, processInfo: ProcessInfo = None, logCallback: Callable = None, completedCallback: Callable = None): + self.processInfo = processInfo if processInfo else ProcessInfo() + self.logCallback = logCallback + self.completedCallback = completedCallback + + # NB: making sure that the following values were not set previously + self.proc = None + self.procThread = None + self.procOutputQueue = None + + def __del__(self): + self._killProcess() + + def cleanup(self): + if self.procThread: + self.procThread.join() + if self.completedCallback: + if self.proc.returncode not in [-9, 0]: # killed or stopped cleanly + self.addLog(self._err) + + self.completedCallback(self.processInfo) + self.proc = None + self.procThread = None + self.procOutputQueue = None + + def isRunning(self): + if self.proc is not None: + psProcess = self.getPSProcess(self.proc.pid) + if psProcess: + return psProcess.is_running() + return False + + def _killProcess(self): + if not self.proc: + return + psProcess = self.getPSProcess(self.proc.pid) # proc.kill() does not work, that would only stop the launcher + if not psProcess: + return + for psChildProcess in psProcess.children(recursive=True): + psChildProcess.kill() + if psProcess.is_running(): + psProcess.kill() + + def stop(self): + logging.info("Cancel is requested.") + self._killProcess() + self.processInfo.procReturnCode = ExitCode.USER_CANCELLED + + def _startHandleProcessOutputThread(self): + self.procOutputQueue = queue.Queue() + self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess) + self.procThread.start() + self.checkProcessOutput() + + def _handleProcessOutputThreadProcess(self): + while True: + try: + line = self.proc.stdout.readline() + if not line: + break + text = line.rstrip() + self.procOutputQueue.put(text) + except UnicodeDecodeError as e: + pass + self.proc.wait() + self.processInfo.procReturnCode = self.proc.returncode # non-zero return code means error + + def checkProcessOutput(self): + outputQueue = self.procOutputQueue + while outputQueue: + if self.processInfo.procReturnCode != ExitCode.DID_NOT_RUN: + self.completedCallback(self.processInfo) + return + try: + line = outputQueue.get_nowait() + logging.info(line) + except queue.Empty: + break + + psProcess = self.getPSProcess(self.proc.pid) + if psProcess and psProcess.is_running(): # No more outputs to process now, check again later + qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput) + else: + self.cleanup() + + def addLog(self, text): + if self.logCallback: + self.logCallback(text) + + +class InferenceServer(BackgroundProcess): + """ Web server class to be used from 3D Slicer. Upon starting the server with the given cmd command, it will be + checked every {CHECK_TIMER_INTERVAL} milliseconds for process status/outputs. + + code: + + cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] + + from MONAIAuto3DSegLib.server import WebServer + server = WebServer(completedCallback=...) + server.start() + + ... + + server.stop() + + """ + + def __init__(self, processInfo: ProcessInfo = None, logCallback: Callable = None, + completedCallback: Callable = None): + super().__init__(processInfo, logCallback, completedCallback) + self.hostName = "127.0.0.1" + self.port = 8891 + + def getAddressUrl(self): + return f"http://{self.hostName}:{self.port}" + + def start(self): + cmd = [ + sys.executable, + Path(__file__).parent.parent / "MONAIAuto3DSegServer" / "main.py", "--host", + self.hostName, "--port", self.port + ] + + logging.debug(f"Launching process: {cmd}") + self.proc = slicer.util.launchConsoleProcess(cmd, useStartupEnvironment=False) + + if self.isRunning(): + self.addLog("Server Started") + + self._startHandleProcessOutputThread() + + +class LocalInference(BackgroundProcess): + """ Running local inference until finished or cancelled. """ + + def __init__(self, processInfo: ProcessInfo = None, logCallback: Callable = None, + completedCallback: Callable = None, waitForCompletion: bool = True): + super().__init__(processInfo, logCallback, completedCallback) + self.waitForCompletion = waitForCompletion + + def run(self, cmd, additionalEnvironmentVariables=None, waitForCompletion=True): + logging.debug(f"Launching process: {cmd}") + self.proc = slicer.util.launchConsoleProcess(cmd, updateEnvironment=additionalEnvironmentVariables) + + if self.isRunning(): + self.addLog("Process Started") + + if waitForCompletion: + self.logProcessOutputUntilCompleted() + self.completedCallback(self.processInfo) + else: + self._startHandleProcessOutputThread() + + def logProcessOutputUntilCompleted(self): + # Wait for the process to end and forward output to the log + proc = self.proc + while True: + try: + line = proc.stdout.readline() + if not line: + break + logging.info(line.rstrip()) + except UnicodeDecodeError as e: + # Code page conversion happens because `universal_newlines=True` sets process output to text mode, + # and it fails because probably system locale is not UTF8. We just ignore the error and discard the string, + # as we only guarantee correct behavior if an UTF8 locale is used. + pass + proc.wait() + retcode = proc.returncode + self.processInfo.procReturnCode = retcode + + if retcode != 0: + from subprocess import CalledProcessError + raise CalledProcessError(proc.returncode, proc.args, output=proc.stdout, stderr=proc.stderr) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py deleted file mode 100644 index d90a7eb..0000000 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/server.py +++ /dev/null @@ -1,137 +0,0 @@ -import slicer -import psutil - -import logging -import queue -import threading - -import qt -from pathlib import Path -from typing import Callable - - -class WebServer: - """ Web server class to be used from 3D Slicer. Upon starting the server with the given cmd command, it will be - checked every {CHECK_TIMER_INTERVAL} milliseconds for process status/outputs. - - code: - - cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] - - from MONAIAuto3DSegLib.server import WebServer - server = WebServer(completedCallback=...) - server.start() - - ... - - server.stop() - - """ - - CHECK_TIMER_INTERVAL = 1000 - - @staticmethod - def getPSProcess(pid): - try: - return psutil.Process(pid) - except psutil.NoSuchProcess: - return None - - def __init__(self, logCallback: Callable=None, completedCallback: Callable=None): - self.logCallback = logCallback - self.completedCallback = completedCallback - self.procThread = None - self.serverProc = None - self.queue = None - self.hostName = "127.0.0.1" - self.port = 8891 - - def isRunning(self): - if self.serverProc is not None: - psProcess = self.getPSProcess(self.serverProc.pid) - if psProcess: - return psProcess.is_running() - return False - - def __del__(self): - self._killProcess() - - def _killProcess(self): - if not self.serverProc: - return - psProcess = self.getPSProcess(self.serverProc.pid) # proc.kill() does not work, that would only stop the launcher - if not psProcess: - return - for psChildProcess in psProcess.children(recursive=True): - psChildProcess.kill() - if psProcess.is_running(): - psProcess.kill() - - def start(self): - import sys - cmd = [sys.executable, "main.py", "--host", self.hostName, "--port", self.port] - self._launchConsoleProcess(cmd) - - def stop(self): - self._killProcess() - - def getAddressUrl(self): - return f"http://{self.hostName}:{self.port}" - - def _launchConsoleProcess(self, cmd): - logging.debug(f"Launching process: {cmd}") - self.serverProc = \ - slicer.util.launchConsoleProcess(cmd, cwd=Path(__file__).parent.parent / "MONAIAuto3DSegServer", useStartupEnvironment=False) - - if self.isRunning(): - self.addLog("Server Started") - - self.queue = queue.Queue() - self.procThread = threading.Thread(target=self._handleProcessOutputThreadProcess) - self.procThread.start() - self.checkProcessOutput() - - def cleanup(self): - if self.procThread: - self.procThread.join() - if self.completedCallback: - if self.serverProc.returncode not in [-9, 0]: # killed or stopped cleanly - self.addLog(self._err) - self.completedCallback("Server Stopped") - self.serverProc = None - self.procThread = None - self.queue = None - - def _handleProcessOutputThreadProcess(self): - self._err = None - while True: - try: - line = self.serverProc.stdout.readline() - if not line: - break - text = line.rstrip() - self.queue.put(text) - if "ERROR" in text: - self._err = text - except UnicodeDecodeError as e: - pass - self.serverProc.wait() - - def checkProcessOutput(self): - outputQueue = self.queue - while outputQueue: - try: - line = outputQueue.get_nowait() - logging.info(line) - except queue.Empty: - break - - psProcess = self.getPSProcess(self.serverProc.pid) - if psProcess and psProcess.is_running(): # No more outputs to process now, check again later - qt.QTimer.singleShot(self.CHECK_TIMER_INTERVAL, self.checkProcessOutput) - else: - self.cleanup() - - def addLog(self, text): - if self.logCallback: - self.logCallback(text) \ No newline at end of file From 0f6adfd58277e1346122c09fd9ce4787880a4409 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Thu, 10 Oct 2024 15:46:05 -0400 Subject: [PATCH 11/27] STYLE: for local inference, processInfo is required and not optional --- MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py index 8fa52c5..2c44def 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py @@ -213,7 +213,7 @@ def start(self): class LocalInference(BackgroundProcess): """ Running local inference until finished or cancelled. """ - def __init__(self, processInfo: ProcessInfo = None, logCallback: Callable = None, + def __init__(self, processInfo: SegmentationProcessInfo, logCallback: Callable = None, completedCallback: Callable = None, waitForCompletion: bool = True): super().__init__(processInfo, logCallback, completedCallback) self.waitForCompletion = waitForCompletion From 0de1609a1b9eaf5424e3f5c405e945d37c53031a Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Mon, 14 Oct 2024 12:45:39 -0400 Subject: [PATCH 12/27] BUG: server port spinbox changes not saved into parameter node --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 1 + MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 83bfceb..d5defc5 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -274,6 +274,7 @@ def setup(self): self.ui.remoteServerButton.toggled.connect(self.onRemoteServerButtonToggled) self.ui.serverButton.toggled.connect(self.onServerButtonToggled) + self.ui.portSpinBox.valueChanged.connect(self.updateParameterNodeFromGUI) # Make sure parameter node is initialized (needed for module reload) self.initializeParameterNode() diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py index 2c44def..f8990d0 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py @@ -175,8 +175,8 @@ class InferenceServer(BackgroundProcess): cmd = [sys.executable, "main.py", "--host", hostName, "--port", port] - from MONAIAuto3DSegLib.server import WebServer - server = WebServer(completedCallback=...) + from MONAIAuto3DSegLib.process import InferenceServer + server = InferenceServer(completedCallback=...) server.start() ... From 9cc838057216df2a4be024e692e0a03fd67529c2 Mon Sep 17 00:00:00 2001 From: che85 Date: Tue, 15 Oct 2024 22:36:38 -0400 Subject: [PATCH 13/27] BUG: fastapi server on Windows causing issues when reload set to True --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 26 +++++++++++++-------- MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py | 2 +- MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py | 21 ++++++++--------- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index d5defc5..f135604 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -729,7 +729,7 @@ def onServerButtonToggled(self, toggled): if not self._webServer or not self._webServer.isRunning() : import platform from pathlib import Path - slicer.util.pip_install("python-multipart fastapi uvicorn[standard]") + slicer.util.pip_install("psutil python-multipart fastapi uvicorn[standard]") hostName = platform.node() port = str(self.ui.portSpinBox.value) @@ -1328,18 +1328,24 @@ def labelDescriptions(self, modelName): https://slicer.readthedocs.io/en/latest/developer_guide/modules/segmentations.html#terminologyentry-tag """ if not self._server_address: - return [] + return {} else: - import tempfile - with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile: - with requests.get(self._server_address + f"/labelDescriptions?id={modelName}", stream=True) as r: - r.raise_for_status() + from pathlib import Path + tempDir = slicer.util.tempDirectory() + tempDir = Path(tempDir) + outfile = tempDir / "labelDescriptions.csv" + with requests.get(self._server_address + f"/labelDescriptions?id={modelName}", stream=True) as r: + r.raise_for_status() - with open(tmpfile.name, 'wb') as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) + with open(outfile, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + labelDescriptions = self._labelDescriptions(outfile) - return self._labelDescriptions(tmpfile.name) + import shutil + shutil.rmtree(tempDir) + return labelDescriptions def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitForCompletion=True, customData=None): """ diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py index f8990d0..89ccfad 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py @@ -1,7 +1,6 @@ import subprocess import slicer -import psutil import sys import logging @@ -72,6 +71,7 @@ def procOutputQueue(self, procOutputQueue): @staticmethod def getPSProcess(pid): + import psutil try: return psutil.Process(pid) except psutil.NoSuchProcess: diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py index 490a50d..0b38579 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py @@ -1,9 +1,7 @@ -# pip install python-multipart -# pip install fastapi -# pip install "uvicorn[standard]" +# pip install python-multipart fastapi uvicorn[standard] -# usage: uvicorn main:app --reload --host example.com --port 8891 -# usage: uvicorn main:app --reload --host localhost --port 8891 +# usage: uvicorn main:app --host example.com --port 8891 +# usage: uvicorn main:app --host localhost --port 8891 import os @@ -45,7 +43,7 @@ def upload(file, session_dir, identifier): extension = "".join(Path(file.filename).suffixes) - file_location = f"{session_dir}/{identifier}{extension}" + file_location = str(Path(session_dir) / f"{identifier}{extension}") with open(file_location, "wb+") as file_object: file_object.write(file.file.read()) return file_location @@ -82,7 +80,7 @@ async def infer( ): import tempfile session_dir = tempfile.mkdtemp(dir=tempfile.gettempdir()) - background_tasks.add_task(shutil.rmtree, session_dir) + background_tasks.add_task(shutil.rmtree, session_dir, ignore_errors=False) logging.debug(session_dir) inputFiles = list() @@ -96,7 +94,7 @@ async def infer( # logging.info("Input Files: ", inputFiles) - outputSegmentationFile = f"{session_dir}/output-segmentation.nrrd" + outputSegmentationFile = str(Path(session_dir) / "output-segmentation.nrrd") modelPath = modelDB.modelPath(model_name) modelPtFile = modelPath.joinpath("model.pt") @@ -108,7 +106,7 @@ async def infer( auto3DSegCommand = [sys.executable, str(inferenceScriptPyFile), "--model-file", str(modelPtFile), "--image-file", inputFiles[0], - "--result-file", str(outputSegmentationFile)] + "--result-file", outputSegmentationFile] for inputIndex in range(1, len(inputFiles)): auto3DSegCommand.append(f"--image-file-{inputIndex + 1}") auto3DSegCommand.append(inputFiles[inputIndex]) @@ -121,6 +119,7 @@ async def infer( raise subprocess.CalledProcessError(proc.returncode, " ".join(auto3DSegCommand)) return FileResponse(outputSegmentationFile, media_type='application/octet-stream', background=background_tasks) except Exception as e: + logging.info(e) shutil.rmtree(session_dir) raise HTTPException(status_code=500, detail=f"Failed to run CMD command: {str(e)}") @@ -134,8 +133,8 @@ def main(argv): args = parser.parse_args(argv) import uvicorn - - uvicorn.run("main:app", host=args.host, port=args.port, reload=True, log_level="debug") + # NB: reload=True causing issues on Windows (https://stackoverflow.com/a/70570250) + uvicorn.run("main:app", host=args.host, port=args.port, log_level="debug", reload=False) if __name__ == "__main__": From 3f65b4cb97a2b7fb00444c66b77dbf4bb1c836bd Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Wed, 16 Oct 2024 22:00:43 -0400 Subject: [PATCH 14/27] ENH: added progress bar for more progress feedback --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 49 ++++++++++++++++--- .../MONAIAuto3DSegLib/model_database.py | 4 -- MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui | 7 +++ 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index f135604..9406d78 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -179,7 +179,24 @@ class MONAIAuto3DSegWidget(ScriptedLoadableModuleWidget, VTKObservationMixin): PROCESSING_STARTING = 1 PROCESSING_IN_PROGRESS = 2 PROCESSING_IMPORT_RESULTS = 3 - PROCESSING_CANCEL_REQUESTED = 4 + PROCESSING_COMPLETED = 4 + PROCESSING_CANCEL_REQUESTED = 5 + + PROCESSING_STATES = { + PROCESSING_IDLE: "Idle", + PROCESSING_STARTING: "Starting...", + PROCESSING_IN_PROGRESS: "In Progress", + PROCESSING_IMPORT_RESULTS: "Importing Results", + PROCESSING_COMPLETED: "Processing Finished", + PROCESSING_CANCEL_REQUESTED: "Cancelling..." + } + + @staticmethod + def getHumanReadableProcessingState(state): + try: + return MONAIAuto3DSegWidget.PROCESSING_STATES[state] + except KeyError: + return "Unknown State" def __init__(self, parent=None): """ @@ -237,9 +254,9 @@ def setup(self): # Create logic class. Logic implements all computations that should be possible to run # in batch mode, without a graphical user interface. self.logic = MONAIAuto3DSegLogic() - self.logic.processingCompletedCallback = self.onProcessingCompleted self.logic.startResultImportCallback = self.onProcessImportStarted self.logic.endResultImportCallback = self.onProcessImportEnded + self.logic.processingCompletedCallback = self.onProcessingCompleted # Connections @@ -282,6 +299,8 @@ def setup(self): self.updateGUIFromParameterNode() + self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) + # Make the model search box in focus by default so users can just start typing to find the model they need qt.QTimer.singleShot(0, self.ui.modelSearchBox.setFocus) @@ -459,7 +478,6 @@ def updateGUIFromParameterNode(self, caller=None, event=None): self.ui.applyButton.enabled = True elif state == MONAIAuto3DSegWidget.PROCESSING_STARTING: - self.ui.applyButton.text = "Starting..." self.ui.applyButton.toolTip = "Please wait while the segmentation is being initialized" self.ui.applyButton.enabled = False elif state == MONAIAuto3DSegWidget.PROCESSING_IN_PROGRESS: @@ -467,11 +485,9 @@ def updateGUIFromParameterNode(self, caller=None, event=None): self.ui.applyButton.toolTip = "Cancel in-progress segmentation" self.ui.applyButton.enabled = True elif state == MONAIAuto3DSegWidget.PROCESSING_IMPORT_RESULTS: - self.ui.applyButton.text = "Importing results..." self.ui.applyButton.toolTip = "Please wait while the segmentation result is being imported" self.ui.applyButton.enabled = False elif state == MONAIAuto3DSegWidget.PROCESSING_CANCEL_REQUESTED: - self.ui.applyButton.text = "Cancelling..." self.ui.applyButton.toolTip = "Please wait for the segmentation to be cancelled" self.ui.applyButton.enabled = False @@ -546,6 +562,17 @@ def addLog(self, text): # self.ui.statusLabel.appendPlainText(text) # slicer.app.processEvents() # force update + def updateProgress(self, state): + if state == self.PROCESSING_IDLE: + qt.QTimer.singleShot(1000, self.ui.progressBar.hide) + self.ui.progressBar.setRange(0,0) + else: + self.ui.progressBar.setRange(0,4) + self.ui.progressBar.show() + self.ui.progressBar.value = state + self.ui.progressBar.setFormat(text := self.getHumanReadableProcessingState(state)) + self.addLog(text) + def addServerLog(self, *args): for arg in args: if self.ui.logConsoleCheckBox.checked: @@ -556,6 +583,7 @@ def addServerLog(self, *args): def setProcessingState(self, state): self._processingState = state self.updateGUIFromParameterNode() + self.updateProgress(state) slicer.app.processEvents() def onApplyButton(self): @@ -597,7 +625,6 @@ def onApply(self): def onCancel(self): with slicer.util.tryWithErrorDisplay("Failed to cancel processing.", waitCursor=True): - # TODO: needed here?? self._segmentationProcessInfo self.logic.cancelProcessing() self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_CANCEL_REQUESTED) @@ -713,9 +740,9 @@ def onRemoteServerButtonToggled(self): self.ui.remoteServerButton.text = "Connect" self.logic = MONAIAuto3DSegLogic() - self.logic.processingCompletedCallback = self.onProcessingCompleted self.logic.startResultImportCallback = self.onProcessImportStarted self.logic.endResultImportCallback = self.onProcessImportEnded + self.logic.processingCompletedCallback = self.onProcessingCompleted self.updateGUIFromParameterNode() def onServerButtonToggled(self, toggled): @@ -963,15 +990,19 @@ def __init__(self): ScriptedLoadableModuleLogic.__init__(self) ModelDatabase.__init__(self) - self.processingCompletedCallback = None self.startResultImportCallback = None self.endResultImportCallback = None + self.processingCompletedCallback = None self.useStandardSegmentNames = True + # process that will used to run inference either remotely or locally self._bgProcess = None # For testing the logic without actually running inference, set self.debugSkipInferenceTempDir to the location # where inference result is stored and set self.debugSkipInference to True. + # Disabling this flag preserves input and output data after execution is completed, + # which can be useful for troubleshooting. + self.clearOutputFolder = True self.debugSkipInference = False self.debugSkipInferenceTempDir = r"c:\Users\andra\AppData\Local\Temp\Slicer\__SlicerTemp__2024-01-16_15+26+25.624" @@ -1111,6 +1142,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor logging.info("Processing started") if self.debugSkipInference: + self.clearOutputFolder = False # For debugging, use a fixed temporary folder tempDir = self.debugSkipInferenceTempDir else: @@ -1166,6 +1198,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor self._bgProcess = LocalInference(processInfo=segmentationProcessInfo, completedCallback=self.onSegmentationProcessCompleted) if self.debugSkipInference: + segmentationProcessInfo.procReturnCode = 0 self.onSegmentationProcessCompleted(segmentationProcessInfo) else: self._bgProcess.run(auto3DSegCommand, additionalEnvironmentVariables=additionalEnvironmentVariables, waitForCompletion=waitForCompletion) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py index ef1e865..9f6df38 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py @@ -36,10 +36,6 @@ def modelsDescriptionJsonFilePath(self): def __init__(self): self.fileCachePath = Path.home().joinpath(f"{self.DEFAULT_CACHE_DIR_NAME}") self.moduleDir = Path(__file__).parent.parent - - # Disabling this flag preserves input and output data after execution is completed, - # which can be useful for troubleshooting. - self.clearOutputFolder = True self._models = [] def model(self, modelId): diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index 1b0e290..f6036b7 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -297,6 +297,13 @@ + + + + 24 + + + From 42b3d497153d98597f152ad3a2ed3c99f4a48bf2 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Thu, 17 Oct 2024 11:10:18 -0400 Subject: [PATCH 15/27] ENH: improved user feedback for local inference and server process --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 4 ++- MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py | 36 ++++++++++++--------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 9406d78..e3bd29f 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -643,7 +643,7 @@ def onProcessingCompleted(self, returnCode, customData): elif returnCode == ExitCode.USER_CANCELLED: m = "\nProcessing was cancelled." else: - m = f"\nProcessing failed with error code {returnCode}." + m = f"\nProcessing failed with error code {returnCode}. Please check logs for further information." self.addLog(m) self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) self._segmentationProcessInfo = None @@ -768,6 +768,8 @@ def onServerButtonToggled(self, toggled): self._webServer.hostName = hostName self._webServer.port = port self._webServer.start() + if self._webServer.isRunning(): + self.addLog("Server started") else: if self._webServer is not None and self._webServer.isRunning(): self._webServer.stop() diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py index 89ccfad..ae327ab 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py @@ -90,13 +90,13 @@ def __init__(self, processInfo: ProcessInfo = None, logCallback: Callable = None def __del__(self): self._killProcess() + def handleSubProcessLogging(self, text): + logging.info(text) + def cleanup(self): if self.procThread: self.procThread.join() if self.completedCallback: - if self.proc.returncode not in [-9, 0]: # killed or stopped cleanly - self.addLog(self._err) - self.completedCallback(self.processInfo) self.proc = None self.procThread = None @@ -121,9 +121,8 @@ def _killProcess(self): psProcess.kill() def stop(self): - logging.info("Cancel is requested.") self._killProcess() - self.processInfo.procReturnCode = ExitCode.USER_CANCELLED + self._setProcReturnCode(ExitCode.USER_CANCELLED) def _startHandleProcessOutputThread(self): self.procOutputQueue = queue.Queue() @@ -142,19 +141,19 @@ def _handleProcessOutputThreadProcess(self): except UnicodeDecodeError as e: pass self.proc.wait() - self.processInfo.procReturnCode = self.proc.returncode # non-zero return code means error + self._setProcReturnCode(self.proc.returncode) # non-zero return code means error def checkProcessOutput(self): outputQueue = self.procOutputQueue while outputQueue: - if self.processInfo.procReturnCode != ExitCode.DID_NOT_RUN: - self.completedCallback(self.processInfo) - return try: line = outputQueue.get_nowait() - logging.info(line) + self.handleSubProcessLogging(line) except queue.Empty: break + if self.processInfo.procReturnCode != ExitCode.DID_NOT_RUN: + self.completedCallback(self.processInfo) + return psProcess = self.getPSProcess(self.proc.pid) if psProcess and psProcess.is_running(): # No more outputs to process now, check again later @@ -166,6 +165,13 @@ def addLog(self, text): if self.logCallback: self.logCallback(text) + def _setProcReturnCode(self, rcode): + # if user cancelled, leave it at that and don't change it + if self.processInfo.procReturnCode == ExitCode.USER_CANCELLED: + return + self.processInfo.procReturnCode = rcode + + class InferenceServer(BackgroundProcess): """ Web server class to be used from 3D Slicer. Upon starting the server with the given cmd command, it will be @@ -203,12 +209,12 @@ def start(self): logging.debug(f"Launching process: {cmd}") self.proc = slicer.util.launchConsoleProcess(cmd, useStartupEnvironment=False) - - if self.isRunning(): - self.addLog("Server Started") - self._startHandleProcessOutputThread() + def handleSubProcessLogging(self, text): + # NB: let upper level handle if it should be logged to console or UI + self.addLog(text) + class LocalInference(BackgroundProcess): """ Running local inference until finished or cancelled. """ @@ -247,7 +253,7 @@ def logProcessOutputUntilCompleted(self): pass proc.wait() retcode = proc.returncode - self.processInfo.procReturnCode = retcode + self._setProcReturnCode(retcode) if retcode != 0: from subprocess import CalledProcessError From 657da6aea72cee62112d57b3e9a5193da548e9f5 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Thu, 17 Oct 2024 11:26:34 -0400 Subject: [PATCH 16/27] BUG: clearOutputFolder shouldn't be used in case of model download - separating out clearOutputFolder (for inference) and temp model download --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 2 +- MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index e3bd29f..7848e9c 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -781,7 +781,7 @@ def onServerCompleted(self, processInfo=None): if returnCode == ExitCode.USER_CANCELLED: m = "\nServer was stopped." else: - m = f"\nProcessing failed with error code {returnCode}." + m = f"\nProcessing failed with error code {returnCode}. Try again with `Log to GUI` for more details." self.addLog(m) self.ui.serverButton.setChecked(False) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py index 9f6df38..c558f0e 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/model_database.py @@ -37,6 +37,7 @@ def __init__(self): self.fileCachePath = Path.home().joinpath(f"{self.DEFAULT_CACHE_DIR_NAME}") self.moduleDir = Path(__file__).parent.parent self._models = [] + self._clearTempDownloadFolder = True def model(self, modelId): for model in self.models: @@ -150,8 +151,9 @@ def deleteAllModels(self): shutil.rmtree(self.modelsPath) def downloadAllModels(self): + # TODO: add some progress here since this could take a while for all models for model in self.models: - slicer.app.processEvents() + #slicer.app.processEvents() self.downloadModel(model["id"]) def downloadModel(self, modelName): @@ -189,7 +191,7 @@ def downloadModel(self, modelName): except Exception as e: raise e finally: - if self.clearOutputFolder: + if self._clearTempDownloadFolder: logging.info("Cleaning up temporary model download folder...") if os.path.isdir(tempDir): import shutil From 99282efe54cdd1facc1c4813bfd282f40c8036b7 Mon Sep 17 00:00:00 2001 From: Andras Lasso Date: Fri, 18 Oct 2024 11:42:31 -0400 Subject: [PATCH 17/27] GUI tweaks Add "remote processing" checkbox and collapse local server section by default. --- MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui | 254 +++++++++++------- 1 file changed, 151 insertions(+), 103 deletions(-) diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index f6036b7..c6e57e7 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -6,7 +6,7 @@ 0 0 - 408 + 409 841 @@ -17,7 +17,127 @@ Inputs - + + + + Remote processing: + + + + + + + + + + + + true + + + + + + + QFrame::NoFrame + + + QFrame::Raised + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Server: + + + + + + + + 0 + 0 + + + + true + + + + + + + QPushButton:checked { + /* Checked style */ + background-color: green; /* Green background when checked */ + border-color: darkgreen; /* Darker border color when checked */ +} + + + Connect + + + true + + + + + + + + + + + + Segmentation model: + + + + + + + + + List models that contain all the specified words + + + + + + + Search in full text of the segmentation model description. Uncheck to search only in the model names. + + + Full text + + + + + + + + + Download sample data set for the current segmentation model + + + ... + + + + @@ -39,14 +159,14 @@ - + Input volume 1: - + @@ -67,14 +187,14 @@ - + Input volume 2: - + @@ -95,14 +215,14 @@ - + Input volume 3: - + @@ -123,14 +243,14 @@ - + Input volume 4: - + @@ -151,78 +271,6 @@ - - - - Download sample data set for the current segmentation model - - - ... - - - - - - - - - List models that contain all the specified words - - - - - - - Search in full text of the segmentation model description. Uncheck to search only in the model names. - - - Full text - - - - - - - - - Segmentation model: - - - - - - - - - - 0 - 0 - - - - true - - - - - - - QPushButton:checked { - /* Checked style */ - background-color: green; /* Green background when checked */ - border-color: darkgreen; /* Darker border color when checked */ -} - - - Connect - - - true - - - - - @@ -232,6 +280,9 @@ Outputs + + + @@ -275,9 +326,6 @@ - - - @@ -447,7 +495,7 @@ Local segmentation server - false + true @@ -605,8 +653,8 @@ 132 - 388 - 87 + 389 + 227 @@ -621,8 +669,8 @@ 8 - 355 - 233 + 389 + 373 @@ -637,8 +685,8 @@ 74 - 323 - 104 + 389 + 253 @@ -653,8 +701,8 @@ 133 - 383 - 132 + 389 + 279 @@ -669,24 +717,24 @@ 179 - 370 - 154 + 389 + 305 - remoteServerButton + remoteProcessingCheckBox toggled(bool) - serverCollapsibleButton - setDisabled(bool) + serverConnectionFrame + setVisible(bool) - 356 - 49 + 134 + 53 - 203 - 533 + 240 + 44 From 416d085eaff5ae6eb47f7330474bdeb7d46b8ee5 Mon Sep 17 00:00:00 2001 From: che85 Date: Fri, 18 Oct 2024 14:25:53 -0400 Subject: [PATCH 18/27] BUG: proper quoting when creating cmd line for server processes --- MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py index 0b38579..e628d2f 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py @@ -112,8 +112,9 @@ async def infer( auto3DSegCommand.append(inputFiles[inputIndex]) try: - # logging.debug(auto3DSegCommand) - proc = await asyncio.create_subprocess_shell(" ".join(auto3DSegCommand)) + cmd = ' '.join(f'"{arg}"' for arg in auto3DSegCommand) + logging.debug(cmd) + proc = await asyncio.create_subprocess_shell(cmd) await proc.wait() if proc.returncode != 0: raise subprocess.CalledProcessError(proc.returncode, " ".join(auto3DSegCommand)) @@ -121,7 +122,7 @@ async def infer( except Exception as e: logging.info(e) shutil.rmtree(session_dir) - raise HTTPException(status_code=500, detail=f"Failed to run CMD command: {str(e)}") + raise def main(argv): From a7d516a60d4c93a89c0e1e62f912fedadaca8855 Mon Sep 17 00:00:00 2001 From: che85 Date: Fri, 18 Oct 2024 15:45:52 -0400 Subject: [PATCH 19/27] BUG: run inference process as exec vs shell - shell was causing issues on machines with conda initialization at shell startup --- MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py index e628d2f..9a0173b 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py @@ -112,12 +112,11 @@ async def infer( auto3DSegCommand.append(inputFiles[inputIndex]) try: - cmd = ' '.join(f'"{arg}"' for arg in auto3DSegCommand) - logging.debug(cmd) - proc = await asyncio.create_subprocess_shell(cmd) + logging.debug(auto3DSegCommand) + proc = await asyncio.create_subprocess_exec(*auto3DSegCommand) await proc.wait() if proc.returncode != 0: - raise subprocess.CalledProcessError(proc.returncode, " ".join(auto3DSegCommand)) + raise subprocess.CalledProcessError(proc.returncode, auto3DSegCommand) return FileResponse(outputSegmentationFile, media_type='application/octet-stream', background=background_tasks) except Exception as e: logging.info(e) From 317740e2cb937d8d49551ab9afe2ff0170ba4be1 Mon Sep 17 00:00:00 2001 From: che85 Date: Mon, 21 Oct 2024 11:44:33 -0400 Subject: [PATCH 20/27] ENH: improved error feedback on client side --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 33 +++++++++++++-------- MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py | 18 +++++++---- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 7848e9c..17a17d5 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -181,6 +181,7 @@ class MONAIAuto3DSegWidget(ScriptedLoadableModuleWidget, VTKObservationMixin): PROCESSING_IMPORT_RESULTS = 3 PROCESSING_COMPLETED = 4 PROCESSING_CANCEL_REQUESTED = 5 + PROCESSING_FAILED = 6 PROCESSING_STATES = { PROCESSING_IDLE: "Idle", @@ -188,7 +189,8 @@ class MONAIAuto3DSegWidget(ScriptedLoadableModuleWidget, VTKObservationMixin): PROCESSING_IN_PROGRESS: "In Progress", PROCESSING_IMPORT_RESULTS: "Importing Results", PROCESSING_COMPLETED: "Processing Finished", - PROCESSING_CANCEL_REQUESTED: "Cancelling..." + PROCESSING_CANCEL_REQUESTED: "Cancelling...", + PROCESSING_FAILED: "Processing Failed" } @staticmethod @@ -601,9 +603,8 @@ def onApply(self): self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_STARTING) - try: - with slicer.util.tryWithErrorDisplay("Failed to start processing.", waitCursor=True): - + with slicer.util.tryWithErrorDisplay("Processing Failed. Check logs for more information.", waitCursor=True): + try: # Create new segmentation node, if not selected yet if not self.ui.outputSegmentationSelector.currentNode(): self.ui.outputSegmentationSelector.addNode() @@ -620,8 +621,10 @@ def onApply(self): self._segmentationProcessInfo = self.logic.process(inputNodes, self.ui.outputSegmentationSelector.currentNode(), self._currentModelId(), self.ui.cpuCheckBox.checked, waitForCompletion=False) - except Exception as e: - self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) + except Exception as e: + self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_FAILED) + self.setProcessingState(MONAIAuto3DSegWidget.PROCESSING_IDLE) + raise def onCancel(self): with slicer.util.tryWithErrorDisplay("Failed to cancel processing.", waitCursor=True): @@ -1422,13 +1425,14 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF logging.info(f"Initiating Inference on {self._server_address}") files = {} - try: - for idx, inputFile in enumerate(inputFiles, start=1): - name = "image_file" - if idx > 1: - name = f"{name}_{idx}" - files[name] = open(inputFile, 'rb') + for idx, inputFile in enumerate(inputFiles, start=1): + name = "image_file" + if idx > 1: + name = f"{name}_{idx}" + files[name] = open(inputFile, 'rb') + r = None + try: with requests.post(self._server_address + f"/infer?model_name={modelId}", files=files) as r: r.raise_for_status() @@ -1437,6 +1441,11 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF binary_file.write(chunk) segmentationProcessInfo.procReturnCode = 0 + except Exception as e: + logging.debug(f"Error occurred: {e}") + if hasattr(r, "content"): + logging.debug(f"Response content: {r.content}") + raise finally: for f in files.values(): f.close() diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py index 9a0173b..e97fc4d 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py @@ -10,6 +10,7 @@ from pathlib import Path + paths = [str(Path(__file__).parent.parent)] for path in paths: if not path in sys.path: @@ -21,8 +22,7 @@ import asyncio import subprocess from fastapi import FastAPI, UploadFile -from fastapi.responses import FileResponse -from fastapi import HTTPException +from fastapi.responses import FileResponse, JSONResponse from fastapi.background import BackgroundTasks @@ -118,10 +118,18 @@ async def infer( if proc.returncode != 0: raise subprocess.CalledProcessError(proc.returncode, auto3DSegCommand) return FileResponse(outputSegmentationFile, media_type='application/octet-stream', background=background_tasks) - except Exception as e: - logging.info(e) + except Exception as err: + logging.info(err) shutil.rmtree(session_dir) - raise + import traceback + return JSONResponse( + content={ + "error": "An unexpected error occurred", + "message": str(err), + "traceback": traceback.format_exc() + }, + status_code=500 + ) def main(argv): From 5cdb7b085dd15a402ca6073b07d520923ebfa925 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Mon, 21 Oct 2024 12:02:42 -0400 Subject: [PATCH 21/27] BUG: unchecking remoteServerButton whenever status of remote checkbox changes --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 1 + MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui | 24 ++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 17a17d5..b117c0a 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -290,6 +290,7 @@ def setup(self): self.ui.serverComboBox.lineEdit().setPlaceholderText("Enter server address") self.ui.serverComboBox.currentIndexChanged.connect(self.onRemoteServerButtonToggled) + self.ui.remoteProcessingCheckBox.toggled.connect(lambda t: self.ui.remoteServerButton.setChecked(False)) self.ui.remoteServerButton.toggled.connect(self.onRemoteServerButtonToggled) self.ui.serverButton.toggled.connect(self.onServerButtonToggled) diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index c6e57e7..d5be06b 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -28,8 +28,14 @@ + + + 0 + 0 + + - + true @@ -38,6 +44,13 @@ + + QPushButton:checked { + /* Checked style */ + background-color: green; /* Green background when checked */ + border-color: darkgreen; /* Darker border color when checked */ +} + QFrame::NoFrame @@ -79,6 +92,15 @@ + + + 0 + 0 + + + + true + QPushButton:checked { /* Checked style */ From 89550c218ba3782f177503a37a8d3483bebf7414 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Mon, 21 Oct 2024 12:06:50 -0400 Subject: [PATCH 22/27] BUG: stylesheet of remote server button was not adapted after changing type --- MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index d5be06b..0ad3c9d 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -99,10 +99,10 @@ - true + false - QPushButton:checked { + QToolButton:checked { /* Checked style */ background-color: green; /* Green background when checked */ border-color: darkgreen; /* Darker border color when checked */ From 35da3574d23f8fce821b7982aef670e026b31ecd Mon Sep 17 00:00:00 2001 From: che85 Date: Mon, 21 Oct 2024 14:52:09 -0400 Subject: [PATCH 23/27] ENH: limit number of inference requests to 5 requests per minute --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 2 +- MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index b117c0a..d271804 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -760,7 +760,7 @@ def onServerButtonToggled(self, toggled): if not self._webServer or not self._webServer.isRunning() : import platform from pathlib import Path - slicer.util.pip_install("psutil python-multipart fastapi uvicorn[standard]") + slicer.util.pip_install("psutil python-multipart fastapi slowapi uvicorn[standard]") hostName = platform.node() port = str(self.ui.portSpinBox.value) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py index e97fc4d..7e3d309 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegServer/main.py @@ -1,4 +1,4 @@ -# pip install python-multipart fastapi uvicorn[standard] +# pip install python-multipart fastapi slowapi uvicorn[standard] # usage: uvicorn main:app --host example.com --port 8891 # usage: uvicorn main:app --host localhost --port 8891 @@ -21,12 +21,17 @@ import shutil import asyncio import subprocess -from fastapi import FastAPI, UploadFile +from fastapi import FastAPI, UploadFile, Request from fastapi.responses import FileResponse, JSONResponse from fastapi.background import BackgroundTasks +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +limiter = Limiter(key_func=lambda request: "request_per_route_per_minute") app = FastAPI() +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) modelDB = ModelDatabase() # deciding which dependencies to choose @@ -70,7 +75,9 @@ def getLabelsFile(id: str): @app.post("/infer") +@limiter.limit("5/minute") async def infer( + request: Request, background_tasks: BackgroundTasks, image_file: UploadFile, model_name: str, From 970fc1f04f67bb3e640f041e20cc6a0453fa2d1a Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Mon, 21 Oct 2024 15:35:40 -0400 Subject: [PATCH 24/27] ENH: improved error display and making statusLabel readonly --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 11 ++++++----- MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui | 6 +++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index d271804..865a443 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -3,6 +3,7 @@ import json import sys import time +from urllib.error import HTTPError import vtk @@ -1442,11 +1443,11 @@ def process(self, inputNodes, outputSegmentation, modelId=None, cpu=False, waitF binary_file.write(chunk) segmentationProcessInfo.procReturnCode = 0 - except Exception as e: - logging.debug(f"Error occurred: {e}") - if hasattr(r, "content"): - logging.debug(f"Response content: {r.content}") - raise + except requests.exceptions.HTTPError as e: + from http import HTTPStatus + status = HTTPStatus(e.response.status_code) + logging.debug(f"Server response content: {r.content}") + raise RuntimeError(status.description) finally: for f in files.values(): f.close() diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index 0ad3c9d..cbacac2 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -613,7 +613,11 @@ - + + + true + + From 06c938f566ae60302a1c128bc8d1dfed18ee8b31 Mon Sep 17 00:00:00 2001 From: Andras Lasso Date: Mon, 21 Oct 2024 16:51:35 -0400 Subject: [PATCH 25/27] Fix server launching on Windows --- MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py index ae327ab..5726450 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py @@ -203,8 +203,9 @@ def getAddressUrl(self): def start(self): cmd = [ sys.executable, - Path(__file__).parent.parent / "MONAIAuto3DSegServer" / "main.py", "--host", - self.hostName, "--port", self.port + str(Path(__file__).parent.parent / "MONAIAuto3DSegServer" / "main.py"), + "--host", self.hostName, + "--port", self.port ] logging.debug(f"Launching process: {cmd}") From c789d04985dd16d8fe67f02945658a26e3090049 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Tue, 22 Oct 2024 09:32:15 -0400 Subject: [PATCH 26/27] BUG: display process information in GUI when running local inference --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 9 ++++++++- MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index 865a443..e9d7075 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -257,6 +257,7 @@ def setup(self): # Create logic class. Logic implements all computations that should be possible to run # in batch mode, without a graphical user interface. self.logic = MONAIAuto3DSegLogic() + self.logic.logCallback = self.addLog self.logic.startResultImportCallback = self.onProcessImportStarted self.logic.endResultImportCallback = self.onProcessImportEnded self.logic.processingCompletedCallback = self.onProcessingCompleted @@ -997,6 +998,7 @@ def __init__(self): ScriptedLoadableModuleLogic.__init__(self) ModelDatabase.__init__(self) + self.logCallback = None self.startResultImportCallback = None self.endResultImportCallback = None self.processingCompletedCallback = None @@ -1013,6 +1015,11 @@ def __init__(self): self.debugSkipInference = False self.debugSkipInferenceTempDir = r"c:\Users\andra\AppData\Local\Temp\Slicer\__SlicerTemp__2024-01-16_15+26+25.624" + def log(self, text): + logging.info(text) + if self.logCallback: + self.logCallback(text) + def getMONAIPythonPackageInfo(self): return self.DEPENDENCY_HANDLER.installedMONAIPythonPackageInfo() @@ -1203,7 +1210,7 @@ def process(self, inputNodes, outputSegmentation, model=None, cpu=False, waitFor segmentationProcessInfo.model = model segmentationProcessInfo.customData = customData - self._bgProcess = LocalInference(processInfo=segmentationProcessInfo, completedCallback=self.onSegmentationProcessCompleted) + self._bgProcess = LocalInference(processInfo=segmentationProcessInfo, logCallback=self.log, completedCallback=self.onSegmentationProcessCompleted) if self.debugSkipInference: segmentationProcessInfo.procReturnCode = 0 self.onSegmentationProcessCompleted(segmentationProcessInfo) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py index 5726450..ba330bf 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSegLib/process.py @@ -238,6 +238,10 @@ def run(self, cmd, additionalEnvironmentVariables=None, waitForCompletion=True): else: self._startHandleProcessOutputThread() + def handleSubProcessLogging(self, text): + self.addLog(text) + logging.info(text) + def logProcessOutputUntilCompleted(self): # Wait for the process to end and forward output to the log proc = self.proc From 0ba44364183e70118554bc21c5c15af0873f64fe Mon Sep 17 00:00:00 2001 From: Andras Lasso Date: Tue, 22 Oct 2024 12:55:16 -0400 Subject: [PATCH 27/27] Improve Remote processing button Make Remote processing button state persistent. Disable model selection and Apply button if remote processing is enabled but not connected to a server yet. --- MONAIAuto3DSeg/MONAIAuto3DSeg.py | 17 ++++++++++++++++- MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui | 6 +----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/MONAIAuto3DSeg/MONAIAuto3DSeg.py b/MONAIAuto3DSeg/MONAIAuto3DSeg.py index e9d7075..1678345 100644 --- a/MONAIAuto3DSeg/MONAIAuto3DSeg.py +++ b/MONAIAuto3DSeg/MONAIAuto3DSeg.py @@ -262,6 +262,10 @@ def setup(self): self.logic.endResultImportCallback = self.onProcessImportEnded self.logic.processingCompletedCallback = self.onProcessingCompleted + self.ui.remoteProcessingCheckBox.checked = qt.QSettings().value(f"{self.moduleName}/remoteProcessing", False) + + self.ui.progressBar.hide() + # Connections # These connections ensure that we update parameter node when scene is closed @@ -292,7 +296,7 @@ def setup(self): self.ui.serverComboBox.lineEdit().setPlaceholderText("Enter server address") self.ui.serverComboBox.currentIndexChanged.connect(self.onRemoteServerButtonToggled) - self.ui.remoteProcessingCheckBox.toggled.connect(lambda t: self.ui.remoteServerButton.setChecked(False)) + self.ui.remoteProcessingCheckBox.toggled.connect(self.onRemoteProcessingCheckBoxToggled) self.ui.remoteServerButton.toggled.connect(self.onRemoteServerButtonToggled) self.ui.serverButton.toggled.connect(self.onServerButtonToggled) @@ -447,6 +451,11 @@ def updateGUIFromParameterNode(self, caller=None, event=None): if state == MONAIAuto3DSegWidget.PROCESSING_IDLE: self.ui.applyButton.text = "Apply" inputErrorMessages = [] # it will contain text if the inputs are not valid + if self.ui.remoteProcessingCheckBox.checked and not self.ui.remoteServerButton.checked: + inputErrorMessages.append("Connect to server or disable remote processing.") + self.ui.modelComboBox.enabled = False + else: + self.ui.modelComboBox.enabled = True if modelId: modelInputs = self.logic.model(modelId)["inputs"] else: @@ -830,6 +839,12 @@ def updateServerUrlGUIFromSettings(self): self.ui.serverComboBox.setCurrentText(settings.value(f"{self.moduleName}/serverUrl")) self.ui.serverComboBox.blockSignals(wasBlocked) + def onRemoteProcessingCheckBoxToggled(self, checked): + # Disconnect remote server button if remote processing state is changed + self.ui.remoteServerButton.setChecked(False) + settings = qt.QSettings() + settings.setValue(f"{self.moduleName}/remoteProcessing", checked) + self.updateGUIFromParameterNode() # # MONAIAuto3DSegLogic diff --git a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui index cbacac2..a4f7e61 100644 --- a/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui +++ b/MONAIAuto3DSeg/Resources/UI/MONAIAuto3DSeg.ui @@ -368,11 +368,7 @@ - - - 24 - - +