From 7d84731a2886104751f8bcb6c158990a45d2cc14 Mon Sep 17 00:00:00 2001 From: Hengtao Guo Date: Thu, 9 Jan 2025 20:44:25 +0000 Subject: [PATCH] Bypass printing memstats when not available on pathways --- MaxText/train.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/MaxText/train.py b/MaxText/train.py index 8c3ed1331..711a5e82d 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -963,12 +963,13 @@ def train_loop(config, state=None): # pytype: disable=attribute-error compiled = p_train_step.lower(state, example_batch, nextrng).compile() compiled_stats = compiled.memory_analysis() - max_logging.log( - f"Output size: {compiled_stats.output_size_in_bytes}, " - f"temp size: {compiled_stats.temp_size_in_bytes}, " - f"argument size: {compiled_stats.argument_size_in_bytes}, " - f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes." - ) + if compiled_stats is not None: + max_logging.log( + f"Output size: {compiled_stats.output_size_in_bytes}, " + f"temp size: {compiled_stats.temp_size_in_bytes}, " + f"argument size: {compiled_stats.argument_size_in_bytes}, " + f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes." + ) return state