From 31f4c412b2e7b488f2d3f7a2778fe1aaace0de9a Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Sat, 3 Aug 2024 17:52:36 +0100 Subject: [PATCH] feat: add example notebooks (#8) * feat: add nb * docs: update README * docs: fix link in example nb --- .github/README.md | 24 +- .pre-commit-config.yaml | 6 + ...mpted_segmentation_with_wandb_tables.ipynb | 618 ++++++++++++++++++ 3 files changed, 641 insertions(+), 7 deletions(-) create mode 100644 examples/notebooks/samv2_prompted_segmentation_with_wandb_tables.ipynb diff --git a/.github/README.md b/.github/README.md index c5e38e139..b1f14833a 100644 --- a/.github/README.md +++ b/.github/README.md @@ -1,13 +1,29 @@ +Open In Colab [![Build and Tests](https://github.com/SauravMaheshkar/samv2/actions/workflows/ci.yml/badge.svg)](https://github.com/SauravMaheshkar/samv2/actions/workflows/ci.yml) +CPU **compatible** fork of the official SAMv2 implementation. + +## Features 🚀 + +* CPU compatible +* ships with config files +* Run image and video inference on CPUs +* [Example notebooks](../examples/notebooks/) showcasing inference using weights and biases. + ## Installation -This is a CPU compatible fork of the official SAMv2 implementation. You can download it from [pypi](https://pypi.org/) using `pip` as follows: +You can download it from [pypi](https://pypi.org/) using `pip` as follows: ```bash pip install samv2 ``` +or from the repository: + +```bash +pip install git+https://github.com/SauravMaheshkar/samv2.git +``` + ## Usage After downloading the official weights, you can use the `load_model()` helper method to instantiate a model. @@ -22,12 +38,6 @@ model = load_model( ) ``` -## Features 🚀 - -* CPU compatible -* ships with config files -* Run image and video inference on CPUs - ## Citation ```bibtex diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aea0f5b91..32c074285 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,3 +16,9 @@ repos: rev: 5.13.2 hooks: - id: isort + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.8.5 + hooks: + - id: nbqa-black + - id: nbqa-isort + args: [ "--profile=black" ] diff --git a/examples/notebooks/samv2_prompted_segmentation_with_wandb_tables.ipynb b/examples/notebooks/samv2_prompted_segmentation_with_wandb_tables.ipynb new file mode 100644 index 000000000..8e3edb8f6 --- /dev/null +++ b/examples/notebooks/samv2_prompted_segmentation_with_wandb_tables.ipynb @@ -0,0 +1,618 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": "\"Open", + "id": "c5d558e73f170b5b" + }, + { + "cell_type": "markdown", + "source": [ + "## 📦 Packages and Basic Setup\n", + "---" + ], + "metadata": { + "id": "zhSRDDTS2qhe" + }, + "id": "zhSRDDTS2qhe" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": { + "collapsed": true, + "id": "initial_id" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install git+https://github.com/SauravMaheshkar/samv2.git wandb\n", + "!wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt\n", + "\n", + "url = \"https://github.com/SauravMaheshkar/SauravMaheshkar/blob/main/assets/text2img/llama_spiderman_coffee.png?raw=true\"" + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "\n", + "import wandb\n", + "from google.colab import userdata\n", + "\n", + "os.environ[\"WANDB_API_KEY\"] = userdata.get(\"W&B\")\n", + "\n", + "run = wandb.init(project=\"samv2\", entity=\"sauravmaheshkar\") # @param {type: \"string\"}\n", + "\n", + "columns = [\"image\", \"mask\", \"score\"]\n", + "wandb_table = wandb.Table(columns=columns)" + ], + "metadata": { + "id": "vY20gV4j9frm" + }, + "id": "vY20gV4j9frm", + "execution_count": null, + "outputs": [] + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-02T17:27:59.088404Z", + "start_time": "2024-08-02T17:27:58.330784Z" + }, + "id": "54c473e4017e8278" + }, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import requests\n", + "from PIL import Image\n", + "\n", + "image = Image.open(requests.get(url, stream=True).raw)\n", + "image = np.array(image.convert(\"RGB\"))" + ], + "id": "54c473e4017e8278", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "source": [ + "plt.imshow(image)" + ], + "metadata": { + "id": "R46X0MrpMKj4" + }, + "id": "R46X0MrpMKj4", + "execution_count": null, + "outputs": [] + }, + { + "metadata": { + "id": "c92dee6bc462cdc2" + }, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from sam2.build_sam import build_sam2\n", + "from sam2.sam2_image_predictor import SAM2ImagePredictor\n", + "from sam2.utils.misc import variant_to_config_mapping\n", + "from sam2.utils.visualization import show_masks\n", + "\n", + "model = build_sam2(\n", + " variant_to_config_mapping[\"tiny\"],\n", + " \"/content/sam2_hiera_tiny.pt\",\n", + ")\n", + "image_predictor = SAM2ImagePredictor(model)\n", + "image_predictor.set_image(image)" + ], + "id": "c92dee6bc462cdc2" + }, + { + "cell_type": "markdown", + "source": [ + "## Perform Segmentation with a single point" + ], + "metadata": { + "id": "j2jdXvBwSL6Z" + }, + "id": "j2jdXvBwSL6Z" + }, + { + "cell_type": "code", + "source": [ + "input_point = np.array([[300, 600]])\n", + "input_label = np.array([1])" + ], + "metadata": { + "id": "6AmRDG8-u_vk" + }, + "id": "6AmRDG8-u_vk", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "masks, scores, logits = image_predictor.predict(\n", + " point_coords=input_point,\n", + " point_labels=input_label,\n", + " box=None,\n", + " multimask_output=True,\n", + ")" + ], + "metadata": { + "id": "li3A3-lwKx2G" + }, + "id": "li3A3-lwKx2G", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask = show_masks(image, masks, scores)" + ], + "metadata": { + "id": "E566SeotyscN" + }, + "id": "E566SeotyscN", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sorted_ind = np.argsort(scores)[::-1]\n", + "print(f\"Top Score: {scores[sorted_ind[0]]}\")" + ], + "metadata": { + "id": "ykn5HCHA-mqw" + }, + "id": "ykn5HCHA-mqw", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask" + ], + "metadata": { + "id": "I0I-A9jO-ite" + }, + "id": "I0I-A9jO-ite", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "wandb_table.add_data(\n", + " wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]]\n", + ")" + ], + "metadata": { + "id": "Aa8ZPK77BAY8" + }, + "id": "Aa8ZPK77BAY8", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Perform Segmentation using multiple points" + ], + "metadata": { + "id": "4CyUCMRbSPzx" + }, + "id": "4CyUCMRbSPzx" + }, + { + "cell_type": "code", + "source": [ + "multi_point_coords = np.array([[300, 600], [700, 700]])\n", + "multi_point_labels = np.array([1, 1])" + ], + "metadata": { + "id": "WOgIC5EvRwEi" + }, + "id": "WOgIC5EvRwEi", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "masks, scores, _ = image_predictor.predict(\n", + " point_coords=multi_point_coords,\n", + " point_labels=multi_point_labels,\n", + " box=None,\n", + " multimask_output=False,\n", + ")" + ], + "metadata": { + "id": "b14bg796N-37" + }, + "id": "b14bg796N-37", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask = show_masks(image, masks, scores)" + ], + "metadata": { + "id": "ng9pGA-gSWtt" + }, + "id": "ng9pGA-gSWtt", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sorted_ind = np.argsort(scores)[::-1]\n", + "print(f\"Top Score: {scores[sorted_ind[0]]}\")" + ], + "metadata": { + "id": "p4DrtxUD_GcU" + }, + "id": "p4DrtxUD_GcU", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask" + ], + "metadata": { + "id": "BixsLaXv_G6j" + }, + "id": "BixsLaXv_G6j", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "wandb_table.add_data(\n", + " wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]]\n", + ")" + ], + "metadata": { + "id": "4Dk-YI1uBPR_" + }, + "id": "4Dk-YI1uBPR_", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Perform Segmentation using a single bounding box" + ], + "metadata": { + "id": "VNg52-FySjRj" + }, + "id": "VNg52-FySjRj" + }, + { + "cell_type": "code", + "source": [ + "single_box_coords = np.array([656, 655, 798, 816])" + ], + "metadata": { + "id": "S_5ElJvySkiU" + }, + "id": "S_5ElJvySkiU", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "masks, scores, _ = image_predictor.predict(\n", + " point_coords=None,\n", + " point_labels=None,\n", + " box=single_box_coords,\n", + " multimask_output=False,\n", + ")" + ], + "metadata": { + "id": "JKp9me1eR__y" + }, + "id": "JKp9me1eR__y", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask = show_masks(image, masks, scores=None, display_image=False)" + ], + "metadata": { + "id": "KDOcDGhhSEej" + }, + "id": "KDOcDGhhSEej", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sorted_ind = np.argsort(scores)[::-1]\n", + "print(f\"Top Score: {scores[sorted_ind[0]]}\")" + ], + "metadata": { + "id": "HMknWBHq_PUf" + }, + "id": "HMknWBHq_PUf", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask" + ], + "metadata": { + "id": "GO-Z_Aud_Pwq" + }, + "id": "GO-Z_Aud_Pwq", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "wandb_table.add_data(\n", + " wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]]\n", + ")" + ], + "metadata": { + "id": "GNWVKKOfBSyI" + }, + "id": "GNWVKKOfBSyI", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Perform Segmentation using multiple bounding boxes" + ], + "metadata": { + "id": "tVcf4zEj1XvI" + }, + "id": "tVcf4zEj1XvI" + }, + { + "cell_type": "code", + "source": [ + "multi_box_coords = np.array([[656, 655, 798, 816], [263, 518, 408, 653]])" + ], + "metadata": { + "id": "h1oYz2ta1L9V" + }, + "id": "h1oYz2ta1L9V", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "masks, scores, _ = image_predictor.predict(\n", + " point_coords=None,\n", + " point_labels=None,\n", + " box=multi_box_coords,\n", + " multimask_output=False,\n", + ")" + ], + "metadata": { + "id": "XgumUWOU1SlB" + }, + "id": "XgumUWOU1SlB", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask = show_masks(\n", + " image, masks, scores=None, only_best=False, display_image=False\n", + ")" + ], + "metadata": { + "id": "sdtCtZhQ1UiV" + }, + "id": "sdtCtZhQ1UiV", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sorted_ind = np.argsort(scores)[::-1]\n", + "print(f\"Top Score: {scores[sorted_ind[0]][0][0]}\")" + ], + "metadata": { + "id": "luAPytTQ_VuU" + }, + "id": "luAPytTQ_VuU", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask" + ], + "metadata": { + "id": "qtHF8WIW_WOS" + }, + "id": "qtHF8WIW_WOS", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "wandb_table.add_data(\n", + " wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]][0][0]\n", + ")" + ], + "metadata": { + "id": "iE7VvAA4BUhB" + }, + "id": "iE7VvAA4BUhB", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Perform Segmentation using a collection of boxes and points" + ], + "metadata": { + "id": "A67LVGps17XT" + }, + "id": "A67LVGps17XT" + }, + { + "cell_type": "code", + "source": [ + "box = np.array([263, 518, 408, 653])\n", + "point = np.array([[300, 600]])\n", + "label = np.array([1])" + ], + "metadata": { + "id": "d_U1vchO18Vb" + }, + "id": "d_U1vchO18Vb", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "masks, scores, _ = image_predictor.predict(\n", + " point_coords=point,\n", + " point_labels=label,\n", + " box=box,\n", + " multimask_output=False,\n", + ")" + ], + "metadata": { + "id": "2p7ZHsKu2LqO" + }, + "id": "2p7ZHsKu2LqO", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask = show_masks(image, masks, scores=None, display_image=False)" + ], + "metadata": { + "id": "IL_1x0Um2RaE" + }, + "id": "IL_1x0Um2RaE", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sorted_ind = np.argsort(scores)[::-1]\n", + "print(f\"Top Score: {scores[sorted_ind[0]]}\")" + ], + "metadata": { + "id": "BgjT2u2r_bas" + }, + "id": "BgjT2u2r_bas", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "output_mask" + ], + "metadata": { + "id": "2EjtuMa2_bvS" + }, + "id": "2EjtuMa2_bvS", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "wandb_table.add_data(\n", + " wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]]\n", + ")" + ], + "metadata": { + "id": "BFfRLf_CBZGF" + }, + "id": "BFfRLf_CBZGF", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "run.log({\"samv2_prompt_segmentation\": wandb_table})\n", + "\n", + "wandb.finish()" + ], + "metadata": { + "id": "2PdN7pwRBZqF" + }, + "id": "2PdN7pwRBZqF", + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + }, + "colab": { + "provenance": [], + "collapsed_sections": [ + "4CyUCMRbSPzx", + "VNg52-FySjRj" + ], + "include_colab_link": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}