-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_acmpc_dynamical_system_args.py
113 lines (95 loc) · 3.97 KB
/
test_acmpc_dynamical_system_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import env
from argparse import ArgumentParser
import gymnasium as gym
import numpy as np
from wrapper import RelativeRedundant
from stable_baselines3 import PPO
from system import DynamicalSystem
WINDOW_SIZE = 512
def str_2_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ValueError("Boolean value expected.")
def main(args):
window_size = WINDOW_SIZE
size = args.size
device = args.device
model_name = args.model_name
agent_location_noise_level = args.agent_location_noise_level
agent_velocity_noise_level = args.agent_velocity_noise_level
target_location_noise_level = args.target_location_noise_level
target_velocity_noise_level = args.target_velocity_noise_level
# System parameters
dt = args.dt
random_force_probability = args.random_force_probability
random_force_magnitude = args.random_force_magnitude
friction_coefficient = args.friction_coefficient
wind_gust = [args.wind_gust_x, args.wind_gust_y]
wind_gust_region = [
[args.wind_gust_region_x_min, args.wind_gust_region_x_max],
[args.wind_gust_region_y_min, args.wind_gust_region_y_max],
]
# Environment parameters
distance_threshold = args.distance_threshold
# Create system
system = DynamicalSystem(
dt=dt,
size=size,
random_force_probability=random_force_probability,
random_force_magnitude=random_force_magnitude,
friction_coefficient=friction_coefficient,
wind_gust=wind_gust,
wind_gust_region=wind_gust_region,
device=device,
)
# Create environment
env = gym.make(
"DynamicalSystem-v0",
render_mode="human",
size=size,
window_size=window_size,
distance_threshold=distance_threshold,
system=system,
agent_location_noise_level=agent_location_noise_level,
agent_velocity_noise_level=agent_velocity_noise_level,
target_location_noise_level=target_location_noise_level,
target_velocity_noise_level=target_velocity_noise_level,
)
env = RelativeRedundant(env)
# Create model
model = PPO.load(model_name, device=device)
while True:
obs, _ = env.reset()
done = False
while not done:
action, _state = model.predict(obs[np.newaxis], deterministic=True)
obs, reward, done, _, information = env.step(action.squeeze(0))
# print(reward)
env.render()
if __name__ == "__main__":
argprs = ArgumentParser()
argprs.add_argument("--size", type=int, default=10)
argprs.add_argument("--device", type=str, default="cpu")
argprs.add_argument("--model_name", type=str, default="model")
argprs.add_argument("--agent_location_noise_level", type=float, default=0.0)
argprs.add_argument("--agent_velocity_noise_level", type=float, default=0.0)
argprs.add_argument("--target_location_noise_level", type=float, default=0.0)
argprs.add_argument("--target_velocity_noise_level", type=float, default=0.0)
argprs.add_argument("--dt", type=float, default=0.1)
argprs.add_argument("--random_force_probability", type=float, default=0.001)
argprs.add_argument("--random_force_magnitude", type=float, default=10.0)
argprs.add_argument("--friction_coefficient", type=float, default=0.25)
argprs.add_argument("--wind_gust_x", type=float, default=0.5)
argprs.add_argument("--wind_gust_y", type=float, default=-0.5)
argprs.add_argument("--wind_gust_region_x_min", type=float, default=0.2)
argprs.add_argument("--wind_gust_region_x_max", type=float, default=0.8)
argprs.add_argument("--wind_gust_region_y_min", type=float, default=0.2)
argprs.add_argument("--wind_gust_region_y_max", type=float, default=0.8)
argprs.add_argument("--distance_threshold", type=float, default=1.0)
args = argprs.parse_args()
main(args)