-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
116 lines (95 loc) · 3.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import List
import logging
import os
import sys
from datetime import datetime
import peract_config
import hydra
from omegaconf import DictConfig, OmegaConf, ListConfig
import run_seed_fn
from helpers.observation_utils import create_obs_config
import torch.multiprocessing as mp
@hydra.main(config_name="config", config_path="conf")
def main(cfg: DictConfig) -> None:
cfg_yaml = OmegaConf.to_yaml(cfg)
logging.info("\n" + cfg_yaml)
peract_config.on_config(cfg)
cfg.rlbench.cameras = (
cfg.rlbench.cameras
if isinstance(cfg.rlbench.cameras, ListConfig)
else [cfg.rlbench.cameras]
)
# sanity check if rgb is not used as camera name
for camera_name in cfg.rlbench.cameras:
assert "rgb" not in camera_name
obs_config = create_obs_config(
cfg.rlbench.cameras, cfg.rlbench.camera_resolution, cfg.method.name
)
cwd = os.getcwd()
logging.info("CWD:" + os.getcwd())
if cfg.framework.start_seed >= 0:
# seed specified
start_seed = cfg.framework.start_seed
elif (
cfg.framework.start_seed == -1
and len(list(filter(lambda x: "seed" in x, os.listdir(cwd)))) > 0
):
# unspecified seed; use largest existing seed plus one
largest_seed = max(
[
int(n.replace("seed", ""))
for n in list(filter(lambda x: "seed" in x, os.listdir(cwd)))
]
)
start_seed = largest_seed + 1
else:
# start with seed 0
start_seed = 0
seed_folder = os.path.join(os.getcwd(), "seed%d" % start_seed)
os.makedirs(seed_folder, exist_ok=True)
start_time = datetime.now()
with open(os.path.join(seed_folder, "config.yaml"), "w") as f:
f.write(cfg_yaml)
# check if previous checkpoints already exceed the number of desired training iterations
# if so, exit the script
latest_weight = 0
weights_folder = os.path.join(seed_folder, "weights")
if os.path.isdir(weights_folder) and len(os.listdir(weights_folder)) > 0:
weights = os.listdir(weights_folder)
latest_weight = sorted(map(int, weights))[-1]
if latest_weight >= cfg.framework.training_iterations:
logging.info(
"Agent was already trained for %d iterations. Exiting." % latest_weight
)
sys.exit(0)
with open(os.path.join(seed_folder, "training.log"), "a") as f:
f.write(
f"# Starting training from weights: {latest_weight} to {cfg.framework.training_iterations}"
)
f.write(f"# Training started on: {start_time.isoformat()}")
f.write(os.linesep)
# run train jobs with multiple seeds (sequentially)
for seed in range(start_seed, start_seed + cfg.framework.seeds):
logging.info("Starting seed %d." % seed)
world_size = cfg.ddp.num_devices
mp.spawn(
run_seed_fn.run_seed,
args=(
cfg,
obs_config,
seed,
world_size,
),
nprocs=world_size,
join=True,
)
end_time = datetime.now()
duration = end_time - start_time
with open(os.path.join(seed_folder, "training.log"), "a") as f:
f.write(f"# Training finished on: {end_time.isoformat()}")
f.write(f"# Took {duration.total_seconds()}")
f.write(os.linesep)
f.write(os.linesep)
if __name__ == "__main__":
peract_config.on_init()
main()