Skip to content

Commit

Permalink
Merge pull request #50 from Klupamos/master
Browse files Browse the repository at this point in the history
QOL changes
  • Loading branch information
uralbash committed Nov 12, 2015
2 parents 1ffd8ea + 8468fbd commit de982db
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 7 deletions.
35 changes: 28 additions & 7 deletions sqlalchemy_mptt/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlalchemy.orm import backref, relationship, object_session
from sqlalchemy.orm.session import Session
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_method

from .events import _get_tree_table

Expand Down Expand Up @@ -72,7 +73,7 @@ def parent_id(cls):
if not pk.name:
pk.name = cls.get_pk_name()

return Column("parent_id", Integer,
return Column("parent_id", pk.type,
ForeignKey('%s.%s' % (cls.__tablename__, pk.name),
ondelete='CASCADE'))

Expand All @@ -99,6 +100,30 @@ def right(cls):
def level(cls):
return Column("level", Integer, nullable=False, default=0)

@hybrid_method
def is_ancestor_of(self, other, inclusive=False):
""" class or instance level method which returns True if self is ancestor (closer to root) of other else False.
Optional flag `inclusive` on whether or not to treat self as ancestor of self.
For example see:
* :mod:`sqlalchemy_mptt.tests.cases.integrity.test_hierarchy_structure`
"""
if inclusive:
return (self.tree_id == other.tree_id) & (self.left <= other.left) & (other.right <= self.right)
return (self.tree_id == other.tree_id) & (self.left < other.left) & (other.right < self.right)

@hybrid_method
def is_descendant_of(self, other, inclusive=False):
""" class or instance level method which returns True if self is descendant (farther from root) of other else False.
Optional flag `inclusive` on whether or not to treat self as descendant of self.
For example see:
* :mod:`sqlalchemy_mptt.tests.cases.integrity.test_hierarchy_structure`
"""
return other.is_ancestor_of(self, inclusive)

def move_inside(self, parent_id):
""" Moving one node of tree inside another
Expand Down Expand Up @@ -247,9 +272,7 @@ def _drilldown_query(self, nodes=None):
table = self.__class__
if not nodes:
nodes = self._base_query_obj()
return nodes.filter(table.tree_id == self.tree_id)\
.filter(table.left >= self.left)\
.filter(table.right <= self.right)
return nodes.filter(self.is_ancestor_of(table, inclusive=True))

def drilldown_tree(self, session=None, json=False, json_fields=None):
""" This method generate a branch from a tree, begining with current
Expand Down Expand Up @@ -309,9 +332,7 @@ def path_to_root(self, session=None):
"""
table = self.__class__
query = self._base_query_obj(session=session)
query = query.filter(table.tree_id == self.tree_id)\
.filter(table.left <= self.left)\
.filter(table.right >= self.right)
query = query.filter(table.is_ancestor_of(self, inclusive=True))
return self._base_order(query, order=desc)

@classmethod
Expand Down
33 changes: 33 additions & 0 deletions sqlalchemy_mptt/tests/cases/integrity.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,36 @@ def test_left_and_right_always_unique_number(self):
right = self.session.query(table.right)
keys = [x[0] for x in left.union(right)]
self.assertEqual(len(keys), len(set(keys)))

def test_hierarchy_structure(self):
""" Nodes with left < self and right > self are considered ancestors,
while nodes with left > self and right < self are considered descendants
"""
table = self.model
pivot = self.session.query(table).filter(table.right - table.left != 1).filter(table.parent_id != None).first()

# Exclusive Tests
ancestors = self.session.query(table).filter(table.is_ancestor_of(pivot)).all()
for ancestor in ancestors:
self.assertTrue(ancestor.is_ancestor_of(pivot))
self.assertNotIn(pivot, ancestors)

descendants = self.session.query(table).filter(table.is_descendant_of(pivot)).all()
for descendant in descendants:
self.assertTrue(descendant.is_descendant_of(pivot))
self.assertNotIn(pivot, descendants)

self.assertEqual(set(), set(ancestors).intersection(set(descendants)))

# Inclusive Tests - because sometimes inclusivity is nice, like with self joins
ancestors = self.session.query(table).filter(table.is_ancestor_of(pivot, inclusive=True)).all()
for ancestor in ancestors:
self.assertTrue(ancestor.is_ancestor_of(pivot, inclusive=True))
self.assertIn(pivot, ancestors)

descendants = self.session.query(table).filter(table.is_descendant_of(pivot, inclusive=True)).all()
for descendant in descendants:
self.assertTrue(descendant.is_descendant_of(pivot, inclusive=True))
self.assertIn(pivot, descendants)

self.assertEqual(set([pivot]), set(ancestors).intersection(set(descendants)))

0 comments on commit de982db

Please sign in to comment.