Skip to content

Commit

Permalink
Bypass printing memstats when not available on pathways
Browse files Browse the repository at this point in the history
  • Loading branch information
hengtaoguo committed Jan 9, 2025
1 parent 07789f1 commit 7d84731
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,12 +963,13 @@ def train_loop(config, state=None):
# pytype: disable=attribute-error
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
compiled_stats = compiled.memory_analysis()
max_logging.log(
f"Output size: {compiled_stats.output_size_in_bytes}, "
f"temp size: {compiled_stats.temp_size_in_bytes}, "
f"argument size: {compiled_stats.argument_size_in_bytes}, "
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
)
if compiled_stats is not None:
max_logging.log(
f"Output size: {compiled_stats.output_size_in_bytes}, "
f"temp size: {compiled_stats.temp_size_in_bytes}, "
f"argument size: {compiled_stats.argument_size_in_bytes}, "
f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
)
return state


Expand Down

0 comments on commit 7d84731

Please sign in to comment.