Skip to content

Commit

Permalink
update to latest sam2
Browse files Browse the repository at this point in the history
  • Loading branch information
rentainhe committed Dec 21, 2024
2 parents dd4c514 + 2b90b9f commit 8b56c25
Show file tree
Hide file tree
Showing 28 changed files with 1,796 additions and 434 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/check_fmt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: SAM2/fmt
on:
pull_request:
branches:
- main
jobs:
ufmt_check:
runs-on: ubuntu-latest
steps:
- name: Check formatting
uses: omnilib/ufmt@action-v1
with:
path: sam2 tools
version: "2.0.0b2"
python-version: "3.10"
black-version: "24.2.0"
usort-version: "1.0.2"
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ dmypy.json
*.pth
outputs/

.idea/
.idea/
demo/backend/checkpoints/*.pt
6 changes: 3 additions & 3 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

### Requirements

- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
Expand Down Expand Up @@ -121,9 +121,9 @@ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar

This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.

In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.

We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
</details>

<details>
Expand Down
4 changes: 2 additions & 2 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2023 - present, IDEA Research.
Copyright [yyyy] [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -198,4 +198,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
27 changes: 27 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## SAM 2 release notes

### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking

- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
* Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
* In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
* Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
* **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
* Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
* This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
* We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.

### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released

- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.

### 07/29/2024 -- SAM 2 is released

- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos.
* SAM 2 code: https://github.com/facebookresearch/sam2
* SAM 2 demo: https://sam2.metademolab.com/
* SAM 2 paper: https://arxiv.org/abs/2408.00714
2 changes: 1 addition & 1 deletion backend.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
ARG MODEL_SIZE=base_plus

FROM ${BASE_IMAGE}
Expand Down
2 changes: 1 addition & 1 deletion demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ cd demo/backend/server/
```bash
PYTORCH_ENABLE_MPS_FALLBACK=1 \
APP_ROOT="$(pwd)/../../../" \
APP_URL=http://localhost:7263 \
API_URL=http://localhost:7263 \
MODEL_SIZE=base_plus \
DATA_PATH="$(pwd)/../../data" \
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = [
"setuptools>=61.0",
"torch>=2.3.1",
"torch>=2.5.1",
]
build-backend = "setuptools.build_meta"
92 changes: 92 additions & 0 deletions sam2/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import time

import numpy as np
import torch
from tqdm import tqdm

from sam2.build_sam import build_sam2_video_predictor

# Only cuda supported
assert torch.cuda.is_available()
device = torch.device("cuda")

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Config and checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"

# Build video predictor with vos_optimized=True setting
predictor = build_sam2_video_predictor(
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
)


# Initialize with video
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)


# Number of runs, warmup etc
warm_up, runs = 5, 25
verbose = True
num_frames = len(frame_names)
total, count = 0, 0
torch.cuda.empty_cache()

# We will select an object with a click.
# See video_predictor_example.ipynb for more detailed explanation
ann_frame_idx, ann_obj_id = 0, 1
# Add a positive click at (x, y) = (210, 350)
# For labels, `1` means positive click
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], np.int32)

_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)

# Warmup and then average FPS over several runs
with torch.autocast("cuda", torch.bfloat16):
with torch.inference_mode():
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
start = time.time()
# Start tracking
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in predictor.propagate_in_video(inference_state):
pass

end = time.time()
total += end - start
count += 1
if i == warm_up - 1:
print("Warmup FPS: ", count * num_frames / total)
total = 0
count = 0

print("FPS: ", count * num_frames / total)
7 changes: 7 additions & 0 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,18 @@ def build_sam2_video_predictor(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
vos_optimized=False,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
]
if vos_optimized:
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
"++model.compile_image_encoder=True", # Let sam2_base handle this
]

if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -47,7 +47,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2.1/sam2.1_hiera_l.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -51,7 +51,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2.1/sam2.1_hiera_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2.1/sam2.1_hiera_t.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ trainer:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -108,7 +108,7 @@ trainer:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2/sam2_hiera_b+.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -47,7 +47,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2/sam2_hiera_l.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -51,7 +51,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2/sam2_hiera_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2/sam2_hiera_t.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
Loading

0 comments on commit 8b56c25

Please sign in to comment.