Skip to content

Commit

Permalink
dynamic magic (#64)
Browse files Browse the repository at this point in the history
* add dialect implementation

* re-enable python tests

* add parser

* include element size in transformation

* generate bound ops

* successful textual compilation

* fix the kernel too

* automatic test generation
  • Loading branch information
jorendumoulin committed Jan 23, 2024
1 parent 3c5b8f4 commit 9e87be3
Show file tree
Hide file tree
Showing 12 changed files with 632 additions and 92 deletions.
22 changes: 0 additions & 22 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name: python-tests
name: python-tests

on:
push:
Expand All @@ -13,20 +12,13 @@ on:
required: true
default: 'main'

permissions:
contents: read
permissions:
contents: read

jobs:
python-tests:
jobs:
python-tests:

runs-on: ubuntu-latest
container:
image: ghcr.io/kuleuven-micas/snax-mlir:${{ github.event.inputs.container-tag || 'main' }}
runs-on: ubuntu-latest
container:
image: ghcr.io/kuleuven-micas/snax-mlir:${{ github.event.inputs.container-tag || 'main' }}

Expand All @@ -44,18 +36,4 @@ jobs:
shell: bash
run: |
/opt/python3.11/bin/python3 -m pytest .
steps:
- uses: actions/checkout@v3
- name: Install pytest
shell: bash
run: |
/opt/python3.11/bin/python3 -m pip install pytest
- name: Reinstall pip modules from requirements
shell: bash
run: |
/opt/python3.11/bin/python3 -m pip install -r requirements.txt
- name: Test with pytest
shell: bash
run: |
/opt/python3.11/bin/python3 -m pytest .
10 changes: 5 additions & 5 deletions compiler/ir/tsl/tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def largest_common_contiguous_block(
"""
result: list[Stride] = []

# does not work on illegal workloads
if self.self_overlaps():
return result
if other.self_overlaps():
return result
# # does not work on illegal workloads
# if self.self_overlaps():
# return result
# if other.self_overlaps():
# return result

# find largest contiguous block
current_stride = starting_stride
Expand Down
1 change: 0 additions & 1 deletion compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from compiler.dialects.snax import Snax
from compiler.dialects.tsl import TSL
from compiler.dialects.tsl import TSL
from compiler.transforms.clear_memory_space import ClearMemorySpace
from compiler.transforms.dispatch_elementwise_mult import DispatchElementWiseMult
from compiler.transforms.dispatch_regions import DispatchRegions
Expand Down
246 changes: 210 additions & 36 deletions compiler/transforms/snax_copy_to_dma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import reduce

from xdsl.dialects import arith, builtin, func, scf
from xdsl.dialects.arith import Addi, Constant, Muli
from xdsl.dialects.arith import Addi, Constant, DivUI, Muli
from xdsl.dialects.builtin import IndexType, IntegerType, NoneAttr
from xdsl.dialects.memref import CopyOp, Dim, ExtractAlignedPointerAsIndexOp, MemRefType
from xdsl.ir import Block, MLContext, Region
Expand Down Expand Up @@ -57,14 +59,16 @@ def match_and_rewrite(self, op: CopyOp, rewriter: PatternRewriter):
ops_to_insert.append(total_size_op)

# step 2: calculate element size to get total size in bytes
# element_type = op.source.type.get_element_type()
# element_size = IntegerType.get_bit_width(element_type) // 8
# total_size_op = Muli(
# total_size_op.result,
# Constant.from_int_and_width(element_size, IndexType()).result,
# IndexType(),
# )
# ops_to_insert.append(total_size_op)
element_type: IntegerType = op.source.type.get_element_type()
element_size = element_type.width.data // 8
element_size_op = Constant.from_int_and_width(element_size, IndexType())
total_size_op = Muli(
total_size_op.result,
element_size_op.result,
IndexType(),
)
ops_to_insert.append(element_size_op)
ops_to_insert.append(total_size_op)

# step 3: extract source and destination pointers
source_ptr_op = ExtractAlignedPointerAsIndexOp.get(op.source)
Expand Down Expand Up @@ -124,21 +128,192 @@ def match_and_rewrite(self, op: CopyOp, rewriter: PatternRewriter):
tsl_source = op.source.type.layout.data
tsl_dest = op.destination.type.layout.data

lcb = tsl_source.largest_common_contiguous_block(tsl_dest)
# lcb is completely static
lcb = tsl_source.largest_common_contiguous_block(
tsl_dest, op.source.type.element_type.width.data // 8
)

# step 3: generate a sorted list of remaing strides;
# all strides excluded from the contiguous block must be generated
# using for loops / multi-dimensional dma transfers
remaining_strides = [stride for stride in tsl_source if stride[2] not in lcb]
# sort the remaining strides by their bound, largest first
# # sort the remaining strides by their bound, largest first, dynamic last

# map the list to
# [
# {
# "dim": dim,
# "depth": depth,
# "stride_src": stride
# "stride_dst": stride
# "bound_op": None
# "step_src_op": None
# "step_dst_op": None
# }
# ]
remaining_strides = {}

# first, get the all the bound ops for strides not in lcb

for dim in range(tsl_source.dimension()):
for depth in range(tsl_source.tstrides[dim].depth()):
stride = tsl_source.get_stride(dim, depth)
if stride in lcb:
continue
if stride.bound is None:
dim_index_op = Constant.from_int_and_width(dim, IndexType())
dim_op = dim_op = Dim.from_source_and_index(
op.source, dim_index_op.result
)
# to calculate the bound, we must divide the size of the matrix by
# the product of all lower tile sizes, which must be known
tilebounds = [
stride.bound
for _, stride in tsl_source.tstrides[dim]
if stride.bound
]
product_tilebounds = reduce(lambda x, y: x * y, tilebounds, 1)
div_op = Constant.from_int_and_width(
product_tilebounds, IndexType()
)
bound_op = DivUI(dim_op.result, div_op.result, IndexType())
ops_to_insert.extend([dim_index_op, dim_op, div_op, bound_op])
else:
bound_op = Constant.from_int_and_width(stride.bound, IndexType())
ops_to_insert.append(bound_op)
remaining_strides[(dim, depth)] = {
"dim": dim,
"depth": depth,
"stride_src": tsl_source.get_stride(dim, depth),
"stride_dst": tsl_dest.get_stride(dim, depth),
"bound_op": bound_op,
}

# second, get the step ops for strides not in lcb
# the strategy here is as follows

# a first step is to find the largest current stride
# as a starting point for further dense calculations

max_key = None
max_value = 0
for key, value in remaining_strides.items():
if value["stride_src"].step and value["stride_src"].step > max_value:
max_key = key
max_value = value["stride_src"].step

max_stride_op = Constant.from_int_and_width(
remaining_strides[max_key]["stride_src"].step, IndexType()
)
starting_bound = Muli(
remaining_strides[max_key]["bound_op"],
max_stride_op,
IndexType(),
)
ops_to_insert.append(max_stride_op)

for dim in range(tsl_source.dimension()):
for depth in range(tsl_source.tstrides[dim].depth()):
if (dim, depth) not in remaining_strides.keys():
continue
stride = remaining_strides[(dim, depth)]
if stride["stride_src"].step is not None:
# create const op
stride_op = Constant.from_int_and_width(
stride["stride_src"].step, IndexType()
)
ops_to_insert.append(stride_op)
stride["step_src_op"] = stride_op
else:
# assign starting bound and increment value
stride_op = starting_bound
# const_bound_op = Constant(stride["bound_op"], IndexType())
starting_bound = Muli(
starting_bound.result,
stride["bound_op"],
IndexType(),
)
ops_to_insert.append(stride_op)
stride["step_src_op"] = stride_op

# now do the same for the destination strides

max_key = None
max_value = 0
for key, value in remaining_strides.items():
if value["stride_dst"].step and value["stride_dst"].step > max_value:
max_key = key
max_value = value["stride_dst"].step

max_stride_op = Constant.from_int_and_width(
remaining_strides[max_key]["stride_dst"].step, IndexType()
)
starting_bound = Muli(
remaining_strides[max_key]["bound_op"],
max_stride_op,
IndexType(),
)
ops_to_insert.append(max_stride_op)

for dim in range(tsl_dest.dimension()):
for depth in range(tsl_dest.tstrides[dim].depth()):
if (dim, depth) not in remaining_strides.keys():
continue
stride = remaining_strides[(dim, depth)]
if stride["stride_dst"].step is not None:
# create const op
stride_op = Constant.from_int_and_width(
stride["stride_dst"].step, IndexType()
)
ops_to_insert.append(stride_op)
stride["step_dst_op"] = stride_op
else:
# assign starting bound and increment value
stride_op = starting_bound
starting_bound = Muli(
starting_bound.result,
stride["bound_op"],
IndexType(),
)
ops_to_insert.append(stride_op)
stride["step_dst_op"] = stride_op

pass
# remaining_strides = [
# {
# "dim": stride[0],
# "depth": stride[1],
# "stride": stride[2],
# "bound_op": None,
# "step_src_op": None,
# "step_dst_op": None,
# }
# for stride in remaining_strides
# ]
# remaining_strides = [

# stride + (tsl_dest.get_stride(stride[0], stride[1]), )
# for stride in remaining_strides
# ]

# # order values of remaining stride by their bound
# remaining_strides = sorted(
# remaining_strides,
# key=lambda stride: stride["bound_op"].result.value
# if stride["bound_op"]
# else 0,
# reverse=True,
# )

remaining_strides = sorted(
remaining_strides, key=lambda stride: stride[2].bound, reverse=True
remaining_strides.values(),
key=lambda x: x["stride_src"].bound if x["stride_src"].bound else 0,
reverse=True,
)
# map the list to [(src_stride, dest_stride), ...]
remaining_strides = [
(stride[2], tsl_dest.get_stride(stride[0], stride[1]))
for stride in remaining_strides
]

# generate ops for stride bounds
# for dim in range(tsl_source.dimension()):
# for depth in range(tsl_source):

# step 4: generate variables for 2D transfer

Expand All @@ -148,13 +323,13 @@ def match_and_rewrite(self, op: CopyOp, rewriter: PatternRewriter):
else:
dma_loop = remaining_strides.pop(0)

dma_size = Constant.from_int_and_width(lcb[-1].bound, IndexType())
dma_stride_src = Constant.from_int_and_width(dma_loop[0].stride, IndexType())
dma_stride_dst = Constant.from_int_and_width(dma_loop[1].stride, IndexType())
dma_stride_bound = Constant.from_int_and_width(dma_loop[0].bound, IndexType())
ops_to_insert.extend(
[dma_size, dma_stride_src, dma_stride_dst, dma_stride_bound]
dma_size = Constant.from_int_and_width(
lcb[-1].bound * lcb[-1].step, IndexType()
)
dma_stride_src = dma_loop["step_src_op"]
dma_stride_dst = dma_loop["step_dst_op"]
dma_stride_bound = dma_loop["bound_op"]
ops_to_insert.extend([dma_size])

# step 5: if there are no remaining strides, insert simple 2D dma transfer
if len(remaining_strides) == 0:
Expand Down Expand Up @@ -194,12 +369,9 @@ def match_and_rewrite(self, op: CopyOp, rewriter: PatternRewriter):
# step 6.1: create the list of loop bounds
lower = arith.Constant.from_int_and_width(0, builtin.IndexType())
step = arith.Constant.from_int_and_width(1, builtin.IndexType())
upper = [
arith.Constant.from_int_and_width(stride[0].bound, builtin.IndexType())
for stride in remaining_strides
]
upper = [stride["bound_op"] for stride in remaining_strides]

ops_to_insert.extend([lower, step, *upper])
ops_to_insert.extend([lower, step])

# step 6.2: create nested for loop (looping from inner to outer)
# most inner for loop has empty region
Expand All @@ -224,20 +396,22 @@ def match_and_rewrite(self, op: CopyOp, rewriter: PatternRewriter):
ops_to_insert_for_loop = []

# source indexing operations:
stride_src = Constant.from_int_and_width(
remaining_strides[i][0].stride, IndexType()
)
# stride_src = Constant.from_int_and_width(
# remaining_strides[i][0].step, IndexType()
# )
stride_src = remaining_strides[i]["step_src_op"]
increment_src = Muli(for_loop.body.block.args[0], stride_src, IndexType())
pointer_src = Addi(pointer_src, increment_src, IndexType())
ops_to_insert_for_loop.extend([stride_src, increment_src, pointer_src])
ops_to_insert_for_loop.extend([increment_src, pointer_src])

# destination indexing operations:
stride_dst = Constant.from_int_and_width(
remaining_strides[i][1].stride, IndexType()
)
# stride_dst = Constant.from_int_and_width(
# remaining_strides[i][1].step, IndexType()
# )
stride_dst = remaining_strides[i]["step_dst_op"]
increment_dst = Muli(for_loop.body.block.args[0], stride_dst, IndexType())
pointer_dst = Addi(pointer_dst, increment_dst, IndexType())
ops_to_insert_for_loop.extend([stride_dst, increment_dst, pointer_dst])
ops_to_insert_for_loop.extend([increment_dst, pointer_dst])

# insert the ops in the for loop body
for_loop.body.block.insert_ops_before(
Expand Down
Loading

0 comments on commit 9e87be3

Please sign in to comment.