forked from wuxiaohua1011/ROAR_gym
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunner_local_planner_env.py
74 lines (63 loc) · 3.02 KB
/
runner_local_planner_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import warnings
import logging
from typing import Optional, Dict
logging.getLogger("tensorflow").setLevel(logging.ERROR)
logging.getLogger("numpy").setLevel(logging.ERROR)
warnings.filterwarnings('ignore')
import os
import sys
from pathlib import Path
sys.path.append(Path(os.getcwd()).parent.as_posix())
import gym
import ROAR_Gym
from ROAR_Sim.configurations.configuration import Configuration as CarlaConfig
from ROAR.configurations.configuration import Configuration as AgentConfig
from ROAR.agent_module.agent import Agent
from ROAR.agent_module.rl_local_planner_agent import RLLocalPlannerAgent
from stable_baselines.ddpg.policies import LnCnnPolicy, LnMlpPolicy
from stable_baselines import DDPG
from datetime import datetime
from stable_baselines.common.callbacks import CheckpointCallback, EveryNTimesteps, CallbackList
from utilities import find_latest_model
try:
from ROAR_Gym.envs.roar_env import LoggingCallback
except:
from ROAR_Gym.ROAR_Gym.envs.roar_env import LoggingCallback
def main(output_folder_path: Path):
# Set gym-carla environment
agent_config = AgentConfig.parse_file(Path("configurations/agent_configuration.json"))
carla_config = CarlaConfig.parse_file(Path("configurations/carla_configuration.json"))
params = {
"agent_config": agent_config,
"carla_config": carla_config,
"ego_agent_class": RLLocalPlannerAgent,
"max_collision": 5,
"rl_pid_model_file_path": Path(os.getcwd()).parent / "ROAR_Sim" /
"data" / "weights" / "rl_pid_model.zip"
}
env = gym.make('roar-local-planner-v0', params=params)
env.reset()
model_params: dict = {
"verbose": 1,
"render": True,
"tensorboard_log": (output_folder_path / "tensorboard").as_posix()
}
latest_model_path = find_latest_model(Path(output_folder_path))
if latest_model_path is None:
model = DDPG(LnCnnPolicy, env=env, **model_params) # full tensorboard log can take up space quickly
else:
model = DDPG.load(latest_model_path, env=env, **model_params)
model.render = True
model.tensorboard_log = (output_folder_path / "tensorboard").as_posix()
logging_callback = LoggingCallback(model=model)
checkpoint_callback = CheckpointCallback(save_freq=1000, verbose=2, save_path=(output_folder_path / "logs").as_posix())
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_callback)
callbacks = CallbackList([checkpoint_callback, event_callback, logging_callback])
model = model.learn(total_timesteps=int(1e10), callback=callbacks, reset_num_timesteps=False)
model.save(f"pid_ddpg_{datetime.now()}")
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt="%H:%M:%S", level=logging.INFO)
logging.getLogger("Controller").setLevel(logging.ERROR)
logging.getLogger("SimplePathFollowingLocalPlanner").setLevel(logging.ERROR)
main(output_folder_path=Path(os.getcwd()) / "output" / "local_planner")