diff --git a/gokart/task.py b/gokart/task.py index 42c00bf9..2f351d2a 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -290,10 +290,7 @@ def _load(targets): return {k: _load(t) for k, t in targets.items()} return targets.load() - data = _load(self._get_input_targets(target)) - if target is None and isinstance(data, dict) and len(data) == 1: - return list(data.values())[0] - return data + return _load(self._get_input_targets(target)) @overload def load_generator(self, target: Union[None, str, TargetOnKart] = None) -> Generator[Any, None, None]: ... diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index 4ddc22fc..e3946b49 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -267,7 +267,7 @@ def test_load_with_single_dict_target(self): data = task.load() target.load.assert_called_once() - self.assertEqual(data, 1) + self.assertEqual(data, {'target_key': 1}) def test_load_with_keyword(self): task = _DummyTask()