Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option include-subclasses to inheritance-diagram. #8159

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ Features added
* #12507: Add the :ref:`collapsible <collapsible-admonitions>` option
to admonition directives.
Patch by Chris Sewell.
* #8191, #8159: Add :rst:dir:`inheritance-diagram:include-subclasses` option to
the :rst:dir:`inheritance-diagram` directive.
Patch by Walter Dörwald.

Bugs fixed
----------
Expand Down
18 changes: 18 additions & 0 deletions doc/usage/extensions/inheritance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ It adds this directive:
.. versionchanged:: 1.7
Added ``top-classes`` option to limit the scope of inheritance graphs.

.. rst:directive:option:: include-subclasses
:type: no value

.. versionadded:: 8.2

If given, any subclass of the classes will be added to the diagram too.

Given the Python module from above, you can specify
your inheritance diagram like this:

.. code-block:: rst

.. inheritance-diagram:: dummy.test.A
:include-subclasses:

This will include the classes A, B, C, D, E and F in the inheritance diagram
but no other classes in the module ``dummy.test``.


Examples
--------
Expand Down
46 changes: 31 additions & 15 deletions sphinx/ext/inheritance_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class E(B): pass
from sphinx.util.docutils import SphinxDirective

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Collection, Iterable, Iterator, Sequence, Set
from typing import Any, ClassVar, Final

from docutils.nodes import Node
Expand Down Expand Up @@ -106,7 +106,7 @@ def try_import(objname: str) -> Any:
return None


def import_classes(name: str, currmodule: str) -> Any:
def import_classes(name: str, currmodule: str) -> list[type[Any]]:
"""Import a class using its fully-qualified *name*."""
target = None

Expand Down Expand Up @@ -156,37 +156,45 @@ def __init__(
private_bases: bool = False,
parts: int = 0,
aliases: dict[str, str] | None = None,
top_classes: Sequence[Any] = (),
top_classes: Set[str] = frozenset(),
include_subclasses: bool = False,
) -> None:
"""*class_names* is a list of child classes to show bases from.

If *show_builtins* is True, then Python builtins will be shown
in the graph.
"""
self.class_names = class_names
classes = self._import_classes(class_names, currmodule)
classes: Collection[type[Any]] = self._import_classes(class_names, currmodule)
if include_subclasses:
classes_set = {*classes}
for cls in tuple(classes_set):
classes_set.update(_subclasses(cls))
classes = classes_set
self.class_info = self._class_info(
classes, show_builtins, private_bases, parts, aliases, top_classes
)
if not self.class_info:
msg = 'No classes found for inheritance diagram'
raise InheritanceException(msg)

def _import_classes(self, class_names: list[str], currmodule: str) -> list[Any]:
def _import_classes(
self, class_names: list[str], currmodule: str
) -> Sequence[type[Any]]:
"""Import a list of classes."""
classes: list[Any] = []
classes: list[type[Any]] = []
for name in class_names:
classes.extend(import_classes(name, currmodule))
return classes

def _class_info(
self,
classes: list[Any],
classes: Collection[type[Any]],
show_builtins: bool,
private_bases: bool,
parts: int,
aliases: dict[str, str] | None,
top_classes: Sequence[Any],
top_classes: Set[str],
) -> list[tuple[str, str, Sequence[str], str | None]]:
"""Return name and bases for all classes that are ancestors of
*classes*.
Expand All @@ -205,7 +213,7 @@ def _class_info(
"""
all_classes = {}

def recurse(cls: Any) -> None:
def recurse(cls: type[Any]) -> None:
if not show_builtins and cls in PY_BUILTINS:
return
if not private_bases and cls.__name__.startswith('_'):
Expand Down Expand Up @@ -248,7 +256,7 @@ def recurse(cls: Any) -> None:
]

def class_name(
self, cls: Any, parts: int = 0, aliases: dict[str, str] | None = None
self, cls: type[Any], parts: int = 0, aliases: dict[str, str] | None = None
) -> str:
"""Given a class object, return a fully-qualified name.

Expand Down Expand Up @@ -377,6 +385,7 @@ class InheritanceDiagram(SphinxDirective):
'private-bases': directives.flag,
'caption': directives.unchanged,
'top-classes': directives.unchanged_required,
'include-subclasses': directives.flag,
}

def run(self) -> list[Node]:
Expand All @@ -387,11 +396,11 @@ def run(self) -> list[Node]:
# Store the original content for use as a hash
node['parts'] = self.options.get('parts', 0)
node['content'] = ', '.join(class_names)
node['top-classes'] = []
for cls in self.options.get('top-classes', '').split(','):
cls = cls.strip()
if cls:
node['top-classes'].append(cls)
node['top-classes'] = frozenset({
cls_stripped
for cls in self.options.get('top-classes', '').split(',')
if (cls_stripped := cls.strip())
})

# Create a graph starting with the list of classes
try:
Expand All @@ -402,6 +411,7 @@ def run(self) -> list[Node]:
private_bases='private-bases' in self.options,
aliases=self.config.inheritance_alias,
top_classes=node['top-classes'],
include_subclasses='include-subclasses' in self.options,
)
except InheritanceException as err:
return [node.document.reporter.warning(err, line=self.lineno)]
Expand All @@ -428,6 +438,12 @@ def run(self) -> list[Node]:
return [figure]


def _subclasses(cls: type[Any]) -> Iterator[type[Any]]:
yield cls
for sub_cls in cls.__subclasses__():
yield from _subclasses(sub_cls)


def get_graph_hash(node: inheritance_diagram) -> str:
encoded = (node['content'] + str(node['parts'])).encode()
return hashlib.md5(encoded, usedforsecurity=False).hexdigest()[-10:]
Expand Down