Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to latest SAM2 #76

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
0c5f8c5
Initial commit
haithamkhedr Jul 29, 2024
658aaba
Use `weights_only` for loading
kit1980 Jul 29, 2024
662fd3d
Fix typo in README: "Aything" corrected to "Anything"
CharlesCNorton Jul 30, 2024
b3011f0
Merge pull request #5 from CharlesCNorton/patch-1
ronghanghu Jul 30, 2024
de05a2e
Correct typo in sav_dataset README.md
CharlesCNorton Jul 30, 2024
82b026c
Merge pull request #7 from CharlesCNorton/patch-2
ronghanghu Jul 30, 2024
f882beb
fix: correct spelling
CharlesCNorton Jul 30, 2024
c812718
Fix typo in comment: "evalaution" to "evaluation"
CharlesCNorton Jul 30, 2024
e62ec49
Fix: Hyphenate to "model-in-the-loop"
CharlesCNorton Jul 30, 2024
32750fa
Merge pull request #30 from CharlesCNorton/patch-4
ronghanghu Jul 30, 2024
cd270ed
Merge pull request #29 from CharlesCNorton/patch-3
ronghanghu Jul 30, 2024
86827e2
Merge pull request #32 from CharlesCNorton/patch-5
haithamkhedr Jul 30, 2024
fa2796b
Change git repo url from SSH to HTTPS
DanBrown47 Jul 31, 2024
0e78a11
Merge pull request #61 from DanBrown47/main
ronghanghu Aug 1, 2024
de4db16
Update README.md
haithamkhedr Aug 2, 2024
59550d4
Update README.md
haithamkhedr Aug 2, 2024
d1fc9a0
Merge pull request #116 from facebookresearch/arXiv-paper
haithamkhedr Aug 2, 2024
b744a3c
[doc] add `INSTALL.md` as an installation FAQ page
ronghanghu Aug 2, 2024
57bc94b
Merge pull request #119 from facebookresearch/ronghanghu/installation…
ronghanghu Aug 2, 2024
b72a8a9
First draft
NielsRogge Aug 3, 2024
17b7450
Use classmethod
NielsRogge Aug 3, 2024
3af4e82
Add model_id_to_filenames
NielsRogge Aug 3, 2024
0c28c63
Do not load config from the hub
NielsRogge Aug 3, 2024
6aeee34
Make huggingface_hub soft dependency
NielsRogge Aug 5, 2024
cb48213
Update links
NielsRogge Aug 5, 2024
e93be7f
Update README
NielsRogge Aug 5, 2024
841cc1f
Update docstring
NielsRogge Aug 5, 2024
acd3939
Add workflow
haithamkhedr Aug 5, 2024
3b0fd9e
Update workflow
haithamkhedr Aug 5, 2024
5e3d6ca
Merge pull request #1 from haithamkhedr/CI
haithamkhedr Aug 5, 2024
0230c5f
Merge pull request #152 from haithamkhedr/main
haithamkhedr Aug 5, 2024
c3393d8
Include original code snippet
NielsRogge Aug 5, 2024
e9503c9
Move HF to separate section
NielsRogge Aug 5, 2024
fbf7e3a
Add link
NielsRogge Aug 5, 2024
e815f70
Address comment
NielsRogge Aug 6, 2024
a36edf1
Clean up
NielsRogge Aug 6, 2024
6f7e700
Make it optional to build CUDA extension for SAM 2; also fallback to …
ronghanghu Aug 6, 2024
27a167c
Update README
NielsRogge Aug 6, 2024
0bac418
Update INSTALL.md (#156)
jhj0517 Aug 6, 2024
8f15c62
Format using ufmt
NielsRogge Aug 6, 2024
511199d
Updated INSTALL.md with CUDA_HOME-related troubleshooting (#140)
AmmoniumX Aug 6, 2024
322aa3e
Revert code snippet
NielsRogge Aug 6, 2024
43c385c
Update docstrings
NielsRogge Aug 6, 2024
6ec8560
Update hieradet.py
arun477 Aug 7, 2024
9b58611
Address comment
NielsRogge Aug 7, 2024
6ba4c65
Merge pull request #128 from NielsRogge/add_hf
haithamkhedr Aug 7, 2024
086daf0
Merge branch 'main' into patch-1
arun477 Aug 7, 2024
6ecb5ff
Add interface for box prompt in SAM 2 video predictor (#174)
ronghanghu Aug 7, 2024
6186d15
also catch errors during installation in case `CUDAExtension` cannot …
ronghanghu Aug 7, 2024
102ddb8
Merge branch 'main' into patch-1
arun477 Aug 8, 2024
d421e0b
add Colab support to the notebooks; pack config files in `sam2_config…
ronghanghu Aug 8, 2024
46945a2
Update hieradet.py
arun477 Aug 9, 2024
8f607e2
Merge branch 'main' into patch-1
arun477 Aug 9, 2024
778e112
Merge pull request #167 from arun477/patch-1
chayryali Aug 9, 2024
1034ee2
better support for non-CUDA devices (CPU, MPS) (#192)
ronghanghu Aug 12, 2024
dce7b54
improving warning message and adding further tips for installation (#…
ronghanghu Aug 12, 2024
1191677
Fix HF image predictor
haithamkhedr Aug 12, 2024
fd5125b
accept kwargs in auto_mask_generator
haithamkhedr Aug 13, 2024
0db838b
Merge pull request #205 from facebookresearch/haitham/fix_hf_image_pr…
haithamkhedr Aug 13, 2024
7e1596c
open `README.md` with unicode (to support Hugging Face emoji); fix va…
ronghanghu Aug 14, 2024
0f6515a
Merge branch 'main' into patch-1
kit1980 Aug 26, 2024
aa9b872
SAM2.1
haithamkhedr Sep 28, 2024
3a7889d
Merge pull request #335 from facebookresearch/sam2.1
chayryali Sep 29, 2024
429a2c7
minor update README.md
ronghanghu Sep 29, 2024
05d9e57
[docs] add a release note and new installation instructions for SAM 2…
ronghanghu Sep 30, 2024
98fcb16
Update links after renaming the repo from `segment-anything-2` to `sa…
ronghanghu Oct 1, 2024
52198ea
Merge pull request #2 from kit1980/patch-1
haithamkhedr Oct 1, 2024
8bf0920
Add MANIFEST.in (#353)
haithamkhedr Oct 3, 2024
e225218
[demo] add GPU to resources (#355)
ronghanghu Oct 3, 2024
29267c8
[doc] Check and raise an error if the user is running Python from the…
ronghanghu Oct 5, 2024
ff9704f
[sam2][demo][1/x] Fix file upload
raedle Oct 8, 2024
c98aa6b
Merge pull request #364 from facebookresearch/pr364
raedle Oct 8, 2024
c2ec8e1
remove unused paths (#384)
haithamkhedr Oct 14, 2024
393ae33
SAM 2 Update 12/11/2024 -- full model compilation for a major VOS spe…
ronghanghu Dec 11, 2024
722d1d1
patch for the case of `offload_state_to_cpu=True` in the new `SAM2Vid…
ronghanghu Dec 12, 2024
2b90b9f
remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and erro…
ronghanghu Dec 16, 2024
8b56c25
update to latest sam2
rentainhe Dec 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/check_fmt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: SAM2/fmt
on:
pull_request:
branches:
- main
jobs:
ufmt_check:
runs-on: ubuntu-latest
steps:
- name: Check formatting
uses: omnilib/ufmt@action-v1
with:
path: sam2 tools
version: "2.0.0b2"
python-version: "3.10"
black-version: "24.2.0"
usort-version: "1.0.2"
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ dmypy.json
*.pth
outputs/

.idea/
.idea/
demo/backend/checkpoints/*.pt
6 changes: 3 additions & 3 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

### Requirements

- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
Expand Down Expand Up @@ -121,9 +121,9 @@ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar

This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.

In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.

We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
</details>

<details>
Expand Down
4 changes: 2 additions & 2 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2023 - present, IDEA Research.
Copyright [yyyy] [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -198,4 +198,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
27 changes: 27 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## SAM 2 release notes

### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking

- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
* Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
* In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
* Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
* **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
* Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
* This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
* We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.

### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released

- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.

### 07/29/2024 -- SAM 2 is released

- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos.
* SAM 2 code: https://github.com/facebookresearch/sam2
* SAM 2 demo: https://sam2.metademolab.com/
* SAM 2 paper: https://arxiv.org/abs/2408.00714
2 changes: 1 addition & 1 deletion backend.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
ARG MODEL_SIZE=base_plus

FROM ${BASE_IMAGE}
Expand Down
2 changes: 1 addition & 1 deletion demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ cd demo/backend/server/
```bash
PYTORCH_ENABLE_MPS_FALLBACK=1 \
APP_ROOT="$(pwd)/../../../" \
APP_URL=http://localhost:7263 \
API_URL=http://localhost:7263 \
MODEL_SIZE=base_plus \
DATA_PATH="$(pwd)/../../data" \
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = [
"setuptools>=61.0",
"torch>=2.3.1",
"torch>=2.5.1",
]
build-backend = "setuptools.build_meta"
92 changes: 92 additions & 0 deletions sam2/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import time

import numpy as np
import torch
from tqdm import tqdm

from sam2.build_sam import build_sam2_video_predictor

# Only cuda supported
assert torch.cuda.is_available()
device = torch.device("cuda")

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Config and checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"

# Build video predictor with vos_optimized=True setting
predictor = build_sam2_video_predictor(
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
)


# Initialize with video
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)


# Number of runs, warmup etc
warm_up, runs = 5, 25
verbose = True
num_frames = len(frame_names)
total, count = 0, 0
torch.cuda.empty_cache()

# We will select an object with a click.
# See video_predictor_example.ipynb for more detailed explanation
ann_frame_idx, ann_obj_id = 0, 1
# Add a positive click at (x, y) = (210, 350)
# For labels, `1` means positive click
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], np.int32)

_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)

# Warmup and then average FPS over several runs
with torch.autocast("cuda", torch.bfloat16):
with torch.inference_mode():
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
start = time.time()
# Start tracking
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in predictor.propagate_in_video(inference_state):
pass

end = time.time()
total += end - start
count += 1
if i == warm_up - 1:
print("Warmup FPS: ", count * num_frames / total)
total = 0
count = 0

print("FPS: ", count * num_frames / total)
7 changes: 7 additions & 0 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,18 @@ def build_sam2_video_predictor(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
vos_optimized=False,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
]
if vos_optimized:
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
"++model.compile_image_encoder=True", # Let sam2_base handle this
]

if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -47,7 +47,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2.1/sam2.1_hiera_l.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -51,7 +51,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2.1/sam2.1_hiera_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2.1/sam2.1_hiera_t.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ trainer:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -108,7 +108,7 @@ trainer:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2/sam2_hiera_b+.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -47,7 +47,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2/sam2_hiera_l.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -51,7 +51,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2/sam2_hiera_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
4 changes: 2 additions & 2 deletions sam2/configs/sam2/sam2_hiera_t.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
Expand All @@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
Expand Down
Loading
Loading