diff --git a/.github/workflows/run-benchmarks.yml b/.github/workflows/run-benchmarks.yml new file mode 100644 index 00000000..47a25eb3 --- /dev/null +++ b/.github/workflows/run-benchmarks.yml @@ -0,0 +1,28 @@ +name: Run benchmarks + +on: + push: + branches: + - main + +jobs: + run-benchmarks: + runs-on: ubuntu-latest + container: + image: ghcr.io/kuleuven-micas/snax:v0.1.14 + steps: + - uses: actions/checkout@v3 + - name: Install snax compiler + run: python3 -m pip install -e . + - name: Reinstall pip modules from requirements + run: python3 -m pip install -r requirements.txt + - name: Run benchmarks + run: python3 genbenchmark.py + working-directory: benchmarks/${{ matrix.kernel }} + - uses: actions/upload-artifact@v4 + with: + name: output_report + path: benchmarks/${{ matrix.kernel }}/output_report.txt + strategy: + matrix: + kernel: [dense_matmul] diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore new file mode 100644 index 00000000..845e8ac9 --- /dev/null +++ b/benchmarks/.gitignore @@ -0,0 +1,3 @@ +results +generated* +output_report.txt diff --git a/benchmarks/dense_matmul/Makefile b/benchmarks/dense_matmul/Makefile new file mode 100644 index 00000000..859cdc87 --- /dev/null +++ b/benchmarks/dense_matmul/Makefile @@ -0,0 +1,44 @@ +# Courtesy of Federico Ficarelli + +.DEFAULT_GOAL := all + +include ../../runtime/snax-gemmx.rules +include ../../runtime/Makefile.rules + +TESTS += generated.x + +SNAXOPTFLAGS = -p convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,set-memory-space,set-memory-layout,realize-memref-casts,insert-sync-barrier,dispatch-regions{nb_cores=3},guarded-linalg-to-memref-stream,schedule-memref-linalg,stream-snaxify,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space + +GEN_DATA_OPTS += --m=${SIZE_M} +GEN_DATA_OPTS += --n=${SIZE_N} +GEN_DATA_OPTS += --k=${SIZE_K} + + +CFLAGS += -std=gnu11 +CFLAGS += -Wall -Wextra +# Needed for perfetto script +CFLAGS += -g +ifdef NO_CHECK +CFLAGS += -DNO_CHECK +endif + +data.c data.h: + $(PYTHON) gendata.py ${GEN_DATA_OPTS} + +%.x: %.o main.o data.o + $(LD) $(LDFLAGS) $^ -o $@ + +sim_%: % + rm -fr ./logs/ + $(VLTSIM) $< + +RUN = $(addprefix run_, $(TESTS)) +$(RUN): run_%: sim_% + mv logs $(subst sim_,,$<).logs + +all: $(TESTS) + +allrun: $(RUN) + +clean: + rm -fr *.ll12 *.x *.o *.logs/ logs/ data.h data.c diff --git a/benchmarks/dense_matmul/genbenchmark.py b/benchmarks/dense_matmul/genbenchmark.py new file mode 100644 index 00000000..d72e06c8 --- /dev/null +++ b/benchmarks/dense_matmul/genbenchmark.py @@ -0,0 +1,135 @@ +import itertools +import json +import pathlib +from io import StringIO + +from xdsl.builder import ImplicitBuilder +from xdsl.dialects import arith, builtin, func, linalg +from xdsl.dialects.builtin import i8, i32 +from xdsl.ir import Block, Region +from xdsl.printer import Printer + +from util.snax_benchmark import SNAXBenchmark + + +def create_matrix_multiply(k, m, n): + """ + Generate IR in the form of: + ``` + builtin.module { + func.func @streamer_matmul(%arg0 : memref<16x16xi8>, %arg1 : memref<16x16xi8, + strided<[1, 16]>>, %arg2 : memref<16x16xi32>) { + %0 = arith.constant 0 : i32 + linalg.quantized_matmul ins(%arg0, %arg1, %0, %0 : memref<16x16xi8>, + memref<16x16xi8, strided<[1, 16]>>, i32, i32) + outs(%arg2 : memref<16x16xi32>) + func.return + } + } + ``` + """ + + def get_2d_memref_type(typ, dim_one, dim_two, transpose=False): + layout = ( + builtin.StridedLayoutAttr([1, dim_one]) if transpose else builtin.NoneAttr() + ) + return builtin.MemRefType(typ, [dim_one, dim_two], layout=layout) + + input_types = [ + get_2d_memref_type(i8, k, m), + get_2d_memref_type(i8, m, n, transpose=True), + get_2d_memref_type(i32, k, n), + ] + + b = Block(arg_types=(input_types)) + + with ImplicitBuilder(b) as (arg0, arg1, arg2): + c0 = arith.Constant.from_int_and_width(0, 32) + linalg.QuantizedMatmulOp([arg0, arg1, c0.result, c0.result], [arg2]) + func.Return() + + region = Region(b) + + function = func.FuncOp.from_region("streamer_matmul", input_types, [], region) + + module = builtin.ModuleOp([function]) + + return module + + +def write_module_to_file(module, file): + output = StringIO() + printer = Printer(stream=output) + printer.print(module) + with open(file, "w") as output_file: + output_file.write(output.getvalue()) + + +def generate_tiled_benchmark(m, n, k) -> SNAXBenchmark: + module = create_matrix_multiply(k, m, n) + write_module_to_file(module, "generated.mlir") + binary = "generated.x" + bm = SNAXBenchmark( + kernel=f"tiled_matmul_generated_{k}x{n}x{m}", + binary=binary, + src_dir=str(pathlib.Path.cwd()), + export_dir=str(pathlib.Path.cwd()), + ) + return bm + + +if __name__ == "__main__": + """Runs the gendata.py script with specified arguments.""" + selected_dims = [32, 48, 64] + + sizes = list(itertools.product(selected_dims, repeat=3)) + + output_report: dict[str, dict] = {} + + for size in sizes: + k, m, n = size + folder = f"test_generated_{k}x{m}x{m}" + bm = generate_tiled_benchmark(k, m, n) + bm.clean() + bm.build( + build_opts=[ + "NO_CHECK=1", + f"SIZE_M={m}", + f"SIZE_N={n}", + f"SIZE_K={k}", + ] + ) + bm.run() + bm.trace() + bm.process_traces(folder) + bm.copy_binary(folder) + bm.copy_logs(folder) + + # add to output report + trace = bm.log_dir.joinpath(bm.input_file.format(hart="00000")) + with open(trace) as file: + data = json.load(file) + cycles = data[1]["cycles"] + ideal = round((k / 8 + 1) * (m / 8) * (n / 8)) + utilization = ideal / cycles + output_report[bm.benchmark] = { + "cycles": cycles, + "ideal": ideal, + "utilization": utilization, + } + + with open("output_report.txt", "w") as file: + file.write("benchmark\tcycles\tideal\tutilization\t\n") + avg_utilization = 0 + avg_n = 0 + for benchmark in output_report: + file.write(f"{benchmark}\t") + file.write(f"{output_report[benchmark]['cycles']}\t") + file.write(f"{output_report[benchmark]['ideal']}\t") + file.write(f"{output_report[benchmark]['utilization']}\t") + file.write("\n") + avg_utilization += output_report[benchmark]["utilization"] + avg_n += 1 + file.write("--------------------------\n") + file.write("average\t\t\t") + file.write(f"{avg_utilization/avg_n}\t\n") diff --git a/benchmarks/dense_matmul/gendata.py b/benchmarks/dense_matmul/gendata.py new file mode 100755 index 00000000..33b9d826 --- /dev/null +++ b/benchmarks/dense_matmul/gendata.py @@ -0,0 +1,67 @@ +# simple script to generate inputs and expected outputs for simple_matmult + +import argparse + +import numpy as np + +from util.gendata import create_data, create_header + + +def create_test_data(n, m, k): + print(f"Creating test data with n={n}, m={m}, k={k}") + # Reset random seed for reproducible behavior + + np.random.seed(0) + + A_size = [n, k] + B_size = [k, m] + + # C = A.B + low_bound = -128 + high_bound = 127 + + A = np.random.randint(low_bound, high_bound, size=A_size, dtype=np.dtype("int8")) + B = np.random.randint(low_bound, high_bound, size=B_size, dtype=np.dtype("int8")) + + # Make sure the product is possible! + assert A.shape[1] == B.shape[0] + C_golden = np.matmul(A.astype(np.dtype("int32")), B.astype(np.dtype("int32"))) + C = np.zeros(C_golden.shape, np.dtype("int32")) + + # Perform layout transformations before writing to memory + + # only thing necessary: transform B from row-major to column-major + B_new_layout = np.transpose(B) + + # C are just all zeros, so layout not important + sizes = { + "N_size": A.shape[0], + "K_size": A.shape[1], + "M_size": B.shape[1], + } + variables = { + "A": A, + "B": B_new_layout, + "C_golden": C_golden, + "C": C, + } + + create_header("data.h", sizes, variables) + create_data("data.c", variables) + + +if __name__ == "__main__": + # Set up the argument parser + parser = argparse.ArgumentParser( + description="Generate test data with specified parameters." + ) + # Adding arguments + parser.add_argument("--n", type=int, default=16, help="Value for n (default: 16)") + parser.add_argument("--m", type=int, default=16, help="Value for m (default: 16)") + parser.add_argument("--k", type=int, default=16, help="Value for k (default: 16)") + + # Parse the arguments + args = parser.parse_args() + + # Call the function with parsed arguments + create_test_data(n=args.n, m=args.m, k=args.k) diff --git a/benchmarks/dense_matmul/main.c b/benchmarks/dense_matmul/main.c new file mode 100644 index 00000000..e6005c3d --- /dev/null +++ b/benchmarks/dense_matmul/main.c @@ -0,0 +1,85 @@ +#include "stdint.h" + +#include "data.h" +#include "memref.h" +#include "snax_rt.h" + +/* + * These libraries are included from github.com/KULeuven-MICAS/snitch_cluster + * Interested users, might want to look at: + * + * /sw/snRuntime/api + * /target/snitch_cluster/sw/runtime/rtl/src + * /target/snitch_cluster/sw/runtime/common + * */ +#include + +// Kernel provided via external definition +void _mlir_ciface_streamer_matmul(TwoDMemrefI8_t *a, TwoDMemrefI8_t *b, + TwoDMemrefI32_t *c); + +int main() { + { + + // Create memref objects for data stored in L3 + TwoDMemrefI8_t memrefA; + memrefA.data = &A; + memrefA.aligned_data = memrefA.data; + memrefA.shape[0] = N_size; + memrefA.shape[1] = K_size; + memrefA.stride[0] = K_size; + memrefA.stride[1] = 1; + memrefA.offset = 0; + + TwoDMemrefI8_t memrefB; + memrefB.data = &B; + memrefB.aligned_data = memrefB.data; + memrefB.shape[0] = K_size; + memrefB.shape[1] = M_size; + memrefB.stride[0] = 1; + memrefB.stride[1] = K_size; + memrefB.offset = 0; + + TwoDMemrefI32_t memrefC; + memrefC.data = &C; + memrefC.aligned_data = memrefC.data; + memrefC.shape[0] = N_size; + memrefC.shape[1] = M_size; + memrefC.stride[0] = M_size; + memrefC.stride[1] = 1; + memrefC.offset = 0; + + _mlir_ciface_streamer_matmul(&memrefA, &memrefB, &memrefC); + + snrt_cluster_hw_barrier(); + + // Correctness check - + // from this point on only core 0 is required to be alive. + int thiscore = snrt_cluster_core_idx(); + if (thiscore != 0) + return 0; + +#ifdef NO_CHECK + // No correctness check = + // Always finish as if nothing happened + return 0; +#endif + int nerr = 0; + + for (int i = 0; i < M_size * N_size; i++) { + { + int32_t error = memrefC.aligned_data[i] - C_golden[i]; + // printf("%d) %d -> %d\n", i, (int32_t)memrefC.aligned_data[i], + // (int32_t)C_golden[i]); + if (error != 0) + nerr += 1; + } + } + + // insert mcycle to show fault in trace + if (nerr != 0) + snrt_mcycle(); + + return nerr; + } +}