From de5c0102684711226a27762d137f5224455524b8 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Mon, 23 Jan 2023 17:11:31 +0400 Subject: [PATCH 01/27] data loading alpha version --- .gitignore | 2 + evaluate.py | 5 +- rl_interface/action.py | 25 +- rl_interface/environment.py | 42 +-- tiramisu_programs/cpp_file.py | 15 +- tiramisu_programs/optimization.py | 7 + tiramisu_programs/schedule.py | 4 +- tiramisu_programs/schedule_controller.py | 116 ++++++-- tiramisu_programs/schedule_utils.py | 334 ++++++++++++++++------- tiramisu_programs/tiramisu_program.py | 72 ++--- train_ppo.py | 8 +- utils/global_ray_variables.py | 30 +- utils/rl_autoscheduler_config.py | 5 + 13 files changed, 461 insertions(+), 204 deletions(-) diff --git a/.gitignore b/.gitignore index 990428e..48ff7ab 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ scripts/env.sh Dataset* .vscode .idea +dataset +cpps \ No newline at end of file diff --git a/evaluate.py b/evaluate.py index 76db7c0..4972df4 100644 --- a/evaluate.py +++ b/evaluate.py @@ -27,7 +27,8 @@ def get_arguments(): # @hydra.main(config_path="config", config_name="config") def main(config: RLAutoSchedulerConfig, checkpoint=None): - if checkpoint is None: return + if checkpoint is None: + return configure_env_variables(config) best_checkpoint = os.path.join(config.ray.base_path, checkpoint) with ray.init(num_cpus=config.ray.ray_num_cpus): @@ -87,7 +88,7 @@ def main(config: RLAutoSchedulerConfig, checkpoint=None): except: print("error", action, observation, reward, done) continue - result["schedule_str"] = env.schedule_object.schedule_str + result["sched_str"] = env.schedule_object.sched_str result["speedup"] = env.schedule_controller.speedup results.append(result) with open("results.json", "w+") as file: diff --git a/rl_interface/action.py b/rl_interface/action.py index 302cd50..1dce1b4 100644 --- a/rl_interface/action.py +++ b/rl_interface/action.py @@ -1,6 +1,3 @@ -import random - - class Action: """ " Action class to store and standardize the action for the environment. @@ -212,8 +209,8 @@ def parameter(self, comp=None, prog=None): first_it = 6 second_it = 7 - first_fact = 32 #random.choice([32, 64, 128]) - second_fact = 32 #random.choice([32, 64, 128]) + first_fact = 32 # random.choice([32, 64, 128]) + second_fact = 32 # random.choice([32, 64, 128]) # #print("after choosing first and second params and factors") # calculate the loop extent to see if we should create new iterators or not @@ -222,7 +219,7 @@ def parameter(self, comp=None, prog=None): self.it_dict[first_comp][first_it]["upper_bound"] - self.it_dict[first_comp][first_it]["lower_bound"]) # #print("\n first loop extent is ", loop_extent_1) - #print("first factor is", first_fact) + # print("first factor is", first_fact) if loop_extent_1 == first_fact: tiling_flag_1 = False print("Tiling flag 1 false, loopextent == factor") @@ -235,10 +232,10 @@ def parameter(self, comp=None, prog=None): self.it_dict[first_comp][second_it]["upper_bound"] - self.it_dict[first_comp][second_it]["lower_bound"]) # print("\n second loop extent is ", loop_extent_2) - #print("second factor is", second_fact) + # print("second factor is", second_fact) if loop_extent_2 == second_fact: tiling_flag_2 = False - #print("tiling flag 2 false, loopextent == factor") + # print("tiling flag 2 false, loopextent == factor") elif loop_extent_2 < second_fact: print("exceeeption, loop extent 2 smaller than factor") from tiramisu_programs.schedule import LoopExtentException @@ -283,15 +280,15 @@ def parameter(self, comp=None, prog=None): second_it = 6 third_it = 7 - first_fact = 32 #random.choice([32, 64, 128]) - second_fact = 32 #random.choice([32, 64, 128]) - third_fact = 32 #random.choice([32, 64, 128]) + first_fact = 32 # random.choice([32, 64, 128]) + second_fact = 32 # random.choice([32, 64, 128]) + third_fact = 32 # random.choice([32, 64, 128]) # calculate the loop extent to see if we should create new iterators or not loop_extent_1 = abs( self.it_dict[first_comp][first_it]["upper_bound"] - self.it_dict[first_comp][first_it]["lower_bound"]) # #print("\n first loop extent is ", loop_extent_1) - #print("first factor is", first_fact) + # print("first factor is", first_fact) if loop_extent_1 == first_fact: tiling_flag_1 = False print("tiling flag 1 false, loopextent == factor") @@ -304,7 +301,7 @@ def parameter(self, comp=None, prog=None): self.it_dict[first_comp][second_it]["upper_bound"] - self.it_dict[first_comp][second_it]["lower_bound"]) # print("\n second loop extent is ", loop_extent_2) - #print("second factor is", second_fact) + # print("second factor is", second_fact) if loop_extent_2 == second_fact: tiling_flag_2 = False print("tiling flag 2 false, loopextent == factor") @@ -317,7 +314,7 @@ def parameter(self, comp=None, prog=None): self.it_dict[first_comp][third_it]["upper_bound"] - self.it_dict[first_comp][third_it]["lower_bound"]) # print("\n third loop extent is ", loop_extent_3) - #print("third factor is", third_fact) + # print("third factor is", third_fact) if loop_extent_3 == third_fact: tiling_flag_3 = False print("tiling flag 3 false, loopextent == factor") diff --git a/rl_interface/environment.py b/rl_interface/environment.py index 071cf70..3ed95b8 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -1,5 +1,6 @@ # np.set_printoptions(threshold=sys.maxsize) import copy +import logging import random import sys import time @@ -45,11 +46,12 @@ def __init__(self, config, shared_variable_actor): self.previous_cpp_file = None self.shared_variable_actor = shared_variable_actor - print("Loading data from {} \n".format( - config.environment.dataset_path)) # FIX that here + if config.environment.use_dataset: + self.dataset_path = config.environment.json_dataset['cpp_root'] self.id = ray.get(self.shared_variable_actor.increment.remote()) + logging.info("worker getting its list of programs") # List of function names self.progs_list = ray.get( self.shared_variable_actor.get_progs_list.remote(self.id)) @@ -57,9 +59,9 @@ def __init__(self, config, shared_variable_actor): # Dict of programs with their annotations, schedules, exectution times and traces self.progs_dict = ray.get( self.shared_variable_actor.get_progs_dict.remote()) - print("Loaded the dataset!") + logging.info("Loaded the dataset!") - # Dict of function with a Dict containing schedules in STR format and their execution time + # Dict of function with a Dict containing schedules in STR format and their execution time TODO is this used? self.scheds = tiramisu_programs.schedule_utils.ScheduleUtils.get_schedules_str( list(self.progs_dict.keys()), self.progs_dict) # to use it to get the execution time @@ -107,22 +109,23 @@ def reset(self, file=None): tiramisu_programs.cpp_file.CPP_File.clean_cpp_file( self.dataset_path, self.previous_cpp_file) # Choose a random program (function) - random_prog_index = random.randint(0, len(self.progs_list) - 1) + function_name = random.choice(self.progs_list) # Copy the function's files to the dataset copy created file = tiramisu_programs.cpp_file.CPP_File.get_cpp_file( - self.dataset_path, self.progs_list[random_prog_index]) + self.dataset_path, function_name) # Set up the function files to be deleted on the next iteration - self.previous_cpp_file = self.progs_list[random_prog_index] + self.previous_cpp_file = function_name # Load the tiramisu program from the file self.prog = tiramisu_programs.tiramisu_program.TiramisuProgram( - self.config, file) + self.config, file, progs_dict=self.progs_dict) print(f"Trying with program {self.prog.name}") - self.schedule_object = tiramisu_programs.schedule.Schedule(self.prog) + self.schedule_object = tiramisu_programs.schedule.Schedule( + self.prog) self.schedule_controller = tiramisu_programs.schedule_controller.ScheduleController( schedule=self.schedule_object, @@ -131,7 +134,8 @@ def reset(self, file=None): config=self.config) # Load the legality check list. Starts empty - lc_data = ray.get(self.shared_variable_actor.get_lc_data.remote()) + lc_data = ray.get( + self.shared_variable_actor.get_lc_data.remote()) self.schedule_controller.load_legality_data(lc_data) @@ -152,10 +156,13 @@ def reset(self, file=None): self.prog.initial_execution_time = self.progs_dict[ self.prog.name]["initial_execution_time"] else: - self.prog.initial_execution_time = 1.0 - self.progs_dict[self.prog.name] = {} - self.progs_dict[self.prog.name]["initial_execution_time"] = self.prog.initial_execution_time - self.progs_dict[self.prog.name]["program_annotation"] = self.schedule_object.annotations + if not self.config.environment.use_dataset: + self.prog.initial_execution_time = 1.0 + self.progs_dict[self.prog.name] = {} + self.progs_dict[self.prog.name]["initial_execution_time"] = self.prog.initial_execution_time + + if not self.config.environment.use_dataset: + self.progs_dict[self.prog.name]["program_annotation"] = self.schedule_object.annotations except: print("RESET_ERROR_STDERR", @@ -177,7 +184,7 @@ def step(self, raw_action): """ action_name = rl_interface.Action.ACTIONS_ARRAY[raw_action] print("\n ----> {} [ {} ] \n".format( - action_name, self.schedule_object.schedule_str)) + action_name, self.schedule_object.sched_str)) info = {} applied_exception = False reward = 0.0 @@ -216,6 +223,7 @@ def step(self, raw_action): "depth": self.depth, "error": "ended with error in the step function", } + self.obs = copy.deepcopy(self.schedule_object.get_representation()) if (self.schedule_controller.depth == self.schedule_object.MAX_DEPTH) or (self.steps >= 20): @@ -230,12 +238,12 @@ def step(self, raw_action): self.schedule_controller.get_legality_data())) if "schedule" in self.progs_dict[self.prog.name]: self.schedule_object.schedule_dict["speedup"] = speedup - self.schedule_object.schedule_dict["schedule_str"] = self.schedule_object.schedule_str + self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str self.progs_dict[self.prog.name]["schedules_list"].append( self.schedule_object.schedule_dict) else: self.schedule_object.schedule_dict["speedup"] = speedup - self.schedule_object.schedule_dict["schedule_str"] = self.schedule_object.schedule_str + self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str self.progs_dict[self.prog.name]["schedules_list"] = [ self.schedule_object.schedule_dict] reward_object = rl_interface.Reward(speedup) diff --git a/tiramisu_programs/cpp_file.py b/tiramisu_programs/cpp_file.py index 57f8a19..0167208 100644 --- a/tiramisu_programs/cpp_file.py +++ b/tiramisu_programs/cpp_file.py @@ -8,6 +8,7 @@ from datetime import datetime import re import torch +import ray from tiramisu_programs.schedule_utils import TimeOutException @@ -162,8 +163,16 @@ def get_cpp_file(cls, Dataset_path, func_name): os.system("rm -r {}".format(target_path)) # print("directory removed") + with open(original_path, 'r') as f: + original_str = f.read() + + original_str = original_str.replace( + f'#include "{func_name}_wrapper.h"', '') + os.mkdir(target_path) - os.system("cp -r {} {}".format(original_path, target_path)) + with open(f"{target_path}/{file_name}", 'w') as f: + f.write(original_str) + # os.system("cp -r {} {}".format(original_path, target_path)) return target_path + "/" + file_name @classmethod @@ -177,9 +186,9 @@ def clean_cpp_file(cls, dataset_path, func_name): Returns: str: The new copied function path. """ - target_path = f"{dataset_path}{func_name}" + target_path = "{}/Dataset_copies/{}".format(".", func_name) - if os.path.isdir(dataset_path) and os.path.isdir(target_path): + if os.path.isdir("./Dataset_copies") and os.path.isdir(target_path): os.system("rm -r {}".format(target_path)) return True else: diff --git a/tiramisu_programs/optimization.py b/tiramisu_programs/optimization.py index e5406e4..2163b55 100644 --- a/tiramisu_programs/optimization.py +++ b/tiramisu_programs/optimization.py @@ -1,6 +1,7 @@ class OptimizationCommand: """Represents a Tirtamisu transformation and maps to Tiramisu code. """ + def __init__(self, optim_type, params_list, comps): assert optim_type in [ "Interchange", @@ -16,6 +17,12 @@ def __init__(self, optim_type, params_list, comps): self.comps = comps self.tiramisu_optim_str = self.get_tiramisu_optim_str() + def __str__(self) -> str: + return f"{self.type} of {self.params_list}" + + def __repr__(self): + return f'OptimizationCommand(type={self.type}, params_list={self.params_list}, comps={self.comps})' + def get_tiramisu_optim_str(self): """Convert the optimization command into Tiramisu code. diff --git a/tiramisu_programs/schedule.py b/tiramisu_programs/schedule.py index 02f736c..0503459 100644 --- a/tiramisu_programs/schedule.py +++ b/tiramisu_programs/schedule.py @@ -15,8 +15,8 @@ class Schedule: MAX_COMPS = 5 def __init__(self, program): - self.depth = 0 - self.schedule_str = "" + # self.depth = 0 + self.sched_str = "" self.is_interchaged = False self.is_tiled = False self.is_unrolled = False diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 6311289..09bbf38 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -61,6 +61,7 @@ def apply_action(self, action): else: comp = list(self.schedule_object.it_dict.keys())[0] action_params = action.parameter(comp, self.schedule_object.prog) + ray.util.pdb.set_trace() if action.id in range(28): # Interchange if not self.schedule_object.is_interchaged: @@ -73,6 +74,15 @@ def apply_action(self, action): self.schedule_object.comps) self.schedule.append(optim1) + tmp_sched_str = optimlist_to_str(self.schedule) + print(tmp_sched_str) + + # check if we can find the schedule in the dataset load the legality check + if self.config.environment.use_dataset: + for sched_json in self.schedule_object.prog.json_representation['schedules_list']: + if tmp_sched_str == sched_json['sched_str']: + saved_legality = 1 if sched_json['legality_check'] else None + if self.schedule_object.is_unrolled: lc_check = self.schedule_object.prog.check_legality_of_schedule( self.schedule, self.non_skewed_comps, first_comp) if saved_legality is None else saved_legality @@ -84,12 +94,14 @@ def apply_action(self, action): print("X: The action produced an error.") self.pop_schedule(action=action) raise LCException + if lc_check == 0: print("X: Illegal action") self.pop_schedule(action=action) info = {"illegal_action": True} done = False return self.schedule_object.repr, 1.0, done, info + self.schedule_object.apply_interchange(action_params) print("O: Interchange applied") self.schedule_object.is_interchaged = True @@ -116,6 +128,15 @@ def apply_action(self, action): self.schedule.append(optim2) + tmp_sched_str = optimlist_to_str(self.schedule) + print(tmp_sched_str) + + # check if we can find the schedule in the dataset load the legality check + if self.config.environment.use_dataset: + for sched_json in self.schedule_object.prog.json_representation['schedules_list']: + if tmp_sched_str == sched_json['sched_str']: + saved_legality = 1 if sched_json['legality_check'] else None + if self.schedule_object.is_unrolled: lc_check = self.schedule_object.prog.check_legality_of_schedule( self.schedule, self.non_skewed_comps, first_comp) if saved_legality is None else saved_legality @@ -140,8 +161,8 @@ def apply_action(self, action): done = True exit = True - self.schedule_object.schedule_str = ScheduleUtils.sched_str( - self.schedule_object.schedule_str, action.id, + self.schedule_object.sched_str = ScheduleUtils.sched_str( + self.schedule_object.sched_str, action.id, action_params, self.schedule_object.comp_indic_dict) else: print("X: Tiling already applied exception") @@ -169,6 +190,16 @@ def apply_action(self, action): optim3 = OptimizationCommand("Unrolling", params, self.non_skewed_comps) self.schedule.append(optim3) + + tmp_sched_str = optimlist_to_str(self.schedule) + print(tmp_sched_str) + + # check if we can find the schedule in the dataset load the legality check + if self.config.environment.use_dataset: + for sched_json in self.schedule_object.prog.json_representation['schedules_list']: + if tmp_sched_str == sched_json['sched_str']: + saved_legality = 1 if sched_json['legality_check'] else None + start_time = time.time() lc_check = self.schedule_object.prog.check_legality_of_schedule( self.schedule, self.non_skewed_comps, first_comp) if saved_legality is None else saved_legality @@ -189,8 +220,6 @@ def apply_action(self, action): self.schedule_object.apply_unrolling(action_params) print("O: Unrolling applied") - for i in range(41, 44): - self.schedule_object.repr["action_mask"][i] = 0 self.schedule_object.is_unrolled = True else: lc_check = 0 @@ -279,6 +308,15 @@ def apply_action(self, action): optim5 = OptimizationCommand("Parallelization", params, self.schedule_object.comps) self.schedule.append(optim5) + + tmp_sched_str = optimlist_to_str(self.schedule) + + # check if we can find the schedule in the dataset load the legality check + if self.config.environment.use_dataset: + for sched_json in self.schedule_object.prog.json_representation['schedules_list']: + if tmp_sched_str == sched_json['sched_str']: + saved_legality = 1 if sched_json['legality_check'] else None + start_time = time.time() if self.schedule_object.is_unrolled: lc_check = self.schedule_object.prog.check_legality_of_schedule( @@ -315,6 +353,17 @@ def apply_action(self, action): optim6 = OptimizationCommand("Reversal", params, self.schedule_object.comps) self.schedule.append(optim6) + + tmp_sched_str = optimlist_to_str(self.schedule) + print(tmp_sched_str) + ray.util.pdb.set_trace() + # check if we can find the schedule in the dataset load the legality check + if self.config.environment.use_dataset: + for sched_json in self.schedule_object.prog.json_representation['schedules_list']: + if tmp_sched_str == sched_json['sched_str']: + saved_legality = 1 if sched_json['legality_check'] else None + + ray.util.pdb.set_trace() start_time = time.time() if self.schedule_object.is_unrolled: lc_check = self.schedule_object.prog.check_legality_of_schedule( @@ -355,6 +404,15 @@ def apply_action(self, action): self.schedule.append(optim7) + tmp_sched_str = optimlist_to_str(self.schedule) + print(tmp_sched_str) + + # check if we can find the schedule in the dataset load the legality check + if self.config.environment.use_dataset: + for sched_json in self.schedule_object.prog.json_representation['schedules_list']: + if tmp_sched_str == sched_json['sched_str']: + saved_legality = 1 if sched_json['legality_check'] else None + start_time = time.time() if self.schedule_object.is_unrolled: @@ -392,8 +450,8 @@ def apply_action(self, action): if (not exit and lc_check != 0) and not (action.id in range( 41, 44) and self.schedule_object.is_skewed): - self.schedule_object.schedule_str = ScheduleUtils.sched_str( - self.schedule_object.schedule_str, action.id, action_params, + self.schedule_object.sched_str = ScheduleUtils.sched_str( + self.schedule_object.sched_str, action.id, action_params, self.schedule_object.comp_indic_dict) if not action.id in range(41, 44): self.schedule_object.it_dict = ScheduleUtils.update_iterators( @@ -467,7 +525,7 @@ def test_additional_actions(self, training=True): unroll_optimisation.params_list[comp][0]) + "," + str( unroll_factor) + ",C" + str( self.schedule_object.comp_indic_dict[comp]) + ")" - self.schedule_object.schedule_str = self.schedule_object.schedule_str.replace( + self.schedule_object.sched_str = self.schedule_object.sched_str.replace( unrolling_str, "") + new_unrolling_str self.schedule.remove(unroll_optimisation) self.schedule.append(new_unrolling_optim) @@ -500,8 +558,8 @@ def test_additional_actions(self, training=True): try: - self.schedule_object.schedule_str = ScheduleUtils.sched_str( - self.schedule_object.schedule_str, action.id, + self.schedule_object.sched_str = ScheduleUtils.sched_str( + self.schedule_object.sched_str, action.id, action_params, self.schedule_object.comp_indic_dict) parallelized_exec_time = self.get_exec_time() parallelization_str = 'P(L' + str( @@ -509,7 +567,7 @@ def test_additional_actions(self, training=True): except: print("X: Illegal action") self.schedule.remove(optim5) - self.schedule_object.schedule_str = self.schedule_object.schedule_str.replace( + self.schedule_object.sched_str = self.schedule_object.sched_str.replace( parallelization_str, "") if parallelized_exec_time < exec_time and parallelized_exec_time != 0: @@ -522,8 +580,8 @@ def test_additional_actions(self, training=True): else: self.schedule.remove(optim5) self.new_scheds[self.schedule_object.prog.name].pop( - self.schedule_object.schedule_str) - self.schedule_object.schedule_str = self.schedule_object.schedule_str.replace( + self.schedule_object.sched_str) + self.schedule_object.sched_str = self.schedule_object.sched_str.replace( parallelization_str, "") self.schedule_object.schedule_dict[first_comp][ "parallelized_dim"] = None @@ -540,7 +598,7 @@ def test_additional_actions(self, training=True): if exec_time != 0: print("\nThe final schedule is ", - self.schedule_object.schedule_str) + self.schedule_object.sched_str) self.speedup = ( self.schedule_object.prog.initial_execution_time / exec_time) print("The speedup is: ", self.speedup) @@ -551,8 +609,8 @@ def test_additional_actions(self, training=True): def get_exec_time_by_model(self, optims_list, cmd_type, nb_executions, initial_exec_time): self.schedule_list_model.append({ - "schedule_str": - self.schedule_object.schedule_str, + "sched_str": + self.schedule_object.sched_str, "schedule_dict": self.schedule_object.schedule_dict }) @@ -589,17 +647,17 @@ def get_exec_time_by_model(self, optims_list, cmd_type, nb_executions, def get_exec_time(self): prog_name = self.schedule_object.prog.name execution_time = 0 - if self.schedule_object.schedule_str != "" and self.schedule != []: + if self.schedule_object.sched_str != "" and self.schedule != []: if prog_name in self.scheds.keys(): - if self.schedule_object.schedule_str in self.scheds[prog_name]: + if self.schedule_object.sched_str in self.scheds[prog_name]: execution_time = self.scheds[prog_name][ - self.schedule_object.schedule_str][0] + self.schedule_object.sched_str][0] else: if prog_name in self.new_scheds.keys( - ) and self.schedule_object.schedule_str in self.new_scheds[ + ) and self.schedule_object.sched_str in self.new_scheds[ prog_name].keys(): execution_time = self.new_scheds[prog_name][ - self.schedule_object.schedule_str][1] + self.schedule_object.sched_str][1] else: curr_sched = copy.deepcopy(self.schedule) self.new_scheds[prog_name] = {} @@ -607,21 +665,21 @@ def get_exec_time(self): self.schedule, 'sched_eval', self.nb_executions, self.schedule_object.prog.initial_execution_time) self.new_scheds[prog_name][ - self.schedule_object.schedule_str] = ( + self.schedule_object.sched_str] = ( curr_sched, execution_time, 0) else: if prog_name in self.new_scheds.keys(): - if self.schedule_object.schedule_str in self.new_scheds[ + if self.schedule_object.sched_str in self.new_scheds[ prog_name].keys(): execution_time = self.new_scheds[prog_name][ - self.schedule_object.schedule_str][1] + self.schedule_object.sched_str][1] else: curr_sched = copy.deepcopy(self.schedule) execution_time = self.measurement_env( self.schedule, 'sched_eval', self.nb_executions, self.schedule_object.prog.initial_execution_time) self.new_scheds[prog_name][ - self.schedule_object.schedule_str] = ( + self.schedule_object.sched_str] = ( curr_sched, execution_time, 0) else: curr_sched = copy.deepcopy(self.schedule) @@ -632,15 +690,15 @@ def get_exec_time(self): self.schedule_object.prog.initial_execution_time) sched_time = time.time() - start_time self.new_scheds[prog_name][ - self.schedule_object.schedule_str] = (curr_sched, - execution_time, - 0) + self.schedule_object.sched_str] = (curr_sched, + execution_time, + 0) else: execution_time = self.schedule_object.prog.initial_execution_time return execution_time def save_legality_data(self, action, lc_check): - key = f"{self.schedule_object.prog.name}@{self.schedule_object.schedule_str}@{action}" + key = f"{self.schedule_object.prog.name}@{self.schedule_object.sched_str}@{action}" self.lc_data.append( [ key, @@ -649,7 +707,7 @@ def save_legality_data(self, action, lc_check): ) def get_legality(self, action): - key = f"{self.schedule_object.prog.name}@{self.schedule_object.schedule_str}@{action}" + key = f"{self.schedule_object.prog.name}@{self.schedule_object.sched_str}@{action}" values = [v for (k, v) in self.lc_data if k == key] return values[0] if len(values) else None diff --git a/tiramisu_programs/schedule_utils.py b/tiramisu_programs/schedule_utils.py index 201e324..38d9f1b 100644 --- a/tiramisu_programs/schedule_utils.py +++ b/tiramisu_programs/schedule_utils.py @@ -1,9 +1,13 @@ import json import re +from typing import List import numpy as np import ray +from tiramisu_programs.optimization import OptimizationCommand + + class LargeAccessMatices(Exception): pass @@ -130,100 +134,142 @@ def isl_to_write_matrix(cls, isl_map): return matrix @classmethod - def sched_json_to_sched_str(cls, sched_json, prog_it): - orig_loop_nest = [] - orig_loop_nest.append(list(prog_it.keys())[0]) - child_list = prog_it[list(prog_it.keys())[0]]['child_iterators'] - while len(child_list) > 0: - child_loop = prog_it[child_list[0]] - orig_loop_nest.append(child_list[0]) - child_list = child_loop['child_iterators'] - + def sched_json_to_sched_str(cls, sched_json, program_json): comp_name = [ - n for n in sched_json.keys() if - not n in ['unfuse_iterators', 'tree_structure', 'execution_times'] - ][0] - schedule = sched_json[comp_name] - transf_loop_nest = orig_loop_nest - sched_str = '' - - if schedule['interchange_dims']: - first_dim_index = transf_loop_nest.index( - schedule['interchange_dims'][0]) - second_dim_index = transf_loop_nest.index( - schedule['interchange_dims'][1]) - sched_str += 'I(L' + str(first_dim_index) + ',L' + str( - second_dim_index) + ')' - transf_loop_nest[first_dim_index], transf_loop_nest[ - second_dim_index] = transf_loop_nest[ - second_dim_index], transf_loop_nest[first_dim_index] - if schedule['skewing']['skewed_dims']: - first_dim_index = transf_loop_nest.index( - schedule['skewing']['skewed_dims'][0]) - second_dim_index = transf_loop_nest.index( - schedule['skewing']['skewed_dims'][1]) - first_factor = schedule['skewing']['skewing_factors'][0] - second_factor = schedule['skewing']['skewing_factors'][1] - sched_str += 'S(L' + str(first_dim_index) + ',L' + str( - second_dim_index) + ',' + str(first_factor) + ',' + str( - second_factor) + ')' - if schedule['parallelized_dim']: - dim_index = transf_loop_nest.index(schedule['parallelized_dim']) - sched_str += 'P(L' + str(dim_index) + ')' - if schedule['tiling']['tiling_dims']: - if schedule['tiling']['tiling_depth'] == 2: - first_dim = schedule['tiling']['tiling_dims'][0] - second_dim = schedule['tiling']['tiling_dims'][1] - - first_dim_index = transf_loop_nest.index(first_dim) - second_dim_index = transf_loop_nest.index(second_dim) - first_factor = schedule['tiling']['tiling_factors'][0] - second_factor = schedule['tiling']['tiling_factors'][1] - sched_str += 'T2(L' + str(first_dim_index) + ',L' + str( - second_dim_index) + ',' + str(first_factor) + ',' + str( - second_factor) + ')' - i = transf_loop_nest.index(first_dim) - transf_loop_nest[ - i:i + 1] = first_dim + '_outer', second_dim + '_outer' - i = transf_loop_nest.index(second_dim) - transf_loop_nest[ - i:i + 1] = first_dim + '_inner', second_dim + '_inner' - else: - first_dim = schedule['tiling']['tiling_dims'][0] - second_dim = schedule['tiling']['tiling_dims'][1] - third_dim = schedule['tiling']['tiling_dims'][2] - first_dim_index = transf_loop_nest.index(first_dim) - second_dim_index = transf_loop_nest.index(second_dim) - third_dim_index = transf_loop_nest.index(third_dim) - first_factor = schedule['tiling']['tiling_factors'][0] - second_factor = schedule['tiling']['tiling_factors'][1] - third_factor = schedule['tiling']['tiling_factors'][2] - sched_str += 'T3(L' + str(first_dim_index) + ',L' + str( - second_dim_index) + ',L' + str( - third_dim_index) + ',' + str(first_factor) + ',' + str( - second_factor) + ',' + str(third_factor) + ')' - i = transf_loop_nest.index(first_dim) - transf_loop_nest[ - i:i + - 1] = first_dim + '_outer', second_dim + '_outer', third_dim + '_outer' - i = transf_loop_nest.index(second_dim) - transf_loop_nest[ - i:i + - 1] = first_dim + '_inner', second_dim + '_inner', third_dim + '_inner' - transf_loop_nest.remove(third_dim) - if schedule['unrolling_factor']: - dim_index = len(transf_loop_nest) - 1 - dim_name = transf_loop_nest[-1] - sched_str += 'U(L' + str(dim_index) + ',' + str( - schedule['unrolling_factor']) + ')' - transf_loop_nest[dim_index:dim_index + - 1] = dim_name + '_Uouter', dim_name + '_Uinner' - if schedule["reversed_dim"]: - dim_index = transf_loop_nest.index(schedule["reversed_dim"]) - sched_str += 'R(L' + str(dim_index) + ')' - + n + for n in sched_json.keys() + if not n in ["unfuse_iterators", "tree_structure", "execution_times", "fusions", "sched_str"] + ] + sched_str = "" + + if ("fusions" in sched_json and sched_json["fusions"]): + for fusion in sched_json["fusions"]: + sched_str += "F(" + for name in comp_name: + if name in fusion: + sched_str += name + "," + + sched_str = sched_str[:-1] + sched_str += ")" + + for name in comp_name: + transf_loop_nest = cls.get_original_iterators(program_json) + schedule = sched_json[name] + sched_str += '{' + name + '}:' + + for transformation in schedule["transformations_list"]: + + if (transformation[0] == 1): + sched_str += "I(L" + str(transformation[1]) + \ + ",L" + str(transformation[2]) + ")" + + elif (transformation[0] == 2): + sched_str += f"R(L{str(transformation[3])})" + elif (transformation[0] == 3): + sched_str += "S(L" + str(transformation[4]) + ",L" + str( + transformation[5]) + "," + str(transformation[6]) + "," + str(transformation[7]) + ")" + + if schedule["parallelized_dim"]: + + dim_index = transf_loop_nest.index( + schedule["parallelized_dim"]) + sched_str += "P(L" + str(dim_index) + ")" + + if schedule["tiling"]: + if schedule["tiling"]["tiling_depth"] == 2: + first_dim = schedule["tiling"]["tiling_dims"][0] + second_dim = schedule["tiling"]["tiling_dims"][1] + first_dim_index = transf_loop_nest.index(first_dim) + second_dim_index = transf_loop_nest.index(second_dim) + first_factor = schedule["tiling"]["tiling_factors"][0] + second_factor = schedule["tiling"]["tiling_factors"][1] + sched_str += ( + "T2(L" + + str(first_dim_index) + + ",L" + + str(second_dim_index) + + "," + + str(first_factor) + + "," + + str(second_factor) + + ")" + ) + i = transf_loop_nest.index(first_dim) + transf_loop_nest[i: i + 1] = first_dim + \ + "_outer", second_dim + "_outer" + i = transf_loop_nest.index(second_dim) + transf_loop_nest[i: i + 1] = first_dim + \ + "_inner", second_dim + "_inner" + else: + first_dim = schedule["tiling"]["tiling_dims"][0] + second_dim = schedule["tiling"]["tiling_dims"][1] + third_dim = schedule["tiling"]["tiling_dims"][2] + first_dim_index = transf_loop_nest.index(first_dim) + second_dim_index = transf_loop_nest.index(second_dim) + third_dim_index = transf_loop_nest.index(third_dim) + first_factor = schedule["tiling"]["tiling_factors"][0] + second_factor = schedule["tiling"]["tiling_factors"][1] + third_factor = schedule["tiling"]["tiling_factors"][2] + sched_str += ( + "T3(L" + + str(first_dim_index) + + ",L" + + str(second_dim_index) + + ",L" + + str(third_dim_index) + + "," + + str(first_factor) + + "," + + str(second_factor) + + "," + + str(third_factor) + + ")" + ) + i = transf_loop_nest.index(first_dim) + transf_loop_nest[i: i + 1] = ( + first_dim + "_outer", + second_dim + "_outer", + third_dim + "_outer", + ) + i = transf_loop_nest.index(second_dim) + transf_loop_nest[i: i + 1] = ( + first_dim + "_inner", + second_dim + "_inner", + third_dim + "_inner", + ) + transf_loop_nest.remove(third_dim) + + if schedule["unrolling_factor"]: + dim_index = len(transf_loop_nest) - 1 + dim_name = transf_loop_nest[-1] + sched_str += "U(L" + str(dim_index) + "," + \ + schedule["unrolling_factor"] + ")" + transf_loop_nest[dim_index: dim_index + 1] = ( + dim_name + "_Uouter", + dim_name + "_Uinner", + ) return sched_str + @classmethod + def get_original_iterators(cls, program_json): + iterators = program_json['iterators'] + to_explore = [] + result = [] + to_explore.append(list(iterators.keys())[0]) + while (to_explore): + it_name = to_explore.pop(0) + iterator = iterators[it_name] + result.append(it_name) + for element in iterator["child_iterators"]: + to_explore.append(element) + + return result + + @classmethod + def list_optimizations_to_sched_str(cls, schedule: List[OptimizationCommand]): + + pass + @classmethod def get_schedules_str(cls, programs_list, programs_dict): if programs_dict != {}: @@ -231,17 +277,14 @@ def get_schedules_str(cls, programs_list, programs_dict): functions_set = {} for fun in programs_list: - if 'schedules_list' in programs_dict[fun].keys(): schedules = programs_dict[fun]['schedules_list'] schedules_set = {} for schedule in schedules: - - comp = list(schedule.keys())[0] - schedule_str = schedule[comp]["schedule_str"] - schedules_set[schedule_str] = schedule[comp]["execution_times"] + schedule_str = schedule["sched_str"] + schedules_set[schedule_str] = schedule["execution_times"] functions_set[fun] = schedules_set @@ -477,7 +520,7 @@ def get_orig_tree_struct(cls, program_json, root_iterator): for child_iterator in program_json['iterators'][root_iterator][ 'child_iterators']: tree_struct['child_list'].append( - self.get_orig_tree_struct(program_json, child_iterator)) + cls.get_orig_tree_struct(program_json, child_iterator)) return tree_struct @classmethod @@ -845,3 +888,100 @@ def update_iterators(cls, id, it_list, action_params, added_iterators, it_list = dict(sorted(it_list.items())) return it_list + + +def optimlist_to_str(optim_list): + """Converts a list of OptimizationCommand to a string. + """ + + comp_names = list(set([ + comp for optim in optim_list for comp in optim.comps + ])) + + comp_names.sort() + + sched_str = "" + + # Add fusions first + fusions = [optim for optim in optim_list if optim.type == "Fusion"] + for fusion in fusions: + sched_str += "F(" + for name in fusion.comps: + sched_str += name + "," + + sched_str = sched_str[:-1] + sched_str += ")" + + # Iterate over the comps and add their transformations + for name in comp_names: + sched_str += '{' + name + '}:' + + for transformation in optim_list: + # Skip the transformation if it doesn't include the comp + if name not in transformation.comps: + continue + + if transformation.type == "Interchange": + sched_str += "I(L" + str(transformation.params_list[0]) + \ + ",L" + str(transformation.params_list[1]) + ")" + + elif transformation.type == "Reversal": + sched_str += f"R(L{str(transformation.params_list[0])})" + + elif transformation.type == "Skewing": + sched_str += "S(L" + str(transformation.params_list[0]) + ",L" + str( + transformation.params_list[1]) + "," + str(transformation.params_list[2]) + "," + str( + transformation.params_list[3]) + ")" + + elif transformation.type == "Parallelization": + sched_str += "P(L" + str(transformation.params_list[0]) + ")" + + elif transformation.type == "Tiling": + # T2 + if len(transformation.params_list) == 4: + first_dim_index = transformation.params_list[0] + second_dim_index = transformation.params_list[1] + first_factor = transformation.params_list[2] + second_factor = transformation.params_list[3] + sched_str += ( + "T2(L" + + str(first_dim_index) + + ",L" + + str(second_dim_index) + + "," + + str(first_factor) + + "," + + str(second_factor) + + ")" + ) + # T3 + else: + first_dim_index = transformation.params_list[0] + second_dim_index = transformation.params_list[1] + third_dim_index = transformation.params_list[2] + first_factor = transformation.params_list[3] + second_factor = transformation.params_list[4] + third_factor = transformation.params_list[5] + sched_str += ( + "T3(L" + + str(first_dim_index) + + ",L" + + str(second_dim_index) + + ",L" + + str(third_dim_index) + + "," + + str(first_factor) + + "," + + str(second_factor) + + "," + + str(third_factor) + + ")" + ) + + elif transformation.type == "Unrolling": + dim_index = transformation.params_list[name][0] + unrolling_factor = transformation.params_list[name][1] + sched_str += "U(L" + str(dim_index) + "," + \ + str(unrolling_factor) + ")" + + return sched_str diff --git a/tiramisu_programs/tiramisu_program.py b/tiramisu_programs/tiramisu_program.py index 0edbeb1..3acd2c8 100644 --- a/tiramisu_programs/tiramisu_program.py +++ b/tiramisu_programs/tiramisu_program.py @@ -96,7 +96,7 @@ class TiramisuProgram(): return 0; }''' - def __init__(self, config, file_path): + def __init__(self, config, file_path, progs_dict=None): self.config = config self.file_path = file_path with open(file_path, 'r') as f: @@ -111,6 +111,9 @@ def __init__(self, config, file_path): self.name = re.findall(r'tiramisu::init\(\"(\w+)\"\);', self.original_str)[0] + self.original_str = self.original_str.replace( + f'#include "{self.name}_wrapper.h"', '') + self.comp_name = re.findall(r'computation (\w+)\(', self.original_str) self.code_gen_line = re.findall(r'tiramisu::codegen\({.+;', @@ -126,37 +129,44 @@ def __init__(self, config, file_path): self.original_str)[0] self.buffer_sizes.append(re.findall(r'\d+', sizes_vect)) - self.program_annotations = '' + self.program_annotations = None self.wrapper_is_compiled = False self.initial_execution_time = 1.0 + self.json_representation = None + if config.environment.use_dataset: + self.json_representation = progs_dict[self.name] def get_program_annotations(self): - if not self.program_annotations == '': + if self.program_annotations is not None: return self.program_annotations - # create a cpp file to get the annotations - get_json_lines = ''' - auto ast = tiramisu::auto_scheduler::syntax_tree(tiramisu::global::get_implicit_function()); - std::string program_json = tiramisu::auto_scheduler::evaluate_by_learning_model::get_program_json(ast); - std::ofstream out("''' + self.func_folder + self.name + '''_program_annotations.json"); - out << program_json; - out.close(); - ''' - get_json_prog = self.original_str.replace(self.code_gen_line, - get_json_lines) - output_file = self.func_folder + self.name + '_get_prog_annot.cpp' + if self.config.environment.use_dataset: + self.program_annotations = self.json_representation['program_annotation'] + else: + # create a cpp file to get the annotations + get_json_lines = ''' + auto ast = tiramisu::auto_scheduler::syntax_tree(tiramisu::global::get_implicit_function()); + std::string program_json = tiramisu::auto_scheduler::evaluate_by_learning_model::get_program_json(ast); + std::ofstream out("''' + self.func_folder + self.name + '''_program_annotations.json"); + out << program_json; + out.close(); + ''' + get_json_prog = self.original_str.replace(self.code_gen_line, + get_json_lines) + output_file = self.func_folder + self.name + '_get_prog_annot.cpp' - with open(output_file, 'w') as f: - f.write(get_json_prog) + with open(output_file, 'w') as f: + f.write(get_json_prog) - # compile the cpp file and run to generate annotations in json file - tiramisu_programs.CPP_File.compile_and_run_tiramisu_code( - self.config, output_file, 'Generating program annotations') + # compile the cpp file and run to generate annotations in json file + tiramisu_programs.CPP_File.compile_and_run_tiramisu_code( + self.config, output_file, 'Generating program annotations') + + # Read the json file and return the annotations + with open(self.func_folder + self.name + '_program_annotations.json', + 'r') as f: + self.program_annotations = json.loads(f.read()) - # Read the json file and return the annotations - with open(self.func_folder + self.name + '_program_annotations.json', - 'r') as f: - self.program_annotations = json.loads(f.read()) return self.program_annotations def check_legality_of_schedule( @@ -164,7 +174,7 @@ def check_legality_of_schedule( optims_list, comps=None, first_comp=None - ): + ): legality_check_lines = ''' prepare_schedules_for_legality_checks(); perform_full_dependency_analysis(); @@ -183,7 +193,7 @@ def check_legality_of_schedule( is_legal &= loop_parallelization_is_legal(''' + str( optim.params_list[0]) + ''', {&''' + first_comp + '''}); ''' - legality_check_lines += optim.tiramisu_optim_str + '\n' + legality_check_lines += optim.tiramisu_optim_str + '\n' elif optim.type == 'Tiling': legality_check_lines += optim.tiramisu_optim_str + '\n' elif optim.type == 'Fusion': @@ -195,7 +205,7 @@ def check_legality_of_schedule( optim.params_list[comp] [0]) + ''', {&''' + comp + '''}); ''' - legality_check_lines += optim.tiramisu_optim_str + '\n' + legality_check_lines += optim.tiramisu_optim_str + '\n' legality_check_lines += ''' is_legal &= check_legality_of_function(); @@ -219,7 +229,7 @@ def check_legality_of_schedule( return lc_result - def call_solver(self, comp, params): + def call_solver(self, comp, params): lc_file = self.func_folder + self.name + '_legality_check.cpp' if os.path.isfile(lc_file): with open(lc_file, 'r') as f: @@ -361,8 +371,7 @@ def get_measurements(self, cmd_type, nb_executions, initial_exec_time): return return self.read_measurements_file() - def write_wrapper_code( - self): + def write_wrapper_code(self): buffers_init_lines = '' for i, buffer_name in enumerate(self.IO_buffer_names): @@ -385,7 +394,8 @@ def write_wrapper_code( with open(output_file, 'w') as f: f.write(wrapper_cpp_code) - wrapper_h_code = self.wrapper_h_template.replace('$func_name$', self.name) + wrapper_h_code = self.wrapper_h_template.replace( + '$func_name$', self.name) wrapper_h_code = wrapper_h_code.replace( '$func_params$', ','.join( ['halide_buffer_t *' + name for name in self.IO_buffer_names])) @@ -419,5 +429,3 @@ def read_solver_result_file(self): def reset_solver_result_file(self): with open(self.func_folder + "solver_result.txt", 'w') as f: f.write('-1') - - diff --git a/train_ppo.py b/train_ppo.py index 4cc46ec..4d44a7f 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -1,3 +1,4 @@ +import logging import os # import hydra import argparse @@ -18,6 +19,9 @@ def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--num-workers", default=-1, type=int) + parser.add_argument("--use-dataset", default=False, type=bool) + parser.add_argument("--log-level", default="INFO", # TODO change back to WARN + type=str, choices=list(logging._nameToLevel.keys())) return parser.parse_args() @@ -27,7 +31,7 @@ def main(config: RLAutoSchedulerConfig): progs_list_registery = GlobalVarActor.remote( config.environment.programs_file, config.environment.dataset_path, - num_workers=config.ray.num_workers) + num_workers=config.ray.num_workers, use_dataset=config.environment.use_dataset, json_dataset=config.environment.json_dataset) shared_variable_actor = Actor.remote(progs_list_registery) register_env( @@ -73,6 +77,8 @@ def main(config: RLAutoSchedulerConfig): args = get_arguments() if args.num_workers != -1: config.ray.num_workers = args.num_workers + config.environment.use_dataset = args.use_dataset + logging.basicConfig(level=logging._nameToLevel[args.log_level]) if args.num_workers == 1: with ray.init(): main(config) diff --git a/utils/global_ray_variables.py b/utils/global_ray_variables.py index ff11b56..56dbd05 100644 --- a/utils/global_ray_variables.py +++ b/utils/global_ray_variables.py @@ -1,3 +1,6 @@ +import bz2 +import logging +import pickle from typing import List import ray import json @@ -7,13 +10,17 @@ @ray.remote class GlobalVarActor: - def __init__(self, programs_file, dataset_path, num_workers=7): + def __init__(self, programs_file, dataset_path, num_workers=7, use_dataset=False, json_dataset=None): self.index = -1 self.num_workers = num_workers - self.progs_list = self.get_dataset(dataset_path) + self.progs_list = [] self.programs_file = programs_file self.progs_dict = dict() self.lc_data = [] + self.json_dataset = json_dataset + + self.get_dataset( + dataset_path, use_dataset, json_dataset_path=json_dataset["path"]) # if os.path.isfile(programs_file): # try: # with open(programs_file) as f: @@ -36,11 +43,20 @@ def __init__(self, programs_file, dataset_path, num_workers=7): with open("lc_data.json", "w+") as f: f.write(json.dumps(self.lc_data)) - def get_dataset(self, path): - os.getcwd() - print("***************************", os.getcwd()) - prog_list = os.listdir(path) - return prog_list + # Load the dataset of programs + def get_dataset(self, path, use_dataset=False, json_dataset_path=None): + if use_dataset: + logging.info(f"reading dataset from json at:{json_dataset_path}") + with bz2.BZ2File(json_dataset_path, 'rb') as f: + self.progs_dict = pickle.load(f) + self.progs_list = list(self.progs_dict.keys()) + logging.info( + f"[Done] reading dataset from json at:{json_dataset_path}") + + else: + os.getcwd() + logging.info(f"reading dataset from ls at: {os.getcwd()}") + self.progs_list = os.listdir(path) def set_progs_list(self, v): self.progs_list = v diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 146c7fa..484542d 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -22,6 +22,11 @@ class EnvironmentConfig: dataset_path: str = "../../Dataset_multi/" programs_file: str = "./multicomp.json" clean_files: bool = True + json_dataset: dict = field(default_factory=lambda: { + "path": None, + "cpp_root": None + }) + use_dataset: bool = False @dataclass From 50658fc80152a7bec94c8ab3b18e37875b6539c2 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Tue, 24 Jan 2023 09:52:02 +0400 Subject: [PATCH 02/27] removed berakpoints and fixed code style to match previous code --- tiramisu_programs/schedule_controller.py | 16 +- tiramisu_programs/schedule_utils.py | 177 ++++++++++++----------- 2 files changed, 96 insertions(+), 97 deletions(-) diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 09bbf38..9215c86 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -61,7 +61,6 @@ def apply_action(self, action): else: comp = list(self.schedule_object.it_dict.keys())[0] action_params = action.parameter(comp, self.schedule_object.prog) - ray.util.pdb.set_trace() if action.id in range(28): # Interchange if not self.schedule_object.is_interchaged: @@ -74,7 +73,7 @@ def apply_action(self, action): self.schedule_object.comps) self.schedule.append(optim1) - tmp_sched_str = optimlist_to_str(self.schedule) + tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check @@ -128,7 +127,7 @@ def apply_action(self, action): self.schedule.append(optim2) - tmp_sched_str = optimlist_to_str(self.schedule) + tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check @@ -191,7 +190,8 @@ def apply_action(self, action): self.non_skewed_comps) self.schedule.append(optim3) - tmp_sched_str = optimlist_to_str(self.schedule) + tmp_sched_str = ScheduleUtils.optimlist_to_str( + self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check @@ -309,7 +309,7 @@ def apply_action(self, action): self.schedule_object.comps) self.schedule.append(optim5) - tmp_sched_str = optimlist_to_str(self.schedule) + tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset: @@ -354,16 +354,14 @@ def apply_action(self, action): self.schedule_object.comps) self.schedule.append(optim6) - tmp_sched_str = optimlist_to_str(self.schedule) + tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - ray.util.pdb.set_trace() # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset: for sched_json in self.schedule_object.prog.json_representation['schedules_list']: if tmp_sched_str == sched_json['sched_str']: saved_legality = 1 if sched_json['legality_check'] else None - ray.util.pdb.set_trace() start_time = time.time() if self.schedule_object.is_unrolled: lc_check = self.schedule_object.prog.check_legality_of_schedule( @@ -404,7 +402,7 @@ def apply_action(self, action): self.schedule.append(optim7) - tmp_sched_str = optimlist_to_str(self.schedule) + tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check diff --git a/tiramisu_programs/schedule_utils.py b/tiramisu_programs/schedule_utils.py index 38d9f1b..68c3486 100644 --- a/tiramisu_programs/schedule_utils.py +++ b/tiramisu_programs/schedule_utils.py @@ -889,99 +889,100 @@ def update_iterators(cls, id, it_list, action_params, added_iterators, return it_list + @classmethod + def optimlist_to_str(cls, optim_list): + """Converts a list of OptimizationCommand to a string. + """ -def optimlist_to_str(optim_list): - """Converts a list of OptimizationCommand to a string. - """ - - comp_names = list(set([ - comp for optim in optim_list for comp in optim.comps - ])) - - comp_names.sort() - - sched_str = "" - - # Add fusions first - fusions = [optim for optim in optim_list if optim.type == "Fusion"] - for fusion in fusions: - sched_str += "F(" - for name in fusion.comps: - sched_str += name + "," - - sched_str = sched_str[:-1] - sched_str += ")" - - # Iterate over the comps and add their transformations - for name in comp_names: - sched_str += '{' + name + '}:' - - for transformation in optim_list: - # Skip the transformation if it doesn't include the comp - if name not in transformation.comps: - continue + comp_names = list(set([ + comp for optim in optim_list for comp in optim.comps + ])) - if transformation.type == "Interchange": - sched_str += "I(L" + str(transformation.params_list[0]) + \ - ",L" + str(transformation.params_list[1]) + ")" + comp_names.sort() - elif transformation.type == "Reversal": - sched_str += f"R(L{str(transformation.params_list[0])})" + sched_str = "" - elif transformation.type == "Skewing": - sched_str += "S(L" + str(transformation.params_list[0]) + ",L" + str( - transformation.params_list[1]) + "," + str(transformation.params_list[2]) + "," + str( - transformation.params_list[3]) + ")" + # Add fusions first + fusions = [optim for optim in optim_list if optim.type == "Fusion"] + for fusion in fusions: + sched_str += "F(" + for name in fusion.comps: + sched_str += name + "," - elif transformation.type == "Parallelization": - sched_str += "P(L" + str(transformation.params_list[0]) + ")" + sched_str = sched_str[:-1] + sched_str += ")" - elif transformation.type == "Tiling": - # T2 - if len(transformation.params_list) == 4: - first_dim_index = transformation.params_list[0] - second_dim_index = transformation.params_list[1] - first_factor = transformation.params_list[2] - second_factor = transformation.params_list[3] - sched_str += ( - "T2(L" - + str(first_dim_index) - + ",L" - + str(second_dim_index) - + "," - + str(first_factor) - + "," - + str(second_factor) - + ")" - ) - # T3 - else: - first_dim_index = transformation.params_list[0] - second_dim_index = transformation.params_list[1] - third_dim_index = transformation.params_list[2] - first_factor = transformation.params_list[3] - second_factor = transformation.params_list[4] - third_factor = transformation.params_list[5] - sched_str += ( - "T3(L" - + str(first_dim_index) - + ",L" - + str(second_dim_index) - + ",L" - + str(third_dim_index) - + "," - + str(first_factor) - + "," - + str(second_factor) - + "," - + str(third_factor) - + ")" - ) + # Iterate over the comps and add their transformations + for name in comp_names: + sched_str += '{' + name + '}:' - elif transformation.type == "Unrolling": - dim_index = transformation.params_list[name][0] - unrolling_factor = transformation.params_list[name][1] - sched_str += "U(L" + str(dim_index) + "," + \ - str(unrolling_factor) + ")" + for transformation in optim_list: + # Skip the transformation if it doesn't include the comp + if name not in transformation.comps: + continue + + if transformation.type == "Interchange": + sched_str += "I(L" + str(transformation.params_list[0]) + \ + ",L" + str(transformation.params_list[1]) + ")" + + elif transformation.type == "Reversal": + sched_str += f"R(L{str(transformation.params_list[0])})" + + elif transformation.type == "Skewing": + sched_str += "S(L" + str(transformation.params_list[0]) + ",L" + str( + transformation.params_list[1]) + "," + str(transformation.params_list[2]) + "," + str( + transformation.params_list[3]) + ")" + + elif transformation.type == "Parallelization": + sched_str += "P(L" + \ + str(transformation.params_list[0]) + ")" + + elif transformation.type == "Tiling": + # T2 + if len(transformation.params_list) == 4: + first_dim_index = transformation.params_list[0] + second_dim_index = transformation.params_list[1] + first_factor = transformation.params_list[2] + second_factor = transformation.params_list[3] + sched_str += ( + "T2(L" + + str(first_dim_index) + + ",L" + + str(second_dim_index) + + "," + + str(first_factor) + + "," + + str(second_factor) + + ")" + ) + # T3 + else: + first_dim_index = transformation.params_list[0] + second_dim_index = transformation.params_list[1] + third_dim_index = transformation.params_list[2] + first_factor = transformation.params_list[3] + second_factor = transformation.params_list[4] + third_factor = transformation.params_list[5] + sched_str += ( + "T3(L" + + str(first_dim_index) + + ",L" + + str(second_dim_index) + + ",L" + + str(third_dim_index) + + "," + + str(first_factor) + + "," + + str(second_factor) + + "," + + str(third_factor) + + ")" + ) + + elif transformation.type == "Unrolling": + dim_index = transformation.params_list[name][0] + unrolling_factor = transformation.params_list[name][1] + sched_str += "U(L" + str(dim_index) + "," + \ + str(unrolling_factor) + ")" - return sched_str + return sched_str From 6985d62ef76f23b47a0b42c18de11c4aa6a0d0f5 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Tue, 24 Jan 2023 10:43:54 +0400 Subject: [PATCH 03/27] added support for the legality check of skewing --- tiramisu_programs/schedule_controller.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 9215c86..3ab22f9 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -264,6 +264,16 @@ def apply_action(self, action): self.schedule.append(optim4) + tmp_sched_str = ScheduleUtils.optimlist_to_str( + self.schedule) + print(tmp_sched_str) + + # check if we can find the schedule in the dataset load the legality check + if self.config.environment.use_dataset: + for sched_json in self.schedule_object.prog.json_representation['schedules_list']: + if tmp_sched_str == sched_json['sched_str']: + saved_legality = 1 if sched_json['legality_check'] else None + start_time = time.time() if self.schedule_object.is_unrolled: lc_check = self.schedule_object.prog.check_legality_of_schedule( @@ -310,6 +320,7 @@ def apply_action(self, action): self.schedule.append(optim5) tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) + print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset: @@ -451,6 +462,7 @@ def apply_action(self, action): self.schedule_object.sched_str = ScheduleUtils.sched_str( self.schedule_object.sched_str, action.id, action_params, self.schedule_object.comp_indic_dict) + ray.util.pdb.set_trace() if not action.id in range(41, 44): self.schedule_object.it_dict = ScheduleUtils.update_iterators( action.id, self.schedule_object.it_dict, action_params, From 8bdd2c2c711ca0d5f323ab4852c07f05a5b7fc32 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Tue, 24 Jan 2023 10:44:23 +0400 Subject: [PATCH 04/27] removed breakpoint --- tiramisu_programs/schedule_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 3ab22f9..4ea628a 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -462,7 +462,7 @@ def apply_action(self, action): self.schedule_object.sched_str = ScheduleUtils.sched_str( self.schedule_object.sched_str, action.id, action_params, self.schedule_object.comp_indic_dict) - ray.util.pdb.set_trace() + # ray.util.pdb.set_trace() if not action.id in range(41, 44): self.schedule_object.it_dict = ScheduleUtils.update_iterators( action.id, self.schedule_object.it_dict, action_params, From c84c8911a29db594cee284d325507aeac9142b02 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Wed, 25 Jan 2023 10:01:37 +0400 Subject: [PATCH 05/27] working on resuming learning --- requirements.txt | 7 ++- train_ppo.py | 85 +++++++++++++++++++++----------- utils/rl_autoscheduler_config.py | 9 ++-- 3 files changed, 67 insertions(+), 34 deletions(-) diff --git a/requirements.txt b/requirements.txt index 174d708..dd7682d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ PyYAML == 6.0 -gym == 0.21.0 +gym==0.21.0 +gymnasium==0.27.1 numpy == 1.23.1 -ray[rllib] == 1.13.0 +ray[rllib] == 2.2.0 sympy == 1.10.1 torch == 1.12.0 tqdm == 4.64.0 +pandas == 1.5.3 +tensorflow_probability == 0.19.0 \ No newline at end of file diff --git a/train_ppo.py b/train_ppo.py index 4cc46ec..d817167 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -3,7 +3,7 @@ import argparse import ray # from hydra.core.config_store import ConfigStore -from ray import tune +from ray import tune, air from ray.rllib.models.catalog import ModelCatalog from ray.tune.registry import register_env @@ -18,6 +18,8 @@ def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--num-workers", default=-1, type=int) + parser.add_argument('--resume-training', + action=argparse.BooleanOptionalAction) return parser.parse_args() @@ -36,35 +38,60 @@ def main(config: RLAutoSchedulerConfig): ) ModelCatalog.register_custom_model("tiramisu_model_v1", TiramisuModelMult) - - analysis = tune.run( - "PPO", - local_dir=local_dir, - name=config.ray.name, - stop={"training_iteration": config.ray.training_iteration}, - max_failures=0, - checkpoint_freq=config.ray.checkpoint_freq, - verbose=0, - config={ - "env": "Tiramisu_env_v1", - "num_workers": config.ray.num_workers, - "placement_strategy": "SPREAD", - "batch_mode": "complete_episodes", - "train_batch_size": max(config.ray.num_workers * 200, config.training.train_batch_size), - "sgd_minibatch_size": config.training.sgd_minibatch_size, - "lr": config.training.lr, - "num_sgd_iter": config.training.num_sgd_iter, - "framework": "torch", - "_disable_preprocessor_api": True, - "model": { - "custom_model": "tiramisu_model_v1", - "custom_model_config": { - "layer_sizes": list(config.model.layer_sizes), - "drops": list(config.model.drops), - }, + config_dict = { + "env": "Tiramisu_env_v1", + "num_workers": config.ray.num_workers, + "placement_strategy": "SPREAD", + "batch_mode": "complete_episodes", + "train_batch_size": max(config.ray.num_workers * 200, config.training.train_batch_size), + "sgd_minibatch_size": config.training.sgd_minibatch_size, + "lr": config.training.lr, + "num_sgd_iter": config.training.num_sgd_iter, + "framework": "torch", + "_disable_preprocessor_api": True, + "model": { + "custom_model": "tiramisu_model_v1", + "custom_model_config": { + "layer_sizes": list(config.model.layer_sizes), + "drops": list(config.model.drops), }, }, - ) + } + + if config.ray.resume_training: + print(f"Resuming training from: {local_dir}/{config.ray.name}") + tuner = tune.Tuner.restore( + path=f"{local_dir}/{config.ray.name}" + ) + else: + tuner = tune.Tuner( + "PPO", + param_space=config_dict, + run_config=air.RunConfig( + local_dir=local_dir, + stop={"training_iteration": config.ray.training_iteration}, + name=config.ray.name, + verbose=0, + failure_config=air.FailureConfig( + max_failures=0 + ), + checkpoint_config=air.CheckpointConfig( + checkpoint_frequency=config.ray.checkpoint_freq, + ) + ), + ) + results = tuner.fit() + + # analysis = tune.run( + # "PPO", + # local_dir=local_dir, + # name=config.ray.name, + # stop={"training_iteration": config.ray.training_iteration}, + # max_failures=0, + # checkpoint_freq=config.ray.checkpoint_freq, + # verbose=0, + # config=config, + # ) if __name__ == "__main__": @@ -73,6 +100,8 @@ def main(config: RLAutoSchedulerConfig): args = get_arguments() if args.num_workers != -1: config.ray.num_workers = args.num_workers + if args.resume_training: + config.ray.resume_training = True if args.num_workers == 1: with ray.init(): main(config) diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 3a568ea..e1abbcc 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -15,6 +15,7 @@ class RayConfig: base_path: str = "/data/scratch/hbenyamina/github/rl_autoscheduler" name: str = "Training_multi_enhanced" log_directory: str = "ray_results" + resume_training: bool = False @dataclass @@ -30,15 +31,15 @@ class TiramisuConfig: env_type: Literal["model", "cpu"] = "cpu" model_checkpoint: str = "/data/scratch/hbenyamina/model_published_nn_finale.pt" compile_tiramisu_cmd: str = 'printf "Compiling ${FILE_PATH}\n" >> ${FUNC_DIR}log.txt;\ - ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ - ${CXX} -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' + c++ -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lz -lpthread -std=c++11 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ + c++ -Wl,--no-as-needed -ldl -g -fno-rtti -lz -lpthread -std=c++11 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl ' run_tiramisu_cmd: str = 'printf "Running ${FILE_PATH}.out\n">> ${FUNC_DIR}log.txt;\ ./${FILE_PATH}.out>> ${FUNC_DIR}log.txt;' compile_wrapper_cmd = 'cd ${FUNC_DIR};\ - ${GXX} -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ - ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' + g++ -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ + g++ -std=c++11 -fno-rtti -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/3rdParty/isl/include/ -I${TIRAMISU_ROOT}/benchmarks -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib/ -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -o ${FUNC_NAME}_wrapper -ltiramisu -lHalide -ldl -lpthread -lz -lm -Wl,-rpath,${TIRAMISU_ROOT}/build ./${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -ltiramisu -lHalide -ldl -lpthread -lz -lm' @dataclass From 1643fc2ba8285ad08692e6d5167f14e60f09cf09 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Wed, 25 Jan 2023 14:07:00 +0400 Subject: [PATCH 06/27] refactored code from later commit --- rl_interface/environment.py | 27 ++++++++++++------------ tiramisu_programs/schedule_controller.py | 23 ++++++++++---------- train_ppo.py | 5 +++-- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/rl_interface/environment.py b/rl_interface/environment.py index 3ed95b8..e3ee31a 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -142,27 +142,26 @@ def reset(self, file=None): # Get the gym representation from the annotations self.obs = self.schedule_object.get_representation() - if self.config.tiramisu.env_type == "cpu": - if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys(): + if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys(): + if self.config.tiramisu.env_type == "cpu": print("Getting the initial exe time by execution") self.prog.initial_execution_time = self.schedule_controller.measurement_env( [], 'initial_exec', self.nb_executions, self.prog.initial_execution_time) - self.progs_dict[self.prog.name] = {} - self.progs_dict[self.prog.name]["initial_execution_time"] = self.prog.initial_execution_time + elif self.config.tiramisu.env_type == "model": + self.prog.initial_execution_time = 1.0 + self.progs_dict[self.prog.name] = {} + self.progs_dict[self.prog.name]["program_annotation"] = self.schedule_object.annotations + self.progs_dict[self.prog.name]["initial_execution_time"] = self.prog.initial_execution_time - else: - print("The initial execution time exists") + else: + print("The initial execution time exists") + # Add something about whether the execution time was created using RL + if self.config.tiramisu.env_type == "cpu": self.prog.initial_execution_time = self.progs_dict[ self.prog.name]["initial_execution_time"] - else: - if not self.config.environment.use_dataset: + elif self.config.tiramisu.env_type == "model": self.prog.initial_execution_time = 1.0 - self.progs_dict[self.prog.name] = {} - self.progs_dict[self.prog.name]["initial_execution_time"] = self.prog.initial_execution_time - - if not self.config.environment.use_dataset: - self.progs_dict[self.prog.name]["program_annotation"] = self.schedule_object.annotations except: print("RESET_ERROR_STDERR", @@ -236,7 +235,7 @@ def step(self, raw_action): speedup = 1.0 ray.get(self.shared_variable_actor.update_lc_data.remote( self.schedule_controller.get_legality_data())) - if "schedule" in self.progs_dict[self.prog.name]: + if "schedules_list" in self.progs_dict[self.prog.name]: self.schedule_object.schedule_dict["speedup"] = speedup self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str self.progs_dict[self.prog.name]["schedules_list"].append( diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 4ea628a..2fda3ef 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -457,17 +457,18 @@ def apply_action(self, action): done = True exit = True - if (not exit and lc_check != 0) and not (action.id in range( - 41, 44) and self.schedule_object.is_skewed): - self.schedule_object.sched_str = ScheduleUtils.sched_str( - self.schedule_object.sched_str, action.id, action_params, - self.schedule_object.comp_indic_dict) - # ray.util.pdb.set_trace() - if not action.id in range(41, 44): - self.schedule_object.it_dict = ScheduleUtils.update_iterators( - action.id, self.schedule_object.it_dict, action_params, - self.schedule_object.added_iterators, - self.schedule_object.comp_indic_dict) + if (not exit and lc_check != 0): + # Changed the sched_str to be updated after all successfull application of actions + self.schedule_object.sched_str = tmp_sched_str + if not (action.id in range(41, 44) and self.schedule_object.is_skewed): + # self.schedule_object.sched_str = ScheduleUtils.sched_str( + # self.schedule_object.sched_str, action.id, action_params, + # self.schedule_object.comp_indic_dict) + if not action.id in range(41, 44): + self.schedule_object.it_dict = ScheduleUtils.update_iterators( + action.id, self.schedule_object.it_dict, action_params, + self.schedule_object.added_iterators, + self.schedule_object.comp_indic_dict) self.depth += 1 return self.schedule_object.repr, 1.0, done, info diff --git a/train_ppo.py b/train_ppo.py index 4d44a7f..7afc97b 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -19,7 +19,7 @@ def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--num-workers", default=-1, type=int) - parser.add_argument("--use-dataset", default=False, type=bool) + parser.add_argument("--use-dataset", action=argparse.BooleanOptionalAction) parser.add_argument("--log-level", default="INFO", # TODO change back to WARN type=str, choices=list(logging._nameToLevel.keys())) return parser.parse_args() @@ -77,7 +77,8 @@ def main(config: RLAutoSchedulerConfig): args = get_arguments() if args.num_workers != -1: config.ray.num_workers = args.num_workers - config.environment.use_dataset = args.use_dataset + if args.use_dataset: + config.environment.use_dataset = args.use_dataset logging.basicConfig(level=logging._nameToLevel[args.log_level]) if args.num_workers == 1: with ray.init(): From 6c8b39be6b1a065ca4e18f1ad0ab170c70ca02b6 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Thu, 26 Jan 2023 09:26:45 +0400 Subject: [PATCH 07/27] fixed loading exec time from dataset and cleaned some code --- rl_interface/environment.py | 7 ++- tiramisu_programs/schedule_controller.py | 63 ++++++++---------------- 2 files changed, 27 insertions(+), 43 deletions(-) diff --git a/rl_interface/environment.py b/rl_interface/environment.py index e3ee31a..1bb1756 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -2,6 +2,7 @@ import copy import logging import random +from socket import gethostname import sys import time import traceback @@ -102,6 +103,7 @@ def reset(self, file=None): print("\n----------Resetting the environment-----------\n") self.episode_total_time = time.time() + self.hostname = gethostname() while True: try: # Clean files of the previous function ran @@ -142,7 +144,10 @@ def reset(self, file=None): # Get the gym representation from the annotations self.obs = self.schedule_object.get_representation() - if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys(): + validInitTime = self.prog.json_representation and self.prog.json_representation['node_name'].startswith( + self.hostname[:2]) + + if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys() or not validInitTime: if self.config.tiramisu.env_type == "cpu": print("Getting the initial exe time by execution") self.prog.initial_execution_time = self.schedule_controller.measurement_env( diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 2fda3ef..a160b2c 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -1,4 +1,5 @@ import copy +from socket import gethostname import sys import time import traceback @@ -32,7 +33,6 @@ def __init__(self, self.nb_executions = nb_executions self.speedup = 1.0 self.steps = 0 - self.new_scheds = {} self.search_time = time.time() self.config = config if self.config.tiramisu.env_type == "cpu": @@ -590,12 +590,13 @@ def test_additional_actions(self, training=True): else: self.schedule.remove(optim5) - self.new_scheds[self.schedule_object.prog.name].pop( - self.schedule_object.sched_str) + self.schedule_object.sched_str = self.schedule_object.sched_str.replace( parallelization_str, "") + self.schedule_object.schedule_dict[first_comp][ "parallelized_dim"] = None + print("X: Parallelization improves the performance") except: @@ -656,54 +657,32 @@ def get_exec_time_by_model(self, optims_list, cmd_type, nb_executions, return stat["predicted_execution_time"] def get_exec_time(self): + hostname = gethostname() prog_name = self.schedule_object.prog.name execution_time = 0 + + # Using dataset and the machine used to generate the data is the same as the current machine + validExecTimes = self.schedule_object.prog.json_representation and self.schedule_object.prog.json_representation['node_name'].startswith( + hostname[:2]) + if self.schedule_object.sched_str != "" and self.schedule != []: - if prog_name in self.scheds.keys(): - if self.schedule_object.sched_str in self.scheds[prog_name]: + # Using dataset and the machine used to generate the data is the same as the current machine + if validExecTimes: + # Look for the schedule + for tmp_schedule in self.schedule_object.prog.json_representation['schedules_list']: + if tmp_schedule['sched_str'] == self.schedule_object.sched_str: + execution_time = min(tmp_schedule['execution_times']) + break + # not using the dataset + else: + # if the program is in the list of programs ran and the schedule has been discovered + if prog_name in self.scheds.keys() and self.schedule_object.sched_str in self.scheds[prog_name]: execution_time = self.scheds[prog_name][ self.schedule_object.sched_str][0] else: - if prog_name in self.new_scheds.keys( - ) and self.schedule_object.sched_str in self.new_scheds[ - prog_name].keys(): - execution_time = self.new_scheds[prog_name][ - self.schedule_object.sched_str][1] - else: - curr_sched = copy.deepcopy(self.schedule) - self.new_scheds[prog_name] = {} - execution_time = self.measurement_env( - self.schedule, 'sched_eval', self.nb_executions, - self.schedule_object.prog.initial_execution_time) - self.new_scheds[prog_name][ - self.schedule_object.sched_str] = ( - curr_sched, execution_time, 0) - else: - if prog_name in self.new_scheds.keys(): - if self.schedule_object.sched_str in self.new_scheds[ - prog_name].keys(): - execution_time = self.new_scheds[prog_name][ - self.schedule_object.sched_str][1] - else: - curr_sched = copy.deepcopy(self.schedule) - execution_time = self.measurement_env( - self.schedule, 'sched_eval', self.nb_executions, - self.schedule_object.prog.initial_execution_time) - self.new_scheds[prog_name][ - self.schedule_object.sched_str] = ( - curr_sched, execution_time, 0) - else: - curr_sched = copy.deepcopy(self.schedule) - self.new_scheds[prog_name] = {} - start_time = time.time() execution_time = self.measurement_env( self.schedule, 'sched_eval', self.nb_executions, self.schedule_object.prog.initial_execution_time) - sched_time = time.time() - start_time - self.new_scheds[prog_name][ - self.schedule_object.sched_str] = (curr_sched, - execution_time, - 0) else: execution_time = self.schedule_object.prog.initial_execution_time return execution_time From 617b692c8b947c052b8c34d1317152f46903075e Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Thu, 26 Jan 2023 09:34:42 +0400 Subject: [PATCH 08/27] refactored checking the hostname in function for future changes --- rl_interface/environment.py | 7 +++---- tiramisu_programs/schedule_controller.py | 6 ++---- tiramisu_programs/schedule_utils.py | 6 ++++++ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/rl_interface/environment.py b/rl_interface/environment.py index 1bb1756..ae9b81b 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -2,7 +2,6 @@ import copy import logging import random -from socket import gethostname import sys import time import traceback @@ -13,6 +12,7 @@ import tiramisu_programs import rl_interface +from tiramisu_programs.schedule_utils import ScheduleUtils from utils.environment_variables import configure_env_variables np.seterr(invalid="raise") @@ -103,7 +103,6 @@ def reset(self, file=None): print("\n----------Resetting the environment-----------\n") self.episode_total_time = time.time() - self.hostname = gethostname() while True: try: # Clean files of the previous function ran @@ -144,8 +143,8 @@ def reset(self, file=None): # Get the gym representation from the annotations self.obs = self.schedule_object.get_representation() - validInitTime = self.prog.json_representation and self.prog.json_representation['node_name'].startswith( - self.hostname[:2]) + validInitTime = self.prog.json_representation and ScheduleUtils.is_same_machine_as_dataset( + self.prog) if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys() or not validInitTime: if self.config.tiramisu.env_type == "cpu": diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index a160b2c..ca89102 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -1,5 +1,4 @@ import copy -from socket import gethostname import sys import time import traceback @@ -657,13 +656,12 @@ def get_exec_time_by_model(self, optims_list, cmd_type, nb_executions, return stat["predicted_execution_time"] def get_exec_time(self): - hostname = gethostname() prog_name = self.schedule_object.prog.name execution_time = 0 # Using dataset and the machine used to generate the data is the same as the current machine - validExecTimes = self.schedule_object.prog.json_representation and self.schedule_object.prog.json_representation['node_name'].startswith( - hostname[:2]) + validExecTimes = self.schedule_object.prog.json_representation and ScheduleUtils.is_same_machine_as_dataset( + self.schedule_object.prog) if self.schedule_object.sched_str != "" and self.schedule != []: # Using dataset and the machine used to generate the data is the same as the current machine diff --git a/tiramisu_programs/schedule_utils.py b/tiramisu_programs/schedule_utils.py index 68c3486..f1d22e9 100644 --- a/tiramisu_programs/schedule_utils.py +++ b/tiramisu_programs/schedule_utils.py @@ -1,5 +1,6 @@ import json import re +from socket import gethostname from typing import List import numpy as np @@ -986,3 +987,8 @@ def optimlist_to_str(cls, optim_list): str(unrolling_factor) + ")" return sched_str + + @classmethod + def is_same_machine_as_dataset(cls, prog): + hostname = gethostname() + return prog.json_representation['node_name'].startswith(hostname[:2]) From 00ff9b6050ab9ec4f1053a5e40f4790cbebb57ce Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Thu, 26 Jan 2023 10:01:31 +0400 Subject: [PATCH 09/27] fixed invalidating exec time on init time --- rl_interface/environment.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/rl_interface/environment.py b/rl_interface/environment.py index ae9b81b..64f5e33 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -143,10 +143,11 @@ def reset(self, file=None): # Get the gym representation from the annotations self.obs = self.schedule_object.get_representation() + # Check if the current machine is the one used to generate the data for this program validInitTime = self.prog.json_representation and ScheduleUtils.is_same_machine_as_dataset( self.prog) - if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys() or not validInitTime: + if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys(): if self.config.tiramisu.env_type == "cpu": print("Getting the initial exe time by execution") self.prog.initial_execution_time = self.schedule_controller.measurement_env( @@ -167,6 +168,13 @@ def reset(self, file=None): elif self.config.tiramisu.env_type == "model": self.prog.initial_execution_time = 1.0 + if self.config.environment.use_dataset and self.config.tiramisu.env_type == "cpu" and not validInitTime: + print("Inittial execution time invalidated") + print("\t -> Getting the initial exe time by execution") + self.prog.initial_execution_time = self.schedule_controller.measurement_env( + [], 'initial_exec', self.nb_executions, + self.prog.initial_execution_time) + except: print("RESET_ERROR_STDERR", traceback.format_exc(), file=sys.stderr) From 3422a620fb17db98f3923778a1eb125b7c230dc4 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Fri, 27 Jan 2023 16:24:40 +0400 Subject: [PATCH 10/27] load new version of dataset and save it --- rl_interface/environment.py | 45 ++++---- tiramisu_programs/schedule_controller.py | 138 +++++++++++++++-------- train_ppo.py | 10 ++ utils/global_ray_variables.py | 22 ++-- utils/rl_autoscheduler_config.py | 3 +- 5 files changed, 135 insertions(+), 83 deletions(-) diff --git a/rl_interface/environment.py b/rl_interface/environment.py index 64f5e33..fd7b7e9 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -143,10 +143,6 @@ def reset(self, file=None): # Get the gym representation from the annotations self.obs = self.schedule_object.get_representation() - # Check if the current machine is the one used to generate the data for this program - validInitTime = self.prog.json_representation and ScheduleUtils.is_same_machine_as_dataset( - self.prog) - if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys(): if self.config.tiramisu.env_type == "cpu": print("Getting the initial exe time by execution") @@ -168,13 +164,6 @@ def reset(self, file=None): elif self.config.tiramisu.env_type == "model": self.prog.initial_execution_time = 1.0 - if self.config.environment.use_dataset and self.config.tiramisu.env_type == "cpu" and not validInitTime: - print("Inittial execution time invalidated") - print("\t -> Getting the initial exe time by execution") - self.prog.initial_execution_time = self.schedule_controller.measurement_env( - [], 'initial_exec', self.nb_executions, - self.prog.initial_execution_time) - except: print("RESET_ERROR_STDERR", traceback.format_exc(), file=sys.stderr) @@ -245,25 +234,29 @@ def step(self, raw_action): speedup = self.schedule_controller.get_final_score() except: speedup = 1.0 - ray.get(self.shared_variable_actor.update_lc_data.remote( - self.schedule_controller.get_legality_data())) - if "schedules_list" in self.progs_dict[self.prog.name]: - self.schedule_object.schedule_dict["speedup"] = speedup - self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str - self.progs_dict[self.prog.name]["schedules_list"].append( - self.schedule_object.schedule_dict) - else: - self.schedule_object.schedule_dict["speedup"] = speedup - self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str - self.progs_dict[self.prog.name]["schedules_list"] = [ - self.schedule_object.schedule_dict] + # Update shared progs_dict with explored schedules' legality checks + ray.get(self.shared_variable_actor.update_progs_dict.remote( + self.prog.name, self.prog.json_representation)) + + if not self.config.environment.use_dataset: + if "schedules_list" in self.progs_dict[self.prog.name]: + self.schedule_object.schedule_dict["speedup"] = speedup + self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str + self.progs_dict[self.prog.name]["schedules_list"].append( + self.schedule_object.schedule_dict) + else: + self.schedule_object.schedule_dict["speedup"] = speedup + self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str + self.progs_dict[self.prog.name]["schedules_list"] = [ + self.schedule_object.schedule_dict] reward_object = rl_interface.Reward(speedup) reward = reward_object.reward print(f"Received a reward: {reward}") # Saving data if self.total_steps % self.SAVING_FREQUENCY: - ray.get(self.shared_variable_actor.write_lc_data.remote()) - rl_interface.utils.EnvironmentUtils.write_json_dataset( - f"worker_{self.id}.json", self.progs_dict) + ray.get(self.shared_variable_actor.write_progs_dict.remote()) + + # rl_interface.utils.EnvironmentUtils.write_json_dataset( + # f"worker_{self.id}.json", self.progs_dict) return self.obs, reward, done, info diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index ca89102..868f46d 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -1,4 +1,5 @@ import copy +import logging import sys import time import traceback @@ -75,11 +76,14 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset: - for sched_json in self.schedule_object.prog.json_representation['schedules_list']: - if tmp_sched_str == sched_json['sched_str']: - saved_legality = 1 if sched_json['legality_check'] else None + if self.config.environment.use_dataset and is_schedule_saved: + print( + "Loading legality check from saved schedule") + saved_legality = self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] if self.schedule_object.is_unrolled: lc_check = self.schedule_object.prog.check_legality_of_schedule( @@ -129,11 +133,14 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset: - for sched_json in self.schedule_object.prog.json_representation['schedules_list']: - if tmp_sched_str == sched_json['sched_str']: - saved_legality = 1 if sched_json['legality_check'] else None + if self.config.environment.use_dataset and is_schedule_saved: + print( + "Loading legality check from saved schedule") + saved_legality = self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] if self.schedule_object.is_unrolled: lc_check = self.schedule_object.prog.check_legality_of_schedule( @@ -141,6 +148,12 @@ def apply_action(self, action): else: lc_check = self.schedule_object.prog.check_legality_of_schedule( self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + + # Save legality check + if self.config.environment.use_dataset and saved_legality is None: + self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] = lc_check + if lc_check == -1: print("X: This action produces an error") self.pop_schedule(action=action) @@ -193,11 +206,14 @@ def apply_action(self, action): self.schedule) print(tmp_sched_str) + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset: - for sched_json in self.schedule_object.prog.json_representation['schedules_list']: - if tmp_sched_str == sched_json['sched_str']: - saved_legality = 1 if sched_json['legality_check'] else None + if self.config.environment.use_dataset and is_schedule_saved: + print( + "Loading legality check from saved schedule") + saved_legality = self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() lc_check = self.schedule_object.prog.check_legality_of_schedule( @@ -205,6 +221,11 @@ def apply_action(self, action): l_time = time.time() - start_time self.lc_total_time += l_time + # Save legality check + if self.config.environment.use_dataset and saved_legality is None: + self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] = lc_check + if lc_check == -1: print("X: This action produces an error") self.pop_schedule(action=action) @@ -267,11 +288,14 @@ def apply_action(self, action): self.schedule) print(tmp_sched_str) + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset: - for sched_json in self.schedule_object.prog.json_representation['schedules_list']: - if tmp_sched_str == sched_json['sched_str']: - saved_legality = 1 if sched_json['legality_check'] else None + if self.config.environment.use_dataset and is_schedule_saved: + print( + "Loading legality check from saved schedule") + saved_legality = self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() if self.schedule_object.is_unrolled: @@ -284,6 +308,11 @@ def apply_action(self, action): l_time = time.time() - start_time self.lc_total_time += l_time + # Save legality check + if self.config.environment.use_dataset and saved_legality is None: + self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] = lc_check + if lc_check == -1: print("X: This action produces an error") self.pop_schedule(action=action) @@ -321,11 +350,14 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset: - for sched_json in self.schedule_object.prog.json_representation['schedules_list']: - if tmp_sched_str == sched_json['sched_str']: - saved_legality = 1 if sched_json['legality_check'] else None + if self.config.environment.use_dataset and is_schedule_saved: + print( + "Loading legality check from saved schedule") + saved_legality = self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() if self.schedule_object.is_unrolled: @@ -335,6 +367,11 @@ def apply_action(self, action): lc_check = self.schedule_object.prog.check_legality_of_schedule( self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + # Save legality check + if self.config.environment.use_dataset and saved_legality is None: + self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] = lc_check + l_time = time.time() - start_time self.lc_total_time += l_time if lc_check == -1: @@ -366,11 +403,15 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) + + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset: - for sched_json in self.schedule_object.prog.json_representation['schedules_list']: - if tmp_sched_str == sched_json['sched_str']: - saved_legality = 1 if sched_json['legality_check'] else None + if self.config.environment.use_dataset and is_schedule_saved: + print( + "Loading legality check from saved schedule") + saved_legality = self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() if self.schedule_object.is_unrolled: @@ -381,6 +422,12 @@ def apply_action(self, action): self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality l_time = time.time() - start_time self.lc_total_time += l_time + + # Save legality check + if self.config.environment.use_dataset and saved_legality is None: + self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] = lc_check + if lc_check == -1: print("X: This action produces am error") self.pop_schedule(action=action) @@ -415,11 +462,14 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset: - for sched_json in self.schedule_object.prog.json_representation['schedules_list']: - if tmp_sched_str == sched_json['sched_str']: - saved_legality = 1 if sched_json['legality_check'] else None + if self.config.environment.use_dataset and is_schedule_saved: + print( + "Loading legality check from saved schedule") + saved_legality = self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() @@ -433,6 +483,11 @@ def apply_action(self, action): l_time = time.time() - start_time self.lc_total_time += l_time + # Save legality check + if self.config.environment.use_dataset and saved_legality is None: + self.schedule_object.prog.json_representation[ + 'schedules_legality_dict'][tmp_sched_str] = lc_check + if lc_check == -1: print("X: This action produces an error") self.pop_schedule(action=action) @@ -659,28 +714,15 @@ def get_exec_time(self): prog_name = self.schedule_object.prog.name execution_time = 0 - # Using dataset and the machine used to generate the data is the same as the current machine - validExecTimes = self.schedule_object.prog.json_representation and ScheduleUtils.is_same_machine_as_dataset( - self.schedule_object.prog) - if self.schedule_object.sched_str != "" and self.schedule != []: - # Using dataset and the machine used to generate the data is the same as the current machine - if validExecTimes: - # Look for the schedule - for tmp_schedule in self.schedule_object.prog.json_representation['schedules_list']: - if tmp_schedule['sched_str'] == self.schedule_object.sched_str: - execution_time = min(tmp_schedule['execution_times']) - break - # not using the dataset + # if the program is in the list of programs ran and the schedule has been discovered + if prog_name in self.scheds.keys() and self.schedule_object.sched_str in self.scheds[prog_name]: + execution_time = self.scheds[prog_name][ + self.schedule_object.sched_str][0] else: - # if the program is in the list of programs ran and the schedule has been discovered - if prog_name in self.scheds.keys() and self.schedule_object.sched_str in self.scheds[prog_name]: - execution_time = self.scheds[prog_name][ - self.schedule_object.sched_str][0] - else: - execution_time = self.measurement_env( - self.schedule, 'sched_eval', self.nb_executions, - self.schedule_object.prog.initial_execution_time) + execution_time = self.measurement_env( + self.schedule, 'sched_eval', self.nb_executions, + self.schedule_object.prog.initial_execution_time) else: execution_time = self.schedule_object.prog.initial_execution_time return execution_time diff --git a/train_ppo.py b/train_ppo.py index 7afc97b..ad8eb70 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -75,11 +75,21 @@ def main(config: RLAutoSchedulerConfig): parsed_yaml_dict = parse_yaml_file(read_yaml_file("config.yaml")) config = dict_to_config(parsed_yaml_dict) args = get_arguments() + if args.num_workers != -1: config.ray.num_workers = args.num_workers + if args.use_dataset: config.environment.use_dataset = args.use_dataset + + if config.tiramisu.env_type == 'cpu': + logging.warning( + "DATASET LEARNINING IS INCOMPATIBLE WITH CPU LEARNING. SWITCHING TO MODEL") + # Force model usage if using dataset + config.tiramisu.env_type = "model" + logging.basicConfig(level=logging._nameToLevel[args.log_level]) + logging.getLogger().setLevel(logging._nameToLevel[args.log_level]) if args.num_workers == 1: with ray.init(): main(config) diff --git a/utils/global_ray_variables.py b/utils/global_ray_variables.py index 56dbd05..6720dba 100644 --- a/utils/global_ray_variables.py +++ b/utils/global_ray_variables.py @@ -82,14 +82,20 @@ def write_lc_data(self): json.dump(self.lc_data, f) return True - def update_progs_dict(self, v): - self.progs_dict.update(v) + def update_progs_dict(self, function_name, json_legality_annotations): + self.progs_dict[function_name] = json_legality_annotations return True - def write_progs_dict(self): - print("Saving progs_dict to disk") - with open(self.programs_file, "w") as f: - json.dump(self.progs_dict, f) + def write_progs_dict(self, format="pkl"): + logging.info("Saving the legality_annotations_dict to disk") + + if format == "pkl": + with bz2.BZ2File(self.json_dataset['path_to_save_sataset'], 'wb') as f: + pickle.dump(self.progs_dict, f, + protocol=pickle.HIGHEST_PROTOCOL) + else: + with open(self.programs_file, "w") as f: + json.dump(self.progs_dict, f) return True def get_progs_dict(self): @@ -124,8 +130,8 @@ def get_progs_dict(self): def write_progs_dict(self): return ray.get(self.data_registry.write_progs_dict.remote()) - def update_progs_dict(self, v): - return ray.get(self.data_registry.update_progs_dict.remote(v)) + def update_progs_dict(self, function_name, json_legality_annotations): + return ray.get(self.data_registry.update_progs_dict.remote(function_name, json_legality_annotations)) def increment(self): return ray.get(self.data_registry.increment.remote()) diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 484542d..f5599f0 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -24,7 +24,8 @@ class EnvironmentConfig: clean_files: bool = True json_dataset: dict = field(default_factory=lambda: { "path": None, - "cpp_root": None + "cpp_root": None, + "path_to_save_sataset": None }) use_dataset: bool = False From 0300f248f5567a8b18fe95d3e133df7f5c33a240 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Fri, 3 Feb 2023 12:14:39 +0400 Subject: [PATCH 11/27] dataset learning with one functiona at a time --- rl_interface/environment.py | 99 +++++----------- tiramisu_programs/schedule_controller.py | 82 +++++--------- tiramisu_programs/schedule_utils.py | 2 +- tiramisu_programs/tiramisu_program.py | 10 +- train_ppo.py | 29 ++--- utils/dataset_utilities.py | 47 ++++++-- utils/global_ray_variables.py | 137 ----------------------- utils/rl_autoscheduler_config.py | 2 +- 8 files changed, 116 insertions(+), 292 deletions(-) delete mode 100644 utils/global_ray_variables.py diff --git a/rl_interface/environment.py b/rl_interface/environment.py index fd7b7e9..48f8c1e 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -22,9 +22,9 @@ class TiramisuScheduleEnvironment(gym.Env): ''' The reinforcement learning environment used by the GYM. ''' - SAVING_FREQUENCY = 500 + SAVING_FREQUENCY = 50 - def __init__(self, config, shared_variable_actor): + def __init__(self, config, dataset_actor): print("Configuring the environment variables") configure_env_variables(config) @@ -38,34 +38,17 @@ def __init__(self, config, shared_variable_actor): self.progs_annot = {} self.programs_file = config.environment.programs_file self.measurement_env = None - self.dataset_path = config.environment.dataset_path + self.cpps_path = config.environment.dataset_path self.depth = 0 self.nb_executions = 5 self.episode_total_time = 0 self.prog_ind = 0 self.steps = 0 self.previous_cpp_file = None - self.shared_variable_actor = shared_variable_actor + self.dataset_actor = dataset_actor if config.environment.use_dataset: - self.dataset_path = config.environment.json_dataset['cpp_root'] - - self.id = ray.get(self.shared_variable_actor.increment.remote()) - - logging.info("worker getting its list of programs") - # List of function names - self.progs_list = ray.get( - self.shared_variable_actor.get_progs_list.remote(self.id)) - - # Dict of programs with their annotations, schedules, exectution times and traces - self.progs_dict = ray.get( - self.shared_variable_actor.get_progs_dict.remote()) - logging.info("Loaded the dataset!") - - # Dict of function with a Dict containing schedules in STR format and their execution time TODO is this used? - self.scheds = tiramisu_programs.schedule_utils.ScheduleUtils.get_schedules_str( - list(self.progs_dict.keys()), - self.progs_dict) # to use it to get the execution time + self.cpps_path = config.environment.json_dataset['cpps_path'] self.action_space = gym.spaces.Discrete(62) @@ -108,20 +91,22 @@ def reset(self, file=None): # Clean files of the previous function ran if self.config.environment.clean_files and self.previous_cpp_file: tiramisu_programs.cpp_file.CPP_File.clean_cpp_file( - self.dataset_path, self.previous_cpp_file) - # Choose a random program (function) - function_name = random.choice(self.progs_list) + self.cpps_path, self.previous_cpp_file) + + # get the next function + (function_name, function_dict) = ray.get( + self.dataset_actor.get_next_function.remote()) # Copy the function's files to the dataset copy created file = tiramisu_programs.cpp_file.CPP_File.get_cpp_file( - self.dataset_path, function_name) + self.cpps_path, function_name) # Set up the function files to be deleted on the next iteration self.previous_cpp_file = function_name # Load the tiramisu program from the file self.prog = tiramisu_programs.tiramisu_program.TiramisuProgram( - self.config, file, progs_dict=self.progs_dict) + self.config, file, function_dict) print(f"Trying with program {self.prog.name}") @@ -131,38 +116,18 @@ def reset(self, file=None): self.schedule_controller = tiramisu_programs.schedule_controller.ScheduleController( schedule=self.schedule_object, nb_executions=self.nb_executions, - scheds=self.scheds, config=self.config) - # Load the legality check list. Starts empty - lc_data = ray.get( - self.shared_variable_actor.get_lc_data.remote()) - - self.schedule_controller.load_legality_data(lc_data) - # Get the gym representation from the annotations self.obs = self.schedule_object.get_representation() - if self.progs_dict == {} or self.prog.name not in self.progs_dict.keys(): - if self.config.tiramisu.env_type == "cpu": - print("Getting the initial exe time by execution") - self.prog.initial_execution_time = self.schedule_controller.measurement_env( - [], 'initial_exec', self.nb_executions, - self.prog.initial_execution_time) - elif self.config.tiramisu.env_type == "model": - self.prog.initial_execution_time = 1.0 - self.progs_dict[self.prog.name] = {} - self.progs_dict[self.prog.name]["program_annotation"] = self.schedule_object.annotations - self.progs_dict[self.prog.name]["initial_execution_time"] = self.prog.initial_execution_time - - else: - print("The initial execution time exists") - # Add something about whether the execution time was created using RL - if self.config.tiramisu.env_type == "cpu": - self.prog.initial_execution_time = self.progs_dict[ - self.prog.name]["initial_execution_time"] - elif self.config.tiramisu.env_type == "model": - self.prog.initial_execution_time = 1.0 + if self.config.tiramisu.env_type == "cpu": + print("Getting the initial exe time by execution") + self.prog.initial_execution_time = self.schedule_controller.measurement_env( + [], 'initial_exec', self.nb_executions, + self.prog.initial_execution_time) + elif self.config.tiramisu.env_type == "model": + self.prog.initial_execution_time = 1.0 except: print("RESET_ERROR_STDERR", @@ -191,6 +156,7 @@ def step(self, raw_action): speedup = 1.0 self.steps += 1 self.total_steps += 1 + print(f"step:{self.total_steps}") try: action = rl_interface.Action(raw_action, @@ -234,29 +200,16 @@ def step(self, raw_action): speedup = self.schedule_controller.get_final_score() except: speedup = 1.0 - # Update shared progs_dict with explored schedules' legality checks - ray.get(self.shared_variable_actor.update_progs_dict.remote( - self.prog.name, self.prog.json_representation)) + # Update dataset with explored legality checks + self.dataset_actor.update_dataset.remote( + self.prog.name, self.prog.function_dict) - if not self.config.environment.use_dataset: - if "schedules_list" in self.progs_dict[self.prog.name]: - self.schedule_object.schedule_dict["speedup"] = speedup - self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str - self.progs_dict[self.prog.name]["schedules_list"].append( - self.schedule_object.schedule_dict) - else: - self.schedule_object.schedule_dict["speedup"] = speedup - self.schedule_object.schedule_dict["sched_str"] = self.schedule_object.sched_str - self.progs_dict[self.prog.name]["schedules_list"] = [ - self.schedule_object.schedule_dict] reward_object = rl_interface.Reward(speedup) reward = reward_object.reward print(f"Received a reward: {reward}") # Saving data - if self.total_steps % self.SAVING_FREQUENCY: - ray.get(self.shared_variable_actor.write_progs_dict.remote()) - - # rl_interface.utils.EnvironmentUtils.write_json_dataset( - # f"worker_{self.id}.json", self.progs_dict) + if not self.total_steps % self.SAVING_FREQUENCY: + self.dataset_actor.save_dataset_to_disk.remote( + self.config.environment.json_dataset['path_to_save_dataset'], format="pkl") return self.obs, reward, done, info diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 868f46d..906ac5a 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -24,12 +24,10 @@ class ScheduleController: def __init__(self, schedule: Schedule = None, nb_executions=5, - scheds=None, config=None): self.depth = 0 self.schedule = [] self.schedule_object = schedule - self.scheds = scheds self.nb_executions = nb_executions self.speedup = 1.0 self.steps = 0 @@ -40,7 +38,6 @@ def __init__(self, else: self.measurement_env = self.get_exec_time_by_model self.lc_total_time = 0 - self.lc_data = [] self.schedule_list_model = [] self.model = Model_Recursive_LSTM_v2() self.model.load_state_dict( @@ -54,7 +51,7 @@ def apply_action(self, action): info = {} self.steps += 1 first_comp = self.schedule_object.comps[0] - saved_legality = self.get_legality(action=action) + saved_legality = None if not action.id in range(44, 46): # If the action is skewing action_params = action.parameter() @@ -76,13 +73,13 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset and is_schedule_saved: print( "Loading legality check from saved schedule") - saved_legality = self.schedule_object.prog.json_representation[ + saved_legality = self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] if self.schedule_object.is_unrolled: @@ -92,6 +89,11 @@ def apply_action(self, action): lc_check = self.schedule_object.prog.check_legality_of_schedule( self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + # Save legality check + if self.config.environment.use_dataset and saved_legality is None: + self.schedule_object.prog.function_dict[ + 'schedules_legality_dict'][tmp_sched_str] = lc_check + if lc_check == -1: print("X: The action produced an error.") self.pop_schedule(action=action) @@ -133,13 +135,13 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset and is_schedule_saved: print( "Loading legality check from saved schedule") - saved_legality = self.schedule_object.prog.json_representation[ + saved_legality = self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] if self.schedule_object.is_unrolled: @@ -151,7 +153,7 @@ def apply_action(self, action): # Save legality check if self.config.environment.use_dataset and saved_legality is None: - self.schedule_object.prog.json_representation[ + self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check if lc_check == -1: @@ -206,13 +208,13 @@ def apply_action(self, action): self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset and is_schedule_saved: print( "Loading legality check from saved schedule") - saved_legality = self.schedule_object.prog.json_representation[ + saved_legality = self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() @@ -223,7 +225,7 @@ def apply_action(self, action): # Save legality check if self.config.environment.use_dataset and saved_legality is None: - self.schedule_object.prog.json_representation[ + self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check if lc_check == -1: @@ -288,13 +290,13 @@ def apply_action(self, action): self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset and is_schedule_saved: print( "Loading legality check from saved schedule") - saved_legality = self.schedule_object.prog.json_representation[ + saved_legality = self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() @@ -310,7 +312,7 @@ def apply_action(self, action): # Save legality check if self.config.environment.use_dataset and saved_legality is None: - self.schedule_object.prog.json_representation[ + self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check if lc_check == -1: @@ -350,13 +352,13 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset and is_schedule_saved: print( "Loading legality check from saved schedule") - saved_legality = self.schedule_object.prog.json_representation[ + saved_legality = self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() @@ -369,7 +371,7 @@ def apply_action(self, action): # Save legality check if self.config.environment.use_dataset and saved_legality is None: - self.schedule_object.prog.json_representation[ + self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check l_time = time.time() - start_time @@ -404,13 +406,13 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset and is_schedule_saved: print( "Loading legality check from saved schedule") - saved_legality = self.schedule_object.prog.json_representation[ + saved_legality = self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() @@ -425,7 +427,7 @@ def apply_action(self, action): # Save legality check if self.config.environment.use_dataset and saved_legality is None: - self.schedule_object.prog.json_representation[ + self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check if lc_check == -1: @@ -462,13 +464,13 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.json_representation[ + is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check if self.config.environment.use_dataset and is_schedule_saved: print( "Loading legality check from saved schedule") - saved_legality = self.schedule_object.prog.json_representation[ + saved_legality = self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] start_time = time.time() @@ -485,7 +487,7 @@ def apply_action(self, action): # Save legality check if self.config.environment.use_dataset and saved_legality is None: - self.schedule_object.prog.json_representation[ + self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check if lc_check == -1: @@ -711,38 +713,12 @@ def get_exec_time_by_model(self, optims_list, cmd_type, nb_executions, return stat["predicted_execution_time"] def get_exec_time(self): - prog_name = self.schedule_object.prog.name execution_time = 0 if self.schedule_object.sched_str != "" and self.schedule != []: - # if the program is in the list of programs ran and the schedule has been discovered - if prog_name in self.scheds.keys() and self.schedule_object.sched_str in self.scheds[prog_name]: - execution_time = self.scheds[prog_name][ - self.schedule_object.sched_str][0] - else: - execution_time = self.measurement_env( - self.schedule, 'sched_eval', self.nb_executions, - self.schedule_object.prog.initial_execution_time) + execution_time = self.measurement_env( + self.schedule, 'sched_eval', self.nb_executions, + self.schedule_object.prog.initial_execution_time) else: execution_time = self.schedule_object.prog.initial_execution_time return execution_time - - def save_legality_data(self, action, lc_check): - key = f"{self.schedule_object.prog.name}@{self.schedule_object.sched_str}@{action}" - self.lc_data.append( - [ - key, - lc_check - ] - ) - - def get_legality(self, action): - key = f"{self.schedule_object.prog.name}@{self.schedule_object.sched_str}@{action}" - values = [v for (k, v) in self.lc_data if k == key] - return values[0] if len(values) else None - - def get_legality_data(self): - return self.lc_data - - def load_legality_data(self, lc_data: List) -> None: - self.lc_data = lc_data diff --git a/tiramisu_programs/schedule_utils.py b/tiramisu_programs/schedule_utils.py index f1d22e9..7805af3 100644 --- a/tiramisu_programs/schedule_utils.py +++ b/tiramisu_programs/schedule_utils.py @@ -991,4 +991,4 @@ def optimlist_to_str(cls, optim_list): @classmethod def is_same_machine_as_dataset(cls, prog): hostname = gethostname() - return prog.json_representation['node_name'].startswith(hostname[:2]) + return prog.function_dict['node_name'].startswith(hostname[:2]) diff --git a/tiramisu_programs/tiramisu_program.py b/tiramisu_programs/tiramisu_program.py index 3acd2c8..d70057e 100644 --- a/tiramisu_programs/tiramisu_program.py +++ b/tiramisu_programs/tiramisu_program.py @@ -96,7 +96,7 @@ class TiramisuProgram(): return 0; }''' - def __init__(self, config, file_path, progs_dict=None): + def __init__(self, config, file_path, function_dict=None): self.config = config self.file_path = file_path with open(file_path, 'r') as f: @@ -132,16 +132,14 @@ def __init__(self, config, file_path, progs_dict=None): self.program_annotations = None self.wrapper_is_compiled = False self.initial_execution_time = 1.0 - self.json_representation = None - if config.environment.use_dataset: - self.json_representation = progs_dict[self.name] + self.function_dict = function_dict def get_program_annotations(self): if self.program_annotations is not None: return self.program_annotations - if self.config.environment.use_dataset: - self.program_annotations = self.json_representation['program_annotation'] + if self.function_dict: + self.program_annotations = self.function_dict['program_annotation'] else: # create a cpp file to get the annotations get_json_lines = ''' diff --git a/train_ppo.py b/train_ppo.py index ad8eb70..6f9f95b 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -10,7 +10,7 @@ from rl_interface.environment import TiramisuScheduleEnvironment from rl_interface.model import TiramisuModelMult -from utils.global_ray_variables import Actor, GlobalVarActor +from utils.dataset_utilities import DatasetAgent from utils.rl_autoscheduler_config import (RLAutoSchedulerConfig, dict_to_config, parse_yaml_file, read_yaml_file) @@ -28,15 +28,15 @@ def get_arguments(): # @hydra.main(config_path="config", config_name="config") def main(config: RLAutoSchedulerConfig): local_dir = os.path.join(config.ray.base_path, config.ray.log_directory) - progs_list_registery = GlobalVarActor.remote( - config.environment.programs_file, - config.environment.dataset_path, - num_workers=config.ray.num_workers, use_dataset=config.environment.use_dataset, json_dataset=config.environment.json_dataset) - shared_variable_actor = Actor.remote(progs_list_registery) + + dataset_path = config.environment.json_dataset[ + 'path'] if config.environment.use_dataset else config.environment.dataset_path + dataset_actor = DatasetAgent.remote( + dataset_path=dataset_path, use_dataset=config.environment.use_dataset) register_env( "Tiramisu_env_v1", - lambda a: TiramisuScheduleEnvironment(config, shared_variable_actor), + lambda a: TiramisuScheduleEnvironment(config, dataset_actor), ) ModelCatalog.register_custom_model("tiramisu_model_v1", TiramisuModelMult) @@ -53,6 +53,7 @@ def main(config: RLAutoSchedulerConfig): "env": "Tiramisu_env_v1", "num_workers": config.ray.num_workers, "placement_strategy": "SPREAD", + # "log_level": logging.INFO, "batch_mode": "complete_episodes", "train_batch_size": max(config.ray.num_workers * 200, config.training.train_batch_size), "sgd_minibatch_size": config.training.sgd_minibatch_size, @@ -89,10 +90,10 @@ def main(config: RLAutoSchedulerConfig): config.tiramisu.env_type = "model" logging.basicConfig(level=logging._nameToLevel[args.log_level]) - logging.getLogger().setLevel(logging._nameToLevel[args.log_level]) - if args.num_workers == 1: - with ray.init(): - main(config) - else: - with ray.init(address="auto"): - main(config) + # logging.getLogger().setLevel(logging._nameToLevel[args.log_level]) + # if args.num_workers == 1: + with ray.init(): + main(config) + # else: + # with ray.init(address="auto"): + # main(config) diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index 9f5e60f..dc47657 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -1,20 +1,53 @@ +import bz2 import json +import logging import os +import pickle import random - +import numpy as np import ray.data +@ray.remote class DatasetAgent: - def __init__(self, dataset_path, shuffle=True): + def __init__(self, dataset_path, use_dataset=False, shuffle=False): + self.dataset_path = dataset_path + self.use_dataset = use_dataset + self.dataset = {} + self.function_names = [] self.shuffle = shuffle - if os.path.isfile(dataset_path): - self.dataset = ray.data.read_json(dataset_path) - self.function_names = self.dataset.keys() + if use_dataset: + logging.info(f"reading dataset from json at:{dataset_path}") + with bz2.BZ2File(dataset_path, 'rb') as f: + self.dataset = pickle.load(f) + self.function_names = list(self.dataset.keys()) + logging.info( + f"[Done] reading dataset from json at:{dataset_path}") + + else: + os.getcwd() + logging.info(f"reading data from ls at: {os.getcwd()}") + self.function_names = os.listdir(dataset_path) if self.shuffle: random.shuffle(self.function_names) def get_next_function(self): - for function in self.function_names: - yield function, self.dataset[function] + function_name = np.random.choice(self.function_names) + return function_name, self.dataset[function_name] + + def update_dataset(self, function_name, function_dict): + self.dataset[function_name] = function_dict + + def save_dataset_to_disk(self, path, format): + logging.info("[Start] Save the legality_annotations_dict to disk") + + if format == "pkl": + with bz2.BZ2File(path, 'wb') as f: + pickle.dump(self.dataset, f, + protocol=pickle.HIGHEST_PROTOCOL) + else: + with open(path, "w") as f: + json.dump(self.dataset, f) + logging.info("[Done] Save the legality_annotations_dict to disk") + return True diff --git a/utils/global_ray_variables.py b/utils/global_ray_variables.py deleted file mode 100644 index 6720dba..0000000 --- a/utils/global_ray_variables.py +++ /dev/null @@ -1,137 +0,0 @@ -import bz2 -import logging -import pickle -from typing import List -import ray -import json -import os - - -@ray.remote -class GlobalVarActor: - - def __init__(self, programs_file, dataset_path, num_workers=7, use_dataset=False, json_dataset=None): - self.index = -1 - self.num_workers = num_workers - self.progs_list = [] - self.programs_file = programs_file - self.progs_dict = dict() - self.lc_data = [] - self.json_dataset = json_dataset - - self.get_dataset( - dataset_path, use_dataset, json_dataset_path=json_dataset["path"]) - # if os.path.isfile(programs_file): - # try: - # with open(programs_file) as f: - # self.progs_dict = json.load(f) - # except: - # self.progs_dict = dict() - # else: - # self.progs_dict = dict() - # with open(programs_file,"w+") as f: - # f.write(json.dumps(self.progs_dict)) - - if os.path.isfile("lc_data.json"): - try: - with open("lc_data.json") as f: - self.lc_data = json.load(f) - except: - self.lc_data = [] - else: - self.lc_data = [] - with open("lc_data.json", "w+") as f: - f.write(json.dumps(self.lc_data)) - - # Load the dataset of programs - def get_dataset(self, path, use_dataset=False, json_dataset_path=None): - if use_dataset: - logging.info(f"reading dataset from json at:{json_dataset_path}") - with bz2.BZ2File(json_dataset_path, 'rb') as f: - self.progs_dict = pickle.load(f) - self.progs_list = list(self.progs_dict.keys()) - logging.info( - f"[Done] reading dataset from json at:{json_dataset_path}") - - else: - os.getcwd() - logging.info(f"reading dataset from ls at: {os.getcwd()}") - self.progs_list = os.listdir(path) - - def set_progs_list(self, v): - self.progs_list = v - return True - - # Get programs of a worker by its id - def get_progs_list(self, id): - return [ - item for (index, item) in enumerate(self.progs_list) - if (index % self.num_workers) == (id % self.num_workers) - ] - - def update_lc_data(self, v: List): - self.lc_data.extend(v) - return True - - def get_lc_data(self) -> List: - return self.lc_data - - def write_lc_data(self): - print("Saving lc_data to disk") - with open("lc_data.json", "w") as f: - json.dump(self.lc_data, f) - return True - - def update_progs_dict(self, function_name, json_legality_annotations): - self.progs_dict[function_name] = json_legality_annotations - return True - - def write_progs_dict(self, format="pkl"): - logging.info("Saving the legality_annotations_dict to disk") - - if format == "pkl": - with bz2.BZ2File(self.json_dataset['path_to_save_sataset'], 'wb') as f: - pickle.dump(self.progs_dict, f, - protocol=pickle.HIGHEST_PROTOCOL) - else: - with open(self.programs_file, "w") as f: - json.dump(self.progs_dict, f) - return True - - def get_progs_dict(self): - return self.progs_dict - - def increment(self): - self.index += 1 - return self.index - - -@ray.remote -class Actor: - - def __init__(self, data_registry): - self.data_registry = data_registry - - def get_progs_list(self, id): - return ray.get(self.data_registry.get_progs_list.remote(id)) - - def update_lc_data(self, v: List): - return ray.get(self.data_registry.update_lc_data.remote(v)) - - def get_lc_data(self) -> List: - return ray.get(self.data_registry.get_lc_data.remote()) - - def write_lc_data(self): - return ray.get(self.data_registry.write_lc_data.remote()) - - def get_progs_dict(self): - return ray.get(self.data_registry.get_progs_dict.remote()) - - def write_progs_dict(self): - return ray.get(self.data_registry.write_progs_dict.remote()) - - def update_progs_dict(self, function_name, json_legality_annotations): - return ray.get(self.data_registry.update_progs_dict.remote(function_name, json_legality_annotations)) - - def increment(self): - return ray.get(self.data_registry.increment.remote()) diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index f5599f0..727c466 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -24,7 +24,7 @@ class EnvironmentConfig: clean_files: bool = True json_dataset: dict = field(default_factory=lambda: { "path": None, - "cpp_root": None, + "cpps_path": None, "path_to_save_sataset": None }) use_dataset: bool = False From 0455369725562e8ee94cbd60155bac0113e325c4 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Fri, 3 Feb 2023 14:15:58 +0400 Subject: [PATCH 12/27] removed import bug --- rl_interface/environment.py | 2 +- utils/__init__.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/rl_interface/environment.py b/rl_interface/environment.py index 48f8c1e..36696cb 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -22,7 +22,7 @@ class TiramisuScheduleEnvironment(gym.Env): ''' The reinforcement learning environment used by the GYM. ''' - SAVING_FREQUENCY = 50 + SAVING_FREQUENCY = 500 def __init__(self, config, dataset_actor): print("Configuring the environment variables") diff --git a/utils/__init__.py b/utils/__init__.py index 0fc0bf6..29d52c7 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,3 +1,2 @@ from .environment_variables import * from .rl_autoscheduler_config import * -from .global_ray_variables import * From d5efeca82bab4183e7f160431c01cb9602894596 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Fri, 3 Feb 2023 16:21:29 +0400 Subject: [PATCH 13/27] data saviung multiple formats --- rl_interface/environment.py | 6 --- train_ppo.py | 2 +- utils/dataset_utilities.py | 73 +++++++++++++++++++++++--------- utils/rl_autoscheduler_config.py | 5 ++- 4 files changed, 58 insertions(+), 28 deletions(-) diff --git a/rl_interface/environment.py b/rl_interface/environment.py index 36696cb..ecb4264 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -156,7 +156,6 @@ def step(self, raw_action): speedup = 1.0 self.steps += 1 self.total_steps += 1 - print(f"step:{self.total_steps}") try: action = rl_interface.Action(raw_action, @@ -207,9 +206,4 @@ def step(self, raw_action): reward_object = rl_interface.Reward(speedup) reward = reward_object.reward print(f"Received a reward: {reward}") - - # Saving data - if not self.total_steps % self.SAVING_FREQUENCY: - self.dataset_actor.save_dataset_to_disk.remote( - self.config.environment.json_dataset['path_to_save_dataset'], format="pkl") return self.obs, reward, done, info diff --git a/train_ppo.py b/train_ppo.py index 6f9f95b..0ed125b 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -32,7 +32,7 @@ def main(config: RLAutoSchedulerConfig): dataset_path = config.environment.json_dataset[ 'path'] if config.environment.use_dataset else config.environment.dataset_path dataset_actor = DatasetAgent.remote( - dataset_path=dataset_path, use_dataset=config.environment.use_dataset) + dataset_path=dataset_path, use_dataset=config.environment.use_dataset, path_to_save_dataset=config.environment.json_dataset['path_to_save_dataset'], dataset_format=config.environment.json_dataset['dataset_format']) register_env( "Tiramisu_env_v1", diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index dc47657..87895d2 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -1,32 +1,54 @@ import bz2 import json -import logging import os import pickle import random import numpy as np -import ray.data + +SAVING_FREQUENCY = 10 + + +class DataSetFormat(): + PICKLE = "PICKLE" + JSON = "JSON" + BZ2 = "BZ2" @ray.remote class DatasetAgent: - def __init__(self, dataset_path, use_dataset=False, shuffle=False): + def __init__(self, dataset_path, path_to_save_dataset, dataset_format, use_dataset=False, shuffle=False): self.dataset_path = dataset_path + self.path_to_save_dataset = path_to_save_dataset + self.dataset_format = dataset_format self.use_dataset = use_dataset + self.shuffle = shuffle self.dataset = {} self.function_names = [] - self.shuffle = shuffle + self.nbr_updates = 0 + if use_dataset: - logging.info(f"reading dataset from json at:{dataset_path}") - with bz2.BZ2File(dataset_path, 'rb') as f: - self.dataset = pickle.load(f) - self.function_names = list(self.dataset.keys()) - logging.info( + print(f"reading dataset from json at:{dataset_path}") + match dataset_format: + case DataSetFormat.PICKLE: + with open(dataset_path, 'rb') as f: + self.dataset = pickle.load(f) + self.function_names = list(self.dataset.keys()) + case DataSetFormat.JSON: + with open(dataset_path, 'rb') as f: + self.dataset = json.load(f) + self.function_names = list(self.dataset.keys()) + case DataSetFormat.BZ2: + with bz2.BZ2File(dataset_path, 'rb') as f: + self.dataset = pickle.load(f) + self.function_names = list(self.dataset.keys()) + case _: + raise ValueError("Format specified not supported") + print( f"[Done] reading dataset from json at:{dataset_path}") else: os.getcwd() - logging.info(f"reading data from ls at: {os.getcwd()}") + print(f"reading data from ls at: {os.getcwd()}") self.function_names = os.listdir(dataset_path) if self.shuffle: @@ -38,16 +60,27 @@ def get_next_function(self): def update_dataset(self, function_name, function_dict): self.dataset[function_name] = function_dict + self.nbr_updates += 1 + print(f"# updates: {self.nbr_updates}") + if self.nbr_updates % SAVING_FREQUENCY == 0: + self.save_dataset_to_disk() - def save_dataset_to_disk(self, path, format): - logging.info("[Start] Save the legality_annotations_dict to disk") + def save_dataset_to_disk(self): + print("[Start] Save the legality_annotations_dict to disk") - if format == "pkl": - with bz2.BZ2File(path, 'wb') as f: - pickle.dump(self.dataset, f, - protocol=pickle.HIGHEST_PROTOCOL) - else: - with open(path, "w") as f: - json.dump(self.dataset, f) - logging.info("[Done] Save the legality_annotations_dict to disk") + match self.dataset_format: + case DataSetFormat.PICKLE: + with open(f"{self.path_to_save_dataset}.pkl", "wb") as f: + pickle.dump(self.dataset, f, + protocol=pickle.HIGHEST_PROTOCOL) + case DataSetFormat.JSON: + with open(f"{self.path_to_save_dataset}.json", "w") as f: + json.dump(self.dataset, f) + case DataSetFormat.BZ2: + with bz2.BZ2File(f"{self.path_to_save_dataset}.bz2.pkl", 'wb') as f: + pickle.dump(self.dataset, f, + protocol=pickle.HIGHEST_PROTOCOL) + case _: + raise ValueError("Format specified not supported") + print("[Done] Save the legality_annotations_dict to disk") return True diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 727c466..b02957a 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -3,6 +3,8 @@ import yaml +from utils.dataset_utilities import DataSetFormat + USE_WANDB = False @@ -25,7 +27,8 @@ class EnvironmentConfig: json_dataset: dict = field(default_factory=lambda: { "path": None, "cpps_path": None, - "path_to_save_sataset": None + "path_to_save_sataset": None, + "dataset_format": DataSetFormat.PICKLE }) use_dataset: bool = False From 5e1f766d0ae21d3c2937bd9916dbf5bbc1c2f037 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Fri, 3 Feb 2023 16:22:03 +0400 Subject: [PATCH 14/27] data saviung multiple formats --- utils/dataset_utilities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index 87895d2..4f9404e 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -4,6 +4,7 @@ import pickle import random import numpy as np +import ray SAVING_FREQUENCY = 10 From c4c162f5566e51da744b1dc853314ddbed9fb191 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Mon, 6 Feb 2023 12:14:44 +0400 Subject: [PATCH 15/27] added solver resutls to the dataset, fixed circular imports and remove import * --- rl_interface/__init__.py | 5 --- rl_interface/action.py | 51 ++++++++++++++++++++---- rl_interface/environment.py | 29 ++++++++------ tiramisu_programs/__init__.py | 18 --------- tiramisu_programs/schedule_controller.py | 44 +++++++++----------- tiramisu_programs/tiramisu_program.py | 30 +++++++------- train_ppo.py | 13 +++--- utils/__init__.py | 2 - utils/dataset_utilities.py | 7 +++- utils/rl_autoscheduler_config.py | 8 ++-- 10 files changed, 109 insertions(+), 98 deletions(-) diff --git a/rl_interface/__init__.py b/rl_interface/__init__.py index 9aa9ca9..e69de29 100644 --- a/rl_interface/__init__.py +++ b/rl_interface/__init__.py @@ -1,5 +0,0 @@ -from .action import * -from .environment import * -from .model import * -from .reward import * -from .utils import * \ No newline at end of file diff --git a/rl_interface/action.py b/rl_interface/action.py index 1dce1b4..30cf233 100644 --- a/rl_interface/action.py +++ b/rl_interface/action.py @@ -1,3 +1,9 @@ +from tiramisu_programs.optimization import OptimizationCommand +from tiramisu_programs.schedule_utils import ScheduleUtils +from tiramisu_programs.tiramisu_program import TiramisuProgram +import ray + + class Action: """ " Action class to store and standardize the action for the environment. @@ -101,7 +107,7 @@ def __init__(self, id_, it_dict, common_it): self.it_dict = it_dict self.common_it = common_it - def parameter(self, comp=None, prog=None): + def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[OptimizationCommand] = None): """" Property method to return the parameter related to the action selected. Returns: @@ -373,6 +379,7 @@ def parameter(self, comp=None, prog=None): return params elif self.id == 44: # SKEWING01 + solver_res = None first_it = 0 second_it = 1 @@ -381,10 +388,25 @@ def parameter(self, comp=None, prog=None): "second_dim_index": second_it } - # print("before calling solver") + # Load saved results if they exist + if prog.config.environment.use_dataset: + tmp_sched_str = ScheduleUtils.optimlist_to_str(schedule) + + # Check if schedule is saved + if tmp_sched_str in prog.function_dict[ + 'schedules_solver_results_dict']: + print( + f"Loading solver results from saved schedule: {tmp_sched_str}") + solver_res = prog.function_dict[ + 'schedules_solver_results_dict'][tmp_sched_str] - solver_res = prog.call_solver(comp, skew_params) - # print("afetr calling solver") + if solver_res is None: + solver_res = prog.call_solver(comp, skew_params) + + # Save the new solver results + if prog.config.environment.use_dataset: + prog.function_dict[ + 'schedules_solver_results_dict'][tmp_sched_str] = solver_res if solver_res == None or solver_res == "-1": return { @@ -403,6 +425,7 @@ def parameter(self, comp=None, prog=None): } elif self.id == 45: # SKEWING12 + solver_res = None first_it = 1 second_it = 2 @@ -411,10 +434,24 @@ def parameter(self, comp=None, prog=None): "second_dim_index": second_it } - # print("before calling solver") + # Load saved results if they exist + if prog.config.environment.use_dataset: + tmp_sched_str = ScheduleUtils.optimlist_to_str(schedule) + + # Check if schedule is saved + if tmp_sched_str in prog.function_dict[ + 'schedules_solver_results_dict']: + print( + f"Loading solver results from saved schedule: {tmp_sched_str}") + solver_res = prog.function_dict[ + 'schedules_solver_results_dict'][tmp_sched_str] + + if solver_res is None: + solver_res = prog.call_solver(comp, skew_params) - solver_res = prog.call_solver(comp, skew_params) - # print("afetr calling solver") + # Save the new solver results + if prog.config.environment.use_dataset: + prog.function_dict['schedules_solver_results_dict'][tmp_sched_str] = solver_res if solver_res == None or solver_res == "-1": return { diff --git a/rl_interface/environment.py b/rl_interface/environment.py index ecb4264..270b6be 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -9,11 +9,14 @@ import gym import numpy as np import ray -import tiramisu_programs - -import rl_interface -from tiramisu_programs.schedule_utils import ScheduleUtils +from rl_interface.action import Action +from rl_interface.reward import Reward +from tiramisu_programs.cpp_file import CPP_File +from tiramisu_programs.tiramisu_program import TiramisuProgram +from tiramisu_programs.schedule import Schedule +from tiramisu_programs.schedule_controller import ScheduleController from utils.environment_variables import configure_env_variables +from utils.rl_autoscheduler_config import RLAutoSchedulerConfig np.seterr(invalid="raise") @@ -24,7 +27,7 @@ class TiramisuScheduleEnvironment(gym.Env): ''' SAVING_FREQUENCY = 500 - def __init__(self, config, dataset_actor): + def __init__(self, config: RLAutoSchedulerConfig, dataset_actor): print("Configuring the environment variables") configure_env_variables(config) @@ -90,7 +93,7 @@ def reset(self, file=None): try: # Clean files of the previous function ran if self.config.environment.clean_files and self.previous_cpp_file: - tiramisu_programs.cpp_file.CPP_File.clean_cpp_file( + CPP_File.clean_cpp_file( self.cpps_path, self.previous_cpp_file) # get the next function @@ -98,22 +101,22 @@ def reset(self, file=None): self.dataset_actor.get_next_function.remote()) # Copy the function's files to the dataset copy created - file = tiramisu_programs.cpp_file.CPP_File.get_cpp_file( + file = CPP_File.get_cpp_file( self.cpps_path, function_name) # Set up the function files to be deleted on the next iteration self.previous_cpp_file = function_name # Load the tiramisu program from the file - self.prog = tiramisu_programs.tiramisu_program.TiramisuProgram( + self.prog = TiramisuProgram( self.config, file, function_dict) print(f"Trying with program {self.prog.name}") - self.schedule_object = tiramisu_programs.schedule.Schedule( + self.schedule_object = Schedule( self.prog) - self.schedule_controller = tiramisu_programs.schedule_controller.ScheduleController( + self.schedule_controller = ScheduleController( schedule=self.schedule_object, nb_executions=self.nb_executions, config=self.config) @@ -147,7 +150,7 @@ def step(self, raw_action): Apply a transformation on a program. If the action raw_action is legal, it is applied. If not, it is ignored and not added to the schedule. Returns: The current state after eventually applying the transformation, and the reward that the agent received for taking the action. """ - action_name = rl_interface.Action.ACTIONS_ARRAY[raw_action] + action_name = Action.ACTIONS_ARRAY[raw_action] print("\n ----> {} [ {} ] \n".format( action_name, self.schedule_object.sched_str)) info = {} @@ -158,7 +161,7 @@ def step(self, raw_action): self.total_steps += 1 try: - action = rl_interface.Action(raw_action, + action = Action(raw_action, self.schedule_object.it_dict, self.schedule_object.common_it) _, speedup, done, info = self.schedule_controller.apply_action( @@ -203,7 +206,7 @@ def step(self, raw_action): self.dataset_actor.update_dataset.remote( self.prog.name, self.prog.function_dict) - reward_object = rl_interface.Reward(speedup) + reward_object = Reward(speedup) reward = reward_object.reward print(f"Received a reward: {reward}") return self.obs, reward, done, info diff --git a/tiramisu_programs/__init__.py b/tiramisu_programs/__init__.py index 7015e42..e69de29 100644 --- a/tiramisu_programs/__init__.py +++ b/tiramisu_programs/__init__.py @@ -1,18 +0,0 @@ -from .cpp_file import * -from .optimization import * -from .schedule import * -from .schedule_utils import * -from .schedule_controller import * -from .tiramisu_program import * -from .surrogate_model_utils.json_to_tensor import * -from .surrogate_model_utils.modeling import * - -__all__ = [ - "CPP_File", "OptimizationCommand", "ScheduleController", - "LargeAccessMatices", "NbAccessException", "LoopsDepthException", - "TimeOutException", "LoopExtentException", "RepresentationLengthException", - "NumpyEncoder", "LCException", "SkewParamsException", "IsTiledException", - "IsInterchangedException", "IsSkewedException", "IsUnrolledException", - "IsParallelizedException", "IsReversedException", "SkewUnrollException", - "ScheduleUtils", "Schedule", "TiramisuProgram", "InternalExecException" -] \ No newline at end of file diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 906ac5a..38f215f 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -1,16 +1,14 @@ -import copy -import logging + import sys import time import traceback -from typing import List import torch from rl_interface.action import Action from tiramisu_programs.optimization import OptimizationCommand from tiramisu_programs.schedule import Schedule -from tiramisu_programs.schedule_utils import * +from tiramisu_programs.schedule_utils import ScheduleUtils, IsInterchangedException, IsParallelizedException, IsReversedException, IsSkewedException, IsTiledException, IsUnrolledException, SkewParamsException, SkewUnrollException, LCException from tiramisu_programs.surrogate_model_utils.json_to_tensor import \ get_schedule_representation from tiramisu_programs.surrogate_model_utils.modeling import \ @@ -57,7 +55,8 @@ def apply_action(self, action): action_params = action.parameter() else: comp = list(self.schedule_object.it_dict.keys())[0] - action_params = action.parameter(comp, self.schedule_object.prog) + action_params = action.parameter( + comp, self.schedule_object.prog, self.schedule) if action.id in range(28): # Interchange if not self.schedule_object.is_interchaged: @@ -73,10 +72,9 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and is_schedule_saved: + if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ + 'schedules_legality_dict']: print( "Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ @@ -135,10 +133,9 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and is_schedule_saved: + if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ + 'schedules_legality_dict']: print( "Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ @@ -208,10 +205,9 @@ def apply_action(self, action): self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and is_schedule_saved: + if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ + 'schedules_legality_dict']: print( "Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ @@ -290,10 +286,9 @@ def apply_action(self, action): self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and is_schedule_saved: + if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ + 'schedules_legality_dict']: print( "Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ @@ -352,10 +347,9 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and is_schedule_saved: + if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ + 'schedules_legality_dict']: print( "Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ @@ -406,10 +400,9 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and is_schedule_saved: + if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ + 'schedules_legality_dict']: print( "Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ @@ -464,10 +457,9 @@ def apply_action(self, action): tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) - is_schedule_saved = tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'] # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and is_schedule_saved: + if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ + 'schedules_legality_dict']: print( "Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ diff --git a/tiramisu_programs/tiramisu_program.py b/tiramisu_programs/tiramisu_program.py index d70057e..60eea09 100644 --- a/tiramisu_programs/tiramisu_program.py +++ b/tiramisu_programs/tiramisu_program.py @@ -7,7 +7,9 @@ import ray -import tiramisu_programs +from .cpp_file import CPP_File +from .schedule import TimeOutException +from utils.rl_autoscheduler_config import RLAutoSchedulerConfig class InternalExecException(Exception): @@ -96,7 +98,7 @@ class TiramisuProgram(): return 0; }''' - def __init__(self, config, file_path, function_dict=None): + def __init__(self, config: RLAutoSchedulerConfig, file_path, function_dict=None): self.config = config self.file_path = file_path with open(file_path, 'r') as f: @@ -157,7 +159,7 @@ def get_program_annotations(self): f.write(get_json_prog) # compile the cpp file and run to generate annotations in json file - tiramisu_programs.CPP_File.compile_and_run_tiramisu_code( + CPP_File.compile_and_run_tiramisu_code( self.config, output_file, 'Generating program annotations') # Read the json file and return the annotations @@ -221,7 +223,7 @@ def check_legality_of_schedule( self.reset_legality_check_result_file() log_message = 'Checking legality for: ' + ' '.join( [o.tiramisu_optim_str for o in optims_list]) - tiramisu_programs.CPP_File.compile_and_run_tiramisu_code( + CPP_File.compile_and_run_tiramisu_code( self.config, output_file, log_message) lc_result = self.read_legality_check_result_file() @@ -278,7 +280,7 @@ def call_solver(self, comp, params): log_message = 'Solver results for: computation {}'.format( comp) + ' '.join([p for p in params]) - if tiramisu_programs.CPP_File.compile_and_run_tiramisu_code( + if CPP_File.compile_and_run_tiramisu_code( self.config, output_file, log_message): solver_result = self.read_solver_result_file() if len(solver_result) == 0: @@ -319,7 +321,7 @@ def evaluate_schedule(self, log_message = 'Applying schedule: ' + ' '.join( [o.tiramisu_optim_str for o in optims_list]) start_time = time.time() - if (tiramisu_programs.CPP_File.compile_and_run_tiramisu_code( + if (CPP_File.compile_and_run_tiramisu_code( self.config, output_file, log_message)): try: execution_times = self.get_measurements( @@ -328,7 +330,7 @@ def evaluate_schedule(self, return min(execution_times) else: return 0 - except tiramisu_programs.schedule.TimeOutException: + except TimeOutException: print("time out exception") return 10 * nb_executions * (initial_exec_time if initial_exec_time else 1.0) @@ -344,8 +346,8 @@ def get_measurements(self, cmd_type, nb_executions, initial_exec_time): if not self.wrapper_is_compiled: self.write_wrapper_code() log_message_cmd = 'printf "Compiling wrapper\n">> ${FUNC_DIR}log.txt' - tiramisu_programs.CPP_File.launch_cmd(log_message_cmd, '') - failed = tiramisu_programs.CPP_File.launch_cmd( + CPP_File.launch_cmd(log_message_cmd, '') + failed = CPP_File.launch_cmd( self.config.tiramisu.compile_wrapper_cmd, self.file_path) if failed: print('Failed compiling wrapper') @@ -357,12 +359,12 @@ def get_measurements(self, cmd_type, nb_executions, initial_exec_time): run_wrapper_cmd = 'cd ${FUNC_DIR};\ ${GXX} -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ ./${FUNC_NAME}_wrapper ' + str(nb_executions) - tiramisu_programs.CPP_File.launch_cmd(log_message_cmd, '') + CPP_File.launch_cmd(log_message_cmd, '') s_time = time.time() - failed = tiramisu_programs.CPP_File.launch_cmd(run_wrapper_cmd, - self.file_path, - cmd_type, nb_executions, - initial_exec_time) + failed = CPP_File.launch_cmd(run_wrapper_cmd, + self.file_path, + cmd_type, nb_executions, + initial_exec_time) if failed: print('Failed running wrapper') diff --git a/train_ppo.py b/train_ppo.py index 0ed125b..a2165ae 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -90,10 +90,9 @@ def main(config: RLAutoSchedulerConfig): config.tiramisu.env_type = "model" logging.basicConfig(level=logging._nameToLevel[args.log_level]) - # logging.getLogger().setLevel(logging._nameToLevel[args.log_level]) - # if args.num_workers == 1: - with ray.init(): - main(config) - # else: - # with ray.init(address="auto"): - # main(config) + if args.num_workers == 1: + with ray.init(): + main(config) + else: + with ray.init(address="auto"): + main(config) diff --git a/utils/__init__.py b/utils/__init__.py index 29d52c7..e69de29 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,2 +0,0 @@ -from .environment_variables import * -from .rl_autoscheduler_config import * diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index 4f9404e..b2ab848 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -6,7 +6,7 @@ import numpy as np import ray -SAVING_FREQUENCY = 10 +SAVING_FREQUENCY = 100 class DataSetFormat(): @@ -57,7 +57,10 @@ def __init__(self, dataset_path, path_to_save_dataset, dataset_format, use_datas def get_next_function(self): function_name = np.random.choice(self.function_names) - return function_name, self.dataset[function_name] + if self.use_dataset: + return function_name, self.dataset[function_name] + else: + return function_name, None def update_dataset(self, function_name, function_dict): self.dataset[function_name] = function_dict diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index b02957a..7dce5e3 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -39,15 +39,15 @@ class TiramisuConfig: env_type: Literal["model", "cpu"] = "cpu" model_checkpoint: str = "/data/scratch/hbenyamina/model_published_nn_finale.pt" compile_tiramisu_cmd: str = 'printf "Compiling ${FILE_PATH}\n" >> ${FUNC_DIR}log.txt;\ - c++ -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lz -lpthread -std=c++11 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ - c++ -Wl,--no-as-needed -ldl -g -fno-rtti -lz -lpthread -std=c++11 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ + ${CXX} -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' run_tiramisu_cmd: str = 'printf "Running ${FILE_PATH}.out\n">> ${FUNC_DIR}log.txt;\ ./${FILE_PATH}.out>> ${FUNC_DIR}log.txt;' compile_wrapper_cmd = 'cd ${FUNC_DIR};\ - g++ -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ - g++ -std=c++11 -fno-rtti -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/3rdParty/isl/include/ -I${TIRAMISU_ROOT}/benchmarks -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib/ -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -o ${FUNC_NAME}_wrapper -ltiramisu -lHalide -ldl -lpthread -lz -lm -Wl,-rpath,${TIRAMISU_ROOT}/build ./${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -ltiramisu -lHalide -ldl -lpthread -lz -lm' + ${GXX} -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' @dataclass From 6b78258865f0effab64c24a1e59af62be8423537 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Mon, 6 Feb 2023 15:36:03 +0400 Subject: [PATCH 16/27] changed saving frequency to increase performance --- utils/dataset_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index b2ab848..cd38ff2 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -6,7 +6,7 @@ import numpy as np import ray -SAVING_FREQUENCY = 100 +SAVING_FREQUENCY = 1000 class DataSetFormat(): From 09d2c6ec8ea3c41f4e20c08b35741e8ca84719f5 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Mon, 6 Feb 2023 15:59:18 +0400 Subject: [PATCH 17/27] Added comments explaining dataset config part --- config.yaml.template | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/config.yaml.template b/config.yaml.template index dc630e9..ffa5355 100644 --- a/config.yaml.template +++ b/config.yaml.template @@ -10,6 +10,15 @@ environment: dataset_path: "./Dataset_multi/" programs_file: "./multicomp.json" clean_files: True + json_dataset: + # Path to the dataset file + path: "./Dataset_pickle/full_legality_annotations_solver.pkl" + # Path to the dataset cpp files + cpps_path: "./Dataset_multi" + # Path where to save the updated dataset (without the extension, it will be inferred by the dataset format) + path_to_save_dataset: "./Dataset_pickle/full_legality_annotations_solver_updated" + # Supported formats are available int he dataset_utilities module + dataset_format: "PICKLE" tiramisu: From 2f245d63c3e97efa2bcc7c9d39c9363c641134a8 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Tue, 7 Feb 2023 11:39:29 +0400 Subject: [PATCH 18/27] changed back the commands to no lz to avoid conflicts --- utils/rl_autoscheduler_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 7b7257c..103ef74 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -40,15 +40,15 @@ class TiramisuConfig: env_type: Literal["model", "cpu"] = "cpu" model_checkpoint: str = "/data/scratch/hbenyamina/model_published_nn_finale.pt" compile_tiramisu_cmd: str = 'printf "Compiling ${FILE_PATH}\n" >> ${FUNC_DIR}log.txt;\ - c++ -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lz -lpthread -std=c++11 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ - c++ -Wl,--no-as-needed -ldl -g -fno-rtti -lz -lpthread -std=c++11 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl ' + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ + ${CXX} -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' run_tiramisu_cmd: str = 'printf "Running ${FILE_PATH}.out\n">> ${FUNC_DIR}log.txt;\ ./${FILE_PATH}.out>> ${FUNC_DIR}log.txt;' compile_wrapper_cmd = 'cd ${FUNC_DIR};\ - g++ -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ - g++ -std=c++11 -fno-rtti -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/3rdParty/isl/include/ -I${TIRAMISU_ROOT}/benchmarks -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib/ -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -o ${FUNC_NAME}_wrapper -ltiramisu -lHalide -ldl -lpthread -lz -lm -Wl,-rpath,${TIRAMISU_ROOT}/build ./${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -ltiramisu -lHalide -ldl -lpthread -lz -lm' + ${GXX} -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' @dataclass From 8968bb352bdaea1c3c5dbb247425b9c1638fff03 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Tue, 7 Feb 2023 14:30:10 +0400 Subject: [PATCH 19/27] fixed call of clean cpp --- rl_interface/environment.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rl_interface/environment.py b/rl_interface/environment.py index 547652a..78de81a 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -91,8 +91,7 @@ def reset(self, file=None): try: # Clean files of the previous function ran if self.config.environment.clean_files and self.previous_cpp_file: - CPP_File.clean_cpp_file( - self.cpps_path, self.previous_cpp_file) + CPP_File.clean_cpp_file(self.previous_cpp_file) # get the next function (function_name, function_dict) = ray.get( From 28774d3f128919143453e552fec0999c6ff6fad8 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Mon, 13 Feb 2023 15:23:55 +0400 Subject: [PATCH 20/27] added saving the dataset to the disk when not using dataset --- rl_interface/action.py | 10 +++--- rl_interface/environment.py | 3 ++ tiramisu_programs/schedule_controller.py | 14 ++++---- tiramisu_programs/tiramisu_program.py | 4 ++- train_ppo.py | 45 +++++++++++++++--------- utils/dataset_utilities.py | 9 +++-- utils/rl_autoscheduler_config.py | 1 + 7 files changed, 53 insertions(+), 33 deletions(-) diff --git a/rl_interface/action.py b/rl_interface/action.py index 30cf233..3bc4f49 100644 --- a/rl_interface/action.py +++ b/rl_interface/action.py @@ -388,10 +388,11 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti "second_dim_index": second_it } + # Get schedule id + tmp_sched_str = ScheduleUtils.optimlist_to_str(schedule) + # Load saved results if they exist if prog.config.environment.use_dataset: - tmp_sched_str = ScheduleUtils.optimlist_to_str(schedule) - # Check if schedule is saved if tmp_sched_str in prog.function_dict[ 'schedules_solver_results_dict']: @@ -404,9 +405,8 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti solver_res = prog.call_solver(comp, skew_params) # Save the new solver results - if prog.config.environment.use_dataset: - prog.function_dict[ - 'schedules_solver_results_dict'][tmp_sched_str] = solver_res + prog.function_dict[ + 'schedules_solver_results_dict'][tmp_sched_str] = solver_res if solver_res == None or solver_res == "-1": return { diff --git a/rl_interface/environment.py b/rl_interface/environment.py index 78de81a..d181f04 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -1,5 +1,6 @@ # np.set_printoptions(threshold=sys.maxsize) import copy +import logging import sys import time import traceback @@ -26,6 +27,8 @@ class TiramisuScheduleEnvironment(gym.Env): SAVING_FREQUENCY = 500 def __init__(self, config: RLAutoSchedulerConfig, dataset_actor): + logging.basicConfig(level=config.ray.log_level) + print("Configuring the environment variables") configure_env_variables(config) diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index 38f215f..da32b64 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -88,7 +88,7 @@ def apply_action(self, action): self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality # Save legality check - if self.config.environment.use_dataset and saved_legality is None: + if saved_legality is None: self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check @@ -149,7 +149,7 @@ def apply_action(self, action): self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality # Save legality check - if self.config.environment.use_dataset and saved_legality is None: + if saved_legality is None: self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check @@ -220,7 +220,7 @@ def apply_action(self, action): self.lc_total_time += l_time # Save legality check - if self.config.environment.use_dataset and saved_legality is None: + if saved_legality is None: self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check @@ -306,7 +306,7 @@ def apply_action(self, action): self.lc_total_time += l_time # Save legality check - if self.config.environment.use_dataset and saved_legality is None: + if saved_legality is None: self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check @@ -364,7 +364,7 @@ def apply_action(self, action): self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality # Save legality check - if self.config.environment.use_dataset and saved_legality is None: + if saved_legality is None: self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check @@ -419,7 +419,7 @@ def apply_action(self, action): self.lc_total_time += l_time # Save legality check - if self.config.environment.use_dataset and saved_legality is None: + if saved_legality is None: self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check @@ -478,7 +478,7 @@ def apply_action(self, action): self.lc_total_time += l_time # Save legality check - if self.config.environment.use_dataset and saved_legality is None: + if saved_legality is None: self.schedule_object.prog.function_dict[ 'schedules_legality_dict'][tmp_sched_str] = lc_check diff --git a/tiramisu_programs/tiramisu_program.py b/tiramisu_programs/tiramisu_program.py index f8a3739..7ed835b 100644 --- a/tiramisu_programs/tiramisu_program.py +++ b/tiramisu_programs/tiramisu_program.py @@ -137,7 +137,7 @@ def get_program_annotations(self): if self.program_annotations is not None: return self.program_annotations - if self.function_dict: + if self.function_dict['program_annotation'] is not None: self.program_annotations = self.function_dict['program_annotation'] else: # create a cpp file to get the annotations @@ -164,6 +164,8 @@ def get_program_annotations(self): 'r') as f: self.program_annotations = json.loads(f.read()) + self.function_dict['program_annotation'] = self.program_annotations + return self.program_annotations def check_legality_of_schedule( diff --git a/train_ppo.py b/train_ppo.py index 86a837c..4fc78e9 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -1,29 +1,36 @@ -import logging -import os # import hydra import argparse +import logging +import os + import ray + # from hydra.core.config_store import ConfigStore -from ray import tune, air +from ray import air, tune from ray.rllib.models.catalog import ModelCatalog from ray.tune.registry import register_env from rl_interface.environment import TiramisuScheduleEnvironment from rl_interface.model import TiramisuModelMult from utils.dataset_utilities import DatasetAgent -from utils.rl_autoscheduler_config import (RLAutoSchedulerConfig, - dict_to_config, parse_yaml_file, - read_yaml_file) +from utils.rl_autoscheduler_config import ( + RLAutoSchedulerConfig, + dict_to_config, + parse_yaml_file, + read_yaml_file, +) def get_arguments(): parser = argparse.ArgumentParser() - parser.add_argument("--num-workers", default=-1, type=int) + parser.add_argument("--num-workers", default=-1, type=int, + help="Number of workers to use for training") parser.add_argument('--resume-training', - action=argparse.BooleanOptionalAction) - parser.add_argument("--use-dataset", action=argparse.BooleanOptionalAction) + action=argparse.BooleanOptionalAction, help="Resume training from a saved checkpoint") + parser.add_argument("--use-dataset", action=argparse.BooleanOptionalAction, + help="Use the dataset (path specified in config) to train") parser.add_argument("--log-level", default="INFO", # TODO change back to WARN - type=str, choices=list(logging._nameToLevel.keys())) + type=str, choices=list(logging._nameToLevel.keys()), help="Log levels") return parser.parse_args() @@ -94,9 +101,13 @@ def main(config: RLAutoSchedulerConfig): if args.num_workers != -1: config.ray.num_workers = args.num_workers + if args.resume_training: config.ray.resume_training = True + if args.log_level: + config.ray.log_level = args.log_level + if args.use_dataset: config.environment.use_dataset = args.use_dataset @@ -106,10 +117,10 @@ def main(config: RLAutoSchedulerConfig): # Force model usage if using dataset config.tiramisu.env_type = "model" - logging.basicConfig(level=logging._nameToLevel[args.log_level]) - if args.num_workers == 1: - with ray.init(): - main(config) - else: - with ray.init(address="auto"): - main(config) + logging.basicConfig(level=args.log_level) + # if args.num_workers == 1: + with ray.init(): + main(config) + # else: + # with ray.init(address="auto"): + # main(config) diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index cd38ff2..5fa2328 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -48,8 +48,7 @@ def __init__(self, dataset_path, path_to_save_dataset, dataset_format, use_datas f"[Done] reading dataset from json at:{dataset_path}") else: - os.getcwd() - print(f"reading data from ls at: {os.getcwd()}") + print(f"reading data from ls at: {dataset_path}") self.function_names = os.listdir(dataset_path) if self.shuffle: @@ -60,7 +59,11 @@ def get_next_function(self): if self.use_dataset: return function_name, self.dataset[function_name] else: - return function_name, None + return function_name, { + 'program_annotation': None, + 'schedules_legality_dict': {}, + 'schedules_solver_results_dict': {} + } def update_dataset(self, function_name, function_dict): self.dataset[function_name] = function_dict diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 103ef74..0645472 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -18,6 +18,7 @@ class RayConfig: name: str = "Training_multi_enhanced" log_directory: str = "ray_results" resume_training: bool = False + log_level: str = "WARN" @dataclass From 6d7f7b97dc82f73303d291236bdce0a7c5d3ff30 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Tue, 14 Feb 2023 11:10:25 +0400 Subject: [PATCH 21/27] reverted ray init to multiple workers --- train_ppo.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train_ppo.py b/train_ppo.py index 4fc78e9..ca41387 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -118,9 +118,9 @@ def main(config: RLAutoSchedulerConfig): config.tiramisu.env_type = "model" logging.basicConfig(level=args.log_level) - # if args.num_workers == 1: - with ray.init(): - main(config) - # else: - # with ray.init(address="auto"): - # main(config) + if args.num_workers == 1: + with ray.init(): + main(config) + else: + with ray.init(address="auto"): + main(config) From acfc53bf18837a9e1ffeaacbc02873c2c6e89cc9 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Wed, 22 Feb 2023 10:26:00 +0400 Subject: [PATCH 22/27] added 2 checkpoints for dataset and model.eval --- tiramisu_programs/schedule_controller.py | 2 ++ train_ppo.py | 9 +++++++-- utils/dataset_utilities.py | 15 ++++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index da32b64..f3737d1 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -40,6 +40,7 @@ def __init__(self, self.model = Model_Recursive_LSTM_v2() self.model.load_state_dict( torch.load(config.tiramisu.model_checkpoint, map_location="cpu")) + self.model.eval() def apply_action(self, action): @@ -688,6 +689,7 @@ def get_exec_time_by_model(self, optims_list, cmd_type, nb_executions, max_depth=self.schedule_object.MAX_DEPTH - 1) tree_tensors = (self.schedule_object.templates["prog_tree"], computations_tensor, loops_tensor) + with torch.no_grad(): predicted_speedup = self.model( tree_tensors, diff --git a/train_ppo.py b/train_ppo.py index ca41387..178b566 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -36,6 +36,7 @@ def get_arguments(): # @hydra.main(config_path="config", config_name="config") def main(config: RLAutoSchedulerConfig): + logging.basicConfig(level=config.ray.log_level) local_dir = os.path.join(config.ray.base_path, config.ray.log_directory) dataset_path = config.environment.json_dataset[ @@ -49,6 +50,11 @@ def main(config: RLAutoSchedulerConfig): ) ModelCatalog.register_custom_model("tiramisu_model_v1", TiramisuModelMult) + + # Use all available CPUs as workers (-1 for the head) + if config.ray.num_workers == -1: + config.ray.num_workers = int(ray.available_resources()['CPU'])-1 + logging.INFO(f"{'='*20} # Used CPU:{config.ray.num_workers}") config_dict = { "env": "Tiramisu_env_v1", "num_workers": config.ray.num_workers, @@ -99,8 +105,7 @@ def main(config: RLAutoSchedulerConfig): config = dict_to_config(parsed_yaml_dict) args = get_arguments() - if args.num_workers != -1: - config.ray.num_workers = args.num_workers + config.ray.num_workers = args.num_workers if args.resume_training: config.ray.resume_training = True diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index 5fa2328..8ec7df7 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -26,6 +26,7 @@ def __init__(self, dataset_path, path_to_save_dataset, dataset_format, use_datas self.dataset = {} self.function_names = [] self.nbr_updates = 0 + self.dataset_name = dataset_path.split('/')[-1].split('.')[0] if use_dataset: print(f"reading dataset from json at:{dataset_path}") @@ -70,21 +71,25 @@ def update_dataset(self, function_name, function_dict): self.nbr_updates += 1 print(f"# updates: {self.nbr_updates}") if self.nbr_updates % SAVING_FREQUENCY == 0: - self.save_dataset_to_disk() + if self.nbr_updates % (2*SAVING_FREQUENCY): + self.save_dataset_to_disk(version=2) + else: + self.save_dataset_to_disk(version=1) - def save_dataset_to_disk(self): + def save_dataset_to_disk(self, version=1): print("[Start] Save the legality_annotations_dict to disk") + updated_dataset_name = f"{self.path_to_save_dataset}/{self.dataset_name}_updated_{version}" match self.dataset_format: case DataSetFormat.PICKLE: - with open(f"{self.path_to_save_dataset}.pkl", "wb") as f: + with open(f"{updated_dataset_name}.pkl", "wb") as f: pickle.dump(self.dataset, f, protocol=pickle.HIGHEST_PROTOCOL) case DataSetFormat.JSON: - with open(f"{self.path_to_save_dataset}.json", "w") as f: + with open(f"{updated_dataset_name}.json", "w") as f: json.dump(self.dataset, f) case DataSetFormat.BZ2: - with bz2.BZ2File(f"{self.path_to_save_dataset}.bz2.pkl", 'wb') as f: + with bz2.BZ2File(f"{updated_dataset_name}.bz2.pkl", 'wb') as f: pickle.dump(self.dataset, f, protocol=pickle.HIGHEST_PROTOCOL) case _: From 1f45759afb53550573d736bf5a2fc53acc01c770 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Thu, 23 Feb 2023 13:13:40 +0400 Subject: [PATCH 23/27] fixed INFO to info bug in calling logging --- train_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_ppo.py b/train_ppo.py index 178b566..6cec909 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -54,7 +54,7 @@ def main(config: RLAutoSchedulerConfig): # Use all available CPUs as workers (-1 for the head) if config.ray.num_workers == -1: config.ray.num_workers = int(ray.available_resources()['CPU'])-1 - logging.INFO(f"{'='*20} # Used CPU:{config.ray.num_workers}") + logging.info(f"==================== # Used CPU:{config.ray.num_workers}") config_dict = { "env": "Tiramisu_env_v1", "num_workers": config.ray.num_workers, From c6ac8cc8309dcc02fce58e521113411b5dee8345 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Wed, 15 Mar 2023 10:30:00 +0400 Subject: [PATCH 24/27] format with black --- .gitignore | 3 +- evaluate.py | 22 +- rl_interface/action.py | 189 +- rl_interface/environment.py | 124 +- rl_interface/model.py | 133 +- rl_interface/reward.py | 17 +- rl_interface/utils.py | 11 +- tiramisu_programs/cpp_file.py | 63 +- tiramisu_programs/optimization.py | 49 +- tiramisu_programs/schedule.py | 1788 +++++++++++------ tiramisu_programs/schedule_controller.py | 583 ++++-- tiramisu_programs/schedule_utils.py | 988 +++++---- .../surrogate_model_utils/__init__.py | 2 +- .../surrogate_model_utils/json_to_tensor.py | 47 +- .../surrogate_model_utils/modeling.py | 17 +- tiramisu_programs/tiramisu_program.py | 400 ++-- train_ppo.py | 68 +- utils/dataset_utilities.py | 44 +- utils/environment_variables.py | 3 +- utils/rl_autoscheduler_config.py | 24 +- 20 files changed, 2784 insertions(+), 1791 deletions(-) diff --git a/.gitignore b/.gitignore index 6454256..aaea323 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ Dataset* .vscode .idea dataset -cpps \ No newline at end of file +cpps +dataset_* \ No newline at end of file diff --git a/evaluate.py b/evaluate.py index 4972df4..2c6b7df 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,9 +1,11 @@ import ray.rllib.agents.ppo as ppo import os import json + # import hydra import argparse import ray + # from hydra.core.config_store import ConfigStore from ray import tune from ray.rllib.models.catalog import ModelCatalog @@ -13,9 +15,12 @@ from rl_interface.model import TiramisuModelMult from utils.environment_variables import configure_env_variables from utils.global_ray_variables import Actor, GlobalVarActor -from utils.rl_autoscheduler_config import (RLAutoSchedulerConfig, - dict_to_config, parse_yaml_file, - read_yaml_file) +from utils.rl_autoscheduler_config import ( + RLAutoSchedulerConfig, + dict_to_config, + parse_yaml_file, + read_yaml_file, +) def get_arguments(): @@ -35,16 +40,15 @@ def main(config: RLAutoSchedulerConfig, checkpoint=None): progs_list_registery = GlobalVarActor.remote( config.environment.programs_file, config.environment.dataset_path, - num_workers=config.ray.num_workers) + num_workers=config.ray.num_workers, + ) shared_variable_actor = Actor.remote(progs_list_registery) register_env( "Tiramisu_env_v1", - lambda a: TiramisuScheduleEnvironment(config, shared_variable_actor - ), + lambda a: TiramisuScheduleEnvironment(config, shared_variable_actor), ) - ModelCatalog.register_custom_model("tiramisu_model_v1", - TiramisuModelMult) + ModelCatalog.register_custom_model("tiramisu_model_v1", TiramisuModelMult) agent = ppo.PPOTrainer( env="Tiramisu_env_v1", @@ -63,7 +67,7 @@ def main(config: RLAutoSchedulerConfig, checkpoint=None): "custom_model_config": { "layer_sizes": list(config.model.layer_sizes), "drops": list(config.model.drops), - } + }, }, }, ) diff --git a/rl_interface/action.py b/rl_interface/action.py index 3bc4f49..c680c92 100644 --- a/rl_interface/action.py +++ b/rl_interface/action.py @@ -8,6 +8,7 @@ class Action: """ " Action class to store and standardize the action for the environment. """ + INTERCHANGE01 = 0 INTERCHANGE02 = 1 INTERCHANGE03 = 2 @@ -78,21 +79,68 @@ class Action: EXIT = 61 ACTIONS_ARRAY = [ - 'INTERCHANGE01', 'INTERCHANGE02', 'INTERCHANGE03', 'INTERCHANGE04', - 'INTERCHANGE05', 'INTERCHANGE06', 'INTERCHANGE07', 'INTERCHANGE12', - 'INTERCHANGE13', 'INTERCHANGE14', 'INTERCHANGE15', 'INTERCHANGE16', - 'INTERCHANGE17', 'INTERCHANGE23', 'INTERCHANGE24', 'INTERCHANGE25', - 'INTERCHANGE26', 'INTERCHANGE27', 'INTERCHANGE34', 'INTERCHANGE35', - 'INTERCHANGE36', 'INTERCHANGE37', 'INTERCHANGE45', 'INTERCHANGE46', - 'INTERCHANGE47', 'INTERCHANGE56', 'INTERCHANGE57', 'INTERCHANGE67', - 'TILING2D01', 'TILING2D12', 'TILING2D23', 'TILING2D34', 'TILING2D45', - 'TILING2D56', 'TILING2D67', 'TILING3D012', 'TILING3D123', - 'TILING3D234', 'TILING3D345', 'TILING3D456', 'TILING3D567', - 'UNROLLING4', 'UNROLLING8', 'UNROLLING16', 'SKEWING01', 'SKEWING01', - 'PARALLELIZATION0', 'PARALLELIZATION1', 'REVERSAL0', 'REVERSAL1', - 'REVERSAL2', 'REVERSAL3', 'REVERSAL4', 'REVERSAL5', 'REVERSAL6', - 'REVERSAL7', 'FUSION0', 'FUSION1', 'FUSION2', 'FUSION3', 'FUSION4', - 'EXIT' + "INTERCHANGE01", + "INTERCHANGE02", + "INTERCHANGE03", + "INTERCHANGE04", + "INTERCHANGE05", + "INTERCHANGE06", + "INTERCHANGE07", + "INTERCHANGE12", + "INTERCHANGE13", + "INTERCHANGE14", + "INTERCHANGE15", + "INTERCHANGE16", + "INTERCHANGE17", + "INTERCHANGE23", + "INTERCHANGE24", + "INTERCHANGE25", + "INTERCHANGE26", + "INTERCHANGE27", + "INTERCHANGE34", + "INTERCHANGE35", + "INTERCHANGE36", + "INTERCHANGE37", + "INTERCHANGE45", + "INTERCHANGE46", + "INTERCHANGE47", + "INTERCHANGE56", + "INTERCHANGE57", + "INTERCHANGE67", + "TILING2D01", + "TILING2D12", + "TILING2D23", + "TILING2D34", + "TILING2D45", + "TILING2D56", + "TILING2D67", + "TILING3D012", + "TILING3D123", + "TILING3D234", + "TILING3D345", + "TILING3D456", + "TILING3D567", + "UNROLLING4", + "UNROLLING8", + "UNROLLING16", + "SKEWING01", + "SKEWING01", + "PARALLELIZATION0", + "PARALLELIZATION1", + "REVERSAL0", + "REVERSAL1", + "REVERSAL2", + "REVERSAL3", + "REVERSAL4", + "REVERSAL5", + "REVERSAL6", + "REVERSAL7", + "FUSION0", + "FUSION1", + "FUSION2", + "FUSION3", + "FUSION4", + "EXIT", ] def __init__(self, id_, it_dict, common_it): @@ -107,8 +155,13 @@ def __init__(self, id_, it_dict, common_it): self.it_dict = it_dict self.common_it = common_it - def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[OptimizationCommand] = None): - """" + def parameter( + self, + comp=None, + prog: TiramisuProgram = None, + schedule: list[OptimizationCommand] = None, + ): + """ " Property method to return the parameter related to the action selected. Returns: The parameter related to this action_id @@ -222,8 +275,9 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti # calculate the loop extent to see if we should create new iterators or not # since it's applicable on the common on the common iterators, we retrieve the information from the first computation loop_extent_1 = abs( - self.it_dict[first_comp][first_it]["upper_bound"] - - self.it_dict[first_comp][first_it]["lower_bound"]) + self.it_dict[first_comp][first_it]["upper_bound"] + - self.it_dict[first_comp][first_it]["lower_bound"] + ) # #print("\n first loop extent is ", loop_extent_1) # print("first factor is", first_fact) if loop_extent_1 == first_fact: @@ -232,11 +286,13 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti elif loop_extent_1 < first_fact: print("Exception, loop extent 1 smaller than factor") from tiramisu_programs.schedule import LoopExtentException + raise LoopExtentException loop_extent_2 = abs( - self.it_dict[first_comp][second_it]["upper_bound"] - - self.it_dict[first_comp][second_it]["lower_bound"]) + self.it_dict[first_comp][second_it]["upper_bound"] + - self.it_dict[first_comp][second_it]["lower_bound"] + ) # print("\n second loop extent is ", loop_extent_2) # print("second factor is", second_fact) if loop_extent_2 == second_fact: @@ -245,6 +301,7 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti elif loop_extent_2 < second_fact: print("exceeeption, loop extent 2 smaller than factor") from tiramisu_programs.schedule import LoopExtentException + raise LoopExtentException return { @@ -291,8 +348,9 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti third_fact = 32 # random.choice([32, 64, 128]) # calculate the loop extent to see if we should create new iterators or not loop_extent_1 = abs( - self.it_dict[first_comp][first_it]["upper_bound"] - - self.it_dict[first_comp][first_it]["lower_bound"]) + self.it_dict[first_comp][first_it]["upper_bound"] + - self.it_dict[first_comp][first_it]["lower_bound"] + ) # #print("\n first loop extent is ", loop_extent_1) # print("first factor is", first_fact) if loop_extent_1 == first_fact: @@ -301,11 +359,13 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti elif loop_extent_1 < first_fact: print("exceeeption, loop extent 1 smaller than factor") from tiramisu_programs.schedule import LoopExtentException + raise LoopExtentException loop_extent_2 = abs( - self.it_dict[first_comp][second_it]["upper_bound"] - - self.it_dict[first_comp][second_it]["lower_bound"]) + self.it_dict[first_comp][second_it]["upper_bound"] + - self.it_dict[first_comp][second_it]["lower_bound"] + ) # print("\n second loop extent is ", loop_extent_2) # print("second factor is", second_fact) if loop_extent_2 == second_fact: @@ -314,11 +374,13 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti elif loop_extent_2 < second_fact: print("exceeeption, loop extent 2 smaller than factor") from tiramisu_programs.schedule import LoopExtentException + raise LoopExtentException loop_extent_3 = abs( - self.it_dict[first_comp][third_it]["upper_bound"] - - self.it_dict[first_comp][third_it]["lower_bound"]) + self.it_dict[first_comp][third_it]["upper_bound"] + - self.it_dict[first_comp][third_it]["lower_bound"] + ) # print("\n third loop extent is ", loop_extent_3) # print("third factor is", third_fact) if loop_extent_3 == third_fact: @@ -327,6 +389,7 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti elif loop_extent_3 < third_fact: print("exceeeption, loop extent 3 smaller than factor") from tiramisu_programs.schedule import LoopExtentException + raise LoopExtentException return { @@ -347,10 +410,7 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti for comp in self.it_dict: it = len(self.it_dict[comp].keys()) - 1 unrolling_fact = 4 - params[comp] = { - "dim_index": it, - "unrolling_factor": unrolling_fact - } + params[comp] = {"dim_index": it, "unrolling_factor": unrolling_fact} return params @@ -359,10 +419,7 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti for comp in self.it_dict: it = len(self.it_dict[comp].keys()) - 1 unrolling_fact = 8 - params[comp] = { - "dim_index": it, - "unrolling_factor": unrolling_fact - } + params[comp] = {"dim_index": it, "unrolling_factor": unrolling_fact} return params @@ -371,10 +428,7 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti for comp in self.it_dict: it = len(self.it_dict[comp].keys()) - 1 unrolling_fact = 16 - params[comp] = { - "dim_index": it, - "unrolling_factor": unrolling_fact - } + params[comp] = {"dim_index": it, "unrolling_factor": unrolling_fact} return params @@ -383,10 +437,7 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti first_it = 0 second_it = 1 - skew_params = { - "first_dim_index": first_it, - "second_dim_index": second_it - } + skew_params = {"first_dim_index": first_it, "second_dim_index": second_it} # Get schedule id tmp_sched_str = ScheduleUtils.optimlist_to_str(schedule) @@ -394,19 +445,21 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti # Load saved results if they exist if prog.config.environment.use_dataset: # Check if schedule is saved - if tmp_sched_str in prog.function_dict[ - 'schedules_solver_results_dict']: + if tmp_sched_str in prog.function_dict["schedules_solver_results_dict"]: print( - f"Loading solver results from saved schedule: {tmp_sched_str}") - solver_res = prog.function_dict[ - 'schedules_solver_results_dict'][tmp_sched_str] + f"Loading solver results from saved schedule: {tmp_sched_str}" + ) + solver_res = prog.function_dict["schedules_solver_results_dict"][ + tmp_sched_str + ] if solver_res is None: solver_res = prog.call_solver(comp, skew_params) # Save the new solver results - prog.function_dict[ - 'schedules_solver_results_dict'][tmp_sched_str] = solver_res + prog.function_dict["schedules_solver_results_dict"][ + tmp_sched_str + ] = solver_res if solver_res == None or solver_res == "-1": return { @@ -429,29 +482,29 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti first_it = 1 second_it = 2 - skew_params = { - "first_dim_index": first_it, - "second_dim_index": second_it - } + skew_params = {"first_dim_index": first_it, "second_dim_index": second_it} # Load saved results if they exist if prog.config.environment.use_dataset: tmp_sched_str = ScheduleUtils.optimlist_to_str(schedule) # Check if schedule is saved - if tmp_sched_str in prog.function_dict[ - 'schedules_solver_results_dict']: + if tmp_sched_str in prog.function_dict["schedules_solver_results_dict"]: print( - f"Loading solver results from saved schedule: {tmp_sched_str}") - solver_res = prog.function_dict[ - 'schedules_solver_results_dict'][tmp_sched_str] + f"Loading solver results from saved schedule: {tmp_sched_str}" + ) + solver_res = prog.function_dict["schedules_solver_results_dict"][ + tmp_sched_str + ] if solver_res is None: solver_res = prog.call_solver(comp, skew_params) # Save the new solver results if prog.config.environment.use_dataset: - prog.function_dict['schedules_solver_results_dict'][tmp_sched_str] = solver_res + prog.function_dict["schedules_solver_results_dict"][ + tmp_sched_str + ] = solver_res if solver_res == None or solver_res == "-1": return { @@ -508,22 +561,14 @@ def parameter(self, comp=None, prog: TiramisuProgram = None, schedule: list[Opti fuse_comps = list(self.it_dict.keys()) if self.id == 57: # FUSION1 level = 1 - fuse_comps = [ - comp for comp in self.it_dict if 1 in self.it_dict[comp] - ] + fuse_comps = [comp for comp in self.it_dict if 1 in self.it_dict[comp]] if self.id == 58: # FUSION2 level = 2 - fuse_comps = [ - comp for comp in self.it_dict if 2 in self.it_dict[comp] - ] + fuse_comps = [comp for comp in self.it_dict if 2 in self.it_dict[comp]] if self.id == 59: # FUSION3 level = 3 - fuse_comps = [ - comp for comp in self.it_dict if 3 in self.it_dict[comp] - ] + fuse_comps = [comp for comp in self.it_dict if 3 in self.it_dict[comp]] if self.id == 60: # FUSION4 level = 4 - fuse_comps = [ - comp for comp in self.it_dict if 4 in self.it_dict[comp] - ] + fuse_comps = [comp for comp in self.it_dict if 4 in self.it_dict[comp]] return {"dim_index": level, "fuse_comps": fuse_comps} diff --git a/rl_interface/environment.py b/rl_interface/environment.py index d181f04..ad3f491 100644 --- a/rl_interface/environment.py +++ b/rl_interface/environment.py @@ -21,9 +21,10 @@ class TiramisuScheduleEnvironment(gym.Env): - ''' - The reinforcement learning environment used by the GYM. - ''' + """ + The reinforcement learning environment used by the GYM. + """ + SAVING_FREQUENCY = 500 def __init__(self, config: RLAutoSchedulerConfig, dataset_actor): @@ -52,33 +53,34 @@ def __init__(self, config: RLAutoSchedulerConfig, dataset_actor): self.dataset_actor = dataset_actor if config.environment.use_dataset: - self.cpps_path = config.environment.json_dataset['cpps_path'] + self.cpps_path = config.environment.json_dataset["cpps_path"] self.action_space = gym.spaces.Discrete(62) - self.observation_space = gym.spaces.Dict({ - # Computation representation (5 is the MAX computations) - "representation": - gym.spaces.Box(low=-np.inf, high=np.inf, shape=(5, 1052)), - # Mask to hide actions from being taken 62 masks for 62 actions - "action_mask": - gym.spaces.Box(low=0, high=1, shape=(62, )), - # Representation of loops - "loops_representation": - gym.spaces.Box(low=-np.inf, high=np.inf, shape=(15, 26)), - # Loop indices of loops instead in loop i - "child_list": - gym.spaces.Box(low=-np.inf, high=np.inf, shape=(12, 11)), - # Whether loop i has computations or not - "has_comps": - gym.spaces.Box(low=-np.inf, high=np.inf, shape=(12, )), - # Computation indices of all computations inside of a loop (12 loops,5 max computations) - "computations_indices": - gym.spaces.Box(low=-np.inf, high=np.inf, shape=(12, 5)), - # float representation of the padded string format of the program tree - "prog_tree": - gym.spaces.Box(low=-np.inf, high=np.inf, shape=(5000,)) - }) + self.observation_space = gym.spaces.Dict( + { + # Computation representation (5 is the MAX computations) + "representation": gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(5, 1052) + ), + # Mask to hide actions from being taken 62 masks for 62 actions + "action_mask": gym.spaces.Box(low=0, high=1, shape=(62,)), + # Representation of loops + "loops_representation": gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(15, 26) + ), + # Loop indices of loops instead in loop i + "child_list": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(12, 11)), + # Whether loop i has computations or not + "has_comps": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(12,)), + # Computation indices of all computations inside of a loop (12 loops,5 max computations) + "computations_indices": gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(12, 5) + ), + # float representation of the padded string format of the program tree + "prog_tree": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(5000,)), + } + ) def reset(self, file=None): """ @@ -98,45 +100,47 @@ def reset(self, file=None): # get the next function (function_name, function_dict) = ray.get( - self.dataset_actor.get_next_function.remote()) + self.dataset_actor.get_next_function.remote() + ) # Copy the function's files to the dataset copy created - file = CPP_File.get_cpp_file( - self.cpps_path, function_name) + file = CPP_File.get_cpp_file(self.cpps_path, function_name) # Set up the function files to be deleted on the next iteration self.previous_cpp_file = function_name # Load the tiramisu program from the file - self.prog = TiramisuProgram( - self.config, file, function_dict) + self.prog = TiramisuProgram(self.config, file, function_dict) print(f"Trying with program {self.prog.name}") - self.schedule_object = Schedule( - self.prog) + self.schedule_object = Schedule(self.prog) self.schedule_controller = ScheduleController( schedule=self.schedule_object, nb_executions=self.nb_executions, - config=self.config) + config=self.config, + ) # Get the gym representation from the annotations self.obs = self.schedule_object.get_representation() if self.config.tiramisu.env_type == "cpu": print("Getting the initial exe time by execution") - self.prog.initial_execution_time = self.schedule_controller.measurement_env( - [], 'initial_exec', self.nb_executions, - self.prog.initial_execution_time) + self.prog.initial_execution_time = ( + self.schedule_controller.measurement_env( + [], + "initial_exec", + self.nb_executions, + self.prog.initial_execution_time, + ) + ) elif self.config.tiramisu.env_type == "model": self.prog.initial_execution_time = 1.0 except: - print("RESET_ERROR_STDERR", - traceback.format_exc(), file=sys.stderr) - print("RESET_ERROR_STDOUT", - traceback.format_exc(), file=sys.stdout) + print("RESET_ERROR_STDERR", traceback.format_exc(), file=sys.stderr) + print("RESET_ERROR_STDOUT", traceback.format_exc(), file=sys.stdout) continue self.steps = 0 @@ -151,8 +155,9 @@ def step(self, raw_action): Returns: The current state after eventually applying the transformation, and the reward that the agent received for taking the action. """ action_name = Action.ACTIONS_ARRAY[raw_action] - print("\n ----> {} [ {} ] \n".format( - action_name, self.schedule_object.sched_str)) + print( + "\n ----> {} [ {} ] \n".format(action_name, self.schedule_object.sched_str) + ) info = {} applied_exception = False reward = 0.0 @@ -161,23 +166,20 @@ def step(self, raw_action): self.total_steps += 1 try: - action = Action(raw_action, - self.schedule_object.it_dict, - self.schedule_object.common_it) - _, speedup, done, info = self.schedule_controller.apply_action( - action) + action = Action( + raw_action, self.schedule_object.it_dict, self.schedule_object.common_it + ) + _, speedup, done, info = self.schedule_controller.apply_action(action) print("Obtained speedup: ", speedup) except Exception as e: self.schedule_object.repr["action_mask"][action.id] = 0 - print("STEP_ERROR_STDERR: ", - traceback.format_exc(), - file=sys.stderr, - end=" ") - print("STEP_ERROR_STDOUT: ", - traceback.format_exc(), - file=sys.stdout, - end=" ") + print( + "STEP_ERROR_STDERR: ", traceback.format_exc(), file=sys.stderr, end=" " + ) + print( + "STEP_ERROR_STDOUT: ", traceback.format_exc(), file=sys.stdout, end=" " + ) if applied_exception: print("Already Applied exception") info = {"more than one time": True} @@ -193,8 +195,9 @@ def step(self, raw_action): } self.obs = copy.deepcopy(self.schedule_object.get_representation()) - if (self.schedule_controller.depth - == self.schedule_object.MAX_DEPTH) or (self.steps >= 20): + if (self.schedule_controller.depth == self.schedule_object.MAX_DEPTH) or ( + self.steps >= 20 + ): done = True if done: print("\n ************** End of an episode ************") @@ -204,7 +207,8 @@ def step(self, raw_action): speedup = 1.0 # Update dataset with explored legality checks self.dataset_actor.update_dataset.remote( - self.prog.name, self.prog.function_dict) + self.prog.name, self.prog.function_dict + ) reward_object = Reward(speedup) reward = reward_object.reward diff --git a/rl_interface/model.py b/rl_interface/model.py index 2630a00..ff547ea 100644 --- a/rl_interface/model.py +++ b/rl_interface/model.py @@ -14,8 +14,8 @@ from ray.rllib.utils.framework import try_import_torch from tiramisu_programs.surrogate_model_utils.json_to_tensor import get_tree_footprint -train_device_name = 'cpu' # choose training/storing device, either 'cuda:X' or 'cpu' -store_device_name = 'cpu' +train_device_name = "cpu" # choose training/storing device, either 'cuda:X' or 'cpu' +store_device_name = "cpu" store_device = torch.device(store_device_name) @@ -27,46 +27,44 @@ class TiramisuModelMult(TorchModelV2, nn.Module): - - def __init__(self, obs_space, action_space, num_outputs, model_config, name,**kwargs): + def __init__( + self, obs_space, action_space, num_outputs, model_config, name, **kwargs + ): print("in model init") TorchModelV2.__init__( - self, obs_space, action_space, num_outputs, model_config, name,**kwargs) + self, obs_space, action_space, num_outputs, model_config, name, **kwargs + ) nn.Module.__init__(self) - + shared_layer_sizes = model_config["custom_model_config"]["layer_sizes"] - - - embedding_size= shared_layer_sizes[-1] + embedding_size = shared_layer_sizes[-1] - num_outputs=action_space.n + num_outputs = action_space.n - #Computation Embedding Layer + # Computation Embedding Layer prev_layer_size = obs_space["representation"].shape[1] - comp_embd_layers=[] - cpt=0 + comp_embd_layers = [] + cpt = 0 for size in shared_layer_sizes: comp_embd_layers.extend( - [nn.Linear( - prev_layer_size, - size - ), - nn.Dropout(0.02), - nn.ELU() - ] + [nn.Linear(prev_layer_size, size), nn.Dropout(0.02), nn.ELU()] ) prev_layer_size = size - cpt+=1 - + cpt += 1 + self._comp_embd_layers = nn.Sequential(*comp_embd_layers) - #Recursive Loop Embedding Layer - self.comps_lstm = nn.LSTM(shared_layer_sizes[-1], embedding_size, batch_first=True) - self.nodes_lstm = nn.LSTM(shared_layer_sizes[-1], embedding_size, batch_first=True) - + # Recursive Loop Embedding Layer + self.comps_lstm = nn.LSTM( + shared_layer_sizes[-1], embedding_size, batch_first=True + ) + self.nodes_lstm = nn.LSTM( + shared_layer_sizes[-1], embedding_size, batch_first=True + ) + self.no_comps_tensor = nn.Parameter( nn.init.xavier_uniform_(torch.zeros(1, embedding_size)) ) @@ -74,42 +72,30 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name,**kw nn.init.xavier_uniform_(torch.zeros(1, embedding_size)) ) - prev_layer_size=embedding_size*2+26 + prev_layer_size = embedding_size * 2 + 26 hidden_layer_sizes = shared_layer_sizes[-2:] - rec_loop_embd_layers=[] + rec_loop_embd_layers = [] for size in hidden_layer_sizes: rec_loop_embd_layers.extend( - [nn.Linear( - prev_layer_size, - size - ), - nn.Dropout(0.02), - nn.ELU() - ] + [nn.Linear(prev_layer_size, size), nn.Dropout(0.02), nn.ELU()] ) prev_layer_size = size - + self._rec_loop_embd_layers = nn.Sequential(*rec_loop_embd_layers) - #Prediction Layer - predict_layers=[] - + # Prediction Layer + predict_layers = [] + for size in hidden_layer_sizes: predict_layers.extend( - [nn.Linear( - prev_layer_size, - size - ), - nn.Dropout(0.02), - nn.ELU() - ] + [nn.Linear(prev_layer_size, size), nn.Dropout(0.02), nn.ELU()] ) prev_layer_size = size - + self._prediction_layers = nn.Sequential(*predict_layers) - #Outputs - #1 Policy + # Outputs + # 1 Policy self._logits = SlimFC( in_size=prev_layer_size, out_size=num_outputs, @@ -117,7 +103,7 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name,**kw activation_fn=None, ) - #2 Value + # 2 Value self._value_branch = SlimFC( in_size=prev_layer_size, out_size=1, @@ -125,42 +111,45 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name,**kw activation_fn=None, ) - @override(TorchModelV2) def forward(self, input_dict, state, seq_len): obs = input_dict["obs_flat"]["representation"] - - #computation embedding layer - comps_embeddings=self._comp_embd_layers(obs) - - #recursive loop embedding layer - loops_tensor=input_dict["obs_flat"]["loops_representation"] - child_list=input_dict["obs_flat"]["child_list"][:][:][0][0] - has_comps=input_dict["obs_flat"]["has_comps"][0].tolist() + + # computation embedding layer + comps_embeddings = self._comp_embd_layers(obs) + + # recursive loop embedding layer + loops_tensor = input_dict["obs_flat"]["loops_representation"] + child_list = input_dict["obs_flat"]["child_list"][:][:][0][0] + has_comps = input_dict["obs_flat"]["has_comps"][0].tolist() try: prog_tree_tensor = np.array(input_dict["obs_flat"]["prog_tree"]) - prog_tree_string = "".join(list(prog_tree_tensor[0].view('U1'))).strip("_") + prog_tree_string = "".join(list(prog_tree_tensor[0].view("U1"))).strip("_") prog_tree = json.loads(prog_tree_string) # tree_footprint = get_tree_footprint(prog_tree) except: pass # prog_tree_string = # FIX THIS --> Compare tree footprint - computations_indices=input_dict["obs_flat"]["computations_indices"][:][:][0][0] - + computations_indices = input_dict["obs_flat"]["computations_indices"][:][:][0][ + 0 + ] try: - loop_index=0 - prog_embedding=self.get_hidden_state(prog_tree,comps_embeddings,loops_tensor) + loop_index = 0 + prog_embedding = self.get_hidden_state( + prog_tree, comps_embeddings, loops_tensor + ) except: - print("Actor Critic",traceback.format_exc()) - prog_tree = {"child_list":[]} - prog_embedding = torch.zeros((loops_tensor.shape[0],180)) - + print("Actor Critic", traceback.format_exc()) + prog_tree = {"child_list": []} + prog_embedding = torch.zeros((loops_tensor.shape[0], 180)) + # prediction layer - self._features=self._prediction_layers(prog_embedding.view(prog_embedding.shape[0],-1)) + self._features = self._prediction_layers( + prog_embedding.view(prog_embedding.shape[0], -1) + ) logits = self._logits(self._features) - logits=logits-BIG_NUMBER*(1-input_dict["obs_flat"]["action_mask"]) - + logits = logits - BIG_NUMBER * (1 - input_dict["obs_flat"]["action_mask"]) self._value = self._value_branch(self._features) @@ -198,4 +187,4 @@ def get_hidden_state(self, node, comps_embeddings, loops_tensor): ) x = torch.cat((nodes_h_n, comps_h_n, selected_loop_tensor), 2) x = self._rec_loop_embd_layers(x) - return x \ No newline at end of file + return x diff --git a/rl_interface/reward.py b/rl_interface/reward.py index a28fce5..1bffa9c 100644 --- a/rl_interface/reward.py +++ b/rl_interface/reward.py @@ -1,17 +1,18 @@ from builtins import property import math + class Reward: - def __init__(self,reward): - self._reward=reward - + def __init__(self, reward): + self._reward = reward + @property def reward(self): return self.log_reward() - + @reward.setter - def reward(self,value): + def reward(self, value): self._reward = value - - def log_reward(self,base=4): - return math.log(abs(self._reward),base) \ No newline at end of file + + def log_reward(self, base=4): + return math.log(abs(self._reward), base) diff --git a/rl_interface/utils.py b/rl_interface/utils.py index dad5f2e..1ab3017 100644 --- a/rl_interface/utils.py +++ b/rl_interface/utils.py @@ -1,16 +1,15 @@ - import json import os from tiramisu_programs.schedule_utils import NumpyEncoder -class EnvironmentUtils: +class EnvironmentUtils: @classmethod - def write_json_dataset(cls,filename,data): + def write_json_dataset(cls, filename, data): if not os.path.isdir("./Dataset/"): os.mkdir("./Dataset/") dataset_file = os.path.join("./Dataset/", filename) - with open(dataset_file,"w+") as f: - f.write(json.dumps(data,cls=NumpyEncoder)) - return True \ No newline at end of file + with open(dataset_file, "w+") as f: + f.write(json.dumps(data, cls=NumpyEncoder)) + return True diff --git a/tiramisu_programs/cpp_file.py b/tiramisu_programs/cpp_file.py index fee0364..8a4433a 100644 --- a/tiramisu_programs/cpp_file.py +++ b/tiramisu_programs/cpp_file.py @@ -8,12 +8,8 @@ class CPP_File(object): - @classmethod - def compile_and_run_tiramisu_code(cls, - config, - file_path, - log_message="No message"): + def compile_and_run_tiramisu_code(cls, config, file_path, log_message="No message"): """Compiles and runs a C++ file. Args: @@ -27,13 +23,15 @@ def compile_and_run_tiramisu_code(cls, # print("inside compile and run") # Format the path to the cpp file to compile. - os.environ["FUNC_DIR"] = ("/".join(Path(file_path).parts[:-1]) if len( - Path(file_path).parts) > 1 else ".") + "/" + os.environ["FUNC_DIR"] = ( + "/".join(Path(file_path).parts[:-1]) + if len(Path(file_path).parts) > 1 + else "." + ) + "/" os.environ["FILE_PATH"] = file_path # Compile the C++ file. - failed = cls.launch_cmd(config.tiramisu.compile_tiramisu_cmd, - file_path) + failed = cls.launch_cmd(config.tiramisu.compile_tiramisu_cmd, file_path) # Print the program that failed to compile. if failed: @@ -43,20 +41,21 @@ def compile_and_run_tiramisu_code(cls, return False else: # Run the compiled program. - failed = cls.launch_cmd(config.tiramisu.run_tiramisu_cmd, - file_path) + failed = cls.launch_cmd(config.tiramisu.run_tiramisu_cmd, file_path) if failed: print(f"Error occured while running {file_path}") return False return True @classmethod - def launch_cmd(cls, - step_cmd, - file_path, - cmd_type=None, - nb_executions=None, - initial_exec_time=None): + def launch_cmd( + cls, + step_cmd, + file_path, + cmd_type=None, + nb_executions=None, + initial_exec_time=None, + ): """Execute a command on the shell. Args: @@ -109,8 +108,8 @@ def launch_cmd(cls, except Exception as e: print( - f"\n# {str(datetime.now())} ---> Error running {step_cmd} \n" + - e.stderr.decode("UTF-8"), + f"\n# {str(datetime.now())} ---> Error running {step_cmd} \n" + + e.stderr.decode("UTF-8"), file=sys.stderr, flush=True, ) @@ -126,12 +125,19 @@ def launch_cmd(cls, ) failed = True if failed: - func_folder = ("/".join(Path(file_path).parts[:-1]) - if len(Path(file_path).parts) > 1 else ".") + "/" + func_folder = ( + "/".join(Path(file_path).parts[:-1]) + if len(Path(file_path).parts) > 1 + else "." + ) + "/" with open(func_folder + "error.txt", "a") as f: - f.write("\nError running " + step_cmd + - "\n---------------------------\n" + - out.stderr.decode("UTF-8") + "\n") + f.write( + "\nError running " + + step_cmd + + "\n---------------------------\n" + + out.stderr.decode("UTF-8") + + "\n" + ) return failed @classmethod @@ -157,20 +163,19 @@ def get_cpp_file(cls, Dataset_path, func_name): os.system("rm -r {}".format(target_path)) # print("directory removed") - with open(original_path, 'r') as f: + with open(original_path, "r") as f: original_str = f.read() - original_str = original_str.replace( - f'#include "{func_name}_wrapper.h"', '') + original_str = original_str.replace(f'#include "{func_name}_wrapper.h"', "") os.mkdir(target_path) - with open(f"{target_path}/{file_name}", 'w') as f: + with open(f"{target_path}/{file_name}", "w") as f: f.write(original_str) # os.system("cp -r {} {}".format(original_path, target_path)) return target_path + "/" + file_name @classmethod - def clean_cpp_file(cls, func_name): + def clean_cpp_file(cls, func_name): """Clean the files of the function to run from the existing dataset copy. Args: diff --git a/tiramisu_programs/optimization.py b/tiramisu_programs/optimization.py index 7326b71..bc28974 100644 --- a/tiramisu_programs/optimization.py +++ b/tiramisu_programs/optimization.py @@ -1,6 +1,5 @@ class OptimizationCommand: - """Represents a Tirtamisu transformation and maps to Tiramisu code. - """ + """Represents a Tirtamisu transformation and maps to Tiramisu code.""" def __init__(self, optim_type, params_list, comps): assert optim_type in [ @@ -11,7 +10,9 @@ def __init__(self, optim_type, params_list, comps): "Unrolling", "Reversal", "Fusion", - ], ("Unknown transformation: " + optim_type) + ], ( + "Unknown transformation: " + optim_type + ) self.type = optim_type self.params_list = params_list self.comps = comps @@ -21,39 +22,42 @@ def __str__(self) -> str: return f"{self.type} of {self.params_list}" def __repr__(self): - return f'OptimizationCommand(type={self.type}, params_list={self.params_list}, comps={self.comps})' + return f"OptimizationCommand(type={self.type}, params_list={self.params_list}, comps={self.comps})" def get_tiramisu_optim_str(self): """Convert the optimization command into Tiramisu code. Returns: - str: The tiramisu snippet that represents the optimization command. + str: The tiramisu snippet that represents the optimization command. """ if self.type == "Interchange": assert len(self.params_list) == 2 - interchange_str = (".interchange(" + - ",".join([str(p) - for p in self.params_list]) + ");") + interchange_str = ( + ".interchange(" + ",".join([str(p) for p in self.params_list]) + ");" + ) optim_str = "" for comp in self.comps: optim_str += "\n\t{}".format(comp) + interchange_str return optim_str elif self.type == "Skewing": assert len(self.params_list) == 4 - skewing_str = ".skew(" + ",".join( - [str(p) for p in self.params_list]) + ");" + skewing_str = ".skew(" + ",".join([str(p) for p in self.params_list]) + ");" optim_str = "" for comp in self.comps: optim_str += "\n\t{}".format(comp) + skewing_str return optim_str elif self.type == "Parallelization": assert len(self.params_list) == 1 - return ("\t" + self.comps[0] + ".tag_parallel_level(" + - str(self.params_list[0]) + ");") + return ( + "\t" + + self.comps[0] + + ".tag_parallel_level(" + + str(self.params_list[0]) + + ");" + ) elif self.type == "Tiling": assert len(self.params_list) == 4 or len(self.params_list) == 6 - tiling_str = ".tile(" + ",".join( - [str(p) for p in self.params_list]) + ");" + tiling_str = ".tile(" + ",".join([str(p) for p in self.params_list]) + ");" optim_str = "" for comp in self.comps: optim_str += "\n\t{}".format(comp) + tiling_str @@ -62,8 +66,10 @@ def get_tiramisu_optim_str(self): optim_str = "" for comp in self.comps: unrolling_str = ( - ".unroll(" + - ",".join([str(p) for p in self.params_list[comp]]) + ");") + ".unroll(" + + ",".join([str(p) for p in self.params_list[comp]]) + + ");" + ) optim_str += "\n\t{}".format(comp) + unrolling_str return optim_str elif self.type == "Reversal": @@ -77,8 +83,13 @@ def get_tiramisu_optim_str(self): optim_str = "" prev_comp = self.comps[0] for comp in self.comps[1:]: - optim_str += ("\n\t{}".format(prev_comp) + ".then(" + - str(comp) + "," + str(self.params_list[0]) + - ");") + optim_str += ( + "\n\t{}".format(prev_comp) + + ".then(" + + str(comp) + + "," + + str(self.params_list[0]) + + ");" + ) prev_comp = comp return optim_str diff --git a/tiramisu_programs/schedule.py b/tiramisu_programs/schedule.py index 0503459..c51ac6a 100644 --- a/tiramisu_programs/schedule.py +++ b/tiramisu_programs/schedule.py @@ -4,7 +4,9 @@ from tiramisu_programs.schedule_utils import * from tiramisu_programs.surrogate_model_utils.json_to_tensor import ( - get_sched_rep, get_tree_structure) + get_sched_rep, + get_tree_structure, +) global_dioph_sols_dict = dict() EPSILON = 1e-6 @@ -38,9 +40,11 @@ def get_representation(self): return self.repr # Get the schedule representation. - (self.prog_rep, - self.comps_placeholders, - self.comp_indic_dict) = ScheduleUtils.get_representation(self.annotations) + ( + self.prog_rep, + self.comps_placeholders, + self.comp_indic_dict, + ) = ScheduleUtils.get_representation(self.annotations) # Check that all computations are of length 1052 for comp_rep in self.prog_rep: @@ -52,7 +56,8 @@ def get_representation(self): self.comps_it = [] for comp in self.comps: self.comps_it.append( - self.annotations["computations"][comp]["iterators"]) + self.annotations["computations"][comp]["iterators"] + ) self.common_it = self.comps_it[0] for comp_it in self.comps_it[1:]: self.common_it = [it for it in comp_it if it in self.common_it] @@ -62,8 +67,9 @@ def get_representation(self): raise IndexError else: # A single comp program - self.common_it = self.annotations["computations"][ - self.comps[0]]["iterators"] + self.common_it = self.annotations["computations"][self.comps[0]][ + "iterators" + ] # Set up the schedule representation part. self.schedule_dict = dict() @@ -71,32 +77,29 @@ def get_representation(self): # For every comp initial its schedule dict for comp in self.comps: - dim = len(self.annotations['computations'][comp]['iterators']) + dim = len(self.annotations["computations"][comp]["iterators"]) self.schedule_dict[comp] = dict() self.schedule_dict[comp]["dim"] = dim # Product of the list of matrices - self.schedule_dict[comp]["transformation_matrix"] = np.eye( - dim, dim) + self.schedule_dict[comp]["transformation_matrix"] = np.eye(dim, dim) # Transformation matrices - self.schedule_dict[comp]["transformation_matrices"] = [ - np.eye(dim, dim) - ] - self.schedule_dict[comp]['parallelized_dim'] = None - self.schedule_dict[comp]['unrolling_factor'] = None - self.schedule_dict[comp]['tiling'] = None - self.schedule_dict['tree_structure'] = get_tree_structure( - self.annotations) + self.schedule_dict[comp]["transformation_matrices"] = [np.eye(dim, dim)] + self.schedule_dict[comp]["parallelized_dim"] = None + self.schedule_dict[comp]["unrolling_factor"] = None + self.schedule_dict[comp]["tiling"] = None + self.schedule_dict["tree_structure"] = get_tree_structure(self.annotations) self.templates = dict() - (self.templates["prog_tree"], - self.templates["comps_repr_templates_list"], - self.templates["loops_repr_templates_list"], - self.templates["comps_placeholders_indices_dict"], - self.templates["loops_placeholders_indices_dict"]) = get_sched_rep( - self.annotations, - self.schedule_dict, - max_depth=self.MAX_DEPTH - 1) + ( + self.templates["prog_tree"], + self.templates["comps_repr_templates_list"], + self.templates["loops_repr_templates_list"], + self.templates["comps_placeholders_indices_dict"], + self.templates["loops_placeholders_indices_dict"], + ) = get_sched_rep( + self.annotations, self.schedule_dict, max_depth=self.MAX_DEPTH - 1 + ) self.schedule_dict["fusions"] = [] self.placeholders = self.comps_placeholders @@ -104,37 +107,40 @@ def get_representation(self): self.repr = {} self.repr["representation"] = np.empty((0, 1052), np.float32) self.repr["loops_representation"] = np.empty((0, 26), np.float32) - self.repr['child_list'] = np.empty((0, 11), np.float32) - self.repr['has_comps'] = np.empty((0, 12), np.float32) + self.repr["child_list"] = np.empty((0, 11), np.float32) + self.repr["has_comps"] = np.empty((0, 12), np.float32) self.repr["prog_tree"] = np.empty((0, 5000), np.float32) - self.repr['computations_indices'] = np.empty((0, 5), np.float32) + self.repr["computations_indices"] = np.empty((0, 5), np.float32) # Initialize the representation vectors for i in range(5): if i >= len(self.prog_rep): self.repr["representation"] = np.vstack( - [self.repr["representation"], - np.zeros(1052)]) + [self.repr["representation"], np.zeros(1052)] + ) else: - self.repr["representation"] = np.vstack([ - self.repr["representation"], - np.array([self.prog_rep[i]], dtype=np.float32) - ]) + self.repr["representation"] = np.vstack( + [ + self.repr["representation"], + np.array([self.prog_rep[i]], dtype=np.float32), + ] + ) # Dict of the iterators of each computation self.it_dict = {} for comp in self.comps: comp_it_dict = {} - iterators = list( - self.annotations["computations"][comp]["iterators"]) + iterators = list(self.annotations["computations"][comp]["iterators"]) for i in range(len(iterators)): comp_it_dict[i] = {} - comp_it_dict[i]['iterator'] = iterators[i] - comp_it_dict[i]['lower_bound'] = self.annotations['iterators'][ - iterators[i]]['lower_bound'] - comp_it_dict[i]['upper_bound'] = self.annotations['iterators'][ - iterators[i]]['upper_bound'] + comp_it_dict[i]["iterator"] = iterators[i] + comp_it_dict[i]["lower_bound"] = self.annotations["iterators"][ + iterators[i] + ]["lower_bound"] + comp_it_dict[i]["upper_bound"] = self.annotations["iterators"][ + iterators[i] + ]["upper_bound"] self.it_dict[comp] = comp_it_dict @@ -142,22 +148,22 @@ def get_representation(self): iterators = list(self.annotations["iterators"].keys()) for i in range(len(iterators)): loop_repr = [] - loop_repr.append( - self.annotations['iterators'][iterators[i]]['lower_bound']) - loop_repr.append( - self.annotations['iterators'][iterators[i]]['upper_bound']) + loop_repr.append(self.annotations["iterators"][iterators[i]]["lower_bound"]) + loop_repr.append(self.annotations["iterators"][iterators[i]]["upper_bound"]) loop_repr.extend([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # Add log of the same representation to enable feature multiplication in models loop_log_rep = list(np.log1p(loop_repr)) loop_repr.extend(loop_log_rep) self.repr["loops_representation"] = np.vstack( - [self.repr["loops_representation"], - np.array([loop_repr])]) + [self.repr["loops_representation"], np.array([loop_repr])] + ) # Indices of the nested iterators childs_indexes = [ - iterators.index(child) for child in - self.annotations['iterators'][iterators[i]]['child_iterators'] + iterators.index(child) + for child in self.annotations["iterators"][iterators[i]][ + "child_iterators" + ] ] # Max of loops is 12. Pad to the MAX with -1 if len(childs_indexes) != 11: @@ -165,136 +171,439 @@ def get_representation(self): childs_indexes.append(-1) self.repr["child_list"] = np.vstack( - [self.repr["child_list"], - np.array([childs_indexes])]) + [self.repr["child_list"], np.array([childs_indexes])] + ) # Set the iterator's boolean in the has_comps vector - if self.annotations['iterators'][iterators[i]]['computations_list'] != []: - self.repr['has_comps'] = np.append(self.repr['has_comps'], 1) + if self.annotations["iterators"][iterators[i]]["computations_list"] != []: + self.repr["has_comps"] = np.append(self.repr["has_comps"], 1) else: - self.repr['has_comps'] = np.append(self.repr['has_comps'], 0) + self.repr["has_comps"] = np.append(self.repr["has_comps"], 0) # Set the iterator's computations_indices embedding - computations_list = list(self.annotations['computations'].keys()) + computations_list = list(self.annotations["computations"].keys()) loop_comps = [ computations_list.index(comp) - for comp in self.annotations['iterators'][iterators[i]] - ['computations_list'] + for comp in self.annotations["iterators"][iterators[i]][ + "computations_list" + ] ] if len(loop_comps) != 5: for j in range(5 - len(loop_comps)): loop_comps.append(-1) self.repr["computations_indices"] = np.vstack( - [self.repr["computations_indices"], - np.array([loop_comps])]) + [self.repr["computations_indices"], np.array([loop_comps])] + ) # Pad the loops representation to the MAX loops for i in range(15 - len(self.annotations["iterators"])): loop_repr = np.full(26, -1) self.repr["loops_representation"] = np.vstack( - [self.repr["loops_representation"], loop_repr]) + [self.repr["loops_representation"], loop_repr] + ) # Pad the child_list, has_comps, and computations_indices to the MAX nested iterators for i in range(12 - len(self.annotations["iterators"])): self.repr["child_list"] = np.vstack( - [self.repr["child_list"], - np.full(11, -1)]) - self.repr['has_comps'] = np.append(self.repr['has_comps'], 0) + [self.repr["child_list"], np.full(11, -1)] + ) + self.repr["has_comps"] = np.append(self.repr["has_comps"], 0) self.repr["computations_indices"] = np.vstack( - [self.repr["computations_indices"], - np.full(5, -1)]) + [self.repr["computations_indices"], np.full(5, -1)] + ) # Disable some actions based on the number of iterators available if len(self.common_it) == 5: - self.repr["action_mask"] = np.array([ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 - ], - dtype=np.float32) + self.repr["action_mask"] = np.array( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + dtype=np.float32, + ) else: if len(self.common_it) == 4: - self.repr["action_mask"] = np.array([ - 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, - 1, 1 - ], - dtype=np.float32) + self.repr["action_mask"] = np.array( + [ + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + ], + dtype=np.float32, + ) else: if len(self.common_it) == 3: - self.repr["action_mask"] = np.array([ - 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1 - ], - dtype=np.float32) + self.repr["action_mask"] = np.array( + [ + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + ], + dtype=np.float32, + ) else: if len(self.common_it) == 2: - self.repr["action_mask"] = np.array([ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1 - ], - dtype=np.float32) + self.repr["action_mask"] = np.array( + [ + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 0, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + ], + dtype=np.float32, + ) else: if len(self.common_it) == 1: self.repr["action_mask"] = np.array( [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1 + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, ], - dtype=np.float32) + dtype=np.float32, + ) if len(self.comps) == 1: - np.put(self.repr["action_mask"], [56, 57, 58, 59, 60], - [0, 0, 0, 0, 0]) + np.put(self.repr["action_mask"], [56, 57, 58, 59, 60], [0, 0, 0, 0, 0]) # Set up the prog_tree representation in float to bypass gym max_size = 5000 string = json.dumps(self.templates["prog_tree"]) padded_string = string + (max_size - len(string)) * "_" - self.repr["prog_tree"] = np.array( - list(padded_string), "U1").view(np.float32) + self.repr["prog_tree"] = np.array(list(padded_string), "U1").view(np.float32) return self.repr def apply_interchange(self, params): for comp in self.comps: - l_code = "L" + self.it_dict[comp][ - params["first_dim_index"]]['iterator'] + l_code = "L" + self.it_dict[comp][params["first_dim_index"]]["iterator"] self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l_code + "Interchanged"]] = 1 - l_code = "L" + self.it_dict[comp][ - params["second_dim_index"]]['iterator'] + self.placeholders[comp][l_code + "Interchanged"] + ] = 1 + l_code = "L" + self.it_dict[comp][params["second_dim_index"]]["iterator"] self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l_code + "Interchanged"]] = 1 + self.placeholders[comp][l_code + "Interchanged"] + ] = 1 iterators = list(self.annotations["iterators"].keys()) - if self.it_dict[comp][ - params["first_dim_index"]]['iterator'] in iterators: + if self.it_dict[comp][params["first_dim_index"]]["iterator"] in iterators: loop_1 = iterators.index( - self.it_dict[comp][params["first_dim_index"]]['iterator']) - elif self.it_dict[comp][ - params["first_dim_index"]]['iterator'] in self.added_iterators: - loop_1 = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[comp][params["first_dim_index"]]['iterator']) + self.it_dict[comp][params["first_dim_index"]]["iterator"] + ) + elif ( + self.it_dict[comp][params["first_dim_index"]]["iterator"] + in self.added_iterators + ): + loop_1 = len(self.annotations["iterators"]) + self.added_iterators.index( + self.it_dict[comp][params["first_dim_index"]]["iterator"] + ) self.repr["loops_representation"][loop_1][2] = 1 - if self.it_dict[comp][ - params["second_dim_index"]]['iterator'] in iterators: + if self.it_dict[comp][params["second_dim_index"]]["iterator"] in iterators: loop_2 = iterators.index( - self.it_dict[comp][params["second_dim_index"]]['iterator']) - elif self.it_dict[comp][params["second_dim_index"]][ - 'iterator'] in self.added_iterators: - loop_2 = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[comp][params["second_dim_index"]]['iterator']) + self.it_dict[comp][params["second_dim_index"]]["iterator"] + ) + elif ( + self.it_dict[comp][params["second_dim_index"]]["iterator"] + in self.added_iterators + ): + loop_2 = len(self.annotations["iterators"]) + self.added_iterators.index( + self.it_dict[comp][params["second_dim_index"]]["iterator"] + ) self.repr["loops_representation"][loop_2][2] = 1 for i in range(28): @@ -312,10 +621,11 @@ def apply_interchange(self, params): interchange_matrix[first_iter_index, second_iter_index] = 1 interchange_matrix[second_iter_index, first_iter_index] = 1 self.schedule_dict[comp]["transformation_matrices"].append( - interchange_matrix) - self.schedule_dict[comp][ - "transformation_matrix"] = interchange_matrix @ self.schedule_dict[ - comp]["transformation_matrix"] + interchange_matrix + ) + self.schedule_dict[comp]["transformation_matrix"] = ( + interchange_matrix @ self.schedule_dict[comp]["transformation_matrix"] + ) def apply_tiling(self, params): for comp in self.comps: @@ -323,48 +633,60 @@ def apply_tiling(self, params): first_dim_index = params["first_dim_index"] second_dim_index = params["second_dim_index"] - self.schedule_dict[comp]['tiling'] = { - 'tiling_depth': - params["tiling_depth"], - 'tiling_dims': [ - self.it_dict[comp][first_dim_index]['iterator'], - self.it_dict[comp][second_dim_index]['iterator'] - ], - 'tiling_factors': - [params["first_factor"], params["second_factor"]] + self.schedule_dict[comp]["tiling"] = { + "tiling_depth": params["tiling_depth"], + "tiling_dims": [ + self.it_dict[comp][first_dim_index]["iterator"], + self.it_dict[comp][second_dim_index]["iterator"], + ], + "tiling_factors": [params["first_factor"], params["second_factor"]], } - l_code = "L" + self.it_dict[comp][first_dim_index]['iterator'] + l_code = "L" + self.it_dict[comp][first_dim_index]["iterator"] self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l_code + "Tiled"]] = 1 + self.placeholders[comp][l_code + "Tiled"] + ] = 1 self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l_code + - "TileFactor"]] = params["first_factor"] + self.placeholders[comp][l_code + "TileFactor"] + ] = params["first_factor"] if params["tiling_loop_1"]: - new_upper_bound_1 = self.repr["representation"][ - self.comp_indic_dict[comp]][first_dim_index * 20 + - 1] / params["first_factor"] + new_upper_bound_1 = ( + self.repr["representation"][self.comp_indic_dict[comp]][ + first_dim_index * 20 + 1 + ] + / params["first_factor"] + ) self.repr["representation"][self.comp_indic_dict[comp]][ - first_dim_index * 20 + 1] = new_upper_bound_1 + first_dim_index * 20 + 1 + ] = new_upper_bound_1 new_inner_upper_bound_1 = params["first_factor"] self.repr["representation"][self.comp_indic_dict[comp]][ - first_dim_index * 20 + 10] = new_inner_upper_bound_1 + first_dim_index * 20 + 10 + ] = new_inner_upper_bound_1 loop_added = "{}_1".format( - self.it_dict[comp][first_dim_index]['iterator']) + self.it_dict[comp][first_dim_index]["iterator"] + ) self.added_iterators.append(loop_added) - loop_index = len(self.annotations['iterators'] - ) + self.added_iterators.index(loop_added) + loop_index = len( + self.annotations["iterators"] + ) + self.added_iterators.index(loop_added) loop_repr = [] - if self.repr["representation"][comp_index][ - self.placeholders[comp][l_code + "Reversed"]] == 1: + if ( + self.repr["representation"][comp_index][ + self.placeholders[comp][l_code + "Reversed"] + ] + == 1 + ): lower_bound = self.repr["representation"][comp_index][ - second_dim_index * 20 + 1] + second_dim_index * 20 + 1 + ] else: lower_bound = self.repr["representation"][comp_index][ - second_dim_index * 20] + second_dim_index * 20 + ] loop_repr.extend([lower_bound, params["first_factor"]]) loop_repr.extend([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) @@ -372,39 +694,53 @@ def apply_tiling(self, params): loop_repr.extend(loop_log_rep) self.repr["loops_representation"][loop_index] = loop_repr - l_code = "L" + self.it_dict[comp][second_dim_index]['iterator'] + l_code = "L" + self.it_dict[comp][second_dim_index]["iterator"] self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l_code + "Tiled"]] = 1 + self.placeholders[comp][l_code + "Tiled"] + ] = 1 self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][ - l_code + "TileFactor"]] = params["second_factor"] + self.placeholders[comp][l_code + "TileFactor"] + ] = params["second_factor"] if params["tiling_loop_2"]: - new_upper_bound_2 = self.repr["representation"][ - self.comp_indic_dict[comp]][second_dim_index * 20 + - 1] / params["second_factor"] + new_upper_bound_2 = ( + self.repr["representation"][self.comp_indic_dict[comp]][ + second_dim_index * 20 + 1 + ] + / params["second_factor"] + ) self.repr["representation"][self.comp_indic_dict[comp]][ - second_dim_index * 20 + 1] = new_upper_bound_2 + second_dim_index * 20 + 1 + ] = new_upper_bound_2 new_inner_upper_bound_2 = params["second_factor"] self.repr["representation"][self.comp_indic_dict[comp]][ - second_dim_index * 20 + 10] = new_inner_upper_bound_2 + second_dim_index * 20 + 10 + ] = new_inner_upper_bound_2 loop_added = "{}_1".format( - self.it_dict[comp][second_dim_index]['iterator']) + self.it_dict[comp][second_dim_index]["iterator"] + ) self.added_iterators.append(loop_added) - loop_index = len(self.annotations['iterators'] - ) + self.added_iterators.index(loop_added) + loop_index = len( + self.annotations["iterators"] + ) + self.added_iterators.index(loop_added) loop_repr = [] - if self.repr["representation"][comp_index][ - self.placeholders[comp][l_code + "Reversed"]] == 1: + if ( + self.repr["representation"][comp_index][ + self.placeholders[comp][l_code + "Reversed"] + ] + == 1 + ): lower_bound = self.repr["representation"][comp_index][ - second_dim_index * 20 + 1] + second_dim_index * 20 + 1 + ] else: lower_bound = self.repr["representation"][comp_index][ - second_dim_index * 20] + second_dim_index * 20 + ] loop_repr.extend([lower_bound, params["second_factor"]]) loop_repr.extend([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) @@ -414,51 +750,65 @@ def apply_tiling(self, params): if params["tiling_depth"] == 3: third_dim_index = params["third_dim_index"] - self.schedule_dict[comp]['tiling'] = { - 'tiling_depth': - params["tiling_depth"], - 'tiling_dims': [ - self.it_dict[comp][first_dim_index]['iterator'], - self.it_dict[comp][second_dim_index]['iterator'], - self.it_dict[comp][third_dim_index]['iterator'] - ], - 'tiling_factors': [ - params["first_factor"], params["second_factor"], - params["third_factor"] - ] + self.schedule_dict[comp]["tiling"] = { + "tiling_depth": params["tiling_depth"], + "tiling_dims": [ + self.it_dict[comp][first_dim_index]["iterator"], + self.it_dict[comp][second_dim_index]["iterator"], + self.it_dict[comp][third_dim_index]["iterator"], + ], + "tiling_factors": [ + params["first_factor"], + params["second_factor"], + params["third_factor"], + ], } - l_code = "L" + self.it_dict[comp][third_dim_index]['iterator'] + l_code = "L" + self.it_dict[comp][third_dim_index]["iterator"] self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l_code + "Tiled"]] = 1 + self.placeholders[comp][l_code + "Tiled"] + ] = 1 self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][ - l_code + "TileFactor"]] = params["third_factor"] + self.placeholders[comp][l_code + "TileFactor"] + ] = params["third_factor"] if params["tiling_loop_3"]: - new_upper_bound_3 = self.repr["representation"][ - self.comp_indic_dict[comp]][third_dim_index * 20 + - 1] / params["third_factor"] + new_upper_bound_3 = ( + self.repr["representation"][self.comp_indic_dict[comp]][ + third_dim_index * 20 + 1 + ] + / params["third_factor"] + ) self.repr["representation"][self.comp_indic_dict[comp]][ - third_dim_index * 20 + 1] = new_upper_bound_3 + third_dim_index * 20 + 1 + ] = new_upper_bound_3 new_inner_upper_bound_3 = params["third_factor"] self.repr["representation"][self.comp_indic_dict[comp]][ - third_dim_index * 20 + 10] = new_inner_upper_bound_3 + third_dim_index * 20 + 10 + ] = new_inner_upper_bound_3 loop_added = "{}_1".format( - self.it_dict[comp][third_dim_index]['iterator']) + self.it_dict[comp][third_dim_index]["iterator"] + ) self.added_iterators.append(loop_added) - loop_index = len(self.annotations['iterators'] - ) + self.added_iterators.index(loop_added) + loop_index = len( + self.annotations["iterators"] + ) + self.added_iterators.index(loop_added) loop_repr = [] - if self.repr["representation"][comp_index][ - self.placeholders[comp][l_code + "Reversed"]] == 1: + if ( + self.repr["representation"][comp_index][ + self.placeholders[comp][l_code + "Reversed"] + ] + == 1 + ): lower_bound = self.repr["representation"][comp_index][ - third_dim_index * 20 + 1] + third_dim_index * 20 + 1 + ] else: lower_bound = self.repr["representation"][comp_index][ - third_dim_index * 20] + third_dim_index * 20 + ] loop_repr.extend([lower_bound, params["third_factor"]]) @@ -469,364 +819,495 @@ def apply_tiling(self, params): iterators = list(self.annotations["iterators"].keys()) - if self.it_dict[comp][first_dim_index]['iterator'] in iterators: - loop_1 = iterators.index( - self.it_dict[comp][first_dim_index]['iterator']) - elif self.it_dict[comp][first_dim_index][ - 'iterator'] in self.added_iterators: - loop_1 = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[comp][first_dim_index]['iterator']) + if self.it_dict[comp][first_dim_index]["iterator"] in iterators: + loop_1 = iterators.index(self.it_dict[comp][first_dim_index]["iterator"]) + elif self.it_dict[comp][first_dim_index]["iterator"] in self.added_iterators: + loop_1 = len(self.annotations["iterators"]) + self.added_iterators.index( + self.it_dict[comp][first_dim_index]["iterator"] + ) self.repr["loops_representation"][loop_1][3] = 1 - self.repr["loops_representation"][loop_1][4] = params['first_factor'] + self.repr["loops_representation"][loop_1][4] = params["first_factor"] - if self.it_dict[comp][second_dim_index]['iterator'] in iterators: - loop_2 = iterators.index( - self.it_dict[comp][second_dim_index]['iterator']) - elif self.it_dict[comp][second_dim_index][ - 'iterator'] in self.added_iterators: - loop_2 = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[comp][second_dim_index]['iterator']) + if self.it_dict[comp][second_dim_index]["iterator"] in iterators: + loop_2 = iterators.index(self.it_dict[comp][second_dim_index]["iterator"]) + elif self.it_dict[comp][second_dim_index]["iterator"] in self.added_iterators: + loop_2 = len(self.annotations["iterators"]) + self.added_iterators.index( + self.it_dict[comp][second_dim_index]["iterator"] + ) self.repr["loops_representation"][loop_2][3] = 1 - self.repr["loops_representation"][loop_2][4] = params['second_factor'] + self.repr["loops_representation"][loop_2][4] = params["second_factor"] if params["tiling_depth"] == 3: - if self.it_dict[comp][third_dim_index]['iterator'] in iterators: + if self.it_dict[comp][third_dim_index]["iterator"] in iterators: loop_3 = iterators.index( - self.it_dict[comp][third_dim_index]['iterator']) - elif self.it_dict[comp][third_dim_index][ - 'iterator'] in self.added_iterators: + self.it_dict[comp][third_dim_index]["iterator"] + ) + elif ( + self.it_dict[comp][third_dim_index]["iterator"] in self.added_iterators + ): loop_3 = len( - self.annotations['iterators'] + self.annotations["iterators"] ) + self.added_iterators.index( - self.it_dict[comp][third_dim_index]['iterator']) + self.it_dict[comp][third_dim_index]["iterator"] + ) self.repr["loops_representation"][loop_3][3] = 1 - self.repr["loops_representation"][loop_3][4] = params[ - 'third_factor'] + self.repr["loops_representation"][loop_3][4] = params["third_factor"] if self.is_interchaged == False: if len(self.common_it) == 5: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE05, - rl_interface.action.Action.INTERCHANGE06, - rl_interface.action.Action.INTERCHANGE07, - rl_interface.action.Action.INTERCHANGE15, - rl_interface.action.Action.INTERCHANGE16, - rl_interface.action.Action.INTERCHANGE17, - rl_interface.action.Action.INTERCHANGE25, - rl_interface.action.Action.INTERCHANGE26, - rl_interface.action.Action.INTERCHANGE27, - rl_interface.action.Action.INTERCHANGE35, - rl_interface.action.Action.INTERCHANGE36, - rl_interface.action.Action.INTERCHANGE37, - rl_interface.action.Action.INTERCHANGE45, - rl_interface.action.Action.INTERCHANGE46, - rl_interface.action.Action.INTERCHANGE47, - rl_interface.action.Action.INTERCHANGE56, - rl_interface.action.Action.INTERCHANGE57, - rl_interface.action.Action.INTERCHANGE67 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE05, - rl_interface.action.Action.INTERCHANGE06, - rl_interface.action.Action.INTERCHANGE15, - rl_interface.action.Action.INTERCHANGE16, - rl_interface.action.Action.INTERCHANGE25, - rl_interface.action.Action.INTERCHANGE26, - rl_interface.action.Action.INTERCHANGE35, - rl_interface.action.Action.INTERCHANGE36, - rl_interface.action.Action.INTERCHANGE45, - rl_interface.action.Action.INTERCHANGE46, - rl_interface.action.Action.INTERCHANGE56 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE05, - rl_interface.action.Action.INTERCHANGE15, - rl_interface.action.Action.INTERCHANGE25, - rl_interface.action.Action.INTERCHANGE35, - rl_interface.action.Action.INTERCHANGE45 - ]] = 1 + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE05, + rl_interface.action.Action.INTERCHANGE06, + rl_interface.action.Action.INTERCHANGE07, + rl_interface.action.Action.INTERCHANGE15, + rl_interface.action.Action.INTERCHANGE16, + rl_interface.action.Action.INTERCHANGE17, + rl_interface.action.Action.INTERCHANGE25, + rl_interface.action.Action.INTERCHANGE26, + rl_interface.action.Action.INTERCHANGE27, + rl_interface.action.Action.INTERCHANGE35, + rl_interface.action.Action.INTERCHANGE36, + rl_interface.action.Action.INTERCHANGE37, + rl_interface.action.Action.INTERCHANGE45, + rl_interface.action.Action.INTERCHANGE46, + rl_interface.action.Action.INTERCHANGE47, + rl_interface.action.Action.INTERCHANGE56, + rl_interface.action.Action.INTERCHANGE57, + rl_interface.action.Action.INTERCHANGE67, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE05, + rl_interface.action.Action.INTERCHANGE06, + rl_interface.action.Action.INTERCHANGE15, + rl_interface.action.Action.INTERCHANGE16, + rl_interface.action.Action.INTERCHANGE25, + rl_interface.action.Action.INTERCHANGE26, + rl_interface.action.Action.INTERCHANGE35, + rl_interface.action.Action.INTERCHANGE36, + rl_interface.action.Action.INTERCHANGE45, + rl_interface.action.Action.INTERCHANGE46, + rl_interface.action.Action.INTERCHANGE56, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE05, + rl_interface.action.Action.INTERCHANGE15, + rl_interface.action.Action.INTERCHANGE25, + rl_interface.action.Action.INTERCHANGE35, + rl_interface.action.Action.INTERCHANGE45, + ] + ] = 1 if len(self.common_it) == 4: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE04, - rl_interface.action.Action.INTERCHANGE05, - rl_interface.action.Action.INTERCHANGE06, - rl_interface.action.Action.INTERCHANGE14, - rl_interface.action.Action.INTERCHANGE15, - rl_interface.action.Action.INTERCHANGE16, - rl_interface.action.Action.INTERCHANGE24, - rl_interface.action.Action.INTERCHANGE25, - rl_interface.action.Action.INTERCHANGE26, - rl_interface.action.Action.INTERCHANGE34, - rl_interface.action.Action.INTERCHANGE35, - rl_interface.action.Action.INTERCHANGE36, - rl_interface.action.Action.INTERCHANGE45, - rl_interface.action.Action.INTERCHANGE46, - rl_interface.action.Action.INTERCHANGE56 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE04, - rl_interface.action.Action.INTERCHANGE05, - rl_interface.action.Action.INTERCHANGE14, - rl_interface.action.Action.INTERCHANGE15, - rl_interface.action.Action.INTERCHANGE24, - rl_interface.action.Action.INTERCHANGE25, - rl_interface.action.Action.INTERCHANGE34, - rl_interface.action.Action.INTERCHANGE35, - rl_interface.action.Action.INTERCHANGE45 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE04, - rl_interface.action.Action.INTERCHANGE14, - rl_interface.action.Action.INTERCHANGE24, - rl_interface.action.Action.INTERCHANGE34 - ]] = 1 + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE04, + rl_interface.action.Action.INTERCHANGE05, + rl_interface.action.Action.INTERCHANGE06, + rl_interface.action.Action.INTERCHANGE14, + rl_interface.action.Action.INTERCHANGE15, + rl_interface.action.Action.INTERCHANGE16, + rl_interface.action.Action.INTERCHANGE24, + rl_interface.action.Action.INTERCHANGE25, + rl_interface.action.Action.INTERCHANGE26, + rl_interface.action.Action.INTERCHANGE34, + rl_interface.action.Action.INTERCHANGE35, + rl_interface.action.Action.INTERCHANGE36, + rl_interface.action.Action.INTERCHANGE45, + rl_interface.action.Action.INTERCHANGE46, + rl_interface.action.Action.INTERCHANGE56, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE04, + rl_interface.action.Action.INTERCHANGE05, + rl_interface.action.Action.INTERCHANGE14, + rl_interface.action.Action.INTERCHANGE15, + rl_interface.action.Action.INTERCHANGE24, + rl_interface.action.Action.INTERCHANGE25, + rl_interface.action.Action.INTERCHANGE34, + rl_interface.action.Action.INTERCHANGE35, + rl_interface.action.Action.INTERCHANGE45, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE04, + rl_interface.action.Action.INTERCHANGE14, + rl_interface.action.Action.INTERCHANGE24, + rl_interface.action.Action.INTERCHANGE34, + ] + ] = 1 if len(self.common_it) == 3: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE03, - rl_interface.action.Action.INTERCHANGE04, - rl_interface.action.Action.INTERCHANGE05, - rl_interface.action.Action.INTERCHANGE13, - rl_interface.action.Action.INTERCHANGE14, - rl_interface.action.Action.INTERCHANGE15, - rl_interface.action.Action.INTERCHANGE23, - rl_interface.action.Action.INTERCHANGE24, - rl_interface.action.Action.INTERCHANGE25, - rl_interface.action.Action.INTERCHANGE34, - rl_interface.action.Action.INTERCHANGE35, - rl_interface.action.Action.INTERCHANGE45 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE03, - rl_interface.action.Action.INTERCHANGE04, - rl_interface.action.Action.INTERCHANGE13, - rl_interface.action.Action.INTERCHANGE14, - rl_interface.action.Action.INTERCHANGE23, - rl_interface.action.Action.INTERCHANGE24, - rl_interface.action.Action.INTERCHANGE34 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE03, - rl_interface.action.Action.INTERCHANGE13, - rl_interface.action.Action.INTERCHANGE23 - ]] = 1 + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE03, + rl_interface.action.Action.INTERCHANGE04, + rl_interface.action.Action.INTERCHANGE05, + rl_interface.action.Action.INTERCHANGE13, + rl_interface.action.Action.INTERCHANGE14, + rl_interface.action.Action.INTERCHANGE15, + rl_interface.action.Action.INTERCHANGE23, + rl_interface.action.Action.INTERCHANGE24, + rl_interface.action.Action.INTERCHANGE25, + rl_interface.action.Action.INTERCHANGE34, + rl_interface.action.Action.INTERCHANGE35, + rl_interface.action.Action.INTERCHANGE45, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE03, + rl_interface.action.Action.INTERCHANGE04, + rl_interface.action.Action.INTERCHANGE13, + rl_interface.action.Action.INTERCHANGE14, + rl_interface.action.Action.INTERCHANGE23, + rl_interface.action.Action.INTERCHANGE24, + rl_interface.action.Action.INTERCHANGE34, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE03, + rl_interface.action.Action.INTERCHANGE13, + rl_interface.action.Action.INTERCHANGE23, + ] + ] = 1 if len(self.common_it) == 2: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE02, - rl_interface.action.Action.INTERCHANGE03, - rl_interface.action.Action.INTERCHANGE04, - rl_interface.action.Action.INTERCHANGE12, - rl_interface.action.Action.INTERCHANGE13, - rl_interface.action.Action.INTERCHANGE14, - rl_interface.action.Action.INTERCHANGE23, - rl_interface.action.Action.INTERCHANGE24, - rl_interface.action.Action.INTERCHANGE34 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE02, - rl_interface.action.Action.INTERCHANGE03, - rl_interface.action.Action.INTERCHANGE12, - rl_interface.action.Action.INTERCHANGE13, - rl_interface.action.Action.INTERCHANGE23 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE02, - rl_interface.action.Action.INTERCHANGE12 - ]] = 1 + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE02, + rl_interface.action.Action.INTERCHANGE03, + rl_interface.action.Action.INTERCHANGE04, + rl_interface.action.Action.INTERCHANGE12, + rl_interface.action.Action.INTERCHANGE13, + rl_interface.action.Action.INTERCHANGE14, + rl_interface.action.Action.INTERCHANGE23, + rl_interface.action.Action.INTERCHANGE24, + rl_interface.action.Action.INTERCHANGE34, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE02, + rl_interface.action.Action.INTERCHANGE03, + rl_interface.action.Action.INTERCHANGE12, + rl_interface.action.Action.INTERCHANGE13, + rl_interface.action.Action.INTERCHANGE23, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE02, + rl_interface.action.Action.INTERCHANGE12, + ] + ] = 1 if len(self.common_it) == 1: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE01, - rl_interface.action.Action.INTERCHANGE02, - rl_interface.action.Action.INTERCHANGE03, - rl_interface.action.Action.INTERCHANGE12, - rl_interface.action.Action.INTERCHANGE13, - rl_interface.action.Action.INTERCHANGE23 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE01, - rl_interface.action.Action.INTERCHANGE02, - rl_interface.action.Action.INTERCHANGE12, - rl_interface.action.Action.INTERCHANGE13 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.INTERCHANGE01 - ]] = 1 + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE01, + rl_interface.action.Action.INTERCHANGE02, + rl_interface.action.Action.INTERCHANGE03, + rl_interface.action.Action.INTERCHANGE12, + rl_interface.action.Action.INTERCHANGE13, + rl_interface.action.Action.INTERCHANGE23, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.INTERCHANGE01, + rl_interface.action.Action.INTERCHANGE02, + rl_interface.action.Action.INTERCHANGE12, + rl_interface.action.Action.INTERCHANGE13, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [rl_interface.action.Action.INTERCHANGE01] + ] = 1 if self.is_reversed == False: if len(self.common_it) == 5: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL5, - rl_interface.action.Action.REVERSAL6, - rl_interface.action.Action.REVERSAL7 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL5, - rl_interface.action.Action.REVERSAL6 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.REVERSAL5, + rl_interface.action.Action.REVERSAL6, + rl_interface.action.Action.REVERSAL7, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): self.repr["action_mask"][ - rl_interface.action.Action.REVERSAL5] = 1 + [ + rl_interface.action.Action.REVERSAL5, + rl_interface.action.Action.REVERSAL6, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): + self.repr["action_mask"][ + rl_interface.action.Action.REVERSAL5 + ] = 1 elif len(self.common_it) == 4: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL4, - rl_interface.action.Action.REVERSAL5, - rl_interface.action.Action.REVERSAL6 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL4, - rl_interface.action.Action.REVERSAL5 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.REVERSAL4, + rl_interface.action.Action.REVERSAL5, + rl_interface.action.Action.REVERSAL6, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.REVERSAL4, + rl_interface.action.Action.REVERSAL5, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): self.repr["action_mask"][ - rl_interface.action.Action.REVERSAL4] = 1 + rl_interface.action.Action.REVERSAL4 + ] = 1 elif len(self.common_it) == 3: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL3, - rl_interface.action.Action.REVERSAL4, - rl_interface.action.Action.REVERSAL5 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL3, - rl_interface.action.Action.REVERSAL4 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.REVERSAL3, + rl_interface.action.Action.REVERSAL4, + rl_interface.action.Action.REVERSAL5, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.REVERSAL3, + rl_interface.action.Action.REVERSAL4, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): self.repr["action_mask"][ - rl_interface.action.Action.REVERSAL3] = 1 + rl_interface.action.Action.REVERSAL3 + ] = 1 elif len(self.common_it) == 2: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL2, - rl_interface.action.Action.REVERSAL3, - rl_interface.action.Action.REVERSAL4 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL2, - rl_interface.action.Action.REVERSAL3 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.REVERSAL2, + rl_interface.action.Action.REVERSAL3, + rl_interface.action.Action.REVERSAL4, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): self.repr["action_mask"][ - rl_interface.action.Action.REVERSAL2] = 1 + [ + rl_interface.action.Action.REVERSAL2, + rl_interface.action.Action.REVERSAL3, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): + self.repr["action_mask"][ + rl_interface.action.Action.REVERSAL2 + ] = 1 elif len(self.common_it) == 1: - if params["tiling_loop_1"] and params[ - "tiling_loop_2"] and params["tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL1, - rl_interface.action.Action.REVERSAL2, - rl_interface.action.Action.REVERSAL3 - ]] = 1 - elif params["tiling_loop_1"] and params[ - "tiling_loop_2"] or params[ - "tiling_loop_2"] and params[ - "tiling_loop_3"] or params[ - "tiling_loop_1"] and params[ - "tiling_loop_3"]: - self.repr["action_mask"][[ - rl_interface.action.Action.REVERSAL1, - rl_interface.action.Action.REVERSAL2 - ]] = 1 - elif params["tiling_loop_1"] or params[ - "tiling_loop_2"] or params["tiling_loop_3"]: + if ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + and params["tiling_loop_3"] + ): + self.repr["action_mask"][ + [ + rl_interface.action.Action.REVERSAL1, + rl_interface.action.Action.REVERSAL2, + rl_interface.action.Action.REVERSAL3, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + and params["tiling_loop_2"] + or params["tiling_loop_2"] + and params["tiling_loop_3"] + or params["tiling_loop_1"] + and params["tiling_loop_3"] + ): self.repr["action_mask"][ - rl_interface.action.Action.REVERSAL1] = 1 + [ + rl_interface.action.Action.REVERSAL1, + rl_interface.action.Action.REVERSAL2, + ] + ] = 1 + elif ( + params["tiling_loop_1"] + or params["tiling_loop_2"] + or params["tiling_loop_3"] + ): + self.repr["action_mask"][ + rl_interface.action.Action.REVERSAL1 + ] = 1 for i in range(28, 41): self.repr["action_mask"][i] = 0 @@ -839,34 +1320,41 @@ def apply_unrolling(self, params): for comp in self.comps: self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp]["Unrolled"]] = 1 + self.placeholders[comp]["Unrolled"] + ] = 1 self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp] - ["UnrollFactor"]] = params[comp]["unrolling_factor"] + self.placeholders[comp]["UnrollFactor"] + ] = params[comp]["unrolling_factor"] - l_code = "L" + self.it_dict[comp][params[comp] - ["dim_index"]]['iterator'] - index_upper_bound = self.placeholders[comp][l_code + - 'Interchanged'] - 1 + l_code = "L" + self.it_dict[comp][params[comp]["dim_index"]]["iterator"] + index_upper_bound = self.placeholders[comp][l_code + "Interchanged"] - 1 self.repr["representation"][self.comp_indic_dict[comp]][ - index_upper_bound] = self.repr["representation"][ - self.comp_indic_dict[comp]][index_upper_bound] / params[ - comp]["unrolling_factor"] + index_upper_bound + ] = ( + self.repr["representation"][self.comp_indic_dict[comp]][ + index_upper_bound + ] + / params[comp]["unrolling_factor"] + ) iterators = list(self.annotations["iterators"].keys()) - if self.it_dict[comp][params[comp] - ["dim_index"]]['iterator'] in iterators: + if self.it_dict[comp][params[comp]["dim_index"]]["iterator"] in iterators: loop_index = iterators.index( - self.it_dict[comp][params[comp]["dim_index"]]['iterator']) - elif self.it_dict[comp][params[comp]["dim_index"]][ - 'iterator'] in self.added_iterators: + self.it_dict[comp][params[comp]["dim_index"]]["iterator"] + ) + elif ( + self.it_dict[comp][params[comp]["dim_index"]]["iterator"] + in self.added_iterators + ): loop_index = len( - self.annotations['iterators'] + self.annotations["iterators"] ) + self.added_iterators.index( - self.it_dict[comp][params[comp]["dim_index"]]['iterator']) + self.it_dict[comp][params[comp]["dim_index"]]["iterator"] + ) self.repr["loops_representation"][loop_index][5] = 1 self.repr["loops_representation"][loop_index][6] = params[comp][ - 'unrolling_factor'] + "unrolling_factor" + ] for i in range(41, 44): self.repr["action_mask"][i] = 0 @@ -876,7 +1364,8 @@ def apply_unrolling(self, params): try: for comp in self.comps: self.schedule_dict[comp]["unrolling_factor"] = params[comp][ - "unrolling_factor"] + "unrolling_factor" + ] except Exception: print("ERROR_MODEL", traceback.format_exc()) @@ -885,78 +1374,89 @@ def apply_skewing(self, params): dim_2 = params["second_dim_index"] for comp in self.comps: - l1_code = "L" + self.it_dict[comp][dim_1]['iterator'] - l2_code = "L" + self.it_dict[comp][dim_2]['iterator'] - - index1_upper_bound = self.placeholders[comp][l1_code + - 'Interchanged'] - 1 - index1_lower_bound = self.placeholders[comp][l1_code + - 'Interchanged'] - 2 - index2_upper_bound = self.placeholders[comp][l2_code + - 'Interchanged'] - 1 - index2_lower_bound = self.placeholders[comp][l2_code + - 'Interchanged'] - 2 - - l1_lower_bound = self.repr["representation"][ - self.comp_indic_dict[comp]][index1_lower_bound] - l1_upper_bound = self.repr["representation"][ - self.comp_indic_dict[comp]][index1_upper_bound] - l2_lower_bound = self.repr["representation"][ - self.comp_indic_dict[comp]][index2_lower_bound] - l2_upper_bound = self.repr["representation"][ - self.comp_indic_dict[comp]][index2_upper_bound] + l1_code = "L" + self.it_dict[comp][dim_1]["iterator"] + l2_code = "L" + self.it_dict[comp][dim_2]["iterator"] + + index1_upper_bound = self.placeholders[comp][l1_code + "Interchanged"] - 1 + index1_lower_bound = self.placeholders[comp][l1_code + "Interchanged"] - 2 + index2_upper_bound = self.placeholders[comp][l2_code + "Interchanged"] - 1 + index2_lower_bound = self.placeholders[comp][l2_code + "Interchanged"] - 2 + + l1_lower_bound = self.repr["representation"][self.comp_indic_dict[comp]][ + index1_lower_bound + ] + l1_upper_bound = self.repr["representation"][self.comp_indic_dict[comp]][ + index1_upper_bound + ] + l2_lower_bound = self.repr["representation"][self.comp_indic_dict[comp]][ + index2_lower_bound + ] + l2_upper_bound = self.repr["representation"][self.comp_indic_dict[comp]][ + index2_upper_bound + ] l1_extent = l1_upper_bound - l1_lower_bound l2_extent = l2_upper_bound - l2_lower_bound skew_factor = params["first_factor"] self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l1_code + "Skewed"]] = 1 + self.placeholders[comp][l1_code + "Skewed"] + ] = 1 + self.repr["representation"][self.comp_indic_dict[comp]][ + self.placeholders[comp][l1_code + "SkewFactor"] + ] = skew_factor self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l1_code + "SkewFactor"]] = skew_factor - self.repr["representation"][ - self.comp_indic_dict[comp]][index1_lower_bound] = abs( - params["first_factor"]) * l1_lower_bound + index1_lower_bound + ] = (abs(params["first_factor"]) * l1_lower_bound) self.repr["representation"][self.comp_indic_dict[comp]][ - index1_upper_bound] = l1_lower_bound + abs( - params["first_factor"]) * l1_extent + abs( - params["second_factor"]) * l2_extent + index1_upper_bound + ] = ( + l1_lower_bound + + abs(params["first_factor"]) * l1_extent + + abs(params["second_factor"]) * l2_extent + ) skew_factor = params["second_factor"] self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l2_code + "Skewed"]] = 1 + self.placeholders[comp][l2_code + "Skewed"] + ] = 1 self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l2_code + "SkewFactor"]] = skew_factor - self.repr["representation"][ - self.comp_indic_dict[comp]][index2_lower_bound] = 0 + self.placeholders[comp][l2_code + "SkewFactor"] + ] = skew_factor self.repr["representation"][self.comp_indic_dict[comp]][ - index2_upper_bound] = (l2_extent) + 1 + index2_lower_bound + ] = 0 + self.repr["representation"][self.comp_indic_dict[comp]][ + index2_upper_bound + ] = (l2_extent) + 1 iterators = list(self.annotations["iterators"].keys()) - if self.it_dict[comp][dim_1]['iterator'] in iterators: - loop_1 = iterators.index(self.it_dict[comp][dim_1]['iterator']) - elif self.it_dict[comp][dim_1]['iterator'] in self.added_iterators: - loop_1 = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[comp][dim_1]['iterator']) + if self.it_dict[comp][dim_1]["iterator"] in iterators: + loop_1 = iterators.index(self.it_dict[comp][dim_1]["iterator"]) + elif self.it_dict[comp][dim_1]["iterator"] in self.added_iterators: + loop_1 = len(self.annotations["iterators"]) + self.added_iterators.index( + self.it_dict[comp][dim_1]["iterator"] + ) self.repr["loops_representation"][loop_1][7] = 1 - self.repr["loops_representation"][loop_1][8] = params['first_factor'] - - self.repr["loops_representation"][loop_1][9] = self.repr[ - "representation"][0][index1_upper_bound] - self.repr[ - "representation"][0][index1_lower_bound] - - if self.it_dict[comp][dim_2]['iterator'] in iterators: - loop_2 = iterators.index(self.it_dict[comp][dim_2]['iterator']) - elif self.it_dict[comp][dim_2]['iterator'] in self.added_iterators: - loop_2 = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[comp][dim_2]['iterator']) + self.repr["loops_representation"][loop_1][8] = params["first_factor"] + + self.repr["loops_representation"][loop_1][9] = ( + self.repr["representation"][0][index1_upper_bound] + - self.repr["representation"][0][index1_lower_bound] + ) + + if self.it_dict[comp][dim_2]["iterator"] in iterators: + loop_2 = iterators.index(self.it_dict[comp][dim_2]["iterator"]) + elif self.it_dict[comp][dim_2]["iterator"] in self.added_iterators: + loop_2 = len(self.annotations["iterators"]) + self.added_iterators.index( + self.it_dict[comp][dim_2]["iterator"] + ) self.repr["loops_representation"][loop_2][7] = 1 - self.repr["loops_representation"][loop_2][8] = params['second_factor'] - self.repr["loops_representation"][loop_2][9] = self.repr[ - "representation"][0][index2_upper_bound] - self.repr[ - "representation"][0][index2_lower_bound] + self.repr["loops_representation"][loop_2][8] = params["second_factor"] + self.repr["loops_representation"][loop_2][9] = ( + self.repr["representation"][0][index2_upper_bound] + - self.repr["representation"][0][index2_lower_bound] + ) self.repr["action_mask"][44] = 0 self.repr["action_mask"][45] = 0 @@ -974,37 +1474,42 @@ def apply_skewing(self, params): a, b = global_dioph_sols_dict[(first_factor, second_factor)] else: a, b = ScheduleUtils.linear_diophantine_default( - first_factor, second_factor) + first_factor, second_factor + ) skewing_matrix[first_iter_index, first_iter_index] = first_factor skewing_matrix[first_iter_index, second_iter_index] = second_factor skewing_matrix[second_iter_index, first_iter_index] = a skewing_matrix[second_iter_index, second_iter_index] = b - self.schedule_dict[comp]["transformation_matrices"].append( - skewing_matrix) - self.schedule_dict[comp][ - "transformation_matrix"] = skewing_matrix @ self.schedule_dict[ - comp]["transformation_matrix"] + self.schedule_dict[comp]["transformation_matrices"].append(skewing_matrix) + self.schedule_dict[comp]["transformation_matrix"] = ( + skewing_matrix @ self.schedule_dict[comp]["transformation_matrix"] + ) def apply_parallelization(self, params): first_comp = list(self.it_dict.keys())[0] - iterator = self.it_dict[first_comp][params["dim_index"]]['iterator'] + iterator = self.it_dict[first_comp][params["dim_index"]]["iterator"] self.schedule_dict[first_comp]["parallelized_dim"] = iterator l_code = "L" + iterator - self.repr["representation"][0][self.placeholders[first_comp][ - l_code + "Parallelized"]] = 1 + self.repr["representation"][0][ + self.placeholders[first_comp][l_code + "Parallelized"] + ] = 1 iterators = list(self.annotations["iterators"].keys()) - if self.it_dict[first_comp][ - params["dim_index"]]['iterator'] in iterators: + if self.it_dict[first_comp][params["dim_index"]]["iterator"] in iterators: loop_index = iterators.index( - self.it_dict[first_comp][params["dim_index"]]['iterator']) - elif self.it_dict[first_comp][ - params["dim_index"]]['iterator'] in self.added_iterators: + self.it_dict[first_comp][params["dim_index"]]["iterator"] + ) + elif ( + self.it_dict[first_comp][params["dim_index"]]["iterator"] + in self.added_iterators + ): loop_index = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[first_comp][params["dim_index"]]['iterator']) + self.annotations["iterators"] + ) + self.added_iterators.index( + self.it_dict[first_comp][params["dim_index"]]["iterator"] + ) self.repr["loops_representation"][loop_index][10] = 1 self.repr["action_mask"][46] = 0 @@ -1014,33 +1519,40 @@ def apply_parallelization(self, params): def apply_reversal(self, params): for comp in self.comps: - l_code = "L" + self.it_dict[comp][params["dim_index"]]['iterator'] + l_code = "L" + self.it_dict[comp][params["dim_index"]]["iterator"] - index_upper_bound = self.placeholders[comp][l_code + - 'Interchanged'] - 1 - index_lower_bound = self.placeholders[comp][l_code + - 'Interchanged'] - 2 + index_upper_bound = self.placeholders[comp][l_code + "Interchanged"] - 1 + index_lower_bound = self.placeholders[comp][l_code + "Interchanged"] - 2 self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l_code + "Reversed"]] = 1 + self.placeholders[comp][l_code + "Reversed"] + ] = 1 - tmp = self.repr["representation"][ - self.comp_indic_dict[comp]][index_lower_bound] + tmp = self.repr["representation"][self.comp_indic_dict[comp]][ + index_lower_bound + ] + self.repr["representation"][self.comp_indic_dict[comp]][ + index_lower_bound + ] = self.repr["representation"][self.comp_indic_dict[comp]][ + index_upper_bound + ] self.repr["representation"][self.comp_indic_dict[comp]][ - index_lower_bound] = self.repr["representation"][ - self.comp_indic_dict[comp]][index_upper_bound] - self.repr["representation"][ - self.comp_indic_dict[comp]][index_upper_bound] = tmp + index_upper_bound + ] = tmp iterators = list(self.annotations["iterators"].keys()) - if self.it_dict[comp][params["dim_index"]]['iterator'] in iterators: + if self.it_dict[comp][params["dim_index"]]["iterator"] in iterators: loop_index = iterators.index( - self.it_dict[comp][params["dim_index"]]['iterator']) - elif self.it_dict[comp][ - params["dim_index"]]['iterator'] in self.added_iterators: + self.it_dict[comp][params["dim_index"]]["iterator"] + ) + elif ( + self.it_dict[comp][params["dim_index"]]["iterator"] in self.added_iterators + ): loop_index = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[comp][params["dim_index"]]['iterator']) + self.annotations["iterators"] + ) + self.added_iterators.index( + self.it_dict[comp][params["dim_index"]]["iterator"] + ) self.repr["loops_representation"][loop_index][11] = 1 for i in range(48, 56): @@ -1053,31 +1565,35 @@ def apply_reversal(self, params): reversal_matrix = np.eye(dim, dim) dim_index = params["dim_index"] reversal_matrix[dim_index, dim_index] = -1 - self.schedule_dict[comp]["transformation_matrices"].append( - reversal_matrix) - self.schedule_dict[comp][ - "transformation_matrix"] = reversal_matrix @ self.schedule_dict[ - comp]["transformation_matrix"] + self.schedule_dict[comp]["transformation_matrices"].append(reversal_matrix) + self.schedule_dict[comp]["transformation_matrix"] = ( + reversal_matrix @ self.schedule_dict[comp]["transformation_matrix"] + ) def apply_fusion(self, params): fusion = [] for comp in params["fuse_comps"]: fusion.append(comp) - l_code = "L" + self.it_dict[comp][params["dim_index"]]['iterator'] + l_code = "L" + self.it_dict[comp][params["dim_index"]]["iterator"] self.repr["representation"][self.comp_indic_dict[comp]][ - self.placeholders[comp][l_code + "Fused"]] = 1 + self.placeholders[comp][l_code + "Fused"] + ] = 1 fusion.append(params["dim_index"]) self.schedule_dict["fusions"].append(fusion) iterators = list(self.annotations["iterators"].keys()) - if self.it_dict[comp][params["dim_index"]]['iterator'] in iterators: + if self.it_dict[comp][params["dim_index"]]["iterator"] in iterators: loop_index = iterators.index( - self.it_dict[comp][params["dim_index"]]['iterator']) - elif self.it_dict[comp][ - params["dim_index"]]['iterator'] in self.added_iterators: + self.it_dict[comp][params["dim_index"]]["iterator"] + ) + elif ( + self.it_dict[comp][params["dim_index"]]["iterator"] in self.added_iterators + ): loop_index = len( - self.annotations['iterators']) + self.added_iterators.index( - self.it_dict[comp][params["dim_index"]]['iterator']) + self.annotations["iterators"] + ) + self.added_iterators.index( + self.it_dict[comp][params["dim_index"]]["iterator"] + ) self.repr["loops_representation"][loop_index][12] = 1 for i in range(56, 61): diff --git a/tiramisu_programs/schedule_controller.py b/tiramisu_programs/schedule_controller.py index f3737d1..05bbc21 100644 --- a/tiramisu_programs/schedule_controller.py +++ b/tiramisu_programs/schedule_controller.py @@ -1,4 +1,3 @@ - import sys import time import traceback @@ -8,21 +7,28 @@ from tiramisu_programs.optimization import OptimizationCommand from tiramisu_programs.schedule import Schedule -from tiramisu_programs.schedule_utils import ScheduleUtils, IsInterchangedException, IsParallelizedException, IsReversedException, IsSkewedException, IsTiledException, IsUnrolledException, SkewParamsException, SkewUnrollException, LCException -from tiramisu_programs.surrogate_model_utils.json_to_tensor import \ - get_schedule_representation -from tiramisu_programs.surrogate_model_utils.modeling import \ - Model_Recursive_LSTM_v2 +from tiramisu_programs.schedule_utils import ( + ScheduleUtils, + IsInterchangedException, + IsParallelizedException, + IsReversedException, + IsSkewedException, + IsTiledException, + IsUnrolledException, + SkewParamsException, + SkewUnrollException, + LCException, +) +from tiramisu_programs.surrogate_model_utils.json_to_tensor import ( + get_schedule_representation, +) +from tiramisu_programs.surrogate_model_utils.modeling import Model_Recursive_LSTM_v2 global_dioph_sols_dict = dict() class ScheduleController: - - def __init__(self, - schedule: Schedule = None, - nb_executions=5, - config=None): + def __init__(self, schedule: Schedule = None, nb_executions=5, config=None): self.depth = 0 self.schedule = [] self.schedule_object = schedule @@ -39,7 +45,8 @@ def __init__(self, self.schedule_list_model = [] self.model = Model_Recursive_LSTM_v2() self.model.load_state_dict( - torch.load(config.tiramisu.model_checkpoint, map_location="cpu")) + torch.load(config.tiramisu.model_checkpoint, map_location="cpu") + ) self.model.eval() def apply_action(self, action): @@ -57,41 +64,59 @@ def apply_action(self, action): else: comp = list(self.schedule_object.it_dict.keys())[0] action_params = action.parameter( - comp, self.schedule_object.prog, self.schedule) + comp, self.schedule_object.prog, self.schedule + ) if action.id in range(28): # Interchange if not self.schedule_object.is_interchaged: params = [ int(action_params["first_dim_index"]), - int(action_params["second_dim_index"]) + int(action_params["second_dim_index"]), ] - optim1 = OptimizationCommand("Interchange", params, - self.schedule_object.comps) + optim1 = OptimizationCommand( + "Interchange", params, self.schedule_object.comps + ) self.schedule.append(optim1) tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict']: - print( - "Loading legality check from saved schedule") + if ( + self.config.environment.use_dataset + and tmp_sched_str + in self.schedule_object.prog.function_dict[ + "schedules_legality_dict" + ] + ): + print("Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] + "schedules_legality_dict" + ][tmp_sched_str] if self.schedule_object.is_unrolled: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, self.non_skewed_comps, first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, self.non_skewed_comps, first_comp + ) + if saved_legality is None + else saved_legality + ) else: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, first_comp=first_comp + ) + if saved_legality is None + else saved_legality + ) # Save legality check if saved_legality is None: - self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] = lc_check + self.schedule_object.prog.function_dict["schedules_legality_dict"][ + tmp_sched_str + ] = lc_check if lc_check == -1: print("X: The action produced an error.") @@ -117,7 +142,7 @@ def apply_action(self, action): if not self.schedule_object.is_tiled: params = [ int(action_params["first_dim_index"]), - int(action_params["second_dim_index"]) + int(action_params["second_dim_index"]), ] params.append(action_params["first_factor"]) params.append(action_params["second_factor"]) @@ -126,8 +151,9 @@ def apply_action(self, action): params.insert(2, action_params["third_dim_index"]) params.append(action_params["third_factor"]) - optim2 = OptimizationCommand("Tiling", params, - self.schedule_object.comps) + optim2 = OptimizationCommand( + "Tiling", params, self.schedule_object.comps + ) self.schedule.append(optim2) @@ -135,24 +161,40 @@ def apply_action(self, action): print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict']: - print( - "Loading legality check from saved schedule") + if ( + self.config.environment.use_dataset + and tmp_sched_str + in self.schedule_object.prog.function_dict[ + "schedules_legality_dict" + ] + ): + print("Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] + "schedules_legality_dict" + ][tmp_sched_str] if self.schedule_object.is_unrolled: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, self.non_skewed_comps, first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, self.non_skewed_comps, first_comp + ) + if saved_legality is None + else saved_legality + ) else: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, first_comp=first_comp + ) + if saved_legality is None + else saved_legality + ) # Save legality check if saved_legality is None: - self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] = lc_check + self.schedule_object.prog.function_dict["schedules_legality_dict"][ + tmp_sched_str + ] = lc_check if lc_check == -1: print("X: This action produces an error") @@ -173,8 +215,11 @@ def apply_action(self, action): done = True exit = True self.schedule_object.sched_str = ScheduleUtils.sched_str( - self.schedule_object.sched_str, action.id, - action_params, self.schedule_object.comp_indic_dict) + self.schedule_object.sched_str, + action.id, + action_params, + self.schedule_object.comp_indic_dict, + ) else: print("X: Tiling already applied exception") raise IsTiledException @@ -184,46 +229,63 @@ def apply_action(self, action): if not self.schedule_object.is_unrolled: self.non_skewed_comps = [] for comp in self.schedule_object.comps: - it_skewed = "L" + self.schedule_object.it_dict[comp][ - action_params[comp] - ["dim_index"]]["iterator"] + "Skewed" - if self.schedule_object.repr["representation"][ - self.schedule_object.comp_indic_dict[comp]][ - self.schedule_object.placeholders[comp] - [it_skewed]] != 1: + it_skewed = ( + "L" + + self.schedule_object.it_dict[comp][ + action_params[comp]["dim_index"] + ]["iterator"] + + "Skewed" + ) + if ( + self.schedule_object.repr["representation"][ + self.schedule_object.comp_indic_dict[comp] + ][self.schedule_object.placeholders[comp][it_skewed]] + != 1 + ): self.non_skewed_comps.append(comp) for comp in self.non_skewed_comps: params[comp] = [ int(action_params[comp]["dim_index"]), - int(action_params[comp]["unrolling_factor"]) + int(action_params[comp]["unrolling_factor"]), ] if self.non_skewed_comps != []: - optim3 = OptimizationCommand("Unrolling", params, - self.non_skewed_comps) + optim3 = OptimizationCommand( + "Unrolling", params, self.non_skewed_comps + ) self.schedule.append(optim3) - tmp_sched_str = ScheduleUtils.optimlist_to_str( - self.schedule) + tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict']: - print( - "Loading legality check from saved schedule") + if ( + self.config.environment.use_dataset + and tmp_sched_str + in self.schedule_object.prog.function_dict[ + "schedules_legality_dict" + ] + ): + print("Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] + "schedules_legality_dict" + ][tmp_sched_str] start_time = time.time() - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, self.non_skewed_comps, first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, self.non_skewed_comps, first_comp + ) + if saved_legality is None + else saved_legality + ) l_time = time.time() - start_time self.lc_total_time += l_time # Save legality check if saved_legality is None: self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] = lc_check + "schedules_legality_dict" + ][tmp_sched_str] = lc_check if lc_check == -1: print("X: This action produces an error") @@ -243,7 +305,8 @@ def apply_action(self, action): else: lc_check = 0 info[ - 'error'] = "trying to apply unrolling after skewing in one of the computations" + "error" + ] = "trying to apply unrolling after skewing in one of the computations" else: print("X: Unrolling is already applied") @@ -253,63 +316,82 @@ def apply_action(self, action): if not self.schedule_object.is_skewed: - if (action_params["first_factor"] != None - and action_params["second_factor"] != None): + if ( + action_params["first_factor"] != None + and action_params["second_factor"] != None + ): non_inner_comps = [] for comp in self.schedule_object.comps: - if (action_params["first_dim_index"] != - len(self.schedule_object.it_dict[comp]) - 1 - and action_params["second_dim_index"] != - len(self.schedule_object.it_dict[comp]) - 1 - ) or ( - (action_params["first_dim_index"] - == len(self.schedule_object.it_dict[comp]) - 1 - or action_params["second_dim_index"] - == len(self.schedule_object.it_dict[comp]) - 1 - and not self.schedule_object.is_unrolled)): + if ( + action_params["first_dim_index"] + != len(self.schedule_object.it_dict[comp]) - 1 + and action_params["second_dim_index"] + != len(self.schedule_object.it_dict[comp]) - 1 + ) or ( + ( + action_params["first_dim_index"] + == len(self.schedule_object.it_dict[comp]) - 1 + or action_params["second_dim_index"] + == len(self.schedule_object.it_dict[comp]) - 1 + and not self.schedule_object.is_unrolled + ) + ): non_inner_comps.append(comp) if non_inner_comps != []: params = [ int(action_params["first_dim_index"]), - int(action_params["second_dim_index"]) + int(action_params["second_dim_index"]), ] params.append(action_params["first_factor"]) params.append(action_params["second_factor"]) - optim4 = OptimizationCommand("Skewing", params, - non_inner_comps) + optim4 = OptimizationCommand("Skewing", params, non_inner_comps) self.schedule.append(optim4) - tmp_sched_str = ScheduleUtils.optimlist_to_str( - self.schedule) + tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict']: - print( - "Loading legality check from saved schedule") + if ( + self.config.environment.use_dataset + and tmp_sched_str + in self.schedule_object.prog.function_dict[ + "schedules_legality_dict" + ] + ): + print("Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] + "schedules_legality_dict" + ][tmp_sched_str] start_time = time.time() if self.schedule_object.is_unrolled: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, self.non_skewed_comps, - first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, self.non_skewed_comps, first_comp + ) + if saved_legality is None + else saved_legality + ) else: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, first_comp=first_comp + ) + if saved_legality is None + else saved_legality + ) l_time = time.time() - start_time self.lc_total_time += l_time # Save legality check if saved_legality is None: self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] = lc_check + "schedules_legality_dict" + ][tmp_sched_str] = lc_check if lc_check == -1: print("X: This action produces an error") @@ -341,33 +423,50 @@ def apply_action(self, action): if not self.schedule_object.is_parallelized: params = [int(action_params["dim_index"])] - optim5 = OptimizationCommand("Parallelization", params, - self.schedule_object.comps) + optim5 = OptimizationCommand( + "Parallelization", params, self.schedule_object.comps + ) self.schedule.append(optim5) tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict']: - print( - "Loading legality check from saved schedule") + if ( + self.config.environment.use_dataset + and tmp_sched_str + in self.schedule_object.prog.function_dict[ + "schedules_legality_dict" + ] + ): + print("Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] + "schedules_legality_dict" + ][tmp_sched_str] start_time = time.time() if self.schedule_object.is_unrolled: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, self.non_skewed_comps, first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, self.non_skewed_comps, first_comp + ) + if saved_legality is None + else saved_legality + ) else: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, first_comp=first_comp + ) + if saved_legality is None + else saved_legality + ) # Save legality check if saved_legality is None: - self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] = lc_check + self.schedule_object.prog.function_dict["schedules_legality_dict"][ + tmp_sched_str + ] = lc_check l_time = time.time() - start_time self.lc_total_time += l_time @@ -394,35 +493,52 @@ def apply_action(self, action): if not self.schedule_object.is_reversed: params = [int(action_params["dim_index"])] - optim6 = OptimizationCommand("Reversal", params, - self.schedule_object.comps) + optim6 = OptimizationCommand( + "Reversal", params, self.schedule_object.comps + ) self.schedule.append(optim6) tmp_sched_str = ScheduleUtils.optimlist_to_str(self.schedule) print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict']: - print( - "Loading legality check from saved schedule") + if ( + self.config.environment.use_dataset + and tmp_sched_str + in self.schedule_object.prog.function_dict[ + "schedules_legality_dict" + ] + ): + print("Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] + "schedules_legality_dict" + ][tmp_sched_str] start_time = time.time() if self.schedule_object.is_unrolled: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, self.non_skewed_comps, first_comp=first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, self.non_skewed_comps, first_comp=first_comp + ) + if saved_legality is None + else saved_legality + ) else: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, first_comp=first_comp + ) + if saved_legality is None + else saved_legality + ) l_time = time.time() - start_time self.lc_total_time += l_time # Save legality check if saved_legality is None: - self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] = lc_check + self.schedule_object.prog.function_dict["schedules_legality_dict"][ + tmp_sched_str + ] = lc_check if lc_check == -1: print("X: This action produces am error") @@ -444,14 +560,15 @@ def apply_action(self, action): raise IsReversedException if action.id in range(56, 61): # Fusion - params = [ - int(action_params["dim_index"]), action_params["fuse_comps"] - ] - if action_params["fuse_comps"] != [] and len( - action_params["fuse_comps"]) != 1: + params = [int(action_params["dim_index"]), action_params["fuse_comps"]] + if ( + action_params["fuse_comps"] != [] + and len(action_params["fuse_comps"]) != 1 + ): - optim7 = OptimizationCommand("Fusion", params, - action_params["fuse_comps"]) + optim7 = OptimizationCommand( + "Fusion", params, action_params["fuse_comps"] + ) self.schedule.append(optim7) @@ -459,29 +576,45 @@ def apply_action(self, action): print(tmp_sched_str) # check if we can find the schedule in the dataset load the legality check - if self.config.environment.use_dataset and tmp_sched_str in self.schedule_object.prog.function_dict[ - 'schedules_legality_dict']: - print( - "Loading legality check from saved schedule") + if ( + self.config.environment.use_dataset + and tmp_sched_str + in self.schedule_object.prog.function_dict[ + "schedules_legality_dict" + ] + ): + print("Loading legality check from saved schedule") saved_legality = self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] + "schedules_legality_dict" + ][tmp_sched_str] start_time = time.time() if self.schedule_object.is_unrolled: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, self.non_skewed_comps, first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, self.non_skewed_comps, first_comp + ) + if saved_legality is None + else saved_legality + ) else: - lc_check = self.schedule_object.prog.check_legality_of_schedule( - self.schedule, first_comp=first_comp) if saved_legality is None else saved_legality + lc_check = ( + self.schedule_object.prog.check_legality_of_schedule( + self.schedule, first_comp=first_comp + ) + if saved_legality is None + else saved_legality + ) l_time = time.time() - start_time self.lc_total_time += l_time # Save legality check if saved_legality is None: - self.schedule_object.prog.function_dict[ - 'schedules_legality_dict'][tmp_sched_str] = lc_check + self.schedule_object.prog.function_dict["schedules_legality_dict"][ + tmp_sched_str + ] = lc_check if lc_check == -1: print("X: This action produces an error") @@ -506,7 +639,7 @@ def apply_action(self, action): done = True exit = True - if (not exit and lc_check != 0): + if not exit and lc_check != 0: # Changed the sched_str to be updated after all successfull application of actions self.schedule_object.sched_str = tmp_sched_str if not (action.id in range(41, 44) and self.schedule_object.is_skewed): @@ -515,9 +648,12 @@ def apply_action(self, action): # self.schedule_object.comp_indic_dict) if not action.id in range(41, 44): self.schedule_object.it_dict = ScheduleUtils.update_iterators( - action.id, self.schedule_object.it_dict, action_params, + action.id, + self.schedule_object.it_dict, + action_params, self.schedule_object.added_iterators, - self.schedule_object.comp_indic_dict) + self.schedule_object.comp_indic_dict, + ) self.depth += 1 return self.schedule_object.repr, 1.0, done, info @@ -537,15 +673,15 @@ def get_final_score(self): exec_time = self.get_exec_time() speedup = 1.0 if exec_time != 0: - speedup = (self.schedule_object.prog.initial_execution_time / - exec_time) + speedup = self.schedule_object.prog.initial_execution_time / exec_time return speedup def test_additional_actions(self, training=True): info = dict() if training: print( - "This operation alters the training and, therefore, it won't be executed") + "This operation alters the training and, therefore, it won't be executed" + ) try: exec_time = 0 exec_time = self.get_exec_time() @@ -563,30 +699,43 @@ def test_additional_actions(self, training=True): unroll_factor = unroll_optimisation.params_list[comp][1] new_unrolling_params[comp] = { "dim_index": len(self.schedule_object.it_dict[comp]) - 1, - "unrolling_factor": unroll_factor + "unrolling_factor": unroll_factor, } new_unrolling_optim_params[comp] = [ - len(self.schedule_object.it_dict[comp] - ) - 1, unroll_factor + len(self.schedule_object.it_dict[comp]) - 1, + unroll_factor, ] new_unrolling_optim = OptimizationCommand( - "Unrolling", new_unrolling_optim_params, self.non_skewed_comps) + "Unrolling", new_unrolling_optim_params, self.non_skewed_comps + ) new_unrolling_str = "" unrolling_str = "" for comp in self.non_skewed_comps: unroll_factor = unroll_optimisation.params_list[comp][1] - new_unrolling_str += "U(L" + str( - len(self.schedule_object.it_dict[comp]) - - 1) + "," + str(unroll_factor) + ",C" + str( - self.schedule_object.comp_indic_dict[comp]) + ")" - unrolling_str += "U(L" + str( - unroll_optimisation.params_list[comp][0]) + "," + str( - unroll_factor) + ",C" + str( - self.schedule_object.comp_indic_dict[comp]) + ")" - self.schedule_object.sched_str = self.schedule_object.sched_str.replace( - unrolling_str, "") + new_unrolling_str + new_unrolling_str += ( + "U(L" + + str(len(self.schedule_object.it_dict[comp]) - 1) + + "," + + str(unroll_factor) + + ",C" + + str(self.schedule_object.comp_indic_dict[comp]) + + ")" + ) + unrolling_str += ( + "U(L" + + str(unroll_optimisation.params_list[comp][0]) + + "," + + str(unroll_factor) + + ",C" + + str(self.schedule_object.comp_indic_dict[comp]) + + ")" + ) + self.schedule_object.sched_str = ( + self.schedule_object.sched_str.replace(unrolling_str, "") + + new_unrolling_str + ) self.schedule.remove(unroll_optimisation) self.schedule.append(new_unrolling_optim) self.schedule_object.apply_unrolling(new_unrolling_params) @@ -599,82 +748,97 @@ def test_additional_actions(self, training=True): if not self.schedule_object.is_parallelized: print("Testing if parallelization improves the performance...") - action = Action(Action.PARALLELIZATION0, - self.schedule_object.it_dict, - self.schedule_object.common_it) + action = Action( + Action.PARALLELIZATION0, + self.schedule_object.it_dict, + self.schedule_object.common_it, + ) action_params = action.parameter() params = [int(action_params["dim_index"])] - optim5 = OptimizationCommand("Parallelization", params, - self.schedule_object.comps) + optim5 = OptimizationCommand( + "Parallelization", params, self.schedule_object.comps + ) first_comp = list(self.schedule_object.it_dict.keys())[0] iterator = self.schedule_object.it_dict[first_comp][ - action_params["dim_index"]]['iterator'] + action_params["dim_index"] + ]["iterator"] self.schedule_object.schedule_dict[first_comp][ - "parallelized_dim"] = iterator + "parallelized_dim" + ] = iterator self.schedule.append(optim5) try: self.schedule_object.sched_str = ScheduleUtils.sched_str( - self.schedule_object.sched_str, action.id, - action_params, self.schedule_object.comp_indic_dict) + self.schedule_object.sched_str, + action.id, + action_params, + self.schedule_object.comp_indic_dict, + ) parallelized_exec_time = self.get_exec_time() - parallelization_str = 'P(L' + str( - action_params["dim_index"]) + ')' + parallelization_str = ( + "P(L" + str(action_params["dim_index"]) + ")" + ) except: print("X: Illegal action") self.schedule.remove(optim5) - self.schedule_object.sched_str = self.schedule_object.sched_str.replace( - parallelization_str, "") - - if parallelized_exec_time < exec_time and parallelized_exec_time != 0: + self.schedule_object.sched_str = ( + self.schedule_object.sched_str.replace( + parallelization_str, "" + ) + ) + + if ( + parallelized_exec_time < exec_time + and parallelized_exec_time != 0 + ): exec_time = parallelized_exec_time - self.schedule_object.apply_parallelization( - action_params) + self.schedule_object.apply_parallelization(action_params) print("O: Parallelization improves the performance.") else: self.schedule.remove(optim5) - self.schedule_object.sched_str = self.schedule_object.sched_str.replace( - parallelization_str, "") + self.schedule_object.sched_str = ( + self.schedule_object.sched_str.replace( + parallelization_str, "" + ) + ) self.schedule_object.schedule_dict[first_comp][ - "parallelized_dim"] = None + "parallelized_dim" + ] = None print("X: Parallelization improves the performance") except: print("X: Error while measuring performance") - print(f"failed to save schedule", - traceback.format_exc(), - flush=True) + print(f"failed to save schedule", traceback.format_exc(), flush=True) info = {"Internal execution error": True} return self.schedule_object.repr, self.speedup, True, info if exec_time != 0: - print("\nThe final schedule is ", - self.schedule_object.sched_str) - self.speedup = ( - self.schedule_object.prog.initial_execution_time / exec_time) + print("\nThe final schedule is ", self.schedule_object.sched_str) + self.speedup = self.schedule_object.prog.initial_execution_time / exec_time print("The speedup is: ", self.speedup) start_time = time.time() info["depth"] = self.depth return self.schedule_object.repr, self.speedup, True, info - def get_exec_time_by_model(self, optims_list, cmd_type, nb_executions, - initial_exec_time): - self.schedule_list_model.append({ - "sched_str": - self.schedule_object.sched_str, - "schedule_dict": - self.schedule_object.schedule_dict - }) + def get_exec_time_by_model( + self, optims_list, cmd_type, nb_executions, initial_exec_time + ): + self.schedule_list_model.append( + { + "sched_str": self.schedule_object.sched_str, + "schedule_dict": self.schedule_object.schedule_dict, + } + ) stat = dict() try: computations_tensor, loops_tensor = get_schedule_representation( @@ -682,24 +846,28 @@ def get_exec_time_by_model(self, optims_list, cmd_type, nb_executions, self.schedule_object.schedule_dict, self.schedule_object.templates["comps_repr_templates_list"], self.schedule_object.templates["loops_repr_templates_list"], - self.schedule_object. - templates["comps_placeholders_indices_dict"], - self.schedule_object. - templates["loops_placeholders_indices_dict"], - max_depth=self.schedule_object.MAX_DEPTH - 1) - tree_tensors = (self.schedule_object.templates["prog_tree"], - computations_tensor, loops_tensor) - + self.schedule_object.templates["comps_placeholders_indices_dict"], + self.schedule_object.templates["loops_placeholders_indices_dict"], + max_depth=self.schedule_object.MAX_DEPTH - 1, + ) + tree_tensors = ( + self.schedule_object.templates["prog_tree"], + computations_tensor, + loops_tensor, + ) + with torch.no_grad(): predicted_speedup = self.model( - tree_tensors, - num_matrices=self.schedule_object.MAX_DEPTH - 1).item() + tree_tensors, num_matrices=self.schedule_object.MAX_DEPTH - 1 + ).item() stat[ - "initial_execution_time"] = self.schedule_object.prog.initial_execution_time + "initial_execution_time" + ] = self.schedule_object.prog.initial_execution_time stat["predicted_speedup"] = predicted_speedup print(f"The predicted speedup is {predicted_speedup}") - stat[ - "predicted_execution_time"] = self.schedule_object.prog.initial_execution_time / predicted_speedup + stat["predicted_execution_time"] = ( + self.schedule_object.prog.initial_execution_time / predicted_speedup + ) except Exception: print("ERROR_MODEL", traceback.format_exc()) print(sys.exc_info()[2]) @@ -711,8 +879,11 @@ def get_exec_time(self): if self.schedule_object.sched_str != "" and self.schedule != []: execution_time = self.measurement_env( - self.schedule, 'sched_eval', self.nb_executions, - self.schedule_object.prog.initial_execution_time) + self.schedule, + "sched_eval", + self.nb_executions, + self.schedule_object.prog.initial_execution_time, + ) else: execution_time = self.schedule_object.prog.initial_execution_time return execution_time diff --git a/tiramisu_programs/schedule_utils.py b/tiramisu_programs/schedule_utils.py index 7805af3..8aa345e 100644 --- a/tiramisu_programs/schedule_utils.py +++ b/tiramisu_programs/schedule_utils.py @@ -34,7 +34,6 @@ class RepresentationLengthException(Exception): class NumpyEncoder(json.JSONEncoder): - def default(self, obj): if isinstance(obj, np.ndarray): return obj.flatten().tolist() @@ -78,7 +77,6 @@ class SkewUnrollException(Exception): class ScheduleUtils: - @classmethod def linear_diophantine_default(cls, f_i, f_j): found = False @@ -112,20 +110,20 @@ def pad_access_matrix(cls, access_matrix, max_depth): access_matrix = np.c_[np.ones(access_matrix.shape[0]), access_matrix] access_matrix = np.r_[[np.ones(access_matrix.shape[1])], access_matrix] padded_access_matrix = np.zeros((max_depth + 1, max_depth + 2)) - padded_access_matrix[:access_matrix.shape[0], :access_matrix.shape[1] - - 1] = access_matrix[:, :-1] - padded_access_matrix[:access_matrix.shape[0], -1] = access_matrix[:, - -1] + padded_access_matrix[ + : access_matrix.shape[0], : access_matrix.shape[1] - 1 + ] = access_matrix[:, :-1] + padded_access_matrix[: access_matrix.shape[0], -1] = access_matrix[:, -1] return padded_access_matrix @classmethod def isl_to_write_matrix(cls, isl_map): - comp_iterators_str = re.findall(r'\[(.*)\]\s*->', isl_map)[0] - buffer_iterators_str = re.findall(r'->\s*\w*\[(.*)\]', isl_map)[0] + comp_iterators_str = re.findall(r"\[(.*)\]\s*->", isl_map)[0] + buffer_iterators_str = re.findall(r"->\s*\w*\[(.*)\]", isl_map)[0] buffer_iterators_str = re.sub(r"\w+'\s=", "", buffer_iterators_str) - comp_iter_names = re.findall(r'(?:\s*(\w+))+', comp_iterators_str) - buf_iter_names = re.findall(r'(?:\s*(\w+))+', buffer_iterators_str) + comp_iter_names = re.findall(r"(?:\s*(\w+))+", comp_iterators_str) + buf_iter_names = re.findall(r"(?:\s*(\w+))+", buffer_iterators_str) matrix = np.zeros([len(buf_iter_names), len(comp_iter_names) + 1]) for i, buf_iter in enumerate(buf_iter_names): for j, comp_iter in enumerate(comp_iter_names): @@ -139,11 +137,18 @@ def sched_json_to_sched_str(cls, sched_json, program_json): comp_name = [ n for n in sched_json.keys() - if not n in ["unfuse_iterators", "tree_structure", "execution_times", "fusions", "sched_str"] + if not n + in [ + "unfuse_iterators", + "tree_structure", + "execution_times", + "fusions", + "sched_str", + ] ] sched_str = "" - if ("fusions" in sched_json and sched_json["fusions"]): + if "fusions" in sched_json and sched_json["fusions"]: for fusion in sched_json["fusions"]: sched_str += "F(" for name in comp_name: @@ -156,24 +161,37 @@ def sched_json_to_sched_str(cls, sched_json, program_json): for name in comp_name: transf_loop_nest = cls.get_original_iterators(program_json) schedule = sched_json[name] - sched_str += '{' + name + '}:' + sched_str += "{" + name + "}:" for transformation in schedule["transformations_list"]: - if (transformation[0] == 1): - sched_str += "I(L" + str(transformation[1]) + \ - ",L" + str(transformation[2]) + ")" + if transformation[0] == 1: + sched_str += ( + "I(L" + + str(transformation[1]) + + ",L" + + str(transformation[2]) + + ")" + ) - elif (transformation[0] == 2): + elif transformation[0] == 2: sched_str += f"R(L{str(transformation[3])})" - elif (transformation[0] == 3): - sched_str += "S(L" + str(transformation[4]) + ",L" + str( - transformation[5]) + "," + str(transformation[6]) + "," + str(transformation[7]) + ")" + elif transformation[0] == 3: + sched_str += ( + "S(L" + + str(transformation[4]) + + ",L" + + str(transformation[5]) + + "," + + str(transformation[6]) + + "," + + str(transformation[7]) + + ")" + ) if schedule["parallelized_dim"]: - dim_index = transf_loop_nest.index( - schedule["parallelized_dim"]) + dim_index = transf_loop_nest.index(schedule["parallelized_dim"]) sched_str += "P(L" + str(dim_index) + ")" if schedule["tiling"]: @@ -196,11 +214,15 @@ def sched_json_to_sched_str(cls, sched_json, program_json): + ")" ) i = transf_loop_nest.index(first_dim) - transf_loop_nest[i: i + 1] = first_dim + \ - "_outer", second_dim + "_outer" + transf_loop_nest[i : i + 1] = ( + first_dim + "_outer", + second_dim + "_outer", + ) i = transf_loop_nest.index(second_dim) - transf_loop_nest[i: i + 1] = first_dim + \ - "_inner", second_dim + "_inner" + transf_loop_nest[i : i + 1] = ( + first_dim + "_inner", + second_dim + "_inner", + ) else: first_dim = schedule["tiling"]["tiling_dims"][0] second_dim = schedule["tiling"]["tiling_dims"][1] @@ -227,13 +249,13 @@ def sched_json_to_sched_str(cls, sched_json, program_json): + ")" ) i = transf_loop_nest.index(first_dim) - transf_loop_nest[i: i + 1] = ( + transf_loop_nest[i : i + 1] = ( first_dim + "_outer", second_dim + "_outer", third_dim + "_outer", ) i = transf_loop_nest.index(second_dim) - transf_loop_nest[i: i + 1] = ( + transf_loop_nest[i : i + 1] = ( first_dim + "_inner", second_dim + "_inner", third_dim + "_inner", @@ -243,9 +265,10 @@ def sched_json_to_sched_str(cls, sched_json, program_json): if schedule["unrolling_factor"]: dim_index = len(transf_loop_nest) - 1 dim_name = transf_loop_nest[-1] - sched_str += "U(L" + str(dim_index) + "," + \ - schedule["unrolling_factor"] + ")" - transf_loop_nest[dim_index: dim_index + 1] = ( + sched_str += ( + "U(L" + str(dim_index) + "," + schedule["unrolling_factor"] + ")" + ) + transf_loop_nest[dim_index : dim_index + 1] = ( dim_name + "_Uouter", dim_name + "_Uinner", ) @@ -253,11 +276,11 @@ def sched_json_to_sched_str(cls, sched_json, program_json): @classmethod def get_original_iterators(cls, program_json): - iterators = program_json['iterators'] + iterators = program_json["iterators"] to_explore = [] result = [] to_explore.append(list(iterators.keys())[0]) - while (to_explore): + while to_explore: it_name = to_explore.pop(0) iterator = iterators[it_name] result.append(it_name) @@ -278,8 +301,8 @@ def get_schedules_str(cls, programs_list, programs_dict): functions_set = {} for fun in programs_list: - if 'schedules_list' in programs_dict[fun].keys(): - schedules = programs_dict[fun]['schedules_list'] + if "schedules_list" in programs_dict[fun].keys(): + schedules = programs_dict[fun]["schedules_list"] schedules_set = {} @@ -300,10 +323,11 @@ def get_representation(cls, program_annot): max_accesses = 21 program_representation = [] indices_dict = dict() - computations_dict = program_annot['computations'] + computations_dict = program_annot["computations"] ordered_comp_list = sorted( list(computations_dict.keys()), - key=lambda x: computations_dict[x]['absolute_order']) + key=lambda x: computations_dict[x]["absolute_order"], + ) placeholders_comp = {} @@ -312,61 +336,77 @@ def get_representation(cls, program_annot): comp_representation = [] iterators_repr = [] - for iter_i, iterator_name in enumerate(comp_dict['iterators']): - iterator_dict = program_annot['iterators'][iterator_name] - iterators_repr.extend([ - iterator_dict['lower_bound'], iterator_dict['upper_bound'] - ]) - - l_code = 'L' + iterator_name - iterators_repr.extend([ - l_code + 'Interchanged', l_code + 'Skewed', - l_code + 'SkewFactor', l_code + 'Parallelized', - l_code + 'Tiled', l_code + 'TileFactor', - l_code + 'Reversed', l_code + 'Fused', 0, 0, - l_code + "_1" + 'Interchanged', l_code + "_1" + 'Skewed', - l_code + "_1" + 'SkewFactor', - l_code + "_1" + 'Parallelized', l_code + "_1" + 'Tiled', - l_code + "_1" + 'TileFactor', l_code + "_1" + 'Reversed', - l_code + "_1" + 'Fused' - ]) + for iter_i, iterator_name in enumerate(comp_dict["iterators"]): + iterator_dict = program_annot["iterators"][iterator_name] + iterators_repr.extend( + [iterator_dict["lower_bound"], iterator_dict["upper_bound"]] + ) + + l_code = "L" + iterator_name + iterators_repr.extend( + [ + l_code + "Interchanged", + l_code + "Skewed", + l_code + "SkewFactor", + l_code + "Parallelized", + l_code + "Tiled", + l_code + "TileFactor", + l_code + "Reversed", + l_code + "Fused", + 0, + 0, + l_code + "_1" + "Interchanged", + l_code + "_1" + "Skewed", + l_code + "_1" + "SkewFactor", + l_code + "_1" + "Parallelized", + l_code + "_1" + "Tiled", + l_code + "_1" + "TileFactor", + l_code + "_1" + "Reversed", + l_code + "_1" + "Fused", + ] + ) iterator_repr_size = int( - len(iterators_repr) / (2 * len(comp_dict['iterators']))) - iterators_repr.extend([0] * iterator_repr_size * 2 * - (max_depth - len(comp_dict['iterators']))) + len(iterators_repr) / (2 * len(comp_dict["iterators"])) + ) + iterators_repr.extend( + [0] * iterator_repr_size * 2 * (max_depth - len(comp_dict["iterators"])) + ) - iterators_repr.extend(['Unrolled', 'UnrollFactor']) + iterators_repr.extend(["Unrolled", "UnrollFactor"]) comp_representation.extend(iterators_repr) padded_write_matrix = cls.pad_access_matrix( - cls.isl_to_write_matrix(comp_dict['write_access_relation']), - max_depth) - write_access_repr = [comp_dict['write_buffer_id'] + 1 - ] + padded_write_matrix.flatten().tolist() + cls.isl_to_write_matrix(comp_dict["write_access_relation"]), max_depth + ) + write_access_repr = [ + comp_dict["write_buffer_id"] + 1 + ] + padded_write_matrix.flatten().tolist() comp_representation.extend(write_access_repr) read_accesses_repr = [] - for read_access_dict in comp_dict['accesses']: + for read_access_dict in comp_dict["accesses"]: read_access_matrix = cls.pad_access_matrix( - read_access_dict['access_matrix'], max_depth) - read_access_repr = [read_access_dict['buffer_id'] + 1 - ] + read_access_matrix.flatten().tolist() + read_access_dict["access_matrix"], max_depth + ) + read_access_repr = [ + read_access_dict["buffer_id"] + 1 + ] + read_access_matrix.flatten().tolist() read_accesses_repr.extend(read_access_repr) access_repr_len = (max_depth + 1) * (max_depth + 2) + 1 read_accesses_repr.extend( - [0] * access_repr_len * - (max_accesses - len(comp_dict['accesses']))) + [0] * access_repr_len * (max_accesses - len(comp_dict["accesses"])) + ) comp_representation.extend(read_accesses_repr) - comp_representation.append(comp_dict['number_of_additions']) - comp_representation.append(comp_dict['number_of_subtraction']) - comp_representation.append(comp_dict['number_of_multiplication']) - comp_representation.append(comp_dict['number_of_division']) + comp_representation.append(comp_dict["number_of_additions"]) + comp_representation.append(comp_dict["number_of_subtraction"]) + comp_representation.append(comp_dict["number_of_multiplication"]) + comp_representation.append(comp_dict["number_of_division"]) placeholders_indices_dict = {} for i, element in enumerate(comp_representation): @@ -387,70 +427,88 @@ def get_representation_template(cls, program_annot): min_accesses = 1 max_depth = 5 - comp_name = list(program_annot['computations'].keys())[0] - comp_dict = program_annot['computations'][comp_name] + comp_name = list(program_annot["computations"].keys())[0] + comp_dict = program_annot["computations"][comp_name] - if len(comp_dict['accesses']) > max_accesses: + if len(comp_dict["accesses"]) > max_accesses: raise NbAccessException - if len(comp_dict['accesses']) < min_accesses: + if len(comp_dict["accesses"]) < min_accesses: raise NbAccessException - if len(comp_dict['iterators']) > max_depth: + if len(comp_dict["iterators"]) > max_depth: raise LoopsDepthException comp_repr_template = [] iterators_repr = [] - for iter_i, iterator_name in enumerate(comp_dict['iterators']): - iterator_dict = program_annot['iterators'][iterator_name] + for iter_i, iterator_name in enumerate(comp_dict["iterators"]): + iterator_dict = program_annot["iterators"][iterator_name] + iterators_repr.extend( + [iterator_dict["lower_bound"], iterator_dict["upper_bound"]] + ) + + l_code = "L" + iterator_name iterators_repr.extend( - [iterator_dict['lower_bound'], iterator_dict['upper_bound']]) - - l_code = 'L' + iterator_name - iterators_repr.extend([ - l_code + 'Interchanged', l_code + 'Skewed', - l_code + 'SkewFactor', l_code + 'Parallelized', - l_code + 'Tiled', l_code + 'TileFactor', l_code + 'Reversed', - 0, 0, l_code + "_1" + 'Interchanged', l_code + "_1" + 'Skewed', - l_code + "_1" + 'SkewFactor', l_code + "_1" + 'Parallelized', - l_code + "_1" + 'Tiled', l_code + 'TileFactor', - l_code + "_1" + 'Reversed' - ]) + [ + l_code + "Interchanged", + l_code + "Skewed", + l_code + "SkewFactor", + l_code + "Parallelized", + l_code + "Tiled", + l_code + "TileFactor", + l_code + "Reversed", + 0, + 0, + l_code + "_1" + "Interchanged", + l_code + "_1" + "Skewed", + l_code + "_1" + "SkewFactor", + l_code + "_1" + "Parallelized", + l_code + "_1" + "Tiled", + l_code + "TileFactor", + l_code + "_1" + "Reversed", + ] + ) iterator_repr_size = int( - len(iterators_repr) / (2 * len(comp_dict['iterators']))) - iterators_repr.extend([0] * iterator_repr_size * 2 * - (max_depth - len(comp_dict['iterators']))) + len(iterators_repr) / (2 * len(comp_dict["iterators"])) + ) + iterators_repr.extend( + [0] * iterator_repr_size * 2 * (max_depth - len(comp_dict["iterators"])) + ) - iterators_repr.extend(['Unrolled', 'UnrollFactor']) + iterators_repr.extend(["Unrolled", "UnrollFactor"]) comp_repr_template.extend(iterators_repr) padded_write_matrix = cls.pad_access_matrix( - cls.isl_to_write_matrix(comp_dict['write_access_relation']), - max_depth) - write_access_repr = [comp_dict['write_buffer_id'] + 1 - ] + padded_write_matrix.flatten().tolist() + cls.isl_to_write_matrix(comp_dict["write_access_relation"]), max_depth + ) + write_access_repr = [ + comp_dict["write_buffer_id"] + 1 + ] + padded_write_matrix.flatten().tolist() comp_repr_template.extend(write_access_repr) read_accesses_repr = [] - for read_access_dict in comp_dict['accesses']: + for read_access_dict in comp_dict["accesses"]: read_access_matrix = cls.pad_access_matrix( - read_access_dict['access_matrix'], max_depth) - read_access_repr = [read_access_dict['buffer_id'] + 1 - ] + read_access_matrix.flatten().tolist() + read_access_dict["access_matrix"], max_depth + ) + read_access_repr = [ + read_access_dict["buffer_id"] + 1 + ] + read_access_matrix.flatten().tolist() read_accesses_repr.extend(read_access_repr) access_repr_len = (max_depth + 1) * (max_depth + 2) + 1 - read_accesses_repr.extend([0] * access_repr_len * - (max_accesses - len(comp_dict['accesses']))) + read_accesses_repr.extend( + [0] * access_repr_len * (max_accesses - len(comp_dict["accesses"])) + ) comp_repr_template.extend(read_accesses_repr) - comp_repr_template.append(comp_dict['number_of_additions']) - comp_repr_template.append(comp_dict['number_of_subtraction']) - comp_repr_template.append(comp_dict['number_of_multiplication']) - comp_repr_template.append(comp_dict['number_of_division']) + comp_repr_template.append(comp_dict["number_of_additions"]) + comp_repr_template.append(comp_dict["number_of_subtraction"]) + comp_repr_template.append(comp_dict["number_of_multiplication"]) + comp_repr_template.append(comp_dict["number_of_division"]) placeholders_indices_dict = {} for i, element in enumerate(comp_repr_template): @@ -463,76 +521,107 @@ def get_representation_template(cls, program_annot): @classmethod def sched_str(cls, sched_str, id, params, comp_indic): if id in range(28): - sched_str += 'I(L' + str(params["first_dim_index"]) + ',L' + str( - params['second_dim_index']) + ')' + sched_str += ( + "I(L" + + str(params["first_dim_index"]) + + ",L" + + str(params["second_dim_index"]) + + ")" + ) else: if id in range(28, 41): if params["tiling_depth"] == 2: - sched_str += 'T2(L' + str( - params["first_dim_index"]) + ',L' + str( - params['second_dim_index']) + ',' + str( - params["first_factor"]) + ',' + str( - params["second_factor"]) + ')' + sched_str += ( + "T2(L" + + str(params["first_dim_index"]) + + ",L" + + str(params["second_dim_index"]) + + "," + + str(params["first_factor"]) + + "," + + str(params["second_factor"]) + + ")" + ) else: - sched_str += 'T3(L' + str( - params["first_dim_index"]) + ',L' + str( - params['second_dim_index']) + ',L' + str( - params["third_dim_index"]) + ',' + str( - params["first_factor"]) + ',' + str( - params["second_factor"]) + ',' + str( - params["third_factor"]) + ')' + sched_str += ( + "T3(L" + + str(params["first_dim_index"]) + + ",L" + + str(params["second_dim_index"]) + + ",L" + + str(params["third_dim_index"]) + + "," + + str(params["first_factor"]) + + "," + + str(params["second_factor"]) + + "," + + str(params["third_factor"]) + + ")" + ) else: if id in range(41, 44): for comp in params: - sched_str += 'U(L' + str( - params[comp]["dim_index"]) + ',' + str( - params[comp]['unrolling_factor']) + ",C" + str( - comp_indic[comp]) + ')' + sched_str += ( + "U(L" + + str(params[comp]["dim_index"]) + + "," + + str(params[comp]["unrolling_factor"]) + + ",C" + + str(comp_indic[comp]) + + ")" + ) else: if id in range(44, 46): - sched_str += 'S(L' + str( - params["first_dim_index"]) + ',L' + str( - params['second_dim_index']) + ',' + str( - params["first_factor"]) + ',' + str( - params["second_factor"]) + ')' + sched_str += ( + "S(L" + + str(params["first_dim_index"]) + + ",L" + + str(params["second_dim_index"]) + + "," + + str(params["first_factor"]) + + "," + + str(params["second_factor"]) + + ")" + ) else: if id in range(46, 48): - sched_str += 'P(L' + str(params["dim_index"]) + ')' + sched_str += "P(L" + str(params["dim_index"]) + ")" else: if id in range(48, 56): - sched_str += 'R(L' + str( - params["dim_index"]) + ')' + sched_str += "R(L" + str(params["dim_index"]) + ")" else: if id in range(56, 61): - sched_str += 'F(L' + str( - params["dim_index"]) + ')' + sched_str += "F(L" + str(params["dim_index"]) + ")" return sched_str @classmethod def get_orig_tree_struct(cls, program_json, root_iterator): tree_struct = { - 'loop_name': - root_iterator, - 'computations_list': - program_json['iterators'][root_iterator]['computations_list'][:], - 'child_list': [] + "loop_name": root_iterator, + "computations_list": program_json["iterators"][root_iterator][ + "computations_list" + ][:], + "child_list": [], } - for child_iterator in program_json['iterators'][root_iterator][ - 'child_iterators']: - tree_struct['child_list'].append( - cls.get_orig_tree_struct(program_json, child_iterator)) + for child_iterator in program_json["iterators"][root_iterator][ + "child_iterators" + ]: + tree_struct["child_list"].append( + cls.get_orig_tree_struct(program_json, child_iterator) + ) return tree_struct @classmethod - def update_iterators(cls, id, it_list, action_params, added_iterators, - comp_indic_dict): + def update_iterators( + cls, id, it_list, action_params, added_iterators, comp_indic_dict + ): for comp in it_list: if id in range(28): tmp = it_list[comp][action_params["first_dim_index"]] - it_list[comp][ - action_params["first_dim_index"]] = it_list[comp].pop( - action_params["second_dim_index"]) + it_list[comp][action_params["first_dim_index"]] = it_list[comp].pop( + action_params["second_dim_index"] + ) it_list[comp][action_params["second_dim_index"]] = tmp if id in range(28, 41): @@ -545,11 +634,15 @@ def update_iterators(cls, id, it_list, action_params, added_iterators, if action_params["tiling_depth"] == 2: while i > depth_2: - if action_params["tiling_loop_1"] and action_params[ - "tiling_loop_2"]: + if ( + action_params["tiling_loop_1"] + and action_params["tiling_loop_2"] + ): it_list[comp][i + 2] = it_list[comp][i] - elif action_params["tiling_loop_1"] or action_params[ - "tiling_loop_2"]: + elif ( + action_params["tiling_loop_1"] + or action_params["tiling_loop_2"] + ): it_list[comp][i + 1] = it_list[comp][i] i -= 1 @@ -558,15 +651,17 @@ def update_iterators(cls, id, it_list, action_params, added_iterators, depth_3 = action_params["third_dim_index"] while i > depth_3: - if action_params["tiling_loop_1"] and action_params[ - "tiling_loop_2"] and action_params[ - "tiling_loop_3"]: + if ( + action_params["tiling_loop_1"] + and action_params["tiling_loop_2"] + and action_params["tiling_loop_3"] + ): it_list[comp][i + 3] = it_list[comp][i] else: booleans = [ action_params["tiling_loop_1"], action_params["tiling_loop_2"], - action_params["tiling_loop_3"] + action_params["tiling_loop_3"], ] if booleans.count(True) == 2: it_list[comp][i + 2] = it_list[comp][i] @@ -575,283 +670,332 @@ def update_iterators(cls, id, it_list, action_params, added_iterators, i -= 1 if action_params["tiling_depth"] == 2: - if action_params["tiling_loop_1"] and action_params[ - "tiling_loop_2"]: - - it_list[comp][depth_1][ - 'upper_bound'] = it_list[comp][depth_1][ - 'upper_bound'] / action_params["first_factor"] + if ( + action_params["tiling_loop_1"] + and action_params["tiling_loop_2"] + ): + + it_list[comp][depth_1]["upper_bound"] = ( + it_list[comp][depth_1]["upper_bound"] + / action_params["first_factor"] + ) it_list[comp][depth_1 + 2] = {} - it_list[comp][depth_1 + 2]['iterator'] = "{}_1".format( - it_list[comp][depth_1]['iterator']) - it_list[comp][depth_1 + 2]['lower_bound'] = it_list[ - comp][depth_1]['lower_bound'] - it_list[comp][ - depth_1 + - 2]['upper_bound'] = action_params["first_factor"] - - added_iterators.append(it_list[comp][depth_1 + - 2]['iterator']) - - it_list[comp][depth_2][ - 'upper_bound'] = it_list[comp][depth_2][ - 'upper_bound'] / action_params["second_factor"] + it_list[comp][depth_1 + 2]["iterator"] = "{}_1".format( + it_list[comp][depth_1]["iterator"] + ) + it_list[comp][depth_1 + 2]["lower_bound"] = it_list[comp][ + depth_1 + ]["lower_bound"] + it_list[comp][depth_1 + 2]["upper_bound"] = action_params[ + "first_factor" + ] + + added_iterators.append(it_list[comp][depth_1 + 2]["iterator"]) + + it_list[comp][depth_2]["upper_bound"] = ( + it_list[comp][depth_2]["upper_bound"] + / action_params["second_factor"] + ) it_list[comp][depth_2 + 2] = {} - it_list[comp][depth_2 + 2]['iterator'] = "{}_1".format( - it_list[comp][depth_2]['iterator']) - it_list[comp][depth_2 + 2]['lower_bound'] = it_list[ - comp][depth_2]['lower_bound'] - it_list[comp][ - depth_2 + - 2]['upper_bound'] = action_params["second_factor"] + it_list[comp][depth_2 + 2]["iterator"] = "{}_1".format( + it_list[comp][depth_2]["iterator"] + ) + it_list[comp][depth_2 + 2]["lower_bound"] = it_list[comp][ + depth_2 + ]["lower_bound"] + it_list[comp][depth_2 + 2]["upper_bound"] = action_params[ + "second_factor" + ] - added_iterators.append(it_list[comp][depth_2 + - 2]['iterator']) + added_iterators.append(it_list[comp][depth_2 + 2]["iterator"]) else: if action_params["tiling_loop_1"]: - it_list[comp][depth_1]['upper_bound'] = it_list[ - comp][depth_1]['upper_bound'] / action_params[ - "first_factor"] + it_list[comp][depth_1]["upper_bound"] = ( + it_list[comp][depth_1]["upper_bound"] + / action_params["first_factor"] + ) it_list[comp][depth_1 + 2] = {} - it_list[comp][ - depth_1 + 2]['iterator'] = "{}_1".format( - it_list[comp][depth_1]['iterator']) - it_list[comp][depth_1 + - 2]['lower_bound'] = it_list[comp][ - depth_1]['lower_bound'] - it_list[comp][depth_1 + 2][ - 'upper_bound'] = action_params["first_factor"] + it_list[comp][depth_1 + 2]["iterator"] = "{}_1".format( + it_list[comp][depth_1]["iterator"] + ) + it_list[comp][depth_1 + 2]["lower_bound"] = it_list[comp][ + depth_1 + ]["lower_bound"] + it_list[comp][depth_1 + 2]["upper_bound"] = action_params[ + "first_factor" + ] added_iterators.append( - it_list[comp][depth_1 + 2]['iterator']) + it_list[comp][depth_1 + 2]["iterator"] + ) elif action_params["tiling_loop_2"]: - it_list[comp][depth_2]['upper_bound'] = it_list[ - comp][depth_2]['upper_bound'] / action_params[ - "second_factor"] + it_list[comp][depth_2]["upper_bound"] = ( + it_list[comp][depth_2]["upper_bound"] + / action_params["second_factor"] + ) it_list[comp][depth_2 + 1] = {} - it_list[comp][ - depth_2 + 1]['iterator'] = "{}_1".format( - it_list[comp][depth_2]['iterator']) - it_list[comp][depth_2 + - 1]['lower_bound'] = it_list[comp][ - depth_2]['lower_bound'] - it_list[comp][depth_2 + 1][ - 'upper_bound'] = action_params["second_factor"] + it_list[comp][depth_2 + 1]["iterator"] = "{}_1".format( + it_list[comp][depth_2]["iterator"] + ) + it_list[comp][depth_2 + 1]["lower_bound"] = it_list[comp][ + depth_2 + ]["lower_bound"] + it_list[comp][depth_2 + 1]["upper_bound"] = action_params[ + "second_factor" + ] added_iterators.append( - it_list[comp][depth_2 + 1]['iterator']) + it_list[comp][depth_2 + 1]["iterator"] + ) elif action_params["tiling_depth"] == 3: - if action_params["tiling_loop_1"] and action_params[ - "tiling_loop_2"] and action_params["tiling_loop_3"]: + if ( + action_params["tiling_loop_1"] + and action_params["tiling_loop_2"] + and action_params["tiling_loop_3"] + ): - it_list[comp][depth_1][ - 'upper_bound'] = it_list[comp][depth_1][ - 'upper_bound'] / action_params["first_factor"] + it_list[comp][depth_1]["upper_bound"] = ( + it_list[comp][depth_1]["upper_bound"] + / action_params["first_factor"] + ) it_list[comp][depth_1 + 3] = {} - it_list[comp][depth_1 + 3]['iterator'] = "{}_1".format( - it_list[comp][depth_1]['iterator']) - it_list[comp][depth_1 + 3]['lower_bound'] = it_list[ - comp][depth_1]['lower_bound'] - it_list[comp][ - depth_1 + - 3]['upper_bound'] = action_params["first_factor"] - - added_iterators.append(it_list[comp][depth_1 + - 3]['iterator']) - - it_list[comp][depth_2][ - 'upper_bound'] = it_list[comp][depth_2][ - 'upper_bound'] / action_params["second_factor"] + it_list[comp][depth_1 + 3]["iterator"] = "{}_1".format( + it_list[comp][depth_1]["iterator"] + ) + it_list[comp][depth_1 + 3]["lower_bound"] = it_list[comp][ + depth_1 + ]["lower_bound"] + it_list[comp][depth_1 + 3]["upper_bound"] = action_params[ + "first_factor" + ] + + added_iterators.append(it_list[comp][depth_1 + 3]["iterator"]) + + it_list[comp][depth_2]["upper_bound"] = ( + it_list[comp][depth_2]["upper_bound"] + / action_params["second_factor"] + ) it_list[comp][depth_2 + 3] = {} - it_list[comp][depth_2 + 3]['iterator'] = "{}_1".format( - it_list[comp][depth_2]['iterator']) - it_list[comp][depth_2 + 3]['lower_bound'] = it_list[ - comp][depth_2]['lower_bound'] - it_list[comp][ - depth_2 + - 3]['upper_bound'] = action_params["second_factor"] - - added_iterators.append(it_list[comp][depth_2 + - 3]['iterator']) - - it_list[comp][depth_3][ - 'upper_bound'] = it_list[comp][depth_3][ - 'upper_bound'] / action_params["third_factor"] + it_list[comp][depth_2 + 3]["iterator"] = "{}_1".format( + it_list[comp][depth_2]["iterator"] + ) + it_list[comp][depth_2 + 3]["lower_bound"] = it_list[comp][ + depth_2 + ]["lower_bound"] + it_list[comp][depth_2 + 3]["upper_bound"] = action_params[ + "second_factor" + ] + + added_iterators.append(it_list[comp][depth_2 + 3]["iterator"]) + + it_list[comp][depth_3]["upper_bound"] = ( + it_list[comp][depth_3]["upper_bound"] + / action_params["third_factor"] + ) it_list[comp][depth_3 + 3] = {} - it_list[comp][depth_3 + 3]['iterator'] = "{}_1".format( - it_list[comp][depth_3]['iterator']) - it_list[comp][depth_3 + 3]['lower_bound'] = it_list[ - comp][depth_3]['lower_bound'] - it_list[comp][ - depth_3 + - 3]['upper_bound'] = action_params["third_factor"] - - added_iterators.append(it_list[comp][depth_3 + - 3]['iterator']) - - elif action_params["tiling_loop_1"] and action_params[ - "tiling_loop_2"]: - - it_list[comp][depth_1][ - 'upper_bound'] = it_list[comp][depth_1][ - 'upper_bound'] / action_params["first_factor"] + it_list[comp][depth_3 + 3]["iterator"] = "{}_1".format( + it_list[comp][depth_3]["iterator"] + ) + it_list[comp][depth_3 + 3]["lower_bound"] = it_list[comp][ + depth_3 + ]["lower_bound"] + it_list[comp][depth_3 + 3]["upper_bound"] = action_params[ + "third_factor" + ] + + added_iterators.append(it_list[comp][depth_3 + 3]["iterator"]) + + elif ( + action_params["tiling_loop_1"] + and action_params["tiling_loop_2"] + ): + + it_list[comp][depth_1]["upper_bound"] = ( + it_list[comp][depth_1]["upper_bound"] + / action_params["first_factor"] + ) it_list[comp][depth_1 + 3] = {} - it_list[comp][depth_1 + 3]['iterator'] = "{}_1".format( - it_list[comp][depth_1]['iterator']) - it_list[comp][depth_1 + 3]['lower_bound'] = it_list[ - comp][depth_1]['lower_bound'] - it_list[comp][ - depth_1 + - 3]['upper_bound'] = action_params["first_factor"] - - added_iterators.append(it_list[comp][depth_1 + - 3]['iterator']) - - it_list[comp][depth_2][ - 'upper_bound'] = it_list[comp][depth_2][ - 'upper_bound'] / action_params["second_factor"] + it_list[comp][depth_1 + 3]["iterator"] = "{}_1".format( + it_list[comp][depth_1]["iterator"] + ) + it_list[comp][depth_1 + 3]["lower_bound"] = it_list[comp][ + depth_1 + ]["lower_bound"] + it_list[comp][depth_1 + 3]["upper_bound"] = action_params[ + "first_factor" + ] + + added_iterators.append(it_list[comp][depth_1 + 3]["iterator"]) + + it_list[comp][depth_2]["upper_bound"] = ( + it_list[comp][depth_2]["upper_bound"] + / action_params["second_factor"] + ) it_list[comp][depth_2 + 3] = {} - it_list[comp][depth_2 + 3]['iterator'] = "{}_1".format( - it_list[comp][depth_2]['iterator']) - it_list[comp][depth_2 + 3]['lower_bound'] = it_list[ - comp][depth_2]['lower_bound'] - it_list[comp][ - depth_2 + - 3]['upper_bound'] = action_params["second_factor"] - - added_iterators.append(it_list[comp][depth_2 + - 3]['iterator']) - - elif action_params["tiling_loop_2"] and action_params[ - "tiling_loop_3"]: - - it_list[comp][depth_2][ - 'upper_bound'] = it_list[comp][depth_2][ - 'upper_bound'] / action_params["second_factor"] + it_list[comp][depth_2 + 3]["iterator"] = "{}_1".format( + it_list[comp][depth_2]["iterator"] + ) + it_list[comp][depth_2 + 3]["lower_bound"] = it_list[comp][ + depth_2 + ]["lower_bound"] + it_list[comp][depth_2 + 3]["upper_bound"] = action_params[ + "second_factor" + ] + + added_iterators.append(it_list[comp][depth_2 + 3]["iterator"]) + + elif ( + action_params["tiling_loop_2"] + and action_params["tiling_loop_3"] + ): + + it_list[comp][depth_2]["upper_bound"] = ( + it_list[comp][depth_2]["upper_bound"] + / action_params["second_factor"] + ) it_list[comp][depth_2 + 2] = {} - it_list[comp][depth_2 + 2]['iterator'] = "{}_1".format( - it_list[comp][depth_2]['iterator']) - it_list[comp][depth_2 + 2]['lower_bound'] = it_list[ - comp][depth_2]['lower_bound'] - it_list[comp][ - depth_2 + - 2]['upper_bound'] = action_params["second_factor"] - - added_iterators.append(it_list[comp][depth_2 + - 2]['iterator']) - - it_list[comp][depth_3][ - 'upper_bound'] = it_list[comp][depth_3][ - 'upper_bound'] / action_params["third_factor"] + it_list[comp][depth_2 + 2]["iterator"] = "{}_1".format( + it_list[comp][depth_2]["iterator"] + ) + it_list[comp][depth_2 + 2]["lower_bound"] = it_list[comp][ + depth_2 + ]["lower_bound"] + it_list[comp][depth_2 + 2]["upper_bound"] = action_params[ + "second_factor" + ] + + added_iterators.append(it_list[comp][depth_2 + 2]["iterator"]) + + it_list[comp][depth_3]["upper_bound"] = ( + it_list[comp][depth_3]["upper_bound"] + / action_params["third_factor"] + ) it_list[comp][depth_3 + 2] = {} - it_list[comp][depth_3 + 2]['iterator'] = "{}_1".format( - it_list[comp][depth_3]['iterator']) - it_list[comp][depth_3 + 2]['lower_bound'] = it_list[ - comp][depth_3]['lower_bound'] - it_list[comp][ - depth_3 + - 2]['upper_bound'] = action_params["third_factor"] - - added_iterators.append(it_list[comp][depth_3 + - 2]['iterator']) - - elif action_params["tiling_loop_1"] and action_params[ - "tiling_loop_3"]: - - it_list[comp][depth_1][ - 'upper_bound'] = it_list[comp][depth_1][ - 'upper_bound'] / action_params["first_factor"] + it_list[comp][depth_3 + 2]["iterator"] = "{}_1".format( + it_list[comp][depth_3]["iterator"] + ) + it_list[comp][depth_3 + 2]["lower_bound"] = it_list[comp][ + depth_3 + ]["lower_bound"] + it_list[comp][depth_3 + 2]["upper_bound"] = action_params[ + "third_factor" + ] + + added_iterators.append(it_list[comp][depth_3 + 2]["iterator"]) + + elif ( + action_params["tiling_loop_1"] + and action_params["tiling_loop_3"] + ): + + it_list[comp][depth_1]["upper_bound"] = ( + it_list[comp][depth_1]["upper_bound"] + / action_params["first_factor"] + ) it_list[comp][depth_1 + 3] = {} - it_list[comp][depth_1 + 3]['iterator'] = "{}_1".format( - it_list[comp][depth_1]['iterator']) - it_list[comp][depth_1 + 3]['lower_bound'] = it_list[ - comp][depth_1]['lower_bound'] - it_list[comp][ - depth_1 + - 3]['upper_bound'] = action_params["first_factor"] - - added_iterators.append(it_list[comp][depth_1 + - 3]['iterator']) - - it_list[comp][depth_3][ - 'upper_bound'] = it_list[comp][depth_3][ - 'upper_bound'] / action_params["third_factor"] + it_list[comp][depth_1 + 3]["iterator"] = "{}_1".format( + it_list[comp][depth_1]["iterator"] + ) + it_list[comp][depth_1 + 3]["lower_bound"] = it_list[comp][ + depth_1 + ]["lower_bound"] + it_list[comp][depth_1 + 3]["upper_bound"] = action_params[ + "first_factor" + ] + + added_iterators.append(it_list[comp][depth_1 + 3]["iterator"]) + + it_list[comp][depth_3]["upper_bound"] = ( + it_list[comp][depth_3]["upper_bound"] + / action_params["third_factor"] + ) it_list[comp][depth_3 + 2] = {} - it_list[comp][depth_3 + 2]['iterator'] = "{}_1".format( - it_list[comp][depth_3]['iterator']) - it_list[comp][depth_3 + 2]['lower_bound'] = it_list[ - comp][depth_3]['lower_bound'] - it_list[comp][ - depth_3 + - 2]['upper_bound'] = action_params["third_factor"] - - added_iterators.append(it_list[comp][depth_3 + - 2]['iterator']) + it_list[comp][depth_3 + 2]["iterator"] = "{}_1".format( + it_list[comp][depth_3]["iterator"] + ) + it_list[comp][depth_3 + 2]["lower_bound"] = it_list[comp][ + depth_3 + ]["lower_bound"] + it_list[comp][depth_3 + 2]["upper_bound"] = action_params[ + "third_factor" + ] + + added_iterators.append(it_list[comp][depth_3 + 2]["iterator"]) else: if action_params["tiling_loop_1"]: - it_list[comp][depth_1]['upper_bound'] = it_list[ - comp][depth_1]['upper_bound'] / action_params[ - "first_factor"] + it_list[comp][depth_1]["upper_bound"] = ( + it_list[comp][depth_1]["upper_bound"] + / action_params["first_factor"] + ) it_list[comp][depth_1 + 3] = {} - it_list[comp][ - depth_1 + 3]['iterator'] = "{}_1".format( - it_list[comp][depth_1]['iterator']) - it_list[comp][depth_1 + - 3]['lower_bound'] = it_list[comp][ - depth_1]['lower_bound'] - it_list[comp][depth_1 + 3][ - 'upper_bound'] = action_params["first_factor"] + it_list[comp][depth_1 + 3]["iterator"] = "{}_1".format( + it_list[comp][depth_1]["iterator"] + ) + it_list[comp][depth_1 + 3]["lower_bound"] = it_list[comp][ + depth_1 + ]["lower_bound"] + it_list[comp][depth_1 + 3]["upper_bound"] = action_params[ + "first_factor" + ] added_iterators.append( - it_list[comp][depth_1 + 3]['iterator']) + it_list[comp][depth_1 + 3]["iterator"] + ) elif action_params["tiling_loop_2"]: - it_list[comp][depth_2]['upper_bound'] = it_list[ - comp][depth_2]['upper_bound'] / action_params[ - "second_factor"] + it_list[comp][depth_2]["upper_bound"] = ( + it_list[comp][depth_2]["upper_bound"] + / action_params["second_factor"] + ) it_list[comp][depth_2 + 2] = {} - it_list[comp][ - depth_2 + 2]['iterator'] = "{}_1".format( - it_list[comp][depth_2]['iterator']) - it_list[comp][depth_2 + - 2]['lower_bound'] = it_list[comp][ - depth_2]['lower_bound'] - it_list[comp][depth_2 + 2][ - 'upper_bound'] = action_params["second_factor"] + it_list[comp][depth_2 + 2]["iterator"] = "{}_1".format( + it_list[comp][depth_2]["iterator"] + ) + it_list[comp][depth_2 + 2]["lower_bound"] = it_list[comp][ + depth_2 + ]["lower_bound"] + it_list[comp][depth_2 + 2]["upper_bound"] = action_params[ + "second_factor" + ] added_iterators.append( - it_list[comp][depth_2 + 2]['iterator']) + it_list[comp][depth_2 + 2]["iterator"] + ) elif action_params["tiling_loop_3"]: - it_list[comp][depth_3]['upper_bound'] = it_list[ - comp][depth_3]['upper_bound'] / action_params[ - "third_factor"] + it_list[comp][depth_3]["upper_bound"] = ( + it_list[comp][depth_3]["upper_bound"] + / action_params["third_factor"] + ) it_list[comp][depth_3 + 1] = {} - it_list[comp][ - depth_3 + 1]['iterator'] = "{}_1".format( - it_list[comp][depth_3]['iterator']) - it_list[comp][depth_3 + - 1]['lower_bound'] = it_list[comp][ - depth_3]['lower_bound'] - it_list[comp][depth_3 + 1][ - 'upper_bound'] = action_params["third_factor"] + it_list[comp][depth_3 + 1]["iterator"] = "{}_1".format( + it_list[comp][depth_3]["iterator"] + ) + it_list[comp][depth_3 + 1]["lower_bound"] = it_list[comp][ + depth_3 + ]["lower_bound"] + it_list[comp][depth_3 + 1]["upper_bound"] = action_params[ + "third_factor" + ] added_iterators.append( - it_list[comp][depth_3 + 1]['iterator']) + it_list[comp][depth_3 + 1]["iterator"] + ) elif id in range(41, 44): - it_list[comp][action_params["dim_index"]][ - 'upper_bound'] = it_list[comp][action_params["dim_index"]][ - 'upper_bound'] / action_params['unrolling_factor'] + it_list[comp][action_params["dim_index"]]["upper_bound"] = ( + it_list[comp][action_params["dim_index"]]["upper_bound"] + / action_params["unrolling_factor"] + ) elif id in range(44, 46): depth_1 = action_params["first_dim_index"] @@ -866,13 +1010,15 @@ def update_iterators(cls, id, it_list, action_params, added_iterators, l2_extent = abs(l2_upper_bound - l2_lower_bound) l2_lower_bound = 0 - l1_lower_bound = abs( - action_params["first_factor"]) * l1_lower_bound - l1_upper_bound = l1_lower_bound + abs( - action_params["first_factor"]) * l1_extent + abs( - action_params["second_factor"]) * l2_extent - l2_upper_bound = ((l1_extent * l2_extent) / - (l1_upper_bound - l1_lower_bound)) + 1 + l1_lower_bound = abs(action_params["first_factor"]) * l1_lower_bound + l1_upper_bound = ( + l1_lower_bound + + abs(action_params["first_factor"]) * l1_extent + + abs(action_params["second_factor"]) * l2_extent + ) + l2_upper_bound = ( + (l1_extent * l2_extent) / (l1_upper_bound - l1_lower_bound) + ) + 1 it_list[comp][depth_1]["lower_bound"] = l1_lower_bound it_list[comp][depth_1]["upper_bound"] = l1_upper_bound @@ -880,11 +1026,11 @@ def update_iterators(cls, id, it_list, action_params, added_iterators, it_list[comp][depth_2]["upper_bound"] = l2_upper_bound elif id in range(48, 56): - tmp = it_list[comp][action_params["dim_index"]]['lower_bound'] - it_list[comp][ - action_params["dim_index"]]['lower_bound'] = it_list[comp][ - action_params["dim_index"]]['upper_bound'] - it_list[comp][action_params["dim_index"]]['upper_bound'] = tmp + tmp = it_list[comp][action_params["dim_index"]]["lower_bound"] + it_list[comp][action_params["dim_index"]]["lower_bound"] = it_list[ + comp + ][action_params["dim_index"]]["upper_bound"] + it_list[comp][action_params["dim_index"]]["upper_bound"] = tmp it_list = dict(sorted(it_list.items())) @@ -892,12 +1038,9 @@ def update_iterators(cls, id, it_list, action_params, added_iterators, @classmethod def optimlist_to_str(cls, optim_list): - """Converts a list of OptimizationCommand to a string. - """ + """Converts a list of OptimizationCommand to a string.""" - comp_names = list(set([ - comp for optim in optim_list for comp in optim.comps - ])) + comp_names = list(set([comp for optim in optim_list for comp in optim.comps])) comp_names.sort() @@ -915,7 +1058,7 @@ def optimlist_to_str(cls, optim_list): # Iterate over the comps and add their transformations for name in comp_names: - sched_str += '{' + name + '}:' + sched_str += "{" + name + "}:" for transformation in optim_list: # Skip the transformation if it doesn't include the comp @@ -923,20 +1066,32 @@ def optimlist_to_str(cls, optim_list): continue if transformation.type == "Interchange": - sched_str += "I(L" + str(transformation.params_list[0]) + \ - ",L" + str(transformation.params_list[1]) + ")" + sched_str += ( + "I(L" + + str(transformation.params_list[0]) + + ",L" + + str(transformation.params_list[1]) + + ")" + ) elif transformation.type == "Reversal": sched_str += f"R(L{str(transformation.params_list[0])})" elif transformation.type == "Skewing": - sched_str += "S(L" + str(transformation.params_list[0]) + ",L" + str( - transformation.params_list[1]) + "," + str(transformation.params_list[2]) + "," + str( - transformation.params_list[3]) + ")" + sched_str += ( + "S(L" + + str(transformation.params_list[0]) + + ",L" + + str(transformation.params_list[1]) + + "," + + str(transformation.params_list[2]) + + "," + + str(transformation.params_list[3]) + + ")" + ) elif transformation.type == "Parallelization": - sched_str += "P(L" + \ - str(transformation.params_list[0]) + ")" + sched_str += "P(L" + str(transformation.params_list[0]) + ")" elif transformation.type == "Tiling": # T2 @@ -983,12 +1138,13 @@ def optimlist_to_str(cls, optim_list): elif transformation.type == "Unrolling": dim_index = transformation.params_list[name][0] unrolling_factor = transformation.params_list[name][1] - sched_str += "U(L" + str(dim_index) + "," + \ - str(unrolling_factor) + ")" + sched_str += ( + "U(L" + str(dim_index) + "," + str(unrolling_factor) + ")" + ) return sched_str @classmethod def is_same_machine_as_dataset(cls, prog): hostname = gethostname() - return prog.function_dict['node_name'].startswith(hostname[:2]) + return prog.function_dict["node_name"].startswith(hostname[:2]) diff --git a/tiramisu_programs/surrogate_model_utils/__init__.py b/tiramisu_programs/surrogate_model_utils/__init__.py index 64eaac4..617af58 100644 --- a/tiramisu_programs/surrogate_model_utils/__init__.py +++ b/tiramisu_programs/surrogate_model_utils/__init__.py @@ -1 +1 @@ -from .modeling import * \ No newline at end of file +from .modeling import * diff --git a/tiramisu_programs/surrogate_model_utils/json_to_tensor.py b/tiramisu_programs/surrogate_model_utils/json_to_tensor.py index d719bb4..6c78640 100644 --- a/tiramisu_programs/surrogate_model_utils/json_to_tensor.py +++ b/tiramisu_programs/surrogate_model_utils/json_to_tensor.py @@ -5,7 +5,6 @@ import torch - device = "cpu" train_device = torch.device("cpu") @@ -186,9 +185,8 @@ def update_tree_atributes(node): node["loop_index"] = loops_indices_dict[node["loop_name"]] if node["computations_list"] != []: node["computations_indices"] = [ - comps_indices_dict[comp_name] - for comp_name in node["computations_list"] - ] + comps_indices_dict[comp_name] for comp_name in node["computations_list"] + ] node["has_comps"] = True else: node["has_comps"] = False @@ -443,7 +441,7 @@ def get_schedule_representation( iterator = fusion[-1] for loop in fusion[:-1]: fused_loop = computations_dict[loop]["iterators"][iterator] - # fused_loop2 = computations_dict[fusion[1]]["iterators"][fusion[2]] + # fused_loop2 = computations_dict[fusion[1]]["iterators"][fusion[2]] loop_schedules_dict[fused_loop]["fused"] = 1 # loop_schedules_dict[fused_loop2]["fused"] = 1 for loop_name in program_json["iterators"]: @@ -505,9 +503,11 @@ def get_padded_transformation_matrix( if "transformation_matrices" in comp_schedule_dict: if comp_schedule_dict["transformation_matrices"] != []: if ("transformation_matrix" in comp_schedule_dict) and ( - comp_schedule_dict["transformation_matrix"] is not None + comp_schedule_dict["transformation_matrix"] is not None ): - final_transformation_matrix = comp_schedule_dict["transformation_matrix"].reshape(nb_iterators, nb_iterators) + final_transformation_matrix = comp_schedule_dict[ + "transformation_matrix" + ].reshape(nb_iterators, nb_iterators) else: final_transformation_matrix = identity.copy() final_mat = final_transformation_matrix @@ -570,7 +570,9 @@ def get_padded_transformation_matrix( ).reshape(1, -1) comparison_matrix = identity.copy() for mat in comp_schedule_dict["transformation_matrices"][::-1]: - comparison_matrix = comparison_matrix @ mat.reshape(nb_iterators, nb_iterators) + comparison_matrix = comparison_matrix @ mat.reshape( + nb_iterators, nb_iterators + ) assert (comparison_matrix == final_transformation_matrix).all() else: interchange_matrix = identity.copy() @@ -647,29 +649,35 @@ def get_padded_transformation_matrix( print(padding_ranges) return padded_mat + def nest_iterators(root_iterator, iterators): - if root_iterator['child_iterators'] == []: - return {'loop_name': root_iterator["loop_name"], - 'computations_list': root_iterator['computations_list'], - 'child_list': []} + if root_iterator["child_iterators"] == []: + return { + "loop_name": root_iterator["loop_name"], + "computations_list": root_iterator["computations_list"], + "child_list": [], + } subtrees = [] - for loop_name in root_iterator['child_iterators']: + for loop_name in root_iterator["child_iterators"]: child_iterator = iterators[loop_name] child_iterator["loop_name"] = loop_name sub_tree = nest_iterators(child_iterator, iterators) subtrees.append(sub_tree) - return {'loop_name': root_iterator["loop_name"], - 'computations_list': root_iterator['computations_list'], - 'child_list': subtrees} + return { + "loop_name": root_iterator["loop_name"], + "computations_list": root_iterator["computations_list"], + "child_list": subtrees, + } + def get_tree_structure(prog_dict): iterators = prog_dict["iterators"] mentionned = [] for loop, content in iterators.items(): - mentionned.extend(content['child_iterators']) + mentionned.extend(content["child_iterators"]) - possible_root = [loop for loop in iterators if loop not in mentionned] + possible_root = [loop for loop in iterators if loop not in mentionned] assert len(possible_root) == 1 root_loop_name = possible_root[0] @@ -677,6 +685,7 @@ def get_tree_structure(prog_dict): root_iterator["loop_name"] = root_loop_name return nest_iterators(root_iterator, iterators) + def get_tree_footprint(tree): footprint = "" if tree["has_comps"]: @@ -687,4 +696,4 @@ def get_tree_footprint(tree): for child in tree["child_list"]: footprint += get_tree_footprint(child) footprint += "" - return footprint \ No newline at end of file + return footprint diff --git a/tiramisu_programs/surrogate_model_utils/modeling.py b/tiramisu_programs/surrogate_model_utils/modeling.py index 11dee31..8038e64 100644 --- a/tiramisu_programs/surrogate_model_utils/modeling.py +++ b/tiramisu_programs/surrogate_model_utils/modeling.py @@ -6,7 +6,7 @@ def seperate_vector( X: torch.Tensor, num_matrices: int = 4, pad: bool = True, pad_amount: int = 5 ) -> torch.Tensor: batch_size, _ = X.shape - tags_num = (6*num_matrices+3) + tags_num = 6 * num_matrices + 3 first_part = X[:, :tags_num] second_part = X[:, tags_num : tags_num + 36 * num_matrices] third_part = X[:, tags_num + 36 * num_matrices :] @@ -82,7 +82,11 @@ def __init__( nn.init.zeros_(self.concat_layers[i].weight) self.concat_dropouts.append(nn.Dropout(drops[i])) self.predict = nn.Linear(regression_layer_sizes[-1], output_size, bias=True) - self.encode_vectors = nn.Linear(transformation_matrix_dimension**2, transformation_matrix_dimension**2, bias=True) + self.encode_vectors = nn.Linear( + transformation_matrix_dimension**2, + transformation_matrix_dimension**2, + bias=True, + ) nn.init.xavier_uniform_(self.predict.weight) # nn.init.zeros_(self.predict.weight) self.ELU = nn.ELU() @@ -113,8 +117,6 @@ def __init__( else: self.embedding_generator = lambda x: x - - def get_hidden_state(self, node, comps_embeddings, loops_tensor): nodes_list = [] for n in node["child_list"]: @@ -129,7 +131,9 @@ def get_hidden_state(self, node, comps_embeddings, loops_tensor): ) if node["has_comps"]: selected_comps_tensor = torch.index_select( - comps_embeddings, 1, torch.tensor(node["computations_indices"]).to(self.train_device) + comps_embeddings, + 1, + torch.tensor(node["computations_indices"]).to(self.train_device), ) lstm_out, (comps_h_n, comps_c_n) = self.comps_lstm(selected_comps_tensor) comps_h_n = comps_h_n.permute(1, 0, 2) @@ -161,7 +165,7 @@ def forward(self, tree_tensors, num_matrices=6): # vectors = vectors[~mask,:,:] # final_matrix = final_matrix[~mask,:,:] # batch_size, num_comps, _ = final_matrix.shape - vectors = self.encode_vectors(vectors) + vectors = self.encode_vectors(vectors) # self.ELU(vectors) # print(vectors.shape) lstm_out, (prog_embedding, comps_c_n) = self.comps_embed(vectors) @@ -200,4 +204,3 @@ def forward(self, tree_tensors, num_matrices=6): x = self.regression_dropouts[i](self.ELU(x)) out = self.predict(x) return self.ELU(out[:, 0, 0]) - diff --git a/tiramisu_programs/tiramisu_program.py b/tiramisu_programs/tiramisu_program.py index 7ed835b..dd3d4c9 100644 --- a/tiramisu_programs/tiramisu_program.py +++ b/tiramisu_programs/tiramisu_program.py @@ -13,8 +13,8 @@ class InternalExecException(Exception): pass -class TiramisuProgram(): - wrapper_h_template = '''#include +class TiramisuProgram: + wrapper_h_template = """#include #include #include #include @@ -59,9 +59,9 @@ class TiramisuProgram(): int $func_name$($func_params$); #ifdef __cplusplus } // extern "C" -#endif''' +#endif""" - wrapper_cpp_template = '''#include "Halide.h" + wrapper_cpp_template = """#include "Halide.h" #include "$func_name$_wrapper.h" #include "tiramisu/utils.h" #include @@ -93,40 +93,46 @@ class TiramisuProgram(): out.close(); return 0; -}''' +}""" def __init__(self, config: RLAutoSchedulerConfig, file_path, function_dict=None): self.config = config self.file_path = file_path - with open(file_path, 'r') as f: + with open(file_path, "r") as f: self.original_str = f.read() - self.func_folder = ('/'.join(Path(file_path).parts[:-1]) - if len(Path(file_path).parts) > 1 else '.') + '/' + self.func_folder = ( + "/".join(Path(file_path).parts[:-1]) + if len(Path(file_path).parts) > 1 + else "." + ) + "/" - self.body = re.findall(r'(tiramisu::init(?s:.)+)tiramisu::codegen', - self.original_str)[0] + self.body = re.findall( + r"(tiramisu::init(?s:.)+)tiramisu::codegen", self.original_str + )[0] - self.name = re.findall(r'tiramisu::init\(\"(\w+)\"\);', - self.original_str)[0] + self.name = re.findall(r"tiramisu::init\(\"(\w+)\"\);", self.original_str)[0] self.original_str = self.original_str.replace( - f'#include "{self.name}_wrapper.h"', '') + f'#include "{self.name}_wrapper.h"', "" + ) - self.comp_name = re.findall(r'computation (\w+)\(', self.original_str) + self.comp_name = re.findall(r"computation (\w+)\(", self.original_str) - self.code_gen_line = re.findall(r'tiramisu::codegen\({.+;', - self.original_str)[0] + self.code_gen_line = re.findall(r"tiramisu::codegen\({.+;", self.original_str)[ + 0 + ] - buffers_vect = re.findall(r'{(.+)}', self.code_gen_line)[0] + buffers_vect = re.findall(r"{(.+)}", self.code_gen_line)[0] - self.IO_buffer_names = re.findall(r'\w+', buffers_vect) + self.IO_buffer_names = re.findall(r"\w+", buffers_vect) self.buffer_sizes = [] for buf_name in self.IO_buffer_names: - sizes_vect = re.findall(r'buffer ' + buf_name + '.*{(.*)}', - self.original_str)[0] - self.buffer_sizes.append(re.findall(r'\d+', sizes_vect)) + sizes_vect = re.findall( + r"buffer " + buf_name + ".*{(.*)}", self.original_str + )[0] + self.buffer_sizes.append(re.findall(r"\d+", sizes_vect)) self.program_annotations = None self.wrapper_is_compiled = False @@ -137,128 +143,156 @@ def get_program_annotations(self): if self.program_annotations is not None: return self.program_annotations - if self.function_dict['program_annotation'] is not None: - self.program_annotations = self.function_dict['program_annotation'] + if self.function_dict["program_annotation"] is not None: + self.program_annotations = self.function_dict["program_annotation"] else: # create a cpp file to get the annotations - get_json_lines = ''' + get_json_lines = ( + ''' auto ast = tiramisu::auto_scheduler::syntax_tree(tiramisu::global::get_implicit_function()); std::string program_json = tiramisu::auto_scheduler::evaluate_by_learning_model::get_program_json(ast); - std::ofstream out("''' + self.func_folder + self.name + '''_program_annotations.json"); + std::ofstream out("''' + + self.func_folder + + self.name + + """_program_annotations.json"); out << program_json; out.close(); - ''' - get_json_prog = self.original_str.replace(self.code_gen_line, - get_json_lines) - output_file = self.func_folder + self.name + '_get_prog_annot.cpp' - - with open(output_file, 'w') as f: + """ + ) + get_json_prog = self.original_str.replace( + self.code_gen_line, get_json_lines + ) + output_file = self.func_folder + self.name + "_get_prog_annot.cpp" + + with open(output_file, "w") as f: f.write(get_json_prog) # compile the cpp file and run to generate annotations in json file CPP_File.compile_and_run_tiramisu_code( - self.config, output_file, 'Generating program annotations') + self.config, output_file, "Generating program annotations" + ) # Read the json file and return the annotations - with open(self.func_folder + self.name + '_program_annotations.json', - 'r') as f: + with open( + self.func_folder + self.name + "_program_annotations.json", "r" + ) as f: self.program_annotations = json.loads(f.read()) - self.function_dict['program_annotation'] = self.program_annotations + self.function_dict["program_annotation"] = self.program_annotations return self.program_annotations - def check_legality_of_schedule( - self, - optims_list, - comps=None, - first_comp=None - ): - legality_check_lines = ''' + def check_legality_of_schedule(self, optims_list, comps=None, first_comp=None): + legality_check_lines = """ prepare_schedules_for_legality_checks(); perform_full_dependency_analysis(); bool is_legal=true; -''' +""" for optim in optims_list: - if optim.type == 'Interchange': - legality_check_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Reversal': - legality_check_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Skewing': - legality_check_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Parallelization': - legality_check_lines += ''' - is_legal &= loop_parallelization_is_legal(''' + str( - optim.params_list[0]) + ''', {&''' + first_comp + '''}); -''' - legality_check_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Tiling': - legality_check_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Fusion': - legality_check_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Unrolling': - legality_check_lines += ''' - is_legal &= loop_unrolling_is_legal(''' + str( - optim.params_list[comps[0]] - [0]) + ''', {''' + ", ".join([f"&{comp}" for comp in comps]) + '''});''' - legality_check_lines += optim.tiramisu_optim_str + '\n' - - legality_check_lines += ''' + if optim.type == "Interchange": + legality_check_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Reversal": + legality_check_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Skewing": + legality_check_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Parallelization": + legality_check_lines += ( + """ + is_legal &= loop_parallelization_is_legal(""" + + str(optim.params_list[0]) + + """, {&""" + + first_comp + + """}); +""" + ) + legality_check_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Tiling": + legality_check_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Fusion": + legality_check_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Unrolling": + legality_check_lines += ( + """ + is_legal &= loop_unrolling_is_legal(""" + + str(optim.params_list[comps[0]][0]) + + """, {""" + + ", ".join([f"&{comp}" for comp in comps]) + + """});""" + ) + legality_check_lines += optim.tiramisu_optim_str + "\n" + + legality_check_lines += ( + ''' is_legal &= check_legality_of_function(); - std::ofstream out("''' + self.func_folder + '''legality_check_result.txt"); + std::ofstream out("''' + + self.func_folder + + """legality_check_result.txt"); out << is_legal; out.close(); - ''' + """ + ) - LC_code = self.original_str.replace(self.code_gen_line, - legality_check_lines) - output_file = self.func_folder + self.name + '_legality_check.cpp' - with open(output_file, 'w') as f: + LC_code = self.original_str.replace(self.code_gen_line, legality_check_lines) + output_file = self.func_folder + self.name + "_legality_check.cpp" + with open(output_file, "w") as f: f.write(LC_code) self.reset_legality_check_result_file() - log_message = 'Checking legality for: ' + ' '.join( - [o.tiramisu_optim_str for o in optims_list]) - CPP_File.compile_and_run_tiramisu_code( - self.config, output_file, log_message) + log_message = "Checking legality for: " + " ".join( + [o.tiramisu_optim_str for o in optims_list] + ) + CPP_File.compile_and_run_tiramisu_code(self.config, output_file, log_message) lc_result = self.read_legality_check_result_file() return lc_result def call_solver(self, comp, params): - lc_file = self.func_folder + self.name + '_legality_check.cpp' + lc_file = self.func_folder + self.name + "_legality_check.cpp" if os.path.isfile(lc_file): - with open(lc_file, 'r') as f: + with open(lc_file, "r") as f: original_str = f.read() - to_replace = re.findall(r'(std::ofstream out(?s:.)+)return', - original_str)[0] + to_replace = re.findall(r"(std::ofstream out(?s:.)+)return", original_str)[ + 0 + ] header = "function * fct = tiramisu::global::get_implicit_function();\n" original_str = original_str.replace( - "is_legal &= check_legality_of_function()", "") + "is_legal &= check_legality_of_function()", "" + ) original_str = original_str.replace("bool is_legal=true;", "") original_str = re.sub( - r'is_legal &= loop_parallelization_is_legal.*\n', "", original_str) + r"is_legal &= loop_parallelization_is_legal.*\n", "", original_str + ) original_str = re.sub( - r'is_legal &= loop_unrolling_is_legal.*\n', "", original_str) + r"is_legal &= loop_unrolling_is_legal.*\n", "", original_str + ) else: original_str = self.original_str to_replace = self.code_gen_line - header = ''' + header = """ perform_full_dependency_analysis(); prepare_schedules_for_legality_checks(); function * fct = tiramisu::global::get_implicit_function(); - ''' - - solver_lines = header + "auto auto_skewing_result = fct->skewing_local_solver({&" + comp + "}},{},{},1);\n".format( - params["first_dim_index"], params["second_dim_index"]) - - solver_lines += ''' - std::ofstream out("''' + self.func_folder + '''solver_result.txt"); + """ + + solver_lines = ( + header + + "auto auto_skewing_result = fct->skewing_local_solver({&" + + comp + + "}},{},{},1);\n".format( + params["first_dim_index"], params["second_dim_index"] + ) + ) + + solver_lines += ( + ''' + std::ofstream out("''' + + self.func_folder + + """solver_result.txt"); std::vector> outer1, outer2,outer3; tie( outer1, outer2, outer3 )= auto_skewing_result; if (outer1.size()>0){ @@ -274,19 +308,22 @@ def call_solver(self, comp, params): out << outer3.front().second << std::endl; } - ''' + """ + ) solver_code = original_str.replace(to_replace, solver_lines) - output_file = self.func_folder + self.name + '_solver.cpp' + output_file = self.func_folder + self.name + "_solver.cpp" - with open(output_file, 'w') as f: + with open(output_file, "w") as f: f.write(solver_code) self.reset_solver_result_file() - log_message = 'Solver results for: computation {}'.format( - comp) + ' '.join([p for p in params]) + log_message = "Solver results for: computation {}".format(comp) + " ".join( + [p for p in params] + ) if CPP_File.compile_and_run_tiramisu_code( - self.config, output_file, log_message): + self.config, output_file, log_message + ): solver_result = self.read_solver_result_file() if len(solver_result) == 0: return None @@ -295,142 +332,157 @@ def call_solver(self, comp, params): else: raise InternalExecException - def evaluate_schedule(self, - optims_list, - cmd_type, - nb_executions, - initial_exec_time=None): + def evaluate_schedule( + self, optims_list, cmd_type, nb_executions, initial_exec_time=None + ): - optim_lines = '' + optim_lines = "" for optim in optims_list: - if optim.type == 'Interchange': - optim_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Skewing': - optim_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Parallelization': - optim_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Tiling': - optim_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Unrolling': - optim_lines += optim.tiramisu_optim_str + '\n' - elif optim.type == 'Reversal': - optim_lines += optim.tiramisu_optim_str + '\n' + if optim.type == "Interchange": + optim_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Skewing": + optim_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Parallelization": + optim_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Tiling": + optim_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Unrolling": + optim_lines += optim.tiramisu_optim_str + "\n" + elif optim.type == "Reversal": + optim_lines += optim.tiramisu_optim_str + "\n" codegen_code = self.original_str.replace( self.code_gen_line, - optim_lines + '\n' + self.code_gen_line.replace( - self.name, self.func_folder + self.name)) - output_file = self.func_folder + self.name + '_schedule_codegen.cpp' - with open(output_file, 'w') as f: + optim_lines + + "\n" + + self.code_gen_line.replace(self.name, self.func_folder + self.name), + ) + output_file = self.func_folder + self.name + "_schedule_codegen.cpp" + with open(output_file, "w") as f: f.write(codegen_code) - log_message = 'Applying schedule: ' + ' '.join( - [o.tiramisu_optim_str for o in optims_list]) + log_message = "Applying schedule: " + " ".join( + [o.tiramisu_optim_str for o in optims_list] + ) start_time = time.time() - if (CPP_File.compile_and_run_tiramisu_code( - self.config, output_file, log_message)): + if CPP_File.compile_and_run_tiramisu_code( + self.config, output_file, log_message + ): try: execution_times = self.get_measurements( - cmd_type, nb_executions, initial_exec_time) + cmd_type, nb_executions, initial_exec_time + ) if len(execution_times) != 0: return min(execution_times) else: return 0 except TimeOutException: print("time out exception") - return 10 * nb_executions * (initial_exec_time - if initial_exec_time else 1.0) + return ( + 10 + * nb_executions + * (initial_exec_time if initial_exec_time else 1.0) + ) else: raise InternalExecException def get_measurements(self, cmd_type, nb_executions, initial_exec_time): - os.environ['FUNC_DIR'] = ('/'.join(Path(self.file_path).parts[:-1]) - if len(Path(self.file_path).parts) > 1 else - '.') + '/' - os.environ['FILE_PATH'] = self.file_path - os.environ['FUNC_NAME'] = self.name + os.environ["FUNC_DIR"] = ( + "/".join(Path(self.file_path).parts[:-1]) + if len(Path(self.file_path).parts) > 1 + else "." + ) + "/" + os.environ["FILE_PATH"] = self.file_path + os.environ["FUNC_NAME"] = self.name if not self.wrapper_is_compiled: self.write_wrapper_code() log_message_cmd = 'printf "Compiling wrapper\n">> ${FUNC_DIR}log.txt' - CPP_File.launch_cmd(log_message_cmd, '') + CPP_File.launch_cmd(log_message_cmd, "") failed = CPP_File.launch_cmd( - self.config.tiramisu.compile_wrapper_cmd, self.file_path) + self.config.tiramisu.compile_wrapper_cmd, self.file_path + ) if failed: - print('Failed compiling wrapper') + print("Failed compiling wrapper") return self.wrapper_is_compiled = True self.reset_measurements_file() - log_message_cmd = 'printf "Running wrapper nb_exec = ' + str( - nb_executions) + '\n">> ${FUNC_DIR}log.txt' - run_wrapper_cmd = 'cd ${FUNC_DIR};\ + log_message_cmd = ( + 'printf "Running wrapper nb_exec = ' + + str(nb_executions) + + '\n">> ${FUNC_DIR}log.txt' + ) + run_wrapper_cmd = ( + "cd ${FUNC_DIR};\ ${GXX} -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ - ./${FUNC_NAME}_wrapper ' + str(nb_executions) - CPP_File.launch_cmd(log_message_cmd, '') + ./${FUNC_NAME}_wrapper " + + str(nb_executions) + ) + CPP_File.launch_cmd(log_message_cmd, "") s_time = time.time() - failed = CPP_File.launch_cmd(run_wrapper_cmd, - self.file_path, - cmd_type, nb_executions, - initial_exec_time) + failed = CPP_File.launch_cmd( + run_wrapper_cmd, self.file_path, cmd_type, nb_executions, initial_exec_time + ) if failed: - print('Failed running wrapper') + print("Failed running wrapper") return return self.read_measurements_file() def write_wrapper_code(self): - buffers_init_lines = '' + buffers_init_lines = "" for i, buffer_name in enumerate(self.IO_buffer_names): - buffers_init_lines += f''' + buffers_init_lines += f""" double *c_{buffer_name} = (double*)malloc({'*'.join(self.buffer_sizes[i][::-1])}* sizeof(double)); parallel_init_buffer(c_{buffer_name}, {'*'.join(self.buffer_sizes[i][::-1])}, (double){str(random.randint(1,10))}); Halide::Buffer {buffer_name}(c_{buffer_name}, {','.join(self.buffer_sizes[i][::-1])}); - ''' - wrapper_cpp_code = self.wrapper_cpp_template.replace( - '$func_name$', self.name) - wrapper_cpp_code = wrapper_cpp_code.replace('$buffers_init$', - buffers_init_lines) - wrapper_cpp_code = wrapper_cpp_code.replace('$func_folder_path$', - self.func_folder) + """ + wrapper_cpp_code = self.wrapper_cpp_template.replace("$func_name$", self.name) + wrapper_cpp_code = wrapper_cpp_code.replace( + "$buffers_init$", buffers_init_lines + ) + wrapper_cpp_code = wrapper_cpp_code.replace( + "$func_folder_path$", self.func_folder + ) wrapper_cpp_code = wrapper_cpp_code.replace( - '$func_params$', - ','.join([name + '.raw_buffer()' - for name in self.IO_buffer_names])) - output_file = self.func_folder + self.name + '_wrapper.cpp' - with open(output_file, 'w') as f: + "$func_params$", + ",".join([name + ".raw_buffer()" for name in self.IO_buffer_names]), + ) + output_file = self.func_folder + self.name + "_wrapper.cpp" + with open(output_file, "w") as f: f.write(wrapper_cpp_code) - wrapper_h_code = self.wrapper_h_template.replace( - '$func_name$', self.name) + wrapper_h_code = self.wrapper_h_template.replace("$func_name$", self.name) wrapper_h_code = wrapper_h_code.replace( - '$func_params$', ','.join( - ['halide_buffer_t *' + name for name in self.IO_buffer_names])) - output_file = self.func_folder + self.name + '_wrapper.h' - with open(output_file, 'w') as f: + "$func_params$", + ",".join(["halide_buffer_t *" + name for name in self.IO_buffer_names]), + ) + output_file = self.func_folder + self.name + "_wrapper.h" + with open(output_file, "w") as f: f.write(wrapper_h_code) def read_legality_check_result_file(self): - with open(self.func_folder + "legality_check_result.txt", 'r') as f: + with open(self.func_folder + "legality_check_result.txt", "r") as f: res = int(f.read()) return res def reset_legality_check_result_file(self): - with open(self.func_folder + "legality_check_result.txt", 'w') as f: - f.write('-1') + with open(self.func_folder + "legality_check_result.txt", "w") as f: + f.write("-1") def read_measurements_file(self): - with open(self.func_folder + "measurements_file.txt", 'r') as f: + with open(self.func_folder + "measurements_file.txt", "r") as f: res = [float(i) for i in f.read().split()] return res def reset_measurements_file(self): - with open(self.func_folder + "measurements_file.txt", 'w') as f: - f.write('-1') + with open(self.func_folder + "measurements_file.txt", "w") as f: + f.write("-1") def read_solver_result_file(self): - with open(self.func_folder + "solver_result.txt", 'r') as f: + with open(self.func_folder + "solver_result.txt", "r") as f: res = f.readlines() return res def reset_solver_result_file(self): - with open(self.func_folder + "solver_result.txt", 'w') as f: - f.write('-1') + with open(self.func_folder + "solver_result.txt", "w") as f: + f.write("-1") diff --git a/train_ppo.py b/train_ppo.py index 6cec909..56d23dd 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -23,14 +23,29 @@ def get_arguments(): parser = argparse.ArgumentParser() - parser.add_argument("--num-workers", default=-1, type=int, - help="Number of workers to use for training") - parser.add_argument('--resume-training', - action=argparse.BooleanOptionalAction, help="Resume training from a saved checkpoint") - parser.add_argument("--use-dataset", action=argparse.BooleanOptionalAction, - help="Use the dataset (path specified in config) to train") - parser.add_argument("--log-level", default="INFO", # TODO change back to WARN - type=str, choices=list(logging._nameToLevel.keys()), help="Log levels") + parser.add_argument( + "--num-workers", + default=-1, + type=int, + help="Number of workers to use for training", + ) + parser.add_argument( + "--resume-training", + action=argparse.BooleanOptionalAction, + help="Resume training from a saved checkpoint", + ) + parser.add_argument( + "--use-dataset", + action=argparse.BooleanOptionalAction, + help="Use the dataset (path specified in config) to train", + ) + parser.add_argument( + "--log-level", + default="INFO", # TODO change back to WARN + type=str, + choices=list(logging._nameToLevel.keys()), + help="Log levels", + ) return parser.parse_args() @@ -39,28 +54,36 @@ def main(config: RLAutoSchedulerConfig): logging.basicConfig(level=config.ray.log_level) local_dir = os.path.join(config.ray.base_path, config.ray.log_directory) - dataset_path = config.environment.json_dataset[ - 'path'] if config.environment.use_dataset else config.environment.dataset_path + dataset_path = ( + config.environment.json_dataset["path"] + if config.environment.use_dataset + else config.environment.dataset_path + ) dataset_actor = DatasetAgent.remote( - dataset_path=dataset_path, use_dataset=config.environment.use_dataset, path_to_save_dataset=config.environment.json_dataset['path_to_save_dataset'], dataset_format=config.environment.json_dataset['dataset_format']) + dataset_path=dataset_path, + use_dataset=config.environment.use_dataset, + path_to_save_dataset=config.environment.json_dataset["path_to_save_dataset"], + dataset_format=config.environment.json_dataset["dataset_format"], + ) register_env( "Tiramisu_env_v1", lambda a: TiramisuScheduleEnvironment(config, dataset_actor), ) - ModelCatalog.register_custom_model("tiramisu_model_v1", - TiramisuModelMult) + ModelCatalog.register_custom_model("tiramisu_model_v1", TiramisuModelMult) # Use all available CPUs as workers (-1 for the head) if config.ray.num_workers == -1: - config.ray.num_workers = int(ray.available_resources()['CPU'])-1 + config.ray.num_workers = int(ray.available_resources()["CPU"]) - 1 logging.info(f"==================== # Used CPU:{config.ray.num_workers}") config_dict = { "env": "Tiramisu_env_v1", "num_workers": config.ray.num_workers, "placement_strategy": "SPREAD", "batch_mode": "complete_episodes", - "train_batch_size": max(config.ray.num_workers * 200, config.training.train_batch_size), + "train_batch_size": max( + config.ray.num_workers * 200, config.training.train_batch_size + ), "sgd_minibatch_size": config.training.sgd_minibatch_size, "lr": config.training.lr, "num_sgd_iter": config.training.num_sgd_iter, @@ -77,9 +100,7 @@ def main(config: RLAutoSchedulerConfig): if config.ray.resume_training: print(f"Resuming training from: {local_dir}/{config.ray.name}") - tuner = tune.Tuner.restore( - path=f"{local_dir}/{config.ray.name}" - ) + tuner = tune.Tuner.restore(path=f"{local_dir}/{config.ray.name}") else: tuner = tune.Tuner( "PPO", @@ -89,12 +110,10 @@ def main(config: RLAutoSchedulerConfig): stop={"training_iteration": config.ray.training_iteration}, name=config.ray.name, verbose=0, - failure_config=air.FailureConfig( - max_failures=0 - ), + failure_config=air.FailureConfig(max_failures=0), checkpoint_config=air.CheckpointConfig( checkpoint_frequency=config.ray.checkpoint_freq, - ) + ), ), ) results = tuner.fit() @@ -116,9 +135,10 @@ def main(config: RLAutoSchedulerConfig): if args.use_dataset: config.environment.use_dataset = args.use_dataset - if config.tiramisu.env_type == 'cpu': + if config.tiramisu.env_type == "cpu": logging.warning( - "DATASET LEARNINING IS INCOMPATIBLE WITH CPU LEARNING. SWITCHING TO MODEL") + "DATASET LEARNINING IS INCOMPATIBLE WITH CPU LEARNING. SWITCHING TO MODEL" + ) # Force model usage if using dataset config.tiramisu.env_type = "model" diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index 8ec7df7..40623c0 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -6,10 +6,10 @@ import numpy as np import ray -SAVING_FREQUENCY = 1000 +SAVING_FREQUENCY = 10000 -class DataSetFormat(): +class DataSetFormat: PICKLE = "PICKLE" JSON = "JSON" BZ2 = "BZ2" @@ -17,7 +17,14 @@ class DataSetFormat(): @ray.remote class DatasetAgent: - def __init__(self, dataset_path, path_to_save_dataset, dataset_format, use_dataset=False, shuffle=False): + def __init__( + self, + dataset_path, + path_to_save_dataset, + dataset_format, + use_dataset=False, + shuffle=False, + ): self.dataset_path = dataset_path self.path_to_save_dataset = path_to_save_dataset self.dataset_format = dataset_format @@ -26,27 +33,26 @@ def __init__(self, dataset_path, path_to_save_dataset, dataset_format, use_datas self.dataset = {} self.function_names = [] self.nbr_updates = 0 - self.dataset_name = dataset_path.split('/')[-1].split('.')[0] + self.dataset_name = dataset_path.split("/")[-1].split(".")[0] if use_dataset: print(f"reading dataset from json at:{dataset_path}") match dataset_format: case DataSetFormat.PICKLE: - with open(dataset_path, 'rb') as f: + with open(dataset_path, "rb") as f: self.dataset = pickle.load(f) self.function_names = list(self.dataset.keys()) case DataSetFormat.JSON: - with open(dataset_path, 'rb') as f: + with open(dataset_path, "rb") as f: self.dataset = json.load(f) self.function_names = list(self.dataset.keys()) case DataSetFormat.BZ2: - with bz2.BZ2File(dataset_path, 'rb') as f: + with bz2.BZ2File(dataset_path, "rb") as f: self.dataset = pickle.load(f) self.function_names = list(self.dataset.keys()) case _: raise ValueError("Format specified not supported") - print( - f"[Done] reading dataset from json at:{dataset_path}") + print(f"[Done] reading dataset from json at:{dataset_path}") else: print(f"reading data from ls at: {dataset_path}") @@ -61,9 +67,9 @@ def get_next_function(self): return function_name, self.dataset[function_name] else: return function_name, { - 'program_annotation': None, - 'schedules_legality_dict': {}, - 'schedules_solver_results_dict': {} + "program_annotation": None, + "schedules_legality_dict": {}, + "schedules_solver_results_dict": {}, } def update_dataset(self, function_name, function_dict): @@ -71,7 +77,7 @@ def update_dataset(self, function_name, function_dict): self.nbr_updates += 1 print(f"# updates: {self.nbr_updates}") if self.nbr_updates % SAVING_FREQUENCY == 0: - if self.nbr_updates % (2*SAVING_FREQUENCY): + if self.nbr_updates % (2 * SAVING_FREQUENCY): self.save_dataset_to_disk(version=2) else: self.save_dataset_to_disk(version=1) @@ -79,19 +85,19 @@ def update_dataset(self, function_name, function_dict): def save_dataset_to_disk(self, version=1): print("[Start] Save the legality_annotations_dict to disk") - updated_dataset_name = f"{self.path_to_save_dataset}/{self.dataset_name}_updated_{version}" + updated_dataset_name = ( + f"{self.path_to_save_dataset}/{self.dataset_name}_updated_{version}" + ) match self.dataset_format: case DataSetFormat.PICKLE: with open(f"{updated_dataset_name}.pkl", "wb") as f: - pickle.dump(self.dataset, f, - protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump(self.dataset, f, protocol=pickle.HIGHEST_PROTOCOL) case DataSetFormat.JSON: with open(f"{updated_dataset_name}.json", "w") as f: json.dump(self.dataset, f) case DataSetFormat.BZ2: - with bz2.BZ2File(f"{updated_dataset_name}.bz2.pkl", 'wb') as f: - pickle.dump(self.dataset, f, - protocol=pickle.HIGHEST_PROTOCOL) + with bz2.BZ2File(f"{updated_dataset_name}.bz2.pkl", "wb") as f: + pickle.dump(self.dataset, f, protocol=pickle.HIGHEST_PROTOCOL) case _: raise ValueError("Format specified not supported") print("[Done] Save the legality_annotations_dict to disk") diff --git a/utils/environment_variables.py b/utils/environment_variables.py index 040b4a7..1d0d817 100644 --- a/utils/environment_variables.py +++ b/utils/environment_variables.py @@ -1,5 +1,6 @@ import os + def configure_env_variables(config): # Put the path to your tiramisu installation here os.environ["TIRAMISU_ROOT"] = config.tiramisu.tiramisu_path @@ -7,4 +8,4 @@ def configure_env_variables(config): # The two environment variables below are set to 1 to avoid a Docker container error os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1" os.environ["RAY_ALLOW_SLOW_STORAGE"] = "1" - os.environ['TUNE_RESULT_DIR'] = os.getcwd() \ No newline at end of file + os.environ["TUNE_RESULT_DIR"] = os.getcwd() diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 0645472..70f179e 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -26,12 +26,14 @@ class EnvironmentConfig: dataset_path: str = "../../Dataset_multi/" programs_file: str = "./multicomp.json" clean_files: bool = True - json_dataset: dict = field(default_factory=lambda: { - "path": None, - "cpps_path": None, - "path_to_save_sataset": None, - "dataset_format": DataSetFormat.PICKLE - }) + json_dataset: dict = field( + default_factory=lambda: { + "path": None, + "cpps_path": None, + "path_to_save_sataset": None, + "dataset_format": DataSetFormat.PICKLE, + } + ) use_dataset: bool = False @@ -47,9 +49,9 @@ class TiramisuConfig: run_tiramisu_cmd: str = 'printf "Running ${FILE_PATH}.out\n">> ${FUNC_DIR}log.txt;\ ./${FILE_PATH}.out>> ${FUNC_DIR}log.txt;' - compile_wrapper_cmd = 'cd ${FUNC_DIR};\ + compile_wrapper_cmd = "cd ${FUNC_DIR};\ ${GXX} -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ - ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl" @dataclass @@ -62,10 +64,8 @@ class TrainingConfig: @dataclass class ModelConfig: - layer_sizes: List[int] = field( - default_factory=lambda: [600, 350, 200, 180]) - drops: List[float] = field( - default_factory=lambda: [0.225, 0.225, 0.225, 0.225]) + layer_sizes: List[int] = field(default_factory=lambda: [600, 350, 200, 180]) + drops: List[float] = field(default_factory=lambda: [0.225, 0.225, 0.225, 0.225]) @dataclass From 7ce563ae7166607ac00e564e1eb604c5d1259a17 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Thu, 16 Mar 2023 15:10:43 +0400 Subject: [PATCH 25/27] added sequential data --- utils/dataset_utilities.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index 40623c0..5a8d6f6 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -34,6 +34,8 @@ def __init__( self.function_names = [] self.nbr_updates = 0 self.dataset_name = dataset_path.split("/")[-1].split(".")[0] + self.current_function = 0 + self.dataset_size = 0 if use_dataset: print(f"reading dataset from json at:{dataset_path}") @@ -61,8 +63,22 @@ def __init__( if self.shuffle: random.shuffle(self.function_names) - def get_next_function(self): - function_name = np.random.choice(self.function_names) + self.dataset_size = len(self.function_names) + + def get_next_function(self, random=False): + if random: + function_name = np.random.choice(self.function_names) + else: + function_name = self.function_names[ + self.current_function % self.dataset_size + ] + self.current_function += 1 + + print( + f"Selected function with index: {self.current_function}, name: {function_name}" + ) + print(self.function_names[:10]) + if self.use_dataset: return function_name, self.dataset[function_name] else: From b0ca8df79ce16aeb4aeb866099e75f74dcf4431b Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Tue, 28 Mar 2023 04:27:50 -0400 Subject: [PATCH 26/27] added seed, comments and types --- utils/rl_autoscheduler_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 70f179e..51344eb 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -43,15 +43,15 @@ class TiramisuConfig: env_type: Literal["model", "cpu"] = "cpu" model_checkpoint: str = "/data/scratch/hbenyamina/model_published_nn_finale.pt" compile_tiramisu_cmd: str = 'printf "Compiling ${FILE_PATH}\n" >> ${FUNC_DIR}log.txt;\ - ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ - ${CXX} -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/install/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++17 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ + ${CXX} -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++17 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/install/lib64 -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/install/lib64:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' run_tiramisu_cmd: str = 'printf "Running ${FILE_PATH}.out\n">> ${FUNC_DIR}log.txt;\ ./${FILE_PATH}.out>> ${FUNC_DIR}log.txt;' compile_wrapper_cmd = "cd ${FUNC_DIR};\ ${GXX} -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ - ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl" + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/install/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++17 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/install/lib64 -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/install/lib64:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl" @dataclass From cfc4024dac755e546f118107f88cde9a75d8c699 Mon Sep 17 00:00:00 2001 From: Smail KOURTA Date: Tue, 28 Mar 2023 04:30:48 -0400 Subject: [PATCH 27/27] added seed, comments and types and reverted tiramisu config --- utils/dataset_utilities.py | 62 +++++++++++++++++++++++++++----- utils/rl_autoscheduler_config.py | 6 ++-- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/utils/dataset_utilities.py b/utils/dataset_utilities.py index 5a8d6f6..a2dc277 100644 --- a/utils/dataset_utilities.py +++ b/utils/dataset_utilities.py @@ -3,12 +3,15 @@ import os import pickle import random +from typing import Tuple import numpy as np import ray +# Frequency at which the dataset is saved to disk SAVING_FREQUENCY = 10000 +# Enum for the dataset format class DataSetFormat: PICKLE = "PICKLE" JSON = "JSON" @@ -17,13 +20,31 @@ class DataSetFormat: @ray.remote class DatasetAgent: + """ + DatasetAgent is a class that is used to read the dataset and update it. + It is used to read the dataset from disk and update it with the new functions. + It is also used to save the dataset to disk. + + There are currently two modes of operation: + 1. use_dataset = True: In this mode, the dataset is read a pickle file on disk and the functions are returned from the dataset. + 2. use_dataset = False: In this mode, the dataset is not used and the functions are returned as placeholders. The list of functions is read from the disk using `ls`. + + :param dataset_path: path to the dataset + :param path_to_save_dataset: path to save the dataset + :param dataset_format: format of the dataset (PICKLE, JSON, BZ2) + :param use_dataset: whether to use the dataset or not + :param shuffle: whether to shuffle the dataset or not + :param seed: seed for the random number generator + """ + def __init__( self, - dataset_path, - path_to_save_dataset, - dataset_format, + dataset_path: str, + path_to_save_dataset: str, + dataset_format: DataSetFormat.BZ2 | DataSetFormat.JSON | DataSetFormat.PICKLE, use_dataset=False, shuffle=False, + seed=None, ): self.dataset_path = dataset_path self.path_to_save_dataset = path_to_save_dataset @@ -36,6 +57,7 @@ def __init__( self.dataset_name = dataset_path.split("/")[-1].split(".")[0] self.current_function = 0 self.dataset_size = 0 + self.seed = seed if use_dataset: print(f"reading dataset from json at:{dataset_path}") @@ -60,14 +82,20 @@ def __init__( print(f"reading data from ls at: {dataset_path}") self.function_names = os.listdir(dataset_path) + # Shuffle the dataset (can be used with random sampling turned off to get a random order) if self.shuffle: + # Set the seed if specified (for reproducibility) + if self.seed is not None: + random.seed(self.seed) random.shuffle(self.function_names) self.dataset_size = len(self.function_names) - def get_next_function(self, random=False): + def get_next_function(self, random=False) -> Tuple[str, dict]: + # Choose a random function if random: function_name = np.random.choice(self.function_names) + # Choose the next function sequentially else: function_name = self.function_names[ self.current_function % self.dataset_size @@ -77,10 +105,12 @@ def get_next_function(self, random=False): print( f"Selected function with index: {self.current_function}, name: {function_name}" ) - print(self.function_names[:10]) + # If we are using the dataset, return the function from the dataset if self.use_dataset: return function_name, self.dataset[function_name] + + # If we are not using the dataset, return placeholders else: return function_name, { "program_annotation": None, @@ -88,17 +118,31 @@ def get_next_function(self, random=False): "schedules_solver_results_dict": {}, } - def update_dataset(self, function_name, function_dict): + # Update the dataset with the new function + def update_dataset(self, function_name: str, function_dict: dict) -> bool: + """ + Update the dataset with the new function + :param function_name: name of the function + :param function_dict: dictionary containing the function information + :return: True if the dataset was saved successfully + """ self.dataset[function_name] = function_dict self.nbr_updates += 1 print(f"# updates: {self.nbr_updates}") if self.nbr_updates % SAVING_FREQUENCY == 0: if self.nbr_updates % (2 * SAVING_FREQUENCY): - self.save_dataset_to_disk(version=2) + return self.save_dataset_to_disk(version=2) else: - self.save_dataset_to_disk(version=1) + return self.save_dataset_to_disk(version=1) + return False - def save_dataset_to_disk(self, version=1): + # Save the dataset to disk + def save_dataset_to_disk(self, version=1) -> bool: + """ + Save the dataset to disk + :param version: version of the dataset to save (1 or 2) + :return: True if the dataset was saved successfully + """ print("[Start] Save the legality_annotations_dict to disk") updated_dataset_name = ( diff --git a/utils/rl_autoscheduler_config.py b/utils/rl_autoscheduler_config.py index 51344eb..70f179e 100644 --- a/utils/rl_autoscheduler_config.py +++ b/utils/rl_autoscheduler_config.py @@ -43,15 +43,15 @@ class TiramisuConfig: env_type: Literal["model", "cpu"] = "cpu" model_checkpoint: str = "/data/scratch/hbenyamina/model_published_nn_finale.pt" compile_tiramisu_cmd: str = 'printf "Compiling ${FILE_PATH}\n" >> ${FUNC_DIR}log.txt;\ - ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/install/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++17 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ - ${CXX} -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++17 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/install/lib64 -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/install/lib64:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 -o ${FILE_PATH}.o -c ${FILE_PATH};\ + ${CXX} -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O0 ${FILE_PATH}.o -o ./${FILE_PATH}.out -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl' run_tiramisu_cmd: str = 'printf "Running ${FILE_PATH}.out\n">> ${FUNC_DIR}log.txt;\ ./${FILE_PATH}.out>> ${FUNC_DIR}log.txt;' compile_wrapper_cmd = "cd ${FUNC_DIR};\ ${GXX} -shared -o ${FUNC_NAME}.o.so ${FUNC_NAME}.o;\ - ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/install/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++17 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/install/lib64 -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/install/lib64:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl" + ${CXX} -I${TIRAMISU_ROOT}/3rdParty/Halide/include -I${TIRAMISU_ROOT}/include -I${TIRAMISU_ROOT}/3rdParty/isl/include -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -std=c++11 -O3 -o ${FUNC_NAME}_wrapper ${FUNC_NAME}_wrapper.cpp ./${FUNC_NAME}.o.so -L${TIRAMISU_ROOT}/build -L${TIRAMISU_ROOT}/3rdParty/Halide/lib -L${TIRAMISU_ROOT}/3rdParty/isl/build/lib -Wl,-rpath,${TIRAMISU_ROOT}/build:${TIRAMISU_ROOT}/3rdParty/Halide/lib:${TIRAMISU_ROOT}/3rdParty/isl/build/lib -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl" @dataclass