Skip to content

Commit

Permalink
update padding condition
Browse files Browse the repository at this point in the history
  • Loading branch information
innat authored Mar 26, 2024
1 parent d6dfb19 commit 32e24e1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
3 changes: 3 additions & 0 deletions test/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def test_valid_call_non_square_shape(self):
)
@pytest.mark.large # Fit is slow, so mark these large.
def test_classifier_fit(self, jit_compile):
if jit_compile and keras.backend.backend() == "torch":
self.skipTest("TODO: Torch Backend `jit_compile` fails on GPU.")
self.supports_jit = False
model = VideoSwinT(
input_shape=(8, 224, 224, 3),
include_rescaling=True,
Expand Down
10 changes: 5 additions & 5 deletions videoswin/blocks/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def build(self, input_shape):
[pad_l, self.pad_r],
[0, 0],
]
self.do_pad = any(
value > 0 for value in (self.pad_d1, self.pad_r, self.pad_b)
)
self.built = True

def first_forward(self, x, mask_matrix, training):
Expand Down Expand Up @@ -196,11 +199,8 @@ def first_forward(self, x, mask_matrix, training):
x = shifted_x

# pad if required
do_pad = ops.logical_or(
ops.greater(self.pad_d1, 0),
ops.logical_or(ops.greater(self.pad_r, 0), ops.greater(self.pad_b, 0)),
)
x = ops.cond(do_pad, lambda: x[:, :depth, :height, :width, :], lambda: x)
if self.do_pad:
return x[:, :depth, :height, :width, :]

return x

Expand Down

0 comments on commit 32e24e1

Please sign in to comment.