From 42d0bb782dab8fee285e6e2e1681add6c94aa595 Mon Sep 17 00:00:00 2001 From: Adrien Delessert Date: Sat, 28 Dec 2024 12:30:04 -0500 Subject: [PATCH 1/3] feat: update for 2.1 models --- sam2/configs/2.1/sam2.1_hiera_b+.yaml | 116 +++++++++++++++ sam2/configs/2.1/sam2.1_hiera_l.yaml | 120 +++++++++++++++ sam2/configs/2.1/sam2.1_hiera_s.yaml | 119 +++++++++++++++ sam2/configs/2.1/sam2.1_hiera_t.yaml | 121 ++++++++++++++++ sam2/modeling/backbones/hieradet.py | 33 ++++- sam2/modeling/backbones/image_encoder.py | 2 +- sam2/modeling/backbones/utils.py | 10 +- sam2/modeling/memory_attention.py | 7 +- sam2/modeling/memory_encoder.py | 2 +- sam2/modeling/position_encoding.py | 44 ++++-- sam2/modeling/sam/prompt_encoder.py | 33 ++++- sam2/modeling/sam/transformer.py | 22 +-- sam2/modeling/sam2_base.py | 177 +++++++++++++++++------ sam2/modeling/sam2_utils.py | 174 ++++++++++++++++++++++ sam2/utils/misc.py | 7 +- 15 files changed, 891 insertions(+), 96 deletions(-) create mode 100644 sam2/configs/2.1/sam2.1_hiera_b+.yaml create mode 100644 sam2/configs/2.1/sam2.1_hiera_l.yaml create mode 100644 sam2/configs/2.1/sam2.1_hiera_s.yaml create mode 100644 sam2/configs/2.1/sam2.1_hiera_t.yaml diff --git a/sam2/configs/2.1/sam2.1_hiera_b+.yaml b/sam2/configs/2.1/sam2.1_hiera_b+.yaml new file mode 100644 index 000000000..d7172f9b0 --- /dev/null +++ b/sam2/configs/2.1/sam2.1_hiera_b+.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/2.1/sam2.1_hiera_l.yaml b/sam2/configs/2.1/sam2.1_hiera_l.yaml new file mode 100644 index 000000000..23073ea7a --- /dev/null +++ b/sam2/configs/2.1/sam2.1_hiera_l.yaml @@ -0,0 +1,120 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/2.1/sam2.1_hiera_s.yaml b/sam2/configs/2.1/sam2.1_hiera_s.yaml new file mode 100644 index 000000000..fd8d40465 --- /dev/null +++ b/sam2/configs/2.1/sam2.1_hiera_s.yaml @@ -0,0 +1,119 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/2.1/sam2.1_hiera_t.yaml b/sam2/configs/2.1/sam2.1_hiera_t.yaml new file mode 100644 index 000000000..e762aec93 --- /dev/null +++ b/sam2/configs/2.1/sam2.1_hiera_t.yaml @@ -0,0 +1,121 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [64, 64] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + no_obj_embed_spatial: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: true + proj_tpos_enc_in_obj_ptrs: true + use_signed_tpos_enc_to_obj_ptrs: true + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/sam2/modeling/backbones/hieradet.py b/sam2/modeling/backbones/hieradet.py index 4c6d3b9fc..19ac77b61 100644 --- a/sam2/modeling/backbones/hieradet.py +++ b/sam2/modeling/backbones/hieradet.py @@ -4,19 +4,22 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import logging from functools import partial from typing import List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr from sam2.modeling.backbones.utils import ( PatchEmbed, window_partition, window_unpartition, ) -from sam2.modeling.sam2_utils import MLP, DropPath + +from sam2.modeling.sam2_utils import DropPath, MLP def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: @@ -45,11 +48,7 @@ def __init__( self.dim = dim self.dim_out = dim_out - self.num_heads = num_heads - head_dim = dim_out // num_heads - self.scale = head_dim**-0.5 - self.q_pool = q_pool self.qkv = nn.Linear(dim, dim_out * 3) self.proj = nn.Linear(dim_out, dim_out) @@ -196,6 +195,7 @@ def __init__( 16, 20, ), + weights_path=None, return_interm_layers=True, # return feats from every stage ): super().__init__() @@ -265,6 +265,11 @@ def __init__( else [self.blocks[-1].dim_out] ) + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: h, w = hw window_embed = self.pos_embed_window @@ -292,3 +297,21 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: outputs.append(feats) return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/sam2/modeling/backbones/image_encoder.py b/sam2/modeling/backbones/image_encoder.py index 5f92baf47..c3ffefe5c 100644 --- a/sam2/modeling/backbones/image_encoder.py +++ b/sam2/modeling/backbones/image_encoder.py @@ -71,6 +71,7 @@ def __init__( self.position_encoding = position_encoding self.convs = nn.ModuleList() self.backbone_channel_list = backbone_channel_list + self.d_model = d_model for dim in backbone_channel_list: current = nn.Sequential() current.add_module( @@ -99,7 +100,6 @@ def __init__( self.fpn_top_down_levels = list(fpn_top_down_levels) def forward(self, xs: List[torch.Tensor]): - out = [None] * len(self.convs) pos = [None] * len(self.convs) assert len(xs) == len(self.convs) diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py index 32d55c754..930b1b762 100644 --- a/sam2/modeling/backbones/utils.py +++ b/sam2/modeling/backbones/utils.py @@ -32,9 +32,7 @@ def window_partition(x, window_size): Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = ( - x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - ) + windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) return windows, (Hp, Wp) @@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw): Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view( + x = windows.reshape( B, Hp // window_size, Wp // window_size, window_size, window_size, -1 ) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) if Hp > H or Wp > W: - x = x[:, :H, :W, :].contiguous() + x = x[:, :H, :W, :] return x diff --git a/sam2/modeling/memory_attention.py b/sam2/modeling/memory_attention.py index 00d708320..0cacb8b56 100644 --- a/sam2/modeling/memory_attention.py +++ b/sam2/modeling/memory_attention.py @@ -7,14 +7,14 @@ from typing import Optional import torch -from torch import Tensor, nn +from torch import nn, Tensor -from sam2.modeling.sam2_utils import get_activation_fn, get_clones from sam2.modeling.sam.transformer import RoPEAttention +from sam2.modeling.sam2_utils import get_activation_fn, get_clones -class MemoryAttentionLayer(nn.Module): +class MemoryAttentionLayer(nn.Module): def __init__( self, activation: str, @@ -87,7 +87,6 @@ def forward( query_pos: Optional[Tensor] = None, num_k_exclude_rope: int = 0, ) -> torch.Tensor: - # Self-Attn, Cross-Attn tgt = self._forward_sa(tgt, query_pos) tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) diff --git a/sam2/modeling/memory_encoder.py b/sam2/modeling/memory_encoder.py index e8b2df7aa..f60202dfa 100644 --- a/sam2/modeling/memory_encoder.py +++ b/sam2/modeling/memory_encoder.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F -from sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones +from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d class MaskDownSampler(nn.Module): diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py index 4a308c30b..2241d4cf1 100644 --- a/sam2/modeling/position_encoding.py +++ b/sam2/modeling/position_encoding.py @@ -8,6 +8,7 @@ from typing import Any, Optional, Tuple import numpy as np + import torch from torch import nn @@ -15,7 +16,7 @@ class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. + used by the Attention Is All You Need paper, generalized to work on images. """ def __init__( @@ -24,6 +25,11 @@ def __init__( temperature: int = 10000, normalize: bool = True, scale: Optional[float] = None, + # Following settings only relevant + # for warmping up cache for compilation + warmup_cache: bool = True, + image_size: int = 1024, + strides: Tuple[int] = (4, 8, 16, 32), ): super().__init__() assert num_pos_feats % 2 == 0, "Expecting even model width" @@ -37,6 +43,12 @@ def __init__( self.scale = scale self.cache = {} + if warmup_cache and torch.cuda.is_available(): + # Warmup cache for cuda, to help with compilation + device = torch.device("cuda") + for stride in strides: + cache_key = (image_size // stride, image_size // stride) + self._pe(1, device, *cache_key) def _encode_xy(self, x, y): # The positions are expected to be normalized @@ -75,19 +87,20 @@ def encode_points(self, x, y, labels): return pos @torch.no_grad() - def forward(self, x: torch.Tensor): - cache_key = (x.shape[-2], x.shape[-1]) + def _pe(self, B, device, *cache_key): + H, W = cache_key if cache_key in self.cache: - return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) + y_embed = ( - torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + torch.arange(1, H + 1, dtype=torch.float32, device=device) .view(1, -1, 1) - .repeat(x.shape[0], 1, x.shape[-1]) + .repeat(B, 1, W) ) x_embed = ( - torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + torch.arange(1, W + 1, dtype=torch.float32, device=device) .view(1, 1, -1) - .repeat(x.shape[0], x.shape[-2], 1) + .repeat(B, H, 1) ) if self.normalize: @@ -95,7 +108,7 @@ def forward(self, x: torch.Tensor): y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -110,6 +123,12 @@ def forward(self, x: torch.Tensor): self.cache[cache_key] = pos[0] return pos + @torch.no_grad() + def forward(self, x: torch.Tensor): + B = x.shape[0] + cache_key = (x.shape[-2], x.shape[-1]) + return self._pe(B, x.device, *cache_key) + class PositionEmbeddingRandom(nn.Module): """ @@ -210,6 +229,11 @@ def apply_rotary_enc( # repeat freqs along seq_len dim to match k seq_len if repeat_freqs_k: r = xk_.shape[-2] // xq_.shape[-2] - freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py index e177a71a3..c57876264 100644 --- a/sam2/modeling/sam/prompt_encoder.py +++ b/sam2/modeling/sam/prompt_encoder.py @@ -10,6 +10,7 @@ from torch import nn from sam2.modeling.position_encoding import PositionEmbeddingRandom + from sam2.modeling.sam2_utils import LayerNorm2d @@ -91,12 +92,32 @@ def _embed_points( point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight - point_embedding[labels == 0] += self.point_embeddings[0].weight - point_embedding[labels == 1] += self.point_embeddings[1].weight - point_embedding[labels == 2] += self.point_embeddings[2].weight - point_embedding[labels == 3] += self.point_embeddings[3].weight + + point_embedding = torch.where( + (labels == -1).unsqueeze(-1), + torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 0).unsqueeze(-1), + point_embedding + self.point_embeddings[0].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 1).unsqueeze(-1), + point_embedding + self.point_embeddings[1].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 2).unsqueeze(-1), + point_embedding + self.point_embeddings[2].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 3).unsqueeze(-1), + point_embedding + self.point_embeddings[3].weight, + point_embedding, + ) return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py index dbae913c9..f9fe9a3fb 100644 --- a/sam2/modeling/sam/transformer.py +++ b/sam2/modeling/sam/transformer.py @@ -5,20 +5,15 @@ # LICENSE file in the root directory of this source tree. import math -import warnings from functools import partial from typing import Tuple, Type import torch import torch.nn.functional as F -from torch import Tensor, nn +from torch import nn, Tensor from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.sam2_utils import MLP -from sam2.utils.misc import get_sdp_backends - -warnings.simplefilter(action="ignore", category=FutureWarning) -# OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() class TwoWayTransformer(nn.Module): @@ -245,9 +240,7 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: dropout_p = self.dropout_p if self.training else 0.0 # Attention - - with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) @@ -265,7 +258,7 @@ def __init__( # whether to repeat q rope to match k length # this is needed for cross-attention to memories rope_k_repeat=False, - feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution **kwargs, ): super().__init__(*args, **kwargs) @@ -274,7 +267,9 @@ def __init__( compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta ) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) - self.freqs_cis = freqs_cis + self.freqs_cis = ( + freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis + ) self.rope_k_repeat = rope_k_repeat def forward( @@ -307,9 +302,8 @@ def forward( ) dropout_p = self.dropout_p if self.training else 0.0 - - with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + # Attention + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py index 4326f448d..d9f4e515b 100644 --- a/sam2/modeling/sam2_base.py +++ b/sam2/modeling/sam2_base.py @@ -7,12 +7,13 @@ import torch import torch.distributed import torch.nn.functional as F + from torch.nn.init import trunc_normal_ -from sam2.modeling.sam2_utils import MLP, get_1d_sine_pe, select_closest_cond_frames from sam2.modeling.sam.mask_decoder import MaskDecoder from sam2.modeling.sam.prompt_encoder import PromptEncoder from sam2.modeling.sam.transformer import TwoWayTransformer +from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames # a large negative value as a placeholder score for missing objects NO_OBJ_SCORE = -1024.0 @@ -58,9 +59,6 @@ def __init__( # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. memory_temporal_stride_for_eval=1, - # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click - # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames - add_all_frames_to_correct_as_cond=False, # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) non_overlap_masks_for_mem_enc=False, # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder @@ -72,6 +70,9 @@ def __init__( # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) only_obj_ptrs_in_the_past_for_eval=False, @@ -87,6 +88,8 @@ def __init__( # hope to make recovery easier if there is a mistake and mitigate accumulation of errors soft_no_obj_ptr: bool = False, use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. sam_mask_decoder_extra_args=None, compile_image_encoder: bool = False, @@ -109,12 +112,13 @@ def __init__( if proj_tpos_enc_in_obj_ptrs: assert add_tpos_enc_to_obj_ptrs # these options need to be used together self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval # Part 2: memory attention to condition current frame's visual features # with memories (and obj ptrs) from past frames self.memory_attention = memory_attention - self.hidden_dim = memory_attention.d_model + self.hidden_dim = image_encoder.neck.d_model # Part 3: memory encoder for the previous frame's outputs self.memory_encoder = memory_encoder @@ -169,9 +173,12 @@ def __init__( self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) trunc_normal_(self.no_obj_ptr, std=0.02) self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) self._build_sam_heads() - self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond self.max_cond_frames_in_attn = max_cond_frames_in_attn # Model compilation @@ -193,8 +200,8 @@ def device(self): def forward(self, *args, **kwargs): raise NotImplementedError( - "Please use the corresponding methods in SAM2VideoPredictor for inference." - "See notebooks/video_predictor_example.ipynb for an example." + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." ) def _build_sam_heads(self): @@ -387,8 +394,6 @@ def _forward_sam_heads( if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: - # Only hard possible with gt - assert not self.teacher_force_obj_scores_for_mem lambda_is_obj_appearing = object_score_logits.sigmoid() else: lambda_is_obj_appearing = is_obj_appearing.float() @@ -512,6 +517,7 @@ def _prepare_memory_conditioned_features( return pix_feat num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 # Step 1: condition the visual features of the current frame on previous memories if not is_init_cond_frame: # Retrieve the memories encoded with the maskmem backbone @@ -527,9 +533,9 @@ def _prepare_memory_conditioned_features( t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 - # We also allow taking the memory frame non-consecutively (with r>1), in which case - # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. - r = self.memory_temporal_stride_for_eval + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval for t_pos in range(1, self.num_maskmem): t_rel = self.num_maskmem - t_pos # how many frames before current frame if t_rel == 1: @@ -545,15 +551,15 @@ def _prepare_memory_conditioned_features( if not track_in_reverse: # first find the nearest frame among every r-th frames before this frame # for r=1, this would be (frame_idx - 2) - prev_frame_idx = ((frame_idx - 2) // r) * r + prev_frame_idx = ((frame_idx - 2) // stride) * stride # then seek further among every r-th frames - prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride else: # first find the nearest frame among every r-th frames after this frame # for r=1, this would be (frame_idx + 2) - prev_frame_idx = -(-(frame_idx + 2) // r) * r + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride # then seek further among every r-th frames - prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) if out is None: # If an unselected conditioning frame is among the last (self.num_maskmem - 1) @@ -566,10 +572,10 @@ def _prepare_memory_conditioned_features( continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, # so we load it back to GPU (it's a no-op if it's already on GPU). - feats = prev["maskmem_features"].to(self.device) + feats = prev["maskmem_features"].to(device, non_blocking=True) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) # Spatial positional encoding (it might have been offloaded to CPU in eval) - maskmem_enc = prev["maskmem_pos_enc"][-1].to(self.device) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) # Temporal positional encoding maskmem_enc = ( @@ -592,7 +598,14 @@ def _prepare_memory_conditioned_features( ptr_cond_outputs = selected_cond_outputs pos_and_ptrs = [ # Temporal pos encoding contains how far away each pointer is from current frame - (abs(frame_idx - t), out["obj_ptr"]) + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) for t, out in ptr_cond_outputs.items() ] # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame @@ -615,7 +628,9 @@ def _prepare_memory_conditioned_features( if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = torch.tensor(pos_list, device=device) + obj_pos = torch.tensor(pos_list).to( + device=device, non_blocking=True + ) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) @@ -641,7 +656,7 @@ def _prepare_memory_conditioned_features( pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem - # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder) + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] @@ -665,6 +680,7 @@ def _encode_new_memory( current_vision_feats, feat_sizes, pred_masks_high_res, + object_score_logits, is_mask_from_pts, ): """Encode the current image and its prediction into a memory feature.""" @@ -697,10 +713,19 @@ def _encode_new_memory( ) maskmem_features = maskmem_out["vision_features"] maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) return maskmem_features, maskmem_pos_enc - def track_step( + def _track_step( self, frame_idx, is_init_cond_frame, @@ -711,15 +736,8 @@ def track_step( mask_inputs, output_dict, num_frames, - track_in_reverse=False, # tracking in reverse time order (for demo usage) - # Whether to run the memory encoder on the predicted masks. Sometimes we might want - # to skip the memory encoder with `run_mem_encoder=False`. For example, - # in demo we might call `track_step` multiple times for each user click, - # and only encode the memory when the user finalizes their clicks. And in ablation - # settings like SAM training on static images, we don't need the memory encoder. - run_mem_encoder=True, - # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). - prev_sam_mask_logits=None, + track_in_reverse, + prev_sam_mask_logits, ): current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW @@ -740,7 +758,7 @@ def track_step( ) else: # fused the visual feature with previous memory features in the memory bank - pix_feat_with_mem = self._prepare_memory_conditioned_features( + pix_feat = self._prepare_memory_conditioned_features( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats[-1:], @@ -759,34 +777,32 @@ def track_step( mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) sam_outputs = self._forward_sam_heads( - backbone_features=pix_feat_with_mem, + backbone_features=pix_feat, point_inputs=point_inputs, mask_inputs=mask_inputs, high_res_features=high_res_features, multimask_output=multimask_output, ) - ( - _, - _, - _, - low_res_masks, - high_res_masks, - obj_ptr, - _, - ) = sam_outputs - current_out["pred_masks"] = low_res_masks - current_out["pred_masks_high_res"] = high_res_masks - current_out["obj_ptr"] = obj_ptr + return current_out, sam_outputs, high_res_features, pix_feat - # Finally run the memory encoder on the predicted mask to encode - # it into a new memory feature (that can be used in future frames) + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): if run_mem_encoder and self.num_maskmem > 0: high_res_masks_for_mem_enc = high_res_masks maskmem_features, maskmem_pos_enc = self._encode_new_memory( current_vision_feats=current_vision_feats, feat_sizes=feat_sizes, pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, is_mask_from_pts=(point_inputs is not None), ) current_out["maskmem_features"] = maskmem_features @@ -795,6 +811,71 @@ def track_step( current_out["maskmem_features"] = None current_out["maskmem_pos_enc"] = None + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + return current_out def _use_multimask(self, is_init_cond_frame, point_inputs): diff --git a/sam2/modeling/sam2_utils.py b/sam2/modeling/sam2_utils.py index 6d9705963..e16caae3a 100644 --- a/sam2/modeling/sam2_utils.py +++ b/sam2/modeling/sam2_utils.py @@ -6,11 +6,15 @@ import copy +from typing import Tuple +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from sam2.utils.misc import mask_to_box + def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): """ @@ -147,3 +151,173 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index 2f143ee8d..4b7dcbabf 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -15,15 +15,20 @@ from torch.nn.attention import SDPBackend from tqdm import tqdm -VARIANTS: List[str] = ["tiny", "small", "base_plus", "large"] variant_to_config_mapping: Dict[str, str] = { "tiny": "sam2_hiera_t.yaml", "small": "sam2_hiera_s.yaml", "base_plus": "sam2_hiera_b+.yaml", "large": "sam2_hiera_l.yaml", + "2.1/tiny": "2.1/sam2.1_hiera_t.yaml", + "2.1/small": "2.1/sam2.1_hiera_s.yaml", + "2.1/base_plus": "2.1/sam2.1_hiera_b+.yaml", + "2.1/large": "2.1/sam2.1_hiera_l.yaml", } +VARIANTS: List[str] = list(variant_to_config_mapping.keys()) + def get_sdp_backends(dropout_p: float) -> Union[List[SDPBackend], SDPBackend]: backends = [] From e1549c1786dc691e258f6284c9dec7a161e1c26a Mon Sep 17 00:00:00 2001 From: Adrien Delessert Date: Sun, 29 Dec 2024 13:53:33 -0500 Subject: [PATCH 2/3] feat: add tests for version 2.1 --- .gitignore | 2 ++ sam2/utils/download.py | 53 +++++++++++++++++++++++---------------- tests/conftest.py | 28 +++++++++++++++++++++ tests/test_build_model.py | 17 +++++++++++-- 4 files changed, 76 insertions(+), 24 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..4aefbe308 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +artifacts/ \ No newline at end of file diff --git a/sam2/utils/download.py b/sam2/utils/download.py index 577a0b16e..c1ada6e7a 100644 --- a/sam2/utils/download.py +++ b/sam2/utils/download.py @@ -7,30 +7,39 @@ @pytest.fixture def download_weights(output_directory: str = "artifacts") -> None: - base_url: str = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" - file_names: List[str] = [ - "sam2_hiera_tiny.pt", - "sam2_hiera_small.pt", - "sam2_hiera_base_plus.pt", - "sam2_hiera_large.pt", + version_base_urls: dict = { + "2": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/", + "2.1": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/", + } + + file_suffixes: List[str] = [ + "hiera_tiny.pt", + "hiera_small.pt", + "hiera_base_plus.pt", + "hiera_large.pt", ] if not os.path.exists(output_directory): os.makedirs(output_directory) - for file_name in file_names: - file_path = os.path.join(output_directory, file_name) - if not os.path.exists(file_path): - url = f"{base_url}{file_name}" - command = ["wget", url, "-P", output_directory] - try: - result = subprocess.run( - command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - print(f"Download of {file_name} completed successfully.") - print(result.stdout.decode()) - except subprocess.CalledProcessError as e: - print(f"An error occurred during the download of {file_name}.") - print(e.stderr.decode()) - else: - print(f"{file_name} already exists. Skipping download.") + for version, base_url in version_base_urls.items(): + for suffix in file_suffixes: + file_name = f"sam{version}_{suffix}" + file_path = os.path.join(output_directory, file_name) + if not os.path.exists(file_path): + url = f"{base_url}{file_name}" + command = ["wget", url, "-P", output_directory] + try: + result = subprocess.run( + command, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + print(f"Download of {file_name} completed successfully.") + print(result.stdout.decode()) + except subprocess.CalledProcessError as e: + print(f"An error occurred during the download of {file_name}.") + print(e.stderr.decode()) + else: + print(f"{file_name} already exists. Skipping download.") diff --git a/tests/conftest.py b/tests/conftest.py index 44e8d7a7b..d51fe43a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,3 +42,31 @@ def video_predictor(download_weights): variant_to_config_mapping["tiny"], "./artifacts/sam2_hiera_tiny.pt", ) + + +@pytest.fixture +def image_predictor21(download_weights) -> "SAM2ImagePredictor": + model = build_sam2( + variant_to_config_mapping["2.1/tiny"], + "./artifacts/sam2_hiera_tiny.pt", + ) + image_predictor = SAM2ImagePredictor(model) + return image_predictor + + +@pytest.fixture +def mask_generator21(download_weights) -> "SAM2AutomaticMaskGenerator": + model = build_sam2( + variant_to_config_mapping["2.1/tiny"], + "./artifacts/sam2_hiera_tiny.pt", + ) + mask_generator = SAM2AutomaticMaskGenerator(model) + return mask_generator + + +@pytest.fixture +def video_predictor21(download_weights): + return build_sam2_video_predictor( + variant_to_config_mapping["2.1/tiny"], + "./artifacts/sam2_hiera_tiny.pt", + ) diff --git a/tests/test_build_model.py b/tests/test_build_model.py index 02a7b1e1c..cde773189 100644 --- a/tests/test_build_model.py +++ b/tests/test_build_model.py @@ -9,12 +9,25 @@ @pytest.mark.full @pytest.mark.parametrize( "variant", - ["tiny", "small", "base_plus", "large"], + [ + "tiny", + "small", + "base_plus", + "large", + "2.1/tiny", + "2.1/small", + "2.1/base_plus", + "2.1/large", + ], ) def test_build_sam(download_weights, variant: str): + parts = variant.split("/") + base_variant = parts[-1] + version = f"{parts[0]}" if len(parts) > 1 else "2" + model = build_sam2( variant_to_config_mapping[variant], - f"./artifacts/sam2_hiera_{variant}.pt", + f"./artifacts/sam{version}_hiera_{base_variant}.pt", ) assert isinstance(model, torch.nn.Module) From 0c6f3b2fc2ce5aded8e676c8cf31dba626a67bc7 Mon Sep 17 00:00:00 2001 From: Adrien Delessert Date: Sun, 29 Dec 2024 14:18:05 -0500 Subject: [PATCH 3/3] fix: remove accidental pytest dependency --- sam2/utils/__init__.py | 2 -- tests/conftest.py | 2 +- tests/test_build_model.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sam2/utils/__init__.py b/sam2/utils/__init__.py index d4d9f56d9..5277f4615 100644 --- a/sam2/utils/__init__.py +++ b/sam2/utils/__init__.py @@ -3,5 +3,3 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - -from .download import download_weights diff --git a/tests/conftest.py b/tests/conftest.py index d51fe43a4..2821bbbc0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from sam2.build_sam import build_sam2, build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor -from sam2.utils import download_weights +from sam2.utils.download import download_weights from sam2.utils.misc import variant_to_config_mapping diff --git a/tests/test_build_model.py b/tests/test_build_model.py index cde773189..71c9c9083 100644 --- a/tests/test_build_model.py +++ b/tests/test_build_model.py @@ -2,7 +2,7 @@ import torch from sam2.build_sam import build_sam2 -from sam2.utils import download_weights +from sam2.utils.download import download_weights from sam2.utils.misc import variant_to_config_mapping