diff --git a/sqlalchemy_mptt/mixins.py b/sqlalchemy_mptt/mixins.py index a9b162b..39d666f 100644 --- a/sqlalchemy_mptt/mixins.py +++ b/sqlalchemy_mptt/mixins.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import backref, relationship, object_session from sqlalchemy.orm.session import Session from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.ext.hybrid import hybrid_method from .events import _get_tree_table @@ -72,7 +73,7 @@ def parent_id(cls): if not pk.name: pk.name = cls.get_pk_name() - return Column("parent_id", Integer, + return Column("parent_id", pk.type, ForeignKey('%s.%s' % (cls.__tablename__, pk.name), ondelete='CASCADE')) @@ -99,6 +100,30 @@ def right(cls): def level(cls): return Column("level", Integer, nullable=False, default=0) + @hybrid_method + def is_ancestor_of(self, other, inclusive=False): + """ class or instance level method which returns True if self is ancestor (closer to root) of other else False. + Optional flag `inclusive` on whether or not to treat self as ancestor of self. + + For example see: + + * :mod:`sqlalchemy_mptt.tests.cases.integrity.test_hierarchy_structure` + """ + if inclusive: + return (self.tree_id == other.tree_id) & (self.left <= other.left) & (other.right <= self.right) + return (self.tree_id == other.tree_id) & (self.left < other.left) & (other.right < self.right) + + @hybrid_method + def is_descendant_of(self, other, inclusive=False): + """ class or instance level method which returns True if self is descendant (farther from root) of other else False. + Optional flag `inclusive` on whether or not to treat self as descendant of self. + + For example see: + + * :mod:`sqlalchemy_mptt.tests.cases.integrity.test_hierarchy_structure` + """ + return other.is_ancestor_of(self, inclusive) + def move_inside(self, parent_id): """ Moving one node of tree inside another @@ -247,9 +272,7 @@ def _drilldown_query(self, nodes=None): table = self.__class__ if not nodes: nodes = self._base_query_obj() - return nodes.filter(table.tree_id == self.tree_id)\ - .filter(table.left >= self.left)\ - .filter(table.right <= self.right) + return nodes.filter(self.is_ancestor_of(table, inclusive=True)) def drilldown_tree(self, session=None, json=False, json_fields=None): """ This method generate a branch from a tree, begining with current @@ -309,9 +332,7 @@ def path_to_root(self, session=None): """ table = self.__class__ query = self._base_query_obj(session=session) - query = query.filter(table.tree_id == self.tree_id)\ - .filter(table.left <= self.left)\ - .filter(table.right >= self.right) + query = query.filter(table.is_ancestor_of(self, inclusive=True)) return self._base_order(query, order=desc) @classmethod diff --git a/sqlalchemy_mptt/tests/cases/integrity.py b/sqlalchemy_mptt/tests/cases/integrity.py index 1461760..22fa977 100644 --- a/sqlalchemy_mptt/tests/cases/integrity.py +++ b/sqlalchemy_mptt/tests/cases/integrity.py @@ -91,3 +91,36 @@ def test_left_and_right_always_unique_number(self): right = self.session.query(table.right) keys = [x[0] for x in left.union(right)] self.assertEqual(len(keys), len(set(keys))) + + def test_hierarchy_structure(self): + """ Nodes with left < self and right > self are considered ancestors, + while nodes with left > self and right < self are considered descendants + """ + table = self.model + pivot = self.session.query(table).filter(table.right - table.left != 1).filter(table.parent_id != None).first() + + # Exclusive Tests + ancestors = self.session.query(table).filter(table.is_ancestor_of(pivot)).all() + for ancestor in ancestors: + self.assertTrue(ancestor.is_ancestor_of(pivot)) + self.assertNotIn(pivot, ancestors) + + descendants = self.session.query(table).filter(table.is_descendant_of(pivot)).all() + for descendant in descendants: + self.assertTrue(descendant.is_descendant_of(pivot)) + self.assertNotIn(pivot, descendants) + + self.assertEqual(set(), set(ancestors).intersection(set(descendants))) + + # Inclusive Tests - because sometimes inclusivity is nice, like with self joins + ancestors = self.session.query(table).filter(table.is_ancestor_of(pivot, inclusive=True)).all() + for ancestor in ancestors: + self.assertTrue(ancestor.is_ancestor_of(pivot, inclusive=True)) + self.assertIn(pivot, ancestors) + + descendants = self.session.query(table).filter(table.is_descendant_of(pivot, inclusive=True)).all() + for descendant in descendants: + self.assertTrue(descendant.is_descendant_of(pivot, inclusive=True)) + self.assertIn(pivot, descendants) + + self.assertEqual(set([pivot]), set(ancestors).intersection(set(descendants)))