Skip to content

Commit

Permalink
🐛 bug fix for aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 21, 2024
1 parent 2972b17 commit 1cb6e92
Showing 1 changed file with 34 additions and 41 deletions.
75 changes: 34 additions & 41 deletions videoswin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from keras import layers

from videoswin.blocks import VideoSwinBasicLayer
from videoswin.layers import (VideoSwinPatchingAndEmbedding,
VideoSwinPatchMerging)
from videoswin.layers import VideoSwinPatchingAndEmbedding, VideoSwinPatchMerging

from .utils import parse_model_inputs

Expand Down Expand Up @@ -252,17 +251,17 @@ def __init__(

def get_config(self):
config = {
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}

return config


Expand Down Expand Up @@ -322,20 +321,17 @@ def __init__(
self.include_top = include_top

def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}
)
config = {
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}
return config


Expand Down Expand Up @@ -395,18 +391,15 @@ def __init__(
self.include_top = include_top

def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}
)
config = {
"input_shape": self.input_shape[1:],
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"embed_size": self.embed_size,
"depths": self.depths,
"num_heads": self.num_heads,
"include_rescaling": self.include_rescaling,
"include_top": self.include_top,
}
return config

0 comments on commit 1cb6e92

Please sign in to comment.