From a6c8ade851742ba9686522d08c61b471e5a8f815 Mon Sep 17 00:00:00 2001 From: Keisuke OGAKI Date: Wed, 1 Jan 2025 21:01:30 +0900 Subject: [PATCH] Add tests for task.to_str_params() --- gokart/task.py | 2 +- test/test_task_on_kart.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/gokart/task.py b/gokart/task.py index 43562f64..0aba0e8f 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -375,7 +375,7 @@ def make_unique_id(self) -> str: return unique_id def to_str_params(self, only_significant=False, only_public=False): - if only_significant == True and only_public == False: + if only_significant and (not only_public): # cache to_str_params to avoid too slow task creation of deep task tree # e.g. gokart.build(RecursiveTask(dep=RecursiveTask(dep=RecursiveTask(dep=HelloWorldTask())))) takes O(n*n) to_str_params calls if self._str_params_cache is not None: diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index e3946b49..78a918f3 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -582,6 +582,17 @@ def test_serialize_and_deserialize_default_values(self): deserialized: gokart.TaskOnKart = luigi.task_register.load_task(None, task.get_task_family(), task.to_str_params()) self.assertDictEqual(task.to_str_params(), deserialized.to_str_params()) + def test_to_str_params_changes_on_values_and_flags(self): + class _DummyTaskWithParams(gokart.TaskOnKart): + task_namespace = __name__ + param: str = luigi.Parameter() + + t1 = _DummyTaskWithParams(param='a') + self.assertEqual(t1.to_str_params(), t1.to_str_params()) # cache + self.assertEqual(t1.to_str_params(), _DummyTaskWithParams(param='a').to_str_params()) # same value + self.assertNotEqual(t1.to_str_params(), _DummyTaskWithParams(param='b').to_str_params()) # different value + self.assertNotEqual(t1.to_str_params(), t1.to_str_params(only_significant=True)) + def test_should_lock_run_when_set(self): class _DummyTaskWithLock(gokart.TaskOnKart): def run(self):