diff --git a/src/_griffe/extensions/base.py b/src/_griffe/extensions/base.py index 1903efe0..cde62e23 100644 --- a/src/_griffe/extensions/base.py +++ b/src/_griffe/extensions/base.py @@ -8,7 +8,7 @@ from importlib.util import module_from_spec, spec_from_file_location from inspect import isclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Self, Type, Union from _griffe.agents.nodes.ast import ast_children, ast_kind from _griffe.exceptions import ExtensionNotLoadedError @@ -28,6 +28,8 @@ class Extension: """Base class for Griffe extensions.""" + sub_extensions: tuple[SubExtension, ...] = () + def visit(self, node: ast.AST) -> None: """Visit a node. @@ -277,6 +279,17 @@ def on_wildcard_expansion( """ +class SubExtension: + """Base class for Griffe sub-extensions.""" + + namespace: str + + def __init_subclass__(cls) -> None: + if not hasattr(cls, "namespace") or not cls.namespace: + cls.namespace = cls.__name__ + + + LoadableExtensionType = Union[str, Dict[str, Any], Extension, Type[Extension]] """All the types that can be passed to `load_extensions`.""" @@ -312,6 +325,17 @@ def call(self, event: str, **kwargs: Any) -> None: for extension in self._extensions: getattr(extension, event)(**kwargs) + def subcall(self, namespace: str, event: str, **kwargs: Any) -> None: + """Call the sub-extension hook for the given event. + + Parameters: + event: The triggered event. + **kwargs: Arguments passed to the hook. + """ + for extension in self._extensions: + for sub_extension in extension.sub_extensions: + if sub_extension.namespace == namespace: + getattr(sub_extension, event)(**kwargs) builtin_extensions: set[str] = { "dataclasses",