Skip to content

Commit

Permalink
🔥 move model definision to class method
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 21, 2024
1 parent f47bc2f commit 9a8006a
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 136 deletions.
1 change: 0 additions & 1 deletion test/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def test_pooling_arg_call(self, pooling):

@pytest.mark.large # Saving is slow, so mark these large.
def test_saved_model(self):
self.skipTest("Skipping saving test for now.")
model = VideoSwinT(
input_shape=(8, 224, 224, 3),
include_rescaling=False,
Expand Down
350 changes: 215 additions & 135 deletions videoswin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .utils import parse_model_inputs


@keras.utils.register_keras_serializable(package="swin.transformer.3d")
@keras.utils.register_keras_serializable(package="swin.transformer.backbone.3d")
class VideoSwinBackbone(keras.Model):
"""A Video Swin Transformer backbone model.
Expand Down Expand Up @@ -114,15 +114,7 @@ def __init__(
x = input_spec

if include_rescaling:
# Use common rescaling strategy across keras_cv
x = keras.layers.Rescaling(1.0 / 255.0)(x)

# VideoSwin scales inputs based on the ImageNet mean/stddev.
# Officially, Videw Swin takes tensor of [0-255] ranges.
# And use mean=[123.675, 116.28, 103.53] and
# std=[58.395, 57.12, 57.375] for normalization.
# So, if include_rescaling is set to True, then, to match with the
# official scores, following normalization should be added.
x = layers.Normalization(
mean=[0.485, 0.456, 0.406],
variance=[0.229**2, 0.224**2, 0.225**2],
Expand Down Expand Up @@ -202,132 +194,220 @@ def get_config(self):
return config


def VideoSwinT(
input_shape=(32, 224, 224, 3),
num_classes=400,
pooling="avg",
activation="softmax",
embed_size=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
include_rescaling=False,
include_top=True,
):

if pooling == "avg":
pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
elif pooling == "max":
pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool")
else:
raise ValueError(f'`pooling` must be one of "avg", "max". Received: {pooling}.')

backbone = VideoSwinBackbone(
input_shape=input_shape,
embed_dim=embed_size,
depths=depths,
num_heads=num_heads,
include_rescaling=include_rescaling,
)
@keras.utils.register_keras_serializable(package="swin.transformer.tiny.3d")
class VideoSwinT(keras.Model):
def __init__(
self,
input_shape=(32, 224, 224, 3),
num_classes=400,
pooling="avg",
activation="softmax",
embed_size=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
include_rescaling=False,
include_top=True,
**kwargs,
):

if not include_top:
return backbone

inputs = backbone.input
x = backbone(inputs)
x = pooling_layer(x)
outputs = keras.layers.Dense(
num_classes,
activation=activation,
name="predictions",
dtype="float32",
)(x)
model = keras.Model(inputs, outputs)
return model


def VideoSwinS(
input_shape=(32, 224, 224, 3),
num_classes=400,
pooling="avg",
activation="softmax",
embed_size=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
include_rescaling=False,
include_top=True,
):

if pooling == "avg":
pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
elif pooling == "max":
pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool")
else:
raise ValueError(f'`pooling` must be one of "avg", "max". Received: {pooling}.')

backbone = VideoSwinBackbone(
input_shape=input_shape,
embed_dim=embed_size,
depths=depths,
num_heads=num_heads,
include_rescaling=include_rescaling,
)
if pooling == "avg":
pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
elif pooling == "max":
pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool")
else:
raise ValueError(
f'`pooling` must be one of "avg", "max". Received: {pooling}.'
)

if not include_top:
return backbone

pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
inputs = backbone.input
x = backbone(inputs)
x = pooling_layer(x)
outputs = keras.layers.Dense(
num_classes,
activation=activation,
name="predictions",
dtype="float32",
)(x)
model = keras.Model(inputs, outputs)
return model


def VideoSwinB(
input_shape=(32, 224, 224, 3),
num_classes=400,
pooling="avg",
activation="softmax",
embed_size=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
include_rescaling=False,
include_top=True,
):

if pooling == "avg":
pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
elif pooling == "max":
pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool")
else:
raise ValueError(f'`pooling` must be one of "avg", "max". Received: {pooling}.')

backbone = VideoSwinBackbone(
input_shape=input_shape,
embed_dim=embed_size,
depths=depths,
num_heads=num_heads,
include_rescaling=include_rescaling,
)
backbone = VideoSwinBackbone(
input_shape=input_shape,
embed_dim=embed_size,
depths=depths,
num_heads=num_heads,
include_rescaling=include_rescaling,
)

if not include_top:
return backbone

pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
inputs = backbone.input
x = backbone(inputs)
x = pooling_layer(x)
outputs = keras.layers.Dense(
num_classes,
activation=activation,
name="predictions",
dtype="float32",
)(x)
model = keras.Model(inputs, outputs)
return model
if not include_top:
return backbone

inputs = backbone.input
x = backbone(inputs)
x = pooling_layer(x)
outputs = keras.layers.Dense(
num_classes,
activation=activation,
name="predictions",
dtype="float32",
)(x)
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
self.num_classes = num_classes
self.pooling = pooling
self.activation = activation
self.embed_size = embed_size
self.depths = depths
self.num_heads = num_heads
self.include_rescaling = include_rescaling
self.include_top = include_top

def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}
)
return config


@keras.utils.register_keras_serializable(package="swin.transformer.base.3d")
class VideoSwinB(keras.Model):
def __init__(
self,
input_shape=(32, 224, 224, 3),
num_classes=400,
pooling="avg",
activation="softmax",
embed_size=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
include_rescaling=False,
include_top=True,
**kwargs,
):

if pooling == "avg":
pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
elif pooling == "max":
pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool")
else:
raise ValueError(
f'`pooling` must be one of "avg", "max". Received: {pooling}.'
)

backbone = VideoSwinBackbone(
input_shape=input_shape,
embed_dim=embed_size,
depths=depths,
num_heads=num_heads,
include_rescaling=include_rescaling,
)

if not include_top:
return backbone

inputs = backbone.input
x = backbone(inputs)
x = pooling_layer(x)
outputs = keras.layers.Dense(
num_classes,
activation=activation,
name="predictions",
dtype="float32",
)(x)
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
self.num_classes = num_classes
self.pooling = pooling
self.activation = activation
self.embed_size = embed_size
self.depths = depths
self.num_heads = num_heads
self.include_rescaling = include_rescaling
self.include_top = include_top

def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}
)
return config


@keras.utils.register_keras_serializable(package="swin.transformer.base.3d")
class VideoSwinB(keras.Model):
def __init__(
self,
input_shape=(32, 224, 224, 3),
num_classes=400,
pooling="avg",
activation="softmax",
embed_size=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
include_rescaling=False,
include_top=True,
**kwargs,
):

if pooling == "avg":
pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
elif pooling == "max":
pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool")
else:
raise ValueError(
f'`pooling` must be one of "avg", "max". Received: {pooling}.'
)

backbone = VideoSwinBackbone(
input_shape=input_shape,
embed_dim=embed_size,
depths=depths,
num_heads=num_heads,
include_rescaling=include_rescaling,
)

if not include_top:
return backbone

inputs = backbone.input
x = backbone(inputs)
x = pooling_layer(x)
outputs = keras.layers.Dense(
num_classes,
activation=activation,
name="predictions",
dtype="float32",
)(x)
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
self.num_classes = num_classes
self.pooling = pooling
self.activation = activation
self.embed_size = embed_size
self.depths = depths
self.num_heads = num_heads
self.include_rescaling = include_rescaling
self.include_top = include_top

def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}
)
return config

0 comments on commit 9a8006a

Please sign in to comment.