Skip to content

Commit

Permalink
should remove compute output shape of basic layer
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 21, 2024
1 parent bb4cb13 commit eec9c41
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions videoswin/blocks/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,24 @@ def build(self, input_shape):

self.built = True

def compute_output_shape(self, input_shape):
if self.downsample is not None:
# TODO: remove tensorflow dependencies.
# GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501
# output_shape = tf.TensorShape(
# [
# input_shape[0],
# self.depth_pad,
# self.height_pad // 2,
# self.width_pad // 2,
# 2 * self.input_dim,
# ]
# )

output_shape = (input_shape[0], ) + (self.depth_pad, ) + (self.height_pad // 2, ) + (self.width_pad // 2, ) + (self.input_dim, )
return output_shape

return input_shape
# def compute_output_shape(self, input_shape):
# if self.downsample is not None:
# # TODO: remove tensorflow dependencies.
# # GitHub issue: https://github.com/keras-team/keras/issues/19259 # noqa: E501
# # output_shape = tf.TensorShape(
# # [
# # input_shape[0],
# # self.depth_pad,
# # self.height_pad // 2,
# # self.width_pad // 2,
# # 2 * self.input_dim,
# # ]
# # )

# output_shape = (input_shape[0], ) + (self.depth_pad, ) + (self.height_pad // 2, ) + (self.width_pad // 2, ) + (self.input_dim, )
# return output_shape

# return input_shape

def call(self, x, training=None):
input_shape = ops.shape(x)
Expand Down

0 comments on commit eec9c41

Please sign in to comment.