Skip to content

Commit

Permalink
[update] Update 3D U-Net to export bioimage.io models
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanHCenalmor committed Dec 12, 2024
1 parent 14ae7e8 commit a11a50e
Show file tree
Hide file tree
Showing 4 changed files with 2,722 additions and 524 deletions.
4 changes: 2 additions & 2 deletions Colab_notebooks/Latest_Notebook_versions.csv
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Detectron 2D,1.15.1
Diffusion_Model,1.12
fnet (2D),1.14.1
fnet (3D),1.13.1
U-Net (2D),2.2.2
U-Net (3D),2.2.1
U-Net (2D),2.3.1
U-Net (3D),2.3.1
U-Net (2D) multilabel,2.1.3
Kaibu,1.13.2
MaskRCNN,1.14.1
Expand Down
95 changes: 67 additions & 28 deletions Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@
"\n",
" file.close()\n",
"\n",
"# Function to check source and target file paths include base path\n",
"def check_base_path(base_path, path_to_data):\n",
" if base_path not in path_to_data:\n",
" if path_to_data[0] == '/':\n",
" path_to_data = path_to_data[1:]\n",
" path_to_data = os.path.join(base_path, path_to_data)\n",
" return path_to_data\n",
"\n",
"import sys\n",
"before = [str(m) for m in sys.modules]\n",
"\n",
Expand Down Expand Up @@ -1328,11 +1336,21 @@
"# ------------- Initial user input ------------\n",
"#@markdown ###Path to training images:\n",
"Training_source = '' #@param {type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"Training_source = check_base_path(base_path, Training_source)\n",
"\n",
"Training_target = '' #@param {type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"Training_target = check_base_path(base_path, Training_target)\n",
"\n",
"model_name = '' #@param {type:\"string\"}\n",
"model_path = '' #@param {type:\"string\"}\n",
"\n",
"full_model_path = os.path.join(model_path, model_name)\n",
"\n",
"# Check that the base_path is on the path and otherwise add it\n",
"full_model_path = check_base_path(base_path, full_model_path)\n",
"\n",
"#@markdown ###Training parameters:\n",
"#@markdown Number of epochs\n",
"number_of_epochs = 200#@param {type:\"number\"}\n",
Expand All @@ -1354,7 +1372,6 @@
"\n",
"# ------------- Initialising folder, variables and failsafes ------------\n",
"# Create the folders where to save the model and the QC\n",
"full_model_path = os.path.join(model_path, model_name)\n",
"if os.path.exists(full_model_path):\n",
" print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n",
"\n",
Expand Down Expand Up @@ -1595,6 +1612,8 @@
"\n",
"#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n",
"pretrained_model_path = \"\" #@param {type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"pretrained_model_path = check_base_path(base_path, pretrained_model_path)\n",
"\n",
"#@markdown ###If you chose \"BioImage Model Zoo\", please provide the path or the URL to the model:\n",
"bioimageio_model_id = \"\" #@param {type:\"string\"}\n",
Expand Down Expand Up @@ -1905,25 +1924,28 @@
"#@markdown ###If not, please provide the path to the model folder:\n",
"\n",
"QC_model_folder = \"\" #@param {type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"\n",
"# In case the user chooses to use the trained model, load previous values\n",
"if (Use_the_current_trained_model):\n",
" QC_model_name = model_name\n",
" QC_model_path = model_path\n",
" QC_model_folder = os.path.join(model_path, model_name)\n",
"\n",
"# Check that the base_path is on the path and otherwise add it\n",
"QC_model_folder = check_base_path(base_path, QC_model_folder)\n",
"\n",
"#Here we define the loaded model name and path\n",
"QC_model_name = os.path.basename(QC_model_folder)\n",
"QC_model_path = os.path.dirname(QC_model_folder)\n",
"\n",
"full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n",
"\n",
"if (Use_the_current_trained_model):\n",
" print(\"Using current trained network\")\n",
" QC_model_name = model_name\n",
" QC_model_path = model_path\n",
"\n",
"\n",
"full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n",
"if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n",
" print(\"The \"+QC_model_name+\" network will be evaluated\")\n",
"else:\n",
" print(R+'!! WARNING: The chosen model does not exist !!'+W)\n",
" print('Please make sure you provide a valid model path and model name before proceeding further.')\n",
"\n"
" print('Please make sure you provide a valid model path and model name before proceeding further.')"
]
},
{
Expand Down Expand Up @@ -2025,18 +2047,20 @@
"# ------------- User input ------------\n",
"#@markdown ##Choose the folders that contain your Quality Control dataset\n",
"Source_QC_folder = \"\" #@param{type:\"string\"}\n",
"Target_QC_folder = \"\" #@param{type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"Source_QC_folder = check_base_path(base_path, Source_QC_folder)\n",
"\n",
"Target_QC_folder = \"\" #@param{type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"Target_QC_folder = check_base_path(base_path, Target_QC_folder)\n",
"\n",
"# ------------- Initialise folders ------------\n",
"# Create a quality control/Prediction Folder\n",
"prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n",
"if os.path.exists(prediction_QC_folder):\n",
" shutil.rmtree(prediction_QC_folder)\n",
"\n",
"os.makedirs(prediction_QC_folder)\n",
"\n",
"\n",
"# ------------- Prepare the model and run predictions ------------\n",
"\n",
"# Load the model\n",
Expand Down Expand Up @@ -2190,7 +2214,7 @@
"# ------------- User input ------------\n",
"# information about the model\n",
"#@markdown ##Insert the information to document your model:\n",
"Trained_model_name = \"\" #@param {type:\"string\"}\n",
"Trained_model_name = \"\" #@param {type:\"string\"}\n",
"Trained_model_description = \"\" #@param {type:\"string\"}\n",
"\n",
"#@markdown ###Author(s) - insert information separated by commas:\n",
Expand Down Expand Up @@ -2244,7 +2268,9 @@
"#@markdown ##Do you want to choose the example image?\n",
"default_example_image = True #@param {type:\"boolean\"}\n",
"#@markdown ###If not, please input:\n",
"fileID = \"\" #@param {type:\"string\"}\n",
"fileID = \"\" #@param {type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"fileID = check_base_path(base_path, fileID)\n",
"\n",
"# Check the example image\n",
"if default_example_image:\n",
Expand Down Expand Up @@ -2514,12 +2540,15 @@
},
"outputs": [],
"source": [
"\n",
"\n",
"# ------------- Initial user input ------------\n",
"#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n",
"Data_folder = '' #@param {type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"Data_folder = check_base_path(base_path, Data_folder)\n",
"\n",
"Results_folder = '' #@param {type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"Results_folder = check_base_path(base_path, Results_folder)\n",
"\n",
"# Create the results folder if needed\n",
"os.makedirs(Results_folder, exist_ok=True)\n",
Expand All @@ -2530,6 +2559,8 @@
"#@markdown ###If not, please provide the path to the model folder:\n",
"\n",
"Prediction_model_folder = \"\" #@param {type:\"string\"}\n",
"# Check that the base_path is on the path and otherwise add it\n",
"Prediction_model_folder = check_base_path(base_path, Prediction_model_folder)\n",
"\n",
"#Here we find the loaded model name and parent path\n",
"Prediction_model_name = os.path.basename(Prediction_model_folder)\n",
Expand All @@ -2554,8 +2585,6 @@
"\n",
"# Load the model and prepare generator\n",
"\n",
"\n",
"\n",
"unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'),\n",
" custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n",
"Input_size = unet.input_shape[1:3]\n",
Expand All @@ -2579,7 +2608,6 @@
"# ------------- For display ------------\n",
"print('--------------------------------------------------------------')\n",
"\n",
"\n",
"def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n",
"\n",
" plt.figure(figsize=(18,6))\n",
Expand All @@ -2603,9 +2631,7 @@
" plt.imshow(img_Mask, cmap='gray')\n",
" plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n",
"\n",
"\n",
"interact(show_prediction_mask, continuous_update=False);\n",
"\n"
"interact(show_prediction_mask, continuous_update=False);"
]
},
{
Expand All @@ -2627,15 +2653,14 @@
},
"outputs": [],
"source": [
"\n",
"# @markdown #Play this cell to save results as masks with the chosen threshold\n",
"\n",
"threshold = 120#@param {type:\"number\"}\n",
"\n",
"saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n",
"print('-------------------')\n",
"print('Masks were saved in: '+Results_folder)\n",
"\n",
"\n"
"print('-------------------')\n",
"print(f\"Masks were saved in: {Results_folder}\")"
]
},
{
Expand All @@ -2660,7 +2685,7 @@
"\n",
"---\n",
"\n",
"<font size = 4>**v2.2.2**: \n",
"<font size = 4>**v2.3.1**: \n",
"\n",
"* Updated Bioimage.IO model export to latest version (core-0.6.9, spec-0.5.3.2)\n",
"* Fixed model importation from Bioimage.IO\n",
Expand All @@ -2681,6 +2706,20 @@
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {
Expand Down
Loading

0 comments on commit a11a50e

Please sign in to comment.