diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index fc926008..9e7ca122 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -266,7 +266,13 @@ def create_initial_state(self, *args: Any, **kwargs: Any) -> State: return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Any: + def state(self) -> State | None: + if self._state is None: + return None + return self._state + + @property + def state_label(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -312,7 +318,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: if new_state is None: return None - initial_state_label = self._state.LABEL if self._state is not None else None + initial_state_label = self.state_label label = None try: self._transitioning = True diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index abc2b24b..9262f856 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -45,8 +45,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 44d812d1..3b333edb 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -20,6 +20,7 @@ List, Optional, Protocol, + TypeVar, cast, runtime_checkable, ) @@ -535,6 +536,8 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S value = value.__name__ elif isinstance(value, Savable) and not isinstance(value, type): # persist for a savable obj, call `save` method of obj. + # the rhs branch is for when value is a Savable class, it is true runtime check + # of lhs condition. SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: @@ -544,11 +547,25 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S return out_state -def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: +def load_auto_persist_params( + obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None +) -> None: for member in obj._auto_persist: setattr(obj, member, _get_value(obj, saved_state, member, load_context)) +T = TypeVar('T', bound=Savable) + + +def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T: + obj = cls.__new__(cls) + + if isinstance(obj, SavableWithAutoPersist): + load_auto_persist_params(obj, saved_state, load_context) + + return obj + + def _get_value( obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None ) -> MethodType | Savable: diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 337a3153..1a176b9b 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -22,7 +22,6 @@ import yaml from yaml.loader import Loader -from plumpy.persistence import ensure_object_loader from plumpy.process_comms import KillMessage, MessageType try: @@ -41,6 +40,7 @@ auto_load, auto_persist, auto_save, + ensure_object_loader, ) from .utils import SAVED_STATE_TYPE @@ -98,8 +98,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -171,15 +171,15 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) - obj.state_machine = load_context.process try: obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: - process = load_context.process - obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN]) + if load_context is not None: + obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN]) + else: + raise return obj @@ -235,12 +235,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process - obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) return obj @@ -306,15 +302,12 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process - obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) if obj.COMMAND in saved_state: - # FIXME: typing obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + return obj def interrupt(self, reason: Any) -> None: @@ -444,9 +437,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process callback_name = saved_state.get(obj.DONE_CALLBACK, None) @@ -550,8 +541,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -610,8 +600,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -659,8 +648,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 53723493..8b8107d4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -20,6 +20,7 @@ Any, Awaitable, Callable, + ClassVar, Dict, Generator, Hashable, @@ -175,6 +176,7 @@ class Process(StateMachine, metaclass=ProcessStateMachineMeta): _cleanups: Optional[List[Callable[[], None]]] = None __called: bool = False + _auto_persist: ClassVar[set[str]] @classmethod def current(cls) -> Optional['Process']: @@ -294,7 +296,7 @@ def recreate_from( else: proc._loop = asyncio.get_event_loop() - proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: proc._communicator = load_context.communicator @@ -303,7 +305,7 @@ def recreate_from( proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(proc, saved_state, load_context) + persistence.load_auto_persist_params(proc, saved_state, load_context) # Inputs/outputs try: @@ -527,7 +529,9 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal + if self.state is None: + raise exceptions.InvalidStateError('process is not in state None that is invalid') + return self.state.is_terminal def result(self) -> Any: """ @@ -537,12 +541,12 @@ def result(self) -> Any: If in any other state this will raise an InvalidStateError. :return: The result of the process """ - if isinstance(self._state, process_states.Finished): - return self._state.result - if isinstance(self._state, process_states.Killed): - raise exceptions.KilledError(self._state.msg) - if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception('process excepted')) + if isinstance(self.state, process_states.Finished): + return self.state.result + if isinstance(self.state, process_states.Killed): + raise exceptions.KilledError(self.state.msg) + if isinstance(self.state, process_states.Excepted): + raise (self.state.exception or Exception('process excepted')) raise exceptions.InvalidStateError @@ -552,7 +556,7 @@ def successful(self) -> bool: Will raise if the process is not in the FINISHED state """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError as exception: raise exceptions.InvalidStateError('process is not in the finished state') from exception @@ -563,25 +567,25 @@ def is_successful(self) -> bool: :return: boolean, True if the process is in `Finished` state with `successful` attribute set to `True` """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError: return False def killed(self) -> bool: """Return whether the process is killed.""" - return self.state == process_states.ProcessState.KILLED + return self.state_label == process_states.ProcessState.KILLED def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" - if isinstance(self._state, process_states.Killed): - return self._state.msg + if isinstance(self.state, process_states.Killed): + return self.state.msg raise exceptions.InvalidStateError('Has not been killed') def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" - if isinstance(self._state, process_states.Excepted): - return self._state.exception + if isinstance(self.state, process_states.Excepted): + return self.state.exception return None @@ -591,7 +595,7 @@ def is_excepted(self) -> bool: :return: boolean, True if the process is in ``EXCEPTED`` state. """ - return self.state == process_states.ProcessState.EXCEPTED + return self.state_label == process_states.ProcessState.EXCEPTED def done(self) -> bool: """Return True if the call was successfully killed or finished running. @@ -600,7 +604,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal + return self.has_terminated() # endregion @@ -628,7 +632,7 @@ def callback_excepted( exception: Optional[BaseException], trace: Optional[TracebackType], ) -> None: - if self.state != process_states.ProcessState.EXCEPTED: + if self.state_label != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @contextlib.contextmanager @@ -681,8 +685,8 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA """ out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if isinstance(self._state, persistence.Savable): - out_state['_state'] = self._state.save() + if isinstance(self.state, persistence.Savable): + out_state['_state'] = self.state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -740,7 +744,7 @@ def on_entering(self, state: state_machine.State) -> None: def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement - state_label = self._state.LABEL + state_label = self.state_label if state_label == process_states.ProcessState.RUNNING: call_with_super_check(self.on_running) elif state_label == process_states.ProcessState.WAITING: @@ -752,21 +756,21 @@ def on_entered(self, from_state: Optional[state_machine.State]) -> None: elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_killed) - if self._communicator and isinstance(self.state, enum.Enum): + if self._communicator and isinstance(self.state_label, enum.Enum): from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None - subject = f'state_changed.{from_label}.{self.state.value}' + subject = f'state_changed.{from_label}.{self.state_label.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) except (ConnectionClosed, ChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' - self.logger.warning(message, self.pid, from_label, self.state.value) + self.logger.warning(message, self.pid, from_label, self.state_label.value) except kiwipy.TimeoutError: message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' - self.logger.warning(message, self.pid, from_label, self.state.value) + self.logger.warning(message, self.pid, from_label, self.state_label.value) def on_exiting(self) -> None: - state = self.state + state = self.state_label if state == process_states.ProcessState.WAITING: call_with_super_check(self.on_exit_waiting) elif state == process_states.ProcessState.RUNNING: @@ -1069,7 +1073,6 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) @@ -1095,9 +1098,9 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._pausing if self._stepping: - if not isinstance(self._state, Interruptable): + if not isinstance(self.state, Interruptable): raise exceptions.InvalidStateError( - f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + f'cannot interrupt {self.state.__class__}, method `interrupt` not implement' ) # Ask the step function to pause by setting this flag and giving the @@ -1106,7 +1109,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable self._set_interrupt_action_from_exception(interrupt_exception) self._pausing = self._interrupt_action # Try to interrupt the state - self._state.interrupt(interrupt_exception) + self.state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) return self._do_pause(msg) @@ -1189,7 +1192,7 @@ def play(self) -> bool: @event(from_states=(process_states.Waiting)) def resume(self, *args: Any) -> None: """Start running the process again.""" - return self._state.resume(*args) # type: ignore + return self.state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: @@ -1207,7 +1210,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] Kill the process :param msg: An optional kill message """ - if self.state == process_states.ProcessState.KILLED: + if self.state_label == process_states.ProcessState.KILLED: # Already killed return True @@ -1219,13 +1222,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] # Already killing return self._killing - if self._stepping and isinstance(self._state, Interruptable): + if self._stepping and isinstance(self.state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg) self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action - self._state.interrupt(interrupt_exception) + self.state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) @@ -1294,14 +1297,14 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused - if not isinstance(self._state, Proceedable): - raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + if not isinstance(self.state, Proceedable): + raise StateMachineError(f'cannot step from {self.state.__class__}, async method `execute` not implemented') try: self._stepping = True next_state = None try: - next_state = await self._run_task(self._state.execute) + next_state = await self._run_task(self.state.execute) except process_states.Interruption as exception: # If the interruption was caused by a call to a Process method then there should # be an interrupt action ready to be executed, so just check if the cookie matches diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 66418861..5caf1882 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -8,7 +8,6 @@ import logging import re from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -31,13 +30,12 @@ from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError -from plumpy.persistence import LoadSaveContext, auto_persist, auto_save, ensure_object_loader, Savable +from plumpy.persistence import LoadSaveContext, Savable, auto_persist, auto_save, ensure_object_loader from plumpy.process_listener import ProcessListener from . import lang, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict - __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] ToContext = dict @@ -224,7 +222,7 @@ def recreate_from( else: proc._loop = asyncio.get_event_loop() - proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: proc._communicator = load_context.communicator @@ -233,7 +231,7 @@ def recreate_from( proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(proc, saved_state, load_context) + persistence.load_auto_persist_params(proc, saved_state, load_context) # Inputs/outputs try: @@ -373,8 +371,7 @@ def recreate_from( """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) @@ -447,7 +444,7 @@ def finished(self) -> bool: def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if self._child_stepper is not None: + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() return out_state @@ -464,8 +461,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) @@ -602,7 +598,7 @@ def finished(self) -> bool: def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if self._child_stepper is not None: + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() return out_state @@ -619,8 +615,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) @@ -732,8 +727,7 @@ def recreate_from( """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 6a61fe00..15a218ce 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -150,22 +150,22 @@ def stop(self): class TestStateMachine(unittest.TestCase): def test_basic(self): cd_player = CdPlayer() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) cd_player.play('Eminem - The Real Slim Shady') - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) time.sleep(1.0) cd_player.pause() - self.assertEqual(cd_player.state, PAUSED) + self.assertEqual(cd_player.state_label, PAUSED) cd_player.play() - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) self.assertEqual(cd_player.play(), False) cd_player.stop() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) def test_invalid_event(self): cd_player = CdPlayer() diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index c6826a24..307bfdb7 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -67,7 +67,7 @@ async def test_play(self, thread_communicator, async_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.WAITING + assert proc.state_label == plumpy.ProcessState.WAITING # if not close the background process will raise exception # make sure proc reach the final state @@ -84,7 +84,7 @@ async def test_kill(self, thread_communicator, async_controller): # Check the outcome assert result - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_status(self, thread_communicator, async_controller): @@ -172,7 +172,7 @@ async def test_play(self, thread_communicator, sync_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.CREATED + assert proc.state_label == plumpy.ProcessState.CREATED @pytest.mark.asyncio async def test_kill(self, thread_communicator, sync_controller): @@ -186,7 +186,7 @@ async def test_kill(self, thread_communicator, sync_controller): # Check the outcome assert result # Occasionally fail - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_kill_all(self, thread_communicator, sync_controller): @@ -199,7 +199,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) - assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) + assert all([proc.state_label == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio async def test_status(self, thread_communicator, sync_controller): diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 4ec4c1a5..7f616433 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,7 +5,7 @@ import yaml import plumpy -from plumpy.persistence import auto_load, auto_persist, auto_save +from plumpy.persistence import auto_load, auto_persist, auto_save, ensure_object_loader from plumpy.utils import SAVED_STATE_TYPE from . import utils @@ -25,8 +25,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: @@ -55,8 +55,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: @@ -81,8 +81,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: diff --git a/tests/test_processes.py b/tests/test_processes.py index 8c15cf9a..a62bbd8d 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -19,6 +19,7 @@ # FIXME: any process listener is savable # FIXME: any process control commands are savable + class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): super().__init__() @@ -239,7 +240,7 @@ def test_execute(self): proc.execute() self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertEqual(proc.outputs, {'default': 5}) def test_run_from_class(self): @@ -277,7 +278,7 @@ def test_exception(self): proc = utils.ExceptionProcess() with self.assertRaises(RuntimeError): proc.execute() - self.assertEqual(proc.state, ProcessState.EXCEPTED) + self.assertEqual(proc.state_label, ProcessState.EXCEPTED) def test_run_kill(self): proc = utils.KillProcess() @@ -330,7 +331,7 @@ def test_kill(self): proc.kill(msg) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg(), msg) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_wait_continue(self): proc = utils.WaitForSignalProcess() @@ -344,7 +345,7 @@ def test_wait_continue(self): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_exc_info(self): proc = utils.ExceptionProcess() @@ -368,7 +369,7 @@ def test_wait_pause_play_resume(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause() self.assertTrue(result) @@ -384,7 +385,7 @@ async def async_test(): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) @@ -405,7 +406,7 @@ def test_pause_play_status_messaging(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause(PAUSE_STATUS) self.assertTrue(result) @@ -425,7 +426,7 @@ async def async_test(): loop.run_until_complete(async_test()) self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_kill_in_run(self): class KillProcess(Process): @@ -443,7 +444,7 @@ def run(self, **kwargs): proc.execute() self.assertTrue(proc.after_kill) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused_in_run(self): class PauseProcess(Process): @@ -455,7 +456,7 @@ def run(self, **kwargs): with self.assertRaises(plumpy.KilledError): proc.execute() - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused(self): loop = asyncio.get_event_loop() @@ -479,7 +480,7 @@ async def async_test(): loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_run_multiple(self): # Create and play some processes @@ -555,7 +556,7 @@ def run(self): loop.run_forever() self.assertTrue(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_pause_play_in_process(self): """Test that we can pause and play that by playing within the process""" @@ -573,7 +574,7 @@ def run(self): proc.execute() self.assertFalse(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_process_stack(self): test_case = self @@ -784,7 +785,7 @@ def test_saving_each_step(self): proc = proc_class() saver = utils.ProcessSaver(proc) saver.capture() - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots)) def test_restart(self): @@ -799,7 +800,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it loaded_proc.resume() @@ -822,7 +823,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it twice in succession loaded_proc.resume() @@ -864,7 +865,7 @@ async def async_test(): def test_killed(self): proc = utils.DummyProcess() proc.kill() - self.assertEqual(proc.state, plumpy.ProcessState.KILLED) + self.assertEqual(proc.state_label, plumpy.ProcessState.KILLED) self._check_round_trip(proc) def _check_round_trip(self, proc1): @@ -987,40 +988,40 @@ def run(self): self.out(namespace_nested + '.two', 2) # Run the process in default mode which should not add any outputs and therefore fail - process = DummyDynamicProcess() - process.execute() + proc = DummyDynamicProcess() + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertDictEqual(process.outputs, {}) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertDictEqual(proc.outputs, {}) # Attaching only namespaced ports should fail, because the required port is not added - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) # Attaching only the single required top-level port should be fine - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) # Attaching both the required and namespaced ports should result in a successful termination - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) - process.execute() - - self.assertIsNotNone(process.outputs) - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) + proc.execute() + + self.assertIsNotNone(proc.outputs) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) class TestProcessEvents(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py index 88638e01..be8f2a5e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -468,7 +468,7 @@ def run_until_waiting(proc): listener = plumpy.ProcessListener() in_waiting = plumpy.Future() - if proc.state == ProcessState.WAITING: + if proc.state_label == ProcessState.WAITING: in_waiting.set_result(True) else: