Skip to content

Commit

Permalink
feat: add option to use custom worker on build
Browse files Browse the repository at this point in the history
  • Loading branch information
hiro-o918 committed Oct 29, 2024
1 parent 3e864a0 commit 5c63943
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
2 changes: 1 addition & 1 deletion gokart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gokart.build import build # noqa:F401
from gokart.build import WorkerSchedulerFactory, build # noqa:F401
from gokart.info import make_tree_info, tree_info # noqa:F401
from gokart.pandas_type_config import PandasTypeConfig # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter # noqa:F401
Expand Down
38 changes: 36 additions & 2 deletions gokart/build.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from functools import partial
from logging import getLogger
from typing import Literal, Optional, TypeVar, cast, overload
from typing import Literal, Optional, Protocol, TypeVar, cast, overload

import backoff
import luigi
from luigi import rpc, scheduler

import gokart
from gokart import worker
from gokart.conflict_prevention_lock.task_lock import TaskLockException
from gokart.target import TargetOnKart
from gokart.task import TaskOnKart
Expand Down Expand Up @@ -43,6 +45,31 @@ def __init__(self):
self.flag: bool = False


class WorkerProtocol(Protocol):
"""Protocol for Worker.
This protocol is determined by luigi.worker.Worker.
"""

def add(self, task: TaskOnKart) -> bool: ...

def run(self) -> bool: ...

def __enter__(self) -> 'WorkerProtocol': ...

def __exit__(self, type, value, traceback) -> Literal[False]: ...


class WorkerSchedulerFactory:
def create_local_scheduler(self) -> scheduler.Scheduler:
return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False)

def create_remote_scheduler(self, url) -> rpc.RemoteScheduler:
return rpc.RemoteScheduler(url)

def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant=False) -> WorkerProtocol:
return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant)


def _get_output(task: TaskOnKart[T]) -> T:
output = task.output()
# FIXME: currently, nested output is not supported
Expand Down Expand Up @@ -106,6 +133,7 @@ def build(
"""
if reset_register:
_reset_register()

with LoggerConfig(level=log_level):
task_lock_exception_raised = TaskLockExceptionRaisedFlag()

Expand All @@ -119,7 +147,13 @@ def when_failure(task, exception):
)
def _build_task():
task_lock_exception_raised.flag = False
result = luigi.build([task], local_scheduler=True, detailed_summary=True, log_level=logging.getLevelName(log_level), **env_params)
result = luigi.build(
[task],
local_scheduler=True,
detailed_summary=True,
log_level=logging.getLevelName(log_level),
**env_params,
)
if task_lock_exception_raised.flag:
raise HasLockedTaskException()
if result.status == luigi.LuigiStatusCode.FAILED:
Expand Down

0 comments on commit 5c63943

Please sign in to comment.