Skip to content

Commit

Permalink
rename stream/autoflow related stuff to dart (decoupled acceleration …
Browse files Browse the repository at this point in the history
…runtime tools) (#344)

* rename to dart

* formatting

* more renaming

* fix filechecks

* update snakefile passes
  • Loading branch information
jorendumoulin authored Jan 22, 2025
1 parent 25fc420 commit 3f9ee73
Show file tree
Hide file tree
Showing 37 changed files with 266 additions and 273 deletions.
2 changes: 1 addition & 1 deletion benchmarks/dense_matmul/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ else
REMOVE_MEMREF_COPY=
endif

SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-stream,fuse-streaming-regions,snax-bufferize,alloc-to-global,set-memory-space,set-memory-layout{gemm_layout=${LAYOUT}},realize-memref-casts,${REMOVE_MEMREF_COPY}insert-sync-barrier,dispatch-regions{nb_cores=2},convert-stream-to-snax-stream,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space
SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-dart,dart-fuse-operations,snax-bufferize,alloc-to-global,set-memory-space,set-memory-layout{gemm_layout=${LAYOUT}},realize-memref-casts,${REMOVE_MEMREF_COPY}insert-sync-barrier,dispatch-regions{nb_cores=2},convert-dart-to-snax-stream,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space

GEN_DATA_OPTS += --m=${SIZE_M}
GEN_DATA_OPTS += --n=${SIZE_N}
Expand Down
4 changes: 2 additions & 2 deletions compiler/accelerators/snax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from compiler.accelerators.streamers import StreamerConfiguration
from compiler.accelerators.streamers.streamers import StreamerFlag, StreamerOpts
from compiler.dialects import accfg
from compiler.dialects.dart import StreamingRegionOpBase
from compiler.dialects.snax_stream import StreamerConfigurationAttr, StreamingRegionOp
from compiler.dialects.stream import StreamingRegionOpBase
from compiler.ir.stream import Template
from compiler.ir.dart.access_pattern import Template

c0_attr = builtin.IntegerAttr(0, builtin.IndexType())

Expand Down
6 changes: 3 additions & 3 deletions compiler/accelerators/snax_alu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
StreamerConfiguration,
StreamerType,
)
from compiler.dialects import accfg, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern
from compiler.dialects import accfg, dart, snax_stream
from compiler.ir.dart.access_pattern import Template, TemplatePattern

default_streamer = StreamerConfiguration(
[
Expand Down Expand Up @@ -194,7 +194,7 @@ def generate_acc_op(self) -> accfg.AcceleratorOp:
return op

@staticmethod
def get_template(op: stream.StreamingRegionOpBase):
def get_template(op: dart.StreamingRegionOpBase):
template = [AffineMap.from_callable(lambda x, y: (4 * x + y,))] * 3
template_bounds = (None, 4)
return Template(TemplatePattern(template_bounds, tp) for tp in template)
6 changes: 3 additions & 3 deletions compiler/accelerators/snax_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
StreamerConfiguration,
StreamerType,
)
from compiler.dialects import accfg, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern
from compiler.dialects import accfg, dart, snax_stream
from compiler.ir.dart.access_pattern import Template, TemplatePattern

default_streamer = StreamerConfiguration(
[
Expand Down Expand Up @@ -166,7 +166,7 @@ def lower_acc_await(acc_op: accfg.AcceleratorOp) -> Sequence[Operation]:
]

@staticmethod
def get_template(op: stream.StreamingRegionOpBase) -> Template:
def get_template(op: dart.StreamingRegionOpBase) -> Template:
M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6))
template = [
AffineMap(6, 0, (M * 8 + m, K * 8 + k)),
Expand Down
14 changes: 7 additions & 7 deletions compiler/accelerators/snax_gemmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
StreamerType,
)
from compiler.accelerators.streamers.streamers import StreamerOpts
from compiler.dialects import accfg, kernel, snax_stream, stream
from compiler.ir.stream import Template, TemplatePattern
from compiler.dialects import accfg, dart, kernel, snax_stream
from compiler.ir.dart.access_pattern import Template, TemplatePattern
from compiler.util.pack_bitlist import pack_bitlist

default_streamer = StreamerConfiguration(
Expand Down Expand Up @@ -179,7 +179,7 @@ def _generate_setup_vals(

ops_to_add: list[Operation] = []

assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp)
assert isinstance(generic_op := op.body.block.first_op, dart.GenericOp)

if isinstance(qmac := generic_op.body.block.first_op, kernel.QMacOp):
# gemm
Expand Down Expand Up @@ -271,8 +271,8 @@ def _generate_setup_vals(
]

@staticmethod
def get_template(op: stream.StreamingRegionOpBase) -> Template:
assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp)
def get_template(op: dart.StreamingRegionOpBase) -> Template:
assert isinstance(generic_op := op.body.block.first_op, dart.GenericOp)
if isinstance(generic_op.body.block.first_op, kernel.QMacOp):
# matmul
M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6))
Expand All @@ -283,7 +283,7 @@ def get_template(op: stream.StreamingRegionOpBase) -> Template:
]
template_bounds = (None, None, None, 8, 8, 8)

if isinstance(generic_op.next_op, stream.GenericOp):
if isinstance(generic_op.next_op, dart.GenericOp):
generic_op = generic_op.next_op
if isinstance(generic_op.body.block.first_op, kernel.AddOp):
# gemm, add c pattern that is equal to output pattern
Expand All @@ -299,7 +299,7 @@ def get_template(op: stream.StreamingRegionOpBase) -> Template:
]
template_bounds = (None, None, 8, 8)

if not isinstance(generic_op.next_op, stream.YieldOp):
if not isinstance(generic_op.next_op, dart.YieldOp):
raise RuntimeError("unsupported kernel")

return Template(TemplatePattern(template_bounds, tp) for tp in template)
23 changes: 12 additions & 11 deletions compiler/dialects/stream.py → compiler/dialects/dart.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
)

"""
Custom `stream` dialect, to simplify things in a more principled approach, including:
Custom `dart` dialect, heavily inspired by xDSL `stream` dialect, to simplify
things in a (hopefully) more principled approach, including:
- inherent support for tensors
- streams with value semantics
- no specified static bounds in access patterns: they are just affine maps
Expand All @@ -61,7 +62,7 @@ class StreamType(
Streams can only be read from, there is no distinction between readable/writeable streams.
"""

name = "stream.stream"
name = "dart.stream"

element_type: ParameterDef[_StreamTypeElement]

Expand Down Expand Up @@ -117,13 +118,13 @@ def __init__(


@irdl_op_definition
class StreamingRegionOp(StreamingRegionOpBase):
class OperationOp(StreamingRegionOpBase):
"""
A streaming region op that represents an unscheduled operation,
with streams mapping the iteration space to the operand indexing space.
"""

name = "stream.streaming_region"
name = "dart.operation"

def get_pattern_bounds_to_shapes_map(self) -> AffineMap:
"""
Expand Down Expand Up @@ -171,7 +172,7 @@ class ScheduleOp(StreamingRegionOpBase):
transformation took place.
"""

name = "stream.schedule"
name = "dart.schedule"

# The bounds of the iteration space of the schedule
bounds = prop_def(ParameterDef[ArrayAttr[IntegerAttr[IndexType]]])
Expand Down Expand Up @@ -219,7 +220,7 @@ class AccessPatternOp(StreamingRegionOpBase):
layout resolution, with streams mapping the iteration space to memory.
"""

name = "stream.access_pattern"
name = "dart.access_pattern"

# The bounds of the iteration space of the schedule
bounds = prop_def(ParameterDef[ArrayAttr[IntegerAttr[IndexType]]])
Expand Down Expand Up @@ -251,7 +252,7 @@ def __init__(

@irdl_op_definition
class YieldOp(AbstractYieldOperation[Attribute]):
name = "stream.yield"
name = "dart.yield"

traits = traits_def(IsTerminator())

Expand All @@ -263,7 +264,7 @@ class GenericOp(IRDLOperation):
Indexing maps / iterators are not relevant, so they are not included.
"""

name = "stream.generic"
name = "dart.generic"

# inputs can be streams or integers
inputs = var_operand_def()
Expand Down Expand Up @@ -294,10 +295,10 @@ def __init__(
)


Stream = Dialect(
"stream",
Dart = Dialect(
"dart",
[
StreamingRegionOp,
OperationOp,
ScheduleOp,
AccessPatternOp,
GenericOp,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self, TypeVar, overload
from xdsl.ir.affine import AffineDimExpr, AffineMap

from compiler.ir.autoflow.affine_transform import AffineTransform
from compiler.ir.dart.affine_transform import AffineTransform


@dataclass(frozen=True)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from compiler.ir.stream import Schedule, Template
from compiler.ir.dart.access_pattern import Schedule, Template


def scheduler(template: Template, schedule: Schedule) -> Schedule:
Expand Down
2 changes: 0 additions & 2 deletions compiler/ir/stream/__init__.py

This file was deleted.

18 changes: 10 additions & 8 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from xdsl.xdsl_opt_main import xDSLOptMain

from compiler.dialects.accfg import ACCFG
from compiler.dialects.dart import Dart
from compiler.dialects.kernel import Kernel
from compiler.dialects.snax import Snax
from compiler.dialects.snax_stream import SnaxStream
from compiler.dialects.stream import Stream
from compiler.dialects.test.debug import Debug
from compiler.dialects.tsl import TSL
from compiler.transforms.accfg_config_overlap import AccfgConfigOverlapPass
Expand All @@ -18,20 +18,20 @@
from compiler.transforms.backend.postprocess_mlir import PostprocessPass
from compiler.transforms.clear_memory_space import ClearMemorySpace
from compiler.transforms.convert_accfg_to_csr import ConvertAccfgToCsrPass
from compiler.transforms.convert_dart_to_snax_stream import ConvertDartToSnaxStream
from compiler.transforms.convert_kernel_to_linalg import ConvertKernelToLinalg
from compiler.transforms.convert_linalg_to_accfg import (
ConvertLinalgToAccPass,
TraceStatesPass,
)
from compiler.transforms.convert_linalg_to_kernel import ConvertLinalgToKernel
from compiler.transforms.convert_linalg_to_stream import ConvertLinalgToStream
from compiler.transforms.convert_stream_to_snax_stream import ConvertStreamToSnaxStream
from compiler.transforms.convert_tosa_to_kernel import ConvertTosaToKernelPass
from compiler.transforms.dart.convert_linalg_to_dart import ConvertLinalgToDart
from compiler.transforms.dart.dart_fuse_operations import DartFuseOperationsPass
from compiler.transforms.dispatch_kernels import DispatchKernels
from compiler.transforms.dispatch_regions import DispatchRegions
from compiler.transforms.frontend.preprocess_mlir import PreprocessPass
from compiler.transforms.frontend.preprocess_mlperf_tiny import PreprocessMLPerfTiny
from compiler.transforms.fuse_streaming_regions import FuseStreamingRegions
from compiler.transforms.insert_accfg_op import InsertAccOp
from compiler.transforms.insert_sync_barrier import InsertSyncBarrier
from compiler.transforms.memref_to_snax import MemrefToSNAX
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self.ctx.load_dialect(ACCFG)
self.ctx.load_dialect(SnaxStream)
self.ctx.load_dialect(Debug)
self.ctx.load_dialect(Stream)
self.ctx.load_dialect(Dart)
super().register_pass(DispatchKernels.name, lambda: DispatchKernels)
super().register_pass(SetMemorySpace.name, lambda: SetMemorySpace)
super().register_pass(SetMemoryLayout.name, lambda: SetMemoryLayout)
Expand All @@ -104,7 +104,7 @@ def __init__(
AccfgConfigOverlapPass.name, lambda: AccfgConfigOverlapPass
)
super().register_pass(
ConvertStreamToSnaxStream.name, lambda: ConvertStreamToSnaxStream
ConvertDartToSnaxStream.name, lambda: ConvertDartToSnaxStream
)
super().register_pass(ReuseMemrefAllocs.name, lambda: ReuseMemrefAllocs)
super().register_pass(RemoveMemrefCopyPass.name, lambda: RemoveMemrefCopyPass)
Expand All @@ -120,9 +120,11 @@ def __init__(
super().register_pass(DebugToFuncPass.name, lambda: DebugToFuncPass)
super().register_pass(PreprocessMLPerfTiny.name, lambda: PreprocessMLPerfTiny)
super().register_pass(AddMcycleAroundLaunch.name, lambda: AddMcycleAroundLaunch)
super().register_pass(ConvertLinalgToStream.name, lambda: ConvertLinalgToStream)
super().register_pass(ConvertLinalgToDart.name, lambda: ConvertLinalgToDart)
super().register_pass(SnaxBufferize.name, lambda: SnaxBufferize)
super().register_pass(FuseStreamingRegions.name, lambda: FuseStreamingRegions)
super().register_pass(
DartFuseOperationsPass.name, lambda: DartFuseOperationsPass
)
super().register_pass(AllocToGlobalPass.name, lambda: AllocToGlobalPass)
super().register_pass(PreprocessPass.name, lambda: PreprocessPass)
super().register_pass(PostprocessPass.name, lambda: PostprocessPass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from compiler.accelerators.registry import AcceleratorRegistry
from compiler.accelerators.snax import SNAXStreamer
from compiler.accelerators.util import find_accelerator_op
from compiler.dialects import snax_stream, stream
from compiler.ir.autoflow.affine_transform import AffineTransform
from compiler.ir.stream import Schedule, SchedulePattern, scheduler
from compiler.ir.stream.access_pattern import Template
from compiler.dialects import dart, snax_stream
from compiler.ir.dart.access_pattern import Schedule, SchedulePattern, Template
from compiler.ir.dart.affine_transform import AffineTransform
from compiler.ir.dart.scheduler import scheduler


def get_accelerator_info(op: stream.StreamingRegionOpBase) -> Template:
def get_accelerator_info(op: dart.StreamingRegionOpBase) -> Template:
assert op.accelerator is not None

# Go and fetch the accelerator op
Expand Down Expand Up @@ -51,9 +51,7 @@ class AutoflowScheduler(RewritePattern):
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: stream.StreamingRegionOp, rewriter: PatternRewriter
):
def match_and_rewrite(self, op: dart.OperationOp, rewriter: PatternRewriter):
template = get_accelerator_info(op)

# Make sure the operands are memrefs
Expand All @@ -69,7 +67,7 @@ def match_and_rewrite(
)
schedule = scheduler(template, schedule)

schedule_op = stream.ScheduleOp(
schedule_op = dart.ScheduleOp(
op.inputs,
op.outputs,
ArrayAttr([AffineMapAttr(s.pattern.to_affine_map()) for s in schedule]),
Expand All @@ -92,7 +90,7 @@ class LayoutResolution(RewritePattern):
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stream.ScheduleOp, rewriter: PatternRewriter):
def match_and_rewrite(self, op: dart.ScheduleOp, rewriter: PatternRewriter):
bounds = [x.value.data for x in op.bounds.data]
schedule = Schedule(
SchedulePattern(bounds, pattern.data) for pattern in op.patterns
Expand Down Expand Up @@ -147,7 +145,7 @@ def generate_one_list(n: int, i: int):

new_patterns = ArrayAttr([AffineMapAttr(map) for map in access_patterns])

access_pattern_op = stream.AccessPatternOp(
access_pattern_op = dart.AccessPatternOp(
new_inputs,
new_outputs,
new_patterns,
Expand All @@ -170,7 +168,7 @@ class ConvertStreamToSnaxStreamPattern(RewritePattern):
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stream.AccessPatternOp, rewriter: PatternRewriter):
def match_and_rewrite(self, op: dart.AccessPatternOp, rewriter: PatternRewriter):
template = get_accelerator_info(op)

snax_stride_patterns: list[snax_stream.StridePattern] = []
Expand Down Expand Up @@ -324,8 +322,8 @@ def match_and_rewrite(self, op: stream.AccessPatternOp, rewriter: PatternRewrite


@dataclass(frozen=True)
class ConvertStreamToSnaxStream(ModulePass):
name = "convert-stream-to-snax-stream"
class ConvertDartToSnaxStream(ModulePass):
name = "convert-dart-to-snax-stream"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(AutoflowScheduler()).rewrite_module(op)
Expand Down
Empty file.
Loading

0 comments on commit 3f9ee73

Please sign in to comment.