Skip to content

Commit

Permalink
🎨 can it take non square input
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 22, 2024
1 parent eec9c41 commit 4bf9d9b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 25 deletions.
12 changes: 6 additions & 6 deletions videoswin/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def __init__(
"Depth, height and width of the video must be specified"
" in `input_shape`."
)
if input_spec.shape[-3] != input_spec.shape[-2]:
raise ValueError(
"Input video must be square i.e. the height must"
" be equal to the width in the `input_shape`"
" tuple/tensor."
)
# if input_spec.shape[-3] != input_spec.shape[-2]:
# raise ValueError(
# "Input video must be square i.e. the height must"
# " be equal to the width in the `input_shape`"
# " tuple/tensor."
# )

x = input_spec

Expand Down
19 changes: 0 additions & 19 deletions videoswin/blocks/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,6 @@ def build(self, input_shape):

self.built = True

# def compute_output_shape(self, input_shape):
# if self.downsample is not None:
# # TODO: remove tensorflow dependencies.
# # GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501
# # output_shape = tf.TensorShape(
# # [
# # input_shape[0],
# # self.depth_pad,
# # self.height_pad // 2,
# # self.width_pad // 2,
# # 2 * self.input_dim,
# # ]
# # )

# output_shape = (input_shape[0], ) + (self.depth_pad, ) + (self.height_pad // 2, ) + (self.width_pad // 2, ) + (self.input_dim, )
# return output_shape

# return input_shape

def call(self, x, training=None):
input_shape = ops.shape(x)
batch_size, depth, height, width, channel = (
Expand Down

0 comments on commit 4bf9d9b

Please sign in to comment.