Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 520673564
  • Loading branch information
Jiayu Ye authored and tensorflower-gardener committed Mar 30, 2023
1 parent 209a259 commit 9d3aaa0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 9 deletions.
27 changes: 20 additions & 7 deletions orbit/actions/export_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ class ExportFileManager:
customized naming and cleanup strategies.
"""

def __init__(self,
base_name: str,
max_to_keep: int = 5,
next_id_fn: Optional[Callable[[], int]] = None):
def __init__(
self,
base_name: str,
max_to_keep: int = 5,
next_id_fn: Optional[Callable[[], int]] = None,
subdirectory: Optional[str] = None,
):
"""Initializes the instance.
Args:
Expand All @@ -77,10 +80,14 @@ def __init__(self,
If not supplied, a default ID based on an incrementing counter is used.
One common alternative maybe be to use the current global step count,
for instance passing `next_id_fn=global_step.numpy`.
subdirectory: An optional subdirectory to concat after the
{base_name}-{id}. Then the file manager will manage
{base_name}-{id}/{subdirectory} files.
"""
self._base_name = os.path.normpath(base_name)
self._max_to_keep = max_to_keep
self._next_id_fn = next_id_fn or _CounterIdFn(self._base_name)
self._subdirectory = subdirectory or ''

@property
def managed_files(self):
Expand All @@ -91,7 +98,10 @@ def managed_files(self):
`ExportFileManager` instance, sorted in increasing integer order of the
IDs returned by `next_id_fn`.
"""
return _find_managed_files(self._base_name)
files = _find_managed_files(self._base_name)
return [
os.path.normpath(os.path.join(f, self._subdirectory)) for f in files
]

def clean_up(self):
"""Cleans up old files matching `{base_name}-*`.
Expand All @@ -101,12 +111,15 @@ def clean_up(self):
if self._max_to_keep < 0:
return

for filename in self.managed_files[:-self._max_to_keep]:
# Note that the base folder will remain intact, only the folder with suffix
# is deleted.
for filename in self.managed_files[: -self._max_to_keep]:
tf.io.gfile.rmtree(filename)

def next_name(self) -> str:
"""Returns a new file name based on `base_name` and `next_id_fn()`."""
return f'{self._base_name}-{self._next_id_fn()}'
base_path = f'{self._base_name}-{self._next_id_fn()}'
return os.path.normpath(os.path.join(base_path, self._subdirectory))


class ExportSavedModel:
Expand Down
52 changes: 50 additions & 2 deletions orbit/actions/export_saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_export_file_manager_default_ids(self):
directory = self.create_tempdir()
base_name = os.path.join(directory.full_path, 'basename')
manager = actions.ExportFileManager(base_name, max_to_keep=3)
self.assertLen(tf.io.gfile.listdir(directory.full_path), 0)
self.assertEmpty(tf.io.gfile.listdir(directory.full_path))
directory.create_file(manager.next_name())
manager.clean_up() # Shouldn't do anything...
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1)
Expand Down Expand Up @@ -79,7 +79,7 @@ def next_id():

manager = actions.ExportFileManager(
base_name, max_to_keep=2, next_id_fn=next_id)
self.assertLen(tf.io.gfile.listdir(directory.full_path), 0)
self.assertEmpty(tf.io.gfile.listdir(directory.full_path))
id_num = 30
directory.create_file(manager.next_name())
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1)
Expand All @@ -105,6 +105,54 @@ def next_id():
_id_sorted_file_base_names(directory.full_path),
['basename-200', 'basename-1000'])

def test_export_file_manager_with_suffix(self):
directory = self.create_tempdir()
base_name = os.path.join(directory.full_path, 'basename')

id_num = 0

def next_id():
return id_num

subdirectory = 'sub'

manager = actions.ExportFileManager(
base_name, max_to_keep=2, next_id_fn=next_id, subdirectory=subdirectory
)
self.assertEmpty(tf.io.gfile.listdir(directory.full_path))
id_num = 30
directory.create_file(manager.next_name())
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1)
manager.clean_up() # Shouldn't do anything...
self.assertEqual(
_id_sorted_file_base_names(directory.full_path), ['basename-30']
)
id_num = 200
directory.create_file(manager.next_name())
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2)
manager.clean_up() # Shouldn't do anything...
self.assertEqual(
_id_sorted_file_base_names(directory.full_path),
['basename-30', 'basename-200'],
)
id_num = 1000
directory.create_file(manager.next_name())
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3)
self.assertEqual(
_id_sorted_file_base_names(directory.full_path),
['basename-30', 'basename-200', 'basename-1000'],
)
manager.clean_up() # Should delete file with lowest ID.
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3)
# Note that the base folder is intact, only the suffix folder is deleted.
self.assertEqual(
_id_sorted_file_base_names(directory.full_path),
['basename-30', 'basename-200', 'basename-1000'],
)

step_folder = os.path.join(directory.full_path, 'basename-1000')
self.assertIn(subdirectory, tf.io.gfile.listdir(step_folder))

def test_export_file_manager_managed_files(self):
directory = self.create_tempdir()
directory.create_file('basename-5')
Expand Down

0 comments on commit 9d3aaa0

Please sign in to comment.