diff --git a/compiler/accelerators/snax.py b/compiler/accelerators/snax.py index dd87e8e2..a1bca434 100644 --- a/compiler/accelerators/snax.py +++ b/compiler/accelerators/snax.py @@ -279,7 +279,7 @@ def get_streamer_launch_dict(self, base_addr) -> tuple[int, dict[str, int]]: @staticmethod @abstractmethod - def get_template(op: stream.StreamingRegionOp) -> Template: + def get_template(op: stream.StreamingRegionOpBase) -> Template: """ Get the template for this acelerator to schedule a given stream.streaming_region operation. diff --git a/compiler/accelerators/snax_alu.py b/compiler/accelerators/snax_alu.py index 2be43986..d10e5506 100644 --- a/compiler/accelerators/snax_alu.py +++ b/compiler/accelerators/snax_alu.py @@ -194,7 +194,7 @@ def generate_acc_op(self) -> accfg.AcceleratorOp: return op @staticmethod - def get_template(op: stream.StreamingRegionOp): + def get_template(op: stream.StreamingRegionOpBase): template = [AffineMap.from_callable(lambda x, y: (4 * x + y,))] * 3 template_bounds = (None, 4) return Template(TemplatePattern(template_bounds, tp) for tp in template) diff --git a/compiler/accelerators/snax_gemm.py b/compiler/accelerators/snax_gemm.py index fe301a7d..2f9351f1 100644 --- a/compiler/accelerators/snax_gemm.py +++ b/compiler/accelerators/snax_gemm.py @@ -166,7 +166,7 @@ def lower_acc_await(acc_op: accfg.AcceleratorOp) -> Sequence[Operation]: ] @staticmethod - def get_template(op: stream.StreamingRegionOp) -> Template: + def get_template(op: stream.StreamingRegionOpBase) -> Template: M, N, K, m, n, k = (AffineDimExpr(i) for i in range(6)) template = [ AffineMap(6, 0, (M * 8 + m, K * 8 + k)), diff --git a/compiler/accelerators/snax_gemmx.py b/compiler/accelerators/snax_gemmx.py index 934a1846..4103893a 100644 --- a/compiler/accelerators/snax_gemmx.py +++ b/compiler/accelerators/snax_gemmx.py @@ -271,7 +271,7 @@ def _generate_setup_vals( ] @staticmethod - def get_template(op: stream.StreamingRegionOp) -> Template: + def get_template(op: stream.StreamingRegionOpBase) -> Template: assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp) if isinstance(generic_op.body.block.first_op, kernel.QMacOp): # matmul diff --git a/compiler/transforms/convert_stream_to_snax_stream.py b/compiler/transforms/convert_stream_to_snax_stream.py index e5dbb8db..0bf2943f 100644 --- a/compiler/transforms/convert_stream_to_snax_stream.py +++ b/compiler/transforms/convert_stream_to_snax_stream.py @@ -2,7 +2,7 @@ from xdsl.context import MLContext from xdsl.dialects import arith, builtin, memref -from xdsl.dialects.builtin import MemRefType +from xdsl.dialects.builtin import AffineMapAttr, ArrayAttr, MemRefType from xdsl.ir import Operation from xdsl.ir.affine import AffineMap from xdsl.passes import ModulePass @@ -17,12 +17,31 @@ from compiler.accelerators.registry import AcceleratorRegistry from compiler.accelerators.snax import SNAXStreamer from compiler.dialects import snax_stream, stream -from compiler.dialects.snax import StreamerConfigurationAttr from compiler.ir.stream import Schedule, SchedulePattern, scheduler +from compiler.ir.stream.access_pattern import Template + + +def get_accelerator_info(op: stream.StreamingRegionOpBase) -> Template: + assert op.accelerator is not None + + # Go and fetch the accelerator op + accelerator_str = op.accelerator.data + acc_op = find_accelerator_op(op, accelerator_str) + + if not acc_op: + raise RuntimeError("AcceleratorOp not found!") + + # get template and template_bounds + accelerator_type = AcceleratorRegistry().get_acc_info(acc_op) + assert issubclass(accelerator_type, SNAXStreamer) + + template = accelerator_type.get_template(op) + + return template @dataclass -class MemrefStreamToSnaxPattern(RewritePattern): +class AutoflowScheduler(RewritePattern): """ A pass to convert streaming region operations to snax stream. @@ -42,27 +61,7 @@ class MemrefStreamToSnaxPattern(RewritePattern): def match_and_rewrite( self, op: stream.StreamingRegionOp, rewriter: PatternRewriter ): - # Handle only stream ops dispatched to an accelerator: - if op.accelerator is None: - return - - # Go and fetch the accelerator op - accelerator_str = op.accelerator.data - acc_op = find_accelerator_op(op, accelerator_str) - - if not acc_op: - raise RuntimeError("AcceleratorOp not found!") - - if "streamer_config" not in acc_op.attributes: - raise RuntimeError("Streamer interface not found for given accelerator op") - streamer_config = acc_op.attributes["streamer_config"] - assert isinstance(streamer_config, StreamerConfigurationAttr) - - # get template and template_bounds - accelerator_type = AcceleratorRegistry().get_acc_info(acc_op) - assert issubclass(accelerator_type, SNAXStreamer) - - template = accelerator_type.get_template(op) + template = get_accelerator_info(op) # Make sure the operands are memrefs for memref_operand in op.operands: @@ -77,6 +76,31 @@ def match_and_rewrite( ) schedule = scheduler(template, schedule) + schedule_op = stream.ScheduleOp( + op.inputs, + op.outputs, + ArrayAttr([AffineMapAttr(s.pattern) for s in schedule]), + rewriter.move_region_contents_to_new_regions(op.body), + schedule[0].bounds, + [[]], + op.accelerator, + op.result_types, + ) + + rewriter.replace_matched_op(schedule_op) + + +@dataclass +class LayoutResolution(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: stream.ScheduleOp, rewriter: PatternRewriter): + template = get_accelerator_info(op) + + bounds = [x.value.data for x in op.bounds.data] + schedule = Schedule( + SchedulePattern(bounds, pattern.data) for pattern in op.patterns + ) + # We are now ready to convert the stream access patterns into snax stride patterns # construct the strided patterns for SNAX Streamers @@ -159,7 +183,8 @@ def generate_one_list(n: int, i: int): # TODO: what is still required is a better system for the unused operands # of snax_gemmx / other accelerators. this now fills in empty/zero patterns for the unused operands. - if acc_op.name_prop.root_reference.data == "snax_gemmx": + assert op.accelerator + if op.accelerator.data == "snax_gemmx": empty_pattern = snax_stream.StridePattern( upper_bounds=[0] * 3, temporal_strides=[0] * 3, spatial_strides=[0] ) @@ -243,7 +268,7 @@ def generate_one_list(n: int, i: int): memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1]) ) - if accelerator_str == "snax_gemmx": + if op.accelerator.data == "snax_gemmx": # make last spatial stride patterns 2d snax_stride_patterns[-2] = snax_stream.StridePattern( upper_bounds=snax_stride_patterns[-2].upper_bounds, @@ -261,7 +286,7 @@ def generate_one_list(n: int, i: int): inputs=new_inputs, outputs=new_outputs, stride_patterns=snax_stride_patterns, - accelerator=accelerator_str, + accelerator=op.accelerator.data, body=rewriter.move_region_contents_to_new_regions(op.body), ) @@ -273,4 +298,5 @@ class ConvertStreamToSnaxStream(ModulePass): name = "convert-stream-to-snax-stream" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - PatternRewriteWalker(MemrefStreamToSnaxPattern()).rewrite_module(op) + PatternRewriteWalker(AutoflowScheduler()).rewrite_module(op) + PatternRewriteWalker(LayoutResolution()).rewrite_module(op)