Skip to content

Commit

Permalink
attempt to fix cache class type annotations
Browse files Browse the repository at this point in the history
pyright still doesn't like it
  • Loading branch information
majosm committed Jan 24, 2025
1 parent e61139e commit 40b2a64
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 26 deletions.
5 changes: 3 additions & 2 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
CachedWalkMapper,
CopyMapper,
SubsetDependencyMapper,
TransformMapperCache,
)
from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin

Expand Down Expand Up @@ -140,8 +141,8 @@ def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_cache: CodeGenPreprocessor._CacheT | None = None,
_function_cache: CodeGenPreprocessor._FunctionCacheT | None = None
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
Expand Down
5 changes: 3 additions & 2 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
CachedWalkMapper,
CombineMapper,
CopyMapper,
TransformMapperCache,
_verify_is_array,
)

Expand Down Expand Up @@ -239,9 +240,9 @@ def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_cache: _DistributedInputReplacer._CacheT | None = None,
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache:
_DistributedInputReplacer._FunctionCacheT | None = None,
TransformMapperCache[FunctionDefinition] | None = None,
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)

Expand Down
63 changes: 45 additions & 18 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@

__doc__ = """
.. autoclass:: Mapper
.. autoclass:: MapperCache
.. autoclass:: CachedMapperCache
.. autoclass:: CachedMapper
.. autoclass:: TransformMapperCache
.. autoclass:: TransformMapperWithExtraArgsCache
.. autoclass:: TransformMapper
.. autoclass:: TransformMapperWithExtraArgs
.. autoclass:: CopyMapper
Expand Down Expand Up @@ -296,9 +299,10 @@ def __call__(self,
CacheResultT = TypeVar("CacheResultT")


class CachedMapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT, P]):
# FIXME: Just fix CacheKeyT == Hashable?
class MapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT, P]):
"""
Cache for :class:`CachedMapper`.
Cache for mappers.
.. automethod:: __init__
.. method:: get_key
Expand Down Expand Up @@ -365,6 +369,10 @@ def clear(self) -> None:
self._expr_key_to_result = {}


class CachedMapperCache(MapperCache[CacheExprT, Hashable, CacheResultT, P]):
pass


class CachedMapper(Mapper[ResultT, FunctionResultT, P]):
"""Mapper class that maps each node in the DAG exactly once. This loses some
information compared to :class:`Mapper` as a node is visited only from
Expand All @@ -374,24 +382,21 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]):
.. automethod:: get_function_definition_cache_key
.. automethod:: clone_for_callee
"""
_CacheT: TypeAlias = CachedMapperCache[
ArrayOrNames, Hashable, _OtherResultT, _OtherP]
_FunctionCacheT: TypeAlias = CachedMapperCache[
FunctionDefinition, Hashable, _OtherFunctionResultT, _OtherP]

def __init__(
self,
_cache: CachedMapper._CacheT[ResultT, P] | None = None,
_cache:
CachedMapperCache[ArrayOrNames, ResultT, P] | None = None,
_function_cache:
CachedMapper._FunctionCacheT[FunctionResultT, P] | None = None
CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None
) -> None:
super().__init__()

self._cache: CachedMapper._CacheT[ResultT, P] = (
self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = (
_cache if _cache is not None
else CachedMapperCache(self.get_cache_key))

self._function_cache: CachedMapper._FunctionCacheT[FunctionResultT, P] = (
self._function_cache: CachedMapperCache[
FunctionDefinition, FunctionResultT, P] = (
_function_cache if _function_cache is not None
else CachedMapperCache(self.get_function_definition_cache_key))

Expand Down Expand Up @@ -441,6 +446,14 @@ def clone_for_callee(

# {{{ TransformMapper

class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, []]):
pass


class TransformMapperWithExtraArgsCache(CachedMapperCache[CacheExprT, CacheExprT, P]):
pass


class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]):
"""Base class for mappers that transform :class:`pytato.array.Array`\\ s into
other :class:`pytato.array.Array`\\ s.
Expand All @@ -449,8 +462,15 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]):
arrays (e.g., computing a cache key from them). Does not implement default
mapper methods; for that, see :class:`CopyMapper`.
"""
_CacheT: TypeAlias = CachedMapper._CacheT[ArrayOrNames, []]
_FunctionCacheT: TypeAlias = CachedMapper._FunctionCacheT[FunctionDefinition, []]
_cache: TransformMapperCache[ArrayOrNames]
_function_cache: TransformMapperCache[FunctionDefinition]

def __init__(
self,
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)

# }}}

Expand All @@ -467,9 +487,16 @@ class TransformMapperWithExtraArgs(
The logic in :class:`TransformMapper` purposely does not take the extra
arguments to keep the cost of its each call frame low.
"""
_CacheT: TypeAlias = CachedMapper._CacheT[ArrayOrNames, _OtherP]
_FunctionCacheT: TypeAlias = CachedMapper._FunctionCacheT[
FunctionDefinition, _OtherP]
_cache: TransformMapperWithExtraArgsCache[ArrayOrNames, P]
_function_cache: TransformMapperWithExtraArgsCache[FunctionDefinition, P]

def __init__(
self,
_cache: TransformMapperWithExtraArgsCache[ArrayOrNames, P] | None = None,
_function_cache:
TransformMapperWithExtraArgsCache[FunctionDefinition, P] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)

# }}}

Expand Down Expand Up @@ -1495,8 +1522,8 @@ class CachedMapAndCopyMapper(CopyMapper):
def __init__(
self,
map_fn: Callable[[ArrayOrNames], ArrayOrNames],
_cache: CachedMapAndCopyMapper._CacheT | None = None,
_function_cache: CachedMapAndCopyMapper._FunctionCacheT | None = None
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache: TransformMapperCache[FunctionDefinition] | None = None
) -> None:
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn
Expand Down
9 changes: 5 additions & 4 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
index_lambda_to_high_level_op,
)
from pytato.scalar_expr import SCALAR_CLASSES
from pytato.transform import ArrayOrNames, CopyMapper, Mapper
from pytato.transform import ArrayOrNames, CopyMapper, Mapper, TransformMapperCache
from pytato.utils import are_shape_components_equal, are_shapes_equal


Expand All @@ -92,7 +92,7 @@
if TYPE_CHECKING:
from collections.abc import Collection, Mapping

from pytato.function import NamedCallResult
from pytato.function import FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall


Expand Down Expand Up @@ -596,8 +596,9 @@ class AxisTagAttacher(CopyMapper):
def __init__(self,
axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]],
tag_corresponding_redn_descr: bool,
_cache: AxisTagAttacher._CacheT | None = None,
_function_cache: AxisTagAttacher._FunctionCacheT | None = None):
_cache: TransformMapperCache[ArrayOrNames] | None = None,
_function_cache:
TransformMapperCache[FunctionDefinition] | None = None):
super().__init__(_cache=_cache, _function_cache=_function_cache)
self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags
self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr
Expand Down

0 comments on commit 40b2a64

Please sign in to comment.