diff --git a/gokart/worker.py b/gokart/worker.py index 2bf847f4..867bd24b 100644 --- a/gokart/worker.py +++ b/gokart/worker.py @@ -30,37 +30,36 @@ import collections import collections.abc +import contextlib import datetime +import functools import getpass import importlib +import json import logging import multiprocessing import os -import signal -import subprocess -import sys -import contextlib -import functools - import queue as Queue import random +import signal import socket +import subprocess +import sys import threading import time import traceback +from typing import Any, Dict, Generator, List, Literal, Optional, Set, Tuple +import luigi +import luigi.scheduler +import luigi.worker from luigi import notifications from luigi.event import Event -from luigi.task_register import load_task -from luigi.scheduler import DISABLED, DONE, FAILED, PENDING, UNKNOWN, Scheduler, RetryPolicy -from luigi.scheduler import WORKER_STATE_ACTIVE, WORKER_STATE_DISABLED +from luigi.scheduler import DISABLED, DONE, FAILED, PENDING, UNKNOWN, WORKER_STATE_ACTIVE, WORKER_STATE_DISABLED, RetryPolicy, Scheduler from luigi.target import Target -from luigi.task import Task, Config, DynamicRequirements, flatten -from luigi.task_register import TaskClassException +from luigi.task import DynamicRequirements, Task, flatten +from luigi.task_register import TaskClassException, load_task from luigi.task_status import RUNNING -from luigi.parameter import BoolParameter, FloatParameter, IntParameter, OptionalParameter, Parameter, TimeDeltaParameter - -import json logger = logging.getLogger('luigi-interface') @@ -78,16 +77,12 @@ _WAIT_INTERVAL_EPS = 0.00001 -def _is_external(task): +def _is_external(task: Task) -> bool: return task.run is None or task.run == NotImplemented -def _get_retry_policy_dict(task): - return RetryPolicy(task.retry_count, task.disable_hard_timeout, task.disable_window)._asdict() - - -class TaskException(Exception): - pass +def _get_retry_policy_dict(task: Task) -> Dict[str, Any]: + return RetryPolicy(task.retry_count, task.disable_hard_timeout, task.disable_window)._asdict() # type: ignore GetWorkResponse = collections.namedtuple( @@ -120,16 +115,16 @@ class TaskProcess(multiprocessing.Process): def __init__( self, - task, - worker_id, - result_queue, - status_reporter, - use_multiprocessing=False, - worker_timeout=0, - check_unfulfilled_deps=True, - check_complete_on_run=False, - task_completion_cache=None, - ): + task: luigi.Task, + worker_id: str, + result_queue: multiprocessing.Queue, + status_reporter: luigi.worker.TaskStatusReporter, + use_multiprocessing: bool = False, + worker_timeout: int = 0, + check_unfulfilled_deps: bool = True, + check_complete_on_run: bool = False, + task_completion_cache: Optional[Dict[str, Any]] = None, + ) -> None: super(TaskProcess, self).__init__() self.task = task self.worker_id = worker_id @@ -143,9 +138,9 @@ def __init__( self.task_completion_cache = task_completion_cache # completeness check using the cache - self.check_complete = functools.partial(check_complete_cached, completion_cache=task_completion_cache) + self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache) - def _run_get_new_deps(self): + def _run_get_new_deps(self) -> Optional[List[Tuple[str, str, Dict[str, str]]]]: task_gen = self.task.run() if not isinstance(task_gen, collections.abc.Generator): @@ -173,7 +168,7 @@ def _run_get_new_deps(self): # get the next generator result next_send = requires.paths - def run(self): + def run(self) -> None: logger.info('[pid %s] Worker %s running %s', os.getpid(), self.worker_id, self.task) if self.use_multiprocessing: @@ -182,10 +177,10 @@ def run(self): currentTime = time.time() random.seed(processID * currentTime) - status = FAILED + status: Optional[str] = FAILED expl = '' - missing = [] - new_deps = [] + missing: List[str] = [] + new_deps: Optional[List[Tuple[str, str, Dict[str, str]]]] = [] try: # Verify that all the tasks are fulfilled! For external tasks we # don't care about unfulfilled dependencies, because we are just @@ -226,7 +221,7 @@ def run(self): elif self.check_complete(self.task): status = DONE else: - raise TaskException('Task finished running, but complete() is still returning false.') + raise luigi.worker.TaskException('Task finished running, but complete() is still returning false.') else: status = PENDING @@ -247,12 +242,12 @@ def run(self): finally: self.result_queue.put((self.task.task_id, status, expl, missing, new_deps)) - def _handle_run_exception(self, ex): + def _handle_run_exception(self, ex: BaseException) -> str: logger.exception('[pid %s] Worker %s failed %s', os.getpid(), self.worker_id, self.task) self.task.trigger_event(Event.FAILURE, self.task, ex) return self.task.on_failure(ex) - def _recursive_terminate(self): + def _recursive_terminate(self) -> None: import psutil try: @@ -272,7 +267,7 @@ def _recursive_terminate(self): except psutil.NoSuchProcess: return - def terminate(self): + def terminate(self) -> None: """Terminate this process and its subprocesses.""" # default terminate() doesn't cleanup child processes, it orphans them. try: @@ -289,18 +284,18 @@ def _forward_attributes(self): yield self finally: # reset attributes again - for reporter_attr, task_attr in self.forward_reporter_attributes.items(): + for _, task_attr in self.forward_reporter_attributes.items(): setattr(self.task, task_attr, None) # This code and the task_process_context config key currently feels a bit ad-hoc. # Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897 class ContextManagedTaskProcess(TaskProcess): - def __init__(self, context, *args, **kwargs): + def __init__(self, context, *args, **kwargs) -> None: super(ContextManagedTaskProcess, self).__init__(*args, **kwargs) self.context = context - def run(self): + def run(self) -> None: if self.context: logger.debug('Importing module and instantiating ' + self.context) module_path, class_name = self.context.rsplit('.', 1) @@ -313,231 +308,6 @@ def run(self): super(ContextManagedTaskProcess, self).run() -class TaskStatusReporter: - """ - Reports task status information to the scheduler. - - This object must be pickle-able for passing to `TaskProcess` on systems - where fork method needs to pickle the process object (e.g. Windows). - """ - - def __init__(self, scheduler, task_id, worker_id, scheduler_messages): - self._task_id = task_id - self._worker_id = worker_id - self._scheduler = scheduler - self.scheduler_messages = scheduler_messages - - def update_tracking_url(self, tracking_url): - self._scheduler.add_task(task_id=self._task_id, worker=self._worker_id, status=RUNNING, tracking_url=tracking_url) - - def update_status_message(self, message): - self._scheduler.set_task_status_message(self._task_id, message) - - def update_progress_percentage(self, percentage): - self._scheduler.set_task_progress_percentage(self._task_id, percentage) - - def decrease_running_resources(self, decrease_resources): - self._scheduler.decrease_running_task_resources(self._task_id, decrease_resources) - - -class SchedulerMessage: - """ - Message object that is build by the the :py:class:`Worker` when a message from the scheduler is - received and passed to the message queue of a :py:class:`Task`. - """ - - def __init__(self, scheduler, task_id, message_id, content, **payload): - super(SchedulerMessage, self).__init__() - - self._scheduler = scheduler - self._task_id = task_id - self._message_id = message_id - - self.content = content - self.payload = payload - - def __str__(self): - return str(self.content) - - def __eq__(self, other): - return self.content == other - - def respond(self, response): - self._scheduler.add_scheduler_message_response(self._task_id, self._message_id, response) - - -class SingleProcessPool: - """ - Dummy process pool for using a single processor. - - Imitates the api of multiprocessing.Pool using single-processor equivalents. - """ - - def apply_async(self, function, args): - return function(*args) - - def close(self): - pass - - def join(self): - pass - - -class DequeQueue(collections.deque): - """ - deque wrapper implementing the Queue interface. - """ - - def put(self, obj, block=None, timeout=None): - return self.append(obj) - - def get(self, block=None, timeout=None): - try: - return self.pop() - except IndexError: - raise Queue.Empty - - -class AsyncCompletionException(Exception): - """ - Exception indicating that something went wrong with checking complete. - """ - - def __init__(self, trace): - self.trace = trace - - -class TracebackWrapper: - """ - Class to wrap tracebacks so we can know they're not just strings. - """ - - def __init__(self, trace): - self.trace = trace - - -def check_complete_cached(task, completion_cache=None): - # check if cached and complete - cache_key = task.task_id - if completion_cache is not None and completion_cache.get(cache_key): - return True - - # (re-)check the status - is_complete = task.complete() - - # tell the cache when complete - if completion_cache is not None and is_complete: - completion_cache[cache_key] = is_complete - - return is_complete - - -def check_complete(task, out_queue, completion_cache=None): - """ - Checks if task is complete, puts the result to out_queue, optionally using the completion cache. - """ - logger.debug('Checking if %s is complete', task) - try: - is_complete = check_complete_cached(task, completion_cache) - except Exception: - is_complete = TracebackWrapper(traceback.format_exc()) - out_queue.put((task, is_complete)) - - -class worker(Config): - # NOTE: `section.config-variable` in the config_path argument is deprecated in favor of `worker.config_variable` - - id = Parameter(default='', description='Override the auto-generated worker_id') - ping_interval = FloatParameter(default=1.0, config_path=dict(section='core', name='worker-ping-interval')) - keep_alive = BoolParameter(default=False, config_path=dict(section='core', name='worker-keep-alive')) - count_uniques = BoolParameter( - default=False, - config_path=dict(section='core', name='worker-count-uniques'), - description='worker-count-uniques means that we will keep a ' 'worker alive only if it has a unique pending task, as ' 'well as having keep-alive true', - ) - count_last_scheduled = BoolParameter( - default=False, description='Keep a worker alive only if there are ' 'pending tasks which it was the last to ' 'schedule.' - ) - wait_interval = FloatParameter(default=1.0, config_path=dict(section='core', name='worker-wait-interval')) - wait_jitter = FloatParameter(default=5.0) - - max_keep_alive_idle_duration = TimeDeltaParameter(default=datetime.timedelta(0)) - - max_reschedules = IntParameter(default=1, config_path=dict(section='core', name='worker-max-reschedules')) - timeout = IntParameter(default=0, config_path=dict(section='core', name='worker-timeout')) - task_limit = IntParameter(default=None, config_path=dict(section='core', name='worker-task-limit')) - retry_external_tasks = BoolParameter( - default=False, - config_path=dict(section='core', name='retry-external-tasks'), - description='If true, incomplete external tasks will be ' 'retested for completion while Luigi is running.', - ) - send_failure_email = BoolParameter(default=True, description='If true, send e-mails directly from the worker' 'on failure') - no_install_shutdown_handler = BoolParameter(default=False, description='If true, the SIGUSR1 shutdown handler will' 'NOT be install on the worker') - check_unfulfilled_deps = BoolParameter(default=True, description='If true, check for completeness of ' 'dependencies before running a task') - check_complete_on_run = BoolParameter( - default=False, - description='If true, only mark tasks as done after running if they are complete. ' - 'Regardless of this setting, the worker will always check if external ' - 'tasks are complete before marking them as done.', - ) - force_multiprocessing = BoolParameter(default=False, description='If true, use multiprocessing also when ' 'running with 1 worker') - task_process_context = OptionalParameter( - default=None, - description='If set to a fully qualified class name, the class will ' - 'be instantiated with a TaskProcess as its constructor parameter and ' - 'applied as a context manager around its run() call, so this can be ' - 'used for obtaining high level customizable monitoring or logging of ' - 'each individual Task run.', - ) - cache_task_completion = BoolParameter( - default=False, - description='If true, cache the response of successful completion checks ' - 'of tasks assigned to a worker. This can especially speed up tasks with ' - 'dynamic dependencies but assumes that the completion status does not change ' - 'after it was true the first time.', - ) - - -class KeepAliveThread(threading.Thread): - """ - Periodically tell the scheduler that the worker still lives. - """ - - def __init__(self, scheduler, worker_id, ping_interval, rpc_message_callback): - super(KeepAliveThread, self).__init__() - self._should_stop = threading.Event() - self._scheduler = scheduler - self._worker_id = worker_id - self._ping_interval = ping_interval - self._rpc_message_callback = rpc_message_callback - - def stop(self): - self._should_stop.set() - - def run(self): - while True: - self._should_stop.wait(self._ping_interval) - if self._should_stop.is_set(): - logger.info('Worker %s was stopped. Shutting down Keep-Alive thread' % self._worker_id) - break - with fork_lock: - response = None - try: - response = self._scheduler.ping(worker=self._worker_id) - except BaseException: # httplib.BadStatusLine: - logger.warning('Failed pinging scheduler') - - # handle rpc messages - if response: - for message in response['rpc_messages']: - self._rpc_message_callback(message) - - -def rpc_message_callback(fn): - fn.is_rpc_message_callback = True - return fn - - class Worker: """ Worker object communicates with a scheduler. @@ -548,14 +318,16 @@ class Worker: * asks for stuff to do (pulls it in a loop and runs it) """ - def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant=False, **kwargs): + def __init__( + self, scheduler: Optional[Scheduler] = None, worker_id: Optional[str] = None, worker_processes: int = 1, assistant: bool = False, **kwargs: Any + ) -> None: if scheduler is None: scheduler = Scheduler() self.worker_processes = int(worker_processes) self._worker_info = self._generate_worker_info() - self._config = worker(**kwargs) + self._config = luigi.worker.worker(**kwargs) worker_id = worker_id or self._config.id or self._generate_worker_id(self._worker_info) @@ -568,17 +340,17 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant self._stop_requesting_work = False self.host = socket.gethostname() - self._scheduled_tasks = {} - self._suspended_tasks = {} - self._batch_running_tasks = {} - self._batch_families_sent = set() + self._scheduled_tasks: Dict[str, Task] = {} + self._suspended_tasks: Dict[str, Task] = {} + self._batch_running_tasks: Dict[str, Any] = {} + self._batch_families_sent: Set[str] = set() self._first_task = None self.add_succeeded = True self.run_succeeded = True - self.unfulfilled_counts = collections.defaultdict(int) + self.unfulfilled_counts: Dict[str, int] = collections.defaultdict(int) # note that ``signal.signal(signal.SIGUSR1, fn)`` only works inside the main execution thread, which is why we # provide the ability to conditionally install the hook. @@ -590,9 +362,9 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant pass # Keep info about what tasks are running (could be in other processes) - self._task_result_queue = multiprocessing.Queue() - self._running_tasks = {} - self._idle_since = None + self._task_result_queue: multiprocessing.Queue = multiprocessing.Queue() + self._running_tasks: Dict[str, TaskProcess] = {} + self._idle_since: Optional[datetime.datetime] = None # mp-safe dictionary for caching completation checks across task processes self._task_completion_cache = None @@ -600,8 +372,8 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant self._task_completion_cache = multiprocessing.Manager().dict() # Stuff for execution_summary - self._add_task_history = [] - self._get_work_response_history = [] + self._add_task_history: List[Any] = [] + self._get_work_response_history: List[Any] = [] def _add_task(self, *args, **kwargs): """ @@ -627,16 +399,16 @@ def _add_task(self, *args, **kwargs): logger.info('Informed scheduler that task %s has status %s', task_id, status) - def __enter__(self): + def __enter__(self) -> 'Worker': """ Start the KeepAliveThread. """ - self._keep_alive_thread = KeepAliveThread(self._scheduler, self._id, self._config.ping_interval, self._handle_rpc_message) + self._keep_alive_thread = luigi.worker.KeepAliveThread(self._scheduler, self._id, self._config.ping_interval, self._handle_rpc_message) self._keep_alive_thread.daemon = True self._keep_alive_thread.start() return self - def __exit__(self, type, value, traceback): + def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: """ Stop the KeepAliveThread and kill still running tasks. """ @@ -648,7 +420,7 @@ def __exit__(self, type, value, traceback): self._task_result_queue.close() return False # Don't suppress exception - def _generate_worker_info(self): + def _generate_worker_info(self) -> List[Tuple[str, Any]]: # Generate as much info as possible about the worker # Some of these calls might not be available on all OS's args = [('salt', '%09d' % random.randrange(0, 10_000_000_000)), ('workers', self.worker_processes)] @@ -672,30 +444,32 @@ def _generate_worker_info(self): pass return args - def _generate_worker_id(self, worker_info): + def _generate_worker_id(self, worker_info: List[Any]) -> str: worker_info_str = ', '.join(['{}={}'.format(k, v) for k, v in worker_info]) return 'Worker({})'.format(worker_info_str) - def _validate_task(self, task): + def _validate_task(self, task: Task) -> None: if not isinstance(task, Task): - raise TaskException('Can not schedule non-task %s' % task) + raise luigi.worker.TaskException('Can not schedule non-task %s' % task) if not task.initialized(): # we can't get the repr of it since it's not initialized... - raise TaskException('Task of class %s not initialized. Did you override __init__ and forget to call super(...).__init__?' % task.__class__.__name__) + raise luigi.worker.TaskException( + 'Task of class %s not initialized. Did you override __init__ and forget to call super(...).__init__?' % task.__class__.__name__ + ) - def _log_complete_error(self, task, tb): + def _log_complete_error(self, task: Task, tb: str) -> None: log_msg = 'Will not run {task} or any dependencies due to error in complete() method:\n{tb}'.format(task=task, tb=tb) logger.warning(log_msg) - def _log_dependency_error(self, task, tb): + def _log_dependency_error(self, task: Task, tb: str) -> None: log_msg = 'Will not run {task} or any dependencies due to error in deps() method:\n{tb}'.format(task=task, tb=tb) logger.warning(log_msg) - def _log_unexpected_error(self, task): + def _log_unexpected_error(self, task: Task) -> None: logger.exception('Luigi unexpected framework error while scheduling %s', task) # needs to be called from within except clause - def _announce_scheduling_failure(self, task, expl): + def _announce_scheduling_failure(self, task: Task, expl: Any) -> None: try: self._scheduler.announce_scheduling_failure( worker=self._id, @@ -710,7 +484,7 @@ def _announce_scheduling_failure(self, task, expl): self._email_unexpected_error(task, formatted_traceback) raise - def _email_complete_error(self, task, formatted_traceback): + def _email_complete_error(self, task: Task, formatted_traceback: str) -> None: self._announce_scheduling_failure(task, formatted_traceback) if self._config.send_failure_email: self._email_error( @@ -720,7 +494,7 @@ def _email_complete_error(self, task, formatted_traceback): headline='Will not run {task} or any dependencies due to error in complete() method', ) - def _email_dependency_error(self, task, formatted_traceback): + def _email_dependency_error(self, task: Task, formatted_traceback: str) -> None: self._announce_scheduling_failure(task, formatted_traceback) if self._config.send_failure_email: self._email_error( @@ -730,7 +504,7 @@ def _email_dependency_error(self, task, formatted_traceback): headline='Will not run {task} or any dependencies due to error in deps() method', ) - def _email_unexpected_error(self, task, formatted_traceback): + def _email_unexpected_error(self, task: Task, formatted_traceback: str) -> None: # this sends even if failure e-mails are disabled, as they may indicate # a more severe failure that may not reach other alerting methods such # as scheduler batch notification @@ -741,7 +515,7 @@ def _email_unexpected_error(self, task, formatted_traceback): headline='Luigi framework error', ) - def _email_task_failure(self, task, formatted_traceback): + def _email_task_failure(self, task: Task, formatted_traceback: str) -> None: if self._config.send_failure_email: self._email_error( task, @@ -750,14 +524,14 @@ def _email_task_failure(self, task, formatted_traceback): headline='A task failed when running. Most likely run() raised an exception.', ) - def _email_error(self, task, formatted_traceback, subject, headline): + def _email_error(self, task: Task, formatted_traceback: str, subject: str, headline: str) -> None: formatted_subject = subject.format(task=task, host=self.host) formatted_headline = headline.format(task=task, host=self.host) command = subprocess.list2cmdline(sys.argv) message = notifications.format_task_error(formatted_headline, task, command, formatted_traceback) notifications.send_error_email(formatted_subject, message, task.owner_email) - def _handle_task_load_error(self, exception, task_ids): + def _handle_task_load_error(self, exception: Exception, task_ids: List[str]) -> None: msg = 'Cannot find task(s) sent by scheduler: {}'.format(','.join(task_ids)) logger.exception(msg) subject = 'Luigi: {}'.format(msg) @@ -772,7 +546,7 @@ def _handle_task_load_error(self, exception, task_ids): ) notifications.send_error_email(subject, error_message) - def add(self, task, multiprocess=False, processes=0): + def add(self, task: Task, multiprocess: bool = False, processes: int = 0) -> bool: """ Add a Task for the worker to check and possibly schedule and run. @@ -782,13 +556,13 @@ def add(self, task, multiprocess=False, processes=0): self._first_task = task.task_id self.add_succeeded = True if multiprocess: - queue = multiprocessing.Manager().Queue() - pool = multiprocessing.Pool(processes=processes if processes > 0 else None) + queue: Any = multiprocessing.Manager().Queue() + pool: Any = multiprocessing.Pool(processes=processes if processes > 0 else None) else: - queue = DequeQueue() - pool = SingleProcessPool() + queue = luigi.worker.DequeQueue() + pool = luigi.worker.SingleProcessPool() self._validate_task(task) - pool.apply_async(check_complete, [task, queue, self._task_completion_cache]) + pool.apply_async(luigi.worker.check_complete, [task, queue, self._task_completion_cache]) # we track queue size ourselves because len(queue) won't work for multiprocessing queue_size = 1 @@ -802,9 +576,9 @@ def add(self, task, multiprocess=False, processes=0): if next.task_id not in seen: self._validate_task(next) seen.add(next.task_id) - pool.apply_async(check_complete, [next, queue, self._task_completion_cache]) + pool.apply_async(luigi.worker.check_complete, [next, queue, self._task_completion_cache]) queue_size += 1 - except (KeyboardInterrupt, TaskException): + except (KeyboardInterrupt, luigi.worker.TaskException): raise except Exception as ex: self.add_succeeded = False @@ -818,7 +592,7 @@ def add(self, task, multiprocess=False, processes=0): pool.join() return self.add_succeeded - def _add_task_batcher(self, task): + def _add_task_batcher(self, task: Task) -> None: family = task.task_family if family not in self._batch_families_sent: task_class = type(task) @@ -832,7 +606,7 @@ def _add_task_batcher(self, task): ) self._batch_families_sent.add(family) - def _add(self, task, is_complete): + def _add(self, task: Task, is_complete: bool) -> Generator[Task, None, None]: if self._config.task_limit is not None and len(self._scheduled_tasks) >= self._config.task_limit: logger.warning('Will not run %s or any dependencies due to exceeded task-limit of %d', task, self._config.task_limit) deps = None @@ -845,7 +619,7 @@ def _add(self, task, is_complete): self._check_complete_value(is_complete) except KeyboardInterrupt: raise - except AsyncCompletionException as ex: + except luigi.worker.AsyncCompletionException as ex: formatted_traceback = ex.trace except BaseException: formatted_traceback = traceback.format_exc() @@ -919,23 +693,23 @@ def _add(self, task, is_complete): accepts_messages=task.accepts_messages, ) - def _validate_dependency(self, dependency): + def _validate_dependency(self, dependency: Task) -> None: if isinstance(dependency, Target): raise Exception('requires() can not return Target objects. Wrap it in an ExternalTask class') elif not isinstance(dependency, Task): raise Exception('requires() must return Task objects but {} is a {}'.format(dependency, type(dependency))) - def _check_complete_value(self, is_complete): + def _check_complete_value(self, is_complete: bool) -> None: if is_complete not in (True, False): - if isinstance(is_complete, TracebackWrapper): - raise AsyncCompletionException(is_complete.trace) + if isinstance(is_complete, luigi.worker.TracebackWrapper): + raise luigi.workerAsyncCompletionException(is_complete.trace) raise Exception('Return value of Task.complete() must be boolean (was %r)' % is_complete) - def _add_worker(self): + def _add_worker(self) -> None: self._worker_info.append(('first_task', self._first_task)) self._scheduler.add_worker(self._id, self._worker_info) - def _log_remote_tasks(self, get_work_response): + def _log_remote_tasks(self, get_work_response: GetWorkResponse) -> None: logger.debug('Done') logger.debug('There are no more tasks to run at this time') if get_work_response.running_tasks: @@ -948,7 +722,7 @@ def _log_remote_tasks(self, get_work_response): if get_work_response.n_pending_last_scheduled: logger.debug('There are %i pending tasks last scheduled by this worker', get_work_response.n_pending_last_scheduled) - def _get_work_task_id(self, get_work_response): + def _get_work_task_id(self, get_work_response: Dict[str, Any]) -> Optional[str]: if get_work_response.get('task_id') is not None: return get_work_response['task_id'] elif 'batch_id' in get_work_response: @@ -976,7 +750,7 @@ def _get_work_task_id(self, get_work_response): else: return None - def _get_work(self): + def _get_work(self) -> GetWorkResponse: if self._stop_requesting_work: return GetWorkResponse(None, 0, 0, 0, 0, WORKER_STATE_DISABLED) @@ -1028,7 +802,7 @@ def _get_work(self): worker_state=r.get('worker_state', WORKER_STATE_ACTIVE), ) - def _run_task(self, task_id): + def _run_task(self, task_id: str) -> None: if task_id in self._running_tasks: logger.debug('Got already running task id {} from scheduler, taking a break'.format(task_id)) next(self._sleeper()) @@ -1048,8 +822,8 @@ def _run_task(self, task_id): task_process.run() def _create_task_process(self, task): - message_queue = multiprocessing.Queue() if task.accepts_messages else None - reporter = TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue) + message_queue: Any = multiprocessing.Queue() if task.accepts_messages else None + reporter = luigi.worker.TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue) use_multiprocessing = self._config.force_multiprocessing or bool(self.worker_processes > 1) return ContextManagedTaskProcess( self._config.task_process_context, @@ -1064,7 +838,7 @@ def _create_task_process(self, task): task_completion_cache=self._task_completion_cache, ) - def _purge_children(self): + def _purge_children(self) -> None: """ Find dead children and put a response on the result queue. @@ -1084,7 +858,7 @@ def _purge_children(self): logger.info(error_msg) self._task_result_queue.put((task_id, FAILED, error_msg, [], [])) - def _handle_next_task(self): + def _handle_next_task(self) -> None: """ We have to catch three ways a task can be "done": @@ -1152,7 +926,7 @@ def _handle_next_task(self): self.run_succeeded &= (status == DONE) or (len(new_deps) > 0) return - def _sleeper(self): + def _sleeper(self) -> Generator[None, None, None]: # TODO is exponential backoff necessary? while True: jitter = self._config.wait_jitter @@ -1161,7 +935,7 @@ def _sleeper(self): time.sleep(wait_interval) yield - def _keep_alive(self, get_work_response): + def _keep_alive(self, get_work_response) -> bool: """ Returns true if a worker should stay alive given. @@ -1191,14 +965,14 @@ def _keep_alive(self, get_work_response): logger.debug('[%s] %s until shutdown', self._id, time_to_shutdown) return time_to_shutdown > datetime.timedelta(0) - def handle_interrupt(self, signum, _): + def handle_interrupt(self, signum, _) -> None: """ Stops the assistant from asking for more work on SIGUSR1 """ if signum == signal.SIGUSR1: self._start_phasing_out() - def _start_phasing_out(self): + def _start_phasing_out(self) -> None: """ Go into a mode where we dont ask for more work and quit once existing tasks are done. @@ -1206,7 +980,7 @@ def _start_phasing_out(self): self._config.keep_alive = False self._stop_requesting_work = True - def run(self): + def run(self) -> bool: """ Returns True if all scheduled tasks were executed successfully. """ @@ -1251,7 +1025,7 @@ def run(self): return self.run_succeeded - def _handle_rpc_message(self, message): + def _handle_rpc_message(self, message: Dict[str, Any]) -> None: logger.info('Worker %s got message %s' % (self._id, message)) # the message is a dict {'name': , 'kwargs': } @@ -1270,19 +1044,19 @@ def _handle_rpc_message(self, message): logger.info("Worker %s successfully dispatched rpc message to function '%s'" % tpl) func(**kwargs) - @rpc_message_callback - def set_worker_processes(self, n): + @luigi.worker.rpc_message_callback + def set_worker_processes(self, n: int) -> None: # set the new value self.worker_processes = max(1, n) # tell the scheduler self._scheduler.add_worker(self._id, {'workers': self.worker_processes}) - @rpc_message_callback - def dispatch_scheduler_message(self, task_id, message_id, content, **kwargs): + @luigi.worker.rpc_message_callback + def dispatch_scheduler_message(self, task_id: str, message_id: str, content: str, **kwargs: Any) -> None: task_id = str(task_id) if task_id in self._running_tasks: task_process = self._running_tasks[task_id] if task_process.status_reporter.scheduler_messages: - message = SchedulerMessage(self._scheduler, task_id, message_id, content, **kwargs) + message = luigi.worker.SchedulerMessage(self._scheduler, task_id, message_id, content, **kwargs) task_process.status_reporter.scheduler_messages.put(message)