Skip to content

Commit

Permalink
Merge pull request #31 from GaretJax/left-right
Browse files Browse the repository at this point in the history
Expire left/right attributes of modified parents.
  • Loading branch information
uralbash committed Aug 15, 2014
2 parents 517e76e + dad09d9 commit 7f2f5f2
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 30 deletions.
1 change: 1 addition & 0 deletions docs/sqlalchemy_mptt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Base events
.. autofunction:: mptt_before_insert
.. autofunction:: mptt_before_delete
.. autofunction:: mptt_before_update
.. autoclass:: TreesManager

Hidden method
~~~~~~~~~~~~~
Expand Down
8 changes: 8 additions & 0 deletions sqlalchemy_mptt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
# Copyright © 2014 uralbash <[email protected]>
#
# Distributed under terms of the MIT license.

from sqlalchemy.orm import mapper
from .mixins import BaseNestedSets
from .events import TreesManager

__version__ = "0.0.8"
__mixins__ = [BaseNestedSets]
__all__ = ['BaseNestedSets', 'mptt_sessionmaker']

tree_manager = TreesManager(BaseNestedSets)
tree_manager.register_mapper(mapper)
mptt_sessionmaker = tree_manager.register_factory
108 changes: 107 additions & 1 deletion sqlalchemy_mptt/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
"""
SQLAlchemy events extension
"""
from sqlalchemy import and_, case, select
import weakref

from sqlalchemy import and_, case, select, event, inspection
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.sql import func


Expand Down Expand Up @@ -306,3 +309,106 @@ def mptt_before_update(mapper, connection, instance):
tree_id=tree_id
)
)


class _WeakDictBasedSet(weakref.WeakKeyDictionary, object):
# In absence of a default weakset implementation, provide our own dict
# based solution.

def add(self, obj):
self[obj] = None

def discard(self, obj):
super(_WeakDictBasedSet, self).pop(obj, None)

def pop(self):
return self.popitem()[0]


class TreesManager(object):
"""
Manages events dispatching for all subclasses of a given class.
"""
def __init__(self, base_class):
self.base_class = base_class
self.classes = set()
self.instances = _WeakDictBasedSet()

def register_mapper(self, mapper):
for e, h in (
('before_insert', self.before_insert),
('before_update', self.before_update),
('before_delete', self.before_delete),
):
event.listen(self.base_class, e, h, propagate=True)
return self

def register_factory(self, sessionmaker):
"""
Registers this TreesManager instance to respond on
`after_flush_postexec` events on the given session or session factory.
This method returns the original argument, so that it can be used by
wrapping an already exisiting instance:
.. code-block:: python
:linenos:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, mapper
from sqlalchemy_mptt.mixins import BaseNestedSets
engine = create_engine('...')
trees_manager = TreesManager(BaseNestedSets)
trees_manager.register_mapper(mapper)
Session = tree_manager.register_factory(
sessionmaker(bind=engine)
)
A reference to this method, bound to a default instance of this class
and already registered to a mapper, is importable directly from
`sqlalchemy_mptt`:
.. code-block:: python
:linenos:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy_mptt import mptt_sessionmaker
engine = create_engine('...')
Session = mptt_sessionmaker(sessionmaker(bind=engine))
"""
event.listen(sessionmaker, 'after_flush_postexec',
self.after_flush_postexec)
return sessionmaker

def before_insert(self, mapper, connection, instance):
self.instances.add(instance)
mptt_before_insert(mapper, connection, instance)

def before_update(self, mapper, connection, instance):
self.instances.add(instance)
mptt_before_update(mapper, connection, instance)

def before_delete(self, mapper, connection, instance):
self.instances.discard(instance)
mptt_before_delete(mapper, connection, instance)

def after_flush_postexec(self, session, context):
"""
Event listener to recursively expire `left` and `right` attributes the
parents of all modified instances part of this flush.
"""
while self.instances:
instance = self.instances.pop()
parent = self.get_parent_value(instance)
while parent != NO_VALUE and parent is not None:
self.instances.discard(parent)
session.expire(parent, ['left', 'right'])
parent = self.get_parent_value(parent)

@staticmethod
def get_parent_value(instance):
return inspection.inspect(instance).attrs.parent.loaded_value
31 changes: 4 additions & 27 deletions sqlalchemy_mptt/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
"""
SQLAlchemy nested sets mixin
"""
import warnings

from sqlalchemy import Column, event, ForeignKey, Index, Integer
from sqlalchemy import Column, ForeignKey, Index, Integer
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref, relationship, mapper
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm.session import Session

from .events import mptt_before_delete, mptt_before_insert, mptt_before_update
from .events import _get_tree_table


Expand Down Expand Up @@ -122,9 +122,7 @@ def register_tree(cls):
MyMPTTmodel.register_tree()
"""
event.listen(cls, "before_insert", mptt_before_insert)
event.listen(cls, "before_update", mptt_before_update)
event.listen(cls, "before_delete", mptt_before_delete)
warnings.warn('Trees are registered automatically', DeprecationWarning)

def move_inside(self, parent_id):
""" Moving one node of tree inside another
Expand Down Expand Up @@ -285,24 +283,3 @@ def rebuild(cls, session, tree_id=None):
trees = trees.filter_by(tree_id=tree_id)
for tree in trees:
cls.rebuild_tree(session, tree.tree_id)


class TreesManager(object):
def __init__(self, base_class):
self.base_class = base_class
self.classes = set()

def register(self, mapper):
event.listen(mapper, 'instrument_class', self.class_instrumented)
event.listen(mapper, 'after_configured', self.after_configured)

def class_instrumented(self, mapper, cls):
if issubclass(cls, self.base_class):
self.classes.add(cls)

def after_configured(self):
while self.classes:
self.classes.pop().register_tree()


TreesManager(BaseNestedSets).register(mapper)
File renamed without changes.
63 changes: 61 additions & 2 deletions sqlalchemy_mptt/tests/tree_testing_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy_mptt import mptt_sessionmaker


def add_fixture(model, fixtures, session):
Expand Down Expand Up @@ -102,7 +103,7 @@ class TreeTestingMixin(object):

def setUp(self):
self.engine = create_engine('sqlite:///:memory:')
Session = sessionmaker(bind=self.engine)
Session = mptt_sessionmaker(sessionmaker(bind=self.engine))
self.session = Session()
self.base.metadata.create_all(self.engine)
add_mptt_tree(self.session, self.model)
Expand All @@ -113,6 +114,60 @@ def setUp(self):
def tearDown(self):
self.base.metadata.drop_all(self.engine)

def test_explicit_registration(self):
# TODO: assertWarns was added to python > 3.2
# with self.assertWarns(DeprecationWarning):
self.model.register_tree()

def test_tree_orm_initialize(self):
t0 = self.model(ppk=30)
t1 = self.model(ppk=31, parent=t0)
t2 = self.model(ppk=32, parent=t1)
t3 = self.model(ppk=33, parent=t1)

self.session.add(t0)
self.session.flush()

self.assertEqual(t0.left, 1)
self.assertEqual(t0.right, 8)

self.assertEqual(t1.left, 2)
self.assertEqual(t1.right, 7)

self.assertEqual(t2.left, 3)
self.assertEqual(t2.right, 4)

self.assertEqual(t3.left, 5)
self.assertEqual(t3.right, 6)

t0 = self.model(ppk=40)
t1 = self.model(ppk=41, parent=t0)
t2 = self.model(ppk=42, parent=t1)
t3 = self.model(ppk=43, parent=t2)
t4 = self.model(ppk=44, parent=t3)
t5 = self.model(ppk=45, parent=t4)

self.session.add(t3)
self.session.flush()

self.assertEqual(t0.left, 1)
self.assertEqual(t0.right, 12)

self.assertEqual(t1.left, 2)
self.assertEqual(t1.right, 11)

self.assertEqual(t2.left, 3)
self.assertEqual(t2.right, 10)

self.assertEqual(t3.left, 4)
self.assertEqual(t3.right, 9)

self.assertEqual(t4.left, 5)
self.assertEqual(t4.right, 8)

self.assertEqual(t5.left, 6)
self.assertEqual(t5.right, 7)

def test_tree_initialize(self):
""" Initial state of the trees
Expand Down Expand Up @@ -1254,7 +1309,11 @@ def test_rebuild(self):
4 14(9)15 18(11)19
"""

self.session.query(self.model).update({'lft': 0, 'rgt': 0, 'level': 0})
self.session.query(self.model).update({
self.model.left: 0,
self.model.right: 0,
self.model.level: 0
})
self.model.rebuild(self.session, 1)
# id lft rgt lvl parent tree
self.assertEqual(self.result.all(),
Expand Down

0 comments on commit 7f2f5f2

Please sign in to comment.