From 86294180757c70c8c7117020601663089043f7a3 Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Sat, 15 Jul 2023 21:31:11 -0400 Subject: [PATCH] REF: use torchscript models from huggingface hub (#144) * use pytorch 2.0.0 as base image * install g++ * do not remove gcc * import torch to please jit compiler * move custom model impls to custom_models namespace * refactor to use wsinfer-zoo * run isort and then black * rm modeldefs + make modellib and patchlib public (no underscore) * do not use torch.compile on torchscript models * Fix/issue 131 (#133) * use tifffile in lieu of large_image * run isort * make outputs float or None * changes to please mypy * add newline at end of document * add openslide-python and tifffile to core deps * add back roi support and mps device caps * black formatting * rm unused file * add wsinfer-zoo to deps * predownload registry JSON + install system deps in early layer * scale step size and print info Fixes https://github.com/SBU-BMI/wsinfer/issues/135 * add patchlib presets to package data and rm modeldefs * set default step_size to None * only allow step-size=patch-size * allow custom step sizes * update mpp print logs to slide mpp * add tiff mpp via openslide * resize patches to prescribed patch size and spacing * add model config schema * add schemas to package data * fix error messages Replace `--model-name` with `--model`. * create OpenSlide obj in worker_init func Fixes https://github.com/SBU-BMI/wsinfer/issues/137 The OpenSlide object is no longer created in `__init__`. Previously the openslide object was shared across workers. Now each worker creates its own OpenSlide object. I hypothesize that this will allow multi-worker data loading on Windows. * handle num_workers=0 * ADD choice of backends (tiffslide or openslide) (#139) * replace openslide with tiffslide * patch zarr to avoid decoding tiles in duplicate This implements the change proposed in https://github.com/zarr-developers/zarr-python/pull/1454 * rm openslide-python and add tiffslide * do not stitch because it imposes a performance penalty * ignore types in vis_params * add isort and tiffslide to dev deps * add NoBackendException * run isort * use wsinfer.wsi module instead of slide_utils and add tiffslide and openslide backends * use wsinfer.wsi.WSI as generic entrypoint for whole slides * replace PathType with "str | Path" * add logging and backend selection to cli * add "from __future__ import annotations" * TST: update tests for dev branch (#143) * begin to update tests * do not resize images prior to transform This introduces subtle differences from the current stable version of wsinfer. * fix for issue #125 * do not save slide path in model outputs csv * add test_cli_run_with_registered_models * add reference model outputs These reference outputs were created using a patched version of 0.3.6 wsinfer. The patches involved padding the patches from large-image to be the expected patch size. Large image does not pad images by default, whereas openslide and tiffslide pad with black. * skip jit tests and cli with custom config * deprecate python 3.7 * install openslide and tiffslide * remove WSIType object * remove dense grid creation fixes #138 * remove timm and custom models We will focus on using TorchScript models only. In the future, we can also look into using ONNX as a backend. fixes #140 * limit click versions to please mypy related to https://github.com/pallets/click/issues/2558 * satisfy mypy * fix cli args for wsinfer run * fail loudly with dev pytorch + fix jit compile tests * fix test of issue 89 * move wsinfer imports to beginning of file * add test of mutually exclusive cli args * use -p shorthand for model-path * mark that we support typing * add py.typed to package data * run test-package on windows, macos, and linux * fix test of patching * install openslide differently on different systems * close the case statement * fix the way we install openslide on different envs * fix matrix.os test * get line length with python for cross-platform * test "wsinfer run" differently for unix and windows * fix windows test * fix path to csv * skip windows tests for now because tissue segmentation is different * run "wsinfer run" on windows but do not test file length * add test of local model with config --- .github/workflows/ci.yml | 73 +- .gitignore | 2 +- .readthedocs.yaml | 2 +- Dockerfile | 23 +- README.md | 8 +- docs/installing.rst | 8 +- setup.cfg | 32 +- .../purple.csv | 145 +++ .../purple.csv | 145 +++ .../purple.csv | 145 +++ .../lung-tumor-resnet34.tcga-luad/purple.csv | 37 + .../purple.csv | 442 +++++++ .../purple.csv | 5 + .../purple.csv | 145 +++ tests/test_all.py | 1027 ++++------------- wsinfer/__init__.py | 25 +- wsinfer/__main__.py | 2 + wsinfer/_modellib/inceptionv4_no_batchnorm.py | 366 ------ wsinfer/_modellib/models.py | 309 ----- wsinfer/_modellib/resnet_preact.py | 79 -- wsinfer/_modellib/run_inference.py | 438 ------- wsinfer/_modellib/transforms.py | 66 -- wsinfer/_modellib/vgg16mod.py | 20 - wsinfer/_patchlib/create_dense_patch_grid.py | 66 -- wsinfer/_patchlib/utils/__init__.py | 0 wsinfer/_patchlib/utils/utils.py | 238 ---- wsinfer/_patchlib/wsi_core/__init__.py | 0 wsinfer/_version.py | 7 +- wsinfer/cli/__init__.py | 1 + wsinfer/cli/cli.py | 34 +- wsinfer/cli/convert_csv_to_geojson.py | 4 +- wsinfer/cli/convert_csv_to_sbubmi.py | 67 +- wsinfer/cli/infer.py | 193 ++-- wsinfer/cli/list_models_and_weights.py | 25 - wsinfer/cli/patch.py | 4 +- wsinfer/errors.py | 35 + .../modeldefs/inceptionv4_tcga-brca-v1.yaml | 24 - .../inceptionv4nobn_tcga-tils-v1.yaml | 30 - .../preactresnet34_tcga-paad-v1.yaml | 26 - wsinfer/modeldefs/resnet34_tcga-brca-v1.yaml | 24 - wsinfer/modeldefs/resnet34_tcga-luad-v1.yaml | 28 - wsinfer/modeldefs/resnet34_tcga-prad-v1.yaml | 25 - wsinfer/modeldefs/vgg16_tcga-tils-v1.yaml | 20 - wsinfer/modeldefs/vgg16mod_tcga-BRCA-v1.yaml | 29 - wsinfer/modellib/__init__.py | 1 + wsinfer/modellib/data.py | 172 +++ wsinfer/modellib/models.py | 81 ++ wsinfer/modellib/run_inference.py | 211 ++++ wsinfer/modellib/transforms.py | 28 + wsinfer/{_patchlib => patchlib}/README.md | 0 wsinfer/{_patchlib => patchlib}/__init__.py | 2 + .../create_patches_fp.py | 69 +- .../{_patchlib => patchlib}/presets/tcga.csv | 0 wsinfer/patchlib/utils/__init__.py | 1 + .../utils/file_utils.py | 3 + .../wsi_core/WholeSlideImage.py | 41 +- wsinfer/patchlib/wsi_core/__init__.py | 1 + .../wsi_core/batch_process_utils.py | 4 +- .../wsi_core/util_classes.py | 4 +- .../wsi_core/wsi_utils.py | 9 +- wsinfer/{_modellib/__init__.py => py.typed} | 0 wsinfer/schemas/model-config.schema.json | 59 + wsinfer/wsi.py | 200 ++++ 63 files changed, 2428 insertions(+), 2882 deletions(-) create mode 100644 tests/reference/breast-tumor-inception_v4.tcga-brca/purple.csv create mode 100644 tests/reference/breast-tumor-resnet34.tcga-brca/purple.csv create mode 100644 tests/reference/breast-tumor-vgg16mod.tcga-brca/purple.csv create mode 100644 tests/reference/lung-tumor-resnet34.tcga-luad/purple.csv create mode 100644 tests/reference/pancancer-lymphocytes-inceptionv4.tcga/purple.csv create mode 100644 tests/reference/pancreas-tumor-preactresnet34.tcga-paad/purple.csv create mode 100644 tests/reference/prostate-tumor-resnet34.tcga-prad/purple.csv delete mode 100644 wsinfer/_modellib/inceptionv4_no_batchnorm.py delete mode 100644 wsinfer/_modellib/models.py delete mode 100644 wsinfer/_modellib/resnet_preact.py delete mode 100644 wsinfer/_modellib/run_inference.py delete mode 100644 wsinfer/_modellib/transforms.py delete mode 100644 wsinfer/_modellib/vgg16mod.py delete mode 100644 wsinfer/_patchlib/create_dense_patch_grid.py delete mode 100644 wsinfer/_patchlib/utils/__init__.py delete mode 100644 wsinfer/_patchlib/utils/utils.py delete mode 100644 wsinfer/_patchlib/wsi_core/__init__.py delete mode 100644 wsinfer/cli/list_models_and_weights.py create mode 100644 wsinfer/errors.py delete mode 100644 wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml delete mode 100644 wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml delete mode 100644 wsinfer/modeldefs/preactresnet34_tcga-paad-v1.yaml delete mode 100644 wsinfer/modeldefs/resnet34_tcga-brca-v1.yaml delete mode 100644 wsinfer/modeldefs/resnet34_tcga-luad-v1.yaml delete mode 100644 wsinfer/modeldefs/resnet34_tcga-prad-v1.yaml delete mode 100644 wsinfer/modeldefs/vgg16_tcga-tils-v1.yaml delete mode 100644 wsinfer/modeldefs/vgg16mod_tcga-BRCA-v1.yaml create mode 100644 wsinfer/modellib/__init__.py create mode 100644 wsinfer/modellib/data.py create mode 100644 wsinfer/modellib/models.py create mode 100644 wsinfer/modellib/run_inference.py create mode 100644 wsinfer/modellib/transforms.py rename wsinfer/{_patchlib => patchlib}/README.md (100%) rename wsinfer/{_patchlib => patchlib}/__init__.py (96%) rename wsinfer/{_patchlib => patchlib}/create_patches_fp.py (90%) rename wsinfer/{_patchlib => patchlib}/presets/tcga.csv (100%) create mode 100644 wsinfer/patchlib/utils/__init__.py rename wsinfer/{_patchlib => patchlib}/utils/file_utils.py (97%) rename wsinfer/{_patchlib => patchlib}/wsi_core/WholeSlideImage.py (97%) create mode 100644 wsinfer/patchlib/wsi_core/__init__.py rename wsinfer/{_patchlib => patchlib}/wsi_core/batch_process_utils.py (99%) rename wsinfer/{_patchlib => patchlib}/wsi_core/util_classes.py (99%) rename wsinfer/{_patchlib => patchlib}/wsi_core/wsi_utils.py (99%) rename wsinfer/{_modellib/__init__.py => py.typed} (100%) create mode 100644 wsinfer/schemas/model-config.schema.json create mode 100644 wsinfer/wsi.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f21bd01..afe6331 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -20,11 +20,14 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install the package run: | + sudo apt update + sudo apt install -y libopenslide0 python -m pip install --upgrade pip setuptools wheel - python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - python -m pip install --editable .[dev] --find-links https://girder.github.io/large_image_wheels + python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu openslide-python tiffslide + python -m pip install --editable .[dev] - name: Run tests run: python -m pytest --verbose tests/ + test-pytorch-nightly: runs-on: ubuntu-latest steps: @@ -35,14 +38,16 @@ jobs: python-version: "3.10" - name: Install the package run: | + sudo apt update + sudo apt install -y libopenslide0 python -m pip install --upgrade pip setuptools wheel - python -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu - python -m pip install --editable .[dev] --find-links https://girder.github.io/large_image_wheels + python -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu openslide-python tiffslide + python -m pip install --editable .[dev] - name: Check types run: python -m mypy --install-types --non-interactive wsinfer/ - name: Run tests - continue-on-error: true run: python -m pytest --verbose tests/ + test-docker: runs-on: ubuntu-latest steps: @@ -60,35 +65,62 @@ jobs: wget -q https://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/JP2K-33003-1.svs cd .. docker run --rm --shm-size=512m --volume $(pwd):/work --workdir /work wsinferimage run \ - --wsi-dir slides/ --results-dir results/ --model resnet34 --weights TCGA-BRCA-v1 + --wsi-dir slides/ --results-dir results/ --model breast-tumor-resnet34.tcga-brca test -f results/run_metadata_*.json test -f results/patches/JP2K-33003-1.h5 test -f results/model-outputs/JP2K-33003-1.csv test $(wc -l < results/model-outputs/JP2K-33003-1.csv) -eq 653 + + # This is run on multiple operating systems. test-package: - runs-on: ubuntu-latest + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v3 - name: Set up Python 3.10 uses: actions/setup-python@v4 with: python-version: "3.10" - - name: Install the package + - name: Install OpenSlide on Ubuntu + if: matrix.os == 'ubuntu-latest' + run: sudo apt update && sudo apt install -y libopenslide0 + - name: Install OpenSlide on macOS + if: matrix.os == 'macos-latest' + run: brew install openslide + - name: Install the wsinfer python package run: | python -m pip install --upgrade pip setuptools wheel - python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - python -m pip install . --find-links https://girder.github.io/large_image_wheels - - name: Run the wsinfer command in a new directory + python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu openslide-python tiffslide + python -m pip install . + - name: Run 'wsinfer run' on Unix + if: matrix.os != 'windows-latest' run: | mkdir newdir && cd newdir mkdir slides && cd slides wget -q https://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/JP2K-33003-1.svs cd .. - wsinfer run --wsi-dir slides/ --results-dir results/ --model resnet34 --weights TCGA-BRCA-v1 + wsinfer run --wsi-dir slides/ --results-dir results/ --model breast-tumor-resnet34.tcga-brca test -f results/run_metadata_*.json test -f results/patches/JP2K-33003-1.h5 test -f results/model-outputs/JP2K-33003-1.csv test $(wc -l < results/model-outputs/JP2K-33003-1.csv) -eq 653 + # FIXME: tissue segmentation has different outputs on Windows. The patch sizes + # are the same but the coordinates found are different. + - name: Run 'wsinfer run' on Windows + if: matrix.os == 'windows-latest' + run: | + mkdir newdir && cd newdir + mkdir slides && cd slides + Invoke-WebRequest -URI https://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/JP2K-33003-1.svs -OutFile JP2K-33003-1.svs + cd .. + wsinfer run --wsi-dir slides/ --results-dir results/ --model breast-tumor-resnet34.tcga-brca + Test-Path -Path results/run_metadata_*.json -PathType Leaf + Test-Path -Path results/patches/JP2K-33003-1.h5 -PathType Leaf + Test-Path -Path results/model-outputs/JP2K-33003-1.csv -PathType Leaf + # test $(python -c "print(sum(1 for _ in open('results/model-outputs/JP2K-33003-1.csv')))") -eq 653 + style-and-types: runs-on: ubuntu-latest steps: @@ -99,28 +131,33 @@ jobs: python-version: "3.10" - name: Install the package run: | + sudo apt update + sudo apt install -y libopenslide0 python -m pip install --upgrade pip setuptools wheel - python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - python -m pip install .[dev] --find-links https://girder.github.io/large_image_wheels + python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu openslide-python tiffslide + python -m pip install .[dev] - name: Check style (flake8) run: python -m flake8 wsinfer/ - name: Check style (black) run: python -m black --check wsinfer/ - name: Check types run: python -m mypy --install-types --non-interactive wsinfer/ + docs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.10 + - name: Set up Python uses: actions/setup-python@v4 with: python-version: "3.10" - name: Install the package run: | + sudo apt update + sudo apt install -y libopenslide0 python -m pip install --upgrade pip setuptools wheel - python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - python -m pip install .[docs] --find-links https://girder.github.io/large_image_wheels + python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu openslide-python tiffslide + python -m pip install .[docs] - name: Build docs run: | cd docs diff --git a/.gitignore b/.gitignore index 5698338..7c66fc5 100644 --- a/.gitignore +++ b/.gitignore @@ -170,4 +170,4 @@ cython_debug/ .idea/ # Extras -.DS_Store \ No newline at end of file +.DS_Store diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 713aecd..5f478f0 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -10,7 +10,7 @@ build: jobs: post_create_environment: - python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - - python -m pip install .[docs] --find-links https://girder.github.io/large_image_wheels + - python -m pip install .[docs] post_install: # Re-run the installation to ensure we have an appropriate version of sphinx. # We might not want to use the latest version. diff --git a/Dockerfile b/Dockerfile index 23070a9..65ebeb5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,17 +1,32 @@ +# FIXME: when using the torch 2.0.1 image, we get an error +# OSError: /lib/x86_64-linux-gnu/libgobject-2.0.so.0: undefined symbol: ffi_type_uint32, version LIBFFI_BASE_7.0 +# The error is fixed by +# LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libffi.so.7 + FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime -WORKDIR /opt/wsinfer -COPY . . +ARG DEBIAN_FRONTEND="noninteractive" +ENV TZ=Etc/UTC + +# Install system dependencies. RUN apt-get update \ && apt-get install -y --no-install-recommends gcc g++ git libopenslide0 \ - && python -m pip install --no-cache-dir --editable . \ - --find-links https://girder.github.io/large_image_wheels \ && rm -rf /var/lib/apt/lists/* + +# Install wsinfer. +WORKDIR /opt/wsinfer +COPY . . +RUN python -m pip install --no-cache-dir --editable . openslide-python tiffslide + # Use a writable directory for downloading model weights. Default is ~/.cache, which is # not guaranteed to be writable in a Docker container. ENV TORCH_HOME=/var/lib/wsinfer RUN mkdir -p "$TORCH_HOME" \ && chmod 777 "$TORCH_HOME" \ && chmod a+s "$TORCH_HOME" + +# Test that the program runs (and also download the registry JSON file). +RUN wsinfer --help + WORKDIR /work ENTRYPOINT ["wsinfer"] CMD ["--help"] diff --git a/README.md b/README.md index 98cd79d..ebb8374 100644 --- a/README.md +++ b/README.md @@ -23,15 +23,13 @@ We do not install these dependencies automatically because their installation ca on a user's system. Then use the command below to install this package. ``` -python -m pip install --find-links https://girder.github.io/large_image_wheels wsinfer +python -m pip install wsinfer ``` To use the _bleeding edge_, use ``` -python -m pip install \ - --find-links https://girder.github.io/large_image_wheels \ - git+https://github.com/SBU-BMI/wsinfer.git +python -m pip install git+https://github.com/SBU-BMI/wsinfer.git ``` ## Developers @@ -41,7 +39,7 @@ Clone this GitHub repository and install the package (in editable mode with the ``` git clone https://github.com/SBU-BMI/wsinfer.git cd wsinfer -python -m pip install --editable .[dev] --find-links https://girder.github.io/large_image_wheels +python -m pip install --editable .[dev] ``` # Cutting a release diff --git a/docs/installing.rst b/docs/installing.rst index 0cb0637..e8916c8 100644 --- a/docs/installing.rst +++ b/docs/installing.rst @@ -6,7 +6,7 @@ Installing and getting started Prerequisites ------------- -WSInfer supports Python 3.7+ and has been tested on Linux. +WSInfer supports Python 3.8+ and has been tested on Linux. Install PyTorch before installing WSInfer. Please see `PyTorch's installation instructions `_. @@ -21,11 +21,9 @@ the type of hardware a user has. Manual installation ------------------- -After having installed PyTorch, install releases of WSInfer from `PyPI `_. -Be sure to include the line :code:`--find-links https://girder.github.io/large_image_wheels` to ensure -dependencies are installed properly. :: +After having installed PyTorch, install releases of WSInfer from `PyPI `_. :: - pip install wsinfer --find-links https://girder.github.io/large_image_wheels + pip install wsinfer This installs the :code:`wsinfer` Python package and the :code:`wsinfer` command line program. :: diff --git a/setup.cfg b/setup.cfg index 7fa3960..d49db10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,6 @@ classifiers = Operating System :: OS Independent Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 @@ -29,35 +28,36 @@ classifiers = [options] packages = find: -python_requires = >= 3.7 +python_requires = >= 3.8 install_requires = - click>=8.0,<9 + click>=8.0,<9,!=8.1.4,!=8.1.5 geojson h5py - # OpenSlide and TIFF readers should handle all images we will encounter. - large-image[openslide,tiff]>=1.8.0 numpy opencv-python-headless>=4.0.0 pandas pillow pyyaml shapely - timm + tifffile + tiffslide # The installation fo torch and torchvision can differ by hardware. Users are # advised to install torch and torchvision for their given hardware and then install # wsinfer. See https://pytorch.org/get-started/locally/. torch>=1.7 torchvision tqdm + wsinfer-zoo [options.extras_require] dev = black flake8 imagecodecs # for tifffile + isort mypy pytest - tifffile + tiffslide types-Pillow types-PyYAML types-tqdm @@ -73,8 +73,9 @@ console_scripts = [options.package_data] wsinfer = - _patchlib/presets/*.csv - modeldefs/*.yaml + py.typed + patchlib/presets/*.csv + schemas/*.json [flake8] max-line-length = 88 @@ -84,8 +85,6 @@ exclude = wsinfer/_version.py [mypy] [mypy-h5py] ignore_missing_imports = True -[mypy-large_image] -ignore_missing_imports = True [mypy-cv2] ignore_missing_imports = True [mypy-geojson] @@ -96,12 +95,16 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-pandas] ignore_missing_imports = True -[mypy-timm.*] +[mypy-safetensors.*] ignore_missing_imports = True [mypy-scipy.stats] ignore_missing_imports = True [mypy-shapely.*] ignore_missing_imports = True +[mypy-tifffile] +ignore_missing_imports = True +[mypy-zarr.storage] +ignore_missing_imports = True [versioneer] VCS = git @@ -110,3 +113,8 @@ versionfile_source = wsinfer/_version.py versionfile_build = wsinfer/_version.py tag_prefix = v parentdir_prefix = wsinfer + + +[isort] +profile = black +force_single_line = True diff --git a/tests/reference/breast-tumor-inception_v4.tcga-brca/purple.csv b/tests/reference/breast-tumor-inception_v4.tcga-brca/purple.csv new file mode 100644 index 0000000..5e527ca --- /dev/null +++ b/tests/reference/breast-tumor-inception_v4.tcga-brca/purple.csv @@ -0,0 +1,145 @@ +minx,miny,width,height,prob_notumor,prob_tumor +0,0,350,350,0.9564111828804016,0.0435887686908245 +0,350,350,350,0.9564111828804016,0.0435887686908245 +0,700,350,350,0.9564111828804016,0.0435887686908245 +0,1050,350,350,0.9564111828804016,0.0435887686908245 +0,1400,350,350,0.9564111828804016,0.0435887686908245 +0,1750,350,350,0.9564111828804016,0.0435887686908245 +0,2100,350,350,0.9564111828804016,0.0435887686908245 +0,2450,350,350,0.9564111828804016,0.0435887686908245 +0,2800,350,350,0.9564111828804016,0.0435887686908245 +0,3150,350,350,0.9564111828804016,0.0435887686908245 +0,3500,350,350,0.9564111828804016,0.0435887686908245 +0,3850,350,350,0.9600746631622314,0.0399253033101558 +350,0,350,350,0.9564111828804016,0.0435887686908245 +350,350,350,350,0.9564111828804016,0.0435887686908245 +350,700,350,350,0.9564111828804016,0.0435887686908245 +350,1050,350,350,0.9564111828804016,0.0435887686908245 +350,1400,350,350,0.9564111828804016,0.0435887686908245 +350,1750,350,350,0.9564111828804016,0.0435887686908245 +350,2100,350,350,0.9564111828804016,0.0435887686908245 +350,2450,350,350,0.9564111828804016,0.0435887686908245 +350,2800,350,350,0.9564111828804016,0.0435887686908245 +350,3150,350,350,0.9564111828804016,0.0435887686908245 +350,3500,350,350,0.9564111828804016,0.0435887686908245 +350,3850,350,350,0.9600746631622314,0.0399253033101558 +700,0,350,350,0.9564111828804016,0.0435887686908245 +700,350,350,350,0.9564111828804016,0.0435887686908245 +700,700,350,350,0.9564111828804016,0.0435887686908245 +700,1050,350,350,0.9564111828804016,0.0435887686908245 +700,1400,350,350,0.9564111828804016,0.0435887686908245 +700,1750,350,350,0.9564111828804016,0.0435887686908245 +700,2100,350,350,0.9564111828804016,0.0435887686908245 +700,2450,350,350,0.9564111828804016,0.0435887686908245 +700,2800,350,350,0.9564111828804016,0.0435887686908245 +700,3150,350,350,0.9564111828804016,0.0435887686908245 +700,3500,350,350,0.9564111828804016,0.0435887686908245 +700,3850,350,350,0.9600746631622314,0.0399253033101558 +1050,0,350,350,0.9564111828804016,0.0435887686908245 +1050,350,350,350,0.9564111828804016,0.0435887686908245 +1050,700,350,350,0.9564111828804016,0.0435887686908245 +1050,1050,350,350,0.9564111828804016,0.0435887686908245 +1050,1400,350,350,0.9564111828804016,0.0435887686908245 +1050,1750,350,350,0.9564111828804016,0.0435887686908245 +1050,2100,350,350,0.9564111828804016,0.0435887686908245 +1050,2450,350,350,0.9564111828804016,0.0435887686908245 +1050,2800,350,350,0.9564111828804016,0.0435887686908245 +1050,3150,350,350,0.9564111828804016,0.0435887686908245 +1050,3500,350,350,0.9564111828804016,0.0435887686908245 +1050,3850,350,350,0.9600746631622314,0.0399253033101558 +1400,0,350,350,0.9564111828804016,0.0435887686908245 +1400,350,350,350,0.9564111828804016,0.0435887686908245 +1400,700,350,350,0.9564111828804016,0.0435887686908245 +1400,1050,350,350,0.9564111828804016,0.0435887686908245 +1400,1400,350,350,0.9564111828804016,0.0435887686908245 +1400,1750,350,350,0.9564111828804016,0.0435887686908245 +1400,2100,350,350,0.9564111828804016,0.0435887686908245 +1400,2450,350,350,0.9564111828804016,0.0435887686908245 +1400,2800,350,350,0.9564111828804016,0.0435887686908245 +1400,3150,350,350,0.9564111828804016,0.0435887686908245 +1400,3500,350,350,0.9564111828804016,0.0435887686908245 +1400,3850,350,350,0.9600746631622314,0.0399253033101558 +1750,0,350,350,0.9564111828804016,0.0435887686908245 +1750,350,350,350,0.9564111828804016,0.0435887686908245 +1750,700,350,350,0.9564111828804016,0.0435887686908245 +1750,1050,350,350,0.9564111828804016,0.0435887686908245 +1750,1400,350,350,0.9564111828804016,0.0435887686908245 +1750,1750,350,350,0.9564111828804016,0.0435887686908245 +1750,2100,350,350,0.9564111828804016,0.0435887686908245 +1750,2450,350,350,0.9564111828804016,0.0435887686908245 +1750,2800,350,350,0.9564111828804016,0.0435887686908245 +1750,3150,350,350,0.9564111828804016,0.0435887686908245 +1750,3500,350,350,0.9564111828804016,0.0435887686908245 +1750,3850,350,350,0.9600746631622314,0.0399253033101558 +2100,0,350,350,0.9564111828804016,0.0435887686908245 +2100,350,350,350,0.9564111828804016,0.0435887686908245 +2100,700,350,350,0.9564111828804016,0.0435887686908245 +2100,1050,350,350,0.9564111828804016,0.0435887686908245 +2100,1400,350,350,0.9564111828804016,0.0435887686908245 +2100,1750,350,350,0.9564111828804016,0.0435887686908245 +2100,2100,350,350,0.9564111828804016,0.0435887686908245 +2100,2450,350,350,0.9564111828804016,0.0435887686908245 +2100,2800,350,350,0.9564111828804016,0.0435887686908245 +2100,3150,350,350,0.9564111828804016,0.0435887686908245 +2100,3500,350,350,0.9564111828804016,0.0435887686908245 +2100,3850,350,350,0.9600746631622314,0.0399253033101558 +2450,0,350,350,0.9564111828804016,0.0435887686908245 +2450,350,350,350,0.9564111828804016,0.0435887686908245 +2450,700,350,350,0.9564111828804016,0.0435887686908245 +2450,1050,350,350,0.9564111828804016,0.0435887686908245 +2450,1400,350,350,0.9564111828804016,0.0435887686908245 +2450,1750,350,350,0.9564111828804016,0.0435887686908245 +2450,2100,350,350,0.9564111828804016,0.0435887686908245 +2450,2450,350,350,0.9564111828804016,0.0435887686908245 +2450,2800,350,350,0.9564111828804016,0.0435887686908245 +2450,3150,350,350,0.9564111828804016,0.0435887686908245 +2450,3500,350,350,0.9564111828804016,0.0435887686908245 +2450,3850,350,350,0.9600746631622314,0.0399253033101558 +2800,0,350,350,0.9564111828804016,0.0435887686908245 +2800,350,350,350,0.9564111828804016,0.0435887686908245 +2800,700,350,350,0.9564111828804016,0.0435887686908245 +2800,1050,350,350,0.9564111828804016,0.0435887686908245 +2800,1400,350,350,0.9564111828804016,0.0435887686908245 +2800,1750,350,350,0.9564111828804016,0.0435887686908245 +2800,2100,350,350,0.9564111828804016,0.0435887686908245 +2800,2450,350,350,0.9564111828804016,0.0435887686908245 +2800,2800,350,350,0.9564111828804016,0.0435887686908245 +2800,3150,350,350,0.9564111828804016,0.0435887686908245 +2800,3500,350,350,0.9564111828804016,0.0435887686908245 +2800,3850,350,350,0.9600746631622314,0.0399253033101558 +3150,0,350,350,0.9564111828804016,0.0435887686908245 +3150,350,350,350,0.9564111828804016,0.0435887686908245 +3150,700,350,350,0.9564111828804016,0.0435887686908245 +3150,1050,350,350,0.9564111828804016,0.0435887686908245 +3150,1400,350,350,0.9564111828804016,0.0435887686908245 +3150,1750,350,350,0.9564111828804016,0.0435887686908245 +3150,2100,350,350,0.9564111828804016,0.0435887686908245 +3150,2450,350,350,0.9564111828804016,0.0435887686908245 +3150,2800,350,350,0.9564111828804016,0.0435887686908245 +3150,3150,350,350,0.9564111828804016,0.0435887686908245 +3150,3500,350,350,0.9564111828804016,0.0435887686908245 +3150,3850,350,350,0.9600746631622314,0.0399253033101558 +3500,0,350,350,0.9564111828804016,0.0435887686908245 +3500,350,350,350,0.9564111828804016,0.0435887686908245 +3500,700,350,350,0.9564111828804016,0.0435887686908245 +3500,1050,350,350,0.9564111828804016,0.0435887686908245 +3500,1400,350,350,0.9564111828804016,0.0435887686908245 +3500,1750,350,350,0.9564111828804016,0.0435887686908245 +3500,2100,350,350,0.9564111828804016,0.0435887686908245 +3500,2450,350,350,0.9564111828804016,0.0435887686908245 +3500,2800,350,350,0.9564111828804016,0.0435887463390827 +3500,3150,350,350,0.9564111828804016,0.0435887463390827 +3500,3500,350,350,0.9564111828804016,0.0435887463390827 +3500,3850,350,350,0.9600746631622314,0.0399253405630588 +3850,0,350,350,0.967224657535553,0.0327753536403179 +3850,350,350,350,0.967224657535553,0.0327753536403179 +3850,700,350,350,0.967224657535553,0.0327753536403179 +3850,1050,350,350,0.967224657535553,0.0327753536403179 +3850,1400,350,350,0.967224657535553,0.0327753536403179 +3850,1750,350,350,0.967224657535553,0.0327753536403179 +3850,2100,350,350,0.967224657535553,0.0327753536403179 +3850,2450,350,350,0.967224657535553,0.0327753536403179 +3850,2800,350,350,0.967224657535553,0.0327753536403179 +3850,3150,350,350,0.967224657535553,0.0327753536403179 +3850,3500,350,350,0.967224657535553,0.0327753536403179 +3850,3850,350,350,0.9655932784080504,0.0344067215919494 diff --git a/tests/reference/breast-tumor-resnet34.tcga-brca/purple.csv b/tests/reference/breast-tumor-resnet34.tcga-brca/purple.csv new file mode 100644 index 0000000..27cc80d --- /dev/null +++ b/tests/reference/breast-tumor-resnet34.tcga-brca/purple.csv @@ -0,0 +1,145 @@ +minx,miny,width,height,prob_notumor,prob_tumor +0,0,350,350,0.9525965452194214,0.0474034734070301 +0,350,350,350,0.9525965452194214,0.0474034734070301 +0,700,350,350,0.9525965452194214,0.0474034734070301 +0,1050,350,350,0.9525965452194214,0.0474034734070301 +0,1400,350,350,0.9525965452194214,0.0474034734070301 +0,1750,350,350,0.9525965452194214,0.0474034734070301 +0,2100,350,350,0.9525965452194214,0.0474034734070301 +0,2450,350,350,0.9525965452194214,0.0474034734070301 +0,2800,350,350,0.9525965452194214,0.0474034734070301 +0,3150,350,350,0.9525965452194214,0.0474034734070301 +0,3500,350,350,0.9525965452194214,0.0474034734070301 +0,3850,350,350,0.983870267868042,0.0161297302693128 +350,0,350,350,0.9525965452194214,0.0474034734070301 +350,350,350,350,0.9525965452194214,0.0474034734070301 +350,700,350,350,0.9525965452194214,0.0474034734070301 +350,1050,350,350,0.9525965452194214,0.0474034734070301 +350,1400,350,350,0.9525965452194214,0.0474034734070301 +350,1750,350,350,0.9525965452194214,0.0474034734070301 +350,2100,350,350,0.9525965452194214,0.0474034734070301 +350,2450,350,350,0.9525965452194214,0.0474034734070301 +350,2800,350,350,0.9525965452194214,0.0474034734070301 +350,3150,350,350,0.9525965452194214,0.0474034734070301 +350,3500,350,350,0.9525965452194214,0.0474034734070301 +350,3850,350,350,0.983870267868042,0.0161297302693128 +700,0,350,350,0.9525965452194214,0.0474034734070301 +700,350,350,350,0.9525965452194214,0.0474034734070301 +700,700,350,350,0.9525965452194214,0.0474034734070301 +700,1050,350,350,0.9525965452194214,0.0474034734070301 +700,1400,350,350,0.9525965452194214,0.0474034734070301 +700,1750,350,350,0.9525965452194214,0.0474034734070301 +700,2100,350,350,0.9525965452194214,0.0474034734070301 +700,2450,350,350,0.9525965452194214,0.0474034734070301 +700,2800,350,350,0.9525965452194214,0.0474034734070301 +700,3150,350,350,0.9525965452194214,0.0474034734070301 +700,3500,350,350,0.9525965452194214,0.0474034734070301 +700,3850,350,350,0.983870267868042,0.0161297302693128 +1050,0,350,350,0.9525965452194214,0.0474034734070301 +1050,350,350,350,0.9525965452194214,0.0474034734070301 +1050,700,350,350,0.9525965452194214,0.0474034734070301 +1050,1050,350,350,0.9525965452194214,0.0474034734070301 +1050,1400,350,350,0.9525965452194214,0.0474034734070301 +1050,1750,350,350,0.9525965452194214,0.0474034734070301 +1050,2100,350,350,0.9525965452194214,0.0474034734070301 +1050,2450,350,350,0.9525965452194214,0.0474034734070301 +1050,2800,350,350,0.9525965452194214,0.0474034734070301 +1050,3150,350,350,0.9525965452194214,0.0474034734070301 +1050,3500,350,350,0.9525965452194214,0.0474034734070301 +1050,3850,350,350,0.983870267868042,0.0161297302693128 +1400,0,350,350,0.9525965452194214,0.0474034734070301 +1400,350,350,350,0.9525965452194214,0.0474034734070301 +1400,700,350,350,0.9525965452194214,0.0474034734070301 +1400,1050,350,350,0.9525965452194214,0.0474034734070301 +1400,1400,350,350,0.9525965452194214,0.0474034734070301 +1400,1750,350,350,0.9525965452194214,0.0474034734070301 +1400,2100,350,350,0.9525965452194214,0.0474034734070301 +1400,2450,350,350,0.9525965452194214,0.0474034734070301 +1400,2800,350,350,0.9525965452194214,0.0474034734070301 +1400,3150,350,350,0.9525965452194214,0.0474034734070301 +1400,3500,350,350,0.9525965452194214,0.0474034734070301 +1400,3850,350,350,0.983870267868042,0.0161297302693128 +1750,0,350,350,0.9525965452194214,0.0474034734070301 +1750,350,350,350,0.9525965452194214,0.0474034734070301 +1750,700,350,350,0.9525965452194214,0.0474034734070301 +1750,1050,350,350,0.9525965452194214,0.0474034734070301 +1750,1400,350,350,0.9525965452194214,0.0474034734070301 +1750,1750,350,350,0.9525965452194214,0.0474034734070301 +1750,2100,350,350,0.9525965452194214,0.0474034734070301 +1750,2450,350,350,0.9525965452194214,0.0474034734070301 +1750,2800,350,350,0.9525965452194214,0.0474034734070301 +1750,3150,350,350,0.9525965452194214,0.0474034734070301 +1750,3500,350,350,0.9525965452194214,0.0474034734070301 +1750,3850,350,350,0.983870267868042,0.0161297302693128 +2100,0,350,350,0.9525965452194214,0.0474034734070301 +2100,350,350,350,0.9525965452194214,0.0474034734070301 +2100,700,350,350,0.9525965452194214,0.0474034734070301 +2100,1050,350,350,0.9525965452194214,0.0474034734070301 +2100,1400,350,350,0.9525965452194214,0.0474034734070301 +2100,1750,350,350,0.9525965452194214,0.0474034734070301 +2100,2100,350,350,0.9525965452194214,0.0474034734070301 +2100,2450,350,350,0.9525965452194214,0.0474034734070301 +2100,2800,350,350,0.9525965452194214,0.0474034734070301 +2100,3150,350,350,0.9525965452194214,0.0474034734070301 +2100,3500,350,350,0.9525965452194214,0.0474034734070301 +2100,3850,350,350,0.983870267868042,0.0161297302693128 +2450,0,350,350,0.9525965452194214,0.0474034734070301 +2450,350,350,350,0.9525965452194214,0.0474034734070301 +2450,700,350,350,0.9525965452194214,0.0474034734070301 +2450,1050,350,350,0.9525965452194214,0.0474034734070301 +2450,1400,350,350,0.9525965452194214,0.0474034734070301 +2450,1750,350,350,0.9525965452194214,0.0474034734070301 +2450,2100,350,350,0.9525965452194214,0.0474034734070301 +2450,2450,350,350,0.9525965452194214,0.0474034734070301 +2450,2800,350,350,0.9525965452194214,0.0474034734070301 +2450,3150,350,350,0.9525965452194214,0.0474034734070301 +2450,3500,350,350,0.9525965452194214,0.0474034734070301 +2450,3850,350,350,0.983870267868042,0.0161297302693128 +2800,0,350,350,0.9525965452194214,0.0474034734070301 +2800,350,350,350,0.9525965452194214,0.0474034734070301 +2800,700,350,350,0.9525965452194214,0.0474034734070301 +2800,1050,350,350,0.9525965452194214,0.0474034734070301 +2800,1400,350,350,0.9525965452194214,0.0474034734070301 +2800,1750,350,350,0.9525965452194214,0.0474034734070301 +2800,2100,350,350,0.9525965452194214,0.0474034734070301 +2800,2450,350,350,0.9525965452194214,0.0474034734070301 +2800,2800,350,350,0.9525965452194214,0.0474034734070301 +2800,3150,350,350,0.9525965452194214,0.0474034734070301 +2800,3500,350,350,0.9525965452194214,0.0474034734070301 +2800,3850,350,350,0.983870267868042,0.0161297302693128 +3150,0,350,350,0.9525965452194214,0.0474034734070301 +3150,350,350,350,0.9525965452194214,0.0474034734070301 +3150,700,350,350,0.9525965452194214,0.0474034734070301 +3150,1050,350,350,0.9525965452194214,0.0474034734070301 +3150,1400,350,350,0.9525965452194214,0.0474034734070301 +3150,1750,350,350,0.9525965452194214,0.0474034734070301 +3150,2100,350,350,0.9525965452194214,0.0474034734070301 +3150,2450,350,350,0.9525965452194214,0.0474034734070301 +3150,2800,350,350,0.9525965452194214,0.0474034734070301 +3150,3150,350,350,0.9525965452194214,0.0474034734070301 +3150,3500,350,350,0.9525965452194214,0.0474034734070301 +3150,3850,350,350,0.983870267868042,0.0161297302693128 +3500,0,350,350,0.9525965452194214,0.0474034734070301 +3500,350,350,350,0.9525965452194214,0.0474034734070301 +3500,700,350,350,0.9525965452194214,0.0474034734070301 +3500,1050,350,350,0.9525965452194214,0.0474034734070301 +3500,1400,350,350,0.9525965452194214,0.0474034734070301 +3500,1750,350,350,0.9525965452194214,0.0474034734070301 +3500,2100,350,350,0.9525965452194214,0.0474034734070301 +3500,2450,350,350,0.9525965452194214,0.0474034734070301 +3500,2800,350,350,0.9525965452194214,0.0474034063518047 +3500,3150,350,350,0.9525965452194214,0.0474034063518047 +3500,3500,350,350,0.9525965452194214,0.0474034063518047 +3500,3850,350,350,0.983870267868042,0.0161297619342803 +3850,0,350,350,0.9536041617393494,0.0463958717882633 +3850,350,350,350,0.9536041617393494,0.0463958717882633 +3850,700,350,350,0.9536041617393494,0.0463958717882633 +3850,1050,350,350,0.9536041617393494,0.0463958717882633 +3850,1400,350,350,0.9536041617393494,0.0463958717882633 +3850,1750,350,350,0.9536041617393494,0.0463958717882633 +3850,2100,350,350,0.9536041617393494,0.0463958717882633 +3850,2450,350,350,0.9536041617393494,0.0463958717882633 +3850,2800,350,350,0.9536041617393494,0.0463958717882633 +3850,3150,350,350,0.9536041617393494,0.0463958717882633 +3850,3500,350,350,0.9536041617393494,0.0463958717882633 +3850,3850,350,350,0.9831942915916444,0.0168057028204202 diff --git a/tests/reference/breast-tumor-vgg16mod.tcga-brca/purple.csv b/tests/reference/breast-tumor-vgg16mod.tcga-brca/purple.csv new file mode 100644 index 0000000..15678d0 --- /dev/null +++ b/tests/reference/breast-tumor-vgg16mod.tcga-brca/purple.csv @@ -0,0 +1,145 @@ +minx,miny,width,height,prob_notumor,prob_tumor +0,0,350,350,0.9108285307884216,0.0891714394092559 +0,350,350,350,0.9108285307884216,0.0891714394092559 +0,700,350,350,0.9108285307884216,0.0891714394092559 +0,1050,350,350,0.9108285307884216,0.0891714394092559 +0,1400,350,350,0.9108285307884216,0.0891714394092559 +0,1750,350,350,0.9108285307884216,0.0891714394092559 +0,2100,350,350,0.9108285307884216,0.0891714394092559 +0,2450,350,350,0.9108285307884216,0.0891714394092559 +0,2800,350,350,0.9108285307884216,0.0891714394092559 +0,3150,350,350,0.9108285307884216,0.0891714394092559 +0,3500,350,350,0.9108285307884216,0.0891714394092559 +0,3850,350,350,0.8815857172012329,0.1184142604470253 +350,0,350,350,0.9108285307884216,0.0891714394092559 +350,350,350,350,0.9108285307884216,0.0891714394092559 +350,700,350,350,0.9108285307884216,0.0891714394092559 +350,1050,350,350,0.9108285307884216,0.0891714394092559 +350,1400,350,350,0.9108285307884216,0.0891714394092559 +350,1750,350,350,0.9108285307884216,0.0891714394092559 +350,2100,350,350,0.9108285307884216,0.0891714394092559 +350,2450,350,350,0.9108285307884216,0.0891714394092559 +350,2800,350,350,0.9108285307884216,0.0891714394092559 +350,3150,350,350,0.9108285307884216,0.0891714394092559 +350,3500,350,350,0.9108285307884216,0.0891714394092559 +350,3850,350,350,0.8815857172012329,0.1184142604470253 +700,0,350,350,0.9108285307884216,0.0891714394092559 +700,350,350,350,0.9108285307884216,0.0891714394092559 +700,700,350,350,0.9108285307884216,0.0891714394092559 +700,1050,350,350,0.9108285307884216,0.0891714394092559 +700,1400,350,350,0.9108285307884216,0.0891714394092559 +700,1750,350,350,0.9108285307884216,0.0891714394092559 +700,2100,350,350,0.9108285307884216,0.0891714394092559 +700,2450,350,350,0.9108285307884216,0.0891714394092559 +700,2800,350,350,0.9108285307884216,0.0891714394092559 +700,3150,350,350,0.9108285307884216,0.0891714394092559 +700,3500,350,350,0.9108285307884216,0.0891714394092559 +700,3850,350,350,0.8815857172012329,0.1184142604470253 +1050,0,350,350,0.9108285307884216,0.0891714394092559 +1050,350,350,350,0.9108285307884216,0.0891714394092559 +1050,700,350,350,0.9108285307884216,0.0891714394092559 +1050,1050,350,350,0.9108285307884216,0.0891714394092559 +1050,1400,350,350,0.9108285307884216,0.0891714394092559 +1050,1750,350,350,0.9108285307884216,0.0891714394092559 +1050,2100,350,350,0.9108285307884216,0.0891714394092559 +1050,2450,350,350,0.9108285307884216,0.0891714394092559 +1050,2800,350,350,0.9108285307884216,0.0891714394092559 +1050,3150,350,350,0.9108285307884216,0.0891714394092559 +1050,3500,350,350,0.9108285307884216,0.0891714394092559 +1050,3850,350,350,0.8815857172012329,0.1184142604470253 +1400,0,350,350,0.9108285307884216,0.0891714394092559 +1400,350,350,350,0.9108285307884216,0.0891714394092559 +1400,700,350,350,0.9108285307884216,0.0891714394092559 +1400,1050,350,350,0.9108285307884216,0.0891714394092559 +1400,1400,350,350,0.9108285307884216,0.0891714394092559 +1400,1750,350,350,0.9108285307884216,0.0891714394092559 +1400,2100,350,350,0.9108285307884216,0.0891714394092559 +1400,2450,350,350,0.9108285307884216,0.0891714394092559 +1400,2800,350,350,0.9108285307884216,0.0891714394092559 +1400,3150,350,350,0.9108285307884216,0.0891714394092559 +1400,3500,350,350,0.9108285307884216,0.0891714394092559 +1400,3850,350,350,0.8815857172012329,0.1184142604470253 +1750,0,350,350,0.9108285307884216,0.0891714394092559 +1750,350,350,350,0.9108285307884216,0.0891714394092559 +1750,700,350,350,0.9108285307884216,0.0891714394092559 +1750,1050,350,350,0.9108285307884216,0.0891714394092559 +1750,1400,350,350,0.9108285307884216,0.0891714394092559 +1750,1750,350,350,0.9108285307884216,0.0891714394092559 +1750,2100,350,350,0.9108285307884216,0.0891714394092559 +1750,2450,350,350,0.9108285307884216,0.0891714394092559 +1750,2800,350,350,0.9108285307884216,0.0891714394092559 +1750,3150,350,350,0.9108285307884216,0.0891714394092559 +1750,3500,350,350,0.9108285307884216,0.0891714394092559 +1750,3850,350,350,0.8815857172012329,0.1184142604470253 +2100,0,350,350,0.9108285307884216,0.0891714394092559 +2100,350,350,350,0.9108285307884216,0.0891714394092559 +2100,700,350,350,0.9108285307884216,0.0891714394092559 +2100,1050,350,350,0.9108285307884216,0.0891714394092559 +2100,1400,350,350,0.9108285307884216,0.0891714394092559 +2100,1750,350,350,0.9108285307884216,0.0891714394092559 +2100,2100,350,350,0.9108285307884216,0.0891714394092559 +2100,2450,350,350,0.9108285307884216,0.0891714394092559 +2100,2800,350,350,0.9108285307884216,0.0891714394092559 +2100,3150,350,350,0.9108285307884216,0.0891714394092559 +2100,3500,350,350,0.9108285307884216,0.0891714394092559 +2100,3850,350,350,0.8815857172012329,0.1184142604470253 +2450,0,350,350,0.9108285307884216,0.0891714394092559 +2450,350,350,350,0.9108285307884216,0.0891714394092559 +2450,700,350,350,0.9108285307884216,0.0891714394092559 +2450,1050,350,350,0.9108285307884216,0.0891714394092559 +2450,1400,350,350,0.9108285307884216,0.0891714394092559 +2450,1750,350,350,0.9108285307884216,0.0891714394092559 +2450,2100,350,350,0.9108285307884216,0.0891714394092559 +2450,2450,350,350,0.9108285307884216,0.0891714394092559 +2450,2800,350,350,0.9108285307884216,0.0891714394092559 +2450,3150,350,350,0.9108285307884216,0.0891714394092559 +2450,3500,350,350,0.9108285307884216,0.0891714394092559 +2450,3850,350,350,0.8815857172012329,0.1184142604470253 +2800,0,350,350,0.9108285307884216,0.0891714394092559 +2800,350,350,350,0.9108285307884216,0.0891714394092559 +2800,700,350,350,0.9108285307884216,0.0891714394092559 +2800,1050,350,350,0.9108285307884216,0.0891714394092559 +2800,1400,350,350,0.9108285307884216,0.0891714394092559 +2800,1750,350,350,0.9108285307884216,0.0891714394092559 +2800,2100,350,350,0.9108285307884216,0.0891714394092559 +2800,2450,350,350,0.9108285307884216,0.0891714394092559 +2800,2800,350,350,0.9108285307884216,0.0891714394092559 +2800,3150,350,350,0.9108285307884216,0.0891714394092559 +2800,3500,350,350,0.9108285307884216,0.0891714394092559 +2800,3850,350,350,0.8815857172012329,0.1184142604470253 +3150,0,350,350,0.9108285307884216,0.0891714394092559 +3150,350,350,350,0.9108285307884216,0.0891714394092559 +3150,700,350,350,0.9108285307884216,0.0891714394092559 +3150,1050,350,350,0.9108285307884216,0.0891714394092559 +3150,1400,350,350,0.9108285307884216,0.0891714394092559 +3150,1750,350,350,0.9108285307884216,0.0891714394092559 +3150,2100,350,350,0.9108285307884216,0.0891714394092559 +3150,2450,350,350,0.9108285307884216,0.0891714394092559 +3150,2800,350,350,0.9108285307884216,0.0891714394092559 +3150,3150,350,350,0.9108285307884216,0.0891714394092559 +3150,3500,350,350,0.9108285307884216,0.0891714394092559 +3150,3850,350,350,0.8815857172012329,0.1184142604470253 +3500,0,350,350,0.9108285307884216,0.0891714394092559 +3500,350,350,350,0.9108285307884216,0.0891714394092559 +3500,700,350,350,0.9108285307884216,0.0891714394092559 +3500,1050,350,350,0.9108285307884216,0.0891714394092559 +3500,1400,350,350,0.9108285307884216,0.0891714394092559 +3500,1750,350,350,0.9108285307884216,0.0891714394092559 +3500,2100,350,350,0.9108285307884216,0.0891714394092559 +3500,2450,350,350,0.9108285307884216,0.0891714394092559 +3500,2800,350,350,0.9108285307884216,0.0891714617609977 +3500,3150,350,350,0.9108285307884216,0.0891714617609977 +3500,3500,350,350,0.9108285307884216,0.0891714617609977 +3500,3850,350,350,0.8815857768058777,0.1184142380952835 +3850,0,350,350,0.8739222884178162,0.126077726483345 +3850,350,350,350,0.8739222884178162,0.126077726483345 +3850,700,350,350,0.8739222884178162,0.126077726483345 +3850,1050,350,350,0.8739222884178162,0.126077726483345 +3850,1400,350,350,0.8739222884178162,0.126077726483345 +3850,1750,350,350,0.8739222884178162,0.126077726483345 +3850,2100,350,350,0.8739222884178162,0.126077726483345 +3850,2450,350,350,0.8739222884178162,0.126077726483345 +3850,2800,350,350,0.8739222884178162,0.126077726483345 +3850,3150,350,350,0.8739222884178162,0.126077726483345 +3850,3500,350,350,0.8739222884178162,0.126077726483345 +3850,3850,350,350,0.8631160855293274,0.1368838846683502 diff --git a/tests/reference/lung-tumor-resnet34.tcga-luad/purple.csv b/tests/reference/lung-tumor-resnet34.tcga-luad/purple.csv new file mode 100644 index 0000000..10efb3a --- /dev/null +++ b/tests/reference/lung-tumor-resnet34.tcga-luad/purple.csv @@ -0,0 +1,37 @@ +minx,miny,width,height,prob_lepidic,prob_benign,prob_acinar,prob_micropapillary,prob_mucinous,prob_solid +0,0,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +0,700,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +0,1400,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +0,2100,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +0,2800,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +0,3500,700,700,0.0071091777645051,0.9879971742630004,0.0035794575233012,0.0001476426259614,0.0004695237439591,0.0006970081012696 +700,0,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +700,700,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +700,1400,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +700,2100,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +700,2800,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +700,3500,700,700,0.0071091777645051,0.9879971742630004,0.0035794575233012,0.0001476426259614,0.0004695237439591,0.0006970081012696 +1400,0,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +1400,700,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +1400,1400,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +1400,2100,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +1400,2800,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +1400,3500,700,700,0.0071091777645051,0.9879971742630004,0.0035794575233012,0.0001476426259614,0.0004695237439591,0.0006970081012696 +2100,0,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2100,700,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2100,1400,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2100,2100,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2100,2800,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2100,3500,700,700,0.0071091777645051,0.9879971742630004,0.0035794575233012,0.0001476426259614,0.0004695237439591,0.0006970081012696 +2800,0,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2800,700,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2800,1400,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2800,2100,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2800,2800,700,700,0.0127929542213678,0.979295015335083,0.0050891297869384,0.0003837002441287,0.0006556882872246,0.0017834370955824 +2800,3500,700,700,0.0071091777645051,0.9879971742630004,0.0035794575233012,0.0001476426259614,0.0004695237439591,0.0006970081012696 +3500,0,700,700,0.0127342632040381,0.9793447852134703,0.0057523809373378,0.000368564564269,0.0004710552457254,0.0013289707712829 +3500,700,700,700,0.0127342632040381,0.9793447852134703,0.0057523809373378,0.000368564564269,0.0004710552457254,0.0013289707712829 +3500,1400,700,700,0.0127342399209737,0.9793447852134703,0.0057523669674992,0.0003685631672851,0.0004710543144028,0.0013289669295772 +3500,2100,700,700,0.0127342399209737,0.9793447852134703,0.0057523669674992,0.0003685631672851,0.0004710543144028,0.0013289669295772 +3500,2800,700,700,0.0127342399209737,0.9793447852134703,0.0057523669674992,0.0003685631672851,0.0004710543144028,0.0013289669295772 +3500,3500,700,700,0.008527117781341,0.9865036010742188,0.00379602285102,0.0001453986478736,0.0003295204369351,0.0006984102074056 diff --git a/tests/reference/pancancer-lymphocytes-inceptionv4.tcga/purple.csv b/tests/reference/pancancer-lymphocytes-inceptionv4.tcga/purple.csv new file mode 100644 index 0000000..67f7ee1 --- /dev/null +++ b/tests/reference/pancancer-lymphocytes-inceptionv4.tcga/purple.csv @@ -0,0 +1,442 @@ +minx,miny,width,height,prob_notils,prob_tils +0,0,200,200,1.0,3.427372535086404e-12 +0,200,200,200,1.0,3.427372535086404e-12 +0,400,200,200,1.0,3.427372535086404e-12 +0,600,200,200,1.0,3.427372535086404e-12 +0,800,200,200,1.0,3.427372535086404e-12 +0,1000,200,200,1.0,3.427372535086404e-12 +0,1200,200,200,1.0,3.427372535086404e-12 +0,1400,200,200,1.0,3.427372535086404e-12 +0,1600,200,200,1.0,3.427372535086404e-12 +0,1800,200,200,1.0,3.427372535086404e-12 +0,2000,200,200,1.0,3.427372535086404e-12 +0,2200,200,200,1.0,3.427372535086404e-12 +0,2400,200,200,1.0,3.427372535086404e-12 +0,2600,200,200,1.0,3.427372535086404e-12 +0,2800,200,200,1.0,3.427372535086404e-12 +0,3000,200,200,1.0,3.427372535086404e-12 +0,3200,200,200,1.0,3.427372535086404e-12 +0,3400,200,200,1.0,3.427372535086404e-12 +0,3600,200,200,1.0,3.427372535086404e-12 +0,3800,200,200,1.0,3.427372535086404e-12 +0,4000,200,200,0.9999136924743652,8.626562339486554e-05 +200,0,200,200,1.0,3.427372535086404e-12 +200,200,200,200,1.0,3.427372535086404e-12 +200,400,200,200,1.0,3.427372535086404e-12 +200,600,200,200,1.0,3.427372535086404e-12 +200,800,200,200,1.0,3.427372535086404e-12 +200,1000,200,200,1.0,3.427372535086404e-12 +200,1200,200,200,1.0,3.427372535086404e-12 +200,1400,200,200,1.0,3.427372535086404e-12 +200,1600,200,200,1.0,3.427372535086404e-12 +200,1800,200,200,1.0,3.427372535086404e-12 +200,2000,200,200,1.0,3.427372535086404e-12 +200,2200,200,200,1.0,3.427372535086404e-12 +200,2400,200,200,1.0,3.427372535086404e-12 +200,2600,200,200,1.0,3.427372535086404e-12 +200,2800,200,200,1.0,3.427372535086404e-12 +200,3000,200,200,1.0,3.427372535086404e-12 +200,3200,200,200,1.0,3.427372535086404e-12 +200,3400,200,200,1.0,3.427372535086404e-12 +200,3600,200,200,1.0,3.427372535086404e-12 +200,3800,200,200,1.0,3.427372535086404e-12 +200,4000,200,200,0.9999136924743652,8.626562339486554e-05 +400,0,200,200,1.0,3.427372535086404e-12 +400,200,200,200,1.0,3.427372535086404e-12 +400,400,200,200,1.0,3.427372535086404e-12 +400,600,200,200,1.0,3.427372535086404e-12 +400,800,200,200,1.0,3.427372535086404e-12 +400,1000,200,200,1.0,3.427372535086404e-12 +400,1200,200,200,1.0,3.427372535086404e-12 +400,1400,200,200,1.0,3.427372535086404e-12 +400,1600,200,200,1.0,3.427372535086404e-12 +400,1800,200,200,1.0,3.427372535086404e-12 +400,2000,200,200,1.0,3.427372535086404e-12 +400,2200,200,200,1.0,3.427372535086404e-12 +400,2400,200,200,1.0,3.427372535086404e-12 +400,2600,200,200,1.0,3.427372535086404e-12 +400,2800,200,200,1.0,3.427372535086404e-12 +400,3000,200,200,1.0,3.427372535086404e-12 +400,3200,200,200,1.0,3.427372535086404e-12 +400,3400,200,200,1.0,3.427372535086404e-12 +400,3600,200,200,1.0,3.427372535086404e-12 +400,3800,200,200,1.0,3.427372535086404e-12 +400,4000,200,200,0.9999136924743652,8.626562339486554e-05 +600,0,200,200,1.0,3.427372535086404e-12 +600,200,200,200,1.0,3.427372535086404e-12 +600,400,200,200,1.0,3.427372535086404e-12 +600,600,200,200,1.0,3.427372535086404e-12 +600,800,200,200,1.0,3.427372535086404e-12 +600,1000,200,200,1.0,3.427372535086404e-12 +600,1200,200,200,1.0,3.427372535086404e-12 +600,1400,200,200,1.0,3.427372535086404e-12 +600,1600,200,200,1.0,3.427372535086404e-12 +600,1800,200,200,1.0,3.427372535086404e-12 +600,2000,200,200,1.0,3.427372535086404e-12 +600,2200,200,200,1.0,3.427372535086404e-12 +600,2400,200,200,1.0,3.427372535086404e-12 +600,2600,200,200,1.0,3.427372535086404e-12 +600,2800,200,200,1.0,3.427372535086404e-12 +600,3000,200,200,1.0,3.427372535086404e-12 +600,3200,200,200,1.0,3.427372535086404e-12 +600,3400,200,200,1.0,3.427372535086404e-12 +600,3600,200,200,1.0,3.427372535086404e-12 +600,3800,200,200,1.0,3.427372535086404e-12 +600,4000,200,200,0.9999136924743652,8.626562339486554e-05 +800,0,200,200,1.0,3.427372535086404e-12 +800,200,200,200,1.0,3.427372535086404e-12 +800,400,200,200,1.0,3.427372535086404e-12 +800,600,200,200,1.0,3.427372535086404e-12 +800,800,200,200,1.0,3.427372535086404e-12 +800,1000,200,200,1.0,3.427372535086404e-12 +800,1200,200,200,1.0,3.427372535086404e-12 +800,1400,200,200,1.0,3.427372535086404e-12 +800,1600,200,200,1.0,3.427372535086404e-12 +800,1800,200,200,1.0,3.427372535086404e-12 +800,2000,200,200,1.0,3.427372535086404e-12 +800,2200,200,200,1.0,3.427372535086404e-12 +800,2400,200,200,1.0,3.427372535086404e-12 +800,2600,200,200,1.0,3.427372535086404e-12 +800,2800,200,200,1.0,3.427372535086404e-12 +800,3000,200,200,1.0,3.427372535086404e-12 +800,3200,200,200,1.0,3.427372535086404e-12 +800,3400,200,200,1.0,3.427372535086404e-12 +800,3600,200,200,1.0,3.427372535086404e-12 +800,3800,200,200,1.0,3.427372535086404e-12 +800,4000,200,200,0.9999136924743652,8.626562339486554e-05 +1000,0,200,200,1.0,3.427372535086404e-12 +1000,200,200,200,1.0,3.427372535086404e-12 +1000,400,200,200,1.0,3.427372535086404e-12 +1000,600,200,200,1.0,3.427372535086404e-12 +1000,800,200,200,1.0,3.427372535086404e-12 +1000,1000,200,200,1.0,3.427372535086404e-12 +1000,1200,200,200,1.0,3.427372535086404e-12 +1000,1400,200,200,1.0,3.427372535086404e-12 +1000,1600,200,200,1.0,3.427372535086404e-12 +1000,1800,200,200,1.0,3.427372535086404e-12 +1000,2000,200,200,1.0,3.427372535086404e-12 +1000,2200,200,200,1.0,3.427372535086404e-12 +1000,2400,200,200,1.0,3.427372535086404e-12 +1000,2600,200,200,1.0,3.427372535086404e-12 +1000,2800,200,200,1.0,3.427372535086404e-12 +1000,3000,200,200,1.0,3.427372535086404e-12 +1000,3200,200,200,1.0,3.427372535086404e-12 +1000,3400,200,200,1.0,3.427372535086404e-12 +1000,3600,200,200,1.0,3.427372535086404e-12 +1000,3800,200,200,1.0,3.427372535086404e-12 +1000,4000,200,200,0.9999136924743652,8.626562339486554e-05 +1200,0,200,200,1.0,3.427372535086404e-12 +1200,200,200,200,1.0,3.427372535086404e-12 +1200,400,200,200,1.0,3.427372535086404e-12 +1200,600,200,200,1.0,3.427372535086404e-12 +1200,800,200,200,1.0,3.427372535086404e-12 +1200,1000,200,200,1.0,3.427372535086404e-12 +1200,1200,200,200,1.0,3.427372535086404e-12 +1200,1400,200,200,1.0,3.427372535086404e-12 +1200,1600,200,200,1.0,3.427372535086404e-12 +1200,1800,200,200,1.0,3.427372535086404e-12 +1200,2000,200,200,1.0,3.427372535086404e-12 +1200,2200,200,200,1.0,3.427372535086404e-12 +1200,2400,200,200,1.0,3.427372535086404e-12 +1200,2600,200,200,1.0,3.427372535086404e-12 +1200,2800,200,200,1.0,3.427372535086404e-12 +1200,3000,200,200,1.0,3.427372535086404e-12 +1200,3200,200,200,1.0,3.427372535086404e-12 +1200,3400,200,200,1.0,3.427372535086404e-12 +1200,3600,200,200,1.0,3.427372535086404e-12 +1200,3800,200,200,1.0,3.427372535086404e-12 +1200,4000,200,200,0.9999136924743652,8.626562339486554e-05 +1400,0,200,200,1.0,3.427372535086404e-12 +1400,200,200,200,1.0,3.427372535086404e-12 +1400,400,200,200,1.0,3.427372535086404e-12 +1400,600,200,200,1.0,3.427372535086404e-12 +1400,800,200,200,1.0,3.427372535086404e-12 +1400,1000,200,200,1.0,3.427372535086404e-12 +1400,1200,200,200,1.0,3.427372535086404e-12 +1400,1400,200,200,1.0,3.427372535086404e-12 +1400,1600,200,200,1.0,3.427372535086404e-12 +1400,1800,200,200,1.0,3.427372535086404e-12 +1400,2000,200,200,1.0,3.427372535086404e-12 +1400,2200,200,200,1.0,3.427372535086404e-12 +1400,2400,200,200,1.0,3.427372535086404e-12 +1400,2600,200,200,1.0,3.427372535086404e-12 +1400,2800,200,200,1.0,3.427372535086404e-12 +1400,3000,200,200,1.0,3.427372535086404e-12 +1400,3200,200,200,1.0,3.427372535086404e-12 +1400,3400,200,200,1.0,3.427372535086404e-12 +1400,3600,200,200,1.0,3.427372535086404e-12 +1400,3800,200,200,1.0,3.427372535086404e-12 +1400,4000,200,200,0.9999136924743652,8.626562339486554e-05 +1600,0,200,200,1.0,3.427372535086404e-12 +1600,200,200,200,1.0,3.427372535086404e-12 +1600,400,200,200,1.0,3.427372535086404e-12 +1600,600,200,200,1.0,3.427372535086404e-12 +1600,800,200,200,1.0,3.427372535086404e-12 +1600,1000,200,200,1.0,3.427372535086404e-12 +1600,1200,200,200,1.0,3.427372535086404e-12 +1600,1400,200,200,1.0,3.427372535086404e-12 +1600,1600,200,200,1.0,3.427372535086404e-12 +1600,1800,200,200,1.0,3.427372535086404e-12 +1600,2000,200,200,1.0,3.427372535086404e-12 +1600,2200,200,200,1.0,3.427372535086404e-12 +1600,2400,200,200,1.0,3.427372535086404e-12 +1600,2600,200,200,1.0,3.427372535086404e-12 +1600,2800,200,200,1.0,3.427372535086404e-12 +1600,3000,200,200,1.0,3.427372535086404e-12 +1600,3200,200,200,1.0,3.427372535086404e-12 +1600,3400,200,200,1.0,3.427372535086404e-12 +1600,3600,200,200,1.0,3.427372535086404e-12 +1600,3800,200,200,1.0,3.427372535086404e-12 +1600,4000,200,200,0.9999136924743652,8.626562339486554e-05 +1800,0,200,200,1.0,3.427372535086404e-12 +1800,200,200,200,1.0,3.427372535086404e-12 +1800,400,200,200,1.0,3.427372535086404e-12 +1800,600,200,200,1.0,3.427372535086404e-12 +1800,800,200,200,1.0,3.427372535086404e-12 +1800,1000,200,200,1.0,3.427372535086404e-12 +1800,1200,200,200,1.0,3.427372535086404e-12 +1800,1400,200,200,1.0,3.427372535086404e-12 +1800,1600,200,200,1.0,3.427372535086404e-12 +1800,1800,200,200,1.0,3.427372535086404e-12 +1800,2000,200,200,1.0,3.427372535086404e-12 +1800,2200,200,200,1.0,3.427372535086404e-12 +1800,2400,200,200,1.0,3.427372535086404e-12 +1800,2600,200,200,1.0,3.427372535086404e-12 +1800,2800,200,200,1.0,3.427372535086404e-12 +1800,3000,200,200,1.0,3.427372535086404e-12 +1800,3200,200,200,1.0,3.427372535086404e-12 +1800,3400,200,200,1.0,3.427372535086404e-12 +1800,3600,200,200,1.0,3.427372535086404e-12 +1800,3800,200,200,1.0,3.427372535086404e-12 +1800,4000,200,200,0.9999136924743652,8.626562339486554e-05 +2000,0,200,200,1.0,3.427372535086404e-12 +2000,200,200,200,1.0,3.427372535086404e-12 +2000,400,200,200,1.0,3.427372535086404e-12 +2000,600,200,200,1.0,3.427372535086404e-12 +2000,800,200,200,1.0,3.427372535086404e-12 +2000,1000,200,200,1.0,3.427372535086404e-12 +2000,1200,200,200,1.0,3.427372535086404e-12 +2000,1400,200,200,1.0,3.427372535086404e-12 +2000,1600,200,200,1.0,3.427372535086404e-12 +2000,1800,200,200,1.0,3.427372535086404e-12 +2000,2000,200,200,1.0,3.427372535086404e-12 +2000,2200,200,200,1.0,3.427372535086404e-12 +2000,2400,200,200,1.0,3.427372535086404e-12 +2000,2600,200,200,1.0,3.427372535086404e-12 +2000,2800,200,200,1.0,3.427372535086404e-12 +2000,3000,200,200,1.0,3.427372535086404e-12 +2000,3200,200,200,1.0,3.427372535086404e-12 +2000,3400,200,200,1.0,3.427372535086404e-12 +2000,3600,200,200,1.0,3.427372535086404e-12 +2000,3800,200,200,1.0,3.427372535086404e-12 +2000,4000,200,200,0.9999136924743652,8.626562339486554e-05 +2200,0,200,200,1.0,3.427372535086404e-12 +2200,200,200,200,1.0,3.427372535086404e-12 +2200,400,200,200,1.0,3.427372535086404e-12 +2200,600,200,200,1.0,3.427372535086404e-12 +2200,800,200,200,1.0,3.427372535086404e-12 +2200,1000,200,200,1.0,3.427372535086404e-12 +2200,1200,200,200,1.0,3.427372535086404e-12 +2200,1400,200,200,1.0,3.427372535086404e-12 +2200,1600,200,200,1.0,3.427372535086404e-12 +2200,1800,200,200,1.0,3.427372535086404e-12 +2200,2000,200,200,1.0,3.427372535086404e-12 +2200,2200,200,200,1.0,3.427372535086404e-12 +2200,2400,200,200,1.0,3.427372535086404e-12 +2200,2600,200,200,1.0,3.427372535086404e-12 +2200,2800,200,200,1.0,3.427372535086404e-12 +2200,3000,200,200,1.0,3.427372535086404e-12 +2200,3200,200,200,1.0,3.427372535086404e-12 +2200,3400,200,200,1.0,3.427372535086404e-12 +2200,3600,200,200,1.0,3.427372535086404e-12 +2200,3800,200,200,1.0,3.427372535086404e-12 +2200,4000,200,200,0.9999136924743652,8.626562339486554e-05 +2400,0,200,200,1.0,3.427372535086404e-12 +2400,200,200,200,1.0,3.427372535086404e-12 +2400,400,200,200,1.0,3.427372535086404e-12 +2400,600,200,200,1.0,3.427372535086404e-12 +2400,800,200,200,1.0,3.427372535086404e-12 +2400,1000,200,200,1.0,3.427372535086404e-12 +2400,1200,200,200,1.0,3.427372535086404e-12 +2400,1400,200,200,1.0,3.427372535086404e-12 +2400,1600,200,200,1.0,3.427372535086404e-12 +2400,1800,200,200,1.0,3.427372535086404e-12 +2400,2000,200,200,1.0,3.427372535086404e-12 +2400,2200,200,200,1.0,3.427372535086404e-12 +2400,2400,200,200,1.0,3.427372535086404e-12 +2400,2600,200,200,1.0,3.427372535086404e-12 +2400,2800,200,200,1.0,3.427372535086404e-12 +2400,3000,200,200,1.0,3.427372535086404e-12 +2400,3200,200,200,1.0,3.427372535086404e-12 +2400,3400,200,200,1.0,3.427372535086404e-12 +2400,3600,200,200,1.0,3.427372535086404e-12 +2400,3800,200,200,1.0,3.427372535086404e-12 +2400,4000,200,200,0.9999136924743652,8.626562339486554e-05 +2600,0,200,200,1.0,3.427372535086404e-12 +2600,200,200,200,1.0,3.427372535086404e-12 +2600,400,200,200,1.0,3.427372535086404e-12 +2600,600,200,200,1.0,3.427372535086404e-12 +2600,800,200,200,1.0,3.427372535086404e-12 +2600,1000,200,200,1.0,3.427372535086404e-12 +2600,1200,200,200,1.0,3.427372535086404e-12 +2600,1400,200,200,1.0,3.427372535086404e-12 +2600,1600,200,200,1.0,3.427372535086404e-12 +2600,1800,200,200,1.0,3.427372535086404e-12 +2600,2000,200,200,1.0,3.427372535086404e-12 +2600,2200,200,200,1.0,3.427372535086404e-12 +2600,2400,200,200,1.0,3.427372535086404e-12 +2600,2600,200,200,1.0,3.427372535086404e-12 +2600,2800,200,200,1.0,3.427372535086404e-12 +2600,3000,200,200,1.0,3.427372535086404e-12 +2600,3200,200,200,1.0,3.427372535086404e-12 +2600,3400,200,200,1.0,3.427372535086404e-12 +2600,3600,200,200,1.0,3.427372535086404e-12 +2600,3800,200,200,1.0,3.427372535086404e-12 +2600,4000,200,200,0.9999136924743652,8.626562339486554e-05 +2800,0,200,200,1.0,3.427372535086404e-12 +2800,200,200,200,1.0,3.427372535086404e-12 +2800,400,200,200,1.0,3.427372535086404e-12 +2800,600,200,200,1.0,3.427372535086404e-12 +2800,800,200,200,1.0,3.427372535086404e-12 +2800,1000,200,200,1.0,3.427372535086404e-12 +2800,1200,200,200,1.0,3.427372535086404e-12 +2800,1400,200,200,1.0,3.427372535086404e-12 +2800,1600,200,200,1.0,3.427372535086404e-12 +2800,1800,200,200,1.0,3.427372535086404e-12 +2800,2000,200,200,1.0,3.427372535086404e-12 +2800,2200,200,200,1.0,3.427372535086404e-12 +2800,2400,200,200,1.0,3.427372535086404e-12 +2800,2600,200,200,1.0,3.427372535086404e-12 +2800,2800,200,200,1.0,3.427372535086404e-12 +2800,3000,200,200,1.0,3.427372535086404e-12 +2800,3200,200,200,1.0,3.427372535086404e-12 +2800,3400,200,200,1.0,3.427372535086404e-12 +2800,3600,200,200,1.0,3.427372535086404e-12 +2800,3800,200,200,1.0,3.427372535086404e-12 +2800,4000,200,200,0.9999136924743652,8.626562339486554e-05 +3000,0,200,200,1.0,3.427372535086404e-12 +3000,200,200,200,1.0,3.427372535086404e-12 +3000,400,200,200,1.0,3.427372535086404e-12 +3000,600,200,200,1.0,3.427372535086404e-12 +3000,800,200,200,1.0,3.427372535086404e-12 +3000,1000,200,200,1.0,3.427372535086404e-12 +3000,1200,200,200,1.0,3.427372535086404e-12 +3000,1400,200,200,1.0,3.427372535086404e-12 +3000,1600,200,200,1.0,3.427372535086404e-12 +3000,1800,200,200,1.0,3.427372535086404e-12 +3000,2000,200,200,1.0,3.427372535086404e-12 +3000,2200,200,200,1.0,3.427372535086404e-12 +3000,2400,200,200,1.0,3.427372535086404e-12 +3000,2600,200,200,1.0,3.427372535086404e-12 +3000,2800,200,200,1.0,3.427372535086404e-12 +3000,3000,200,200,1.0,3.427372535086404e-12 +3000,3200,200,200,1.0,3.427372535086404e-12 +3000,3400,200,200,1.0,3.427372535086404e-12 +3000,3600,200,200,1.0,3.427372535086404e-12 +3000,3800,200,200,1.0,3.427372535086404e-12 +3000,4000,200,200,0.9999136924743652,8.626562339486554e-05 +3200,0,200,200,1.0,3.427372535086404e-12 +3200,200,200,200,1.0,3.427372535086404e-12 +3200,400,200,200,1.0,3.427372535086404e-12 +3200,600,200,200,1.0,3.427372535086404e-12 +3200,800,200,200,1.0,3.427372535086404e-12 +3200,1000,200,200,1.0,3.427372535086404e-12 +3200,1200,200,200,1.0,3.427372535086404e-12 +3200,1400,200,200,1.0,3.427372535086404e-12 +3200,1600,200,200,1.0,3.427372535086404e-12 +3200,1800,200,200,1.0,3.427372535086404e-12 +3200,2000,200,200,1.0,3.427372535086404e-12 +3200,2200,200,200,1.0,3.427372535086404e-12 +3200,2400,200,200,1.0,3.427372535086404e-12 +3200,2600,200,200,1.0,3.427372535086404e-12 +3200,2800,200,200,1.0,3.427372535086404e-12 +3200,3000,200,200,1.0,3.427372535086404e-12 +3200,3200,200,200,1.0,3.427372535086404e-12 +3200,3400,200,200,1.0,3.427372535086404e-12 +3200,3600,200,200,1.0,3.427372535086404e-12 +3200,3800,200,200,1.0,3.427372535086404e-12 +3200,4000,200,200,0.9999136924743652,8.626562339486554e-05 +3400,0,200,200,1.0,3.427372535086404e-12 +3400,200,200,200,1.0,3.427372535086404e-12 +3400,400,200,200,1.0,3.427372535086404e-12 +3400,600,200,200,1.0,3.427372535086404e-12 +3400,800,200,200,1.0,3.427372535086404e-12 +3400,1000,200,200,1.0,3.427372535086404e-12 +3400,1200,200,200,1.0,3.427372535086404e-12 +3400,1400,200,200,1.0,3.427372535086404e-12 +3400,1600,200,200,1.0,3.427372535086404e-12 +3400,1800,200,200,1.0,3.427372535086404e-12 +3400,2000,200,200,1.0,3.427372535086404e-12 +3400,2200,200,200,1.0,3.427372535086404e-12 +3400,2400,200,200,1.0,3.427372535086404e-12 +3400,2600,200,200,1.0,3.427372535086404e-12 +3400,2800,200,200,1.0,3.427372535086404e-12 +3400,3000,200,200,1.0,3.427372535086404e-12 +3400,3200,200,200,1.0,3.427372535086404e-12 +3400,3400,200,200,1.0,3.427372535086404e-12 +3400,3600,200,200,1.0,3.427372535086404e-12 +3400,3800,200,200,1.0,3.427372535086404e-12 +3400,4000,200,200,0.9999136924743652,8.626562339486554e-05 +3600,0,200,200,1.0,3.427372535086404e-12 +3600,200,200,200,1.0,3.427372535086404e-12 +3600,400,200,200,1.0,3.427372535086404e-12 +3600,600,200,200,1.0,3.427372535086404e-12 +3600,800,200,200,1.0,3.427372535086404e-12 +3600,1000,200,200,1.0,3.427372535086404e-12 +3600,1200,200,200,1.0,3.427372535086404e-12 +3600,1400,200,200,1.0,3.427372535086404e-12 +3600,1600,200,200,1.0,3.427372535086404e-12 +3600,1800,200,200,1.0,3.427372535086404e-12 +3600,2000,200,200,1.0,3.427372535086404e-12 +3600,2200,200,200,1.0,3.427372535086404e-12 +3600,2400,200,200,1.0,3.427372535086404e-12 +3600,2600,200,200,1.0,3.427372535086404e-12 +3600,2800,200,200,1.0,3.427372535086404e-12 +3600,3000,200,200,1.0,3.427372535086404e-12 +3600,3200,200,200,1.0,3.427372535086404e-12 +3600,3400,200,200,1.0,3.427372535086404e-12 +3600,3600,200,200,1.0,3.427372535086404e-12 +3600,3800,200,200,1.0,3.427372535086404e-12 +3600,4000,200,200,0.9999136924743652,8.626562339486554e-05 +3800,0,200,200,1.0,3.427372535086404e-12 +3800,200,200,200,1.0,3.427372535086404e-12 +3800,400,200,200,1.0,3.427372535086404e-12 +3800,600,200,200,1.0,3.427372535086404e-12 +3800,800,200,200,1.0,3.427372535086404e-12 +3800,1000,200,200,1.0,3.427372535086404e-12 +3800,1200,200,200,1.0,3.427372535086404e-12 +3800,1400,200,200,1.0,3.427372535086404e-12 +3800,1600,200,200,1.0,3.427372535086404e-12 +3800,1800,200,200,1.0,3.427372535086404e-12 +3800,2000,200,200,1.0,3.427372535086404e-12 +3800,2200,200,200,1.0,3.427372535086404e-12 +3800,2400,200,200,1.0,3.427372535086404e-12 +3800,2600,200,200,1.0,3.427372535086404e-12 +3800,2800,200,200,1.0,3.427372535086404e-12 +3800,3000,200,200,1.0,3.427372535086404e-12 +3800,3200,200,200,1.0,3.427372535086404e-12 +3800,3400,200,200,1.0,3.427372535086404e-12 +3800,3600,200,200,1.0,3.427372535086404e-12 +3800,3800,200,200,1.0,3.427372535086404e-12 +3800,4000,200,200,0.9999136924743652,8.626562339486554e-05 +4000,0,200,200,0.9996557235717772,0.0003442850720603 +4000,200,200,200,0.9996557235717772,0.0003442850720603 +4000,400,200,200,0.9996557235717772,0.0003442850720603 +4000,600,200,200,0.9996557235717772,0.0003442850720603 +4000,800,200,200,0.9996557235717772,0.0003442850720603 +4000,1000,200,200,0.9996557235717772,0.0003442850720603 +4000,1200,200,200,0.9996557235717772,0.0003442850720603 +4000,1400,200,200,0.9996557235717772,0.0003442850720603 +4000,1600,200,200,0.9996557235717772,0.0003442850720603 +4000,1800,200,200,0.9996557235717772,0.0003442850720603 +4000,2000,200,200,0.9996557235717772,0.0003442850720603 +4000,2200,200,200,0.9996557235717772,0.0003442850720603 +4000,2400,200,200,0.9996557235717772,0.0003442850720603 +4000,2600,200,200,0.9996557235717772,0.0003442850720603 +4000,2800,200,200,0.9996557235717772,0.0003442850720603 +4000,3000,200,200,0.9996557235717772,0.0003442850720603 +4000,3200,200,200,0.9996557235717772,0.0003442850720603 +4000,3400,200,200,0.9996557235717772,0.0003442850720603 +4000,3600,200,200,0.9996557235717772,0.0003442850720603 +4000,3800,200,200,0.9996557235717772,0.0003442850720603 +4000,4000,200,200,0.9894591569900512,0.0105408066883683 diff --git a/tests/reference/pancreas-tumor-preactresnet34.tcga-paad/purple.csv b/tests/reference/pancreas-tumor-preactresnet34.tcga-paad/purple.csv new file mode 100644 index 0000000..5fb7af0 --- /dev/null +++ b/tests/reference/pancreas-tumor-preactresnet34.tcga-paad/purple.csv @@ -0,0 +1,5 @@ +minx,miny,width,height,prob_tumor +0,0,2100,2100,0.014464837 +0,2100,2100,2100,0.008648535 +2100,0,2100,2100,0.019902095 +2100,2100,2100,2100,0.013171201 diff --git a/tests/reference/prostate-tumor-resnet34.tcga-prad/purple.csv b/tests/reference/prostate-tumor-resnet34.tcga-prad/purple.csv new file mode 100644 index 0000000..13b4264 --- /dev/null +++ b/tests/reference/prostate-tumor-resnet34.tcga-prad/purple.csv @@ -0,0 +1,145 @@ +minx,miny,width,height,prob_grade3,prob_grade4+5,prob_benign +0,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +0,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +350,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +350,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +700,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +700,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +1050,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1050,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +1400,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1400,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +1750,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +1750,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +2100,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2100,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +2450,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2450,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +2800,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +2800,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +3150,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,2800,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,3150,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,3500,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3150,3850,350,350,0.0012054900871589,8.633837569504976e-05,0.9987081289291382 +3500,0,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3500,350,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3500,700,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3500,1050,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3500,1400,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3500,1750,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3500,2100,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3500,2450,350,350,0.0010944225359708,3.372010542079806e-05,0.9988718628883362 +3500,2800,350,350,0.0010944199748337,3.371991260792129e-05,0.9988718628883362 +3500,3150,350,350,0.0010944199748337,3.371991260792129e-05,0.9988718628883362 +3500,3500,350,350,0.0010944199748337,3.371991260792129e-05,0.9988718628883362 +3500,3850,350,350,0.0012054842663928,8.633788820588961e-05,0.9987081289291382 +3850,0,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,350,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,700,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,1050,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,1400,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,1750,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,2100,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,2450,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,2800,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,3150,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,3500,350,350,0.0002117624972015,3.461315645836294e-05,0.9997536540031432 +3850,3850,350,350,0.0006592872668989,5.078692265669815e-05,0.999289870262146 diff --git a/tests/test_all.py b/tests/test_all.py index fa9c101..4c63409 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,14 +1,12 @@ +from __future__ import annotations + import json -import math import os -from pathlib import Path import platform -import subprocess import sys import time -from typing import List +from pathlib import Path -from click.testing import CliRunner import geojson as geojsonlib import h5py import numpy as np @@ -16,10 +14,13 @@ import pytest import tifffile import torch -import yaml +from click.testing import CliRunner -from wsinfer import get_model_weights -from wsinfer import list_all_models_and_weights +from wsinfer.cli.cli import cli +from wsinfer.cli.infer import _get_info_for_save +from wsinfer.modellib.models import get_pretrained_torch_module +from wsinfer.modellib.models import get_registered_model +from wsinfer.modellib.run_inference import jit_compile @pytest.fixture @@ -29,298 +30,84 @@ def tiff_image(tmp_path: Path) -> Path: path = Path(tmp_path / "images" / "purple.tif") path.parent.mkdir(exist_ok=True) - if sys.version_info >= (3, 8): - tifffile.imwrite( - path, - data=x, - compression="zlib", - tile=(256, 256), - # 0.25 micrometers per pixel. - resolution=(40000, 40000), - resolutionunit=tifffile.RESUNIT.CENTIMETER, - ) - else: - # Earlier versions of tifffile do not have resolutionunit kwarg. - tifffile.imwrite( - path, - data=x, - compression="zlib", - tile=(256, 256), - # 0.25 micrometers per pixel. - resolution=(40000, 40000, "CENTIMETER"), - ) - - return path - - -def test_cli_list(tmp_path: Path): - from wsinfer.cli.cli import cli - - runner = CliRunner() - result = runner.invoke(cli, ["list"]) - assert "resnet34" in result.output - assert "TCGA-BRCA-v1" in result.output - assert result.exit_code == 0 - - # Test of WSINFER_PATH registration... check that the models appear in list. - # Test of single WSINFER_PATH. - config_root_single = tmp_path / "configs-single" - config_root_single.mkdir() - configs = [ - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - dict( - version="1.0", - name="foo2", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - ] - for i, config in enumerate(configs): - with open(config_root_single / f"{i}.yaml", "w") as f: - yaml.safe_dump(config, f) - - ret = subprocess.run( - [sys.executable, "-m", "wsinfer", "list"], - capture_output=True, - env=dict(WSINFER_PATH=str(config_root_single)), - ) - assert ret.returncode == 0 - output = ret.stdout.decode() - assert configs[0]["name"] in output # type: ignore - assert configs[0]["architecture"] in output # type: ignore - assert configs[1]["name"] in output # type: ignore - assert configs[1]["architecture"] in output # type: ignore - # Negative control. - ret = subprocess.run([sys.executable, "-m", "wsinfer", "list"], capture_output=True) - assert configs[0]["name"] not in ret.stdout.decode() # type: ignore - del config_root_single, output, ret, config - - # Test of WSINFER_PATH registration... check that the models appear in list. - # Test of multiple WSINFER_PATH. - config_root = tmp_path / "configs" - config_root.mkdir() - config_paths = [config_root / "0", config_root / "1"] - for i, config in enumerate(configs): - config_paths[i].mkdir() - with open(config_paths[i] / f"{i}.yaml", "w") as f: - yaml.safe_dump(config, f) - - ret = subprocess.run( - [sys.executable, "-m", "wsinfer", "list"], - capture_output=True, - env=dict(WSINFER_PATH=":".join(str(c) for c in config_paths)), + tifffile.imwrite( + path, + data=x, + compression="zlib", + tile=(256, 256), + # 0.25 micrometers per pixel. + resolution=(40_000, 40_000), + resolutionunit=tifffile.RESUNIT.CENTIMETER, ) - assert ret.returncode == 0 - output = ret.stdout.decode() - assert configs[0]["name"] in output # type: ignore - assert configs[0]["architecture"] in output # type: ignore - assert configs[1]["name"] in output # type: ignore - assert configs[1]["architecture"] in output # type: ignore - ret = subprocess.run([sys.executable, "-m", "wsinfer", "list"], capture_output=True) - assert configs[0]["name"] not in ret.stdout.decode() # type: ignore - - -def test_cli_run_args(tmp_path: Path): - """Test that (model and weights) or config is required.""" - from wsinfer.cli.cli import cli - - wsi_dir = tmp_path / "slides" - wsi_dir.mkdir() - - runner = CliRunner() - args = [ - "run", - "--wsi-dir", - str(wsi_dir), - "--results-dir", - str(tmp_path / "results"), - ] - # No model, weights, or config. - result = runner.invoke(cli, args) - assert result.exit_code != 0 - assert "one of (model and weights) or config is required." in result.output - # Only one of model and weights. - result = runner.invoke(cli, [*args, "--model", "resnet34"]) - assert result.exit_code != 0 - assert "model and weights must both be set if one is set." in result.output - result = runner.invoke(cli, [*args, "--weights", "TCGA-BRCA-v1"]) - assert result.exit_code != 0 - assert "model and weights must both be set if one is set." in result.output - - # config and model - result = runner.invoke(cli, [*args, "--config", __file__, "--model", "resnet34"]) - assert result.exit_code != 0 - assert "model and weights are mutually exclusive with config." in result.output - # config and weights - result = runner.invoke( - cli, [*args, "--config", __file__, "--weights", "TCGA-BRCA-v1"] - ) - assert result.exit_code != 0 - assert "model and weights are mutually exclusive with config." in result.output + return path +# The reference data for this test was made using a patched version of wsinfer 0.3.6. +# The patches fixed an issue when calculating strides and added padding to images. +# Large-image (which was the backend in 0.3.6) did not pad images and would return +# tiles that were not fully the requested width and height. @pytest.mark.parametrize( + "model", [ - "model", - "weights", - "class_names", - "expected_probs", - "expected_patch_size", - "expected_num_patches", - ], - [ - # Resnet34 TCGA-BRCA-v1 - ( - "resnet34", - "TCGA-BRCA-v1", - ["notumor", "tumor"], - [0.9525967836380005, 0.04740329459309578], - 350, - 144, - ), - # Resnet34 TCGA-LUAD-v1 - ( - "resnet34", - "TCGA-LUAD-v1", - ["lepidic", "benign", "acinar", "micropapillary", "mucinous", "solid"], - [ - 0.012793001718819141, - 0.9792948961257935, - 0.0050891609862446785, - 0.0003837027761619538, - 0.0006556913140229881, - 0.0017834495520219207, - ], - 700, - 36, - ), - # Resnet34 TCGA-PRAD-v1 - ( - "resnet34", - "TCGA-PRAD-v1", - ["grade3", "grade4or5", "benign"], - [0.0010944147361442447, 3.371985076228157e-05, 0.9988718628883362], - 350, - 144, - ), - # Inception_v4 TCGA-BRCA-v1 - ( - "inception_v4", - "TCGA-BRCA-v1", - ["notumor", "tumor"], - [0.9564113020896912, 0.043588679283857346], - 350, - 144, - ), - # Inceptionv4nobn TCGA-TILs-v1 - ( - "inception_v4nobn", - "TCGA-TILs-v1", - ["notils", "tils"], - [1.0, 3.427359524660334e-12], - 200, - 441, - ), - # VGG16 TCGA-TILs-v1 - ( - "vgg16", - "TCGA-TILs-v1", - ["notils", "tils"], - [0.9987693428993224, 0.0012305785203352], - 200, - 441, - ), - # Vgg16mod TCGA-BRCA-v1 - ( - "vgg16mod", - "TCGA-BRCA-v1", - ["notumor", "tumor"], - [0.9108286499977112, 0.089171402156353], - 350, - 144, - ), - # Preactresnet34 TCGA-PAAD-v1 - ( - "preactresnet34", - "TCGA-PAAD-v1", - ["tumor"], - [0.01446483], - 2100, - 4, - ), + "breast-tumor-resnet34.tcga-brca", + "breast-tumor-inception_v4.tcga-brca", + "breast-tumor-vgg16mod.tcga-brca", + "lung-tumor-resnet34.tcga-luad", + "pancancer-lymphocytes-inceptionv4.tcga", + "pancreas-tumor-preactresnet34.tcga-paad", + "prostate-tumor-resnet34.tcga-prad", ], ) @pytest.mark.parametrize("speedup", [False, True]) -def test_cli_run_regression( +@pytest.mark.parametrize("backend", ["openslide", "tiffslide"]) +def test_cli_run_with_registered_models( model: str, - weights: str, - class_names: List[str], - expected_probs: List[float], - expected_patch_size: int, - expected_num_patches: int, speedup: bool, + backend: str, tiff_image: Path, tmp_path: Path, ): - """A regression test of the command 'wsinfer run', using all registered models.""" - from wsinfer.cli.cli import cli + """A regression test of the command 'wsinfer run'.""" + + reference_csv = Path(__file__).parent / "reference" / model / "purple.csv" + if not reference_csv.exists(): + raise FileNotFoundError(f"reference CSV not found: {reference_csv}") runner = CliRunner() results_dir = tmp_path / "inference" result = runner.invoke( cli, [ + "--backend", + backend, "run", "--wsi-dir", str(tiff_image.parent), - "--model", - model, - "--weights", - weights, "--results-dir", str(results_dir), + "--model", + model, "--speedup" if speedup else "--no-speedup", ], ) assert result.exit_code == 0 assert (results_dir / "model-outputs").exists() df = pd.read_csv(results_dir / "model-outputs" / "purple.csv") - class_prob_cols = [f"prob_{c}" for c in class_names] - assert df.columns.tolist() == [ - "slide", - "minx", - "miny", - "width", - "height", - *class_prob_cols, - ] - # TODO: test the metadata.json file as well. - assert df.shape[0] == expected_num_patches - assert (df.loc[:, "slide"] == str(tiff_image)).all() - assert (df.loc[:, "width"] == expected_patch_size).all() - assert (df.loc[:, "height"] == expected_patch_size).all() - # Test probs. - for col, col_prob in zip(class_names, expected_probs): - col = f"prob_{col}" - assert np.allclose(df.loc[:, col], col_prob) + df_ref = pd.read_csv(reference_csv) + + assert set(df.columns) == set(df_ref.columns) + assert df.shape == df_ref.shape + assert np.array_equal(df["minx"], df_ref["minx"]) + assert np.array_equal(df["miny"], df_ref["miny"]) + assert np.array_equal(df["width"], df_ref["width"]) + assert np.array_equal(df["height"], df_ref["height"]) + + prob_cols = df_ref.filter(like="prob_").columns.tolist() + for prob_col in prob_cols: + assert np.allclose( + df[prob_col], df_ref[prob_col], atol=1e-07 + ), f"Column {prob_col} not allclose at atol=1e-07" # Test that metadata path exists. metadata_paths = list(results_dir.glob("run_metadata_*.json")) @@ -329,40 +116,39 @@ def test_cli_run_regression( assert metadata_path.exists() with open(metadata_path) as f: meta = json.load(f) - assert meta.keys() == {"model_weights", "runtime", "timestamp"} - assert meta["model_weights"]["name"] == weights - assert meta["model_weights"]["architecture"] == model - assert meta["model_weights"]["class_names"] == class_names + assert set(meta.keys()) == {"model", "runtime", "timestamp"} + assert "config" in meta["model"] + assert "huggingface_location" in meta["model"] + assert model in meta["model"]["huggingface_location"]["repo_id"] assert meta["runtime"]["python_executable"] == sys.executable assert meta["runtime"]["python_version"] == platform.python_version() assert meta["timestamp"] del metadata_path, meta - # Test conversion scripts. + # Test conversion to geojson. geojson_dir = results_dir / "geojson" result = runner.invoke(cli, ["togeojson", str(results_dir), str(geojson_dir)]) assert result.exit_code == 0 with open(geojson_dir / "purple.json") as f: d: geojsonlib.GeoJSON = geojsonlib.load(f) assert d.is_valid, "geojson not valid!" - assert len(d["features"]) == expected_num_patches + assert len(d["features"]) == len(df_ref) for geojson_row in d["features"]: assert geojson_row["type"] == "Feature" assert geojson_row["id"] == "PathTileObject" assert geojson_row["geometry"]["type"] == "Polygon" - # Check the probability values. - for i, prob in enumerate(expected_probs): - # names have the prefix "prob_". - assert all( - dd["properties"]["measurements"][i]["name"] == class_prob_cols[i] - for dd in d["features"] - ) - assert all( - np.allclose(dd["properties"]["measurements"][i]["value"], prob) - for dd in d["features"] + res = [] + for i, _ in enumerate(prob_cols): + res.append( + np.array( + [dd["properties"]["measurements"][i]["value"] for dd in d["features"]] + ) ) + geojson_probs = np.stack(res, axis=0) + del res + assert np.allclose(df[prob_cols].T, geojson_probs) # Check the coordinate values. for df_row, geojson_row in zip(df.itertuples(), d["features"]): @@ -378,535 +164,140 @@ def test_cli_run_regression( assert [df_coords] == geojson_row["geometry"]["coordinates"] -@pytest.mark.xfail -def test_convert_to_sbu(): - # TODO: create a synthetic output and then convert it. Check that it is valid. - assert False - - -def test_cli_run_from_config(tiff_image: Path, tmp_path: Path): - """This is a form of a regression test.""" - import wsinfer - from wsinfer.cli.cli import cli +def test_cli_run_with_local_model(tmp_path: Path, tiff_image: Path): + model = "breast-tumor-resnet34.tcga-brca" + reference_csv = Path(__file__).parent / "reference" / model / "purple.csv" + if not reference_csv.exists(): + raise FileNotFoundError(f"reference CSV not found: {reference_csv}") + w = get_registered_model(model) + + config = { + "spec_version": "1.0", + "architecture": "resnet34", + "num_classes": 2, + "class_names": ["notumor", "tumor"], + "patch_size_pixels": 350, + "spacing_um_px": 0.25, + "transform": [ + {"name": "Resize", "arguments": {"size": 224}}, + {"name": "ToTensor"}, + { + "name": "Normalize", + "arguments": { + "mean": [0.7238, 0.5716, 0.6779], + "std": [0.112, 0.1459, 0.1089], + }, + }, + ], + } - # Use config for resnet34 TCGA-BRCA-v1 weights. - config = Path(wsinfer.__file__).parent / "modeldefs" / "resnet34_tcga-brca-v1.yaml" - assert config.exists() + config_path = tmp_path / "config.json" + with open(config_path, "w") as f: + json.dump(config, f) runner = CliRunner() results_dir = tmp_path / "inference" result = runner.invoke( cli, [ + "--backend", + "openslide", "run", "--wsi-dir", str(tiff_image.parent), - "--config", - str(config), "--results-dir", str(results_dir), + "--model-path", + w.model_path, + "--config", + str(config_path), ], ) assert result.exit_code == 0 assert (results_dir / "model-outputs").exists() df = pd.read_csv(results_dir / "model-outputs" / "purple.csv") - assert df.columns.tolist() == [ - "slide", - "minx", - "miny", - "width", - "height", - "prob_notumor", - "prob_tumor", - ] - assert (df.loc[:, "slide"] == str(tiff_image)).all() - assert (df.loc[:, "width"] == 350).all() - assert (df.loc[:, "height"] == 350).all() - assert (df.loc[:, "width"] == 350).all() - assert np.allclose(df.loc[:, "prob_notumor"], 0.9525967836380005) - assert np.allclose(df.loc[:, "prob_tumor"], 0.04740329459309578) + df_ref = pd.read_csv(reference_csv) + assert set(df.columns) == set(df_ref.columns) + assert df.shape == df_ref.shape + assert np.array_equal(df["minx"], df_ref["minx"]) + assert np.array_equal(df["miny"], df_ref["miny"]) + assert np.array_equal(df["width"], df_ref["width"]) + assert np.array_equal(df["height"], df_ref["height"]) -@pytest.mark.parametrize( - "modeldef", - [ - [], - {}, - dict(name="foo", architecture="resnet34"), - # Missing url - dict( - version="1.0", - name="foo", - architecture="resnet34", - # url="foo", - # url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # missing url_file_name when url is given - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - # url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # url and file used together - dict( - version="1.0", - name="foo", - architecture="resnet34", - file=__file__, - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # nonexistent file - dict( - version="1.0", - name="foo", - architecture="resnet34", - file="path/to/fake/file", - # url="foo", - # url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # num_classes missing - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - # num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # num classes not equal to len of class names - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=2, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform missing - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - # transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.resize_size missing - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.mean missing - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.std missing - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.resize_size non int - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=0.5, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.resize_size non int - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict( - resize_size=[100, 100], mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] - ), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.mean not a list of three floats - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.mean not a list of three floats - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[1, 1, 1], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.mean not a list of three floats - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=0.5, std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.std not a list of three floats - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, std=[0.5], mean=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.std not a list of three floats - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, std=[1, 1, 1], mean=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # transform.std not a list of three floats - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, std=0.5, mean=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # invalid patch_size_pixels -- list - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=[350], - spacing_um_px=0.25, - class_names=["tumor"], - ), - # invalid patch_size_pixels -- float - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350.0, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # invalid patch_size_pixels -- negative - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=-100, - spacing_um_px=0.25, - class_names=["tumor"], - ), - # invalid spacing_um_px -- zero - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0, - class_names=["tumor"], - ), - # invalid spacing_um_px -- list - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=[0.25], - class_names=["tumor"], - ), - # invalid class_names -- str - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names="t", - ), - # invalid class_names -- len not equal to num_classes - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor", "nontumor"], - ), - # invalid class_names -- not list of str - dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=[1], - ), - # unknown key - dict( - fakekey="foobar", - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["foo"], - ), - # version != '1.0' - dict( - version="2.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["foo"], - ), - ], -) -def test_invalid_modeldefs(modeldef, tmp_path: Path): - from wsinfer._modellib.models import Weights - - path = tmp_path / "foobar.yaml" - with open(path, "w") as f: - yaml.safe_dump(modeldef, f) + prob_cols = df_ref.filter(like="prob_").columns.tolist() + for prob_col in prob_cols: + assert np.allclose( + df[prob_col], df_ref[prob_col], atol=1e-07 + ), f"Column {prob_col} not allclose at atol=1e-07" - with pytest.raises(Exception): - Weights.from_yaml(path) +def test_cli_run_no_model_or_config(tmp_path: Path): + """Test that --model or (--config and --model-path) is required.""" + wsi_dir = tmp_path / "slides" + wsi_dir.mkdir() -def test_valid_modeldefs(tmp_path: Path): - from wsinfer._modellib.models import Weights - - # Put the weights in a different directory than the config to make sure that - # relative paths work. - weights_file = tmp_path / "ckpts" / "weights.pt" - weights_file.parent.mkdir() - modeldef = dict( - version="1.0", - name="foo", - architecture="resnet34", - file="ckpts/weights.pt", - num_classes=2, - transform=dict(resize_size=224, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["foo", "bar"], - ) - path = tmp_path / "foobar.yaml" - with open(path, "w") as f: - yaml.safe_dump(modeldef, f) - - with pytest.raises(FileNotFoundError): - Weights.from_yaml(path) + runner = CliRunner() + args = [ + "run", + "--wsi-dir", + str(wsi_dir), + "--results-dir", + str(tmp_path / "results"), + ] + # No model, weights, or config. + result = runner.invoke(cli, args) + assert result.exit_code != 0 + assert "one of --model or (--config and --model-path) is required" in result.output - weights_file.touch() - w = Weights.from_yaml(path) - assert w.file is not None - assert Path(w.file).exists() +def test_cli_run_model_and_config(tmp_path: Path): + """Test that (model and weights) or config is required.""" + wsi_dir = tmp_path / "slides" + wsi_dir.mkdir() -def test_model_registration(tmp_path: Path): - from wsinfer._modellib import models + fake_config = tmp_path / "foobar.json" + fake_config.touch() + fake_model_path = tmp_path / "foobar.pt" + fake_model_path.touch() - # Test that registering duplicate weights will error. - d = dict( - version="1.0", - name="foo", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["foo"], + runner = CliRunner() + args = [ + "run", + "--wsi-dir", + str(wsi_dir), + "--results-dir", + str(tmp_path / "results"), + "--model", + "colorectal-tiatoolbox-resnet50.kather100k", + "--model-path", + str(fake_model_path), + "--config", + str(fake_config), + ] + # No model, weights, or config. + result = runner.invoke(cli, args) + assert result.exit_code != 0 + assert ( + "--config and --model-path are mutually exclusive with --model" in result.output ) - path = tmp_path / "foobar.yaml" - with open(path, "w") as f: - yaml.safe_dump(d, f) - path = tmp_path / "foobardup.yaml" - with open(path, "w") as f: - yaml.safe_dump(d, f) - with pytest.raises(models.DuplicateModelWeights): - models.register_model_weights(tmp_path) - # Test that registering models will put them in the _known_model_weights object. - path = tmp_path / "configs" / "foobar.yaml" - path.parent.mkdir() - d = dict( - version="1.0", - name="foo2", - architecture="resnet34", - url="foo", - url_file_name="foo", - num_classes=1, - transform=dict(resize_size=299, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["foo"], - ) - with open(path, "w") as f: - yaml.safe_dump(d, f) - models.register_model_weights(path.parent) - assert (d["architecture"], d["name"]) in models._known_model_weights.keys() - assert all( - isinstance(m, models.Weights) for m in models._known_model_weights.values() - ) +@pytest.mark.xfail +def test_convert_to_sbu(): + # TODO: create a synthetic output and then convert it. Check that it is valid. + assert False @pytest.mark.parametrize( ["patch_size", "patch_spacing"], - [(256, 0.25), (256, 0.50), (350, 0.25)], + [(256, 0.25), (256, 0.50), (350, 0.25), (100, 0.3)], ) def test_patch_cli( patch_size: int, patch_spacing: float, tmp_path: Path, tiff_image: Path ): - from wsinfer.cli.cli import cli - - orig_slide_width = 4096 - orig_slide_height = 4096 + """Test of 'wsinfer patch'.""" + orig_slide_size = 4096 orig_slide_spacing = 0.25 runner = CliRunner() @@ -933,29 +324,32 @@ def test_patch_cli( assert (savedir / "stitches" / f"{stem}.jpg").exists() expected_patch_size = round(patch_size * patch_spacing / orig_slide_spacing) - expected_num_patches = math.ceil(4096 / expected_patch_size) ** 2 - expected_coords = [] - for x in range(0, orig_slide_width, expected_patch_size): - for y in range(0, orig_slide_height, expected_patch_size): - expected_coords.append([x, y]) - expected_coords_arr = np.array(expected_coords) + sqrt_expected_num_patches = round(orig_slide_size / expected_patch_size) + expected_num_patches = sqrt_expected_num_patches**2 + expected_coords = [] + for x in range(0, orig_slide_size, expected_patch_size): + for y in range(0, orig_slide_size, expected_patch_size): + # Patch is kept if centroid is inside. + if ( + x + expected_patch_size // 2 <= orig_slide_size + and y + expected_patch_size // 2 <= orig_slide_size + ): + expected_coords.append([x, y]) + assert len(expected_coords) == expected_num_patches with h5py.File(savedir / "patches" / f"{stem}.h5") as f: assert f["/coords"].attrs["patch_size"] == expected_patch_size coords = f["/coords"][()] assert coords.shape == (expected_num_patches, 2) - assert np.array_equal(expected_coords_arr, coords) + assert np.array_equal(expected_coords, coords) -@pytest.mark.parametrize(["model_name", "weights_name"], list_all_models_and_weights()) -def test_jit_compile(model_name: str, weights_name: str): - import time - from wsinfer._modellib.run_inference import jit_compile +# FIXME: parametrize thie test across our models. +def test_jit_compile(): + w = get_registered_model("breast-tumor-resnet34.tcga-brca") + model = get_pretrained_torch_module(w) - w = get_model_weights(model_name, weights_name) - size = w.transform.resize_size - x = torch.ones(20, 3, size, size, dtype=torch.float32) - model = w.load_model() + x = torch.ones(20, 3, 224, 224, dtype=torch.float32) model.eval() NUM_SAMPLES = 1 with torch.no_grad(): @@ -984,10 +378,8 @@ def test_jit_compile(model_name: str, weights_name: str): def test_issue_89(): """Do not fail if 'git' is not installed.""" - from wsinfer.cli.infer import _get_info_for_save - - w = get_model_weights("resnet34", "TCGA-BRCA-v1") - d = _get_info_for_save(w) + model_obj = get_registered_model("breast-tumor-resnet34.tcga-brca") + d = _get_info_for_save(model_obj) assert d assert "git" in d["runtime"] assert d["runtime"]["git"] @@ -998,8 +390,7 @@ def test_issue_89(): orig_path = os.environ["PATH"] try: os.environ["PATH"] = "" - w = get_model_weights("resnet34", "TCGA-BRCA-v1") - d = _get_info_for_save(w) + d = _get_info_for_save(model_obj) assert d assert "git" in d["runtime"] assert d["runtime"]["git"] is None @@ -1009,7 +400,6 @@ def test_issue_89(): def test_issue_94(tmp_path: Path, tiff_image: Path): """Gracefully handle unreadable slides.""" - from wsinfer.cli.cli import cli # We have a valid tiff in 'tiff_image.parent'. We put in an unreadable file too. badpath = tiff_image.parent / "bad.svs" @@ -1023,12 +413,10 @@ def test_issue_94(tmp_path: Path, tiff_image: Path): "run", "--wsi-dir", str(tiff_image.parent), - "--model", - "resnet34", - "--weights", - "TCGA-BRCA-v1", "--results-dir", str(results_dir), + "--model", + "breast-tumor-resnet34.tcga-brca", ], ) # Important part is that we run through all of the files, despite the unreadble @@ -1040,7 +428,6 @@ def test_issue_94(tmp_path: Path, tiff_image: Path): def test_issue_97(tmp_path: Path, tiff_image: Path): """Write a run_metadata file per run.""" - from wsinfer.cli.cli import cli runner = CliRunner() results_dir = tmp_path / "inference" @@ -1050,12 +437,10 @@ def test_issue_97(tmp_path: Path, tiff_image: Path): "run", "--wsi-dir", str(tiff_image.parent), - "--model", - "resnet34", - "--weights", - "TCGA-BRCA-v1", "--results-dir", str(results_dir), + "--model", + "breast-tumor-resnet34.tcga-brca", ], ) assert result.exit_code == 0 @@ -1071,12 +456,10 @@ def test_issue_97(tmp_path: Path, tiff_image: Path): "run", "--wsi-dir", str(tiff_image.parent), - "--model", - "resnet34", - "--weights", - "TCGA-BRCA-v1", "--results-dir", str(results_dir), + "--model", + "breast-tumor-resnet34.tcga-brca", ], ) assert result.exit_code == 0 @@ -1085,24 +468,10 @@ def test_issue_97(tmp_path: Path, tiff_image: Path): def test_issue_125(tmp_path: Path): - from wsinfer.cli.infer import _get_info_for_save - from wsinfer._modellib.models import Weights - from wsinfer._modellib.transforms import PatchClassification - - w = Weights( - name="foo", - architecture="resnet34", - # We are testing whether we can still save if file is a Path instance. - file=Path(__file__), - num_classes=1, - transform=PatchClassification( - resize_size=299, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) - ), - patch_size_pixels=350, - spacing_um_px=0.25, - class_names=["tumor"], - ) + """Test that path in model config can be saved when a pathlib.Path object.""" + w = get_registered_model("breast-tumor-resnet34.tcga-brca") + w.model_path = Path(w.model_path) # type: ignore info = _get_info_for_save(w) with open(tmp_path / "foo.json", "w") as f: json.dump(info, f) diff --git a/wsinfer/__init__.py b/wsinfer/__init__.py index 089a28f..84e2ce7 100644 --- a/wsinfer/__init__.py +++ b/wsinfer/__init__.py @@ -1,13 +1,26 @@ """WSInfer is a toolkit for fast patch-based inference on whole slide images.""" +from __future__ import annotations + from . import _version -from ._modellib.models import get_model_weights # noqa -from ._modellib.models import list_all_models_and_weights # noqa -from ._modellib.models import register_model_weights # noqa -from ._modellib.run_inference import run_inference # noqa -from ._modellib.run_inference import WholeSlideImagePatches # noqa -from ._modellib.transforms import PatchClassification # noqa +from .modellib.run_inference import WholeSlideImagePatches # noqa +from .modellib.run_inference import run_inference # noqa __version__ = _version.get_versions()["version"] del _version + + +# Patch Zarr. See: +# https://github.com/bayer-science-for-a-better-life/tiffslide/issues/72#issuecomment-1627918238 +# https://github.com/zarr-developers/zarr-python/pull/1454 +def _patch_zarr_kvstore(): + from zarr.storage import KVStore + + def _zarr_KVStore___contains__(self, key): + return key in self._mutable_mapping + + KVStore.__contains__ = _zarr_KVStore___contains__ + + +_patch_zarr_kvstore() diff --git a/wsinfer/__main__.py b/wsinfer/__main__.py index 5aa5ca5..84a858c 100644 --- a/wsinfer/__main__.py +++ b/wsinfer/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .cli.cli import cli if __name__ == "__main__": diff --git a/wsinfer/_modellib/inceptionv4_no_batchnorm.py b/wsinfer/_modellib/inceptionv4_no_batchnorm.py deleted file mode 100644 index 41e5137..0000000 --- a/wsinfer/_modellib/inceptionv4_no_batchnorm.py +++ /dev/null @@ -1,366 +0,0 @@ -# https://raw.githubusercontent.com/rwightman/pytorch-image-models/e9aac412de82310e6905992e802b1ee4dc52b5d1/timm/models/inception_v4.py -""" -Pytorch Inception-V4 implementation -Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is -based upon Google's Tensorflow implementation and pretrained weights -(Apache 2.0 License). - -This source was copied into the wsinfer source code and modified to remove batchnorm. -Bias terms are added wherever batchnorm is removed. -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.models import register_model -from timm.models.helpers import build_model_with_cfg -from timm.models.layers import create_classifier - -default_cfgs = { - "inception_v4nobn": { - "url": "", - "num_classes": 1000, - "input_size": (3, 299, 299), - "pool_size": (8, 8), - "crop_pct": 0.875, - "interpolation": "bicubic", - "mean": IMAGENET_INCEPTION_MEAN, - "std": IMAGENET_INCEPTION_STD, - "first_conv": "features.0.conv", - "classifier": "last_linear", - } -} - - -class BasicConv2d(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): - super(BasicConv2d, self).__init__() - self.conv = nn.Conv2d( - in_planes, - out_planes, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=True, # Set to True after removing BatchNorm. - ) - # self.bn = nn.BatchNorm2d(out_planes, eps=0.001) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - x = self.conv(x) - # x = self.bn(x) - x = self.relu(x) - return x - - -class Mixed3a(nn.Module): - def __init__(self): - super(Mixed3a, self).__init__() - self.maxpool = nn.MaxPool2d(3, stride=2) - self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) - - def forward(self, x): - x0 = self.maxpool(x) - x1 = self.conv(x) - out = torch.cat((x0, x1), 1) - return out - - -class Mixed4a(nn.Module): - def __init__(self): - super(Mixed4a, self).__init__() - - self.branch0 = nn.Sequential( - BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1), - ) - - self.branch1 = nn.Sequential( - BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(64, 96, kernel_size=(3, 3), stride=1), - ) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - out = torch.cat((x0, x1), 1) - return out - - -class Mixed5a(nn.Module): - def __init__(self): - super(Mixed5a, self).__init__() - self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) - self.maxpool = nn.MaxPool2d(3, stride=2) - - def forward(self, x): - x0 = self.conv(x) - x1 = self.maxpool(x) - out = torch.cat((x0, x1), 1) - return out - - -class InceptionA(nn.Module): - def __init__(self): - super(InceptionA, self).__init__() - self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) - - self.branch1 = nn.Sequential( - BasicConv2d(384, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), - ) - - self.branch2 = nn.Sequential( - BasicConv2d(384, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), - BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1), - ) - - self.branch3 = nn.Sequential( - nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(384, 96, kernel_size=1, stride=1), - ) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - x2 = self.branch2(x) - x3 = self.branch3(x) - out = torch.cat((x0, x1, x2, x3), 1) - return out - - -class ReductionA(nn.Module): - def __init__(self): - super(ReductionA, self).__init__() - self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) - - self.branch1 = nn.Sequential( - BasicConv2d(384, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), - BasicConv2d(224, 256, kernel_size=3, stride=2), - ) - - self.branch2 = nn.MaxPool2d(3, stride=2) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - x2 = self.branch2(x) - out = torch.cat((x0, x1, x2), 1) - return out - - -class InceptionB(nn.Module): - def __init__(self): - super(InceptionB, self).__init__() - self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) - - self.branch1 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)), - ) - - self.branch2 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), - ) - - self.branch3 = nn.Sequential( - nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(1024, 128, kernel_size=1, stride=1), - ) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - x2 = self.branch2(x) - x3 = self.branch3(x) - out = torch.cat((x0, x1, x2, x3), 1) - return out - - -class ReductionB(nn.Module): - def __init__(self): - super(ReductionB, self).__init__() - - self.branch0 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=3, stride=2), - ) - - self.branch1 = nn.Sequential( - BasicConv2d(1024, 256, kernel_size=1, stride=1), - BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(320, 320, kernel_size=3, stride=2), - ) - - self.branch2 = nn.MaxPool2d(3, stride=2) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - x2 = self.branch2(x) - out = torch.cat((x0, x1, x2), 1) - return out - - -class InceptionC(nn.Module): - def __init__(self): - super(InceptionC, self).__init__() - - self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) - - self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch1_1a = BasicConv2d( - 384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch1_1b = BasicConv2d( - 384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) - - self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch2_1 = BasicConv2d( - 384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) - self.branch2_2 = BasicConv2d( - 448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch2_3a = BasicConv2d( - 512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch2_3b = BasicConv2d( - 512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) - - self.branch3 = nn.Sequential( - nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(1536, 256, kernel_size=1, stride=1), - ) - - def forward(self, x): - x0 = self.branch0(x) - - x1_0 = self.branch1_0(x) - x1_1a = self.branch1_1a(x1_0) - x1_1b = self.branch1_1b(x1_0) - x1 = torch.cat((x1_1a, x1_1b), 1) - - x2_0 = self.branch2_0(x) - x2_1 = self.branch2_1(x2_0) - x2_2 = self.branch2_2(x2_1) - x2_3a = self.branch2_3a(x2_2) - x2_3b = self.branch2_3b(x2_2) - x2 = torch.cat((x2_3a, x2_3b), 1) - - x3 = self.branch3(x) - - out = torch.cat((x0, x1, x2, x3), 1) - return out - - -class InceptionV4(nn.Module): - def __init__( - self, - num_classes=1000, - in_chans=3, - output_stride=32, - drop_rate=0.0, - global_pool="avg", - ): - super(InceptionV4, self).__init__() - assert output_stride == 32 - self.drop_rate = drop_rate - self.num_classes = num_classes - self.num_features = 1536 - - self.features = nn.Sequential( - BasicConv2d(in_chans, 32, kernel_size=3, stride=2), - BasicConv2d(32, 32, kernel_size=3, stride=1), - BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), - Mixed3a(), - Mixed4a(), - Mixed5a(), - InceptionA(), - InceptionA(), - InceptionA(), - InceptionA(), - ReductionA(), # Mixed6a - InceptionB(), - InceptionB(), - InceptionB(), - InceptionB(), - InceptionB(), - InceptionB(), - InceptionB(), - ReductionB(), # Mixed7a - InceptionC(), - InceptionC(), - InceptionC(), - ) - self.feature_info = [ - dict(num_chs=64, reduction=2, module="features.2"), - dict(num_chs=160, reduction=4, module="features.3"), - dict(num_chs=384, reduction=8, module="features.9"), - dict(num_chs=1024, reduction=16, module="features.17"), - dict(num_chs=1536, reduction=32, module="features.21"), - ] - self.global_pool, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool - ) - - @torch.jit.ignore - def group_matcher(self, coarse=False): - return dict(stem=r"^features\.[012]\.", blocks=r"^features\.(\d+)") - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - assert not enable, "gradient checkpointing not supported" - - @torch.jit.ignore - def get_classifier(self): - return self.last_linear - - def reset_classifier(self, num_classes, global_pool="avg"): - self.num_classes = num_classes - self.global_pool, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool - ) - - def forward_features(self, x): - return self.features(x) - - def forward_head(self, x, pre_logits: bool = False): - x = self.global_pool(x) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) - return x if pre_logits else self.last_linear(x) - - def forward(self, x): - x = self.forward_features(x) - x = self.forward_head(x) - return x - - -def _create_inception_v4(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - InceptionV4, - variant, - pretrained, - feature_cfg=dict(flatten_sequential=True), - **kwargs - ) - - -@register_model -def inception_v4nobn(pretrained=False, **kwargs): - return _create_inception_v4("inception_v4nobn", pretrained, **kwargs) diff --git a/wsinfer/_modellib/models.py b/wsinfer/_modellib/models.py deleted file mode 100644 index 76b4e84..0000000 --- a/wsinfer/_modellib/models.py +++ /dev/null @@ -1,309 +0,0 @@ -import dataclasses -import hashlib -import os -from pathlib import Path -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -import timm -import torch -from torch.hub import load_state_dict_from_url -import yaml - -# Imported for side effects of registering model. -from . import inceptionv4_no_batchnorm as _ # noqa -from .resnet_preact import resnet34_preact as _resnet34_preact -from .vgg16mod import vgg16mod as _vgg16mod -from .transforms import PatchClassification - - -class WsinferException(Exception): - "Base class for wsinfer exceptions." - - -class UnknownArchitectureError(WsinferException): - """Architecture is unknown and cannot be found.""" - - -class ModelWeightsNotFound(WsinferException): - """Model weights are not found, likely because they are not in the registry.""" - - -class DuplicateModelWeights(WsinferException): - """A duplicate key was passed to the model weights registry.""" - - -class ModelRegistrationError(WsinferException): - """Error during model registration.""" - - -PathType = Union[str, Path] - - -def _sha256sum(path: PathType) -> str: - """Calculate SHA256 of a file.""" - sha = hashlib.sha256() - with open(path, "rb") as f: - while True: - data = f.read(1024 * 64) # 64 kb - if not data: - break - sha.update(data) - return sha.hexdigest() - - -@dataclasses.dataclass -class Weights: - """Container for data associated with a trained model.""" - - name: str - architecture: str - num_classes: int - transform: PatchClassification - patch_size_pixels: int - spacing_um_px: float - class_names: List[str] - url: Optional[str] = None - url_file_name: Optional[str] = None - file: Optional[Union[str, Path]] = None - metadata: Optional[Dict[str, Any]] = None - - def __post_init__(self): - if len(set(self.class_names)) != len(self.class_names): - raise ValueError("class_names cannot contain duplicates") - if len(self.class_names) != self.num_classes: - raise ValueError("length of class_names must be equal to num_classes") - - @staticmethod - def _validate_input(d: dict, config_path: Path) -> None: - """Raise error if invalid input.""" - - if not isinstance(d, dict): - raise ValueError("expected config to be a dictionary") - - # Validate contents. - # Validate keys. - required_keys = [ - "version", - "name", - "architecture", - "num_classes", - "transform", - "patch_size_pixels", - "spacing_um_px", - "class_names", - ] - optional_keys = ["url", "url_file_name", "file", "metadata"] - all_keys = required_keys + optional_keys - for req_key in required_keys: - if req_key not in d.keys(): - raise KeyError(f"required key not found: '{req_key}'") - unknown_keys = [k for k in d.keys() if k not in all_keys] - if unknown_keys: - raise KeyError(f"unknown keys: {unknown_keys}") - for req_key in ["resize_size", "mean", "std"]: - if req_key not in d["transform"].keys(): - raise KeyError( - f"required key not found in 'transform' section: '{req_key}'" - ) - - # We include a 'version' key so we can handle updates if needed in the future. - # At this point, we only support version 1.0. - if d["version"] != "1.0": - raise ValueError("config file must include version: '1.0'.") - # Either 'url' or 'file' is required. If 'url' is used, then 'url_file_name' is - # required. - if "url" not in d.keys() and "file" not in d.keys(): - raise KeyError("'url' or 'file' must be provided") - if "url" in d.keys() and "file" in d.keys(): - raise KeyError("only on of 'url' and 'file' can be used") - if "url" in d.keys() and "url_file_name" not in d.keys(): - raise KeyError("when using 'url', 'url_file_name' must also be provided") - - # Validate types. - if not isinstance("architecture", str): - raise ValueError("'architecture' must be a string") - if not isinstance("name", str): - raise ValueError("'name' must be a string") - if "url" in d.keys() and not isinstance(d["url"], str): - raise ValueError("'url' must be a string") - if "url_file_name" in d.keys() and not isinstance(d["url"], str): - raise ValueError("'url_file_name' must be a string") - if not isinstance(d["num_classes"], int): - raise ValueError("'num_classes' must be an integer") - if not isinstance(d["transform"]["resize_size"], int): - raise ValueError("'transform.resize_size' must be an integer") - if not isinstance(d["transform"]["mean"], list): - raise ValueError("'transform.mean' must be a list") - if not all(isinstance(num, float) for num in d["transform"]["mean"]): - raise ValueError("'transform.mean' must be a list of floats") - if not isinstance(d["transform"]["std"], list): - raise ValueError("'transform.std' must be a list") - if not all(isinstance(num, float) for num in d["transform"]["std"]): - raise ValueError("'transform.std' must be a list of floats") - if not isinstance(d["patch_size_pixels"], int) or d["patch_size_pixels"] <= 0: - raise ValueError("patch_size_pixels must be a positive integer") - if not isinstance(d["spacing_um_px"], float) or d["spacing_um_px"] <= 0: - raise ValueError("spacing_um_px must be a positive float") - if not isinstance(d["class_names"], list): - raise ValueError("'class_names' must be a list") - if not all(isinstance(c, str) for c in d["class_names"]): - raise ValueError("'class_names' must be a list of strings") - - # Validate values. - if len(d["transform"]["mean"]) != 3: - raise ValueError("transform.mean must be a list of three numbers") - if len(d["transform"]["std"]) != 3: - raise ValueError("transform.std must be a list of three numbers") - if len(d["class_names"]) != len(set(d["class_names"])): - raise ValueError("duplicate values found in 'class_names'") - if len(d["class_names"]) != d["num_classes"]: - raise ValueError("mismatch between length of class_names and num_classes.") - if "file" in d.keys(): - file = Path(config_path).parent / d["file"] - file = file.resolve() - if not file.exists(): - raise FileNotFoundError(f"'file' not found: {file}") - - @classmethod - def from_yaml(cls, path): - """Create a new instance of Weights from a YAML file.""" - path = Path(path) - - with open(path) as f: - d = yaml.safe_load(f) - cls._validate_input(d, config_path=path) - - transform = PatchClassification( - resize_size=d["transform"]["resize_size"], - mean=d["transform"]["mean"], - std=d["transform"]["std"], - ) - if d.get("file") is not None: - file = path.parent / d.get("file") - else: - file = None - return Weights( - name=d["name"], - architecture=d["architecture"], - url=d.get("url"), - url_file_name=d.get("url_file_name"), - file=file, - num_classes=d["num_classes"], - transform=transform, - patch_size_pixels=d["patch_size_pixels"], - spacing_um_px=d["spacing_um_px"], - class_names=d["class_names"], - ) - - def load_model(self) -> torch.nn.Module: - """Return the pytorch implementation of the architecture with weights loaded.""" - model = _create_model(name=self.architecture, num_classes=self.num_classes) - - # Load state dict. - if self.url and self.url_file_name: - state_dict = load_state_dict_from_url( - url=self.url, - map_location="cpu", - check_hash=True, - file_name=self.url_file_name, - ) - elif self.file: - state_dict = torch.load(self.file, map_location="cpu") - # When training with timm scripts, weights are saved in 'state_dict' key. - if "state_dict" in state_dict.keys(): - state_dict = state_dict["state_dict"] - else: - raise RuntimeError("cannot find weights") - - model.load_state_dict(state_dict, strict=True) - model.eval() - return model - - def get_sha256_of_weights(self) -> str: - """Return the sha256 of the weights file.""" - if self.url and self.url_file_name: - p = Path(torch.hub.get_dir()) / "checkpoints" / self.url_file_name - elif self.file: - p = Path(self.file) - else: - raise RuntimeError("cannot find path to weights") - sha = _sha256sum(p) - return sha - - -# Container for all models we can use that are not in timm. -_model_registry: Dict[str, Callable[[int], torch.nn.Module]] = { - "preactresnet34": _resnet34_preact, - "vgg16mod": _vgg16mod, -} - - -def _create_model(name: str, num_classes: int) -> torch.nn.Module: - """Return a torch model architecture.""" - if name in _model_registry.keys(): - return _model_registry[name](num_classes) - else: - if name not in timm.list_models(): - raise UnknownArchitectureError(f"unknown architecture: '{name}'") - return timm.create_model(name, num_classes=num_classes) - - -# Keys are tuple of (architecture, weights_name). -_known_model_weights: Dict[Tuple[str, str], Weights] = {} - - -def register_model_weights(root: Path): - modeldefs = list(root.glob("*.yml")) + list(root.glob("*.yaml")) - for modeldef in modeldefs: - try: - w = Weights.from_yaml(modeldef) - except Exception as e: - raise ModelRegistrationError( - f"Error registering model from config file ('{modeldef}')\n" - f"Original error is: {e}" - ) - if w.architecture not in timm.list_models() + list(_model_registry.keys()): - raise UnknownArchitectureError(f"{w.architecture} implementation not found") - key = (w.architecture, w.name) - if key in _known_model_weights: - raise DuplicateModelWeights( - f"duplicate models weights: {(w.architecture, w.name)}" - ) - _known_model_weights[key] = w - - -def get_model_weights(architecture: str, name: str) -> Weights: - """Get weights object for an architecture and weights name.""" - key = (architecture, name) - try: - return _known_model_weights[key] - except KeyError: - pairs = " | ".join(" / ".join(p) for p in list_all_models_and_weights()) - raise ModelWeightsNotFound( - f"Invalid model-weight pair: '{architecture}' and '{name}'. Available" - f" models/weight pairs are {pairs}." - ) - - -register_model_weights(Path(__file__).parent / ".." / "modeldefs") - -# Register any user-supplied configurations. -wsinfer_path = os.environ.get("WSINFER_PATH") -if wsinfer_path is not None: - for path in wsinfer_path.split(":"): - register_model_weights(Path(path)) - del path -del wsinfer_path - - -def list_all_models_and_weights() -> List[Tuple[str, str]]: - """Return list of tuples of `(model_name, weights_name)` with available pairs.""" - vals = list(_known_model_weights.keys()) - vals.sort() - return vals diff --git a/wsinfer/_modellib/resnet_preact.py b/wsinfer/_modellib/resnet_preact.py deleted file mode 100644 index a112258..0000000 --- a/wsinfer/_modellib/resnet_preact.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Pre-activation ResNet.""" - -import torch.nn as nn -import torch.nn.functional as F - - -class PreActBlock(nn.Module): - """Pre-activation version of the BasicBlock.""" - - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super().__init__() - self.bn1 = nn.BatchNorm2d(in_planes) - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, stride=1, padding=1, bias=False - ) - - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d( - in_planes, - self.expansion * planes, - kernel_size=1, - stride=stride, - bias=False, - ) - ) - - def forward(self, x): - out = F.relu(self.bn1(x)) - shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x - out = self.conv1(out) - out = self.conv2(F.relu(self.bn2(out))) - out += shortcut - return out - - -class PreActResNet(nn.Module): - def __init__(self, block, num_blocks, num_classes=1): - super(PreActResNet, self).__init__() - self.in_planes = 64 - - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - - self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.linear = nn.Linear(512 * block.expansion, num_classes) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - out = self.conv1(x) - out = self.maxpool(out) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) - out = F.avg_pool2d(out, out.size(2)) - out = out.view(out.size(0), -1) - out = self.linear(out) - return out - - -def resnet34_preact(num_classes: int): - return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes=num_classes) diff --git a/wsinfer/_modellib/run_inference.py b/wsinfer/_modellib/run_inference.py deleted file mode 100644 index b494282..0000000 --- a/wsinfer/_modellib/run_inference.py +++ /dev/null @@ -1,438 +0,0 @@ -"""Run inference. - -From the original paper (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7369575/): -> In the prediction (test) phase, no data augmentation was applied except for the -> normalization of the color channels. -""" - -from pathlib import Path -import typing -import warnings - -import h5py -import large_image -import numpy as np -import pandas as pd -from PIL import Image -import torch -import tqdm - -from .models import Weights - -PathType = typing.Union[str, Path] - - -class WholeSlideImageDirectoryNotFound(FileNotFoundError): - ... - - -class WholeSlideImagesNotFound(FileNotFoundError): - ... - - -class ResultsDirectoryNotFound(FileNotFoundError): - ... - - -class PatchDirectoryNotFound(FileNotFoundError): - ... - - -# Set the maximum number of TileSource objects to cache. We use 1 to minimize how many -# file handles we keep open. -large_image.config.setConfig("cache_tilesource_maximum", 1) - - -def _read_patch_coords(path: PathType) -> np.ndarray: - """Read HDF5 file of patch coordinates are return numpy array. - - Returned array has shape (num_patches, 4). Each row has values - [minx, miny, width, height]. - """ - with h5py.File(path, mode="r") as f: - coords = f["/coords"][()] - coords_metadata = f["/coords"].attrs - if "patch_level" not in coords_metadata.keys(): - raise KeyError( - "Could not find required key 'patch_level' in hdf5 of patch " - "coordinates. Has the version of CLAM been updated?" - ) - patch_level = coords_metadata["patch_level"] - if patch_level != 0: - raise NotImplementedError( - f"This script is designed for patch_level=0 but got {patch_level}" - ) - if coords.ndim != 2: - raise ValueError(f"expected coords to have 2 dimensions, got {coords.ndim}") - if coords.shape[1] != 2: - raise ValueError( - f"expected second dim of coords to have len 2 but got {coords.shape[1]}" - ) - - if "patch_size" not in coords_metadata.keys(): - raise KeyError("expected key 'patch_size' in attrs of coords dataset") - # Append width and height values to the coords, so now each row is - # [minx, miny, width, height] - wh = np.full_like(coords, coords_metadata["patch_size"]) - coords = np.concatenate((coords, wh), axis=1) - - return coords - - -def _filter_patches_in_rois( - *, geojson_path: PathType, coords: np.ndarray -) -> np.ndarray: - """Keep the patches that intersect the ROI(s). - - Parameters - ---------- - geojson_path : str, Path - Path to the GeoJSON file that encodes the points of the ROI(s). - coords : ndarray - Two-dimensional array where each row has minx, miny, width, height. - - Returns - ------- - ndarray of filtered coords. - """ - import geojson - from shapely import STRtree - from shapely.geometry import box, shape - - with open(geojson_path) as f: - geo = geojson.load(f) - if not geo.is_valid: - raise ValueError("GeoJSON of ROI is not valid") - for roi in geo["features"]: - assert roi.is_valid, "an ROI geometry is not valid" - geoms_rois = [shape(roi["geometry"]) for roi in geo["features"]] - coords_orig = coords.copy() - coords = coords.copy() - coords[:, 2] += coords[:, 0] # Calculate maxx. - coords[:, 3] += coords[:, 1] # Calculate maxy. - boxes = [box(*coords[idx]) for idx in range(coords.shape[0])] - tree = STRtree(boxes) - _, intersecting_ids = tree.query(geoms_rois, predicate="intersects") - intersecting_ids = np.sort(np.unique(intersecting_ids)) - return coords_orig[intersecting_ids] - - -class WholeSlideImagePatches(torch.utils.data.Dataset): - """Dataset of one whole slide image. - - This object retrieves patches from a whole slide image on the fly. - - Parameters - ---------- - wsi_path : str, Path - Path to whole slide image file. - patch_path : str, Path - Path to HDF5 file with coordinates of input image. - um_px : float - Scale of the resulting patches. Use 0.5 for 20x magnification. - transform : callable, optional - A callable to modify a retrieved patch. The callable must accept a - PIL.Image.Image instance and return a torch.Tensor. - roi_path : str, Path, optional - Path to GeoJSON file that outlines the region of interest (ROI). Only patches - within the ROI(s) will be used. - """ - - def __init__( - self, - wsi_path: PathType, - patch_path: PathType, - um_px: float, - transform: typing.Optional[typing.Callable[[Image.Image], torch.Tensor]] = None, - roi_path: typing.Optional[PathType] = None, - ): - self.wsi_path = wsi_path - self.patch_path = patch_path - self.um_px = float(um_px) - self.transform = transform - self.roi_path = roi_path - - assert Path(wsi_path).exists(), "wsi path not found" - assert Path(patch_path).exists(), "patch path not found" - if roi_path is not None: - assert Path(roi_path).exists(), "roi path not found" - - self.tilesource: large_image.tilesource.TileSource = large_image.getTileSource( - self.wsi_path - ) - # Disable the tile cache. We wrap this in a try-except because we are accessing - # a private attribute. It is possible that this attribute will change names - # in the future, and if that happens, we do not want to raise errors. - try: - self.tilesource.cache._Cache__maxsize = 0 - except AttributeError: - pass - - self.patches = _read_patch_coords(self.patch_path) - - # If an ROI is given, keep patches that intersect it. - if self.roi_path is not None: - self.patches = _filter_patches_in_rois( - geojson_path=self.roi_path, coords=self.patches - ) - if self.patches.shape[0] == 0: - raise ValueError("No patches left after taking intersection with ROI") - - assert self.patches.ndim == 2, "expected 2D array of patch coordinates" - # x, y, width, height - assert self.patches.shape[1] == 4, "expected second dimension to have len 4" - - def __len__(self): - return self.patches.shape[0] - - def __getitem__( - self, idx: int - ) -> typing.Tuple[typing.Union[Image.Image, torch.Tensor], torch.Tensor]: - coords: typing.Sequence[int] = self.patches[idx] - assert len(coords) == 4, "expected 4 coords (minx, miny, width, height)" - minx, miny, width, height = coords - source_region = dict( - left=minx, top=miny, width=width, height=height, units="base_pixels" - ) - target_scale = dict(mm_x=self.um_px / 1000) - - patch_im, _ = self.tilesource.getRegionAtAnotherScale( - sourceRegion=source_region, - targetScale=target_scale, - format=large_image.tilesource.TILE_FORMAT_PIL, - ) - patch_im = patch_im.convert("RGB") - if self.transform is not None: - patch_im = self.transform(patch_im) - if not isinstance(patch_im, (Image.Image, torch.Tensor)): - raise TypeError( - f"patch image must be an Image of Tensor, but got {type(patch_im)}" - ) - return patch_im, torch.as_tensor([minx, miny, width, height]) - - -def jit_compile( - model: torch.nn.Module, -) -> typing.Union[torch.jit.ScriptModule, torch.nn.Module, typing.Callable]: - """JIT-compile a model for inference.""" - noncompiled = model - device = next(model.parameters()).device - # Attempt to script. If it fails, return the original. - test_input = torch.ones(1, 3, 224, 224).to(device) - w = "Warning: could not JIT compile the model. Using non-compiled model instead." - # TODO: consider freezing the model as well. - # PyTorch 2.x has torch.compile. - if hasattr(torch, "compile"): - # Try to get the most optimized model. - try: - return torch.compile(model, fullgraph=True, mode="max-autotune") - except Exception: - pass - try: - return torch.compile(model, mode="max-autotune") - except Exception: - pass - try: - return torch.compile(model) - except Exception: - warnings.warn(w) - return noncompiled - # For pytorch 1.x, use torch.jit.script. - else: - try: - mjit = torch.jit.script(model) - with torch.no_grad(): - mjit(test_input) - except Exception: - warnings.warn(w) - return noncompiled - # Now that we have scripted the model, try to optimize it further. If that - # fails, return the scripted model. - try: - mjit_frozen = torch.jit.freeze(mjit) - mjit_opt = torch.jit.optimize_for_inference(mjit_frozen) - with torch.no_grad(): - mjit_opt(test_input) - return mjit_opt - except Exception: - return mjit - - -def run_inference( - wsi_dir: PathType, - results_dir: PathType, - weights: Weights, - batch_size: int = 32, - num_workers: int = 0, - speedup: bool = False, - roi_dir: typing.Optional[PathType] = None, -) -> typing.Tuple[typing.List[str], typing.List[str]]: - """Run model inference on a directory of whole slide images and save results to CSV. - - This assumes the patching has already been done and the results are stored in - `results_dir`. An error will be raised otherwise. - - Output CSV files are written to `{results_dir}/model-outputs/`. - - Parameters - ---------- - wsi_dir : str or Path - Directory containing whole slide images. This directory can *only* contain - whole slide images. Otherwise, an error will be raised during model inference. - results_dir : str or Path - Directory containing results of patching. - weights : wsinfer._modellib.models.Weights - Instance of Weights including the model object and information about how to - apply the model to new data. - batch_size : int - The batch size during the forward pass (default is 32). - num_workers : int - Number of workers for data loading (default is 0, meaning use a single thread). - speedup : bool - If True, JIT-compile the model. This has a startup cost but model inference - should be faster (default False). - roi_dir : str, Path, optional - Directory containing GeoJSON files that outlines the regions of interest (ROI). - Only patches within the ROI(s) will be used. The GeoJSON files must have the - extension ".json". - - Returns - ------- - A tuple of two lists of strings. The first list contains the slide IDs for which - patching failed, and the second list contains the slide IDs for which model - inference failed. - """ - # Make sure required directories exist. - wsi_dir = Path(wsi_dir) - if not wsi_dir.exists(): - raise WholeSlideImageDirectoryNotFound(f"directory not found: {wsi_dir}") - wsi_paths = list(wsi_dir.glob("*")) - if not wsi_paths: - raise WholeSlideImagesNotFound(wsi_dir) - results_dir = Path(results_dir) - if not results_dir.exists(): - raise ResultsDirectoryNotFound(results_dir) - - # Check patches directory. - patch_dir = results_dir / "patches" - if not patch_dir.exists(): - raise PatchDirectoryNotFound("Results dir must include 'patches' dir") - # Create the patch paths based on the whole slide image paths. In effect, only - # create patch paths if the whole slide image patch exists. - patch_paths = [patch_dir / p.with_suffix(".h5").name for p in wsi_paths] - - model_output_dir = results_dir / "model-outputs" - model_output_dir.mkdir(exist_ok=True) - - model = weights.load_model() - model.eval() - - if torch.cuda.is_available(): - device = torch.device("cuda") - if torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) - elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): - device = torch.device("mps") - else: - device = torch.device("cpu") - print(f'Using device "{device}"') - - model.to(device) - - if speedup: - if typing.TYPE_CHECKING: - model = typing.cast(torch.nn.Module, jit_compile(model)) - else: - model = jit_compile(model) - - failed_patching = [p.stem for p in patch_paths if not p.exists()] - failed_inference: typing.List[str] = [] - - # Get paths to ROI geojson files. - if roi_dir is not None: - roi_paths = [Path(roi_dir) / p.with_suffix(".json").name for p in wsi_paths] - else: - roi_paths = None - - # results_for_all_slides: typing.List[pd.DataFrame] = [] - for i, (wsi_path, patch_path) in enumerate(zip(wsi_paths, patch_paths)): - print(f"Slide {i+1} of {len(wsi_paths)}") - print(f" Slide path: {wsi_path}") - print(f" Patch path: {patch_path}") - - slide_csv_name = Path(wsi_path).with_suffix(".csv").name - slide_csv = model_output_dir / slide_csv_name - if slide_csv.exists(): - print("Output CSV exists... skipping.") - print(slide_csv) - continue - - if not patch_path.exists(): - print(f"Skipping because patch file not found: {patch_path}") - continue - - roi_path = None - if roi_paths is not None: - roi_path = roi_paths[i] - # We grab all potential names of ROI paths, but we do not require all of - # them to exist. We only use those that exist. - if not roi_path.exists(): - roi_path = None - else: - print(f" ROI path: {roi_path}") - - try: - dset = WholeSlideImagePatches( - wsi_path=wsi_path, - patch_path=patch_path, - um_px=weights.spacing_um_px, - transform=weights.transform, - roi_path=roi_path, - ) - except Exception: - failed_inference.append(wsi_dir.stem) - continue - - loader = torch.utils.data.DataLoader( - dset, - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - ) - - # Store the coordinates and model probabiltiies of each patch in this slide. - # This lets us know where the probabiltiies map to in the slide. - slide_coords: typing.List[np.ndarray] = [] - slide_probs: typing.List[np.ndarray] = [] - for batch_imgs, batch_coords in tqdm.tqdm(loader): - assert batch_imgs.shape[0] == batch_coords.shape[0], "length mismatch" - with torch.no_grad(): - logits: torch.Tensor = model(batch_imgs.to(device)).detach().cpu() - # probs has shape (batch_size, num_classes) or (batch_size,) - if len(logits.shape) > 1 and logits.shape[1] > 1: - probs = torch.nn.functional.softmax(logits, dim=1) - else: - probs = torch.sigmoid(logits.squeeze(1)) - slide_coords.append(batch_coords.numpy()) - slide_probs.append(probs.numpy()) - - slide_coords_arr = np.concatenate(slide_coords, axis=0) - slide_df = pd.DataFrame( - dict( - slide=wsi_path, - minx=slide_coords_arr[:, 0], - miny=slide_coords_arr[:, 1], - width=slide_coords_arr[:, 2], - height=slide_coords_arr[:, 3], - ) - ) - slide_probs_arr = np.concatenate(slide_probs, axis=0) - # Use 'prob-' prefix for all classes. This should make it clearer that the - # column has probabilities for the class. It also makes it easier for us to - # identify columns associated with probabilities. - prob_colnames = [f"prob_{c}" for c in weights.class_names] - slide_df.loc[:, prob_colnames] = slide_probs_arr - slide_df.to_csv(slide_csv, index=False) - print("-" * 40) - - return failed_patching, failed_inference diff --git a/wsinfer/_modellib/transforms.py b/wsinfer/_modellib/transforms.py deleted file mode 100644 index 1f329f4..0000000 --- a/wsinfer/_modellib/transforms.py +++ /dev/null @@ -1,66 +0,0 @@ -"""PyTorch image classification transform. - -From -https://github.com/pytorch/vision/blob/528651a031a08f9f97cc75bd619a326387708219/torchvision/transforms/_presets.py#L1 -""" - -from typing import Tuple -from typing import Union - -from PIL import Image -import torch -from torchvision.transforms import functional as F - -# Get the interpolation mode while catering to older (and newer) versions of -# torchvision and PIL. -if hasattr(F, "InterpolationMode"): - BICUBIC = F.InterpolationMode.BICUBIC - BILINEAR = F.InterpolationMode.BILINEAR - LINEAR = F.InterpolationMode.BILINEAR - NEAREST = F.InterpolationMode.NEAREST -elif hasattr(Image, "Resampling"): - BICUBIC = Image.Resampling.BICUBIC - BILINEAR = Image.Resampling.BILINEAR - LINEAR = Image.Resampling.BILINEAR - NEAREST = Image.Resampling.NEAREST -else: - BICUBIC = Image.BICUBIC - BILINEAR = Image.BILINEAR - NEAREST = Image.NEAREST - LINEAR = Image.LINEAR - - -class PatchClassification(torch.nn.Module): - """Transform module to process RGB patches.""" - - def __init__( - self, - *, - resize_size: int, - mean: Tuple[float, float, float], - std: Tuple[float, float, float], - interpolation=BILINEAR, - ) -> None: - super().__init__() - self.resize_size = resize_size - self.mean = list(mean) - self.std = list(std) - self.interpolation = interpolation - - def __repr__(self): - return ( - f"PatchClassification(resize_size={self.resize_size}, mean={self.mean}," - f" std={self.std}, interpolation={self.interpolation})" - ) - - def forward(self, input: Union[Image.Image, torch.Tensor]) -> torch.Tensor: - img = F.resize( - input, - [self.resize_size, self.resize_size], - interpolation=self.interpolation, - ) - if not isinstance(img, torch.Tensor): - img = F.pil_to_tensor(img) - img = F.convert_image_dtype(img, torch.float) - img = F.normalize(img, mean=self.mean, std=self.std) - return img diff --git a/wsinfer/_modellib/vgg16mod.py b/wsinfer/_modellib/vgg16mod.py deleted file mode 100644 index 0948f36..0000000 --- a/wsinfer/_modellib/vgg16mod.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Implementation of VGG16 in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7369575.""" - -import torch -import torchvision - - -def vgg16mod(num_classes: int) -> torch.nn.Module: - """Create modified VGG16 model. - - The classifier of this model is - Linear (25,088, 4096) - ReLU -> Dropout - Linear (1024, num_classes) - """ - model = torchvision.models.vgg16() - model.classifier = model.classifier[:4] - in_features = model.classifier[0].in_features - model.classifier[0] = torch.nn.Linear(in_features, 1024) - model.classifier[3] = torch.nn.Linear(1024, num_classes) - return model diff --git a/wsinfer/_patchlib/create_dense_patch_grid.py b/wsinfer/_patchlib/create_dense_patch_grid.py deleted file mode 100644 index c2fa429..0000000 --- a/wsinfer/_patchlib/create_dense_patch_grid.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Create a dense grid of patch coordinates. This does *not* create a tissue mask.""" - -import itertools -from pathlib import Path -from typing import Tuple - -import h5py -import large_image -import numpy as np - - -def _get_dense_grid( - slide, orig_patch_size: int, patch_spacing_um_px: float -) -> Tuple[np.ndarray, int]: - ts = large_image.getTileSource(slide) - patch_spacing_mm_px = patch_spacing_um_px / 1000 - patch_size = orig_patch_size * patch_spacing_mm_px / ts.getMetadata()["mm_x"] - patch_size = round(patch_size) - step_size = patch_size # non-overlapping patches - cols = ts.getMetadata()["sizeX"] - rows = ts.getMetadata()["sizeY"] - xs = range(0, cols, step_size) - ys = range(0, rows, step_size) - # List of (x, y) coordinates. - return np.asarray(list(itertools.product(xs, ys))), patch_size - - -def create_grid_and_save( - slide, results_dir, orig_patch_size: int, patch_spacing_um_px: float -): - """Create dense grid of (x,y) coordinates and save to HDF5. - - This is similar to the CLAM coordinate code but does not use a tissue mask. - """ - slide = Path(slide) - results_dir = Path(results_dir) - hdf5_path = results_dir / "patches" / f"{slide.stem}.h5" - hdf5_path.parent.mkdir(exist_ok=True) - coords, patch_size = _get_dense_grid( - slide=slide, - orig_patch_size=orig_patch_size, - patch_spacing_um_px=patch_spacing_um_px, - ) - with h5py.File(hdf5_path, "w") as f: - dset = f.create_dataset("/coords", data=coords, compression="gzip") - dset.attrs["name"] = str(hdf5_path.stem) - dset.attrs["patch_level"] = 0 - dset.attrs["patch_size"] = patch_size - dset.attrs["save_path"] = str(hdf5_path.parent) - - -def create_grid_and_save_multi_slides( - wsi_dir, results_dir, orig_patch_size: int, patch_spacing_um_px: float -): - wsi_dir = Path(wsi_dir) - slides = list(wsi_dir.glob("*")) - if not slides: - raise FileNotFoundError("no slides found") - - for slide in slides: - create_grid_and_save( - slide=slide, - results_dir=results_dir, - orig_patch_size=orig_patch_size, - patch_spacing_um_px=patch_spacing_um_px, - ) diff --git a/wsinfer/_patchlib/utils/__init__.py b/wsinfer/_patchlib/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/wsinfer/_patchlib/utils/utils.py b/wsinfer/_patchlib/utils/utils.py deleted file mode 100644 index a9f909f..0000000 --- a/wsinfer/_patchlib/utils/utils.py +++ /dev/null @@ -1,238 +0,0 @@ -import torch -import numpy as np -import torch.nn as nn -from torch.utils.data import ( - DataLoader, - Sampler, - WeightedRandomSampler, - RandomSampler, - SequentialSampler, - sampler, -) -import torch.optim as optim -from itertools import islice -import math -import collections - -if torch.cuda.is_available(): - device = torch.device("cuda") -elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): - device = torch.device("mps") -else: - device = torch.device("cpu") - -class SubsetSequentialSampler(Sampler): - """Samples elements sequentially from a given list of indices, without replacement. - - Arguments: - indices (sequence): a sequence of indices - """ - - def __init__(self, indices): - self.indices = indices - - def __iter__(self): - return iter(self.indices) - - def __len__(self): - return len(self.indices) - - -def collate_MIL(batch): - img = torch.cat([item[0] for item in batch], dim=0) - label = torch.LongTensor([item[1] for item in batch]) - return [img, label] - - -def collate_features(batch): - img = torch.cat([item[0] for item in batch], dim=0) - coords = np.vstack([item[1] for item in batch]) - return [img, coords] - - -def get_simple_loader(dataset, batch_size=1, num_workers=1): - kwargs = ( - {"pin_memory": False, "num_workers": num_workers} - if device.type == "cuda" - else {} - ) - loader = DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler.SequentialSampler(dataset), - collate_fn=collate_MIL, - **kwargs - ) - return loader - - -def get_split_loader(split_dataset, training=False, testing=False, weighted=False): - """ - return either the validation loader or training loader - """ - kwargs = {"num_workers": 4} if device.type == "cuda" else {} - if not testing: - if training: - if weighted: - weights = make_weights_for_balanced_classes_split(split_dataset) - loader = DataLoader( - split_dataset, - batch_size=1, - sampler=WeightedRandomSampler(weights, len(weights)), - collate_fn=collate_MIL, - **kwargs - ) - else: - loader = DataLoader( - split_dataset, - batch_size=1, - sampler=RandomSampler(split_dataset), - collate_fn=collate_MIL, - **kwargs - ) - else: - loader = DataLoader( - split_dataset, - batch_size=1, - sampler=SequentialSampler(split_dataset), - collate_fn=collate_MIL, - **kwargs - ) - - else: - ids = np.random.choice( - np.arange(len(split_dataset), int(len(split_dataset) * 0.1)), replace=False - ) - loader = DataLoader( - split_dataset, - batch_size=1, - sampler=SubsetSequentialSampler(ids), - collate_fn=collate_MIL, - **kwargs - ) - - return loader - - -def get_optim(model, args): - if args.opt == "adam": - optimizer = optim.Adam( - filter(lambda p: p.requires_grad, model.parameters()), - lr=args.lr, - weight_decay=args.reg, - ) - elif args.opt == "sgd": - optimizer = optim.SGD( - filter(lambda p: p.requires_grad, model.parameters()), - lr=args.lr, - momentum=0.9, - weight_decay=args.reg, - ) - else: - raise NotImplementedError - return optimizer - - -def print_network(net): - num_params = 0 - num_params_train = 0 - print(net) - - for param in net.parameters(): - n = param.numel() - num_params += n - if param.requires_grad: - num_params_train += n - - print("Total number of parameters: %d" % num_params) - print("Total number of trainable parameters: %d" % num_params_train) - - -def generate_split( - cls_ids, - val_num, - test_num, - samples, - n_splits=5, - seed=7, - label_frac=1.0, - custom_test_ids=None, -): - indices = np.arange(samples).astype(int) - - if custom_test_ids is not None: - indices = np.setdiff1d(indices, custom_test_ids) - - np.random.seed(seed) - for i in range(n_splits): - all_val_ids = [] - all_test_ids = [] - sampled_train_ids = [] - - if custom_test_ids is not None: # pre-built test split, do not need to sample - all_test_ids.extend(custom_test_ids) - - for c in range(len(val_num)): - possible_indices = np.intersect1d( - cls_ids[c], indices - ) # all indices of this class - val_ids = np.random.choice( - possible_indices, val_num[c], replace=False - ) # validation ids - - remaining_ids = np.setdiff1d( - possible_indices, val_ids - ) # indices of this class left after validation - all_val_ids.extend(val_ids) - - if custom_test_ids is None: # sample test split - test_ids = np.random.choice(remaining_ids, test_num[c], replace=False) - remaining_ids = np.setdiff1d(remaining_ids, test_ids) - all_test_ids.extend(test_ids) - - if label_frac == 1: - sampled_train_ids.extend(remaining_ids) - - else: - sample_num = math.ceil(len(remaining_ids) * label_frac) - slice_ids = np.arange(sample_num) - sampled_train_ids.extend(remaining_ids[slice_ids]) - - yield sampled_train_ids, all_val_ids, all_test_ids - - -def nth(iterator, n, default=None): - if n is None: - return collections.deque(iterator, maxlen=0) - else: - return next(islice(iterator, n, None), default) - - -def calculate_error(Y_hat, Y): - error = 1.0 - Y_hat.float().eq(Y.float()).float().mean().item() - - return error - - -def make_weights_for_balanced_classes_split(dataset): - N = float(len(dataset)) - weight_per_class = [ - N / len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids)) - ] - weight = [0] * int(N) - for idx in range(len(dataset)): - y = dataset.getlabel(idx) - weight[idx] = weight_per_class[y] - - return torch.DoubleTensor(weight) - - -def initialize_weights(module): - for m in module.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight) - m.bias.data.zero_() - - elif isinstance(m, nn.BatchNorm1d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) diff --git a/wsinfer/_patchlib/wsi_core/__init__.py b/wsinfer/_patchlib/wsi_core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/wsinfer/_version.py b/wsinfer/_version.py index bb00df9..e93924b 100644 --- a/wsinfer/_version.py +++ b/wsinfer/_version.py @@ -9,13 +9,16 @@ """Git implementation of _version.py.""" +from __future__ import annotations + import errno +import functools import os import re import subprocess import sys -from typing import Callable, Dict -import functools +from typing import Callable +from typing import Dict def get_keywords(): diff --git a/wsinfer/cli/__init__.py b/wsinfer/cli/__init__.py index e69de29..9d48db4 100644 --- a/wsinfer/cli/__init__.py +++ b/wsinfer/cli/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/wsinfer/cli/cli.py b/wsinfer/cli/cli.py index 1e32860..c1b7b53 100644 --- a/wsinfer/cli/cli.py +++ b/wsinfer/cli/cli.py @@ -1,24 +1,48 @@ +from __future__ import annotations + +import logging +from typing import Literal + import click -# from .. import __version__ +from ..wsi import set_backend from .convert_csv_to_geojson import togeojson from .convert_csv_to_sbubmi import tosbu from .infer import run -from .list_models_and_weights import list from .patch import patch +_logging_levels = ["debug", "info", "warning", "error", "critical"] + # We use invoke_without_command=True so that 'wsinfer' on its own can be used for # inference on slides. @click.group() +@click.option( + "--backend", + default=None, + help="Backend for loading whole slide images.", + type=click.Choice(["openslide", "tiffslide"]), +) +@click.option( + "--log-level", + default="info", + type=click.Choice(_logging_levels), + help="Set the loudness of logging.", +) @click.version_option() -def cli(): +def cli(backend: Literal["openslide"] | Literal["tiffslide"] | None, log_level: str): """Run patch-level classification inference on whole slide images.""" - pass + + # Configure logger. + levels = {level: getattr(logging, level.upper()) for level in _logging_levels} + level = levels[log_level] + logging.basicConfig(level=level) + + if backend is not None: + set_backend(backend) cli.add_command(run) cli.add_command(togeojson) cli.add_command(tosbu) cli.add_command(patch) -cli.add_command(list) diff --git a/wsinfer/cli/convert_csv_to_geojson.py b/wsinfer/cli/convert_csv_to_geojson.py index 11c2c40..834ca9f 100644 --- a/wsinfer/cli/convert_csv_to_geojson.py +++ b/wsinfer/cli/convert_csv_to_geojson.py @@ -3,9 +3,11 @@ GeoJSON files can be loaded into whole slide image viewers like QuPath. """ +from __future__ import annotations + import json -from pathlib import Path import typing +from pathlib import Path import click import pandas as pd diff --git a/wsinfer/cli/convert_csv_to_sbubmi.py b/wsinfer/cli/convert_csv_to_sbubmi.py index d98edbf..4c96855 100644 --- a/wsinfer/cli/convert_csv_to_sbubmi.py +++ b/wsinfer/cli/convert_csv_to_sbubmi.py @@ -2,40 +2,40 @@ Output directory tree for single class outputs: ├── heatmap_jsons -│  ├── heatmap-SLIDEID.json -│  └── meta-SLIDEID.json +│ ├── heatmap-SLIDEID.json +│ └── meta-SLIDEID.json └── heatmap_txt ├── color-SLIDEID └── prediction-SLIDEID Output directory tree for multi-class outputs: ├── heatmap_jsons -│  └── CLASS_LABEL -│  ├── heatmap-SLIDEID.json -│  └── meta-SLIDEID.json +│ └── CLASS_LABEL +│ ├── heatmap-SLIDEID.json +│ └── meta-SLIDEID.json └── heatmap_txt └── CLASS_LABEL ├── color-SLIDEID └── prediction-SLIDEID """ +from __future__ import annotations + import json import multiprocessing -from pathlib import Path import pprint import random import shutil import time import typing +from pathlib import Path import click -import large_image import numpy as np import pandas as pd import tqdm - -PathType = typing.Union[str, Path] +from wsinfer.wsi import WSI def _box_to_polygon( @@ -48,11 +48,11 @@ def _box_to_polygon( def write_heatmap_and_meta_json_lines( - input: PathType, - output_heatmap: PathType, - output_meta: PathType, - slide_width: PathType, - slide_height: PathType, + input: str | Path, + output_heatmap: str | Path, + output_meta: str | Path, + slide_width: int, + slide_height: int, execution_id: str, study_id: str, case_id: str, @@ -169,7 +169,9 @@ def row_to_json(row: pd.Series): json.dump(meta_dict, f) -def write_heatmap_txt(input: PathType, output: PathType, class_names: typing.List[str]): +def write_heatmap_txt( + input: str | Path, output: str | Path, class_names: typing.List[str] +): df = pd.read_csv(input) # TODO: should we round and cast to int here? df.loc[:, "x_loc"] = (df.minx + (df.width / 2)).round().astype(int) @@ -183,9 +185,9 @@ def write_heatmap_txt(input: PathType, output: PathType, class_names: typing.Lis def write_color_txt( - input: PathType, - output: PathType, - ts: large_image.tilesource.TileSource, + input: str | Path, + output: str | Path, + slide, num_processes: int = 6, ): def whiteness(arr): @@ -205,15 +207,12 @@ def redness(arr): global get_color # Hack to please multiprocessing. def get_color(row: pd.Series): - arr, _ = ts.getRegion( - format=large_image.constants.TILE_FORMAT_NUMPY, - region=dict( - left=row["minx"], - top=row["miny"], - width=row["width"], - height=row["height"], - ), + patch_im = slide.read_region( + location=(row["minx"], row["miny"]), + level=0, + size=(row["width"], row["height"]), ) + arr = np.asarray(patch_im) white = whiteness(arr) black = blackness(arr) red = redness(arr) @@ -357,11 +356,9 @@ def tosbu( click.secho(f"WSI file not found: {wsi_file}", bg="red") click.secho("Skipping...", bg="red") continue - ts = large_image.getTileSource(wsi_file) - if ts.sizeX is None or ts.sizeY is None: - click.secho(f"Unknown size for WSI: {wsi_file}", bg="red") - click.secho("Skipping...", bg="red") - continue + slide = WSI(wsi_file) + + slide_width, slide_height = slide.level_dimensions[0] for class_name in class_names: if len(class_names) == 1: @@ -383,8 +380,8 @@ def tosbu( input=input_csv, output_heatmap=output_heatmap, output_meta=output_meta, - slide_width=ts.sizeX, - slide_height=ts.sizeY, + slide_width=slide_width, + slide_height=slide_height, execution_id=execution_id, study_id=study_id, case_id=slide_id, # TODO: should case_id be different? @@ -419,7 +416,7 @@ def tosbu( write_color_txt( input=input_csv, output=output_color, - ts=ts, + slide=slide, num_processes=num_processes, ) else: @@ -430,7 +427,7 @@ def tosbu( write_color_txt( input=input_csv, output=output_color, - ts=ts, + slide=slide, num_processes=num_processes, ) # Copy this color file to all class-specific dirs. diff --git a/wsinfer/cli/infer.py b/wsinfer/cli/infer.py index 0fea2f8..795a7a5 100644 --- a/wsinfer/cli/infer.py +++ b/wsinfer/cli/infer.py @@ -1,24 +1,30 @@ """Detect cancerous regions in a whole slide image.""" -from datetime import datetime +from __future__ import annotations + +import dataclasses import getpass import json import os -from pathlib import Path import platform import shutil import subprocess import sys -import typing +from datetime import datetime +from pathlib import Path +from typing import Optional +from typing import Union import click +import wsinfer_zoo +import wsinfer_zoo.client +import yaml +from wsinfer_zoo.client import HFModel +from wsinfer_zoo.client import ModelConfiguration -from .._modellib.run_inference import run_inference -from .._modellib import models -from .._patchlib.create_dense_patch_grid import create_grid_and_save_multi_slides -from .._patchlib.create_patches_fp import create_patches - -PathType = typing.Union[str, Path] +from ..modellib import models +from ..modellib.run_inference import run_inference +from ..patchlib.create_patches_fp import create_patches def _num_cpus() -> int: @@ -54,6 +60,7 @@ def _print_system_info() -> None: """Print information about the system.""" import torch import torchvision + from .. import __version__ click.secho(f"\nRunning wsinfer version {__version__}", fg="green") @@ -102,10 +109,11 @@ def _print_system_info() -> None: click.secho("*******************************************", fg="yellow") -def _get_info_for_save(weights: models.Weights): +def _get_info_for_save(model_obj: Union[models.LocalModelTorchScript, HFModel]): """Get dictionary with information about the run. To save as JSON in output dir.""" import torch + from .. import __version__ here = Path(__file__).parent.resolve() @@ -147,35 +155,15 @@ def get_stdout(args) -> str: if git_installed and is_git_repo: git_info = get_git_info() - weights_file = weights.file - if weights_file is None: - if weights.url_file_name is None: - raise TypeError("url_file_name must not be None if file is None.") - weights_file = str( - Path(torch.hub.get_dir()) / "checkpoints" / weights.url_file_name - ) - else: - # Weights file could have been a pathlib.Path object. - weights_file = str(weights_file) + hf_info = None + if hasattr(model_obj, "hf_info"): + hf_info = dataclasses.asdict(model_obj.hf_info) return { - "model_weights": { - "name": weights.name, - "architecture": weights.architecture, - "weights_url": weights.url, - "weights_url_file_name": weights.url_file_name, - "weights_file": weights_file, - "weights_sha256": weights.get_sha256_of_weights(), - "class_names": weights.class_names, - "num_classes": weights.num_classes, - "patch_size_pixels": weights.patch_size_pixels, - "spacing_um_px": weights.spacing_um_px, - "transform": { - "resize_size": weights.transform.resize_size, - "mean": weights.transform.mean, - "std": weights.transform.std, - }, - "metadata": weights.metadata or None, + "model": { + "config": dataclasses.asdict(model_obj.config), + "huggingface_location": hf_info, + "path": str(model_obj.model_path), }, "runtime": { "version": __version__, @@ -187,6 +175,7 @@ def get_stdout(args) -> str: "pytorch_version": torch.__version__, "cuda_version": torch.version.cuda, "git": git_info, + "wsinfer_zoo_version": wsinfer_zoo.__version__, }, "timestamp": _get_timestamp(), } @@ -195,6 +184,7 @@ def get_stdout(args) -> str: @click.command(context_settings=dict(auto_envvar_prefix="WSINFER")) @click.pass_context @click.option( + "-i", "--wsi-dir", type=click.Path(exists=True, file_okay=False, path_type=Path, resolve_path=True), required=True, @@ -202,6 +192,7 @@ def get_stdout(args) -> str: " whole slide images.", ) @click.option( + "-o", "--results-dir", type=click.Path(file_okay=False, path_type=Path, resolve_path=True), required=True, @@ -209,25 +200,34 @@ def get_stdout(args) -> str: " whole slides for which outputs exist.", ) @click.option( + "-m", "--model", - type=click.Choice(sorted({a for a, _ in models.list_all_models_and_weights()})), - help="Model architecture to use. Not required if 'config' is used.", -) -@click.option( - "--weights", - type=click.Choice(sorted({w for _, w in models.list_all_models_and_weights()})), - help="Name of weights to use for the model. Not required if 'config' is used.", + "model_name", + type=click.Choice(sorted(wsinfer_zoo.client.load_registry().models.keys())), + help="Name of the model to use from WSInfer Model Zoo. Mutually exclusive with" + " --config.", ) @click.option( + "-c", "--config", type=click.Path(exists=True, dir_okay=False, path_type=Path, resolve_path=True), help=( - "Path to configuration for architecture and weights. Use this option if the" + "Path to configuration for the trained model. Use this option if the" " model weights are not registered in wsinfer. Mutually exclusive with" - " 'model' and 'weights'." + "--model" + ), +) +@click.option( + "-p", + "--model-path", + type=click.Path(exists=True, dir_okay=False, path_type=Path, resolve_path=True), + help=( + "Path to the pretrained model. Use only when --config is passed. Mutually " + "exclusive with --model." ), ) @click.option( + "-b", "--batch-size", type=click.IntRange(min=1), default=32, @@ -236,6 +236,7 @@ def get_stdout(args) -> str: " batch size.", ) @click.option( + "-n", "--num-workers", default=min(_num_cpus(), 8), # Use at most 8 workers by default. show_default=True, @@ -247,7 +248,8 @@ def get_stdout(args) -> str: "--speedup/--no-speedup", default=False, show_default=True, - help="JIT-compile the model for potential speedups.", + help="JIT-compile the model and apply inference optimizations. This imposes a" + " startup cost but may improve performance overall.", ) @click.option( "--roi-dir", @@ -262,26 +264,18 @@ def get_stdout(args) -> str: ), default=None, ) -@click.option( - "--dense-grid/--no-dense-grid", - default=False, - show_default=True, - help="Use a dense grid of patch coordinates. Patches will be present even if no" - " tissue is present", -) def run( ctx: click.Context, *, wsi_dir: Path, results_dir: Path, - model: typing.Optional[str], - weights: typing.Optional[str], - config: typing.Optional[Path], + model_name: Optional[str], + config: Optional[Path], + model_path: Optional[Path], batch_size: int, num_workers: int = 0, speedup: bool = False, - roi_dir: typing.Optional[PathType] = None, - dense_grid: bool = False, + roi_dir: Optional[str | Path] = None, ): """Run model inference on a directory of whole slide images. @@ -291,17 +285,24 @@ def run( Example: - CUDA_VISIBLE_DEVICES=0 wsinfer run --wsi_dir slides/ --results_dir results - --model resnet34 --weights TCGA-BRCA-v1 --batch_size 32 --num_workers 4 + CUDA_VISIBLE_DEVICES=0 wsinfer run --wsi-dir slides/ --results-dir results + --model breast-tumor-resnet34.tcga-brca --batch-size 32 --num-workers 4 - To list all available models and weights, use `wsinfer list`. + To list all available models and weights, use `wsinfer ls`. """ - if model is None and weights is None and config is None: - raise click.UsageError("one of (model and weights) or config is required.") - elif (model is not None or weights is not None) and config is not None: - raise click.UsageError("model and weights are mutually exclusive with config.") - elif (model is not None) ^ (weights is not None): # XOR - raise click.UsageError("model and weights must both be set if one is set.") + + if model_name is None and config is None and model_path is None: + raise click.UsageError( + "one of --model or (--config and --model-path) is required." + ) + elif (config is not None or model_path is not None) and model_name is not None: + raise click.UsageError( + "--config and --model-path are mutually exclusive with --model." + ) + elif (config is not None) ^ (model_path is not None): # XOR + raise click.UsageError( + "--config and --model-path must both be set if one is set." + ) wsi_dir = wsi_dir.resolve() results_dir = results_dir.resolve() @@ -328,36 +329,46 @@ def run( # Get weights object before running the patching script because we need to get the # necessary spacing and patch size. - if model is not None and weights is not None: - weights_obj = models.get_model_weights(model, name=weights) + model_obj: HFModel | models.LocalModelTorchScript + if model_name is not None: + model_obj = models.get_registered_model(name=model_name) elif config is not None: - weights_obj = models.Weights.from_yaml(config) - - click.secho("\nFinding patch coordinates...\n", fg="green") - if dense_grid: - click.echo("Not using a tissue mask.") - create_grid_and_save_multi_slides( - wsi_dir=wsi_dir, - results_dir=results_dir, - orig_patch_size=weights_obj.patch_size_pixels, - patch_spacing_um_px=weights_obj.spacing_um_px, + assert config.suffix in {".json", ".yaml", ".yml"}, "Unknown file type" + if config.suffix in {".yaml", ".yml"}: + with open(config) as f: + _config_dict = yaml.safe_load(f) + else: + with open(config) as f: + _config_dict = json.load(f) + model_config = ModelConfiguration.from_dict(_config_dict) + model_obj = models.LocalModelTorchScript( + config=model_config, model_path=str(model_path) ) + del _config_dict, model_config else: - create_patches( - source=str(wsi_dir), - save_dir=str(results_dir), - patch_size=weights_obj.patch_size_pixels, - patch_spacing=weights_obj.spacing_um_px, - seg=True, - patch=True, - preset="tcga.csv", - ) + raise click.ClickException("Neither of --config and --model was passed") + + click.secho("\nFinding patch coordinates...\n", fg="green") + + create_patches( + source=str(wsi_dir), + save_dir=str(results_dir), + patch_size=model_obj.config.patch_size_pixels, + patch_spacing=model_obj.config.spacing_um_px, + seg=True, + patch=True, + # Stitching is a bottleneck when using tiffslide. + # TODO: figure out why this is... + stitch=False, + # FIXME: allow customization of this preset + preset="tcga.csv", + ) click.secho("\nRunning model inference.\n", fg="green") failed_patching, failed_inference = run_inference( wsi_dir=wsi_dir, results_dir=results_dir, - weights=weights_obj, + model_info=model_obj, batch_size=batch_size, num_workers=num_workers, speedup=speedup, @@ -376,7 +387,7 @@ def run( timestamp = datetime.now().astimezone().strftime("%Y%m%dT%H%M%S") run_metadata_outpath = results_dir / f"run_metadata_{timestamp}.json" click.echo(f"Saving metadata about run to {run_metadata_outpath}") - run_metadata = _get_info_for_save(weights_obj) + run_metadata = _get_info_for_save(model_obj) with open(run_metadata_outpath, "w") as f: json.dump(run_metadata, f, indent=2) diff --git a/wsinfer/cli/list_models_and_weights.py b/wsinfer/cli/list_models_and_weights.py deleted file mode 100644 index 4155784..0000000 --- a/wsinfer/cli/list_models_and_weights.py +++ /dev/null @@ -1,25 +0,0 @@ -import click - -from .._modellib import models - - -@click.command() -def list(): - """Show all available models and weights.""" - models_weights = models.list_all_models_and_weights() - - weights = [models.get_model_weights(*mw) for mw in models_weights] - - print("+-----------------------------------------------------------+") - click.secho( - "| MODEL WEIGHTS RESOLUTION |", bold=True - ) - print("| ========================================================= |") - _prev_model = models_weights[0][0] - for (model_name, weights_name), weight_obj in zip(models_weights, weights): - if _prev_model != model_name: - print("| --------------------------------------------------------- |") - _prev_model = model_name - r = f"{weight_obj.patch_size_pixels} px @ {weight_obj.spacing_um_px} um/px" - print(f"| {model_name:<18}{weights_name:<15}{r:<25}|") - print("+-----------------------------------------------------------+") diff --git a/wsinfer/cli/patch.py b/wsinfer/cli/patch.py index 36bb01c..105b0fe 100644 --- a/wsinfer/cli/patch.py +++ b/wsinfer/cli/patch.py @@ -1,8 +1,10 @@ +from __future__ import annotations + from typing import Optional import click -from .._patchlib.create_patches_fp import create_patches as _create_patches +from ..patchlib.create_patches_fp import create_patches as _create_patches @click.command() diff --git a/wsinfer/errors.py b/wsinfer/errors.py new file mode 100644 index 0000000..c49ac8b --- /dev/null +++ b/wsinfer/errors.py @@ -0,0 +1,35 @@ +"""Exceptions used in WSInfer.""" + +from __future__ import annotations + + +class WsinferException(Exception): + """Base class for wsinfer exceptions.""" + + +class UnknownArchitectureError(WsinferException): + """Architecture is unknown and cannot be found.""" + + +class WholeSlideImageDirectoryNotFound(WsinferException, FileNotFoundError): + ... + + +class WholeSlideImagesNotFound(WsinferException, FileNotFoundError): + ... + + +class ResultsDirectoryNotFound(WsinferException, FileNotFoundError): + ... + + +class PatchDirectoryNotFound(WsinferException, FileNotFoundError): + ... + + +class CannotReadSpacing(WsinferException): + ... + + +class NoBackendException(WsinferException): + ... diff --git a/wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml b/wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml deleted file mode 100644 index a24d779..0000000 --- a/wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# Configuration of a breast cancer tumor detection model. -# The specification version. Only 1.0 is supported at this time. -version: "1.0" -# The models are referenced by the pair of [architecture, weights], so this pair must -# be unique. -architecture: inception_v4 # Must be a string. -name: TCGA-BRCA-v1 # Must be a string. -# Where to get the model weights. Either a URL or path to a file. -# If using a URL, set the url_file_name (the name of the file when it is downloaded). -url: https://stonybrookmedicine.box.com/shared/static/tfwimlf3ygyga1x4fnn03u9y5uio8gqk.pt -url_file_name: inceptionv4-brca-20190613-aef40942.pt -# If using a relative path, the path is relative to the location of the yaml file. -# file: /path/to/weights.pt -num_classes: 2 -transform: - # These are keyword arguments to the PatchClassification class. - resize_size: 299 # Must be a single integer. - mean: [0.5, 0.5, 0.5] # Must be a list of three floats. - std: [0.5, 0.5, 0.5] # Must be a list of three floats. -patch_size_pixels: 350 -spacing_um_px: 0.25 -class_names: - - notumor - - tumor diff --git a/wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml b/wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml deleted file mode 100644 index f542c29..0000000 --- a/wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# Configuration of a tumor infiltrating lymphocyte detection model. -# The specification version. Only 1.0 is supported at this time. -version: "1.0" -# The models are referenced by the pair of [architecture, weights], so this pair must -# be unique. -# Inceptionv4 without batch normalization. -architecture: inception_v4nobn # Must be a string. -name: TCGA-TILs-v1 # Must be a string. -# Where to get the model weights. Either a URL or path to a file. -# If using a URL, set the url_file_name (the name of the file when it is downloaded). -url: https://stonybrookmedicine.box.com/shared/static/sz1gpc6u3mftadh4g6x3csxnpmztj8po.pt -url_file_name: inceptionv4-tils-v1-20200920-e3e72cd2.pt -# If using a relative path, the path is relative to the location of the yaml file. -# file: /path/to/weights.pt -num_classes: 2 -transform: - # These are keyword arguments to the PatchClassification class. - resize_size: 299 # Must be a single integer. - mean: [0.5, 0.5, 0.5] # Must be a list of three floats. - std: [0.5, 0.5, 0.5] # Must be a list of three floats. -patch_size_pixels: 100 -spacing_um_px: 0.5 -class_names: - - notils - - tils -metadata: - publication: https://doi.org/10.3389/fonc.2021.806603 - notes: | - Implementation does not use batchnorm. Original model was trained with TF Slim - and converted to PyTorch format. diff --git a/wsinfer/modeldefs/preactresnet34_tcga-paad-v1.yaml b/wsinfer/modeldefs/preactresnet34_tcga-paad-v1.yaml deleted file mode 100644 index dc2229e..0000000 --- a/wsinfer/modeldefs/preactresnet34_tcga-paad-v1.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# Configuration of a pancreatic adenocarcinoma tumor detection model. -# The specification version. Only 1.0 is supported at this time. -version: "1.0" -# The models are referenced by the pair of [architecture, weights], so this pair must -# be unique. -architecture: preactresnet34 # Must be a string. -name: TCGA-PAAD-v1 # Must be a string. -# Where to get the model weights. Either a URL or path to a file. -# If using a URL, set the url_file_name (the name of the file when it is downloaded). -url: https://stonybrookmedicine.box.com/shared/static/sol1h9aqrh8lynzc6kidw1lsoeks20hh.pt -url_file_name: preactresnet34-paad-20210101-7892b41f.pt -# If using a relative path, the path is relative to the location of the yaml file. -# file: /path/to/weights.pt -num_classes: 1 -transform: - # These are keyword arguments to the PatchClassification class. - resize_size: 224 # Must be a single integer. - mean: [0.7238, 0.5716, 0.6779] - std: [0.1120, 0.1459, 0.1089] -patch_size_pixels: 350 -# Patches are 525.1106 microns. -# Patch of 2078 pixels @ 0.2527 mpp is 350 pixels at our target spacing. -# (2078 * 0.2527) / 350 -spacing_um_px: 1.500316 -class_names: - - tumor diff --git a/wsinfer/modeldefs/resnet34_tcga-brca-v1.yaml b/wsinfer/modeldefs/resnet34_tcga-brca-v1.yaml deleted file mode 100644 index 096494d..0000000 --- a/wsinfer/modeldefs/resnet34_tcga-brca-v1.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# Configuration of a breast cancer tumor detection model. -# The specification version. Only 1.0 is supported at this time. -version: "1.0" -# The models are referenced by the pair of [architecture, weights], so this pair must -# be unique. -architecture: resnet34 # Must be a string. -name: TCGA-BRCA-v1 # Must be a string. -# Where to get the model weights. Either a URL or path to a file. -# If using a URL, set the url_file_name (the name of the file when it is downloaded). -url: https://stonybrookmedicine.box.com/shared/static/dv5bxk6d15uhmcegs9lz6q70yrmwx96p.pt -url_file_name: resnet34-brca-20190613-01eaf604.pt -# If using a relative path, the path is relative to the location of the yaml file. -# file: /path/to/weights.pt -num_classes: 2 -transform: - # These are keyword arguments to the PatchClassification class. - resize_size: 224 # Must be a single integer. - mean: [0.7238, 0.5716, 0.6779] # Must be a list of three floats. - std: [0.1120, 0.1459, 0.1089] # Must be a list of three floats. -patch_size_pixels: 350 -spacing_um_px: 0.25 -class_names: - - notumor - - tumor diff --git a/wsinfer/modeldefs/resnet34_tcga-luad-v1.yaml b/wsinfer/modeldefs/resnet34_tcga-luad-v1.yaml deleted file mode 100644 index bbc5b8d..0000000 --- a/wsinfer/modeldefs/resnet34_tcga-luad-v1.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# Configuration of a lung adenocarcinoma tumor detection model. -# The specification version. Only 1.0 is supported at this time. -version: "1.0" -# The models are referenced by the pair of [architecture, weights], so this pair must -# be unique. -architecture: resnet34 # Must be a string. -name: TCGA-LUAD-v1 # Must be a string. -# Where to get the model weights. Either a URL or path to a file. -# If using a URL, set the url_file_name (the name of the file when it is downloaded). -url: https://stonybrookmedicine.box.com/shared/static/d6g9huv1olfu2mt9yaud9xqf9bdqx38i.pt -url_file_name: resnet34-luad-20210102-93038ae6.pt -# If using a relative path, the path is relative to the location of the yaml file. -# file: /path/to/weights.pt -num_classes: 6 -transform: - # These are keyword arguments to the PatchClassification class. - resize_size: 224 # Must be a single integer. - mean: [0.8301, 0.6600, 0.8054] - std: [0.0864, 0.1602, 0.0647] -patch_size_pixels: 350 -spacing_um_px: 0.5 -class_names: - - lepidic - - benign - - acinar - - micropapillary - - mucinous - - solid diff --git a/wsinfer/modeldefs/resnet34_tcga-prad-v1.yaml b/wsinfer/modeldefs/resnet34_tcga-prad-v1.yaml deleted file mode 100644 index 179ae9f..0000000 --- a/wsinfer/modeldefs/resnet34_tcga-prad-v1.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# Configuration of a prostate adenocarcinoma tumor detection model. -# The specification version. Only 1.0 is supported at this time. -version: "1.0" -# The models are referenced by the pair of [architecture, weights], so this pair must -# be unique. -architecture: resnet34 # Must be a string. -name: TCGA-PRAD-v1 # Must be a string. -# Where to get the model weights. Either a URL or path to a file. -# If using a URL, set the url_file_name (the name of the file when it is downloaded). -url: https://stonybrookmedicine.box.com/shared/static/nxyr5atk2nlvgibck3l0q6rjin2g7n38.pt -url_file_name: resnet34-prad-20210101-ea6c004c.pt -# If using a relative path, the path is relative to the location of the yaml file. -# file: /path/to/weights.pt -num_classes: 3 -transform: - # These are keyword arguments to the PatchClassification class. - resize_size: 224 # Must be a single integer. - mean: [0.6462, 0.5070, 0.8055] - std: [0.1381, 0.1674, 0.1358] -patch_size_pixels: 175 -spacing_um_px: 0.5 -class_names: - - grade3 - - grade4or5 - - benign diff --git a/wsinfer/modeldefs/vgg16_tcga-tils-v1.yaml b/wsinfer/modeldefs/vgg16_tcga-tils-v1.yaml deleted file mode 100644 index 158d933..0000000 --- a/wsinfer/modeldefs/vgg16_tcga-tils-v1.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# Configuration of a tumor infiltrating lymphocyte detection model (VGG16). -version: "1.0" -architecture: vgg16 -name: TCGA-TILs-v1 -url: https://stonybrookmedicine.box.com/shared/static/0orxxw2aai3l3lztetvukwqetvr3z4lr.pt -url_file_name: vgg16-tils-20220112-3088cb70.pt -num_classes: 2 -transform: - resize_size: 224 - # Normalize to [-1, 1] - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] -patch_size_pixels: 100 -spacing_um_px: 0.5 -class_names: - - notils - - tils -metadata: - notes: | - Original code available at https://github.com/SBU-BMI/u24_lymphocyte. diff --git a/wsinfer/modeldefs/vgg16mod_tcga-BRCA-v1.yaml b/wsinfer/modeldefs/vgg16mod_tcga-BRCA-v1.yaml deleted file mode 100644 index 43204d6..0000000 --- a/wsinfer/modeldefs/vgg16mod_tcga-BRCA-v1.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# Configuration of a tumor infiltrating lymphocyte detection model. -# The specification version. Only 1.0 is supported at this time. -version: "1.0" -# The models are referenced by the pair of [architecture, weights], so this pair must -# be unique. -# Inceptionv4 without batch normalization. -architecture: vgg16mod # Must be a string. -name: TCGA-BRCA-v1 # Must be a string. -# Where to get the model weights. Either a URL or path to a file. -# If using a URL, set the url_file_name (the name of the file when it is downloaded). -url: https://stonybrookmedicine.box.com/shared/static/197s56yvcrdpan7eu5tq8d4gxvq3xded.pt -url_file_name: vgg16-modified-brca-20190613-62bc1b41.pt -# If using a relative path, the path is relative to the location of the yaml file. -# file: /path/to/weights.pt -num_classes: 2 -transform: - # These are keyword arguments to the PatchClassification class. - resize_size: 224 # Must be a single integer. - mean: [0.7238, 0.5716, 0.6779] # Must be a list of three floats. - std: [0.1120, 0.1459, 0.1089] # Must be a list of three floats. -patch_size_pixels: 350 -spacing_um_px: 0.25 -class_names: - - notumor - - tumor -metadata: - notes: | - This model is a modified VGG16. The second-to-last linear layer was removed. See - https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7369575/table/tbl3/ for details. diff --git a/wsinfer/modellib/__init__.py b/wsinfer/modellib/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/wsinfer/modellib/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/wsinfer/modellib/data.py b/wsinfer/modellib/data.py new file mode 100644 index 0000000..ac1e45d --- /dev/null +++ b/wsinfer/modellib/data.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Callable +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import h5py +import numpy as np +import torch +from PIL import Image + +from wsinfer.wsi import WSI + + +def _read_patch_coords(path: str | Path) -> np.ndarray: + """Read HDF5 file of patch coordinates are return numpy array. + + Returned array has shape (num_patches, 4). Each row has values + [minx, miny, width, height]. + """ + with h5py.File(path, mode="r") as f: + coords = f["/coords"][()] + coords_metadata = f["/coords"].attrs + if "patch_level" not in coords_metadata.keys(): + raise KeyError( + "Could not find required key 'patch_level' in hdf5 of patch " + "coordinates. Has the version of CLAM been updated?" + ) + patch_level = coords_metadata["patch_level"] + if patch_level != 0: + raise NotImplementedError( + f"This script is designed for patch_level=0 but got {patch_level}" + ) + if coords.ndim != 2: + raise ValueError(f"expected coords to have 2 dimensions, got {coords.ndim}") + if coords.shape[1] != 2: + raise ValueError( + f"expected second dim of coords to have len 2 but got {coords.shape[1]}" + ) + + if "patch_size" not in coords_metadata.keys(): + raise KeyError("expected key 'patch_size' in attrs of coords dataset") + # Append width and height values to the coords, so now each row is + # [minx, miny, width, height] + wh = np.full_like(coords, coords_metadata["patch_size"]) + coords = np.concatenate((coords, wh), axis=1) + + return coords + + +def _filter_patches_in_rois( + *, geojson_path: str | Path, coords: np.ndarray +) -> np.ndarray: + """Keep the patches that intersect the ROI(s). + + Parameters + ---------- + geojson_path : str, Path + Path to the GeoJSON file that encodes the points of the ROI(s). + coords : ndarray + Two-dimensional array where each row has minx, miny, width, height. + + Returns + ------- + ndarray of filtered coords. + """ + import geojson + from shapely import STRtree + from shapely.geometry import box + from shapely.geometry import shape + + with open(geojson_path) as f: + geo = geojson.load(f) + if not geo.is_valid: + raise ValueError("GeoJSON of ROI is not valid") + for roi in geo["features"]: + assert roi.is_valid, "an ROI geometry is not valid" + geoms_rois = [shape(roi["geometry"]) for roi in geo["features"]] + coords_orig = coords.copy() + coords = coords.copy() + coords[:, 2] += coords[:, 0] # Calculate maxx. + coords[:, 3] += coords[:, 1] # Calculate maxy. + boxes = [box(*coords[idx]) for idx in range(coords.shape[0])] + tree = STRtree(boxes) + _, intersecting_ids = tree.query(geoms_rois, predicate="intersects") + intersecting_ids = np.sort(np.unique(intersecting_ids)) + return coords_orig[intersecting_ids] + + +class WholeSlideImagePatches(torch.utils.data.Dataset): + """Dataset of one whole slide image. + + This object retrieves patches from a whole slide image on the fly. + + Parameters + ---------- + wsi_path : str, Path + Path to whole slide image file. + patch_path : str, Path + Path to npy file with coordinates of input image. + um_px : float + Scale of the resulting patches. For example, 0.5 for ~20x magnification. + patch_size : int + The size of patches in pixels. + transform : callable, optional + A callable to modify a retrieved patch. The callable must accept a + PIL.Image.Image instance and return a torch.Tensor. + roi_path : str, Path, optional + Path to GeoJSON file that outlines the region of interest (ROI). Only patches + within the ROI(s) will be used. + """ + + def __init__( + self, + wsi_path: str | Path, + patch_path: str | Path, + um_px: float, + patch_size: int, + transform: Optional[Callable[[Image.Image], torch.Tensor]] = None, + roi_path: Optional[str | Path] = None, + ): + self.wsi_path = wsi_path + self.patch_path = patch_path + self.um_px = float(um_px) + self.patch_size = int(patch_size) + self.transform = transform + self.roi_path = roi_path + + assert Path(wsi_path).exists(), "wsi path not found" + assert Path(patch_path).exists(), "patch path not found" + if roi_path is not None: + assert Path(roi_path).exists(), "roi path not found" + + self.patches = _read_patch_coords(self.patch_path) + + # If an ROI is given, keep patches that intersect it. + if self.roi_path is not None: + self.patches = _filter_patches_in_rois( + geojson_path=self.roi_path, coords=self.patches + ) + if self.patches.shape[0] == 0: + raise ValueError("No patches left after taking intersection with ROI") + + assert self.patches.ndim == 2, "expected 2D array of patch coordinates" + # x, y, width, height + assert self.patches.shape[1] == 4, "expected second dimension to have len 4" + + def worker_init(self, *_): + self.slide = WSI(self.wsi_path) + + def __len__(self): + return self.patches.shape[0] + + def __getitem__( + self, idx: int + ) -> Tuple[Union[Image.Image, torch.Tensor], torch.Tensor]: + coords: Sequence[int] = self.patches[idx] + assert len(coords) == 4, "expected 4 coords (minx, miny, width, height)" + minx, miny, width, height = coords + + patch_im = self.slide.read_region( + location=(minx, miny), level=0, size=(width, height) + ) + patch_im = patch_im.convert("RGB") + + if self.transform is not None: + patch_im = self.transform(patch_im) + + return patch_im, torch.as_tensor([minx, miny, width, height]) diff --git a/wsinfer/modellib/models.py b/wsinfer/modellib/models.py new file mode 100644 index 0000000..bef488c --- /dev/null +++ b/wsinfer/modellib/models.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import dataclasses +import warnings +from typing import Callable +from typing import Union + + +import torch +import wsinfer_zoo +from wsinfer_zoo.client import HFModelTorchScript +from wsinfer_zoo.client import Model + + +@dataclasses.dataclass +class LocalModelTorchScript(Model): + ... + + +def get_registered_model(name: str) -> HFModelTorchScript: + registry = wsinfer_zoo.client.load_registry() + model = registry.get_model_by_name(name=name) + return model.load_model_torchscript() + + +def get_pretrained_torch_module( + model: HFModelTorchScript | LocalModelTorchScript, +) -> torch.nn.Module: + """Get a PyTorch Module with weights loaded.""" + return torch.jit.load(model.model_path, map_location="cpu") + + +def jit_compile( + model: torch.nn.Module, +) -> Union[torch.jit.ScriptModule, torch.nn.Module, Callable]: + """JIT-compile a model for inference. + + A torchscript model may be JIT compiled here as well. + """ + noncompiled = model + device = next(model.parameters()).device + # Attempt to script. If it fails, return the original. + test_input = torch.ones(1, 3, 224, 224).to(device) + w = "Warning: could not JIT compile the model. Using non-compiled model instead." + + # PyTorch 2.x has torch.compile but it does not work when applied + # to TorchScript models. + if hasattr(torch, "compile") and not isinstance(model, torch.jit.ScriptModule): + # Try to get the most optimized model. + try: + return torch.compile(model, fullgraph=True, mode="max-autotune") + except Exception: + pass + try: + return torch.compile(model, mode="max-autotune") + except Exception: + pass + try: + return torch.compile(model) + except Exception: + warnings.warn(w) + return noncompiled + # For pytorch 1.x, use torch.jit.script. + else: + try: + mjit = torch.jit.script(model) + with torch.no_grad(): + mjit(test_input) + except Exception: + warnings.warn(w) + return noncompiled + # Now that we have scripted the model, try to optimize it further. If that + # fails, return the scripted model. + try: + mjit_frozen = torch.jit.freeze(mjit) + mjit_opt = torch.jit.optimize_for_inference(mjit_frozen) + with torch.no_grad(): + mjit_opt(test_input) + return mjit_opt + except Exception: + return mjit diff --git a/wsinfer/modellib/run_inference.py b/wsinfer/modellib/run_inference.py new file mode 100644 index 0000000..b8f5fb9 --- /dev/null +++ b/wsinfer/modellib/run_inference.py @@ -0,0 +1,211 @@ +"""Run inference. + +From the original paper (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7369575/): +> In the prediction (test) phase, no data augmentation was applied except for the +> normalization of the color channels. +""" +from __future__ import annotations + +import typing +from pathlib import Path +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import pandas as pd +import torch +import tqdm +import wsinfer_zoo.client + +from .. import errors +from .data import WholeSlideImagePatches +from .models import LocalModelTorchScript +from .models import get_pretrained_torch_module +from .models import jit_compile +from .transforms import make_compose_from_transform_config + + +def run_inference( + wsi_dir: str | Path, + results_dir: str | Path, + model_info: wsinfer_zoo.client.HFModelTorchScript | LocalModelTorchScript, + batch_size: int = 32, + num_workers: int = 0, + speedup: bool = False, + roi_dir: Optional[str | Path] = None, +) -> Tuple[List[str], List[str]]: + """Run model inference on a directory of whole slide images and save results to CSV. + + This assumes the patching has already been done and the results are stored in + `results_dir`. An error will be raised otherwise. + + Output CSV files are written to `{results_dir}/model-outputs/`. + + Parameters + ---------- + wsi_dir : str or Path + Directory containing whole slide images. This directory can *only* contain + whole slide images. Otherwise, an error will be raised during model inference. + results_dir : str or Path + Directory containing results of patching. + model_info : + Instance of Weights including the model object and information about how to + apply the model to new data. + batch_size : int + The batch size during the forward pass (default is 32). + num_workers : int + Number of workers for data loading (default is 0, meaning use a single thread). + speedup : bool + If True, JIT-compile the model. This has a startup cost but model inference + should be faster (default False). + + Returns + ------- + A tuple of two lists of strings. The first list contains the slide IDs for which + patching failed, and the second list contains the slide IDs for which model + inference failed. + """ + # Make sure required directories exist. + wsi_dir = Path(wsi_dir) + if not wsi_dir.exists(): + raise errors.WholeSlideImageDirectoryNotFound(f"directory not found: {wsi_dir}") + wsi_paths = list(wsi_dir.glob("*")) + if not wsi_paths: + raise errors.WholeSlideImagesNotFound(wsi_dir) + results_dir = Path(results_dir) + if not results_dir.exists(): + raise errors.ResultsDirectoryNotFound(results_dir) + + # Check patches directory. + patch_dir = results_dir / "patches" + if not patch_dir.exists(): + raise errors.PatchDirectoryNotFound("Results dir must include 'patches' dir") + # Create the patch paths based on the whole slide image paths. In effect, only + # create patch paths if the whole slide image patch exists. + patch_paths = [patch_dir / p.with_suffix(".h5").name for p in wsi_paths] + + model_output_dir = results_dir / "model-outputs" + model_output_dir.mkdir(exist_ok=True) + + model = get_pretrained_torch_module(model=model_info) + model.eval() + + # Set the device. + if torch.cuda.is_available(): + device = torch.device("cuda") + if torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f'Using device "{device}"') + + model.to(device) + + if speedup: + if typing.TYPE_CHECKING: + model = typing.cast(torch.nn.Module, jit_compile(model)) + else: + model = jit_compile(model) + + transform = make_compose_from_transform_config(model_info.config.transform) + + failed_patching = [p.stem for p in patch_paths if not p.exists()] + failed_inference: List[str] = [] + + # Get paths to ROI geojson files. + if roi_dir is not None: + roi_paths = [Path(roi_dir) / p.with_suffix(".json").name for p in wsi_paths] + else: + roi_paths = None + + # results_for_all_slides: typing.List[pd.DataFrame] = [] + for i, (wsi_path, patch_path) in enumerate(zip(wsi_paths, patch_paths)): + print(f"Slide {i+1} of {len(wsi_paths)}") + print(f" Slide path: {wsi_path}") + print(f" Patch path: {patch_path}") + + slide_csv_name = Path(wsi_path).with_suffix(".csv").name + slide_csv = model_output_dir / slide_csv_name + if slide_csv.exists(): + print("Output CSV exists... skipping.") + print(slide_csv) + continue + + if not patch_path.exists(): + print(f"Skipping because patch file not found: {patch_path}") + continue + + roi_path = None + if roi_paths is not None: + roi_path = roi_paths[i] + # We grab all potential names of ROI paths, but we do not require all of + # them to exist. We only use those that exist. + if not roi_path.exists(): + roi_path = None + else: + print(f" ROI path: {roi_path}") + + try: + dset = WholeSlideImagePatches( + wsi_path=wsi_path, + patch_path=patch_path, + um_px=model_info.config.spacing_um_px, + patch_size=model_info.config.patch_size_pixels, + transform=transform, + roi_path=roi_path, + ) + except Exception: + failed_inference.append(wsi_dir.stem) + continue + + # The worker_init_fn does not seem to be used when num_workers=0 + # so we call it manually to finish setting up the dataset. + if num_workers == 0: + dset.worker_init() + + loader = torch.utils.data.DataLoader( + dset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + worker_init_fn=dset.worker_init, + ) + + # Store the coordinates and model probabiltiies of each patch in this slide. + # This lets us know where the probabiltiies map to in the slide. + slide_coords: List[np.ndarray] = [] + slide_probs: List[np.ndarray] = [] + for batch_imgs, batch_coords in tqdm.tqdm(loader): + assert batch_imgs.shape[0] == batch_coords.shape[0], "length mismatch" + with torch.no_grad(): + logits: torch.Tensor = model(batch_imgs.to(device)).detach().cpu() + # probs has shape (batch_size, num_classes) or (batch_size,) + if len(logits.shape) > 1 and logits.shape[1] > 1: + probs = torch.nn.functional.softmax(logits, dim=1) + else: + probs = torch.sigmoid(logits.squeeze(1)) + slide_coords.append(batch_coords.numpy()) + slide_probs.append(probs.numpy()) + + slide_coords_arr = np.concatenate(slide_coords, axis=0) + slide_df = pd.DataFrame( + dict( + minx=slide_coords_arr[:, 0], + miny=slide_coords_arr[:, 1], + width=slide_coords_arr[:, 2], + height=slide_coords_arr[:, 3], + ) + ) + slide_probs_arr = np.concatenate(slide_probs, axis=0) + # Use 'prob-' prefix for all classes. This should make it clearer that the + # column has probabilities for the class. It also makes it easier for us to + # identify columns associated with probabilities. + prob_colnames = [f"prob_{c}" for c in model_info.config.class_names] + slide_df.loc[:, prob_colnames] = slide_probs_arr + slide_df.to_csv(slide_csv, index=False) + print("-" * 40) + + return failed_patching, failed_inference diff --git a/wsinfer/modellib/transforms.py b/wsinfer/modellib/transforms.py new file mode 100644 index 0000000..eddce9d --- /dev/null +++ b/wsinfer/modellib/transforms.py @@ -0,0 +1,28 @@ +"""PyTorch image classification transform.""" + +from __future__ import annotations + +from typing import List + +from torchvision import transforms +from wsinfer_zoo.client import TransformConfigurationItem + +# The subset of transforms known to the wsinfer config spec. +# This can be expanded in the future as needs arise. +_name_to_tv_cls = { + "Resize": transforms.Resize, + "ToTensor": transforms.ToTensor, + "Normalize": transforms.Normalize, +} + + +def make_compose_from_transform_config( + list_of_transforms: List[TransformConfigurationItem], +) -> transforms.Compose: + """Create a torchvision Compose instance from configuration of transforms.""" + all_t: List = [] + for t in list_of_transforms: + cls = _name_to_tv_cls[t.name] + kwargs = t.arguments or {} + all_t.append(cls(**kwargs)) + return transforms.Compose(all_t) diff --git a/wsinfer/_patchlib/README.md b/wsinfer/patchlib/README.md similarity index 100% rename from wsinfer/_patchlib/README.md rename to wsinfer/patchlib/README.md diff --git a/wsinfer/_patchlib/__init__.py b/wsinfer/patchlib/__init__.py similarity index 96% rename from wsinfer/_patchlib/__init__.py rename to wsinfer/patchlib/__init__.py index ac575c9..215c3f7 100644 --- a/wsinfer/_patchlib/__init__.py +++ b/wsinfer/patchlib/__init__.py @@ -18,3 +18,5 @@ # - add --patch_spacing command line arg to request a patch size at a particular # spacing. The patch coordinates are calculated at the base (highest) resolution. # - format code with black + +from __future__ import annotations diff --git a/wsinfer/_patchlib/create_patches_fp.py b/wsinfer/patchlib/create_patches_fp.py similarity index 90% rename from wsinfer/_patchlib/create_patches_fp.py rename to wsinfer/patchlib/create_patches_fp.py index badb7c0..dc8125b 100644 --- a/wsinfer/_patchlib/create_patches_fp.py +++ b/wsinfer/patchlib/create_patches_fp.py @@ -31,19 +31,25 @@ (at your option) any later version. """ -# internal imports -from .wsi_core.WholeSlideImage import WholeSlideImage -from .wsi_core.wsi_utils import StitchCoords -from .wsi_core.batch_process_utils import initialize_df +from __future__ import annotations # other imports import os import pathlib -import numpy as np import time -import pandas as pd from typing import Optional +import numpy as np +import pandas as pd + +from wsinfer.errors import CannotReadSpacing +from wsinfer.patchlib.wsi_core.batch_process_utils import initialize_df + +# internal imports +from wsinfer.patchlib.wsi_core.WholeSlideImage import WholeSlideImage +from wsinfer.patchlib.wsi_core.wsi_utils import StitchCoords +from wsinfer.wsi import get_avg_mpp + _script_path = pathlib.Path(__file__).resolve().parent @@ -95,8 +101,8 @@ def seg_and_patch( patch_save_dir, mask_save_dir, stitch_save_dir, - patch_size=256, - step_size=256, + patch_size: int = 256, + step_size: Optional[int] = None, seg_params={ "seg_level": -1, "sthresh": 8, @@ -156,6 +162,7 @@ def seg_and_patch( stitch_times = 0.0 orig_patch_size = patch_size + orig_step_size = step_size for i in range(total): df.to_csv(os.path.join(save_dir, "process_list_autogen.csv"), index=False) @@ -296,29 +303,35 @@ def seg_and_patch( # Added by Jakub Kaczmarzyk (github kaczmarj) to get patch size for a # particular spacing. The patching happens at the highest resolution, but # we want to extract patches at a particular spacing. + this_step_size = None if patch_spacing is not None: - from PIL import Image - - orig_max = Image.MAX_IMAGE_PIXELS - import large_image - - # Importing large_image changes MAX_IMAGE_PIXELS to None. - Image.MAX_IMAGE_PIXELS = orig_max - del orig_max, Image - - ts = large_image.getTileSource(full_path) - if ts.getMetadata()["mm_x"] is None: + try: + slide_mpp = get_avg_mpp(full_path) + except CannotReadSpacing: print("!" * 40) - print("SKIPPING this slide because I cannot find the spacing!") - print(full_path) + print("SKIPPING this slide because the spacing cannot be read") print("!" * 40) continue - patch_mm = patch_spacing / 1000 # convert micrometer to millimeter. - patch_size = orig_patch_size * patch_mm / ts.getMetadata()["mm_x"] + + patch_size = orig_patch_size * patch_spacing / slide_mpp patch_size = round(patch_size) - del ts - # Use non-overlapping patches by default. - step_size = step_size or patch_size + print( + "Scaled patch size by the patch spacing (result is patches of" + f" {patch_size * slide_mpp} microns)" + ) + + # We use the variable orig_step_size because + if orig_step_size is not None: + this_step_size = orig_step_size * patch_spacing / slide_mpp + this_step_size = round(this_step_size) + print( + "Scaled step size by the patch spacing (result is steps of" + f" {this_step_size * slide_mpp} microns)" + ) + + step_size = this_step_size or patch_size + print(f"Using patch size = {patch_size} @ {slide_mpp} MPP") + print(f"Using step size = {step_size} @ {slide_mpp} MPP") # ---------------------------------------------------------------------- current_patch_params.update( @@ -433,10 +446,10 @@ def create_patches( seg_params[key] = preset_df.loc[0, key] for key in filter_params.keys(): - filter_params[key] = preset_df.loc[0, key] + filter_params[key] = preset_df.loc[0, key] # type: ignore for key in vis_params.keys(): - vis_params[key] = preset_df.loc[0, key] + vis_params[key] = preset_df.loc[0, key] # type: ignore for key in patch_params.keys(): patch_params[key] = preset_df.loc[0, key] diff --git a/wsinfer/_patchlib/presets/tcga.csv b/wsinfer/patchlib/presets/tcga.csv similarity index 100% rename from wsinfer/_patchlib/presets/tcga.csv rename to wsinfer/patchlib/presets/tcga.csv diff --git a/wsinfer/patchlib/utils/__init__.py b/wsinfer/patchlib/utils/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/wsinfer/patchlib/utils/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/wsinfer/_patchlib/utils/file_utils.py b/wsinfer/patchlib/utils/file_utils.py similarity index 97% rename from wsinfer/_patchlib/utils/file_utils.py rename to wsinfer/patchlib/utils/file_utils.py index a20fd3b..4f2ef8d 100644 --- a/wsinfer/_patchlib/utils/file_utils.py +++ b/wsinfer/patchlib/utils/file_utils.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import pickle + import h5py diff --git a/wsinfer/_patchlib/wsi_core/WholeSlideImage.py b/wsinfer/patchlib/wsi_core/WholeSlideImage.py similarity index 97% rename from wsinfer/_patchlib/wsi_core/WholeSlideImage.py rename to wsinfer/patchlib/wsi_core/WholeSlideImage.py index e89ab6d..c7703e6 100644 --- a/wsinfer/_patchlib/wsi_core/WholeSlideImage.py +++ b/wsinfer/patchlib/wsi_core/WholeSlideImage.py @@ -1,27 +1,28 @@ +from __future__ import annotations + import math +import multiprocessing as mp import os from xml.dom import minidom -import multiprocessing as mp + import cv2 import numpy as np -import openslide from PIL import Image -from .wsi_utils import ( - savePatchIter_bag_hdf5, - initialize_hdf5_bag, - save_hdf5, - isBlackPatch, - isWhitePatch, -) -from .util_classes import ( - isInContourV1, - isInContourV2, - isInContourV3_Easy, - isInContourV3_Hard, - Contour_Checking_fn, -) -from ..utils.file_utils import load_pkl, save_pkl +from wsinfer.wsi import WSI + +from ..utils.file_utils import load_pkl +from ..utils.file_utils import save_pkl +from .util_classes import Contour_Checking_fn +from .util_classes import isInContourV1 +from .util_classes import isInContourV2 +from .util_classes import isInContourV3_Easy +from .util_classes import isInContourV3_Hard +from .wsi_utils import initialize_hdf5_bag +from .wsi_utils import isBlackPatch +from .wsi_utils import isWhitePatch +from .wsi_utils import save_hdf5 +from .wsi_utils import savePatchIter_bag_hdf5 Image.MAX_IMAGE_PIXELS = 933120000 @@ -35,7 +36,7 @@ def __init__(self, path): # self.name = ".".join(path.split("/")[-1].split('.')[:-1]) self.name = os.path.splitext(os.path.basename(path))[0] - self.wsi = openslide.open_slide(path) + self.wsi = WSI(path) self.level_downsamples = self._assertLevelDownsamples() self.level_dim = self.wsi.level_dimensions @@ -341,7 +342,7 @@ def createPatches_bag_hdf5( patch_size=256, step_size=256, save_coord=True, - **kwargs + **kwargs, ): contours = self.contours_tissue @@ -566,7 +567,7 @@ def process_contours( save_path, patch_size, step_size, - **kwargs + **kwargs, ) if len(asset_dict) > 0: if init: diff --git a/wsinfer/patchlib/wsi_core/__init__.py b/wsinfer/patchlib/wsi_core/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/wsinfer/patchlib/wsi_core/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/wsinfer/_patchlib/wsi_core/batch_process_utils.py b/wsinfer/patchlib/wsi_core/batch_process_utils.py similarity index 99% rename from wsinfer/_patchlib/wsi_core/batch_process_utils.py rename to wsinfer/patchlib/wsi_core/batch_process_utils.py index 1c86abe..8358261 100644 --- a/wsinfer/_patchlib/wsi_core/batch_process_utils.py +++ b/wsinfer/patchlib/wsi_core/batch_process_utils.py @@ -1,5 +1,7 @@ -import pandas as pd +from __future__ import annotations + import numpy as np +import pandas as pd def initialize_df( diff --git a/wsinfer/_patchlib/wsi_core/util_classes.py b/wsinfer/patchlib/wsi_core/util_classes.py similarity index 99% rename from wsinfer/_patchlib/wsi_core/util_classes.py rename to wsinfer/patchlib/wsi_core/util_classes.py index 785c57c..ab5d936 100644 --- a/wsinfer/_patchlib/wsi_core/util_classes.py +++ b/wsinfer/patchlib/wsi_core/util_classes.py @@ -1,6 +1,8 @@ +from __future__ import annotations + +import cv2 import numpy as np from PIL import Image -import cv2 class Mosaic_Canvas(object): diff --git a/wsinfer/_patchlib/wsi_core/wsi_utils.py b/wsinfer/patchlib/wsi_core/wsi_utils.py similarity index 99% rename from wsinfer/_patchlib/wsi_core/wsi_utils.py rename to wsinfer/patchlib/wsi_core/wsi_utils.py index 2ca1cfa..d9eccca 100644 --- a/wsinfer/_patchlib/wsi_core/wsi_utils.py +++ b/wsinfer/patchlib/wsi_core/wsi_utils.py @@ -1,9 +1,12 @@ +from __future__ import annotations + +import math +import os + +import cv2 import h5py import numpy as np -import os from PIL import Image -import math -import cv2 from .util_classes import Mosaic_Canvas diff --git a/wsinfer/_modellib/__init__.py b/wsinfer/py.typed similarity index 100% rename from wsinfer/_modellib/__init__.py rename to wsinfer/py.typed diff --git a/wsinfer/schemas/model-config.schema.json b/wsinfer/schemas/model-config.schema.json new file mode 100644 index 0000000..81edefe --- /dev/null +++ b/wsinfer/schemas/model-config.schema.json @@ -0,0 +1,59 @@ +{ + "$schema": "http://json-schema.org/draft-04/schema", + "type": "object", + "properties": { + "num_classes": { + "type": "integer", + "description": "The number of classes the model outputs", + "minimum": 1 + }, + "patch_size_pixels": { + "type": "integer", + "description": "The size of the patch in pixels (eg 350)", + "minimum": 1 + }, + "spacing_um_px": { + "type": "number", + "description": "The spacing of the patch in micrometers per pixel (eg 0.5)", + "minimum": 0 + }, + "class_names": { + "type": "array", + "description": "The names of the classes the model outputs. Length must be equal to 'num_classes'.", + "items": { + "type": "string" + }, + "uniqueItems": true + }, + "transform": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "enum": [ + "Resize", + "ToTensor", + "Normalize" + ] + }, + "arguments": { + "type": "object" + } + }, + "additionalProperties": false, + "required": [ + "name" + ] + } + } + }, + "required": [ + "num_classes", + "patch_size_pixels", + "spacing_um_px", + "class_names", + "transform" + ] +} diff --git a/wsinfer/wsi.py b/wsinfer/wsi.py new file mode 100644 index 0000000..a605fa5 --- /dev/null +++ b/wsinfer/wsi.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import logging +from fractions import Fraction +from pathlib import Path +from typing import Literal +from typing import overload + +import tifffile + +from wsinfer.errors import CannotReadSpacing +from wsinfer.errors import NoBackendException + +logger = logging.getLogger(__name__) + + +try: + import openslide + + HAS_OPENSLIDE = True + logger.debug("Imported openslide") +except Exception as err: + HAS_OPENSLIDE = False + logger.debug(f"Unable to import openslide due to error: {err}") + +try: + import tiffslide + + HAS_TIFFSLIDE = True + logger.debug("Imported tiffslide") +except Exception as err: + HAS_TIFFSLIDE = False + logger.debug(f"Unable to import tiffslide due to error: {err}") + + +@overload +def set_backend(name: Literal["openslide"]) -> type[openslide.OpenSlide]: + ... + + +@overload +def set_backend(name: Literal["tiffslide"]) -> type[tiffslide.TiffSlide]: + ... + + +def set_backend( + name: Literal["openslide"] | Literal["tiffslide"], +) -> type[tiffslide.TiffSlide] | type[openslide.OpenSlide]: + global WSI + if name not in ["openslide", "tiffslide"]: + raise ValueError(f"Unknown backend: {name}") + logger.info(f"Setting backend to {name}") + if name == "openslide": + WSI = openslide.OpenSlide + elif name == "tiffslide": + WSI = tiffslide.TiffSlide + else: + raise ValueError(f"Unknown backend: {name}") + return WSI + + +# Set the slide backend based on the environment. +WSI: type[openslide.OpenSlide] | type[tifffile.TiffFile] +if HAS_OPENSLIDE: + WSI = set_backend("openslide") +elif HAS_TIFFSLIDE: + WSI = set_backend("tiffslide") +else: + raise NoBackendException("No backend found! Please install openslide or tiffslide") + + +def _get_mpp_openslide(slide_path: str | Path): + """Read MPP using OpenSlide.""" + logger.debug("Attempting to read MPP using OpenSlide") + slide = openslide.OpenSlide(slide_path) + mppx: float | None = None + mppy: float | None = None + + if ( + openslide.PROPERTY_NAME_MPP_X in slide.properties + and openslide.PROPERTY_NAME_MPP_Y in slide.properties + ): + logger.debug( + "Properties of the OpenSlide object contains keys" + f" {openslide.PROPERTY_NAME_MPP_X} and {openslide.PROPERTY_NAME_MPP_Y}" + ) + mppx = slide.properties[openslide.PROPERTY_NAME_MPP_X] + mppy = slide.properties[openslide.PROPERTY_NAME_MPP_Y] + logger.debug( + f"Value of {openslide.PROPERTY_NAME_MPP_X} is {mppx} and value" + f" of {openslide.PROPERTY_NAME_MPP_Y} is {mppy}" + ) + if mppx is not None and mppy is not None: + try: + logger.debug("Attempting to convert these MPP strings to floats") + mppx = float(mppx) + mppy = float(mppy) + return mppx, mppy + except Exception as err: + logger.debug(f"Exception caught while converting to float: {err}") + else: + logger.debug( + "Properties of the OpenSlide object does not contain keys" + f" {openslide.PROPERTY_NAME_MPP_X} and {openslide.PROPERTY_NAME_MPP_Y}" + ) + raise CannotReadSpacing() + + +def _get_mpp_tiffslide( + slide_path: str | Path, +) -> tuple[float, float]: + """Read MPP using TiffSlide.""" + slide = tiffslide.TiffSlide(slide_path) + mppx: float | None = None + mppy: float | None = None + if ( + tiffslide.PROPERTY_NAME_MPP_X in slide.properties + and tiffslide.PROPERTY_NAME_MPP_Y in slide.properties + ): + mppx = slide.properties[tiffslide.PROPERTY_NAME_MPP_X] + mppy = slide.properties[tiffslide.PROPERTY_NAME_MPP_Y] + if mppx is None or mppy is None: + raise CannotReadSpacing() + else: + try: + mppx = float(mppx) + mppy = float(mppy) + return mppx, mppy + except Exception as err: + raise CannotReadSpacing() from err + raise CannotReadSpacing() + + +# Modified from +# https://github.com/bayer-science-for-a-better-life/tiffslide/blob/8bea5a4c8e1429071ade6d4c40169ce153786d19/tiffslide/tiffslide.py#L712-L745 +def _get_mpp_tifffile(slide_path: str | Path) -> tuple[float, float]: + """Read MPP using Tifffile.""" + with tifffile.TiffFile(slide_path) as tif: + series0 = tif.series[0] + page0 = series0[0] + try: + resolution_unit = page0.tags["ResolutionUnit"].value + x_resolution = Fraction(*page0.tags["XResolution"].value) + y_resolution = Fraction(*page0.tags["YResolution"].value) + except KeyError as err: + raise CannotReadSpacing() from err + + RESUNIT = tifffile.TIFF.RESUNIT + scale = { + RESUNIT.INCH: 25400.0, + RESUNIT.CENTIMETER: 10000.0, + RESUNIT.MILLIMETER: 1000.0, + RESUNIT.MICROMETER: 1.0, + RESUNIT.NONE: None, + }.get(resolution_unit, None) + if scale is not None: + try: + mpp_x = scale / x_resolution + mpp_y = scale / y_resolution + return mpp_x, mpp_y + except ArithmeticError as err: + raise CannotReadSpacing() from err + raise CannotReadSpacing() + + +def get_avg_mpp(slide_path: Path | str) -> float: + """Return the average MPP of a whole slide image. + + The value is in units of micrometers per pixel and is + the average of the X and Y dimensions. + + Raises + ------ + CannotReadSpacing if the spacing cannot be read. + """ + + mppx: float + mppy: float + + if HAS_OPENSLIDE: + try: + mppx, mppy = _get_mpp_openslide(slide_path) + return (mppx + mppy) / 2 + except CannotReadSpacing: + # At this point, we want to continue to other implementations. + pass + if HAS_TIFFSLIDE: + try: + mppx, mppy = _get_mpp_tiffslide(slide_path) + return (mppx + mppy) / 2 + except CannotReadSpacing: + # Our last hope to read the mpp is tifffile. + pass + try: + mppx, mppy = _get_mpp_tifffile(slide_path) + return (mppx + mppy) / 2 + except CannotReadSpacing: + pass + + raise CannotReadSpacing(slide_path)