diff --git a/gokart/task.py b/gokart/task.py index ac5d7752..d260f1c1 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -16,6 +16,7 @@ from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.redis_lock import make_redis_params from gokart.target import TargetOnKart +from gokart.task_complete_check import task_complete_check_wrapper logger = getLogger(__name__) @@ -76,6 +77,9 @@ class TaskOnKart(luigi.Task): description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \ Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.', significant=False) + complete_check_at_run: bool = ExplicitBoolParameter(default=False, + description='Check if output file exists at run. If exists, run() will be skipped.', + significant=False) def __init__(self, *args, **kwargs): self._add_configuration(kwargs, 'TaskOnKart') @@ -86,6 +90,9 @@ def __init__(self, *args, **kwargs): self._rerun_state = self.rerun self._lock_at_dump = True + if self.complete_check_at_run: + self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) + def output(self): return self.make_target() diff --git a/gokart/task_complete_check.py b/gokart/task_complete_check.py new file mode 100644 index 00000000..4ab8a2b5 --- /dev/null +++ b/gokart/task_complete_check.py @@ -0,0 +1,15 @@ +from logging import getLogger +from typing import Callable + +logger = getLogger(__name__) + + +def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable): + + def wrapper(*args, **kwargs): + if complete_check_func(): + logger.warning(f'{run_func.__name__} is skipped because the task is already completed.') + return + return run_func(*args, **kwargs) + + return wrapper diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index 75b462ba..232187b4 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -561,5 +561,66 @@ def test_serialize_and_deserialize_default_values(self): self.assertDictEqual(task.to_str_params(), deserialized.to_str_params()) +class _DummyTaskWithNonCompleted(gokart.TaskOnKart): + + def dump(self, obj): + # overrive dump() to do nothing. + pass + + def run(self): + self.dump('hello') + + def complete(self): + return False + + +class _DummyTaskWithCompleted(gokart.TaskOnKart): + + def dump(self, obj): + # overrive dump() to do nothing. + pass + + def run(self): + self.dump('hello') + + def complete(self): + return True + + +class TestCompleteCheckAtRun(unittest.TestCase): + + def test_run_when_complete_check_at_run_is_false_and_task_is_not_completed(self): + task = _DummyTaskWithNonCompleted(complete_check_at_run=False) + task.dump = MagicMock() + task.run() + + # since run() is called, dump() should be called. + task.dump.assert_called_once() + + def test_run_when_complete_check_at_run_is_false_and_task_is_completed(self): + task = _DummyTaskWithCompleted(complete_check_at_run=False) + task.dump = MagicMock() + task.run() + + # even task is completed, since run() is called, dump() should be called. + task.dump.assert_called_once() + + def test_run_when_complete_check_at_run_is_true_and_task_is_not_completed(self): + task = _DummyTaskWithNonCompleted(complete_check_at_run=True) + task.dump = MagicMock() + task.run() + + # since task is not completed, when run() is called, dump() should be called. + task.dump.assert_called_once() + + def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self): + task = _DummyTaskWithCompleted(complete_check_at_run=True) + task.dump = MagicMock() + task.run() + + # since task is completed, even when run() is called, dump() should not be called. + task.dump.assert_not_called() + + if __name__ == '__main__': unittest.main()