diff --git a/test/test_classifier.py b/test/test_classifier.py index 8078feb..829b8d4 100644 --- a/test/test_classifier.py +++ b/test/test_classifier.py @@ -41,6 +41,9 @@ def test_valid_call_non_square_shape(self): ) @pytest.mark.large # Fit is slow, so mark these large. def test_classifier_fit(self, jit_compile): + if jit_compile and keras.backend.backend() == "torch": + self.skipTest("TODO: Torch Backend `jit_compile` fails on GPU.") + self.supports_jit = False model = VideoSwinT( input_shape=(8, 224, 224, 3), include_rescaling=True, diff --git a/videoswin/blocks/swin_transformer.py b/videoswin/blocks/swin_transformer.py index 34755b7..b5e6ced 100644 --- a/videoswin/blocks/swin_transformer.py +++ b/videoswin/blocks/swin_transformer.py @@ -126,6 +126,9 @@ def build(self, input_shape): [pad_l, self.pad_r], [0, 0], ] + self.do_pad = any( + value > 0 for value in (self.pad_d1, self.pad_r, self.pad_b) + ) self.built = True def first_forward(self, x, mask_matrix, training): @@ -196,11 +199,8 @@ def first_forward(self, x, mask_matrix, training): x = shifted_x # pad if required - do_pad = ops.logical_or( - ops.greater(self.pad_d1, 0), - ops.logical_or(ops.greater(self.pad_r, 0), ops.greater(self.pad_b, 0)), - ) - x = ops.cond(do_pad, lambda: x[:, :depth, :height, :width, :], lambda: x) + if self.do_pad: + return x[:, :depth, :height, :width, :] return x