diff --git a/dopamine/agents/dqn/dqn_agent.py b/dopamine/agents/dqn/dqn_agent.py index 2c2c9808..795d2ed1 100644 --- a/dopamine/agents/dqn/dqn_agent.py +++ b/dopamine/agents/dqn/dqn_agent.py @@ -181,6 +181,8 @@ def __init__(self, self.eval_mode = eval_mode self.training_steps = 0 self.optimizer = optimizer + if os.environ.get('TF_ENABLE_AUTO_MIXED_PRECISION', default='0') == '1': + self.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) self.summary_writer = summary_writer self.summary_writing_frequency = summary_writing_frequency self.allow_partial_reload = allow_partial_reload diff --git a/dopamine/agents/rainbow/rainbow_agent.py b/dopamine/agents/rainbow/rainbow_agent.py index 7e3e4ebb..fb7c734d 100644 --- a/dopamine/agents/rainbow/rainbow_agent.py +++ b/dopamine/agents/rainbow/rainbow_agent.py @@ -38,7 +38,7 @@ from __future__ import print_function import collections - +import os from dopamine.agents.dqn import dqn_agent @@ -127,6 +127,8 @@ def __init__(self, self._replay_scheme = replay_scheme # TODO(b/110897128): Make agent optimizer attribute private. self.optimizer = optimizer + if os.environ.get('TF_ENABLE_AUTO_MIXED_PRECISION', default='0') == '1': + self.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) dqn_agent.DQNAgent.__init__( self,