Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache task.to_str_params() to speedup Task initialization #421

Merged
merged 11 commits into from
Jan 12, 2025
6 changes: 6 additions & 0 deletions gokart/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import hashlib
import inspect
import os
Expand Down Expand Up @@ -118,6 +119,11 @@ def __init__(self, *args, **kwargs):
self._rerun_state = self.rerun
self._lock_at_dump = True

# Cache to_str_params to avoid slow task creation in a deep task tree.
# For example, gokart.build(RecursiveTask(dep=RecursiveTask(dep=RecursiveTask(dep=HelloWorldTask())))) results in O(n^2) calls to to_str_params.
# However, @lru_cache cannot be used as a decorator because luigi.Task employs metaclass tricks.
self.to_str_params = functools.lru_cache(maxsize=None)(self.to_str_params) # type: ignore[method-assign]

if self.complete_check_at_run:
self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore

Expand Down
11 changes: 11 additions & 0 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,17 @@ def test_serialize_and_deserialize_default_values(self):
deserialized: gokart.TaskOnKart = luigi.task_register.load_task(None, task.get_task_family(), task.to_str_params())
self.assertDictEqual(task.to_str_params(), deserialized.to_str_params())

def test_to_str_params_changes_on_values_and_flags(self):
class _DummyTaskWithParams(gokart.TaskOnKart):
task_namespace = __name__
param: str = luigi.Parameter()

t1 = _DummyTaskWithParams(param='a')
self.assertEqual(t1.to_str_params(), t1.to_str_params()) # cache
self.assertEqual(t1.to_str_params(), _DummyTaskWithParams(param='a').to_str_params()) # same value
self.assertNotEqual(t1.to_str_params(), _DummyTaskWithParams(param='b').to_str_params()) # different value
self.assertNotEqual(t1.to_str_params(), t1.to_str_params(only_significant=True))

def test_should_lock_run_when_set(self):
class _DummyTaskWithLock(gokart.TaskOnKart):
def run(self):
Expand Down
Loading