Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset learning #9

Open
wants to merge 31 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
de5c010
data loading alpha version
skourta Jan 23, 2023
50658fc
removed berakpoints and fixed code style to match previous code
skourta Jan 24, 2023
6985d62
added support for the legality check of skewing
skourta Jan 24, 2023
8bdd2c2
removed breakpoint
skourta Jan 24, 2023
c84c891
working on resuming learning
skourta Jan 25, 2023
1643fc2
refactored code from later commit
skourta Jan 25, 2023
6c8b39b
fixed loading exec time from dataset and cleaned some code
skourta Jan 26, 2023
617b692
refactored checking the hostname in function for future changes
skourta Jan 26, 2023
00ff9b6
fixed invalidating exec time on init time
skourta Jan 26, 2023
3422a62
load new version of dataset and save it
skourta Jan 27, 2023
0300f24
dataset learning with one functiona at a time
skourta Feb 3, 2023
0455369
removed import bug
skourta Feb 3, 2023
d5efeca
data saviung multiple formats
skourta Feb 3, 2023
5e1f766
data saviung multiple formats
skourta Feb 3, 2023
c4c162f
added solver resutls to the dataset, fixed circular imports and remov…
skourta Feb 6, 2023
6b78258
changed saving frequency to increase performance
skourta Feb 6, 2023
c7a51a2
Fixing conflicts
skourta Feb 6, 2023
09d2c6e
Added comments explaining dataset config part
skourta Feb 6, 2023
28debb3
fixed conflicts
skourta Feb 7, 2023
a443a0a
Merge pull request #10 from Tiramisu-Compiler/load_model
skourta Feb 7, 2023
2f245d6
changed back the commands to no lz to avoid conflicts
skourta Feb 7, 2023
8968bb3
fixed call of clean cpp
skourta Feb 7, 2023
28774d3
added saving the dataset to the disk when not using dataset
skourta Feb 13, 2023
6d7f7b9
reverted ray init to multiple workers
skourta Feb 14, 2023
acfc53b
added 2 checkpoints for dataset and model.eval
skourta Feb 22, 2023
1f45759
fixed INFO to info bug in calling logging
skourta Feb 23, 2023
c6ac8cc
format with black
skourta Mar 15, 2023
7ce563a
added sequential data
skourta Mar 16, 2023
e5e4b8a
Merge pull request #12 from Tiramisu-Compiler/dataset_learning_sequen…
skourta Mar 22, 2023
b0ca8df
added seed, comments and types
skourta Mar 28, 2023
cfc4024
added seed, comments and types and reverted tiramisu config
skourta Mar 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ scripts/env.sh
Dataset*
.vscode
.idea
dataset
cpps
dataset_*
9 changes: 9 additions & 0 deletions config.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ environment:
dataset_path: "./Dataset_multi/"
programs_file: "./multicomp.json"
clean_files: True
json_dataset:
# Path to the dataset file
path: "./Dataset_pickle/full_legality_annotations_solver.pkl"
# Path to the dataset cpp files
cpps_path: "./Dataset_multi"
# Path where to save the updated dataset (without the extension, it will be inferred by the dataset format)
path_to_save_dataset: "./Dataset_pickle/full_legality_annotations_solver_updated"
# Supported formats are available int he dataset_utilities module
dataset_format: "PICKLE"


tiramisu:
Expand Down
27 changes: 16 additions & 11 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import ray.rllib.agents.ppo as ppo
import os
import json

# import hydra
import argparse
import ray

# from hydra.core.config_store import ConfigStore
from ray import tune
from ray.rllib.models.catalog import ModelCatalog
Expand All @@ -13,9 +15,12 @@
from rl_interface.model import TiramisuModelMult
from utils.environment_variables import configure_env_variables
from utils.global_ray_variables import Actor, GlobalVarActor
from utils.rl_autoscheduler_config import (RLAutoSchedulerConfig,
dict_to_config, parse_yaml_file,
read_yaml_file)
from utils.rl_autoscheduler_config import (
RLAutoSchedulerConfig,
dict_to_config,
parse_yaml_file,
read_yaml_file,
)


def get_arguments():
Expand All @@ -27,23 +32,23 @@ def get_arguments():

# @hydra.main(config_path="config", config_name="config")
def main(config: RLAutoSchedulerConfig, checkpoint=None):
if checkpoint is None: return
if checkpoint is None:
return
configure_env_variables(config)
best_checkpoint = os.path.join(config.ray.base_path, checkpoint)
with ray.init(num_cpus=config.ray.ray_num_cpus):
progs_list_registery = GlobalVarActor.remote(
config.environment.programs_file,
config.environment.dataset_path,
num_workers=config.ray.num_workers)
num_workers=config.ray.num_workers,
)
shared_variable_actor = Actor.remote(progs_list_registery)

register_env(
"Tiramisu_env_v1",
lambda a: TiramisuScheduleEnvironment(config, shared_variable_actor
),
lambda a: TiramisuScheduleEnvironment(config, shared_variable_actor),
)
ModelCatalog.register_custom_model("tiramisu_model_v1",
TiramisuModelMult)
ModelCatalog.register_custom_model("tiramisu_model_v1", TiramisuModelMult)

agent = ppo.PPOTrainer(
env="Tiramisu_env_v1",
Expand All @@ -62,7 +67,7 @@ def main(config: RLAutoSchedulerConfig, checkpoint=None):
"custom_model_config": {
"layer_sizes": list(config.model.layer_sizes),
"drops": list(config.model.drops),
}
},
},
},
)
Expand All @@ -87,7 +92,7 @@ def main(config: RLAutoSchedulerConfig, checkpoint=None):
except:
print("error", action, observation, reward, done)
continue
result["schedule_str"] = env.schedule_object.schedule_str
result["sched_str"] = env.schedule_object.sched_str
result["speedup"] = env.schedule_controller.speedup
results.append(result)
with open("results.json", "w+") as file:
Expand Down
7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
PyYAML == 6.0
gym == 0.21.0
gym==0.21.0
gymnasium==0.27.1
numpy == 1.23.1
ray[rllib] == 1.13.0
ray[rllib] == 2.2.0
sympy == 1.10.1
torch == 1.12.0
tqdm == 4.64.0
pandas == 1.5.3
tensorflow_probability == 0.19.0
5 changes: 0 additions & 5 deletions rl_interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from .action import *
from .environment import *
from .model import *
from .reward import *
from .utils import *
Loading