diff --git a/docs/example.rst b/docs/example.rst index 4cba0c5..df22c0c 100644 --- a/docs/example.rst +++ b/docs/example.rst @@ -33,40 +33,18 @@ Events registered automatically, but you can do it manually: .. code-block:: python - from sqlalchemy.orm import mapper - - from sqlalchemy_mptt.events import TreesManager - from .models import MyUser - - tree_manager = TreesManager(MyUser) - tree_manager.register_mapper(mapper) # register events before_insert, - # before_update and before_delete + from sqlalchemy_mptt import tree_manager + tree_manager.register_events() # register events before_insert, + # before_update and before_delete Or remove events if it required: .. code-block:: python - from sqlalchemy.orm import mapper - - from sqlalchemy_mptt.events import TreesManager - from sqlalchemy_mptt.mixins import BaseNestedSets - - tree_manager = TreesManager(BaseNestedSets) - tree_manager.register_mapper(mapper, # remove events before_insert, - remove=True) # before_update and before_delete - -Or remove for your custom model: - -.. code-block:: python - - from sqlalchemy.orm import mapper - - from sqlalchemy_mptt.events import TreesManager - from .models import MyUser + from sqlalchemy_mptt import tree_manager - tree_manager = TreesManager(MyUser) - tree_manager.register_mapper(mapper, # remove events before_insert, - remove=True) # before_update and before_delete + tree_manager.register_events(remove=True) # remove events before_insert, + # before_update and before_delete Data structure ~~~~~~~~~~~~~~ diff --git a/sqlalchemy_mptt/__init__.py b/sqlalchemy_mptt/__init__.py index aea0e90..e8b3a50 100644 --- a/sqlalchemy_mptt/__init__.py +++ b/sqlalchemy_mptt/__init__.py @@ -5,8 +5,6 @@ # Copyright (c) 2014 uralbash # # Distributed under terms of the MIT license. -from sqlalchemy.orm import mapper - from .events import TreesManager from .mixins import BaseNestedSets @@ -14,5 +12,5 @@ __all__ = ['BaseNestedSets', 'mptt_sessionmaker'] tree_manager = TreesManager(BaseNestedSets) -tree_manager.register_mapper(mapper) +tree_manager.register_events() mptt_sessionmaker = tree_manager.register_factory diff --git a/sqlalchemy_mptt/events.py b/sqlalchemy_mptt/events.py index 09dca36..e80603d 100644 --- a/sqlalchemy_mptt/events.py +++ b/sqlalchemy_mptt/events.py @@ -354,16 +354,16 @@ def __init__(self, base_class): self.classes = set() self.instances = _WeakDefaultDict() - def register_mapper(self, mapper, remove=False): + def register_events(self, remove=False): for e, h in ( ('before_insert', self.before_insert), ('before_update', self.before_update), ('before_delete', self.before_delete), ): - if remove: - if event.contains(self.base_class, e, h): - event.remove(self.base_class, e, h) - else: + is_event_exist = event.contains(self.base_class, e, h) + if remove and is_event_exist: + event.remove(self.base_class, e, h) + elif not is_event_exist: event.listen(self.base_class, e, h, propagate=True) return self diff --git a/sqlalchemy_mptt/tests/__init__.py b/sqlalchemy_mptt/tests/__init__.py index 5c17701..a6aaeca 100644 --- a/sqlalchemy_mptt/tests/__init__.py +++ b/sqlalchemy_mptt/tests/__init__.py @@ -86,16 +86,6 @@ def stop_query_counter(self): self.catch_queries) def setUp(self): - - # register events - from .test_mixins import Tree2 - from sqlalchemy.orm import mapper - from sqlalchemy_mptt.events import TreesManager - - tree_manager = TreesManager(Tree2) - tree_manager.register_mapper(mapper) - - # sqla settings self.engine = create_engine('sqlite:///:memory:') Session = mptt_sessionmaker(sessionmaker(bind=self.engine)) self.session = Session() diff --git a/sqlalchemy_mptt/tests/test_events.py b/sqlalchemy_mptt/tests/test_events.py index fbb6b3b..5664354 100644 --- a/sqlalchemy_mptt/tests/test_events.py +++ b/sqlalchemy_mptt/tests/test_events.py @@ -13,9 +13,7 @@ import unittest from sqlalchemy import Column, Boolean, Integer -from sqlalchemy.orm import mapper from sqlalchemy.event import contains -from sqlalchemy_mptt.events import TreesManager from sqlalchemy.ext.declarative import declarative_base from . import TreeTestingMixin @@ -56,12 +54,11 @@ class TestTreeWithCustomId(TreeTestingMixin, unittest.TestCase): model = TreeWithCustomId -class Events(object): +class Events(unittest.TestCase): def test_register(self): - from sqlalchemy_mptt import BaseNestedSets - tree_manager = TreesManager(BaseNestedSets) - tree_manager.register_mapper(mapper) + from sqlalchemy_mptt import tree_manager + tree_manager.register_events() self.assertTrue(contains(BaseNestedSets, 'before_insert', tree_manager.before_insert)) self.assertTrue(contains(BaseNestedSets, 'before_update', @@ -70,24 +67,24 @@ def test_register(self): tree_manager.before_delete)) def test_register_and_remove(self): - from sqlalchemy_mptt import BaseNestedSets - tree_manager = TreesManager(BaseNestedSets) - tree_manager.register_mapper(mapper) - tree_manager.register_mapper(mapper, remove=True) - self.assertFalse(contains(BaseNestedSets, 'before_insert', + from sqlalchemy_mptt import tree_manager + tree_manager.register_events() + tree_manager.register_events(remove=True) + self.assertFalse(contains(Tree, 'before_insert', tree_manager.before_insert)) - self.assertFalse(contains(BaseNestedSets, 'before_update', + self.assertFalse(contains(Tree, 'before_update', tree_manager.before_update)) - self.assertFalse(contains(BaseNestedSets, 'before_delete', + self.assertFalse(contains(Tree, 'before_delete', tree_manager.before_delete)) + tree_manager.register_events() def test_remove(self): - from sqlalchemy_mptt import BaseNestedSets - tree_manager = TreesManager(BaseNestedSets) - tree_manager.register_mapper(mapper, remove=True) - self.assertFalse(contains(BaseNestedSets, 'before_insert', + from sqlalchemy_mptt import tree_manager + tree_manager.register_events(remove=True) + self.assertFalse(contains(Tree, 'before_insert', tree_manager.before_insert)) - self.assertFalse(contains(BaseNestedSets, 'before_update', + self.assertFalse(contains(Tree, 'before_update', tree_manager.before_update)) - self.assertFalse(contains(BaseNestedSets, 'before_delete', + self.assertFalse(contains(Tree, 'before_delete', tree_manager.before_delete)) + tree_manager.register_events()