diff --git a/pytato/codegen.py b/pytato/codegen.py index 820918ab9..c177416ec 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -63,6 +63,7 @@ CachedWalkMapper, CopyMapper, SubsetDependencyMapper, + TransformMapperCache, ) from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin @@ -140,8 +141,8 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _cache: CodeGenPreprocessor._CacheT | None = None, - _function_cache: CodeGenPreprocessor._FunctionCacheT | None = None + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index d043e1c12..3a16adafe 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -94,6 +94,7 @@ CachedWalkMapper, CombineMapper, CopyMapper, + TransformMapperCache, _verify_is_array, ) @@ -239,9 +240,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _cache: _DistributedInputReplacer._CacheT | None = None, + _cache: TransformMapperCache[ArrayOrNames] | None = None, _function_cache: - _DistributedInputReplacer._FunctionCacheT | None = None, + TransformMapperCache[FunctionDefinition] | None = None, ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index d96d1e0fa..8e2645bdd 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -92,8 +92,11 @@ __doc__ = """ .. autoclass:: Mapper +.. autoclass:: MapperCache .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper +.. autoclass:: TransformMapperCache +.. autoclass:: TransformMapperWithExtraArgsCache .. autoclass:: TransformMapper .. autoclass:: TransformMapperWithExtraArgs .. autoclass:: CopyMapper @@ -296,9 +299,10 @@ def __call__(self, CacheResultT = TypeVar("CacheResultT") -class CachedMapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT, P]): +# FIXME: Just fix CacheKeyT == Hashable? +class MapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT, P]): """ - Cache for :class:`CachedMapper`. + Cache for mappers. .. automethod:: __init__ .. method:: get_key @@ -365,6 +369,10 @@ def clear(self) -> None: self._expr_key_to_result = {} +class CachedMapperCache(MapperCache[CacheExprT, Hashable, CacheResultT, P]): + pass + + class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """Mapper class that maps each node in the DAG exactly once. This loses some information compared to :class:`Mapper` as a node is visited only from @@ -374,24 +382,21 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): .. automethod:: get_function_definition_cache_key .. automethod:: clone_for_callee """ - _CacheT: TypeAlias = CachedMapperCache[ - ArrayOrNames, Hashable, _OtherResultT, _OtherP] - _FunctionCacheT: TypeAlias = CachedMapperCache[ - FunctionDefinition, Hashable, _OtherFunctionResultT, _OtherP] - def __init__( self, - _cache: CachedMapper._CacheT[ResultT, P] | None = None, + _cache: + CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: - CachedMapper._FunctionCacheT[FunctionResultT, P] | None = None + CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None ) -> None: super().__init__() - self._cache: CachedMapper._CacheT[ResultT, P] = ( + self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None else CachedMapperCache(self.get_cache_key)) - self._function_cache: CachedMapper._FunctionCacheT[FunctionResultT, P] = ( + self._function_cache: CachedMapperCache[ + FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None else CachedMapperCache(self.get_function_definition_cache_key)) @@ -441,6 +446,14 @@ def clone_for_callee( # {{{ TransformMapper +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, []]): + pass + + +class TransformMapperWithExtraArgsCache(CachedMapperCache[CacheExprT, CacheExprT, P]): + pass + + class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """Base class for mappers that transform :class:`pytato.array.Array`\\ s into other :class:`pytato.array.Array`\\ s. @@ -449,8 +462,15 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. """ - _CacheT: TypeAlias = CachedMapper._CacheT[ArrayOrNames, []] - _FunctionCacheT: TypeAlias = CachedMapper._FunctionCacheT[FunctionDefinition, []] + _cache: TransformMapperCache[ArrayOrNames] + _function_cache: TransformMapperCache[FunctionDefinition] + + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) # }}} @@ -467,9 +487,16 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. """ - _CacheT: TypeAlias = CachedMapper._CacheT[ArrayOrNames, _OtherP] - _FunctionCacheT: TypeAlias = CachedMapper._FunctionCacheT[ - FunctionDefinition, _OtherP] + _cache: TransformMapperWithExtraArgsCache[ArrayOrNames, P] + _function_cache: TransformMapperWithExtraArgsCache[FunctionDefinition, P] + + def __init__( + self, + _cache: TransformMapperWithExtraArgsCache[ArrayOrNames, P] | None = None, + _function_cache: + TransformMapperWithExtraArgsCache[FunctionDefinition, P] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) # }}} @@ -1495,8 +1522,8 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _cache: CachedMapAndCopyMapper._CacheT | None = None, - _function_cache: CachedMapAndCopyMapper._FunctionCacheT | None = None + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 3cc177445..e96cd1ee4 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -82,7 +82,7 @@ index_lambda_to_high_level_op, ) from pytato.scalar_expr import SCALAR_CLASSES -from pytato.transform import ArrayOrNames, CopyMapper, Mapper +from pytato.transform import ArrayOrNames, CopyMapper, Mapper, TransformMapperCache from pytato.utils import are_shape_components_equal, are_shapes_equal @@ -92,7 +92,7 @@ if TYPE_CHECKING: from collections.abc import Collection, Mapping - from pytato.function import NamedCallResult + from pytato.function import FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall @@ -596,8 +596,9 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], tag_corresponding_redn_descr: bool, - _cache: AxisTagAttacher._CacheT | None = None, - _function_cache: AxisTagAttacher._FunctionCacheT | None = None): + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: + TransformMapperCache[FunctionDefinition] | None = None): super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr