We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, it seems like the documented returns shapes for the following functions might be off:
retrace_ops.retrace(...)
retrace_ops.retrace_core(...)
retrace_ops._general_off_policy_corrected_multistep_target(...)
The first two are documented to return shape [B] and third shape [T, B, num_actions], while they all appear to return [T, B].
import numpy as np import tensorflow as tf from trfl import retrace_ops, indexing_ops ### Example input data: # https://github.com/deepmind/trfl/blob/08ccb293edb929d6002786f1c0c177ef291f2956/trfl/retrace_ops_test.py#L41 lambda_ = 0.9 qs = [ [[2.2, 3.2, 4.2], [5.2, 6.2, 7.2]], [[7.2, 6.2, 5.2], [4.2, 3.2, 2.2]], [[3.2, 5.2, 7.2], [4.2, 6.2, 9.2]], [[2.2, 8.2, 4.2], [9.2, 1.2, 8.2]] ] targnet_qs = [ [[2., 3., 4.], [5., 6., 7.]], [[7., 6., 5.], [4., 3., 2.]], [[3., 5., 7.], [4., 6., 9.]], [[2., 8., 4.], [9., 1., 8.]] ] actions = [ [2, 0], [1, 2], [0, 1], [2, 0] ] rewards = [ [1.9, 2.9], [3.9, 4.9], [5.9, 6.9], [np.nan, np.nan] # nan marks entries we should never use. ] pcontinues = [ [0.8, 0.9], [0.7, 0.8], [0.6, 0.5], [np.nan, np.nan] ] target_policy_probs = [ [[np.nan] * 3, [np.nan] * 3], [[0.41, 0.28, 0.31], [0.19, 0.77, 0.04]], [[0.22, 0.44, 0.34], [0.14, 0.25, 0.61]], [[0.16, 0.72, 0.12], [0.33, 0.30, 0.37]] ] behaviour_policy_probs = [ [np.nan, np.nan], [0.85, 0.86], [0.87, 0.88], [0.89, 0.84] ] ### Retrace Test: ### retrace = retrace_ops.retrace( lambda_, qs, targnet_qs, actions, rewards, pcontinues, target_policy_probs, behaviour_policy_probs) # qs: shape [(T+1), B, num_actions] # https://github.com/deepmind/trfl/blob/08ccb293edb929d6002786f1c0c177ef291f2956/trfl/retrace_ops.py#L85 T = len(qs) - 1 # sequence length B = len(qs[0]) # batch dimension N = len(qs[0][0]) # number of actions # loss: documented shape [B] # https://github.com/deepmind/trfl/blob/08ccb293edb929d6002786f1c0c177ef291f2956/trfl/retrace_ops.py#L121 tf.debugging.assert_equal(retrace.loss.shape, [T, B]) # succeeds ### Multi-step target Test: ### timesteps = tf.shape(qs)[0] # Batch size is qs_shape[1]. timestep_indices_tm1 = tf.range(0, timesteps - 1) timestep_indices_t = tf.range(1, timesteps) target_policy_t = tf.gather(target_policy_probs, timestep_indices_t) behaviour_policy_t = tf.gather(behaviour_policy_probs, timestep_indices_t) a_t = tf.gather(actions, timestep_indices_t) r_t = tf.gather(rewards, timestep_indices_tm1) pcont_t = tf.gather(pcontinues, timestep_indices_tm1) targnet_q_t = tf.gather(targnet_qs, timestep_indices_t) c_t = retrace_ops._retrace_weights( indexing_ops.batched_index(target_policy_t, a_t), behaviour_policy_t) * lambda_ target = retrace_ops._general_off_policy_corrected_multistep_target( r_t, pcont_t, target_policy_t, c_t, targnet_q_t, a_t ) # target: documented shape [T, B, N] # https://github.com/deepmind/trfl/blob/08ccb293edb929d6002786f1c0c177ef291f2956/trfl/retrace_ops.py#L241 tf.debugging.assert_equal(target.shape, [T, B]) # succeeds
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi, it seems like the documented returns shapes for the following functions might be off:
retrace_ops.retrace(...)
retrace_ops.retrace_core(...)
retrace_ops._general_off_policy_corrected_multistep_target(...)
The first two are documented to return shape [B] and third shape [T, B, num_actions], while they all appear to return [T, B].
Some test code to check.
The text was updated successfully, but these errors were encountered: