diff --git a/CHANGES.rst b/CHANGES.rst index 1261d44b6fe..11ba6cbdef4 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -91,6 +91,9 @@ Features added * #12507: Add the :ref:`collapsible ` option to admonition directives. Patch by Chris Sewell. +* #8191, #8159: Add :rst:dir:`inheritance-diagram:include-subclasses` option to + the :rst:dir:`inheritance-diagram` directive. + Patch by Walter Dörwald. Bugs fixed ---------- diff --git a/doc/usage/extensions/inheritance.rst b/doc/usage/extensions/inheritance.rst index d6eee6879c7..ed977f86564 100644 --- a/doc/usage/extensions/inheritance.rst +++ b/doc/usage/extensions/inheritance.rst @@ -100,6 +100,24 @@ It adds this directive: .. versionchanged:: 1.7 Added ``top-classes`` option to limit the scope of inheritance graphs. + .. rst:directive:option:: include-subclasses + :type: no value + + .. versionadded:: 8.2 + + If given, any subclass of the classes will be added to the diagram too. + + Given the Python module from above, you can specify + your inheritance diagram like this: + + .. code-block:: rst + + .. inheritance-diagram:: dummy.test.A + :include-subclasses: + + This will include the classes A, B, C, D, E and F in the inheritance diagram + but no other classes in the module ``dummy.test``. + Examples -------- diff --git a/sphinx/ext/inheritance_diagram.py b/sphinx/ext/inheritance_diagram.py index 1834950742b..2966202322f 100644 --- a/sphinx/ext/inheritance_diagram.py +++ b/sphinx/ext/inheritance_diagram.py @@ -52,7 +52,7 @@ class E(B): pass from sphinx.util.docutils import SphinxDirective if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Collection, Iterable, Iterator, Sequence, Set from typing import Any, ClassVar, Final from docutils.nodes import Node @@ -106,7 +106,7 @@ def try_import(objname: str) -> Any: return None -def import_classes(name: str, currmodule: str) -> Any: +def import_classes(name: str, currmodule: str) -> list[type[Any]]: """Import a class using its fully-qualified *name*.""" target = None @@ -156,7 +156,8 @@ def __init__( private_bases: bool = False, parts: int = 0, aliases: dict[str, str] | None = None, - top_classes: Sequence[Any] = (), + top_classes: Set[str] = frozenset(), + include_subclasses: bool = False, ) -> None: """*class_names* is a list of child classes to show bases from. @@ -164,7 +165,12 @@ def __init__( in the graph. """ self.class_names = class_names - classes = self._import_classes(class_names, currmodule) + classes: Collection[type[Any]] = self._import_classes(class_names, currmodule) + if include_subclasses: + classes_set = {*classes} + for cls in tuple(classes_set): + classes_set.update(_subclasses(cls)) + classes = classes_set self.class_info = self._class_info( classes, show_builtins, private_bases, parts, aliases, top_classes ) @@ -172,21 +178,23 @@ def __init__( msg = 'No classes found for inheritance diagram' raise InheritanceException(msg) - def _import_classes(self, class_names: list[str], currmodule: str) -> list[Any]: + def _import_classes( + self, class_names: list[str], currmodule: str + ) -> Sequence[type[Any]]: """Import a list of classes.""" - classes: list[Any] = [] + classes: list[type[Any]] = [] for name in class_names: classes.extend(import_classes(name, currmodule)) return classes def _class_info( self, - classes: list[Any], + classes: Collection[type[Any]], show_builtins: bool, private_bases: bool, parts: int, aliases: dict[str, str] | None, - top_classes: Sequence[Any], + top_classes: Set[str], ) -> list[tuple[str, str, Sequence[str], str | None]]: """Return name and bases for all classes that are ancestors of *classes*. @@ -205,7 +213,7 @@ def _class_info( """ all_classes = {} - def recurse(cls: Any) -> None: + def recurse(cls: type[Any]) -> None: if not show_builtins and cls in PY_BUILTINS: return if not private_bases and cls.__name__.startswith('_'): @@ -248,7 +256,7 @@ def recurse(cls: Any) -> None: ] def class_name( - self, cls: Any, parts: int = 0, aliases: dict[str, str] | None = None + self, cls: type[Any], parts: int = 0, aliases: dict[str, str] | None = None ) -> str: """Given a class object, return a fully-qualified name. @@ -377,6 +385,7 @@ class InheritanceDiagram(SphinxDirective): 'private-bases': directives.flag, 'caption': directives.unchanged, 'top-classes': directives.unchanged_required, + 'include-subclasses': directives.flag, } def run(self) -> list[Node]: @@ -387,11 +396,11 @@ def run(self) -> list[Node]: # Store the original content for use as a hash node['parts'] = self.options.get('parts', 0) node['content'] = ', '.join(class_names) - node['top-classes'] = [] - for cls in self.options.get('top-classes', '').split(','): - cls = cls.strip() - if cls: - node['top-classes'].append(cls) + node['top-classes'] = frozenset({ + cls_stripped + for cls in self.options.get('top-classes', '').split(',') + if (cls_stripped := cls.strip()) + }) # Create a graph starting with the list of classes try: @@ -402,6 +411,7 @@ def run(self) -> list[Node]: private_bases='private-bases' in self.options, aliases=self.config.inheritance_alias, top_classes=node['top-classes'], + include_subclasses='include-subclasses' in self.options, ) except InheritanceException as err: return [node.document.reporter.warning(err, line=self.lineno)] @@ -428,6 +438,12 @@ def run(self) -> list[Node]: return [figure] +def _subclasses(cls: type[Any]) -> Iterator[type[Any]]: + yield cls + for sub_cls in cls.__subclasses__(): + yield from _subclasses(sub_cls) + + def get_graph_hash(node: inheritance_diagram) -> str: encoded = (node['content'] + str(node['parts'])).encode() return hashlib.md5(encoded, usedforsecurity=False).hexdigest()[-10:]