-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathagainst_agent.py
executable file
·36 lines (29 loc) · 1.14 KB
/
against_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3
from argparse import ArgumentParser
import ray
from ray.rllib import ppo
from ray.tune.registry import register_env, get_registry
from pokebattle_rl_env import PokeBattleEnv
from pokebattle_rl_env.showdown_simulator import ShowdownSimulator
parser = ArgumentParser()
parser.add_argument('-l', '--load', type=str, help='The directory to load a trained model from')
parser.add_argument('-b', '--battles', type=int, default=1000, help='Amount of battles to test the model')
args = parser.parse_args()
env = PokeBattleEnv(ShowdownSimulator(self_play=False))
env_creator_name = "PokeBattleEnv-v0"
register_env(env_creator_name, lambda config: env)
ray.init()
config = ppo.DEFAULT_CONFIG.copy()
config['num_workers'] = 1
config['timesteps_per_batch'] = 200
config['horizon'] = 500
config['min_steps_per_task'] = 1
agent = ppo.PPOAgent(config=config, env=env_creator_name, registry=get_registry())
agent.restore(args.load)
for battle in range(args.battles):
observation = env.reset()
env.render()
done = False
while not done:
action = agent.compute_action(observation)
observation, _, done, _ = env.step(action)