From cc687de0b8689a86a0f525d9b24cf0bf64ec1383 Mon Sep 17 00:00:00 2001 From: Sergey Panfilov Date: Tue, 7 Apr 2020 08:59:23 +0200 Subject: [PATCH] Fix overriding of schema class --- marshmallow_objects/models.py | 2 +- tests/test_models.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/marshmallow_objects/models.py b/marshmallow_objects/models.py index 5fc2358..65d80a3 100644 --- a/marshmallow_objects/models.py +++ b/marshmallow_objects/models.py @@ -50,7 +50,7 @@ def __new__(mcs, name, parents, dct): parent_schemas = [] if parents: for parent in parents: - if issubclass(parent, Model): + if issubclass(parent, Model) and parent != Model: parent_schemas.append(parent.__schema_class__) parent_schemas = parent_schemas or [cls.__schema_class__ or marshmallow.Schema] schema_class = type(name + "Schema", tuple(parent_schemas), schema_fields) diff --git a/tests/test_models.py b/tests/test_models.py index 049325e..8e02c5d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -55,6 +55,15 @@ class MultiInheritance(A, B, C): pass +class CustomSchema(marshmallow.Schema): + def custom_method(self): + pass + + +class D(marshmallow.Model): + __schema_class__ = CustomSchema + + def serialize_context_field(obj, context=None): return obj.test_field == context["value"] @@ -113,6 +122,9 @@ def test_handle_error(self): id(MultiInheritance.handle_error), id(MultiInheritance.__schema_class__.handle_error), ) + def test_schema_class_override(self): + self.assertTrue(issubclass(D.__schema_class__, CustomSchema), D.__schema_class__.__bases__) + class TestModel(unittest.TestCase): def test_tag_field(self):