Skip to content

Commit

Permalink
Use distirbuted dataset for NCF evaluation. (#9009)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 323948101

Co-authored-by: A. Unique TensorFlower <gardener@tensorflow.org>
  • Loading branch information
gagika and tensorflower-gardener authored Jul 30, 2020
1 parent 3729e19 commit a84f1b9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
20 changes: 17 additions & 3 deletions official/recommendation/ncf_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
def create_dataset_from_tf_record_files(input_file_pattern,
pre_batch_size,
batch_size,
is_training=True):
is_training=True,
rebatch=False):
"""Creates dataset from (tf)records files for training/evaluation."""

files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training)
Expand Down Expand Up @@ -62,6 +63,13 @@ def make_dataset(files_dataset, shard_index):
map_fn,
cycle_length=NUM_SHARDS,
num_parallel_calls=tf.data.experimental.AUTOTUNE)

if rebatch:
# A workaround for TPU Pod evaluation dataset.
# TODO (b/162341937) remove once it's fixed.
dataset = dataset.unbatch()
dataset = dataset.batch(pre_batch_size)

dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset

Expand Down Expand Up @@ -162,12 +170,18 @@ def create_ncf_input_data(params,
params["train_dataset_path"],
input_meta_data["train_prebatch_size"],
params["batch_size"],
is_training=True)
is_training=True,
rebatch=False)

# Re-batch evaluation dataset for TPU Pods.
# TODO (b/162341937) remove once it's fixed.
eval_rebatch = (params["use_tpu"] and strategy.num_replicas_in_sync > 8)
eval_dataset = create_dataset_from_tf_record_files(
params["eval_dataset_path"],
input_meta_data["eval_prebatch_size"],
params["eval_batch_size"],
is_training=False)
is_training=False,
rebatch=eval_rebatch)

num_train_steps = int(input_meta_data["num_train_steps"])
num_eval_steps = int(input_meta_data["num_eval_steps"])
Expand Down
4 changes: 3 additions & 1 deletion official/recommendation/ncf_keras_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def run_ncf(_):

params = ncf_common.parse_flags(FLAGS)
params["distribute_strategy"] = strategy
params["use_tpu"] = (FLAGS.distribution_strategy == "tpu")

if params["use_tpu"] and not params["keras_use_ctl"]:
logging.error("Custom training loop must be used when using TPUStrategy.")
Expand Down Expand Up @@ -491,7 +492,8 @@ def step_fn(features):
logging.info("Done training epoch %s, epoch loss=%.3f", epoch + 1,
train_loss)

eval_input_iterator = iter(eval_input_dataset)
eval_input_iterator = iter(
strategy.experimental_distribute_dataset(eval_input_dataset))

hr_sum = 0.0
hr_count = 0.0
Expand Down

0 comments on commit a84f1b9

Please sign in to comment.