diff --git a/.gitignore b/.gitignore index ebd7434d1..77e0a8f9a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ __pycache__/ /Dockerfile **.whl build/*/docker.built +build/plugins/**/*.built build/*/requirements.txt build/*/specific_requirements.txt build/*/dlogs.*.txt diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..5153a0099 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "plugin/graph_processing/graph-transformer/tmi2022"] + path = plugin/graph_processing/graph-transformer/tmi2022 + url = git@github.com:CarlinLiao/tmi2022.git diff --git a/Makefile b/Makefile index 1cf037749..ff98e4b4b 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,7 @@ PACKAGE_NAME := spatialprofilingtoolbox export PYTHON := python export BUILD_SCRIPTS_LOCATION_ABSOLUTE := ${PWD}/build/build_scripts SOURCE_LOCATION := ${PACKAGE_NAME} +PLUGIN_SOURCE_LOCATION := plugin BUILD_LOCATION := build BUILD_LOCATION_ABSOLUTE := ${PWD}/build export TEST_LOCATION := test @@ -64,21 +65,29 @@ export DOCKER_REPO_PREFIX := spt export DOCKER_SCAN_SUGGEST:=false SUBMODULES := apiserver graphs ondemand db workflow DOCKERIZED_SUBMODULES := apiserver db ondemand +PLUGINS := graph_processing/cg-gnn graph_processing/graph-transformer +CUDA_PLUGINS := graph_processing/cg-gnn DOCKERFILES := $(foreach submodule,$(DOCKERIZED_SUBMODULES),${BUILD_LOCATION}/$(submodule)/Dockerfile) -DOCKER_BUILD_TARGETS := $(foreach submodule,$(DOCKERIZED_SUBMODULES),${BUILD_LOCATION_ABSOLUTE}/$(submodule)/docker.built) +DOCKER_BUILD_SUBMODULE_TARGETS := $(foreach submodule,$(DOCKERIZED_SUBMODULES),${BUILD_LOCATION_ABSOLUTE}/$(submodule)/docker.built) +DOCKER_BUILD_PLUGIN_TARGETS := $(foreach plugin,$(PLUGINS),${BUILD_LOCATION_ABSOLUTE}/plugins/$(plugin).docker.built) +DOCKER_BUILD_PLUGIN_CUDA_TARGETS := $(foreach plugin,$(CUDA_PLUGINS),${BUILD_LOCATION_ABSOLUTE}/plugins/$(plugin)-cuda.docker.built) _UNPUSHABLES := db PUSHABLE_SUBMODULES := $(filter-out ${_UNPUSHABLES},$(DOCKERIZED_SUBMODULES)) -DOCKER_PUSH_TARGETS := $(foreach submodule,$(PUSHABLE_SUBMODULES),docker-push-${PACKAGE_NAME}/$(submodule)) -DOCKER_PUSH_DEV_TARGETS := $(foreach submodule,$(DOCKERIZED_SUBMODULES),docker-push-dev-${PACKAGE_NAME}/$(submodule)) +DOCKER_PUSH_SUBMODULE_TARGETS := $(foreach submodule,$(PUSHABLE_SUBMODULES),docker-push-${PACKAGE_NAME}/$(submodule)) +DOCKER_PUSH_DEV_SUBMODULE_TARGETS := $(foreach submodule,$(DOCKERIZED_SUBMODULES),docker-push-dev-${PACKAGE_NAME}/$(submodule)) +DOCKER_PUSH_PLUGIN_TARGETS := $(foreach plugin,$(PLUGINS),docker-push-${PACKAGE_NAME}/$(plugin)) +DOCKER_PUSH_DEV_PLUGIN_TARGETS := $(foreach plugin,$(PLUGINS),docker-push-dev-${PACKAGE_NAME}/$(plugin)) +DOCKER_PUSH_PLUGIN_CUDA_TARGETS := $(foreach plugin,$(CUDA_PLUGINS),docker-push-${PACKAGE_NAME}/$(plugin)-cuda) +DOCKER_PUSH_DEV_PLUGIN_CUDA_TARGETS := $(foreach plugin,$(CUDA_PLUGINS),docker-push-dev-${PACKAGE_NAME}/$(plugin)-cuda) MODULE_TEST_TARGETS := $(foreach submodule,$(SUBMODULES),module-test-$(submodule)) UNIT_TEST_TARGETS := $(foreach submodule,$(SUBMODULES),unit-test-$(submodule)) SINGLETON_TEST_TARGETS := $(foreach submodule,$(SUBMODULES),singleton-test-$(submodule)) DLI := force-rebuild-data-loaded-image # Define PHONY targets -.PHONY: help release-package check-for-pypi-credentials print-source-files build-and-push-docker-images ${DOCKER_PUSH_TARGETS} build-docker-images test module-tests ${MODULE_TEST_TARGETS} ${UNIT_TEST_TARGETS} clean clean-files docker-compositions-rm clean-network-environment generic-spt-push-target data-loaded-images-push-target +.PHONY: help release-package check-for-pypi-credentials print-source-files build-and-push-docker-images ${DOCKER_PUSH_SUBMODULE_TARGETS} ${DOCKER_PUSH_PLUGIN_TARGETS} ${DOCKER_PUSH_PLUGIN_CUDA_TARGETS} build-docker-images test module-tests ${MODULE_TEST_TARGETS} ${UNIT_TEST_TARGETS} clean clean-files docker-compositions-rm clean-network-environment generic-spt-push-target data-loaded-images-push-target ensure-plugin-submodules-are-populated # Submodule-specific variables export DB_SOURCE_LOCATION_ABSOLUTE := ${PWD}/${SOURCE_LOCATION}/db @@ -167,11 +176,11 @@ pyproject.toml: pyproject.toml.unversioned ${BUILD_SCRIPTS_LOCATION_ABSOLUTE}/cr print-source-files: >@echo "${PACKAGE_SOURCE_FILES}" | tr ' ' '\n' -build-and-push-docker-images: ${DOCKER_PUSH_TARGETS} generic-spt-push-target data-loaded-images-push-target +build-and-push-docker-images: ${DOCKER_PUSH_SUBMODULE_TARGETS} ${DOCKER_PUSH_PLUGIN_TARGETS} ${DOCKER_PUSH_PLUGIN_CUDA_TARGETS} generic-spt-push-target data-loaded-images-push-target -build-and-push-docker-images-dev: ${DOCKER_PUSH_DEV_TARGETS} +build-and-push-docker-images-dev: ${DOCKER_PUSH_DEV_SUBMODULE_TARGETS} ${DOCKER_PUSH_DEV_PLUGIN_TARGETS} ${DOCKER_PUSH_DEV_PLUGIN_CUDA_TARGETS} -${DOCKER_PUSH_TARGETS}: build-docker-images check-for-docker-credentials +${DOCKER_PUSH_SUBMODULE_TARGETS}: build-docker-images check-for-docker-credentials >@submodule_directory=$$(echo $@ | sed 's/^docker-push-//g') ; \ submodule_version=$$(grep '^__version__ = ' $$submodule_directory/__init__.py | grep -o '[0-9]\+\.[0-9]\+\.[0-9]\+') ;\ submodule_name=$$(echo $$submodule_directory | sed 's/spatialprofilingtoolbox\///g') ; \ @@ -191,7 +200,7 @@ ${DOCKER_PUSH_TARGETS}: build-docker-images check-for-docker-credentials exit_code=$$(( exit_code1 + exit_code2 )); echo "$$exit_code" > status_code >@${MESSAGE} end "Pushed." "Not pushed." -${DOCKER_PUSH_DEV_TARGETS}: build-docker-images check-for-docker-credentials +${DOCKER_PUSH_DEV_SUBMODULE_TARGETS}: build-docker-images check-for-docker-credentials >@submodule_directory=$$(echo $@ | sed 's/^docker-push-dev-//g') ; \ submodule_version=$$(grep '^__version__ = ' $$submodule_directory/__init__.py | grep -o '[0-9]\+\.[0-9]\+\.[0-9]\+') ;\ submodule_name=$$(echo $$submodule_directory | sed 's/spatialprofilingtoolbox\///g') ; \ @@ -209,6 +218,78 @@ ${DOCKER_PUSH_DEV_TARGETS}: build-docker-images check-for-docker-credentials echo "$$exit_code1" > status_code >@${MESSAGE} end "Pushed." "Not pushed." +${DOCKER_PUSH_PLUGIN_TARGETS}: build-docker-images check-for-docker-credentials +>@plugin_name=$$(basename $@) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + ${MESSAGE} start "Pushing Docker container $$repository_name" +>@plugin_name=$$(basename $@) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + plugin_relative_directory=$$(dirname $@ | sed 's,docker-push-${PACKAGE_NAME}\/,,g')/$$plugin_name ; \ + source_directory=${PLUGIN_SOURCE_LOCATION}/$$plugin_relative_directory ; \ + plugin_version=$$(cat $$source_directory/version.txt) ; \ + echo "$$plugin_version"; \ + echo "$$plugin_name"; \ + echo "$$repository_name"; \ + docker push $$repository_name:$$plugin_version ; \ + exit_code1=$$?; \ + docker push $$repository_name:latest ; \ + exit_code2=$$?; \ + exit_code=$$(( exit_code1 + exit_code2 )); echo "$$exit_code" > status_code +>@${MESSAGE} end "Pushed." "Not pushed." + +${DOCKER_PUSH_DEV_PLUGIN_TARGETS}: build-docker-images check-for-docker-credentials +>@plugin_name=$$(basename $@) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + ${MESSAGE} start "Pushing Docker container $$repository_name" +>@plugin_name=$$(basename $@) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + plugin_relative_directory=$$(dirname $@ | sed 's,docker-push-dev-${PACKAGE_NAME}\/,,g')/$$plugin_name ; \ + source_directory=${PLUGIN_SOURCE_LOCATION}/$$plugin_relative_directory ; \ + plugin_version=$$(cat $$source_directory/version.txt) ; \ + echo "$$plugin_version"; \ + echo "$$plugin_name"; \ + echo "$$repository_name"; \ + docker push $$repository_name:dev ; \ + exit_code1=$$?; \ + echo "$$exit_code1" > status_code +>@${MESSAGE} end "Pushed." "Not pushed." + +${DOCKER_PUSH_PLUGIN_CUDA_TARGETS}: build-docker-images check-for-docker-credentials +>@plugin_name=$$(basename $@ -cuda) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + ${MESSAGE} start "Pushing Docker container $$repository_name:cuda" +>@plugin_name=$$(basename $@ -cuda) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + plugin_relative_directory=$$(dirname $@ | sed 's,docker-push-${PACKAGE_NAME}\/,,g')/$$plugin_name ; \ + source_directory=${PLUGIN_SOURCE_LOCATION}/$$plugin_relative_directory ; \ + plugin_version=$$(cat $$source_directory/version.txt) ; \ + echo "$$plugin_version"; \ + echo "$$plugin_name"; \ + echo "$$repository_name"; \ + docker push $$repository_name:cuda-$$plugin_version ; \ + exit_code1=$$?; \ + docker push $$repository_name:cuda-latest ; \ + exit_code2=$$?; \ + exit_code=$$(( exit_code1 + exit_code2 )); echo "$$exit_code" > status_code +>@${MESSAGE} end "Pushed." "Not pushed." + +${DOCKER_PUSH_DEV_PLUGIN_CUDA_TARGETS}: build-docker-images check-for-docker-credentials +>@plugin_name=$$(basename $@ -cuda) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + ${MESSAGE} start "Pushing Docker container $$repository_name:cuda" +>@plugin_name=$$(basename $@ -cuda) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + plugin_relative_directory=$$(dirname $@ | sed 's,docker-push-dev-${PACKAGE_NAME}\/,,g')/$$plugin_name ; \ + source_directory=${PLUGIN_SOURCE_LOCATION}/$$plugin_relative_directory ; \ + plugin_version=$$(cat $$source_directory/version.txt) ; \ + echo "$$plugin_version"; \ + echo "$$plugin_name"; \ + echo "$$repository_name"; \ + docker push $$repository_name:cuda-dev ; \ + exit_code1=$$?; \ + echo "$$exit_code1" > status_code +>@${MESSAGE} end "Pushed." "Not pushed." + generic-spt-push-target: build-docker-images check-for-docker-credentials >@repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX} ; \ ${MESSAGE} start "Pushing Docker container $$repository_name" @@ -252,7 +333,10 @@ check-dockerfiles-consistency: >@status_code=$$(cat status_code); if [[ "$$status_code" == "0" && ( ! -f check-dockerfiles-consistency ) ]]; then touch check-dockerfiles-consistency; fi; >@${MESSAGE} end "Consistent." "Something missing." -build-docker-images: ${DOCKER_BUILD_TARGETS} +ensure-plugin-submodules-are-populated: +>@git submodule update --init --recursive + +build-docker-images: ${DOCKER_BUILD_SUBMODULE_TARGETS} ${DOCKER_BUILD_PLUGIN_TARGETS} ${DOCKER_BUILD_PLUGIN_CUDA_TARGETS} # Build the Docker container for each submodule by doing the following: # 1. Identify the submodule being built @@ -260,7 +344,7 @@ build-docker-images: ${DOCKER_BUILD_TARGETS} # 3. Copy relevant files to the build folder # 4. docker build the container # 5. Remove copied files -${DOCKER_BUILD_TARGETS}: ${DOCKERFILES} development-image check-docker-daemon-running check-for-docker-credentials check-dockerfiles-consistency +${DOCKER_BUILD_SUBMODULE_TARGETS}: ${DOCKERFILES} development-image check-docker-daemon-running check-for-docker-credentials check-dockerfiles-consistency >@submodule_directory=$$(echo $@ | sed 's/\/docker.built//g') ; \ dockerfile=$${submodule_directory}/Dockerfile ; \ submodule_name=$$(echo $$submodule_directory | sed 's,${BUILD_LOCATION_ABSOLUTE}\/,,g') ; \ @@ -295,6 +379,70 @@ ${DOCKER_BUILD_TARGETS}: ${DOCKERFILES} development-image check-docker-daemon-ru rm ./Dockerfile ; \ rm ./.dockerignore ; \ +${DOCKER_BUILD_PLUGIN_TARGETS}: check-docker-daemon-running check-for-docker-credentials check-dockerfiles-consistency ensure-plugin-submodules-are-populated +>@plugin_name=$$(basename $@ .docker.built) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + ${MESSAGE} start "Building Docker image $$repository_name" +>@plugin_name=$$(basename $@ .docker.built) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + plugin_relative_directory=$$(dirname $@ | sed 's,${BUILD_LOCATION_ABSOLUTE}\/plugins\/,,g')/$$plugin_name ; \ + source_directory=${PLUGIN_SOURCE_LOCATION}/$$plugin_relative_directory ; \ + plugin_version=$$(cat $$source_directory/version.txt) ; \ + plugin_directory=$$(dirname $@)/$$plugin_name ; \ + mkdir -p $$plugin_directory ; \ + cp -r $$source_directory/* $$plugin_directory ; \ + cp $$(dirname $@)/$$(basename $@ .docker.built).dockerfile ./Dockerfile ; \ + cp ${BUILD_SCRIPTS_LOCATION_ABSOLUTE}/.dockerignore . ; \ + docker build \ + ${NO_CACHE_FLAG} \ + -f ./Dockerfile \ + -t $$repository_name:$$plugin_version \ + -t $$repository_name:latest \ + -t $$repository_name:dev \ + --build-arg version=$$plugin_version \ + --build-arg service_name=$$plugin_name \ + $$plugin_directory ; echo "$$?" > status_code; \ + if [[ "$$(cat status_code)" == "0" ]]; \ + then \ + touch $@ ; \ + fi +>@${MESSAGE} end "Built." "Build failed." +>@plugin_name=$$(basename $@ .docker.built) ; \ + plugin_directory=$$(dirname $@)/$$plugin_name ; \ + rm -r $$plugin_directory ; \ + +${DOCKER_BUILD_PLUGIN_CUDA_TARGETS}: check-docker-daemon-running check-for-docker-credentials check-dockerfiles-consistency ensure-plugin-submodules-are-populated +>@plugin_name=$$(basename $@ -cuda.docker.built) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + ${MESSAGE} start "Building Docker image $$repository_name:cuda" +>@plugin_name=$$(basename $@ -cuda.docker.built) ; \ + repository_name=${DOCKER_ORG_NAME}/${DOCKER_REPO_PREFIX}-$$plugin_name ; \ + plugin_relative_directory=$$(dirname $@ | sed 's,${BUILD_LOCATION_ABSOLUTE}\/plugins\/,,g')/$$plugin_name ; \ + source_directory=${PLUGIN_SOURCE_LOCATION}/$$plugin_relative_directory ; \ + plugin_version=$$(cat $$source_directory/version.txt) ; \ + plugin_directory=$$(dirname $@)/$$plugin_name-cuda ; \ + mkdir -p $$plugin_directory ; \ + cp $$source_directory/* $$plugin_directory ; \ + cp $$(dirname $@)/$$(basename $@ .docker.built).dockerfile ./Dockerfile ; \ + cp ${BUILD_SCRIPTS_LOCATION_ABSOLUTE}/.dockerignore . ; \ + docker build \ + ${NO_CACHE_FLAG} \ + -f ./Dockerfile \ + -t $$repository_name:cuda-$$plugin_version \ + -t $$repository_name:cuda-latest \ + -t $$repository_name:cuda-dev \ + --build-arg version=$$plugin_version \ + --build-arg service_name=$$plugin_name \ + $$plugin_directory ; echo "$$?" > status_code; \ + if [[ "$$(cat status_code)" == "0" ]]; \ + then \ + touch $@ ; \ + fi +>@${MESSAGE} end "Built." "Build failed." +>@plugin_name=$$(basename $@ -cuda.docker.built) ; \ + plugin_directory=$$(dirname $@)/$$plugin_name-cuda ; \ + rm -r $$plugin_directory ; \ + check-docker-daemon-running: >@${MESSAGE} start "Checking that Docker daemon is running" >@docker stats --no-stream ; echo "$$?" > status_code @@ -325,17 +473,17 @@ test: unit-tests module-tests module-tests: ${MODULE_TEST_TARGETS} -${MODULE_TEST_TARGETS}: development-image data-loaded-image-1smallnointensity data-loaded-image-1small data-loaded-image-1 data-loaded-image-1and2 ${DOCKER_BUILD_TARGETS} clean-network-environment .initial_time.txt +${MODULE_TEST_TARGETS}: development-image data-loaded-image-1smallnointensity data-loaded-image-1small data-loaded-image-1 data-loaded-image-1and2 ${DOCKER_BUILD_SUBMODULE_TARGETS} clean-network-environment .initial_time.txt >@submodule_directory=$$(echo $@ | sed 's/^module-test-/${BUILD_LOCATION}\//g') ; \ ${MAKE} SHELL=$(SHELL) --no-print-directory -C $$submodule_directory module-tests ; unit-tests: ${UNIT_TEST_TARGETS} -${UNIT_TEST_TARGETS}: development-image data-loaded-image-1smallnointensity data-loaded-image-1small data-loaded-image-1 data-loaded-image-1and2 ${DOCKER_BUILD_TARGETS} clean-network-environment .initial_time.txt +${UNIT_TEST_TARGETS}: development-image data-loaded-image-1smallnointensity data-loaded-image-1small data-loaded-image-1 data-loaded-image-1and2 ${DOCKER_BUILD_SUBMODULE_TARGETS} clean-network-environment .initial_time.txt >@submodule_directory=$$(echo $@ | sed 's/^unit-test-/${BUILD_LOCATION}\//g') ; \ ${MAKE} SHELL=$(SHELL) --no-print-directory -C $$submodule_directory unit-tests ; -${SINGLETON_TEST_TARGETS}: development-image data-loaded-image-1small data-loaded-image-1 data-loaded-image-1and2 ${DOCKER_BUILD_TARGETS} clean-network-environment .initial_time.txt +${SINGLETON_TEST_TARGETS}: development-image data-loaded-image-1small data-loaded-image-1 data-loaded-image-1and2 ${DOCKER_BUILD_SUBMODULE_TARGETS} clean-network-environment .initial_time.txt >@submodule_directory=$$(echo $@ | sed 's/^singleton-test-/${BUILD_LOCATION}\//g') ; \ ${MAKE} SHELL=$(SHELL) --no-print-directory -C $$submodule_directory singleton-tests ; diff --git a/build/plugins/graph_processing/cg-gnn-cuda.dockerfile b/build/plugins/graph_processing/cg-gnn-cuda.dockerfile new file mode 100644 index 000000000..71c75ae12 --- /dev/null +++ b/build/plugins/graph_processing/cg-gnn-cuda.dockerfile @@ -0,0 +1,32 @@ +FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime +WORKDIR /app + +# Install apt packages you need here, and then clean up afterward +RUN apt-get update +RUN apt-get install -y \ + libhdf5-serial-dev \ + libatlas-base-dev \ + libblas-dev \ + liblapack-dev \ + gfortran \ + libpq-dev +RUN rm -rf /var/lib/apt/lists/* + +# Install python packages you need here +ENV PIP_NO_CACHE_DIR=1 +RUN pip install h5py==3.10.0 +RUN pip install numpy==1.24.3 +RUN pip install scipy==1.10.1 +RUN pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html +RUN pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html +ENV DGLBACKEND=pytorch +RUN pip install cg-gnn==0.3.2 + +# Make the files you need in this directory available everywhere in the container +ADD . /app +RUN chmod +x train.py +RUN mv train.py /usr/local/bin/spt-plugin-train-on-graphs +RUN chmod +x /app/print_graph_config.sh +RUN mv /app/print_graph_config.sh /usr/local/bin/spt-plugin-print-graph-request-configuration +RUN chmod +x /app/print_training_config.sh +RUN mv /app/print_training_config.sh /usr/local/bin/spt-plugin-print-training-configuration diff --git a/build/plugins/graph_processing/cg-gnn.dockerfile b/build/plugins/graph_processing/cg-gnn.dockerfile new file mode 100644 index 000000000..9320164ff --- /dev/null +++ b/build/plugins/graph_processing/cg-gnn.dockerfile @@ -0,0 +1,34 @@ +# Use cuda.Dockerfile if you have a CUDA-enabled GPU +FROM python:3.11-slim-buster +WORKDIR /app + +# Install apt packages you need here, and then clean up afterward +RUN apt-get update +RUN apt-get install -y \ + libhdf5-serial-dev \ + libatlas-base-dev \ + libblas-dev \ + liblapack-dev \ + gfortran \ + libpq-dev +RUN rm -rf /var/lib/apt/lists/* + +# Install python packages you need here +ENV PIP_NO_CACHE_DIR=1 +RUN pip install h5py==3.10.0 +RUN pip install numpy==1.24.3 +RUN pip install scipy==1.10.1 +RUN pip install torch --index-url https://download.pytorch.org/whl/cpu +RUN pip install dgl -f https://data.dgl.ai/wheels/repo.html +RUN pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html +ENV DGLBACKEND=pytorch +RUN pip install cg-gnn==0.3.2 + +# Make the files you need in this directory available everywhere in the container +ADD . /app +RUN chmod +x train.py +RUN mv train.py /usr/local/bin/spt-plugin-train-on-graphs +RUN chmod +x /app/print_graph_config.sh +RUN mv /app/print_graph_config.sh /usr/local/bin/spt-plugin-print-graph-request-configuration +RUN chmod +x /app/print_training_config.sh +RUN mv /app/print_training_config.sh /usr/local/bin/spt-plugin-print-training-configuration diff --git a/build/plugins/graph_processing/graph-transformer.dockerfile b/build/plugins/graph_processing/graph-transformer.dockerfile new file mode 100644 index 000000000..44c70c198 --- /dev/null +++ b/build/plugins/graph_processing/graph-transformer.dockerfile @@ -0,0 +1,39 @@ +FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime +WORKDIR /app + +# Install apt packages you need here, and then clean up afterward +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && apt-get install -y --no-install-recommends \ + libhdf5-dev \ + libopenblas0 libopenblas-dev \ + libprotobuf-dev \ + libjpeg8-dev \ + libpng-dev \ + libtiff-dev \ + libwebp-dev \ + libopenjp2-7-dev \ + libtbb-dev \ + libeigen3-dev \ + tesseract-ocr tesseract-ocr-por libtesseract-dev && \ + rm -rf /var/lib/apt/lists/* + +# Install python packages you need here +ENV PIP_NO_CACHE_DIR=1 +RUN pip install h5py==3.10.0 +RUN pip install numpy==1.24.3 +RUN pip install scipy==1.10.1 +RUN pip install pandas +RUN pip install pillow +RUN pip install tensorboardX +RUN pip install opencv-python +RUN pip install einops +RUN pip install torch-geometric + +# Make the files you need in this directory available everywhere in the container +ADD . /app +RUN chmod +x train.py +RUN mv train.py /usr/local/bin/spt-plugin-train-on-graphs +RUN chmod +x /app/print_graph_config.sh +RUN mv /app/print_graph_config.sh /usr/local/bin/spt-plugin-print-graph-request-configuration +RUN chmod +x /app/print_training_config.sh +RUN mv /app/print_training_config.sh /usr/local/bin/spt-plugin-print-training-configuration diff --git a/plugin/README.md b/plugin/README.md new file mode 100644 index 000000000..8a550532a --- /dev/null +++ b/plugin/README.md @@ -0,0 +1,3 @@ +# Plugins + +This directory contains various plugins and plugin archetypes for the SPT platform. diff --git a/plugin/graph_processing/README.md b/plugin/graph_processing/README.md new file mode 100644 index 000000000..316446265 --- /dev/null +++ b/plugin/graph_processing/README.md @@ -0,0 +1,15 @@ +# Graph processing + +These plugins create and process cell graphs that are used to train prediction models and extract features from the models. + +New plugins can be contributed by modifying the [template](template/) or implementation ([cg-gnn](cg-gnn/), [graph-transformer](graph-transformer/)) source to use alternative processing methods. + +Graph processing plugins are Docker images. (template [Dockerfile](template/Dockerfile), `cg-gnn` [Dockerfile](../../build/plugins/graph_processing/cg-gnn.dockerfile) and `graph-transformer` [Dockerfile](../../build/plugins/graph_processing/graph-transformer.dockerfile)) + +Each plugin is expected to have the following commands available on the path: +* `spt-plugin-print-graph-request-configuration`, which prints to `stdout` the configuration file intended to be used by this plugin to fetch graphs from an SPT instance to use for model training. An empty configuration file and a shell script to do this is provided in the template, as well as the command needed to make this available in the template `Dockerfile`. +* `spt-plugin-train-on-graphs` trains the model and outputs a CSV of importance scores that can be read by `spt graphs upload-importances`. A template [`train.py`](template/train.py) is provided that uses a command line interface specified in `train_cli.py`. Its arguments are + 1. `--input_directory`, the path to the directory containing the graphs to train on. + 2. `--config_file`, the path to the configuration file. This should be optional, and if not provided `spt-plugin-train-on-graphs` should use reasonable defaults. + 3. `--output_directory`, the path to the directory in which to save the trained model, importance scores, and any other artifacts deemed important enough to save, like performance reports. +* `spt-plugin-print-training-configuration`, which prints to `stdout` an example configuration file for running `spt-plugin-train-on-graphs`, populated either with example values or the reasonable defaults used by the command. An empty configuration file and a shell script to do this is provided in the templatere, as well as the command needed to make this available in the template `Dockerfile`. diff --git a/plugin/graph_processing/cg-gnn/README.md b/plugin/graph_processing/cg-gnn/README.md new file mode 100644 index 000000000..53de3cf9d --- /dev/null +++ b/plugin/graph_processing/cg-gnn/README.md @@ -0,0 +1,3 @@ +# spt-cg-gnn + +This builds the cg-gnn SPT plugin as a Docker image. diff --git a/plugin/graph_processing/cg-gnn/graph.config b/plugin/graph_processing/cg-gnn/graph.config new file mode 100644 index 000000000..695fdfcaa --- /dev/null +++ b/plugin/graph_processing/cg-gnn/graph.config @@ -0,0 +1,7 @@ +[graph-generation] +validation_data_percent = 0 +test_data_percent = 15 +; use_channels = true +; use_phenotypes = true +; cells_per_roi_target = 5000 +target_name = diff --git a/plugin/graph_processing/cg-gnn/print_graph_config.sh b/plugin/graph_processing/cg-gnn/print_graph_config.sh new file mode 100644 index 000000000..7f07d1a80 --- /dev/null +++ b/plugin/graph_processing/cg-gnn/print_graph_config.sh @@ -0,0 +1,2 @@ +#!/bin/sh +cat /app/graph.config diff --git a/plugin/graph_processing/cg-gnn/print_training_config.sh b/plugin/graph_processing/cg-gnn/print_training_config.sh new file mode 100644 index 000000000..8fee4a4ae --- /dev/null +++ b/plugin/graph_processing/cg-gnn/print_training_config.sh @@ -0,0 +1,2 @@ +#!/bin/sh +cat /app/training.config diff --git a/plugin/graph_processing/cg-gnn/train.py b/plugin/graph_processing/cg-gnn/train.py new file mode 100644 index 000000000..5a1bb841c --- /dev/null +++ b/plugin/graph_processing/cg-gnn/train.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +"""Convert SPT graph objects to CG-GNN graph objects and run training and evaluation with them.""" + +from sys import path +from configparser import ConfigParser +from os import remove +from os.path import join, exists +from configparser import ConfigParser +from warnings import warn + +from numpy import nonzero # type: ignore +from networkx import to_scipy_sparse_array # type: ignore +from torch import ( + FloatTensor, + IntTensor, # type: ignore +) +from dgl import DGLGraph, graph +from cggnn.util import GraphData, save_cell_graphs, load_cell_graphs +from cggnn.util.constants import INDICES, CENTROIDS, FEATURES, IMPORTANCES +from cggnn.run import train_and_evaluate + +path.append('/app') # noqa +from train_cli import parse_arguments, DEFAULT_CONFIG_FILE +from util import HSGraph, GraphData as SPTGraphData, load_hs_graphs, save_hs_graphs + + +def _convert_spt_graph(g_spt: HSGraph) -> DGLGraph: + """Convert a SPT HSGraph to a CG-GNN cell graph.""" + num_nodes = g_spt.node_features.shape[0] + g_dgl = graph([]) + g_dgl.add_nodes(num_nodes) + g_dgl.ndata[INDICES] = IntTensor(g_spt.histological_structure_ids) + g_dgl.ndata[CENTROIDS] = FloatTensor(g_spt.centroids) + g_dgl.ndata[FEATURES] = FloatTensor(g_spt.node_features) + # Note: channels and phenotypes are binary variables, but DGL only supports FloatTensors + edge_list = nonzero(g_spt.adj.toarray()) + g_dgl.add_edges(list(edge_list[0]), list(edge_list[1])) + return g_dgl + + +def _convert_spt_graph_data(g_spt: SPTGraphData) -> GraphData: + """Convert a SPT GraphData object to a CG-GNN/DGL GraphData object.""" + return GraphData( + graph=_convert_spt_graph(g_spt.graph), + label=g_spt.label, + name=g_spt.name, + specimen=g_spt.specimen, + set=g_spt.set, + ) + + +def _convert_spt_graphs_data(graphs_data: list[SPTGraphData]) -> list[GraphData]: + """Convert a list of SPT HSGraphs to CG-GNN cell graphs.""" + return [_convert_spt_graph_data(g_spt) for g_spt in graphs_data] + + +def _convert_dgl_graph(g_dgl: DGLGraph) -> HSGraph: + """Convert a DGLGraph to a CG-GNN cell graph.""" + return HSGraph( + adj=to_scipy_sparse_array(g_dgl.to_networkx()), + node_features=g_dgl.ndata[FEATURES].detach().cpu().numpy(), + centroids=g_dgl.ndata[CENTROIDS].detach().cpu().numpy(), + histological_structure_ids=g_dgl.ndata[INDICES].detach().cpu().numpy(), + importances=g_dgl.ndata[IMPORTANCES].detach().cpu().numpy() if (IMPORTANCES in g_dgl.ndata) + else None, + ) + + +def _convert_dgl_graph_data(g_dgl: GraphData) -> SPTGraphData: + return SPTGraphData( + graph=_convert_dgl_graph(g_dgl.graph), + label=g_dgl.label, + name=g_dgl.name, + specimen=g_dgl.specimen, + set=g_dgl.set, + ) + + +def _convert_dgl_graphs_data(graphs_data: list[GraphData]) -> list[SPTGraphData]: + """Convert a list of DGLGraphs to CG-GNN cell graphs.""" + return [_convert_dgl_graph_data(g_dgl) for g_dgl in graphs_data] + + +def _handle_random_seed_values(random_seed_value: str | None) -> int | None: + if (random_seed_value is not None) and (str(random_seed_value).strip().lower() != "none"): + return int(random_seed_value) + return None + + +if __name__ == '__main__': + args = parse_arguments() + config_file = ConfigParser() + config_file.read(args.config_file) + random_seed: int | None = None + if 'general' in config_file: + random_seed = _handle_random_seed_values(config_file['general'].get('random_seed', None)) + if 'cg-gnn' not in config_file: + warn('No cg-gnn section in config file. Using default values.') + config_file.read(DEFAULT_CONFIG_FILE) + config = config_file['cg-gnn'] + + in_ram: bool = config.getboolean('in_ram', True) + batch_size: int = config.getint('batch_size', 32) + epochs: int = config.getint('epochs', 10) + learning_rate: float = config.getfloat('learning_rate', 1e-3) + k_folds: int = config.getint('k_folds', 5) + explainer: str = config.get('explainer', 'pp') + merge_rois: bool = config.getboolean('merge_rois', True) + if random_seed is None: + random_seed = _handle_random_seed_values(config.get('random_seed', None)) + + spt_graphs, _ = load_hs_graphs(args.input_directory) + save_cell_graphs(_convert_spt_graphs_data(spt_graphs), args.output_directory) + + model, graphs_data, hs_id_to_importances = train_and_evaluate(args.output_directory, + in_ram, + batch_size, + epochs, + learning_rate, + k_folds, + explainer, + merge_rois, + random_seed) + + save_hs_graphs(_convert_dgl_graphs_data(load_cell_graphs(args.output_directory)[0]), + args.output_directory) + for filename in ('graphs.bin', 'graph_info.pkl'): + graphs_file = join(args.output_directory, filename) + if exists(graphs_file): + remove(graphs_file) diff --git a/plugin/graph_processing/cg-gnn/train_cli.py b/plugin/graph_processing/cg-gnn/train_cli.py new file mode 100644 index 000000000..0cc09e14c --- /dev/null +++ b/plugin/graph_processing/cg-gnn/train_cli.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +"""Process arguments to training command.""" + +from argparse import ArgumentParser + +DEFAULT_CONFIG_FILE = 'training.config' + + +def parse_arguments(): + """Parse arguments.""" + arg_parser = ArgumentParser() + arg_parser.add_argument( + '--input_directory', + type=str, + help='Path to the directory containing the cell graphs to be used for training.', + ) + arg_parser.add_argument( + '--config_file', + type=str, + help='Path to config file.', + default=DEFAULT_CONFIG_FILE, + ) + arg_parser.add_argument( + '--output_directory', + type=str, + help='Path to the directory containing the cell graphs to be used for training.', + ) + return arg_parser.parse_args() diff --git a/plugin/graph_processing/cg-gnn/training.config b/plugin/graph_processing/cg-gnn/training.config new file mode 100644 index 000000000..20e5088ba --- /dev/null +++ b/plugin/graph_processing/cg-gnn/training.config @@ -0,0 +1,8 @@ +[cg-gnn] +; in_ram = true +batch_size = 1 +epochs = 5 +learning_rate = 1e-3 +k_folds = 0 +; explainer_model = pp +merge_rois = true diff --git a/plugin/graph_processing/cg-gnn/util.py b/plugin/graph_processing/cg-gnn/util.py new file mode 100644 index 000000000..e2829ffa8 --- /dev/null +++ b/plugin/graph_processing/cg-gnn/util.py @@ -0,0 +1,173 @@ +"""Histological structure graph dataset utility functions. + +This is a copy of spatialprofilingtoolbox.graphs.util. +""" + +from os import listdir, makedirs +from os.path import join +from json import load as json_load +from random import seed +from typing import NamedTuple, Literal +from dataclasses import dataclass, field + +from numpy import ( + savetxt, # type: ignore + loadtxt, + int_, + float_, +) +from numpy.random import seed as np_seed +from numpy.typing import NDArray +from scipy.sparse import spmatrix, isspmatrix_csr, csr_matrix # type: ignore +from h5py import File # type: ignore + +SETS = ('train', 'validation', 'test') +SETS_type = Literal['train', 'validation', 'test'] + + +@dataclass +class HSGraph: + """A histological structure graph instance.""" + adj: spmatrix + node_features: NDArray[float_] + centroids: NDArray[float_] + histological_structure_ids: NDArray[int_] + importances: NDArray[float_] | None = field(default=None) + + +class GraphData(NamedTuple): + """Data relevant to a histological structure graph instance.""" + graph: HSGraph + label: int | None + name: str + specimen: str + set: SETS_type | None + + +def save_hs_graphs(graphs_data: list[GraphData], output_directory: str) -> None: + """Save histological structure graphs to a directory. + + Saves the adjacency graph separately from the rest of the graph data for compatibility. + """ + makedirs(output_directory, exist_ok=True) + for gd in graphs_data: + save_graph_data(gd, join(output_directory, f'{gd.name}.h5')) + + +def load_hs_graphs(graph_directory: str) -> tuple[list[GraphData], list[str]]: + """Load histological structure graphs from a directory. + + Assumes directory contains the files `graphs.pkl`, `feature_names.txt`, and a sparse array for + every graph in `graphs.pkl`. + """ + graphs_data: list[GraphData] = [] + for filename in listdir(graph_directory): + if filename.endswith('.h5'): + try: + graphs_data.append(load_graph_data(join(graph_directory, filename))) + except KeyError: + raise ValueError(f'Graph data file {filename} is missing required fields.') + feature_names: list[str] = loadtxt( + join(graph_directory, 'feature_names.txt'), + dtype=str, + delimiter=',', + ).tolist() + return graphs_data, feature_names + + +def save_graph_data(graph_data: GraphData, filename: str): + """Save GraphData to an HDF5 file.""" + if not isspmatrix_csr(graph_data.graph.adj): + raise ValueError('Graph adjacency matrix must be a CSR matrix.') + + with File(filename, 'w') as f: + f.create_dataset('graph/adj/data', data=graph_data.graph.adj.data) + f.create_dataset('graph/adj/indices', data=graph_data.graph.adj.indices) + f.create_dataset('graph/adj/indptr', data=graph_data.graph.adj.indptr) + f.create_dataset('graph/adj/shape', data=graph_data.graph.adj.shape) + + f.create_dataset('graph/node_features', data=graph_data.graph.node_features) + f.create_dataset('graph/centroids', data=graph_data.graph.centroids) + f.create_dataset( + 'graph/histological_structure_ids', + data=graph_data.graph.histological_structure_ids, + ) + if graph_data.graph.importances is not None: + f.create_dataset('graph/importances', data=graph_data.graph.importances) + + f.create_dataset('label', data=graph_data.label) + f.create_dataset('name', data=graph_data.name) + f.create_dataset('specimen', data=graph_data.specimen) + f.create_dataset('set', data=graph_data.set) + + +def load_graph_data(filename: str) -> GraphData: + """Load GraphData from an HDF5 file.""" + with File(filename, 'r') as f: + adj_data = f['graph/adj/data'][()] + adj_indices = f['graph/adj/indices'][()] + adj_indptr = f['graph/adj/indptr'][()] + adj_shape = f['graph/adj/shape'][()] + adj = csr_matrix((adj_data, adj_indices, adj_indptr), shape=adj_shape) + + node_features: NDArray[float_] = f['graph/node_features'][()] + centroids: NDArray[float_] = f['graph/centroids'][()] + histological_structure_ids: NDArray[int_] = f['graph/histological_structure_ids'][()] + importances: NDArray[float_] = \ + f['graph/importances'][()] if 'graph/importances' in f else None + + # h5 files store strings as byte arrays + label: int | None = f['label'][()] + name: str = f['name'][()].decode() + specimen: str = f['specimen'][()].decode() + set: SETS_type = f['set'][()].decode() + + graph = HSGraph(adj, node_features, centroids, histological_structure_ids, importances) + return GraphData(graph, label, name, specimen, set) + + +def save_graph_data_and_feature_names( + graphs_data: list[GraphData], + features_to_use: list[str], + output_directory: str, +) -> None: + """Save graph data and feature names to disk.""" + save_hs_graphs(graphs_data, output_directory) + savetxt(join(output_directory, 'feature_names.txt'), features_to_use, fmt='%s', delimiter=',') + + +def load_label_to_result(path: str) -> dict[int, str]: + """Read in label_to_result JSON.""" + return {int(label): result for label, result in json_load( + open(path, encoding='utf-8')).items()} + + +def split_graph_sets(graphs_data: list[GraphData]) -> tuple[ + tuple[list[HSGraph], list[int]], + tuple[list[HSGraph], list[int]], + tuple[list[HSGraph], list[int]], + list[HSGraph], +]: + """Split graph data list into train, validation, test, and unlabeled sets.""" + cg_train: tuple[list[HSGraph], list[int]] = ([], []) + cg_val: tuple[list[HSGraph], list[int]] = ([], []) + cg_test: tuple[list[HSGraph], list[int]] = ([], []) + cg_unlabeled: list[HSGraph] = [] + for gd in graphs_data: + if gd.label is None: + cg_unlabeled.append(gd.graph) + continue + which_set: tuple[list[HSGraph], list[int]] = cg_train + if gd.set == 'validation': + which_set = cg_val + elif gd.set == 'test': + which_set = cg_test + which_set[0].append(gd.graph) + which_set[1].append(gd.label) + return cg_train, cg_val, cg_test, cg_unlabeled + + +def set_seeds(random_seed: int) -> None: + """Set random seeds for all libraries.""" + seed(random_seed) + np_seed(random_seed) diff --git a/plugin/graph_processing/cg-gnn/version.txt b/plugin/graph_processing/cg-gnn/version.txt new file mode 100644 index 000000000..6812f8122 --- /dev/null +++ b/plugin/graph_processing/cg-gnn/version.txt @@ -0,0 +1 @@ +0.0.3 \ No newline at end of file diff --git a/plugin/graph_processing/graph-transformer/README.md b/plugin/graph_processing/graph-transformer/README.md new file mode 100644 index 000000000..1bb788080 --- /dev/null +++ b/plugin/graph_processing/graph-transformer/README.md @@ -0,0 +1,3 @@ +# spt-cg-gnn + +This builds the graph-transfomer SPT plugin as a Docker image. diff --git a/plugin/graph_processing/graph-transformer/graph.config b/plugin/graph_processing/graph-transformer/graph.config new file mode 100644 index 000000000..695fdfcaa --- /dev/null +++ b/plugin/graph_processing/graph-transformer/graph.config @@ -0,0 +1,7 @@ +[graph-generation] +validation_data_percent = 0 +test_data_percent = 15 +; use_channels = true +; use_phenotypes = true +; cells_per_roi_target = 5000 +target_name = diff --git a/plugin/graph_processing/graph-transformer/print_graph_config.sh b/plugin/graph_processing/graph-transformer/print_graph_config.sh new file mode 100644 index 000000000..7f07d1a80 --- /dev/null +++ b/plugin/graph_processing/graph-transformer/print_graph_config.sh @@ -0,0 +1,2 @@ +#!/bin/sh +cat /app/graph.config diff --git a/plugin/graph_processing/graph-transformer/print_training_config.sh b/plugin/graph_processing/graph-transformer/print_training_config.sh new file mode 100644 index 000000000..8fee4a4ae --- /dev/null +++ b/plugin/graph_processing/graph-transformer/print_training_config.sh @@ -0,0 +1,2 @@ +#!/bin/sh +cat /app/training.config diff --git a/plugin/graph_processing/graph-transformer/tmi2022 b/plugin/graph_processing/graph-transformer/tmi2022 new file mode 160000 index 000000000..4cba801e8 --- /dev/null +++ b/plugin/graph_processing/graph-transformer/tmi2022 @@ -0,0 +1 @@ +Subproject commit 4cba801e8096c17835c168f475bb34a8e90f1901 diff --git a/plugin/graph_processing/graph-transformer/train.py b/plugin/graph_processing/graph-transformer/train.py new file mode 100644 index 000000000..a1838f1bc --- /dev/null +++ b/plugin/graph_processing/graph-transformer/train.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +"""Train a model.""" + +from sys import path +from configparser import ConfigParser +from warnings import warn +from os import makedirs +from os.path import join +from subprocess import run +from typing import DefaultDict + +from numpy import mean +from scipy.sparse import coo_matrix +from torch import from_numpy, tensor, long, float as torch_float, Size, sparse_coo_tensor, save, \ + load, sum, stack, mm +from torch.cuda import is_available +from torch.nn.functional import softmax +from pandas import Series +from tqdm import tqdm + +path.append('/app') # noqa +from tmi2022.main import main +from train_cli import parse_arguments, DEFAULT_CONFIG_FILE +from util import GraphData, load_hs_graphs, save_hs_graphs + +TMP_DIRECTORY = 'tmp' + + +def _translate_spt_graphs(spt_graphs: list[GraphData], output_directory: str, + ) -> tuple[int, int, str, str, str, str]: + """Translate the SPT graphs into tmi2022 graphs.""" + makedirs(output_directory, exist_ok=True) + ids_train: list[str] = [] + ids_val: list[str] = [] + ids_test: list[str] = [] + ids_unlabeled: list[str] = [] + for graph_data in spt_graphs: + graph_id = _convert_graph_to_tmi2022(graph_data, output_directory) + match graph_data.set: + case 'train': + ids_train.append(graph_id) + case 'validation': + ids_val.append(graph_id) + case 'test': + ids_test.append(graph_id) + case None: + ids_unlabeled.append(graph_id) + case _: + raise ValueError(f'Unknown set {graph_data.set}') + path_to_train_ids: str = join(output_directory, 'train_set.txt') + path_to_val_ids: str = join(output_directory, 'validation_set.txt') + path_to_test_ids: str = join(output_directory, 'test_set.txt') + path_to_unlabeled_ids: str = join(output_directory, 'unlabeled_set.txt') + with open(path_to_train_ids, 'w', encoding='utf-8') as f: + f.write('\n'.join(ids_train)) + with open(path_to_val_ids, 'w', encoding='utf-8') as f: + f.write('\n'.join(ids_val)) + with open(path_to_test_ids, 'w', encoding='utf-8') as f: + f.write('\n'.join(ids_test)) + with open(path_to_unlabeled_ids, 'w', encoding='utf-8') as f: + f.write('\n'.join(ids_unlabeled)) + + # Find the number of classes + unique_labels: set[int] = set() + for graph_data in spt_graphs: + if graph_data.label is not None: + unique_labels.add(graph_data.label) + n_classes = len(unique_labels) + assert unique_labels == set(range(len(unique_labels))), \ + "Labels are not zero-indexed and non-missing" + n_features = spt_graphs[0].graph.node_features.shape[1] + + return n_classes, n_features, \ + path_to_train_ids, path_to_val_ids, path_to_test_ids, path_to_unlabeled_ids + + +def _convert_graph_to_tmi2022(graph_data: GraphData, data_directory: str) -> str: + """Convert an SPT graph to a tmi2022 graph.""" + # Extract data from the GraphData instance + adj = graph_data.graph.adj + node_features = graph_data.graph.node_features + centroids = graph_data.graph.centroids + histological_structure_ids = graph_data.graph.histological_structure_ids + label = graph_data.label + name = graph_data.name + specimen = graph_data.specimen + + # Convert the adjacency matrix to a PyTorch tensor + adj = coo_matrix(adj) + indices = tensor([adj.row, adj.col], dtype=long) + values = tensor(adj.data, dtype=torch_float) + shape = Size(adj.shape) + adj_s = sparse_coo_tensor(indices, values, shape) + + # Convert the node features to a PyTorch tensor + features = from_numpy(node_features) + centroids = from_numpy(centroids) + histological_structure_ids = from_numpy(histological_structure_ids) + + # Create the directory structure + graph_path = join(data_directory, f'{specimen}_features', 'simclr_files', name) + makedirs(graph_path, exist_ok=True) + + # Save the tensors to disk + save(features, join(graph_path, 'features.pt')) + save(adj_s, join(graph_path, 'adj_s.pt')) + save(centroids, join(graph_path, 'centroids.pt')) + save(histological_structure_ids, join(graph_path, 'histological_structure_ids.pt')) + + # Return the id and label in the format expected by GraphDataset + return f'{specimen}/{name}\t{label}' + + +def run_tmi2022(n_class: int, + n_features: int, + data_path: str, + val_set: str, + train: bool, + train_set: str | None = None, + model_path: str = join(TMP_DIRECTORY, "saved_models"), + log_path: str = join(TMP_DIRECTORY, "runs"), + task_name: str = "GraphCAM", + batch_size: int = 8, + log_interval_local: int = 6, + resume: str = "../graph_transformer/saved_models/GraphCAM.pth", + ) -> None: + """Train or test tmi2022 (the latter for creating GraphCAM ratings).""" + + # Set the CUDA_VISIBLE_DEVICES environment variable + if not is_available(): + raise ValueError("A CUDA-supporting GPU is required.") + + # Call the main function with the appropriate parameters + if train: + assert train_set is not None + main(n_class, + n_features, + data_path, + model_path, + log_path, + task_name, + batch_size, + log_interval_local, + train_set, + val_set, + train=True) + else: # test + if val_set.endswith('.txt'): # test, no graphcam + main(n_class, + n_features, + data_path, + model_path, + log_path, + task_name, + batch_size, + log_interval_local, + val_set=val_set, + test=True, + resume=resume) + else: # we're finding graphcam for one graph + id_txt = write_one_id_to_file(val_set) + main(n_class, + n_features, + data_path, + model_path, + log_path, + task_name, + batch_size, + log_interval_local, + val_set=id_txt, + test=True, + graphcam=True, + resume=resume) + + +def write_one_id_to_file(val_set: str) -> str: + """Write val_set to a txt file in TMP_DIRECTORY and return the file path.""" + file_path = join(TMP_DIRECTORY, 'id.txt') + with open(file_path, 'w', encoding='utf-8') as file: + file.write(val_set) + return file_path + + +# def _convert_from_graphdataset_format(id: str, +# data_directory: str, +# importances: NDArray[float_],) -> GraphData: +# """Convert a tmi2022 graph of id to an SPT graph.""" +# specimen, name = id.split('/') + +# # Load the tensors from disk +# graph_path = join(data_directory, f'{specimen}_features', 'simclr_files', name) +# features = load(join(graph_path, 'features.pt')).numpy() +# adj_s = load(join(graph_path, 'adj_s.pt')).to_dense().numpy() +# centroids = load(join(graph_path, 'centroids.pt')).numpy() +# histological_structure_ids = load(join(graph_path, 'histological_structure_ids.pt')).numpy() + +# # Convert the adjacency matrix back to a sparse matrix +# adj = csr_matrix(adj_s) + +# # Extract the label from the id +# label = int(id.split('\t')[1]) + +# # Create a GraphData instance +# return GraphData(HSGraph(adj, features, centroids, histological_structure_ids, importances), +# label, name, specimen, None) + + +def _handle_random_seed_values(random_seed_value: str | None) -> int | None: + if (random_seed_value is not None) and (str(random_seed_value).strip().lower() != "none"): + return int(random_seed_value) + return None + + +if __name__ == '__main__': + args = parse_arguments() + config_file = ConfigParser() + config_file.read(args.config_file) + random_seed: int | None = None + if 'general' in config_file: + random_seed = _handle_random_seed_values(config_file['general'].get('random_seed', None)) + if 'graph-transformer' not in config_file: + warn('No cg-gnn section in config file. Using default values.') + config_file.read(DEFAULT_CONFIG_FILE) + config = config_file['graph-transformer'] + + # Parse config file + task_name = config.get('task_name', 'GraphCAM') + batch_size = config.getint('batch_size', 8) + log_interval_local = config.getint('log_interval_local', 6) + + spt_graphs, _ = load_hs_graphs(args.input_directory) + + # Call the function with the current args.input_directory and graph_directory + graph_directory = join(TMP_DIRECTORY, 'graphs') + n_classes, n_features, path_to_train_ids, path_to_val_ids, path_to_test_ids, \ + path_to_unlabeled_ids = _translate_spt_graphs(spt_graphs, graph_directory) + # Consider deleting spt_graphs and reloading later to save memory + + # Train tmi2022 + run_tmi2022(n_classes, + n_features, + graph_directory, + path_to_val_ids, + True, + train_set=path_to_train_ids, + model_path=args.output_directory, + log_path=TMP_DIRECTORY, + task_name=task_name, + batch_size=batch_size, + log_interval_local=log_interval_local, + ) + + # Report test results + run_tmi2022(n_classes, + n_features, + graph_directory, + path_to_test_ids, + False, + model_path=args.output_directory, + log_path=TMP_DIRECTORY, + task_name=task_name, + batch_size=1, + log_interval_local=log_interval_local, + resume=join(args.output_directory, f'{task_name}.pth'), + ) + + # Find the importance scores + importance_scores: dict[int, list[float]] = DefaultDict(list) + with open(path_to_test_ids, 'r', encoding='utf-8') as f: + test_ids = f.read().splitlines() + with open(path_to_val_ids, 'r', encoding='utf-8') as f: + val_ids = f.read().splitlines() + with open(path_to_train_ids, 'r', encoding='utf-8') as f: + train_ids = f.read().splitlines() + for single_id in tqdm(test_ids + val_ids + train_ids): + run_tmi2022(n_classes, + n_features, + graph_directory, + single_id, + False, + model_path=args.output_directory, + log_path=TMP_DIRECTORY, + task_name=task_name, + batch_size=1, + log_interval_local=1, + resume=join(args.output_directory, f'{task_name}.pth'), + ) + + # Load the CAM scores and convert them into an importance vector + cams = [load(join('graphcam', f'cam_{i}.pt')).detach().cpu() for i in range(n_classes)] + unified_cam = sum(stack(cams), dim=0) + assign_matrix = load(join('graphcam', 's_matrix_ori.pt')).detach().cpu() + assign_matrix = softmax(assign_matrix, dim=1) + node_importance = mm(assign_matrix, unified_cam.transpose(1, 0)) + node_importance = node_importance.flatten().numpy() + + # Save the importance vector back to the graph in GraphData format + for graph_data in spt_graphs: + if graph_data.name == single_id.split('\t')[0].split('/')[1]: + assert graph_data.graph.node_features.shape[0] == node_importance.shape[0] + graph_data.graph.importances = node_importance + for i, importance in enumerate(node_importance): + importance_scores[graph_data.graph.histological_structure_ids[i]].append( + importance) + break + else: + raise RuntimeError(f'Couldn\'t find graph associated with {single_id}') + + save_hs_graphs(spt_graphs, args.output_directory) + hs_id_to_importance: dict[int, float] = {k: mean(v) for k, v in importance_scores.items()} + s = Series(hs_id_to_importance).sort_index() + s.name = 'importance' + s.to_csv(join(args.output_directory, 'importances.csv')) diff --git a/plugin/graph_processing/graph-transformer/train_cli.py b/plugin/graph_processing/graph-transformer/train_cli.py new file mode 100644 index 000000000..0cc09e14c --- /dev/null +++ b/plugin/graph_processing/graph-transformer/train_cli.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +"""Process arguments to training command.""" + +from argparse import ArgumentParser + +DEFAULT_CONFIG_FILE = 'training.config' + + +def parse_arguments(): + """Parse arguments.""" + arg_parser = ArgumentParser() + arg_parser.add_argument( + '--input_directory', + type=str, + help='Path to the directory containing the cell graphs to be used for training.', + ) + arg_parser.add_argument( + '--config_file', + type=str, + help='Path to config file.', + default=DEFAULT_CONFIG_FILE, + ) + arg_parser.add_argument( + '--output_directory', + type=str, + help='Path to the directory containing the cell graphs to be used for training.', + ) + return arg_parser.parse_args() diff --git a/plugin/graph_processing/graph-transformer/training.config b/plugin/graph_processing/graph-transformer/training.config new file mode 100644 index 000000000..df366d0ef --- /dev/null +++ b/plugin/graph_processing/graph-transformer/training.config @@ -0,0 +1,4 @@ +[graph-transformer] +task_name = GraphCAM +batch_size = 8 +log_interval_local = 6 diff --git a/plugin/graph_processing/graph-transformer/util.py b/plugin/graph_processing/graph-transformer/util.py new file mode 100644 index 000000000..e2829ffa8 --- /dev/null +++ b/plugin/graph_processing/graph-transformer/util.py @@ -0,0 +1,173 @@ +"""Histological structure graph dataset utility functions. + +This is a copy of spatialprofilingtoolbox.graphs.util. +""" + +from os import listdir, makedirs +from os.path import join +from json import load as json_load +from random import seed +from typing import NamedTuple, Literal +from dataclasses import dataclass, field + +from numpy import ( + savetxt, # type: ignore + loadtxt, + int_, + float_, +) +from numpy.random import seed as np_seed +from numpy.typing import NDArray +from scipy.sparse import spmatrix, isspmatrix_csr, csr_matrix # type: ignore +from h5py import File # type: ignore + +SETS = ('train', 'validation', 'test') +SETS_type = Literal['train', 'validation', 'test'] + + +@dataclass +class HSGraph: + """A histological structure graph instance.""" + adj: spmatrix + node_features: NDArray[float_] + centroids: NDArray[float_] + histological_structure_ids: NDArray[int_] + importances: NDArray[float_] | None = field(default=None) + + +class GraphData(NamedTuple): + """Data relevant to a histological structure graph instance.""" + graph: HSGraph + label: int | None + name: str + specimen: str + set: SETS_type | None + + +def save_hs_graphs(graphs_data: list[GraphData], output_directory: str) -> None: + """Save histological structure graphs to a directory. + + Saves the adjacency graph separately from the rest of the graph data for compatibility. + """ + makedirs(output_directory, exist_ok=True) + for gd in graphs_data: + save_graph_data(gd, join(output_directory, f'{gd.name}.h5')) + + +def load_hs_graphs(graph_directory: str) -> tuple[list[GraphData], list[str]]: + """Load histological structure graphs from a directory. + + Assumes directory contains the files `graphs.pkl`, `feature_names.txt`, and a sparse array for + every graph in `graphs.pkl`. + """ + graphs_data: list[GraphData] = [] + for filename in listdir(graph_directory): + if filename.endswith('.h5'): + try: + graphs_data.append(load_graph_data(join(graph_directory, filename))) + except KeyError: + raise ValueError(f'Graph data file {filename} is missing required fields.') + feature_names: list[str] = loadtxt( + join(graph_directory, 'feature_names.txt'), + dtype=str, + delimiter=',', + ).tolist() + return graphs_data, feature_names + + +def save_graph_data(graph_data: GraphData, filename: str): + """Save GraphData to an HDF5 file.""" + if not isspmatrix_csr(graph_data.graph.adj): + raise ValueError('Graph adjacency matrix must be a CSR matrix.') + + with File(filename, 'w') as f: + f.create_dataset('graph/adj/data', data=graph_data.graph.adj.data) + f.create_dataset('graph/adj/indices', data=graph_data.graph.adj.indices) + f.create_dataset('graph/adj/indptr', data=graph_data.graph.adj.indptr) + f.create_dataset('graph/adj/shape', data=graph_data.graph.adj.shape) + + f.create_dataset('graph/node_features', data=graph_data.graph.node_features) + f.create_dataset('graph/centroids', data=graph_data.graph.centroids) + f.create_dataset( + 'graph/histological_structure_ids', + data=graph_data.graph.histological_structure_ids, + ) + if graph_data.graph.importances is not None: + f.create_dataset('graph/importances', data=graph_data.graph.importances) + + f.create_dataset('label', data=graph_data.label) + f.create_dataset('name', data=graph_data.name) + f.create_dataset('specimen', data=graph_data.specimen) + f.create_dataset('set', data=graph_data.set) + + +def load_graph_data(filename: str) -> GraphData: + """Load GraphData from an HDF5 file.""" + with File(filename, 'r') as f: + adj_data = f['graph/adj/data'][()] + adj_indices = f['graph/adj/indices'][()] + adj_indptr = f['graph/adj/indptr'][()] + adj_shape = f['graph/adj/shape'][()] + adj = csr_matrix((adj_data, adj_indices, adj_indptr), shape=adj_shape) + + node_features: NDArray[float_] = f['graph/node_features'][()] + centroids: NDArray[float_] = f['graph/centroids'][()] + histological_structure_ids: NDArray[int_] = f['graph/histological_structure_ids'][()] + importances: NDArray[float_] = \ + f['graph/importances'][()] if 'graph/importances' in f else None + + # h5 files store strings as byte arrays + label: int | None = f['label'][()] + name: str = f['name'][()].decode() + specimen: str = f['specimen'][()].decode() + set: SETS_type = f['set'][()].decode() + + graph = HSGraph(adj, node_features, centroids, histological_structure_ids, importances) + return GraphData(graph, label, name, specimen, set) + + +def save_graph_data_and_feature_names( + graphs_data: list[GraphData], + features_to_use: list[str], + output_directory: str, +) -> None: + """Save graph data and feature names to disk.""" + save_hs_graphs(graphs_data, output_directory) + savetxt(join(output_directory, 'feature_names.txt'), features_to_use, fmt='%s', delimiter=',') + + +def load_label_to_result(path: str) -> dict[int, str]: + """Read in label_to_result JSON.""" + return {int(label): result for label, result in json_load( + open(path, encoding='utf-8')).items()} + + +def split_graph_sets(graphs_data: list[GraphData]) -> tuple[ + tuple[list[HSGraph], list[int]], + tuple[list[HSGraph], list[int]], + tuple[list[HSGraph], list[int]], + list[HSGraph], +]: + """Split graph data list into train, validation, test, and unlabeled sets.""" + cg_train: tuple[list[HSGraph], list[int]] = ([], []) + cg_val: tuple[list[HSGraph], list[int]] = ([], []) + cg_test: tuple[list[HSGraph], list[int]] = ([], []) + cg_unlabeled: list[HSGraph] = [] + for gd in graphs_data: + if gd.label is None: + cg_unlabeled.append(gd.graph) + continue + which_set: tuple[list[HSGraph], list[int]] = cg_train + if gd.set == 'validation': + which_set = cg_val + elif gd.set == 'test': + which_set = cg_test + which_set[0].append(gd.graph) + which_set[1].append(gd.label) + return cg_train, cg_val, cg_test, cg_unlabeled + + +def set_seeds(random_seed: int) -> None: + """Set random seeds for all libraries.""" + seed(random_seed) + np_seed(random_seed) diff --git a/plugin/graph_processing/graph-transformer/version.txt b/plugin/graph_processing/graph-transformer/version.txt new file mode 100644 index 000000000..8a9ecc2ea --- /dev/null +++ b/plugin/graph_processing/graph-transformer/version.txt @@ -0,0 +1 @@ +0.0.1 \ No newline at end of file diff --git a/plugin/graph_processing/template/Dockerfile b/plugin/graph_processing/template/Dockerfile new file mode 100644 index 000000000..22076ac83 --- /dev/null +++ b/plugin/graph_processing/template/Dockerfile @@ -0,0 +1,23 @@ +# Choose an appropriate base image +FROM python:3.11-slim-buster +WORKDIR /app + +# Install apt packages you need here, and then clean up afterward +RUN apt-get update +# RUN apt-get install -y +RUN rm -rf /var/lib/apt/lists/* + +# Install python packages you need here +ENV PIP_NO_CACHE_DIR=1 +RUN pip install h5py==3.10.0 +RUN pip install numpy==1.24.3 +RUN pip install scipy==1.10.1 + +# Make the files you need in this directory available everywhere in the container +ADD . /app +RUN chmod +x train.py +RUN mv train.py /usr/local/bin/spt-plugin-train-on-graphs +RUN chmod +x /app/print_graph_config.sh +RUN mv /app/print_graph_config.sh /usr/local/bin/spt-plugin-print-graph-request-configuration +RUN chmod +x /app/print_training_config.sh +RUN mv /app/print_training_config.sh /usr/local/bin/spt-plugin-print-training-configuration diff --git a/plugin/graph_processing/template/README.md b/plugin/graph_processing/template/README.md new file mode 100644 index 000000000..c07a9f8f3 --- /dev/null +++ b/plugin/graph_processing/template/README.md @@ -0,0 +1,3 @@ +# Template + +Template for SPT plugin development. \ No newline at end of file diff --git a/plugin/graph_processing/template/graph.config b/plugin/graph_processing/template/graph.config new file mode 100644 index 000000000..e69de29bb diff --git a/plugin/graph_processing/template/print_graph_config.sh b/plugin/graph_processing/template/print_graph_config.sh new file mode 100644 index 000000000..7f07d1a80 --- /dev/null +++ b/plugin/graph_processing/template/print_graph_config.sh @@ -0,0 +1,2 @@ +#!/bin/sh +cat /app/graph.config diff --git a/plugin/graph_processing/template/print_training_config.sh b/plugin/graph_processing/template/print_training_config.sh new file mode 100644 index 000000000..8fee4a4ae --- /dev/null +++ b/plugin/graph_processing/template/print_training_config.sh @@ -0,0 +1,2 @@ +#!/bin/sh +cat /app/training.config diff --git a/plugin/graph_processing/template/train.py b/plugin/graph_processing/template/train.py new file mode 100644 index 000000000..f143b4340 --- /dev/null +++ b/plugin/graph_processing/template/train.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +"""Train a model.""" + +from sys import path +from configparser import ConfigParser +from warnings import warn + +path.append('/app') # noqa +from train_cli import parse_arguments, DEFAULT_CONFIG_FILE +from util import HSGraph, GraphData, load_hs_graphs, save_hs_graphs + + +def _handle_random_seed_values(random_seed_value: str | None) -> int | None: + if (random_seed_value is not None) and (str(random_seed_value).strip().lower() != "none"): + return int(random_seed_value) + return None + + +if __name__ == '__main__': + args = parse_arguments() + config_file = ConfigParser() + config_file.read(args.config_file) + random_seed: int | None = None + if 'general' in config_file: + random_seed = _handle_random_seed_values(config_file['general'].get('random_seed', None)) + if 'plugin' not in config_file: + warn('No plugin section in config file. Using default values.') + config_file.read(DEFAULT_CONFIG_FILE) + config = config_file['plugin'] + + spt_graphs, _ = load_hs_graphs(args.input_directory) diff --git a/plugin/graph_processing/template/train_cli.py b/plugin/graph_processing/template/train_cli.py new file mode 100644 index 000000000..0cc09e14c --- /dev/null +++ b/plugin/graph_processing/template/train_cli.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +"""Process arguments to training command.""" + +from argparse import ArgumentParser + +DEFAULT_CONFIG_FILE = 'training.config' + + +def parse_arguments(): + """Parse arguments.""" + arg_parser = ArgumentParser() + arg_parser.add_argument( + '--input_directory', + type=str, + help='Path to the directory containing the cell graphs to be used for training.', + ) + arg_parser.add_argument( + '--config_file', + type=str, + help='Path to config file.', + default=DEFAULT_CONFIG_FILE, + ) + arg_parser.add_argument( + '--output_directory', + type=str, + help='Path to the directory containing the cell graphs to be used for training.', + ) + return arg_parser.parse_args() diff --git a/plugin/graph_processing/template/training.config b/plugin/graph_processing/template/training.config new file mode 100644 index 000000000..e69de29bb diff --git a/plugin/graph_processing/template/util.py b/plugin/graph_processing/template/util.py new file mode 100644 index 000000000..e2829ffa8 --- /dev/null +++ b/plugin/graph_processing/template/util.py @@ -0,0 +1,173 @@ +"""Histological structure graph dataset utility functions. + +This is a copy of spatialprofilingtoolbox.graphs.util. +""" + +from os import listdir, makedirs +from os.path import join +from json import load as json_load +from random import seed +from typing import NamedTuple, Literal +from dataclasses import dataclass, field + +from numpy import ( + savetxt, # type: ignore + loadtxt, + int_, + float_, +) +from numpy.random import seed as np_seed +from numpy.typing import NDArray +from scipy.sparse import spmatrix, isspmatrix_csr, csr_matrix # type: ignore +from h5py import File # type: ignore + +SETS = ('train', 'validation', 'test') +SETS_type = Literal['train', 'validation', 'test'] + + +@dataclass +class HSGraph: + """A histological structure graph instance.""" + adj: spmatrix + node_features: NDArray[float_] + centroids: NDArray[float_] + histological_structure_ids: NDArray[int_] + importances: NDArray[float_] | None = field(default=None) + + +class GraphData(NamedTuple): + """Data relevant to a histological structure graph instance.""" + graph: HSGraph + label: int | None + name: str + specimen: str + set: SETS_type | None + + +def save_hs_graphs(graphs_data: list[GraphData], output_directory: str) -> None: + """Save histological structure graphs to a directory. + + Saves the adjacency graph separately from the rest of the graph data for compatibility. + """ + makedirs(output_directory, exist_ok=True) + for gd in graphs_data: + save_graph_data(gd, join(output_directory, f'{gd.name}.h5')) + + +def load_hs_graphs(graph_directory: str) -> tuple[list[GraphData], list[str]]: + """Load histological structure graphs from a directory. + + Assumes directory contains the files `graphs.pkl`, `feature_names.txt`, and a sparse array for + every graph in `graphs.pkl`. + """ + graphs_data: list[GraphData] = [] + for filename in listdir(graph_directory): + if filename.endswith('.h5'): + try: + graphs_data.append(load_graph_data(join(graph_directory, filename))) + except KeyError: + raise ValueError(f'Graph data file {filename} is missing required fields.') + feature_names: list[str] = loadtxt( + join(graph_directory, 'feature_names.txt'), + dtype=str, + delimiter=',', + ).tolist() + return graphs_data, feature_names + + +def save_graph_data(graph_data: GraphData, filename: str): + """Save GraphData to an HDF5 file.""" + if not isspmatrix_csr(graph_data.graph.adj): + raise ValueError('Graph adjacency matrix must be a CSR matrix.') + + with File(filename, 'w') as f: + f.create_dataset('graph/adj/data', data=graph_data.graph.adj.data) + f.create_dataset('graph/adj/indices', data=graph_data.graph.adj.indices) + f.create_dataset('graph/adj/indptr', data=graph_data.graph.adj.indptr) + f.create_dataset('graph/adj/shape', data=graph_data.graph.adj.shape) + + f.create_dataset('graph/node_features', data=graph_data.graph.node_features) + f.create_dataset('graph/centroids', data=graph_data.graph.centroids) + f.create_dataset( + 'graph/histological_structure_ids', + data=graph_data.graph.histological_structure_ids, + ) + if graph_data.graph.importances is not None: + f.create_dataset('graph/importances', data=graph_data.graph.importances) + + f.create_dataset('label', data=graph_data.label) + f.create_dataset('name', data=graph_data.name) + f.create_dataset('specimen', data=graph_data.specimen) + f.create_dataset('set', data=graph_data.set) + + +def load_graph_data(filename: str) -> GraphData: + """Load GraphData from an HDF5 file.""" + with File(filename, 'r') as f: + adj_data = f['graph/adj/data'][()] + adj_indices = f['graph/adj/indices'][()] + adj_indptr = f['graph/adj/indptr'][()] + adj_shape = f['graph/adj/shape'][()] + adj = csr_matrix((adj_data, adj_indices, adj_indptr), shape=adj_shape) + + node_features: NDArray[float_] = f['graph/node_features'][()] + centroids: NDArray[float_] = f['graph/centroids'][()] + histological_structure_ids: NDArray[int_] = f['graph/histological_structure_ids'][()] + importances: NDArray[float_] = \ + f['graph/importances'][()] if 'graph/importances' in f else None + + # h5 files store strings as byte arrays + label: int | None = f['label'][()] + name: str = f['name'][()].decode() + specimen: str = f['specimen'][()].decode() + set: SETS_type = f['set'][()].decode() + + graph = HSGraph(adj, node_features, centroids, histological_structure_ids, importances) + return GraphData(graph, label, name, specimen, set) + + +def save_graph_data_and_feature_names( + graphs_data: list[GraphData], + features_to_use: list[str], + output_directory: str, +) -> None: + """Save graph data and feature names to disk.""" + save_hs_graphs(graphs_data, output_directory) + savetxt(join(output_directory, 'feature_names.txt'), features_to_use, fmt='%s', delimiter=',') + + +def load_label_to_result(path: str) -> dict[int, str]: + """Read in label_to_result JSON.""" + return {int(label): result for label, result in json_load( + open(path, encoding='utf-8')).items()} + + +def split_graph_sets(graphs_data: list[GraphData]) -> tuple[ + tuple[list[HSGraph], list[int]], + tuple[list[HSGraph], list[int]], + tuple[list[HSGraph], list[int]], + list[HSGraph], +]: + """Split graph data list into train, validation, test, and unlabeled sets.""" + cg_train: tuple[list[HSGraph], list[int]] = ([], []) + cg_val: tuple[list[HSGraph], list[int]] = ([], []) + cg_test: tuple[list[HSGraph], list[int]] = ([], []) + cg_unlabeled: list[HSGraph] = [] + for gd in graphs_data: + if gd.label is None: + cg_unlabeled.append(gd.graph) + continue + which_set: tuple[list[HSGraph], list[int]] = cg_train + if gd.set == 'validation': + which_set = cg_val + elif gd.set == 'test': + which_set = cg_test + which_set[0].append(gd.graph) + which_set[1].append(gd.label) + return cg_train, cg_val, cg_test, cg_unlabeled + + +def set_seeds(random_seed: int) -> None: + """Set random seeds for all libraries.""" + seed(random_seed) + np_seed(random_seed) diff --git a/spatialprofilingtoolbox/workflow/graph_plugin/__init__.py b/spatialprofilingtoolbox/workflow/graph_plugin/__init__.py index 18d941846..737a4156a 100644 --- a/spatialprofilingtoolbox/workflow/graph_plugin/__init__.py +++ b/spatialprofilingtoolbox/workflow/graph_plugin/__init__.py @@ -16,7 +16,7 @@ 'cg-gnn': '0.0.3', 'graph-transformer': '0.0.1', } -CUDA_REQUIRED: tuple[str, ...] = ('graph transformer', ) +CUDA_REQUIRED: tuple[str, ...] = ('graph-transformer', ) CPU_REQUIRED: tuple[str, ...] = () @@ -98,7 +98,8 @@ def _handle_image_params(params: dict[str, str | bool], plugin_name: str) -> Non 'For graph plugin workflows, the container_platform must be either `docker` or ' f'`singularity`, not `{params["container_platform"]}`.') params['graph_plugin_image'] = f'{PLUGIN_DOCKER_IMAGES[plugin_name]}:' \ - f'{"cuda-" if params["cuda"] else ""}{PLUGIN_DOCKER_TAGS[plugin_name]}' + f'{"cuda-" if (params["cuda"] and (plugin_name not in CUDA_REQUIRED)) else ""}' \ + f'{PLUGIN_DOCKER_TAGS[plugin_name]}' params['graph_plugin_singularity_run_options'] = '--nv' if \ ((params['container_platform'] == 'singularity') and params['cuda']) else ''