Skip to content

Commit

Permalink
moduler
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 24, 2024
1 parent c590ba8 commit eea9283
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 145 deletions.
2 changes: 1 addition & 1 deletion test/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from base import TestCase
import keras
from base import TestCase
from keras import ops

from videoswin.layers import (
Expand Down
4 changes: 3 additions & 1 deletion videoswin/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def __init__(
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
norm_layer=norm_layer,
downsampling_layer=(VideoSwinPatchMerging if (i < num_layers - 1) else None),
downsampling_layer=(
VideoSwinPatchMerging if (i < num_layers - 1) else None
),
name=f"videoswin_basic_layer_{i + 1}",
)
x = layer(x)
Expand Down
5 changes: 2 additions & 3 deletions videoswin/blocks/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ def call(self, x, training=None):
x = self.downsample(x)

return x

def compute_output_shape(self, input_shape):
if self.downsampling_layer is not None:
output_shape = self.downsample.compute_output_shape(input_shape)
return output_shape
return output_shape

return input_shape

Expand All @@ -157,4 +157,3 @@ def get_config(self):
}
)
return config

2 changes: 1 addition & 1 deletion videoswin/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def call(self, x):
x = self.norm(x)

return x

def get_config(self):
config = super().get_config()
config.update(
Expand Down
4 changes: 1 addition & 3 deletions videoswin/layers/patch_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def call(self, x):

def compute_output_shape(self, input_shape):
batch_size, depth, height, width, _ = input_shape
return (
batch_size, depth, height // 2, width // 2, 2 * self.input_dim
)
return (batch_size, depth, height // 2, width // 2, 2 * self.input_dim)

def get_config(self):
config = super().get_config()
Expand Down
Loading

0 comments on commit eea9283

Please sign in to comment.