Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache task.to_str_params() to speedup Task initialization #421

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
13 changes: 13 additions & 0 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(self, *args, **kwargs):
super(TaskOnKart, self).__init__(*args, **kwargs)
self._rerun_state = self.rerun
self._lock_at_dump = True
self._str_params_cache = None

if self.complete_check_at_run:
self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore
Expand Down Expand Up @@ -373,6 +374,18 @@ def make_unique_id(self) -> str:
self.task_unique_id = unique_id
return unique_id

def to_str_params(self, only_significant=False, only_public=False):
if only_significant and (not only_public):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this condition is necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kitagry This intend to cache only when called from init.

There're two solutions

  • only cache default parameters: simple impl. , but specific use
  • cache all patterns of parameters: general use, but a little bit comprecated impl.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to name the condition for your intention. For example

init_condition = only_significant and (not only_public)
if !init_condition:
    return super().to_str_params(only_significant, only_public)

Copy link
Member Author

@Hi-king Hi-king Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kitagry Thanks! I've introduced _called_with_default_args :)

# cache to_str_params to avoid too slow task creation of deep task tree
# e.g. gokart.build(RecursiveTask(dep=RecursiveTask(dep=RecursiveTask(dep=HelloWorldTask())))) needs O(n^2) times to_str_params calls with respect to n times RecursiveTask
if self._str_params_cache is not None:
return self._str_params_cache
else:
self._str_params_cache = super().to_str_params(only_significant, only_public)
return self._str_params_cache
else:
return super().to_str_params(only_significant, only_public)

def _make_hash_id(self) -> str:
def _to_str_params(task):
if isinstance(task, TaskOnKart):
Expand Down
11 changes: 11 additions & 0 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,17 @@ def test_serialize_and_deserialize_default_values(self):
deserialized: gokart.TaskOnKart = luigi.task_register.load_task(None, task.get_task_family(), task.to_str_params())
self.assertDictEqual(task.to_str_params(), deserialized.to_str_params())

def test_to_str_params_changes_on_values_and_flags(self):
class _DummyTaskWithParams(gokart.TaskOnKart):
task_namespace = __name__
param: str = luigi.Parameter()

t1 = _DummyTaskWithParams(param='a')
self.assertEqual(t1.to_str_params(), t1.to_str_params()) # cache
self.assertEqual(t1.to_str_params(), _DummyTaskWithParams(param='a').to_str_params()) # same value
self.assertNotEqual(t1.to_str_params(), _DummyTaskWithParams(param='b').to_str_params()) # different value
self.assertNotEqual(t1.to_str_params(), t1.to_str_params(only_significant=True))

def test_should_lock_run_when_set(self):
class _DummyTaskWithLock(gokart.TaskOnKart):
def run(self):
Expand Down
Loading