Skip to content

Commit

Permalink
Fixed wrappers.vector.RecordEpisodeStatistics episode length computat…
Browse files Browse the repository at this point in the history
…ion from new autoreset api #1018
  • Loading branch information
TimSchneider42 committed Apr 16, 2024
1 parent 94a7909 commit 760aabc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
15 changes: 8 additions & 7 deletions gymnasium/wrappers/vector/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self.episode_start_times: np.ndarray = np.zeros(())
self.episode_returns: np.ndarray = np.zeros(())
self.episode_lengths: np.ndarray = np.zeros(())
self.prev_dones: np.ndarray = np.zeros((), dtype=bool)

self.time_queue = deque(maxlen=buffer_length)
self.return_queue = deque(maxlen=buffer_length)
Expand All @@ -99,6 +100,7 @@ def reset(
self.episode_start_times = np.full(self.num_envs, time.perf_counter())
self.episode_returns = np.zeros(self.num_envs)
self.episode_lengths = np.zeros(self.num_envs)
self.prev_dones = np.zeros(self.num_envs, dtype=bool)

return obs, info

Expand All @@ -118,10 +120,13 @@ def step(
infos, dict
), f"`vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is {type(infos)}. This may be due to usage of other wrappers in the wrong order."

self.episode_returns += rewards
self.episode_lengths += 1
self.episode_returns[self.prev_dones] = 0
self.episode_lengths[self.prev_dones] = 0
self.episode_start_times[self.prev_dones] = time.perf_counter()
self.episode_returns[~self.prev_dones] += rewards[~self.prev_dones]
self.episode_lengths[~self.prev_dones] += 1

dones = np.logical_or(terminations, truncations)
self.prev_dones = dones = np.logical_or(terminations, truncations)
num_dones = np.sum(dones)

if num_dones:
Expand All @@ -147,10 +152,6 @@ def step(
self.return_queue.extend(self.episode_returns[i])
self.length_queue.extend(self.episode_lengths[i])

self.episode_lengths[dones] = 0
self.episode_returns[dones] = 0
self.episode_start_times[dones] = time.perf_counter()

return (
observations,
rewards,
Expand Down
21 changes: 18 additions & 3 deletions tests/wrappers/vector/test_record_episode_statistics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

import gymnasium as gym
Expand Down Expand Up @@ -35,7 +36,10 @@ def test_record_episode_statistics(num_envs, env_id="CartPole-v1", num_steps=100
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)

for _ in range(num_steps):
# We keep step 0 empty, as the initial reset does not produce a done or reward signal
dones = np.zeros((num_steps + 1, num_envs), dtype=bool)
rewards = np.zeros((num_steps + 1, num_envs), dtype=float)
for t in range(1, num_steps + 1):
action = wrapper_vector_env.action_space.sample()
(
wrapper_vector_obs,
Expand All @@ -57,14 +61,25 @@ def test_record_episode_statistics(num_envs, env_id="CartPole-v1", num_steps=100
data_equivalence(wrapper_vector_terminated, vector_wrapper_terminated)
data_equivalence(wrapper_vector_truncated, vector_wrapper_truncated)

if "episode" in wrapper_vector_info:
assert "episode" in vector_wrapper_info
dones[t] = wrapper_vector_terminated | wrapper_vector_truncated
rewards[t] = wrapper_vector_reward

assert np.all(wrapper_vector_info.get("_episode", np.zeros(num_envs, dtype=bool)) == dones[t])

if "episode" in wrapper_vector_info:
wrapper_vector_time = wrapper_vector_info["episode"].pop("t")
vector_wrapper_time = vector_wrapper_info["episode"].pop("t")
assert wrapper_vector_time.shape == vector_wrapper_time.shape
assert wrapper_vector_time.dtype == vector_wrapper_time.dtype

current_episode_mask = np.concatenate([
~np.flip(np.maximum.accumulate(np.flip(dones[:t], axis=0)), axis=0),
np.ones((1, num_envs), dtype=bool)], axis=0)
current_episode_length = np.sum(current_episode_mask, axis=0) - 1
current_episode_reward = np.sum((rewards[:t + 1] * current_episode_mask)[:t + 1], axis=0)
assert np.all(wrapper_vector_info["episode"]["l"][dones[t]] == current_episode_length[dones[t]])
assert np.all(wrapper_vector_info["episode"]["r"][dones[t]] == current_episode_reward[dones[t]])

data_equivalence(wrapper_vector_info, vector_wrapper_info)

wrapper_vector_env.close()
Expand Down

0 comments on commit 760aabc

Please sign in to comment.