diff --git a/setup.py b/setup.py index 4fe1fcc..32f9753 100644 --- a/setup.py +++ b/setup.py @@ -4,8 +4,8 @@ name="videoswin", packages=find_packages(exclude=["notebooks", "assets"]), version="1.0.0", - license="MIT", - description="Video Swin Transformerss in Keras", + license="Apache License 2.0", + description="Video Swin Transformerss in Keras 3", long_description=open("README.md").read(), long_description_content_type="text/markdown", author="Mohammed Innat", @@ -15,16 +15,31 @@ install_requires=[ "opencv-python>=4.1.2", "keras-nightly", + "tensorflow-datasets", ], - setup_requires=[ - "pytest-runner", - ], - tests_require=["pytest"], + extras_require={ + "tests": [ + "flake8", + "isort", + "black[jupyter]", + "pytest", + "pycocotools", + ], + "examples": ["tensorflow_datasets", "matplotlib"], + }, + python_requires=">=3.9", classifiers=[ - "Development Status :: 1 - Planning", - "Intended Audience :: Developers", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + "Operating System :: Unix", + "Operating System :: Microsoft :: Windows", + "Operating System :: MacOS", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Software Development", ], ) diff --git a/test/test_backbone.py b/test/test_backbone.py new file mode 100644 index 0000000..cde755c --- /dev/null +++ b/test/test_backbone.py @@ -0,0 +1,67 @@ +import os + +import keras +import numpy as np +import pytest +import tensorflow as tf +from base import TestCase +from keras import ops + +from videoswin.model import VideoSwinBackbone + + +class TestVideoSwinSBackbone(TestCase): + + @pytest.mark.large + def test_call(self): + model = VideoSwinBackbone(include_rescaling=True, input_shape=(8, 256, 256, 3)) + x = np.ones((1, 8, 256, 256, 3)) + x_out = ops.convert_to_numpy(model(x)) + num_parameters = sum(np.prod(tuple(x.shape)) for x in model.trainable_variables) + self.assertEqual(x_out.shape, (1, 4, 8, 8, 768)) + self.assertEqual(num_parameters, 27_663_894) + + @pytest.mark.extra_large + def teat_save(self): + # saving test + model = VideoSwinBackbone(include_rescaling=False) + x = np.ones((1, 32, 224, 224, 3)) + x_out = ops.convert_to_numpy(model(x)) + path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(path) + loaded_model = keras.saving.load_model(path) + x_out_loaded = ops.convert_to_numpy(loaded_model(x)) + self.assertAllClose(x_out, x_out_loaded) + + @pytest.mark.extra_large + def test_fit(self): + model = VideoSwinBackbone(include_rescaling=False) + x = np.ones((1, 32, 224, 224, 3)) + y = np.zeros((1, 16, 7, 7, 768)) + model.compile(optimizer="adam", loss="mse", metrics=["mse"]) + model.fit(x, y, epochs=1) + + @pytest.mark.extra_large + def test_can_run_in_mixed_precision(self): + keras.mixed_precision.set_global_policy("mixed_float16") + model = VideoSwinBackbone(include_rescaling=False, input_shape=(8, 224, 224, 3)) + x = np.ones((1, 8, 224, 224, 3)) + y = np.zeros((1, 4, 7, 7, 768)) + model.compile(optimizer="adam", loss="mse", metrics=["mse"]) + model.fit(x, y, epochs=1) + + @pytest.mark.extra_large + def test_can_run_on_gray_video(self): + model = VideoSwinBackbone( + include_rescaling=False, + input_shape=(96, 96, 96, 1), + window_size=[6, 6, 6], + ) + x = np.ones((1, 96, 96, 96, 1)) + y = np.zeros((1, 48, 3, 3, 768)) + model.compile(optimizer="adam", loss="mse", metrics=["mse"]) + model.fit(x, y, epochs=1) + + +if __name__ == "__main__": + tf.test.main() diff --git a/test/test_classifier.py b/test/test_classifier.py new file mode 100644 index 0000000..db14661 --- /dev/null +++ b/test/test_classifier.py @@ -0,0 +1,81 @@ +import os + +import keras +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized +from base import TestCase +from keras import ops + +from videoswin.model import VideoSwinT + + +class VideoClassifierTest(TestCase): + def setUp(self): + self.input_batch = np.ones(shape=(10, 8, 224, 224, 3)) + self.dataset = tf.data.Dataset.from_tensor_slices( + (self.input_batch, tf.one_hot(tf.ones((10,), dtype="int32"), 10)) + ).batch(4) + + def test_valid_call(self): + model = VideoSwinT( + input_shape=(8, 224, 224, 3), + include_rescaling=False, + num_classes=10, + ) + model(self.input_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + @pytest.mark.large # Fit is slow, so mark these large. + def test_classifier_fit(self, jit_compile): + model = VideoSwinT( + input_shape=(8, 224, 224, 3), + include_rescaling=True, + num_classes=10, + ) + model.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"], + jit_compile=jit_compile, + ) + model.fit(self.dataset) + + @parameterized.named_parameters(("avg_pooling", "avg"), ("max_pooling", "max")) + def test_pooling_arg_call(self, pooling): + model = VideoSwinT( + input_shape=(8, 224, 224, 3), + include_rescaling=True, + num_classes=10, + pooling=pooling, + ) + model(self.input_batch) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + model = VideoSwinT( + input_shape=(8, 224, 224, 3), + include_rescaling=False, + num_classes=10, + ) + model_output = model(self.input_batch) + save_path = os.path.join(self.get_temp_dir(), "video_classifier.keras") + model.save(save_path) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, VideoSwinT) + + # Check that output matches. + restored_output = restored_model(self.input_batch) + self.assertAllClose( + ops.convert_to_numpy(model_output), + ops.convert_to_numpy(restored_output), + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/test/test_model.py b/test/test_model.py deleted file mode 100644 index fdc10b7..0000000 --- a/test/test_model.py +++ /dev/null @@ -1,138 +0,0 @@ -import os - -import keras -import numpy as np -import pytest -import tensorflow as tf -from absl.testing import parameterized -from base import TestCase -from keras import ops - -from videoswin.model import VideoSwinBackbone, VideoSwinT - - -class TestVideoSwinSBackbone(TestCase): - - @pytest.mark.large - def test_call(self): - model = VideoSwinBackbone(include_rescaling=True, input_shape=(8, 256, 256, 3)) - x = np.ones((1, 8, 256, 256, 3)) - x_out = ops.convert_to_numpy(model(x)) - num_parameters = sum(np.prod(tuple(x.shape)) for x in model.trainable_variables) - self.assertEqual(x_out.shape, (1, 4, 8, 8, 768)) - self.assertEqual(num_parameters, 27_663_894) - - @pytest.mark.extra_large - def teat_save(self): - # saving test - model = VideoSwinBackbone(include_rescaling=False) - x = np.ones((1, 32, 224, 224, 3)) - x_out = ops.convert_to_numpy(model(x)) - path = os.path.join(self.get_temp_dir(), "model.keras") - model.save(path) - loaded_model = keras.saving.load_model(path) - x_out_loaded = ops.convert_to_numpy(loaded_model(x)) - self.assertAllClose(x_out, x_out_loaded) - - @pytest.mark.extra_large - def test_fit(self): - model = VideoSwinBackbone(include_rescaling=False) - x = np.ones((1, 32, 224, 224, 3)) - y = np.zeros((1, 16, 7, 7, 768)) - model.compile(optimizer="adam", loss="mse", metrics=["mse"]) - model.fit(x, y, epochs=1) - - @pytest.mark.extra_large - def test_can_run_in_mixed_precision(self): - keras.mixed_precision.set_global_policy("mixed_float16") - model = VideoSwinBackbone(include_rescaling=False, input_shape=(8, 224, 224, 3)) - x = np.ones((1, 8, 224, 224, 3)) - y = np.zeros((1, 4, 7, 7, 768)) - model.compile(optimizer="adam", loss="mse", metrics=["mse"]) - model.fit(x, y, epochs=1) - - @pytest.mark.extra_large - def test_can_run_on_gray_video(self): - model = VideoSwinBackbone( - include_rescaling=False, - input_shape=(96, 96, 96, 1), - window_size=[6, 6, 6], - ) - x = np.ones((1, 96, 96, 96, 1)) - y = np.zeros((1, 48, 3, 3, 768)) - model.compile(optimizer="adam", loss="mse", metrics=["mse"]) - model.fit(x, y, epochs=1) - - -class VideoClassifierTest(TestCase): - def setUp(self): - self.input_batch = np.ones(shape=(10, 8, 224, 224, 3)) - self.dataset = tf.data.Dataset.from_tensor_slices( - (self.input_batch, tf.one_hot(tf.ones((10,), dtype="int32"), 10)) - ).batch(4) - - def test_valid_call(self): - model = VideoSwinT( - backbone=VideoSwinBackbone( - input_shape=(8, 224, 224, 3), include_rescaling=False - ), - num_classes=10, - ) - model(self.input_batch) - - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - @pytest.mark.large # Fit is slow, so mark these large. - def test_classifier_fit(self, jit_compile): - model = VideoSwinT( - backbone=VideoSwinBackbone( - input_shape=(8, 224, 224, 3), include_rescaling=True - ), - num_classes=10, - ) - model.compile( - loss="categorical_crossentropy", - optimizer="adam", - metrics=["accuracy"], - jit_compile=jit_compile, - ) - model.fit(self.dataset) - - @parameterized.named_parameters(("avg_pooling", "avg"), ("max_pooling", "max")) - def test_pooling_arg_call(self, pooling): - model = VideoSwinT( - backbone=VideoSwinBackbone( - input_shape=(8, 224, 224, 3), include_rescaling=True - ), - num_classes=10, - pooling=pooling, - ) - model(self.input_batch) - - @pytest.mark.large # Saving is slow, so mark these large. - def test_saved_model(self): - model = VideoSwinT( - backbone=VideoSwinBackbone( - input_shape=(8, 224, 224, 3), include_rescaling=False - ), - num_classes=10, - ) - model_output = model(self.input_batch) - save_path = os.path.join(self.get_temp_dir(), "video_classifier.keras") - model.save(save_path) - restored_model = keras.models.load_model(save_path) - - # Check we got the real object back. - self.assertIsInstance(restored_model, VideoSwinT) - - # Check that output matches. - restored_output = restored_model(self.input_batch) - self.assertAllClose( - ops.convert_to_numpy(model_output), - ops.convert_to_numpy(restored_output), - ) - - -if __name__ == "__main__": - tf.test.main() diff --git a/videoswin/model.py b/videoswin/model.py index bf9de6d..6d4221d 100644 --- a/videoswin/model.py +++ b/videoswin/model.py @@ -205,6 +205,7 @@ def get_config(self): def VideoSwinT( input_shape=(32, 224, 224, 3), num_classes=400, + pooling="avg", activation="softmax", embed_size=96, depths=[2, 2, 6, 2], @@ -212,6 +213,14 @@ def VideoSwinT( include_rescaling=False, include_top=True, ): + + if pooling == "avg": + pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool") + elif pooling == "max": + pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool") + else: + raise ValueError(f'`pooling` must be one of "avg", "max". Received: {pooling}.') + backbone = VideoSwinBackbone( input_shape=input_shape, embed_dim=embed_size, @@ -223,7 +232,6 @@ def VideoSwinT( if not include_top: return backbone - pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool") inputs = backbone.input x = backbone(inputs) x = pooling_layer(x) @@ -240,6 +248,7 @@ def VideoSwinT( def VideoSwinS( input_shape=(32, 224, 224, 3), num_classes=400, + pooling="avg", activation="softmax", embed_size=96, depths=[2, 2, 18, 2], @@ -247,6 +256,14 @@ def VideoSwinS( include_rescaling=False, include_top=True, ): + + if pooling == "avg": + pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool") + elif pooling == "max": + pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool") + else: + raise ValueError(f'`pooling` must be one of "avg", "max". Received: {pooling}.') + backbone = VideoSwinBackbone( input_shape=input_shape, embed_dim=embed_size, @@ -275,6 +292,7 @@ def VideoSwinS( def VideoSwinB( input_shape=(32, 224, 224, 3), num_classes=400, + pooling="avg", activation="softmax", embed_size=128, depths=[2, 2, 18, 2], @@ -283,6 +301,13 @@ def VideoSwinB( include_top=True, ): + if pooling == "avg": + pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool") + elif pooling == "max": + pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool") + else: + raise ValueError(f'`pooling` must be one of "avg", "max". Received: {pooling}.') + backbone = VideoSwinBackbone( input_shape=input_shape, embed_dim=embed_size,