Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run simulations in parallel #347

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
8 changes: 5 additions & 3 deletions kernels/alloc/Snakefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from util.snake.configs import get_snax_mac_config
from util.snake.paths import get_traces

config = get_snax_mac_config()

Expand Down Expand Up @@ -28,11 +29,12 @@ module default_rules:
config


files = ["func"]


rule all:
input:
"func.x",
shell:
"{config[vltsim]} {input[0]}"
*get_traces(files, config["num_chips"], config["num_harts"], "dasm"),


use rule * from default_rules exclude compile_simple_main as default_*
Expand Down
8 changes: 5 additions & 3 deletions kernels/gemm/Snakefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from util.snake.configs import get_snax_gemmx_config
from util.snake.paths import get_traces

config = get_snax_gemmx_config()

Expand Down Expand Up @@ -49,11 +50,12 @@ rule compile_main:
"{config[cc]} {config[cflags]} -c {input[0]}"


files = ["gemm"]


rule all:
input:
"gemm.x",
shell:
"{config[vltsim]} {input[0]}"
*get_traces(files, config["num_chips"], config["num_harts"], "dasm"),


from gendata import create_data_files
Expand Down
8 changes: 5 additions & 3 deletions kernels/rescale/Snakefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from util.snake.configs import get_snax_gemmx_config
from util.snake.paths import get_traces

config = get_snax_gemmx_config()
config["snaxoptflags"] = ",".join(
Expand Down Expand Up @@ -44,11 +45,12 @@ rule link_snax_binary:
"{config[ld]} {config[ldflags]} {input} -o {output}"


files = ["rescale"]


rule all:
input:
"rescale.x",
shell:
"{config[vltsim]} {input[0]}"
*get_traces(files, config["num_chips"], config["num_harts"], "dasm"),


from gendata import create_data_files
Expand Down
8 changes: 5 additions & 3 deletions kernels/simple_copy/Snakefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from util.snake.configs import get_snax_mac_config
from util.snake.paths import get_traces

config = get_snax_mac_config()
config["snaxoptflags"] = ",".join(
Expand Down Expand Up @@ -29,12 +30,13 @@ module default_rules:
use rule * from default_rules as default_*


files = ["simple_copy"]


# Rules
rule all:
input:
"simple_copy.x",
shell:
"{config[vltsim]} {input[0]}"
*get_traces(files, config["num_chips"], config["num_harts"], "dasm"),


rule compile_main:
Expand Down
11 changes: 5 additions & 6 deletions kernels/streamer_alu/Snakefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from util.snake.configs import get_snax_alu_config
from util.snake.paths import get_traces, get_trace_ext

config = get_snax_alu_config()
config["snaxoptflags"] = ",".join(
Expand Down Expand Up @@ -39,14 +40,12 @@ module default_rules:
use rule * from default_rules as default_*


# Rules
files = ["streamer_add", "streamer_add_stream"]


rule all:
input:
"streamer_add.x",
"streamer_add_stream.x",
run:
for item in input:
shell("{config[vltsim]} {item}")
*get_traces(files, config["num_chips"], config["num_harts"], "dasm"),


from gendata import create_data_files
Expand Down
10 changes: 5 additions & 5 deletions kernels/streamer_matmul/Snakefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from util.snake.configs import get_snax_gemmx_config
from util.snake.paths import get_traces

config = get_snax_gemmx_config()
config["snaxoptflags"] = ",".join(
Expand Down Expand Up @@ -42,14 +43,13 @@ module default_rules:
use rule * from default_rules as default_*


files = ["quantized_matmul", "tiled_quantized_matmul"]


# Rules
rule all:
input:
"quantized_matmul.x",
"tiled_quantized_matmul.x",
run:
for item in input:
shell("{config[vltsim]} {item}")
*get_traces(files, config["num_chips"], config["num_harts"], "dasm"),


rule generate_quantized_matmul:
Expand Down
21 changes: 12 additions & 9 deletions kernels/tiled_add/Snakefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from util.snake.configs import get_snax_alu_config
from util.snake.paths import get_traces

config = get_snax_alu_config()

Expand Down Expand Up @@ -122,16 +123,18 @@ COMPILER_OPTS = ["accfgboth", "noaccfgopt"]

rule all:
input:
expand(
"{file}_{array_size}_{tile_size}_{compiler_opt}.x",
file=FILES,
array_size=ARRAY_SIZES,
tile_size=TILE_SIZES,
compiler_opt=COMPILER_OPTS,
*get_traces(
expand(
"{file}_{array_size}_{tile_size}_{compiler_opt}",
file=FILES,
array_size=ARRAY_SIZES,
tile_size=TILE_SIZES,
compiler_opt=COMPILER_OPTS,
),
config["num_chips"],
config["num_harts"],
"dasm",
),
run:
for item in input:
shell("{config[vltsim]} {item}")


from gendata import generate_data
Expand Down
6 changes: 6 additions & 0 deletions util/snake/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def get_snax_mac_config():
config = {}
config.update(get_default_paths())
config.update(get_default_flags(snitch_sw_path))
config["num_chips"] = 1
config["num_harts"] = 2
config["vltsim"] = f"{snax_utils_path}/snax-mac-rtl/bin/snitch_cluster.vlt"
config["cflags"].append(
f"-I{snitch_sw_path}/target/snitch_cluster/sw/snax/mac/include"
Expand All @@ -31,6 +33,8 @@ def get_snax_gemmx_config():
config = {}
config.update(get_default_paths())
config.update(get_default_flags(snitch_sw_path))
config["num_chips"] = 1
config["num_harts"] = 2
config["vltsim"] = (
snax_utils_path
+ "/snax-kul-cluster-mixed-narrow-wide-rtl/bin/snitch_cluster.vlt"
Expand All @@ -45,6 +49,8 @@ def get_snax_alu_config():
config = {}
config.update(get_default_paths())
config.update(get_default_flags(snitch_sw_path))
config["num_chips"] = 1
config["num_harts"] = 2
config["vltsim"] = snax_utils_path + "/snax-alu-rtl/bin/snitch_cluster.vlt"
return config

Expand Down
20 changes: 20 additions & 0 deletions util/snake/default_rules.smk
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from util.snake.paths import get_trace_ext


rule snax_opt_mlir:
"""
Apply various transformations snax-opt on mlir files.
Expand Down Expand Up @@ -54,3 +57,20 @@ rule clean:
"""
shell:
"rm -rf *.ll12 *.x *.o *.logs/ logs/ data* *.dasm"


rule simulate:
input:
"{file}.x",
output:
temp(
multiext(
*get_trace_ext(
"{file}", config["num_chips"], config["num_harts"], "dasm"
)
)
JosseVanDelm marked this conversation as resolved.
Show resolved Hide resolved
),
log:
"{file}.vltlog",
shell:
"{config[vltsim]} --prefix-trace={wildcards.file}_ {wildcards.file}.x 2>&1 | tee {log}"
29 changes: 29 additions & 0 deletions util/snake/paths.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections.abc import Generator, Sequence


def get_default_paths():
return {
"cc": "clang",
Expand All @@ -6,3 +9,29 @@ def get_default_paths():
"mlir-translate": "mlir-translate",
"snax-opt": "snax-opt",
}


def get_traces(
JosseVanDelm marked this conversation as resolved.
Show resolved Hide resolved
prefixes: Sequence[str],
num_chips: int = 1,
num_harts: int = 2,
extension: str = "dasm",
) -> Generator[str, None, None]:
JosseVanDelm marked this conversation as resolved.
Show resolved Hide resolved
for prefix in prefixes:
for chip_id in range(num_chips):
for hart_id in range(num_harts):
yield f"{prefix}_trace_chip_{chip_id:02d}_hart_{hart_id:05d}.{extension}"


def get_trace_ext(
prefix: str, num_chips: int = 1, num_harts: int = 2, extension: str = "dasm"
) -> Sequence[str]:
extensions = [
trace[: 26 + len(extension)]
for trace in get_traces([""], num_chips, num_harts, extension)
]
return [prefix, *extensions]


if __name__ == "__main__":
print(get_trace_ext("file", 1, 2, "dasm"))
Loading