Skip to content

Commit

Permalink
feat: add option to use custom worker on build
Browse files Browse the repository at this point in the history
  • Loading branch information
hiro-o918 committed Oct 28, 2024
1 parent 0424b8d commit 6826d98
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion gokart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gokart.build import build # noqa:F401
from gokart.build import build, WorkerSchedulerFactory # noqa:F401
from gokart.info import make_tree_info, tree_info # noqa:F401
from gokart.pandas_type_config import PandasTypeConfig # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter # noqa:F401
Expand Down
39 changes: 37 additions & 2 deletions gokart/build.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from functools import partial
from logging import getLogger
from typing import Literal, Optional, TypeVar, cast, overload
from typing import Literal, Optional, Protocol, Self, TypeVar, cast, overload

import backoff
import luigi
from luigi import rpc, scheduler

import gokart
from gokart import worker
from gokart.conflict_prevention_lock.task_lock import TaskLockException
from gokart.target import TargetOnKart
from gokart.task import TaskOnKart
Expand Down Expand Up @@ -43,6 +45,30 @@ def __init__(self):
self.flag: bool = False


class WorkerProtocol(Protocol):
"""Protocol for Worker.
This protocol is determined by luigi.worker.Worker.
"""
def add(self, task: TaskOnKart) -> bool: ...

def run(self) -> bool: ...

def __enter__(self) -> Self: ...

def __exit__(self, type, value, traceback) -> None: ...


class WorkerSchedulerFactory:
def create_local_scheduler(self) -> scheduler.Scheduler:
return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False)

def create_remote_scheduler(self, url) -> rpc.RemoteScheduler:
return rpc.RemoteScheduler(url)

def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant=False) -> WorkerProtocol:
return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant)


def _get_output(task: TaskOnKart[T]) -> T:
output = task.output()
# FIXME: currently, nested output is not supported
Expand Down Expand Up @@ -98,6 +124,7 @@ def build(
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
worker_scheduler_factory: Optional[WorkerSchedulerFactory] = None,
**env_params,
) -> Optional[T]:
"""
Expand All @@ -106,6 +133,7 @@ def build(
"""
if reset_register:
_reset_register()

with LoggerConfig(level=log_level):
task_lock_exception_raised = TaskLockExceptionRaisedFlag()

Expand All @@ -119,7 +147,14 @@ def when_failure(task, exception):
)
def _build_task():
task_lock_exception_raised.flag = False
result = luigi.build([task], local_scheduler=True, detailed_summary=True, log_level=logging.getLevelName(log_level), **env_params)
result = luigi.build(
[task],
local_scheduler=True,
detailed_summary=True,
worker_scheduler_factory=worker_scheduler_factory,
log_level=logging.getLevelName(log_level),
**env_params,
)
if task_lock_exception_raised.flag:
raise HasLockedTaskException()
if result.status == luigi.LuigiStatusCode.FAILED:
Expand Down
7 changes: 7 additions & 0 deletions test/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,12 @@ def test_build_expo_backoff_when_luigi_failed_due_to_locked_task(self):
gokart.build(_FailThreeTimesAndSuccessTask(), reset_register=False)


class TestBuildOnGokartWorker:
def test_build(self):
text = 'test'
output = gokart.build(_DummyTask(param=text), worker_scheduler_factory=gokart.WorkerSchedulerFactory())
assert output == text


if __name__ == '__main__':
unittest.main()

0 comments on commit 6826d98

Please sign in to comment.