diff --git a/docs/config.rst b/docs/config.rst index 419bc2462..84226f00c 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -215,6 +215,18 @@ State Persistence The number of save calls to buffer before writing the state. Defaults to 1, which is no buffering. + **db_store_method** + The method to use for saving state information to a SQL database. Only used if store_type is sql. + + Valid options are: + **json** - uses the `simplejson` module. + + **msgpack** - uses the `msgpack` module, from the msgpack-python package (tested with version 0.3.0). + + **pickle** - uses the `cPickle` module. be careful with this one, as pickle is Turing complete. + + **yaml** - uses the `yaml` module, from the PyYaml package (tested with version 3.10). + Example:: @@ -223,6 +235,7 @@ Example:: name: local_sqlite connection_details: "sqlite:///dest_state.db" buffer_size: 1 # No buffer + db_store_method: json .. _action_runners: diff --git a/docs/man_tronview.rst b/docs/man_tronview.rst index 1957c7f6e..cf8367741 100644 --- a/docs/man_tronview.rst +++ b/docs/man_tronview.rst @@ -69,6 +69,9 @@ Options ``-s, --save`` Save server and color options to client config file (~/.tron) +``--namespace`` + Only show jobs and services from the specified namespace + States ---------- diff --git a/tests/core/job_test.py b/tests/core/job_test.py index 61d1c2cf6..0c6c53516 100644 --- a/tests/core/job_test.py +++ b/tests/core/job_test.py @@ -103,12 +103,17 @@ def test_restore_state(self): job_runs = [mock.Mock(), mock.Mock()] state_data = ({'enabled': False, 'run_ids': [1, 2]}, run_data) - with mock.patch.object(self.job.job_runs, 'restore_state', return_value=job_runs): + with contextlib.nested( + mock.patch.object(self.job.job_runs, 'restore_state', return_value=job_runs), + mock.patch.object(self.job.job_runs, 'get_run_numbers', return_value=state_data[0]['run_ids']) + ): self.job.restore_state(state_data) assert not self.job.enabled calls = [mock.call(job_runs[i]) for i in xrange(len(job_runs))] self.job.watcher.watch.assert_has_calls(calls) + calls = [mock.call(job_runs[i], jobrun.JobRun.NOTIFY_DONE) for i in xrange(len(job_runs))] + self.job_scheduler.watch.assert_has_calls(calls) assert_equal(self.job.job_state.state_data, state_data[0]) self.job.job_runs.restore_state.assert_called_once_with( sorted(run_data, key=lambda data: data['run_num'], reverse=True), @@ -117,6 +122,7 @@ def test_restore_state(self): self.job.context, self.job.node_pool ) + self.job.job_runs.get_run_numbers.assert_called_once_with() self.job.job_scheduler.restore_state.assert_called_once_with() self.job.event.ok.assert_called_with('restored') @@ -500,6 +506,35 @@ def mock_eventloop(self): def teardown_job(self): event.EventManager.reset() + def test_restore_state_scheduled(self): + mock_scheduled = [mock.Mock(), mock.Mock()] + with contextlib.nested( + mock.patch.object(self.job_scheduler.job_runs, 'get_scheduled', + return_value=iter(mock_scheduled)), + mock.patch.object(self.job_scheduler, 'schedule'), + mock.patch.object(self.job_scheduler, '_set_callback') + ) as (get_patch, sched_patch, back_patch): + self.job_scheduler.restore_state() + get_patch.assert_called_once_with() + calls = [mock.call(m) for m in mock_scheduled] + back_patch.assert_has_calls(calls) + sched_patch.assert_called_once_with() + + def test_restore_state_queued(self): + queued = mock.Mock() + with contextlib.nested( + mock.patch.object(self.job_scheduler.job_runs, 'get_scheduled', + return_value=iter([])), + mock.patch.object(self.job_scheduler.job_runs, 'get_first_queued', + return_value=queued), + mock.patch.object(self.job_scheduler, 'schedule'), + mock.patch.object(job.eventloop, 'call_later') + ) as (get_patch, queue_patch, sched_patch, later_patch): + self.job_scheduler.restore_state() + get_patch.assert_called_once_with() + later_patch.assert_called_once_with(0, self.job_scheduler.run_job, queued, run_queued=True) + sched_patch.assert_called_once_with() + def test_schedule(self): with mock.patch.object(self.job_scheduler.job_state, 'is_enabled', new=True): diff --git a/tests/mcp_reconfigure_test.py b/tests/mcp_reconfigure_test.py index 3ad3b1316..184031d6c 100644 --- a/tests/mcp_reconfigure_test.py +++ b/tests/mcp_reconfigure_test.py @@ -143,6 +143,7 @@ def teardown_mcp(self): event.EventManager.reset() filehandler.OutputPath(self.test_dir).delete() filehandler.FileHandleManager.reset() + self.mcp.state_watcher.shutdown() def reconfigure(self): config = {schema.MASTER_NAMESPACE: self._get_config(1, self.test_dir)} diff --git a/tests/mcp_test.py b/tests/mcp_test.py index d73b07031..706b2f7e7 100644 --- a/tests/mcp_test.py +++ b/tests/mcp_test.py @@ -45,10 +45,9 @@ class MasterControlProgramTestCase(TestCase): def setup_mcp(self): self.working_dir = tempfile.mkdtemp() self.config_path = tempfile.mkdtemp() - self.mcp = mcp.MasterControlProgram( + with mock.patch('tron.serialize.runstate.statemanager.StateChangeWatcher', autospec=True): + self.mcp = mcp.MasterControlProgram( self.working_dir, self.config_path) - self.mcp.state_watcher = mock.create_autospec( - statemanager.StateChangeWatcher) @teardown def teardown_mcp(self): @@ -134,11 +133,12 @@ class MasterControlProgramRestoreStateTestCase(TestCase): def setup_mcp(self): self.working_dir = tempfile.mkdtemp() self.config_path = tempfile.mkdtemp() - self.mcp = mcp.MasterControlProgram( - self.working_dir, self.config_path) - self.mcp.jobs = mock.create_autospec(job.JobCollection) - self.mcp.services = mock.create_autospec(service.ServiceCollection) - self.mcp.state_watcher = mock.create_autospec(statemanager.StateChangeWatcher) + with mock.patch('tron.serialize.runstate.statemanager.StateChangeWatcher', autospec=True): + self.mcp = mcp.MasterControlProgram( + self.working_dir, self.config_path) + self.mcp.jobs = mock.create_autospec(job.JobCollection) + self.mcp.services = mock.create_autospec(service.ServiceCollection) + self.mcp.state_watcher = mock.create_autospec(statemanager.StateChangeWatcher) @teardown def teardown_mcp(self): diff --git a/tests/serialize/runstate/statemanager_test.py b/tests/serialize/runstate/statemanager_test.py index f8ea1ded1..891193b26 100644 --- a/tests/serialize/runstate/statemanager_test.py +++ b/tests/serialize/runstate/statemanager_test.py @@ -1,32 +1,15 @@ -import os import mock +import contextlib from testify import TestCase, assert_equal, setup, run from tests.assertions import assert_raises from tests.testingutils import autospec_method -from tron.config import schema from tron.serialize import runstate -from tron.serialize.runstate.shelvestore import ShelveStateStore from tron.serialize.runstate.statemanager import PersistentStateManager, StateChangeWatcher from tron.serialize.runstate.statemanager import StateSaveBuffer from tron.serialize.runstate.statemanager import StateMetadata from tron.serialize.runstate.statemanager import PersistenceStoreError from tron.serialize.runstate.statemanager import VersionMismatchError -from tron.serialize.runstate.statemanager import PersistenceManagerFactory - - -class PersistenceManagerFactoryTestCase(TestCase): - - def test_from_config_shelve(self): - thefilename = 'thefilename' - config = schema.ConfigState( - store_type='shelve', name=thefilename, buffer_size=0, - connection_details=None) - manager = PersistenceManagerFactory.from_config(config) - store = manager._impl - assert_equal(store.filename, config.name) - assert isinstance(store, ShelveStateStore) - os.unlink(thefilename) class StateMetadataTestCase(TestCase): @@ -73,13 +56,19 @@ class PersistentStateManagerTestCase(TestCase): @setup def setup_manager(self): - self.store = mock.Mock() - self.store.build_key.side_effect = lambda t, i: '%s%s' % (t, i) - self.buffer = StateSaveBuffer(1) - self.manager = PersistentStateManager(self.store, self.buffer) + with mock.patch('tron.serialize.runstate.statemanager.ParallelStore', autospec=True) \ + as self.store_patch: + self.store = self.store_patch.return_value + self.build_patch = mock.Mock(side_effect=lambda t, i: '%s%s' % (t, i)) + self.store_patch.return_value.configure_mock(build_key=self.build_patch) + self.buffer = StateSaveBuffer(1) + self.manager = PersistentStateManager() + self.manager._buffer = self.buffer def test__init__(self): - assert_equal(self.manager._impl, self.store) + self.store_patch.assert_called_once_with() + self.build_patch.assert_called_once_with(runstate.MCP_STATE, StateMetadata.name) + assert_equal(self.manager.metadata_key, self.manager._impl.build_key(runstate.MCP_STATE, StateMetadata.name)) def test_keys_for_items(self): names = ['namea', 'nameb'] @@ -137,14 +126,43 @@ def test_disabled_nested(self): pass assert not self.manager.enabled + def test_update_config_success(self): + new_config = mock.Mock(buffer_size=5) + self.store.load_config.configure_mock(return_value=True) + with contextlib.nested( + mock.patch.object(self.manager, '_save_from_buffer'), + mock.patch('tron.serialize.runstate.statemanager.StateSaveBuffer', autospec=True) + ) as (save_patch, buffer_patch): + assert_equal(self.manager.update_from_config(new_config), True) + save_patch.assert_called_once_with() + self.store.load_config.assert_called_once_with(new_config) + buffer_patch.assert_called_once_with(new_config.buffer_size) + + def test_update_config_failure(self): + new_config = mock.Mock(buffer_size=5) + self.store.load_config.configure_mock(return_value=False) + with contextlib.nested( + mock.patch.object(self.manager, '_save_from_buffer'), + mock.patch('tron.serialize.runstate.statemanager.StateSaveBuffer', autospec=True) + ) as (save_patch, buffer_patch): + assert_equal(self.manager.update_from_config(new_config), False) + save_patch.assert_called_once_with() + self.store.load_config.assert_called_once_with(new_config) + assert not buffer_patch.called + class StateChangeWatcherTestCase(TestCase): @setup def setup_watcher(self): - self.watcher = StateChangeWatcher() - self.state_manager = mock.create_autospec(PersistentStateManager) - self.watcher.state_manager = self.state_manager + with mock.patch('tron.serialize.runstate.statemanager.PersistentStateManager', autospec=True) \ + as self.persistence_patch: + self.watcher = StateChangeWatcher() + self.state_manager = mock.create_autospec(PersistentStateManager) + self.watcher.state_manager = self.state_manager + + def test__init__(self): + self.persistence_patch.assert_called_once_with() def test_update_from_config_no_change(self): self.watcher.config = state_config = mock.Mock() @@ -153,17 +171,39 @@ def test_update_from_config_no_change(self): assert_equal(self.watcher.state_manager, self.state_manager) assert not self.watcher.shutdown.mock_calls - @mock.patch('tron.serialize.runstate.statemanager.PersistenceManagerFactory', - autospec=True) - def test_update_from_config_changed(self, mock_factory): - state_config = mock.Mock() - autospec_method(self.watcher.shutdown) + def test_update_from_config_success(self): + state_config = mock.Mock(store_type="shelve") assert self.watcher.update_from_config(state_config) assert_equal(self.watcher.config, state_config) - self.watcher.shutdown.assert_called_with() - assert_equal(self.watcher.state_manager, - mock_factory.from_config.return_value) - mock_factory.from_config.assert_called_with(state_config) + self.state_manager.update_from_config.assert_called_once_with(state_config) + + def test_update_from_config_failure_same_config(self): + state_config = self.watcher.config + assert not self.watcher.update_from_config(state_config) + assert_equal(self.watcher.config, state_config) + assert not self.state_manager.update_from_config.called + + def test_update_from_config_failure_from_state_manager(self): + self.state_manager.update_from_config.configure_mock(return_value=False) + state_config = self.watcher.config + fake_config = mock.Mock(store_type="shelve") + assert not self.watcher.update_from_config(fake_config) + assert_equal(self.watcher.config, state_config) + self.state_manager.update_from_config.assert_called_once_with(fake_config) + + def test_update_from_config_failure_bad_store_type(self): + state_config = self.watcher.config + fake_config = mock.Mock(store_type="hue_hue_hue") + assert_raises(PersistenceStoreError, self.watcher.update_from_config, fake_config) + assert_equal(self.watcher.config, state_config) + assert not self.state_manager.update_from_config.called + + def test_update_from_config_failure_bad_db_type(self): + state_config = self.watcher.config + fake_config = mock.Mock(store_type="sql", store_method="make_it_rain") + assert_raises(PersistenceStoreError, self.watcher.update_from_config, fake_config) + assert_equal(self.watcher.config, state_config) + assert not self.state_manager.update_from_config.called def test_save_job(self): mock_job = mock.Mock() diff --git a/tests/serialize/runstate/tronstore/__init__.py b/tests/serialize/runstate/tronstore/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/serialize/runstate/tronstore/parallelstore_test.py b/tests/serialize/runstate/tronstore/parallelstore_test.py new file mode 100644 index 000000000..05cf6091f --- /dev/null +++ b/tests/serialize/runstate/tronstore/parallelstore_test.py @@ -0,0 +1,111 @@ +import contextlib +import mock +from testify import TestCase, run, setup_teardown, assert_equal +from tron.serialize import runstate +from tron.serialize.runstate.tronstore.parallelstore import ParallelStore, ParallelKey +from tron.serialize.runstate.tronstore import msg_enums + + +class ParallelStoreTestCase(TestCase): + + @setup_teardown + def setup_store(self): + self.config = mock.Mock( + name='test_config', + transport_method='pickle', + store_type='shelve', + connection_details=None, + db_store_method=None, + buffer_size=1 + ) + with mock.patch('tron.serialize.runstate.tronstore.parallelstore.StoreProcessProtocol', autospec=True) \ + as (self.process_patch): + self.store = ParallelStore() + yield + + def test__init__(self): + self.process_patch.assert_called_once_with() + assert self.store.request_factory + + def test_build_key(self): + key_type = runstate.JOB_STATE + key_name = 'the_fun_ends_here' + assert_equal(self.store.build_key(key_type, key_name), ParallelKey(key_type, key_name)) + + def test_save(self): + key_value_pairs = [ + (self.store.build_key(runstate.JOB_STATE, 'riki_the_pubstar'), + {'butterfly': 'time_to_buy_mkb'}), + (self.store.build_key(runstate.JOB_STATE, 'you_died_30_seconds_in'), + {'it_was_lag': 'i_swear'}) + ] + with mock.patch.object(self.store.request_factory, 'build') as build_patch: + self.store.save(key_value_pairs) + for key, state_data in key_value_pairs: + build_patch.assert_any_call(msg_enums.REQUEST_SAVE, key.type, (key.key, state_data)) + assert self.store.process.send_request.called + + def test_restore_single_success(self): + key = self.store.build_key(runstate.JOB_STATE, 'zeus_ult') + fake_response = mock.Mock(data=10, success=True) + with contextlib.nested( + mock.patch.object(self.store.request_factory, 'build'), + mock.patch.object(self.store.process, 'send_request_get_response', return_value=fake_response), + ) as (build_patch, send_patch): + assert_equal(self.store.restore_single(key), fake_response.data) + build_patch.assert_called_once_with(msg_enums.REQUEST_RESTORE, key.type, key.key) + assert send_patch.called + + def test_restore_single_failure(self): + key = self.store.build_key(runstate.JOB_STATE, 'rip_ryan_davis') + fake_response = mock.Mock(data=777, success=False) + with contextlib.nested( + mock.patch.object(self.store.request_factory, 'build'), + mock.patch.object(self.store.process, 'send_request_get_response', return_value=fake_response), + ) as (build_patch, send_patch): + assert not self.store.restore_single(key) + build_patch.assert_called_once_with(msg_enums.REQUEST_RESTORE, key.type, key.key) + assert send_patch.called + + def test_restore(self): + keys = [self.store.build_key(runstate.JOB_STATE, 'true_steel'), + self.store.build_key(runstate.JOB_STATE, 'the_test')] + fake_response = mock.Mock() + response_dict = dict((key, fake_response) for key in keys) + with mock.patch.object(self.store, 'restore_single', return_value=fake_response) as restore_patch: + assert_equal(self.store.restore(keys), response_dict) + for key in keys: + restore_patch.assert_any_call(key) + + def test_cleanup(self): + with mock.patch.object(self.store, 'cleanup') as clean_patch: + self.store.cleanup() + clean_patch.assert_called_once_with() + + def test_load_config_success(self): + new_config = mock.Mock() + config_req = mock.Mock() + with contextlib.nested( + mock.patch.object(self.store.process, 'update_config'), + mock.patch.object(self.store.request_factory, 'build', return_value=config_req), + mock.patch.object(self.store.process, 'send_request_get_response', + return_value=mock.Mock(success=True)) + ) as (update_patch, build_patch, send_patch): + self.store.load_config(new_config) + build_patch.assert_called_once_with(msg_enums.REQUEST_CONFIG, '', new_config) + send_patch.assert_called_once_with(config_req) + update_patch.assert_called_once_with(new_config) + + def test_load_config_failure(self): + new_config = mock.Mock() + config_req = mock.Mock() + with contextlib.nested( + mock.patch.object(self.store.process, 'update_config'), + mock.patch.object(self.store.request_factory, 'build', return_value=config_req), + mock.patch.object(self.store.process, 'send_request_get_response', + return_value=mock.Mock(success=False)) + ) as (update_patch, build_patch, send_patch): + self.store.load_config(new_config) + build_patch.assert_called_once_with(msg_enums.REQUEST_CONFIG, '', new_config) + send_patch.assert_called_once_with(config_req) + assert not update_patch.called diff --git a/tests/serialize/runstate/tronstore/process_test.py b/tests/serialize/runstate/tronstore/process_test.py new file mode 100644 index 000000000..503d6913e --- /dev/null +++ b/tests/serialize/runstate/tronstore/process_test.py @@ -0,0 +1,254 @@ +import contextlib +import mock +import signal +import os +from testify import TestCase, assert_equal, assert_raises, setup_teardown +from tron.serialize.runstate.tronstore import tronstore +from tron.serialize.runstate.tronstore.process import StoreProcessProtocol, TronStoreError + +class StoreProcessProtocolTestCase(TestCase): + + @setup_teardown + def setup_process(self): + self.test_pipe_a = mock.Mock() + self.test_pipe_b = mock.Mock() + pipe_return = mock.Mock(return_value=(self.test_pipe_a, self.test_pipe_b)) + with contextlib.nested( + mock.patch('tron.serialize.runstate.tronstore.process.Process', + autospec=True), + mock.patch('tron.serialize.runstate.tronstore.process.Pipe', + new=pipe_return), + mock.patch('tron.serialize.runstate.tronstore.process.StoreResponseFactory') + ) as (self.process_patch, self.pipe_setup_patch, self.factory_patch): + self.factory = self.factory_patch.return_value + self.process = StoreProcessProtocol() + yield + + def test__init__(self): + assert not self.process.config + self.factory_patch.assert_called_once_with() + assert_equal(self.process.orphaned_responses, {}) + assert not self.process.is_shutdown + assert self.process.pipe + + def test_start_process(self): + self.pipe_setup_patch.assert_called_once_with() + self.process_patch.assert_called_once_with(target=tronstore.main, args=(self.process.config, self.test_pipe_b)) + assert self.process_patch.daemon + self.process.process.start.assert_called_once_with() + + def test_verify_is_alive_while_dead(self): + with contextlib.nested( + mock.patch.object(self.process.process, 'is_alive', return_value=False), + mock.patch.object(self.process, '_start_process'), + ) as (alive_patch, start_patch): + assert_raises(TronStoreError, self.process._verify_is_alive) + alive_patch.assert_called_with() + assert_equal(alive_patch.call_count, 2) + start_patch.assert_called_once_with() + + def test_verify_is_alive_while_alive(self): + with contextlib.nested( + mock.patch.object(self.process.process, 'is_alive', return_value=True), + mock.patch.object(self.process, '_start_process'), + ) as (alive_patch, start_patch): + self.process._verify_is_alive() + alive_patch.assert_called_once_with() + assert not start_patch.called + + def test_send_request_running(self): + self.process.is_shutdown = False + fake_id = 77 + test_request = mock.Mock(serialized='sunny_sausalito', id=fake_id) + with contextlib.nested( + mock.patch.object(self.process, '_verify_is_alive'), + mock.patch.object(self.process.pipe, 'send_bytes') + ) as (verify_patch, pipe_patch): + self.process.send_request(test_request) + verify_patch.assert_called_once_with() + pipe_patch.assert_called_once_with(test_request.serialized) + + def test_send_request_shutdown(self): + self.process.is_shutdown = True + fake_id = 77 + test_request = mock.Mock(serialized='whiskey_media', id=fake_id) + with contextlib.nested( + mock.patch.object(self.process, '_verify_is_alive'), + mock.patch.object(self.process.pipe, 'send_bytes') + ) as (verify_patch, pipe_patch): + self.process.send_request(test_request) + assert not verify_patch.called + assert not pipe_patch.called + + def test_send_request_get_response_running_with_response(self): + self.process.is_shutdown = False + fake_id = 77 + test_request = mock.Mock(serialized='objection', id=fake_id) + test_response = mock.Mock(id=fake_id, data='overruled', success=True) + with contextlib.nested( + mock.patch.object(self.process, '_verify_is_alive'), + mock.patch.object(self.process.pipe, 'send_bytes'), + mock.patch.object(self.process, '_poll_for_response', return_value=test_response) + ) as (verify_patch, pipe_patch, poll_patch): + assert_equal(self.process.send_request_get_response(test_request), test_response) + verify_patch.assert_called_once_with() + pipe_patch.assert_called_once_with(test_request.serialized) + poll_patch.assert_called_once_with(fake_id, self.process.POLL_TIMEOUT) + + def test_send_request_get_response_running_no_response(self): + self.process.is_shutdown = False + fake_id = 77 + test_request = mock.Mock(serialized='maaaaaagiiiiccc', id=fake_id) + with contextlib.nested( + mock.patch.object(self.process, '_verify_is_alive'), + mock.patch.object(self.process.pipe, 'send_bytes'), + mock.patch.object(self.process, '_poll_for_response', return_value=None) + ) as (verify_patch, pipe_patch, poll_patch): + assert_equal(self.process.send_request_get_response(test_request), + self.process.response_factory.build(False, fake_id, '')) + verify_patch.assert_called_once_with() + pipe_patch.assert_called_once_with(test_request.serialized) + poll_patch.assert_called_once_with(fake_id, self.process.POLL_TIMEOUT) + + def test_send_request_get_response_shutdown(self): + self.process.is_shutdown = True + fake_id = 77 + test_request = mock.Mock(serialized='i_wish_for_the_nile', id=fake_id) + test_response = mock.Mock(id=fake_id, data='no_way', success=True) + with contextlib.nested( + mock.patch.object(self.process, '_verify_is_alive'), + mock.patch.object(self.process.pipe, 'send_bytes'), + mock.patch.object(self.process, '_poll_for_response', return_value=test_response) + ) as (verify_patch, pipe_patch, poll_patch): + assert_equal(self.process.send_request_get_response(test_request), + self.process.response_factory.build(False, fake_id, '')) + assert not verify_patch.called + assert not pipe_patch.called + assert not poll_patch.called + + def test_send_request_shutdown_not_shutdown(self): + self.process.is_shutdown = False + fake_id = 77 + test_request = mock.Mock(serialized='ghost_truck', id=fake_id) + with contextlib.nested( + mock.patch.object(self.process.process, 'is_alive', return_value=True), + mock.patch.object(self.process.pipe, 'close'), + mock.patch.object(self.process.pipe, 'send_bytes'), + mock.patch.object(self.process, '_poll_for_response', return_value=mock.Mock()), + mock.patch.object(os, 'kill') + ) as (alive_patch, close_patch, send_patch, poll_patch, kill_patch): + self.process.send_request_shutdown(test_request) + alive_patch.assert_called_once_with() + assert self.process.is_shutdown + send_patch.assert_called_once_with(test_request.serialized) + poll_patch.assert_called_once_with(fake_id, self.process.SHUTDOWN_TIMEOUT) + close_patch.assert_called_once_with() + kill_patch.assert_called_once_with(self.process.process.pid, signal.SIGKILL) + + def test_send_request_shutdown_is_shutdown(self): + self.process.is_shutdown = True + fake_id = 77 + test_request = mock.Mock(serialized='thats_million_bucks', id=fake_id) + with contextlib.nested( + mock.patch.object(self.process.process, 'is_alive', return_value=True), + mock.patch.object(self.process.pipe, 'close'), + mock.patch.object(self.process.pipe, 'send_bytes'), + mock.patch.object(self.process, '_poll_for_response', return_value=mock.Mock()), + mock.patch.object(self.process.process, 'terminate') + ) as (alive_patch, close_patch, send_patch, poll_patch, terminate_patch): + self.process.send_request_shutdown(test_request) + assert not alive_patch.called # should have short circuited + close_patch.assert_called_once_with() + assert self.process.is_shutdown + assert not send_patch.called + assert not poll_patch.called + assert not terminate_patch.called + + def test_send_request_shutdown_not_shutdown_but_dead(self): + self.process.is_shutdown = False + fake_id = 77 + test_request = mock.Mock(serialized='thats_million_bucks', id=fake_id) + with contextlib.nested( + mock.patch.object(self.process.process, 'is_alive', return_value=False), + mock.patch.object(self.process.pipe, 'close'), + mock.patch.object(self.process.pipe, 'send_bytes'), + mock.patch.object(self.process, '_poll_for_response', return_value=mock.Mock()), + mock.patch.object(self.process.process, 'terminate') + ) as (alive_patch, close_patch, send_patch, poll_patch, terminate_patch): + self.process.send_request_shutdown(test_request) + alive_patch.assert_called_once_with() + close_patch.assert_called_once_with() + assert self.process.is_shutdown + assert not send_patch.called + assert not poll_patch.called + assert not terminate_patch.called + + def test_update_config(self): + new_config = mock.Mock() + self.process.update_config(new_config) + assert_equal(self.process.config, new_config) + + def test_poll_for_response_has_response_makes_orphaned(self): + self.process.orphaned_responses = {} + fake_id = 77 + fake_timeout = 0.05 + fake_response_serial = ['first'] + fake_response_matching = mock.Mock(serialized=fake_response_serial, id=fake_id) + fake_id_other = 96943 + fake_response_other = mock.Mock(serialized='oliver', id=fake_id_other) + + def recv_change_response(): + ret = fake_response_serial[0] + fake_response_serial[0] = 'second' + return ret + + def get_fake_response(fake_response_serial): + if fake_response_serial == 'first': + return fake_response_other + else: + return fake_response_matching + + with contextlib.nested( + mock.patch.object(self.process.pipe, 'poll', return_value=True), + mock.patch.object(self.process.pipe, 'recv_bytes', side_effect=recv_change_response), + mock.patch.object(self.process.response_factory, 'from_msg', side_effect=get_fake_response) + ) as (poll_patch, recv_patch, from_msg_patch): + assert_equal(self.process._poll_for_response(fake_id, fake_timeout), fake_response_matching) + assert_equal(self.process.orphaned_responses, {fake_id_other: fake_response_other}) + poll_patch.assert_called_with(fake_timeout) + assert_equal(poll_patch.call_count, 2) + recv_patch.assert_called_with() + assert_equal(recv_patch.call_count, 2) + from_msg_patch.assert_called_with('second') + assert_equal(from_msg_patch.call_count, 2) + + def test_poll_for_response_has_orphaned(self): + fake_id = 77 + fake_timeout = 0.05 + fake_response = mock.Mock(serialized='wherein_there_is_dotes', id=fake_id) + self.process.orphaned_responses = {fake_id: fake_response} + with contextlib.nested( + mock.patch.object(self.process.pipe, 'poll', return_value=True), + mock.patch.object(self.process.pipe, 'recv_bytes'), + mock.patch.object(self.process.response_factory, 'from_msg') + ) as (poll_patch, recv_patch, from_msg_patch): + assert_equal(self.process._poll_for_response(fake_id, fake_timeout), fake_response) + assert_equal(self.process.orphaned_responses, {}) + assert not poll_patch.called + assert not recv_patch.called + assert not from_msg_patch.called + + def test_poll_for_response_no_response(self): + fake_id = 77 + fake_timeout = 0.05 + self.process.orphaned_responses = {} + with contextlib.nested( + mock.patch.object(self.process.pipe, 'poll', return_value=False), + mock.patch.object(self.process.pipe, 'recv_bytes'), + mock.patch.object(self.process.response_factory, 'from_msg') + ) as (poll_patch, recv_patch, from_msg_patch): + assert_equal(self.process._poll_for_response(fake_id, fake_timeout), None) + assert_equal(self.process.orphaned_responses, {}) + poll_patch.assert_called_once_with(fake_timeout) + assert not recv_patch.called + assert not from_msg_patch.called diff --git a/tests/serialize/runstate/tronstore/store_test.py b/tests/serialize/runstate/tronstore/store_test.py new file mode 100644 index 000000000..1bc8ce567 --- /dev/null +++ b/tests/serialize/runstate/tronstore/store_test.py @@ -0,0 +1,301 @@ +import os +import shelve +import tempfile +import mock +import contextlib +from testify import TestCase, run, setup, assert_equal, teardown +from tron.serialize.runstate.tronstore.store import ShelveStore, SQLStore, MongoStore, YamlStore, SyncStore, NullStore +from tron.serialize.runstate.tronstore.serialize import cPickleSerializer +from tron.serialize import runstate + + +class ShelveStoreTestCase(TestCase): + + @setup + def setup_store(self): + self.filename = os.path.join(tempfile.gettempdir(), 'tmp_shelve.state') + self.store = ShelveStore(self.filename, None, None) + + @teardown + def teardown_store(self): + self.store.cleanup() + os.unlink(self.filename) + + def test__init__(self): + assert_equal(self.filename, self.store.fname) + + def test_save(self): + data_type = runstate.JOB_STATE + key_value_pairs = [ + ("one", {'some': 'data'}), + ("two", {'its': 'fake'}) + ] + for key, value in key_value_pairs: + self.store.save(key, value, data_type) + self.store.cleanup() + stored = shelve.open(self.filename) + for key, value in key_value_pairs: + assert_equal(stored['(%s__%s)' % (data_type, key)], value) + + def test_restore_success(self): + data_type = runstate.JOB_STATE + keys = ["three", "four"] + value = {'some': 'data'} + store = shelve.open(self.filename) + for key in keys: + store['(%s__%s)' % (data_type, key)] = value + store.close() + + for key in keys: + assert_equal((True, value), self.store.restore(key, data_type)) + + def test_restore_failure(self): + keys = ["nope", "theyre not there"] + for key in keys: + assert_equal((False, None), self.store.restore(key, 'data_type')) + + +class SQLStoreTestCase(TestCase): + + @setup + def setup_store(self): + details = 'sqlite:///:memory:' + self.store = SQLStore('name', details, cPickleSerializer) + + @teardown + def teardown_store(self): + self.store.cleanup() + + def test_create_engine(self): + assert_equal(self.store.engine.url.database, ':memory:') + + def test_create_tables(self): + assert self.store.job_state_table.name + assert self.store.job_run_table.name + assert self.store.service_table.name + assert self.store.metadata_table.name + + def test_save(self): + data_type = runstate.SERVICE_STATE + key = 'dotes' + state_data = {'the_true_victim_is': 'roshan'} + self.store.save(key, state_data, data_type) + + rows = self.store.engine.execute(self.store.service_table.select()) + assert_equal(rows.fetchone(), ('dotes', self.store.serializer.serialize(state_data), 'pickle')) + + def test_restore_success(self): + data_type = runstate.JOB_STATE + key = '20minbf' + state_data = {'ogre_magi': 'pure_skill'} + + self.store.save(key, state_data, data_type) + assert_equal((True, state_data), self.store.restore(key, data_type)) + + def test_restore_failure(self): + data_type = runstate.JOB_RUN_STATE + key = 'someone_get_gem' + + assert_equal((False, None), self.store.restore(key, data_type)) + + +class MongoStoreTestCase(TestCase): + + store = None + + @setup + def setup_store(self): + import mock + self.db_name = 'test_base' + details = "hostname=localhost&port=5555" + with mock.patch('pymongo.Connection', autospec=True): + self.store = MongoStore(self.db_name, details, None) + + # Since we mocked the pymongo connection, a teardown isn't needed. + + def _create_doc(self, key, doc, data_type): + import pymongo + db = pymongo.Connection()[self.db_name] + doc['_id'] = key + db[self.store.TYPE_TO_COLLECTION_MAP[data_type]].save(doc) + db.connection.disconnect() + + def test__init__(self): + assert_equal(self.store.db_name, self.db_name) + + def test_parse_connection_details(self): + details = "hostname=mongoserver&port=55555" + params = self.store._parse_connection_details(details) + assert_equal(params, {'hostname': 'mongoserver', 'port': '55555'}) + + def test_parse_connection_details_with_user_creds(self): + details = "hostname=mongoserver&port=55555&username=ted&password=sam" + params = self.store._parse_connection_details(details) + expected = { + 'hostname': 'mongoserver', + 'port': '55555', + 'username': 'ted', + 'password': 'sam'} + assert_equal(params, expected) + + def test_parse_connection_details_none(self): + params = self.store._parse_connection_details(None) + assert_equal(params, {}) + + def test_parse_connection_details_empty(self): + params = self.store._parse_connection_details("") + assert_equal(params, {}) + + def test_save(self): + import mock + collection = mock.Mock() + key = 'gotta_have_that_dotes' + state_data = {'skywrath_mage': 'more_like_early_game_page'} + data_type = runstate.JOB_STATE + with mock.patch.object(self.store, 'db', + new={self.store.TYPE_TO_COLLECTION_MAP[data_type]: + collection} + ): + self.store.save(key, state_data, data_type) + state_data['_id'] = key + collection.save.assert_called_once_with(state_data) + + def test_restore_success(self): + import mock + key = 'stop_feeding' + state_data = {'0_and_7': 'only_10_minutes_in'} + data_type = runstate.JOB_RUN_STATE + collection = mock.Mock() + collection.find_one = mock.Mock(return_value=state_data) + with mock.patch.object(self.store, 'db', + new={self.store.TYPE_TO_COLLECTION_MAP[data_type]: + collection} + ): + assert_equal(self.store.restore(key, data_type), (True, state_data)) + collection.find_one.assert_called_once_with(key) + + def test_restore_failure(self): + import mock + key = 'gg_team_fed' + data_type = runstate.SERVICE_STATE + collection = mock.Mock() + collection.find_one = mock.Mock(return_value=None) + with mock.patch.object(self.store, 'db', + new={self.store.TYPE_TO_COLLECTION_MAP[data_type]: + collection} + ): + assert_equal(self.store.restore(key, data_type), (False, None)) + collection.find_one.assert_called_once_with(key) + + +class YamlStoreTestCase(TestCase): + + @setup + def setup_store(self): + self.filename = os.path.join(tempfile.gettempdir(), 'yaml_state') + self.store = YamlStore(self.filename, None, None) + self.test_data = { + self.store.TYPE_MAPPING[runstate.JOB_STATE]: {'a': 1}, + self.store.TYPE_MAPPING[runstate.JOB_RUN_STATE]: {'b': 2}, + self.store.TYPE_MAPPING[runstate.SERVICE_STATE]: {'c': 3} + } + + @teardown + def teardown_store(self): + try: + os.unlink(self.filename) + except OSError: + pass + + def test_restore_success(self): + import yaml + with open(self.filename, 'w') as fh: + yaml.dump(self.test_data, fh) + self.store = YamlStore(self.filename, None, None) + + data_types = [runstate.JOB_STATE, runstate.JOB_RUN_STATE, runstate.SERVICE_STATE] + for data_type in data_types: + for key in self.test_data[self.store.TYPE_MAPPING[data_type]].keys(): + success, value = self.store.restore(key, data_type) + assert success + assert_equal(self.test_data[self.store.TYPE_MAPPING[data_type]][key], value) + + def test_restore_failure(self): + assert_equal(self.store.restore('gg_stick_pro_build', runstate.JOB_STATE), (False, None)) + + def test_save(self): + import yaml + job_data = {'euls_on_sk': 'sounds_legit'} + run_data = {'phantom_cancer': 'needs_diffusal_level_2'} + service_data = {'everyone_go_dagon': 'hey_look_we_won'} + expected = { + self.store.TYPE_MAPPING[runstate.JOB_STATE]: job_data, + self.store.TYPE_MAPPING[runstate.JOB_RUN_STATE]: run_data, + self.store.TYPE_MAPPING[runstate.SERVICE_STATE]: service_data, + } + self.store.save(job_data.keys()[0], job_data.values()[0], runstate.JOB_STATE) + self.store.save(run_data.keys()[0], run_data.values()[0], runstate.JOB_RUN_STATE) + self.store.save(service_data.keys()[0], service_data.values()[0], runstate.SERVICE_STATE) + + assert_equal(self.store.buffer, expected) + with open(self.filename, 'r') as fh: + actual = yaml.load(fh) + assert_equal(actual, expected) + + +class SyncStoreTestCase(TestCase): + + @setup + def setup_sync_store(self): + self.fake_config = mock.Mock( + name='we_must_be_swift_as_a_coursing_river', + store_type='with_all_the_force_of_a_great_typhoon', + connection_details='with_all_the_strength_of_a_raging_fire', + db_store_method='mysterious_as_the_dark_side_of_the_moon') + self.store_class = mock.Mock() + with contextlib.nested( + mock.patch.object(runstate.tronstore.store, 'build_store', return_value=self.store_class), + mock.patch('tron.serialize.runstate.tronstore.store.Lock', autospec=True) + ) as (self.build_patch, self.lock_patch): + self.store = SyncStore(self.fake_config) + self.lock = self.lock_patch.return_value + + def test__init__(self): + self.lock_patch.assert_called_once_with() + self.build_patch.assert_called_once_with( + self.fake_config.name, + self.fake_config.store_type, + self.fake_config.connection_details, + self.fake_config.db_store_method + ) + assert_equal(self.store_class, self.store.store) + + def test__init__null_config(self): + store = SyncStore(None) + assert isinstance(store.store, NullStore) + + def test_save(self): + fake_arg = 'catch_a_ride' + fake_kwarg = 'no_refunds' + self.store.save(fake_arg, fake_kwarg=fake_kwarg) + self.lock.__enter__.assert_called_once_with() + self.lock.__exit__.assert_called_once_with(None, None, None) + self.store_class.save.assert_called_once_with(fake_arg, fake_kwarg=fake_kwarg) + + def test_restore(self): + fake_arg = 'catch_a_ride' + fake_kwarg = 'no_refunds' + self.store.restore(fake_arg, fake_kwarg=fake_kwarg) + self.lock.__enter__.assert_called_once_with() + self.lock.__exit__.assert_called_once_with(None, None, None) + self.store_class.restore.assert_called_once_with(fake_arg, fake_kwarg=fake_kwarg) + + def test_cleanup(self): + self.store.cleanup() + self.lock.__enter__.assert_called_once_with() + self.lock.__exit__.assert_called_once_with(None, None, None) + self.store_class.cleanup.assert_called_once_with() + + +if __name__ == "__main__": + run() diff --git a/tests/serialize/runstate/tronstore/tronstore_test.py b/tests/serialize/runstate/tronstore/tronstore_test.py new file mode 100644 index 000000000..494687cd8 --- /dev/null +++ b/tests/serialize/runstate/tronstore/tronstore_test.py @@ -0,0 +1,421 @@ +import contextlib +import mock +import signal +from Queue import Queue, Empty +from testify import TestCase, assert_equal, assert_raises, setup_teardown, setup, run + +from tron.serialize.runstate.tronstore import tronstore, msg_enums + +class TronstoreMainTestCase(TestCase): + + @setup_teardown + def setup_main(self): + self.config = mock.Mock() + self.pipe = mock.Mock() + self.store_class = mock.Mock() + self.thread_pool = mock.Mock() + self.request_factory = mock.Mock() + self.response_factory = mock.Mock() + + def echo_single_request(request): + return request + self.request_factory.from_msg = echo_single_request + + def echo_requests(not_used): + return self.requests + + def raise_to_exit(exitcode): + raise SystemError + + with contextlib.nested( + mock.patch('tron.serialize.runstate.tronstore.store.SyncStore', + new=mock.Mock(return_value=self.store_class)), + mock.patch('tron.serialize.runstate.tronstore.tronstore.TronstorePool', + new=mock.Mock(return_value=self.thread_pool)), + mock.patch('tron.serialize.runstate.tronstore.tronstore.SyncPipe', + new=mock.Mock(return_value=self.pipe)), + mock.patch('tron.serialize.runstate.tronstore.tronstore.StoreRequestFactory', + new=mock.Mock(return_value=self.request_factory)), + mock.patch('tron.serialize.runstate.tronstore.tronstore.StoreResponseFactory', + new=mock.Mock(return_value=self.response_factory)), + mock.patch.object(tronstore.os, '_exit', + autospec=True) + ) as ( + self.store_patch, + self.thread_patch, + self.pipe_patch, + self.request_patch, + self.response_patch, + self.exit_patch + ): + self.main = tronstore.TronstoreMain(self.config, self.pipe) + yield + + def test__init__(self): + self.store_patch.assert_called_once_with(self.config) + self.request_patch.assert_called_once_with() + self.pipe_patch.assert_called_once_with(self.pipe) + self.response_patch.assert_called_once_with() + self.thread_patch.assert_called_once_with(self.response_factory, self.pipe, self.store_class) + assert_equal(self.main.config, self.config) + assert not self.main.is_shutdown + assert not self.main.shutdown_req_id + + def test_get_all_from_pipe(self): + fake_data = 'fuego' + self.pipe.recv_bytes = mock.Mock(return_value=fake_data) + self.pipe.poll = mock.Mock(side_effect=iter([True, False])) + assert_equal(self.main._get_all_from_pipe(), [fake_data]) + self.pipe.recv_bytes.assert_called_once_with() + assert_equal(self.pipe.poll.call_count, 2) + + def test_reconfigure_success(self): + fake_id = 77 + fake_data = mock.Mock() + request = mock.Mock(req_type=msg_enums.REQUEST_CONFIG, id=fake_id, data=fake_data) + + self.main._reconfigure(request) + self.thread_pool.stop.assert_called_once_with() + self.thread_pool.start.assert_called_once_with() + self.store_class.cleanup.assert_called_once_with() + self.store_patch.assert_any_call(fake_data) + self.thread_patch.assert_any_call(self.response_factory, self.pipe, self.store_class) + assert_equal(self.thread_patch.call_count, 2) + assert_equal(self.main.config, fake_data) + self.response_factory.build.assert_called_once_with(True, fake_id, '') + self.pipe.send_bytes.assert_called_once_with(self.response_factory.build().serialized) + + def test_reconfigure_failure(self): + fake_id = 77 + fake_data = mock.Mock() + request = mock.Mock(req_type=msg_enums.REQUEST_CONFIG, id=fake_id, data=fake_data) + self.store_patch.configure_mock(side_effect=iter([SystemError, lambda x: None])) + + self.main._reconfigure(request) + assert_equal(self.store_patch.call_count, 3) + self.store_patch.assert_any_call(fake_data) + self.store_patch.assert_any_call(self.config) + self.thread_patch.assert_any_call(self.response_factory, self.pipe, + self.store_class) + assert_equal(self.thread_patch.call_count, 2) + self.thread_pool.stop.assert_called_once_with() + self.store_class.cleanup.assert_called_once_with() + self.thread_pool.start.assert_called_once_with() + self.response_factory.build.assert_called_once_with(False, fake_id, '') + self.pipe.send_bytes.assert_called_once_with(self.response_factory.build().serialized) + + def test_handle_request_save(self): + fake_id = 77 + request = mock.Mock(req_type=msg_enums.REQUEST_SAVE, id=fake_id) + self.main._handle_request(request) + self.thread_pool.enqueue_work.assert_called_once_with(request) + + def test_handle_request_restore(self): + fake_id = 77 + request = mock.Mock(req_type=msg_enums.REQUEST_RESTORE, id=fake_id) + self.main._handle_request(request) + self.thread_pool.enqueue_work.assert_called_once_with(request) + + def test_handle_request_shutdown(self): + fake_id = 77 + request = mock.Mock(req_type=msg_enums.REQUEST_SHUTDOWN, id=fake_id) + self.main._handle_request(request) + assert self.main.is_shutdown + assert_equal(fake_id, self.main.shutdown_req_id) + + def test_handle_request_config(self): + fake_id = 77 + request = mock.Mock(req_type=msg_enums.REQUEST_CONFIG, id=fake_id) + with mock.patch.object(self.main, '_reconfigure') as reconf_patch: + self.main._handle_request(request) + reconf_patch.assert_called_once_with(request) + + def test_shutdown_has_id(self): + fake_id = 77 + self.main.shutdown_req_id = fake_id + self.main._shutdown() + self.thread_pool.stop.assert_called_once_with() + self.store_class.cleanup.assert_called_once_with() + self.response_factory.build.assert_called_once_with(True, fake_id, '') + self.pipe.send_bytes.assert_called_once_with(self.response_factory.build().serialized) + self.exit_patch.assert_called_once_with(0) + + def test_shutdown_no_id(self): + self.main.shutdown_req_id = None + self.main._shutdown() + self.thread_pool.stop.assert_called_once_with() + self.store_class.cleanup.assert_called_once_with() + assert not self.response_factory.build.called + assert not self.pipe.send_bytes.called + self.exit_patch.assert_called_once_with(0) + + def test_main_loop_handle_requests(self): + self.main.is_shutdown = False + self.pipe.poll = mock.Mock(return_value=True) + requests = [mock.Mock(), mock.Mock()] + self.request_factory.from_msg = mock.Mock(side_effect=lambda x: x) + with contextlib.nested( + mock.patch.object(self.main, '_get_all_from_pipe', return_value=requests), + mock.patch.object(self.main, '_handle_request', + side_effect=iter([None, SystemError])) + ) as (all_patch, handle_patch): + assert_raises(SystemError, self.main.main_loop) + self.thread_pool.start.assert_called_once_with() + self.pipe.poll.assert_called_once_with(self.main.POLL_TIMEOUT) + all_patch.assert_called_once_with() + self.request_factory.from_msg.assert_has_calls( + [mock.call(requests[i]) for i in xrange(len(requests))]) + handle_patch.assert_has_calls( + [mock.call(requests[i]) for i in xrange(len(requests))]) + + def test_main_loop_is_shutdown(self): + self.main.is_shutdown = True + self.pipe.poll.configure_mock(return_value=False) + with mock.patch.object(self.main, '_shutdown', + side_effect=SystemError) as shutdown_patch: + assert_raises(SystemError, self.main.main_loop) + self.thread_pool.start.assert_called_once_with() + self.pipe.poll.assert_called_once_with(self.main.POLL_TIMEOUT) + shutdown_patch.assert_called_once_with() + + def test_main_loop_trond_check(self): + fake_id = 77 + self.main.is_shutdown = False + self.pipe.poll = mock.Mock(side_effect=iter([False, SystemError])) + with contextlib.nested( + mock.patch.object(tronstore.os, 'kill', side_effect=TypeError), + mock.patch.object(tronstore.os, 'getppid', return_value=fake_id) + ) as (kill_patch, ppid_patch): + assert_raises(SystemError, self.main.main_loop) + assert self.main.is_shutdown + ppid_patch.assert_called_once_with() + kill_patch.assert_called_once_with(fake_id, 0) + assert_equal(self.pipe.poll.call_count, 2) + + +class TronstoreHandleRequestsTestCase(TestCase): + + @setup + def setup_args(self): + self.queue = mock.Mock() + self.store_class = mock.Mock() + self.pipe = mock.Mock() + self.factory = mock.Mock() + self.do_work = mock.Mock(val=False) + + self.queue.empty.configure_mock(side_effect=iter([False, True])) + + def test_handle_requests_save(self): + fake_id = 3090 + request_data = ('fantastic', 'voyage') + data_type = 'lakeside' + request = mock.Mock(req_type=msg_enums.REQUEST_SAVE, data=request_data, + data_type=data_type, id=fake_id) + self.queue.get.configure_mock(return_value=request) + + tronstore.handle_requests(self.queue, self.factory, self.pipe, + self.store_class, self.do_work) + + self.store_class.save.assert_called_once_with(request_data[0], request_data[1], + data_type) + assert_equal(self.queue.empty.call_count, 2) + self.queue.get.assert_called_once_with(block=True, timeout=1.0) + + def test_handle_requests_restore(self): + fake_id = 53045 + request_data = 'edgeworth' + fake_success = ('steel_samurai_fan', 'or_maybe_its_ironic') + data_type = 'lawyer' + self.store_class.restore = mock.Mock(return_value=fake_success) + request = mock.Mock(req_type=msg_enums.REQUEST_RESTORE, data=request_data, + data_type=data_type, id=fake_id) + self.queue.get.configure_mock(return_value=request) + + tronstore.handle_requests(self.queue, self.factory, self.pipe, + self.store_class, self.do_work) + + self.store_class.restore.assert_called_once_with(request_data, data_type) + self.factory.build.assert_called_once_with(fake_success[0], fake_id, fake_success[1]) + self.pipe.send_bytes.assert_called_once_with(self.factory.build().serialized) + assert_equal(self.queue.empty.call_count, 2) + self.queue.get.assert_called_once_with(block=True, timeout=1.0) + + def test_handle_requests_other(self): + fake_id = 1234567890 + request = mock.Mock(req_type='not_actually_a_request', id=fake_id) + self.queue.get.configure_mock(return_value=request) + + tronstore.handle_requests(self.queue, self.factory, self.pipe, + self.store_class, self.do_work) + + self.factory.build.assert_called_once_with(False, fake_id, '') + self.pipe.send_bytes.assert_called_once_with(self.factory.build().serialized) + assert_equal(self.queue.empty.call_count, 2) + self.queue.get.assert_called_once_with(block=True, timeout=1.0) + + def test_handle_requests_cont_on_empty(self): + self.queue.get.configure_mock(side_effect=Empty) + + tronstore.handle_requests(self.queue, self.factory, self.pipe, + self.store_class, self.do_work) + + assert_equal(self.queue.empty.call_count, 2) + self.queue.get.assert_called_once_with(block=True, timeout=1.0) + assert not self.pipe.send_bytes.called + + +class TronstoreOtherTestCase(TestCase): + + def test_main(self): + config = mock.Mock() + pipe = mock.Mock() + with contextlib.nested( + mock.patch.object(tronstore, '_register_null_handlers'), + mock.patch('tron.serialize.runstate.tronstore.tronstore.TronstoreMain', autospec=True) + ) as (handler_patch, tronstore_patch): + tronstore.main(config, pipe) + handler_patch.assert_called_once_with() + tronstore_patch.assert_called_once_with(config, pipe) + tronstore_patch.return_value.main_loop.assert_called_once_with() + + def test_register_null_handlers(self): + with mock.patch.object(tronstore.signal, 'signal') as signal_patch: + tronstore._register_null_handlers() + signal_patch.assert_any_call(signal.SIGINT, tronstore._discard_signal) + signal_patch.assert_any_call(signal.SIGTERM, tronstore._discard_signal) + signal_patch.assert_any_call(signal.SIGHUP, tronstore._discard_signal) + + +class PoolBoolTestCase(TestCase): + + def test__init__(self): + poolbool = tronstore.PoolBool() + assert poolbool._val + assert poolbool.val + assert poolbool.value + + def test__init__invalid(self): + assert_raises(TypeError, tronstore.PoolBool, 'frue') + + def test__init__false(self): + poolbool = tronstore.PoolBool(False) + assert not poolbool._val + assert not poolbool.val + assert not poolbool.value + + def test_set(self): + poolbool = tronstore.PoolBool(False) + poolbool.set(True) + assert poolbool._val + assert poolbool.val + assert poolbool.value + + def test_set_invalid(self): + poolbool = tronstore.PoolBool(False) + assert_raises(TypeError, poolbool.set, 'tralse') + assert not poolbool._val + assert not poolbool.val + assert not poolbool.value + + +class SyncPipeTestCase(TestCase): + + @setup + def setup_pipe(self): + self.pipe = mock.Mock() + self.sync = tronstore.SyncPipe(self.pipe) + self.sync.lock = mock.MagicMock() + + def test__init__(self): + assert_equal(self.pipe, self.sync.pipe) + + def test_poll(self): + fake_arg = 'arrrrrrg' + fake_kwarg = 'no_dont_do_it_nishbot_we_love_you' + self.sync.poll(fake_arg, fake_kwarg=fake_kwarg) + self.pipe.poll.assert_called_once_with(fake_arg, fake_kwarg=fake_kwarg) + assert not self.sync.lock.__enter__.called + assert not self.sync.lock.lock.called + + def test_send_bytes(self): + fake_arg = 'makin_bacon' + fake_kwarg = 'hioh_its_mnc' + fake_return = 'churros' + self.pipe.send_bytes.configure_mock(return_value=fake_return) + assert_equal(self.sync.send_bytes(fake_arg, fake_kwarg=fake_kwarg), fake_return) + self.pipe.send_bytes.assert_called_once_with(fake_arg, fake_kwarg=fake_kwarg) + self.sync.lock.__enter__.assert_called_once_with() + self.sync.lock.__exit__.assert_called_once_with(None, None, None) + + def test_recv_bytes(self): + fake_arg = 'hey_can_i_have_root' + fake_kwarg = 'pls' + fake_return = 'PPFFFFFFFFAHAHAHAHAHHA' + self.pipe.recv_bytes.configure_mock(return_value=fake_return) + assert_equal(self.sync.recv_bytes(fake_arg, fake_kwarg=fake_kwarg), fake_return) + self.pipe.recv_bytes.assert_called_once_with(fake_arg, fake_kwarg=fake_kwarg) + self.sync.lock.__enter__.assert_called_once_with() + self.sync.lock.__exit__.assert_called_once_with(None, None, None) + + +class TronstorePoolTestCase(TestCase): + + @setup_teardown + def setup_tronstore_pool(self): + self.factory = mock.Mock() + self.pipe = mock.Mock() + self.store = mock.Mock() + with mock.patch('tron.serialize.runstate.tronstore.tronstore.Thread', autospec=True) \ + as self.thread_patch: + self.pool = tronstore.TronstorePool(self.factory, self.pipe, self.store) + yield + + def test__init__(self): + assert isinstance(self.pool.request_queue, tronstore.Queue) + assert_equal(self.pool.response_factory, self.factory) + assert_equal(self.pool.pipe, self.pipe) + assert_equal(self.pool.store_class, self.store) + assert isinstance(self.pool.keep_working, tronstore.PoolBool) + assert self.pool.keep_working.value + assert_equal(self.pool.thread_pool, [self.thread_patch.return_value for i in range(self.pool.POOL_SIZE)]) + self.thread_patch.assert_any_call(target=tronstore.handle_requests, + args=( + self.pool.request_queue, + self.factory, + self.pipe, + self.store, + self.pool.keep_working + )) + + def test_start(self): + self.pool.keep_working.set(False) + self.pool.start() + assert self.pool.keep_working.value + assert not self.thread_patch.return_value.daemon + assert_equal(self.thread_patch.return_value.start.call_count, self.pool.POOL_SIZE) + + def test_stop(self): + self.pool.keep_working.set(True) + self.thread_patch.return_value.is_alive.return_value = False + with mock.patch.object(self.pool, 'has_work', return_value=False) \ + as work_patch: + self.pool.stop() + assert not self.pool.keep_working.value + work_patch.assert_called_once_with() + assert_equal(self.thread_patch.return_value.is_alive.call_count, self.pool.POOL_SIZE) + + def test_enqueue_work(self): + fake_work = 'youre_fired' + with mock.patch.object(self.pool.request_queue, 'put') as put_patch: + self.pool.enqueue_work(fake_work) + put_patch.assert_called_once_with(fake_work) + + def test_has_work(self): + with mock.patch.object(self.pool.request_queue, 'empty', return_value=True) \ + as empty_patch: + assert not self.pool.has_work() + empty_patch.assert_called_once_with() + + +if __name__ == "__main__": + run() diff --git a/tools/inspect_serialized_state.py b/tools/inspect_serialized_state.py index 65d6aa951..548846cd0 100644 --- a/tools/inspect_serialized_state.py +++ b/tools/inspect_serialized_state.py @@ -34,9 +34,13 @@ def get_container(config_path): def get_state(container): config = container.get_master().state_persistence - state_manager = statemanager.PersistenceManagerFactory.from_config(config) + state_manager = statemanager.PersistentStateManager() names = container.get_job_and_service_names() - return state_manager.restore(*names) + if not state_manager.update_from_config(config): + raise SystemError('Configuration failed to load correctly.') + data = state_manager.restore(*names) + state_manager.cleanup() + return data def format_date(date_string): @@ -52,10 +56,10 @@ def max_run(item): return max(start_time) if start_time else None def build(name, job): - start_times = (max_run(job_run['runs']) for job_run in job['runs']) + start_times = (max_run(job_run['runs']) for job_run in job[1]) start_times = filter(None, start_times) last_run = format_date(max(start_times)) if start_times else None - return format % (name, job['enabled'], len(job['runs']), last_run) + return format % (name, job[0]['enabled'], len(job[1]), last_run) seq = sorted(build(*item) for item in job_states.iteritems()) return header + "".join(seq) @@ -87,4 +91,4 @@ def main(config_path, working_dir): if __name__ == "__main__": opts = parse_options() - main(opts.config_path, opts.working_dir) \ No newline at end of file + main(opts.config_path, opts.working_dir) diff --git a/tools/migration/migrate_state.py b/tools/migration/migrate_state.py index 963f367c6..1d0045125 100644 --- a/tools/migration/migrate_state.py +++ b/tools/migration/migrate_state.py @@ -17,7 +17,7 @@ import optparse from tron.config import manager, schema from tron.serialize import runstate -from tron.serialize.runstate.statemanager import PersistenceManagerFactory +from tron.serialize.runstate.statemanager import PersistentStateManager from tron.utils import tool_utils @@ -49,11 +49,16 @@ def parse_options(): def get_state_manager_from_config(config_path, working_dir): """Return a state manager from the configuration. """ + if not working_dir: + working_dir = config_path config_manager = manager.ConfigManager(config_path) config_container = config_manager.load() state_config = config_container.get_master().state_persistence with tool_utils.working_dir(working_dir): - return PersistenceManagerFactory.from_config(state_config) + ret = PersistentStateManager() + if not ret.update_from_config(state_config): + raise SystemError("%s failed to load." % config_path) + return ret def get_current_config(config_path): @@ -90,8 +95,11 @@ def convert_state(opts): job_states = add_namespaces(job_states) service_states = add_namespaces(service_states) - for name, job in job_states.iteritems(): - dest_manager.save(runstate.JOB_STATE, name, job) + for name, (job_state, run_list) in job_states.iteritems(): + dest_manager.save(runstate.JOB_STATE, name, job_state) + for run_data in run_list: + run_name = '%s.%s' % (run_data['job_name'], run_data['run_num']) + dest_manager.save(runstate.JOB_RUN_STATE, run_name, run_data) print "Migrated %s jobs." % len(job_states) for name, service in service_states.iteritems(): @@ -100,7 +108,9 @@ def convert_state(opts): dest_manager.cleanup() + print "Hang on, saving everything to the destination object..." + if __name__ == "__main__": opts, _args = parse_options() - convert_state(opts) \ No newline at end of file + convert_state(opts) diff --git a/tools/migration/migrate_state_from_0.6.1_to_0.6.2.py b/tools/migration/migrate_state_from_0.6.1_to_0.6.2.py new file mode 100644 index 000000000..403d7ff12 --- /dev/null +++ b/tools/migration/migrate_state_from_0.6.1_to_0.6.2.py @@ -0,0 +1,224 @@ +"""Usage: %prog [options] + +This is a script to convert old state storing containers into the new +objects used by Tron v0.6.2 and tronstore. The script will use the same +mechanism for storing state as specified in the Tron configuration file. +Config elements can be overriden via command line options, which allows for +full configuration of the mechanism used to store the new state object. + +Please ensure that you have Tron v0.6.2 before running this script. Also note +that migrate_state.py will NOT work again until running this script, as it has +been changed to work with v0.6.2's method of state storing. + +The working dir should generally be the same as the one used when launching +trond, but should contain the file pointed to by the configuration file. +The script attempts to load a configuration from /config by +default, or whatever -f was set to. + +***IMPORTANT*** +When using SQLAlchemy/MongoDB storing mechanisms, the -c option for setting +connection detail parameters MUST be set. + +HOWEVER, THE SCRIPT DOES NOT CHECK WHETHER OR NOT THE CONNECTION DETAILS +ARE THE SAME, NOR IF YOU ARE GOING TO CLOBBER YOUR OLD DATABASE WITH THE GIVEN +CONNECTION AND CONFIGURATION PARAMETERS. + +Please especially ensure that you are not connecting to the exact +same SQL database that holds your old state_data, or you are likely to run +into a large number of strange problems and inconsistencies. +***IMPORTANT*** + + +Command line options: + -c str Set new connection details to str for SQL/MongoDB storing. This + is REQUIRED for using SQL/MongoDB as the new state store. + + -m str Set a new mechanism for storing the new state objects. + Defaults to whatever store_type was set to in the Tron + configuration file. + Options for str are sql, mongo, yaml, and shelve. + + -d str Set a new method for storing state data within an SQL database. + Defaults to whatever was set to db_store_method in the Tron + configuration file, or json if it isn't set. Only used if + SQLAlchemy is the storing mechanism. + Options for str are pickle, yaml, msgpack, and json. + + -f str Set the path for the configuration dir to str. This defaults to + /config +""" + +import sys +import os +import copy + +from tron.commands import cmd_utils +from tron.config import ConfigError +from tron.config.schema import StatePersistenceTypes +from tron.config.manager import ConfigManager +from tron.serialize import runstate +from tron.serialize.runstate.shelvestore import ShelveStateStore +from tron.serialize.runstate.mongostore import MongoStateStore +from tron.serialize.runstate.yamlstore import YamlStateStore +from tron.serialize.runstate.sqlalchemystore import SQLAlchemyStateStore +from tron.serialize.runstate.tronstore.parallelstore import ParallelStore +from tron.serialize.runstate.statemanager import StateMetadata +from tron.serialize.runstate.tronstore.serialize import MsgPackSerializer + +def parse_options(): + usage = "usage: %prog [options] " + parser = cmd_utils.build_option_parser(usage) + parser.add_option("-c", type="string", + help="Set new connection details for db connections", + dest="new_connection_details", default=None) + parser.add_option("-m", type="string", + help="Set new state storing mechanism (store_type)", + dest="store_type", default=None) + parser.add_option("-d", type="string", + help="Set new SQL db serialization method (db_store_method)", + dest="db_store_method", default=None) + parser.add_option("-f", type="string", + help="Set path to Tron configuration file", + dest="conf_dir", default=None) + options, args = parser.parse_args(sys.argv) + return options, args[1], args[2] + +def parse_config(conf_dir): + if conf_dir: + manager = ConfigManager(conf_dir) + else: + manager = ConfigManager('config') + return manager.load() + +def get_old_state_store(state_info): + name = state_info.name + connection_details = state_info.connection_details + store_type = state_info.store_type + + if store_type == StatePersistenceTypes.shelve: + return ShelveStateStore(name) + + if store_type == StatePersistenceTypes.sql: + return SQLAlchemyStateStore(name, connection_details) + + if store_type == StatePersistenceTypes.mongo: + return MongoStateStore(name, connection_details) + + if store_type == StatePersistenceTypes.yaml: + return YamlStateStore(name) + +def compile_new_info(options, state_info, new_file): + new_state_info = copy.deepcopy(state_info) + + new_state_info = new_state_info._replace(name=new_file) + + if options.store_type: + new_state_info = new_state_info._replace(store_type=options.store_type) + + if options.db_store_method: + new_state_info = new_state_info._replace(db_store_method=options.db_store_method) + + if options.new_connection_details: + new_state_info = new_state_info._replace(connection_details=options.new_connection_details) + elif new_state_info.store_type in ('sql', 'mongo'): + raise ConfigError('Must specify connection_details using -c to use %s' + % new_state_info.store_type) + + return new_state_info + +def assert_copied(new_store, data, key): + """A small function to counter race conditions. It's possible that + tronstore will serve the restore request BEFORE the save request, which + will result in an Exception. We simply retry 10 times (which should be more + than enough time for tronstore to serve the save request).""" + + if new_store.process.config.store_type == 'mongo': + data['_id'] = key.key + for i in range(10): + try: + new_data = new_store.restore([key])[key] + except: + continue + + if data == new_data: + return + + try: + if MsgPackSerializer.deserialize(MsgPackSerializer.serialize(data)) == new_data: + return + except: + continue + + raise AssertionError('The value %s failed to copy.' % key.iden) + +def copy_metadata(old_store, new_store): + meta_key_old = old_store.build_key(runstate.MCP_STATE, StateMetadata.name) + old_metadata_dict = old_store.restore([meta_key_old]) + if old_metadata_dict: + old_metadata = old_metadata_dict[meta_key_old] + if 'version' in old_metadata: + old_metadata['version'] = (0, 6, 2, 0) + meta_key_new = new_store.build_key(runstate.MCP_STATE, StateMetadata.name) + new_store.save([(meta_key_new, old_metadata)]) + assert_copied(new_store, old_metadata, meta_key_new) + +def copy_services(old_store, new_store, service_names): + for service in service_names: + service_key_old = old_store.build_key(runstate.SERVICE_STATE, service) + old_service_dict = old_store.restore([service_key_old]) + if old_service_dict: + old_service_data = old_service_dict[service_key_old] + service_key_new = new_store.build_key(runstate.SERVICE_STATE, service) + new_store.save([(service_key_new, old_service_data)]) + assert_copied(new_store, old_service_data, service_key_new) + +def copy_jobs(old_store, new_store, job_names): + for job in job_names: + job_key_old = old_store.build_key(runstate.JOB_STATE, job) + old_job_dict = old_store.restore([job_key_old]) + if old_job_dict: + old_job_data = old_job_dict[job_key_old] + job_state_key = new_store.build_key(runstate.JOB_STATE, job) + + run_ids = [] + for job_run in old_job_data['runs']: + run_ids.append(job_run['run_num']) + job_run_key = new_store.build_key(runstate.JOB_RUN_STATE, + job + ('.%s' % job_run['run_num'])) + new_store.save([(job_run_key, job_run)]) + assert_copied(new_store, job_run, job_run_key) + + run_ids = sorted(run_ids, reverse=True) + job_state_data = {'enabled': old_job_data['enabled'], 'run_ids': run_ids} + new_store.save([(job_state_key, job_state_data)]) + assert_copied(new_store, job_state_data, job_state_key) + + +def main(): + print('Parsing options...') + (options, working_dir, new_fname) = parse_options() + os.chdir(working_dir) + print('Parsing configuration file...') + config = parse_config(options.conf_dir) + state_info = config.get_master().state_persistence + print('Setting up the old state storing object...') + old_store = get_old_state_store(state_info) + print('Setting up the new state storing object...') + new_state_info = compile_new_info(options, state_info, new_fname) + new_store = ParallelStore() + if not new_store.load_config(new_state_info): + raise AssertionError("Invalid configuration.") + + print('Copying metadata...') + copy_metadata(old_store, new_store) + print('Copying service data...') + copy_services(old_store, new_store, config.get_services().keys()) + print('Converting job data...') + copy_jobs(old_store, new_store, config.get_jobs().keys()) + print('Done copying. All data has been verified.') + print('Cleaning up, just a sec...') + old_store.cleanup() + new_store.cleanup() + +if __name__ == "__main__": + main() diff --git a/tools/migration/migrate_state_pre_0_6_2.py b/tools/migration/migrate_state_pre_0_6_2.py new file mode 100644 index 000000000..963f367c6 --- /dev/null +++ b/tools/migration/migrate_state_pre_0_6_2.py @@ -0,0 +1,106 @@ +""" + Migrate a state file/database from one StateStore implementation to another. It + may also be used to add namespace names to jobs/services when upgrading + from pre-0.5.2 to version 0.5.2. + + Usage: + python tools/migration/migrate_state.py \ + -s -d [ --namespace ] + + old_config.yaml and new_config.yaml should be configuration files with valid + state_persistence sections. The state_persistence section configures the + StateStore. + + Pre 0.5 state files can be read by the YamlStateStore. See the configuration + documentation for more details on how to create state_persistence sections. +""" +import optparse +from tron.config import manager, schema +from tron.serialize import runstate +from tron.serialize.runstate.statemanager import PersistenceManagerFactory +from tron.utils import tool_utils + + +def parse_options(): + parser = optparse.OptionParser() + parser.add_option('-s', '--source', + help="The source configuration path which contains a state_persistence " + "section configured for the state file/database.") + parser.add_option('-d', '--dest', + help="The destination configuration path which contains a " + "state_persistence section configured for the state file/database.") + parser.add_option('--source-working-dir', + help="The working directory for source dir to resolve relative paths.") + parser.add_option('--dest-working-dir', + help="The working directory for dest dir to resolve relative paths.") + parser.add_option('--namespace', action='store_true', + help="Move jobs/services which are missing a namespace to the MASTER") + + opts, args = parser.parse_args() + + if not opts.source: + parser.error("--source is required") + if not opts.dest: + parser.error("--dest is required.") + + return opts, args + + +def get_state_manager_from_config(config_path, working_dir): + """Return a state manager from the configuration. + """ + config_manager = manager.ConfigManager(config_path) + config_container = config_manager.load() + state_config = config_container.get_master().state_persistence + with tool_utils.working_dir(working_dir): + return PersistenceManagerFactory.from_config(state_config) + + +def get_current_config(config_path): + config_manager = manager.ConfigManager(config_path) + return config_manager.load() + + +def add_namespaces(state_data): + return dict(('%s.%s' % (schema.MASTER_NAMESPACE, name), data) + for (name, data) in state_data.iteritems()) + +def strip_namespace(names): + return [name.split('.', 1)[1] for name in names] + + +def convert_state(opts): + source_manager = get_state_manager_from_config(opts.source, opts.source_working_dir) + dest_manager = get_state_manager_from_config(opts.dest, opts.dest_working_dir) + container = get_current_config(opts.source) + + msg = "Migrating state from %s to %s" + print msg % (source_manager._impl, dest_manager._impl) + + job_names, service_names = container.get_job_and_service_names() + if opts.namespace: + job_names = strip_namespace(job_names) + service_names = strip_namespace(service_names) + + job_states, service_states = source_manager.restore( + job_names, service_names, skip_validation=True) + source_manager.cleanup() + + if opts.namespace: + job_states = add_namespaces(job_states) + service_states = add_namespaces(service_states) + + for name, job in job_states.iteritems(): + dest_manager.save(runstate.JOB_STATE, name, job) + print "Migrated %s jobs." % len(job_states) + + for name, service in service_states.iteritems(): + dest_manager.save(runstate.SERVICE_STATE, name, service) + print "Migrated %s services." % len(service_states) + + dest_manager.cleanup() + + +if __name__ == "__main__": + opts, _args = parse_options() + convert_state(opts) \ No newline at end of file diff --git a/tron/__init__.py b/tron/__init__.py index bf67cb197..e161cf4c1 100644 --- a/tron/__init__.py +++ b/tron/__init__.py @@ -1,4 +1,4 @@ -__version_info__ = (0, 6, 1, 1) +__version_info__ = (0, 6, 2, 0) __version__ = ".".join("%s" % v for v in __version_info__) __author__ = 'Yelp ' diff --git a/tron/config/config_parse.py b/tron/config/config_parse.py index 897b4c673..d8c9151a7 100644 --- a/tron/config/config_parse.py +++ b/tron/config/config_parse.py @@ -392,6 +392,7 @@ class ValidateStatePersistence(Validator): defaults = { 'buffer_size': 1, 'connection_details': None, + 'db_store_method': 'json', } validators = { @@ -400,6 +401,8 @@ class ValidateStatePersistence(Validator): schema.StatePersistenceTypes), 'connection_details': valid_string, 'buffer_size': valid_int, + 'db_store_method': config_utils.build_enum_validator( + schema.StateSerializationTypes), } def post_validation(self, config, config_context): @@ -426,7 +429,7 @@ def validate_jobs_and_services(config, config_context): config_utils.unique_names(fmt_string, config['jobs'], config['services']) -DEFAULT_STATE_PERSISTENCE = ConfigState('tron_state', 'shelve', None, 1) +DEFAULT_STATE_PERSISTENCE = ConfigState('tron_state', 'shelve', None, 1, 'json') DEFAULT_NODE = ValidateNode().do_shortcut('localhost') diff --git a/tron/config/manager.py b/tron/config/manager.py index 68e4dcb4c..6b28f8712 100644 --- a/tron/config/manager.py +++ b/tron/config/manager.py @@ -134,4 +134,4 @@ def create_new_config(path, master_content): manager = ConfigManager(path) manager.manifest.create() filename = manager.get_filename_from_manifest(schema.MASTER_NAMESPACE) - write_raw(filename , master_content) \ No newline at end of file + write_raw(filename , master_content) diff --git a/tron/config/schema.py b/tron/config/schema.py index ebb385bf9..94c1f50bd 100644 --- a/tron/config/schema.py +++ b/tron/config/schema.py @@ -92,7 +92,8 @@ def config_object_factory(name, required=None, optional=None): 'store_type', ],[ 'connection_details', - 'buffer_size' + 'buffer_size', + 'db_store_method', ]) @@ -151,6 +152,7 @@ def config_object_factory(name, required=None, optional=None): StatePersistenceTypes = Enum.create('shelve', 'sql', 'mongo', 'yaml') +StateSerializationTypes = Enum.create('json', 'pickle', 'msgpack', 'yaml') ActionRunnerTypes = Enum.create('none', 'subprocess') diff --git a/tron/core/job.py b/tron/core/job.py index 206a05ae2..619da3283 100644 --- a/tron/core/job.py +++ b/tron/core/job.py @@ -58,7 +58,9 @@ def status(self, job_runs): """Current status of the job. Takes a JobRunCollection as an argument.""" if not self.enabled: return self.STATUS_DISABLED - if job_runs.get_run_by_state(ActionRun.STATE_RUNNING): + + if (job_runs.get_run_by_state(ActionRun.STATE_RUNNING) or + job_runs.get_run_by_state(ActionRun.STATE_STARTING)): return self.STATUS_RUNNING if (job_runs.get_run_by_state(ActionRun.STATE_SCHEDULED) or @@ -128,9 +130,12 @@ def __init__(self, job_runs, job_config, job_state, scheduler, actiongraph, def restore_state(self): """Restore the job state and schedule any JobRuns.""" - scheduled = self.job_runs.get_scheduled() - for job_run in scheduled: - self._set_callback(job_run) + scheduled = list(self.job_runs.get_scheduled()) + if scheduled: + for job_run in scheduled: + self._set_callback(job_run) + else: + self._run_first_queued() # Ensure we have at least 1 scheduled run self.schedule() @@ -228,6 +233,13 @@ def _queue_or_cancel_active(self, job_run): job_run.cancel() self.schedule() + def _run_first_queued(self): + # TODO: this should only start runs on the same node if this is an + # all_nodes job, but that is currently not possible + queued_run = self.job_runs.get_first_queued() + if queued_run: + eventloop.call_later(0, self.run_job, queued_run, run_queued=True) + def handle_job_events(self, _observable, event): """Handle notifications from observables. If a JobRun has completed look for queued JobRuns that may need to start now. @@ -235,11 +247,7 @@ def handle_job_events(self, _observable, event): if event != jobrun.JobRun.NOTIFY_DONE: return - # TODO: this should only start runs on the same node if this is an - # all_nodes job, but that is currently not possible - queued_run = self.job_runs.get_first_queued() - if queued_run: - eventloop.call_later(0, self.run_job, queued_run, run_queued=True) + self._run_first_queued() # Attempt to schedule a new run. This will only schedule a run if the # previous run was cancelled from a scheduled state, or if the job @@ -439,8 +447,10 @@ def restore_state(self, state_data): self.node_pool) for run in job_runs: self.watcher.watch(run) + self.job_scheduler.watch(run, jobrun.JobRun.NOTIFY_DONE) self.job_state.restore_state(job_state_data) self.job_scheduler.restore_state() + self.job_state.set_run_ids(self.job_runs.get_run_numbers()) # consistency self.event.ok('restored') def update_from_job(self, job): diff --git a/tron/mcp.py b/tron/mcp.py index 5b26e79bf..3aa1bf2a4 100644 --- a/tron/mcp.py +++ b/tron/mcp.py @@ -112,7 +112,8 @@ def update_state_watcher_config(self, state_config): """ if self.state_watcher.update_from_config(state_config): for job_container in self.jobs: - self.state_watcher.save_job_run(job_container.get_job_runs()) + for job_run in job_container.get_job_runs(): + self.state_watcher.save_job_run(job_run) self.state_watcher.save_job(job_container.get_job_state()) for service in self.services: self.state_watcher.save_service(service) diff --git a/tron/serialize/runstate/statemanager.py b/tron/serialize/runstate/statemanager.py index e4451eb89..e503faece 100644 --- a/tron/serialize/runstate/statemanager.py +++ b/tron/serialize/runstate/statemanager.py @@ -6,10 +6,7 @@ from tron.config import schema from tron.core import job, jobrun, service from tron.serialize import runstate -from tron.serialize.runstate.mongostore import MongoStateStore -from tron.serialize.runstate.shelvestore import ShelveStateStore -from tron.serialize.runstate.sqlalchemystore import SQLAlchemyStateStore -from tron.serialize.runstate.yamlstore import YamlStateStore +from tron.serialize.runstate.tronstore.parallelstore import ParallelStore from tron.utils import observer log = logging.getLogger(__name__) @@ -21,37 +18,6 @@ class VersionMismatchError(ValueError): class PersistenceStoreError(ValueError): """Raised if the store can not be created or fails a read or write.""" - -class PersistenceManagerFactory(object): - """Create a PersistentStateManager.""" - - @classmethod - def from_config(cls, persistence_config): - store_type = persistence_config.store_type - name = persistence_config.name - connection_details = persistence_config.connection_details - buffer_size = persistence_config.buffer_size - store = None - - if store_type not in schema.StatePersistenceTypes: - raise PersistenceStoreError("Unknown store type: %s" % store_type) - - if store_type == schema.StatePersistenceTypes.shelve: - store = ShelveStateStore(name) - - if store_type == schema.StatePersistenceTypes.sql: - store = SQLAlchemyStateStore(name, connection_details) - - if store_type == schema.StatePersistenceTypes.mongo: - store = MongoStateStore(name, connection_details) - - if store_type == schema.StatePersistenceTypes.yaml: - store = YamlStateStore(name) - - buffer = StateSaveBuffer(buffer_size) - return PersistentStateManager(store, buffer) - - class StateMetadata(object): """A data object for saving state metadata. Conforms to the same RunState interface as Jobs and Services. @@ -78,6 +44,11 @@ def validate_metadata(cls, metadata): return version = metadata['version'] + if not isinstance(version, tuple): + try: + version = tuple(version) + except: + raise PersistenceStoreError("Stored metadata looks corrupted.") # Names (and state keys) changed in 0.5.2, requires migration # see tools/migration/migrate_state_to_namespace if version > cls.version or version < (0, 5, 2): @@ -110,6 +81,18 @@ def __iter__(self): self.buffer.clear() +class NullSaveBuffer(object): + buffer_size = 0 + buffer = {} + counter = 0 + + def save(self, key, state_data): + return False + + def __iter__(self): + return iter([]) + + class PersistentStateManager(object): """Provides an interface to persist the state of Tron. @@ -127,15 +110,19 @@ def restore(self, keys): def save(self, key, state_data): pass + def load_config(self, new_config): + pass + def cleanup(self): pass """ + # TODO: Rename things here, as ParallelStore is always used - def __init__(self, persistence_impl, buffer): + def __init__(self): self.enabled = True - self._buffer = buffer - self._impl = persistence_impl + self._buffer = NullSaveBuffer() + self._impl = ParallelStore() self.metadata_key = self._impl.build_key( runstate.MCP_STATE, StateMetadata.name) @@ -200,6 +187,14 @@ def _save_from_buffer(self): log.warn(msg) raise PersistenceStoreError(msg) + def update_from_config(self, new_state_config): + self._save_from_buffer() + if self._impl.load_config(new_state_config): + self._buffer = StateSaveBuffer(new_state_config.buffer_size) + return True + else: + return False + def cleanup(self): self._save_from_buffer() self._impl.cleanup() @@ -244,17 +239,25 @@ class StateChangeWatcher(observer.Observer): """Observer of stateful objects.""" def __init__(self): - self.state_manager = NullStateManager + self.state_manager = PersistentStateManager() self.config = None def update_from_config(self, state_config): if self.config == state_config: return False - self.shutdown() - self.state_manager = PersistenceManagerFactory.from_config(state_config) - self.config = state_config - return True + if state_config.store_type not in schema.StatePersistenceTypes: + raise PersistenceStoreError("Unknown store type: %s" % state_config.store_type) + + if state_config.db_store_method not in schema.StateSerializationTypes \ + and state_config.store_type in ('sql', 'mongo'): + raise PersistenceStoreError("Unknown db store method: %s" % state_config.db_store_method) + + if not self.state_manager.update_from_config(state_config): + return False + else: + self.config = state_config + return True def handler(self, observable, _event): """Handle a state change in an observable by saving its state.""" diff --git a/tron/serialize/runstate/tronstore/__init__.py b/tron/serialize/runstate/tronstore/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tron/serialize/runstate/tronstore/__init__.py @@ -0,0 +1 @@ + diff --git a/tron/serialize/runstate/tronstore/messages.py b/tron/serialize/runstate/tronstore/messages.py new file mode 100644 index 000000000..4e3be9bef --- /dev/null +++ b/tron/serialize/runstate/tronstore/messages.py @@ -0,0 +1,107 @@ +from tron.serialize.runstate.tronstore.serialize import cPickleSerializer + +# a simple max integer to prevent ids from growing indefinitely +MAX_MSG_ID = 2**32 - 1 + + +class StoreRequestFactory(object): + """A factory to generate requests by giving each a unique id. + The serialization method should usually be cPickle. However, you can simply + change what serialization class is used in the __init__ method- just + make sure to change it in the StoreResponseFactory as well! + """ + + def __init__(self): + self.serializer = cPickleSerializer + self.id_counter = 1 + + def _increment_counter(self): + """A simple method to make sure that we don't indefinitely increase + the id assigned to StoreRequests. + """ + return self.id_counter+1 if not self.id_counter == MAX_MSG_ID else 0 + + def build(self, req_type, data_type, data): + new_request = StoreRequest(self.id_counter, req_type, data_type, data, self.serializer) + self.id_counter = self._increment_counter() + return new_request + + def from_msg(self, msg): + return StoreRequest.from_message(self.serializer.deserialize(msg), self.serializer) + + def get_method(self): + return self.serializer + + +class StoreResponseFactory(object): + """A factory to generate responses that need to be converted to serialized + strings and back.. + """ + + def __init__(self): + self.serializer = cPickleSerializer + + def build(self, success, req_id, data): + new_request = StoreResponse(req_id, success, data, self.serializer) + return new_request + + def from_msg(self, msg): + return StoreResponse.from_message(self.serializer.deserialize(msg), self.serializer) + + def get_method(self): + return self.serializer + + +class StoreRequest(object): + """An object representing a request to tronstore. The request has four + essential attributes: + id - an integer identifier, used for matching requests with responses + req_type - the request type from msg_enums.py, such as save/restore + data_type - the type of data the request is for. there are four kinds + of saved state_data: job, jobrun, service, and meta state_data. + data - the data required for the request, like a name or state_data + """ + + def __init__(self, req_id, req_type, data_type, data, method): + self.id = req_id + self.req_type = req_type + self.data = data + self.data_type = data_type + self.method = method + + @classmethod + def from_message(cls, msg_data, method): + req_id, req_type, data_type, data = msg_data + return cls(req_id, req_type, data_type, data, method) + + @property + def serialized(self): + return self.method.serialize(( + self.id, + self.req_type, + self.data_type, + self.data)) + + +class StoreResponse(object): + """An object representing a response from tronstore. The response has three + essential attributes: + id - matches the id of some request so this can be matched with it + success - shows if the request matching this response was successful + data - data requested by a request, if any + """ + + def __init__(self, req_id, success, data, method): + self.id = req_id + self.success = success + self.data = data + self.method = method + + @classmethod + def from_message(cls, msg_data, method): + req_id, success, data = msg_data + return cls(req_id, success, data, method) + + @property + def serialized(self): + return self.method.serialize((self.id, self.success, self.data)) diff --git a/tron/serialize/runstate/tronstore/msg_enums.py b/tron/serialize/runstate/tronstore/msg_enums.py new file mode 100644 index 000000000..b7de4cee6 --- /dev/null +++ b/tron/serialize/runstate/tronstore/msg_enums.py @@ -0,0 +1,4 @@ +REQUEST_SAVE = 10 +REQUEST_RESTORE = 11 +REQUEST_CONFIG = 12 +REQUEST_SHUTDOWN = 13 diff --git a/tron/serialize/runstate/tronstore/parallelstore.py b/tron/serialize/runstate/tronstore/parallelstore.py new file mode 100644 index 000000000..9e9ab65b1 --- /dev/null +++ b/tron/serialize/runstate/tronstore/parallelstore.py @@ -0,0 +1,83 @@ +import itertools +import operator +import logging +import os + +from tron.serialize.runstate.tronstore.process import StoreProcessProtocol +from tron.serialize.runstate.tronstore.messages import StoreRequestFactory +from tron.serialize.runstate.tronstore import msg_enums + +log = logging.getLogger(__name__) + + +class ParallelKey(object): + __slots__ = ['type', 'iden'] + + def __init__(self, type, iden): + self.type = type + self.iden = iden + + @property + def key(self): + return str(self.iden) + + def __str__(self): + return "%s %s" % (self.type, self.iden) + + def __eq__(self, other): + return self.type == other.type and self.iden == other.iden + + def __hash__(self): + return hash(self.key) + + +class ParallelStore(object): + """Persist state using a parallel storing mechanism, tronstore. This uses + the python mulitprocessing module to run the tronstore executable in + another process, and handles all communication between trond and tronstore. + + This class handles construction of all messages that need to be sent + to tronstore based on whatever method was called.""" + + def __init__(self): + self.request_factory = StoreRequestFactory() + self.process = StoreProcessProtocol() + + def build_key(self, type, iden): + return ParallelKey(type, iden) + + def save(self, key_value_pairs): + for key, state_data in key_value_pairs: + request = self.request_factory.build(msg_enums.REQUEST_SAVE, key.type, (key.key, state_data)) + self.process.send_request(request) + + def restore_single(self, key): + request = self.request_factory.build(msg_enums.REQUEST_RESTORE, key.type, key.key) + response = self.process.send_request_get_response(request) + return response.data if response.success else None + + def restore(self, keys): + items = itertools.izip(keys, (self.restore_single(key) for key in keys)) + return dict(itertools.ifilter(operator.itemgetter(1), items)) + + def cleanup(self): + shutdown_req = self.request_factory.build(msg_enums.REQUEST_SHUTDOWN, '', '') + self.process.send_request_shutdown(shutdown_req) + shutdown = cleanup + + def load_config(self, new_config): + """Reconfigure the storing mechanism to use a new configuration + by shutting down and restarting tronstore. THIS MUST BE CALLED + AT LEAST ONCE, as tronstore is started with a null configuration + whenever a ParallelStore object is created.""" + config_req = self.request_factory.build(msg_enums.REQUEST_CONFIG, '', new_config) + response = self.process.send_request_get_response(config_req) + if response.success: + self.process.update_config(new_config) + return True + else: + return False + + def __repr__(self): + store = self.process.config.store_type if self.process.config else None + return "ParallelStore(%s)" % store diff --git a/tron/serialize/runstate/tronstore/process.py b/tron/serialize/runstate/tronstore/process.py new file mode 100644 index 000000000..cd21636c6 --- /dev/null +++ b/tron/serialize/runstate/tronstore/process.py @@ -0,0 +1,153 @@ +import signal +import logging +import os +from multiprocessing import Process, Pipe + +from tron.serialize.runstate.tronstore import tronstore +from tron.serialize.runstate.tronstore.messages import StoreResponseFactory + +log = logging.getLogger(__name__) + + +class TronStoreError(Exception): + """Raised whenever tronstore exits for an unknown reason.""" + def __init__(self, code): + self.code = code + + def __str__(self): + return repr(self.code) + + +class StoreProcessProtocol(object): + """The class that actually spawns and handles the tronstore process. + + This class uses the python multiprocessing module. Upon creation, it + starts tronstore with a null configuration. A reconfiguration request + must be sent to tronstore via one of the supplied object methods before + it will be able to actually perform saves and restores. Calling + update_config on this object will simply update the saved configuration + object- it won't actually update the configuration that the tronstore + process itself is using unless a _verify_is_alive fails and tronstore is + restarted. + + Communication with the process is handled by a Pipe object, which can + simply pass entire Python objects via Pickle. Despite this, we still + serialize all requests with cPickle before sending them, as cPickle + is much faster and effectively the same as cPickle. + """ + # This timeout MUST be longer than the POLL_TIMEOUT in tronstore! + SHUTDOWN_TIMEOUT = 100.0 + POLL_TIMEOUT = 10.0 + + def __init__(self): + self.config = None + self.response_factory = StoreResponseFactory() + self.orphaned_responses = {} + self.is_shutdown = False + self._start_process() + + def _start_process(self): + """Spawn the tronstore process with the saved configuration.""" + self.pipe, child_pipe = Pipe() + store_args = (self.config, child_pipe) + + self.process = Process(target=tronstore.main, args=store_args) + self.process.daemon = True + self.process.start() + + def _verify_is_alive(self): + """A check to verify that tronstore is alive. Attempts to restart + tronstore if it finds that it exited for some reason. + """ + if not self.process.is_alive(): + code = self.process.exitcode + log.warn("tronstore exited prematurely with status code %d. Attempting to restart." % code) + self._start_process() + if not self.process.is_alive(): + raise TronStoreError("tronstore crashed with status code %d and failed to restart" % code) + + def send_request(self, request): + """Send a StoreRequest to tronstore and immediately return without + waiting for tronstore's response. + """ + if self.is_shutdown: + log.warn('attempted to send a request of type %s while shut down!' % request.req_type) + return + self._verify_is_alive() + + self.pipe.send_bytes(request.serialized) + + def _poll_for_response(self, id, timeout): + """Polls for a response to the request with identifier id. Throws + any responses that it isn't looking for into a dict, and tries to + retrieve a matching response from this dict before pulling new + responses. + """ + if id in self.orphaned_responses: + return self.orphaned_responses.pop(id) + + while self.pipe.poll(timeout): + response = self.response_factory.from_msg(self.pipe.recv_bytes()) + if response.id == id: + return response + else: + self.orphaned_responses[response.id] = response + return None + + def send_request_get_response(self, request): + """Send a StoreRequest to tronstore, and block until tronstore responds + with the appropriate data. The StoreResponse is returned as is, with no + modifications. Blocks for POLL_TIMEOUT seconds until returning None. + """ + + if self.is_shutdown: + log.warn('attempted to send a request of type %s while shut down!' % request.req_type) + return self.response_factory.build(False, request.id, '') + self._verify_is_alive() + + self.pipe.send_bytes(request.serialized) + response = self._poll_for_response(request.id, self.POLL_TIMEOUT) + if not response: + log.warn(("tronstore took longer than %d seconds to respond to a" + "request, and it was dropped.") % self.POLL_TIMEOUT) + return self.response_factory.build(False, request.id, '') + else: + return response + + def send_request_shutdown(self, request): + """Shut down the process protocol. Waits for SHUTDOWN_TIMEOUT seconds + for tronstore to send a shutdown response, killing both pipes and the + process itself if no shutdown response was returned. + + Calling this prevents ANY further requests from being made to tronstore + as the process will be killed. + """ + if self.is_shutdown or not self.process.is_alive(): + self.pipe.close() + self.is_shutdown = True + return + self.is_shutdown = True + + self.pipe.send_bytes(request.serialized) + response = self._poll_for_response(request.id, self.SHUTDOWN_TIMEOUT) + + if not response or not response.success: + log.error("tronstore failed to shut down cleanly.") + + self.pipe.close() + # We can't actually use process.terminate(), as that sends a SIGTERM + # to the process, which unfortunately is registered to do nothing + # (as the process depends on trond to shut itself down, and shuts + # itself down if trond is dead anyway.) + # We want a hard kill regardless. The only way we should get to + # this code is if tronstore is about to call os._exit(0) itself. + try: + os.kill(self.process.pid, signal.SIGKILL) + except: + pass + + def update_config(self, new_config): + """Update the configuration. Needed to make sure that tronstore + is restarted with the correct configuration upon exiting + prematurely.""" + self.config = new_config diff --git a/tron/serialize/runstate/tronstore/serialize.py b/tron/serialize/runstate/tronstore/serialize.py new file mode 100644 index 000000000..b20c58130 --- /dev/null +++ b/tron/serialize/runstate/tronstore/serialize.py @@ -0,0 +1,117 @@ +"""Message serialization modules for tronstore. + +This is mainly used by the SQLAlchemy store object, an option for saving state +with tronstore, by serializing the state data into a string that's saved in +a SQL database, or by deserializing strings that are saved into state data.""" +import datetime +import simplejson as json +import cPickle as pickle + +try: + import msgpack + no_msgpack = False +except ImportError: + no_msgpack = True + +try: + import yaml + no_yaml = False +except ImportError: + no_yaml = True + + +def custom_decode(obj): + """A custom decoder for datetime and tuple objects. + The tuple part only works for JSON, as MsgPack handles tuples and lists + itself no matter what. + """ + try: + if b'__tuple__' in obj: + return tuple(custom_decode(o) for o in obj['items']) + elif b'__datetime__' in obj: + obj = datetime.datetime.strptime(obj["as_str"], "%Y%m%dT%H:%M:%S.%f") + return obj + except: + return obj + + +def custom_encode(obj): + """A custom encoder for datetime and tuple objects.""" + if isinstance(obj, tuple): + return {'__tuple__': True, 'items': [custom_encode(e) for e in obj]} + elif isinstance(obj, datetime.datetime): + return {'__datetime__': True, 'as_str': obj.strftime("%Y%m%dT%H:%M:%S.%f")} + return obj + + +class SerializerModuleError(Exception): + """Raised if a serialization module is used without it being installed.""" + def __init__(self, code): + self.code = code + + def __str__(self): + return repr(self.code) + + +class JSONSerializer(object): + name = 'json' + + @classmethod + def serialize(cls, data): + return json.dumps(data, default=custom_encode, tuple_as_array=False) + + @classmethod + def deserialize(cls, data_str): + return json.loads(data_str, object_hook=custom_decode) + + +class cPickleSerializer(object): + name = 'pickle' + + @classmethod + def serialize(cls, data): + return pickle.dumps(data) + + @classmethod + def deserialize(cls, data_str): + return pickle.loads(data_str) + + +class MsgPackSerializer(object): + name = 'msgpack' + + @classmethod + def serialize(cls, data): + if no_msgpack: + raise SerializerModuleError('MessagePack not installed.') + return msgpack.packb(data, default=custom_encode) + + @classmethod + def deserialize(cls, data_str): + if no_msgpack: + raise SerializerModuleError('MessagePack not installed.') + return msgpack.unpackb(data_str, object_hook=custom_decode, use_list=0) + + +class YamlSerializer(object): + name = 'yaml' + + @classmethod + def serialize(cls, data): + if no_yaml: + raise SerializerModuleError('PyYaml not installed.') + return yaml.dump(data) + + @classmethod + def deserialize(cls, data_str): + if no_yaml: + raise SerializerModuleError('PyYaml not installed.') + return yaml.load(data_str) + + +serialize_class_map = { + 'json': JSONSerializer, + 'pickle': cPickleSerializer, + 'msgpack': MsgPackSerializer, + 'yaml': YamlSerializer +} diff --git a/tron/serialize/runstate/tronstore/store.py b/tron/serialize/runstate/tronstore/store.py new file mode 100644 index 000000000..e5d478815 --- /dev/null +++ b/tron/serialize/runstate/tronstore/store.py @@ -0,0 +1,309 @@ +import shelve +import urlparse +import os +from contextlib import contextmanager +from threading import Lock + +from tron.serialize.runstate.tronstore.serialize import serialize_class_map +from tron.serialize import runstate +from tron.config.config_utils import MAX_IDENTIFIER_LENGTH + + +class NullStore(object): + + def save(self, key, state_data, data_type): + return False + + def restore(self, key, data_type): + return (False, None) + + def cleanup(self): + pass + + def __repr__(self): + return "NullStateStore" + + +class ShelveStore(object): + """Store state using python's built-in shelve module.""" + + def __init__(self, name, connection_details=None, serializer=None): + self.fname = name + self.shelve = shelve.open(self.fname) + + def save(self, key, state_data, data_type): + self.shelve['(%s__%s)' % (data_type, key)] = state_data + self.shelve.sync() + return True + + def restore(self, key, data_type): + value = self.shelve.get('(%s__%s)' % (data_type, key)) + return (True, value) if value else (False, None) + + def cleanup(self): + self.shelve.close() + + def __repr__(self): + return "ShelveStateStore('%s')" % self.filename + + +class SQLStore(object): + """Store state using SQLAlchemy. Creates tables if needed.""" + + def __init__(self, name, connection_details, serializer): + import sqlalchemy as sql + global sql + assert sql + + self.name = name + self._connection = None + self.serializer = serializer + self.engine = sql.create_engine(connection_details, + connect_args={'check_same_thread': False}, + poolclass=sql.pool.StaticPool, + encoding='ascii') + self.engine.raw_connection().connection.text_factory = str + self._setup_tables() + + def _setup_tables(self): + self._metadata = sql.MetaData() + self.job_state_table = sql.Table('job_state_data', self._metadata, + sql.Column('key', sql.String(MAX_IDENTIFIER_LENGTH), primary_key=True), + sql.Column('state_data', sql.LargeBinary), + sql.Column('serial_method', sql.String(MAX_IDENTIFIER_LENGTH))) + self.service_table = sql.Table('service_data', self._metadata, + sql.Column('key', sql.String(MAX_IDENTIFIER_LENGTH), primary_key=True), + sql.Column('state_data', sql.LargeBinary), + sql.Column('serial_method', sql.String(MAX_IDENTIFIER_LENGTH))) + self.job_run_table = sql.Table('job_run_data', self._metadata, + sql.Column('key', sql.String(MAX_IDENTIFIER_LENGTH), primary_key=True), + sql.Column('state_data', sql.LargeBinary), + sql.Column('serial_method', sql.String(MAX_IDENTIFIER_LENGTH))) + self.metadata_table = sql.Table('metadata_table', self._metadata, + sql.Column('key', sql.String(MAX_IDENTIFIER_LENGTH), primary_key=True), + sql.Column('state_data', sql.LargeBinary), + sql.Column('serial_method', sql.String(MAX_IDENTIFIER_LENGTH))) + + self._metadata.create_all(self.engine) + + @contextmanager + def connect(self): + if not self._connection or self._connection.closed: + self._connection = self.engine.connect() + yield self._connection + + def _get_table(self, data_type): + if data_type == runstate.JOB_STATE: + return self.job_state_table + elif data_type == runstate.JOB_RUN_STATE: + return self.job_run_table + elif data_type == runstate.SERVICE_STATE: + return self.service_table + elif data_type == runstate.MCP_STATE: + return self.metadata_table + else: + return None + + def save(self, key, state_data, data_type): + with self.connect() as conn: + table = self._get_table(data_type) + if table is None: + return False + state_data = self.serializer.serialize(state_data) + serial_method = self.serializer.name + update_result = conn.execute( + table.update() + .where(table.c.key == key) + .values(state_data=state_data, + serial_method=serial_method)) + if not update_result.rowcount: + conn.execute( + table.insert() + .values(key=key, state_data=state_data, + serial_method=serial_method)) + return True + + def restore(self, key, data_type): + with self.connect() as conn: + table = self._get_table(data_type) + if table is None: + return (False, None) + result = conn.execute(sql.sql.select( + [table.c.state_data, table.c.serial_method], + table.c.key == key) + ).fetchone() + if not result: + return (False, None) + elif result[1] != self.serializer.name: + # TODO: If/when we have logging in the Tronstore process, + # log here that the db_store_method was different + serializer = serialize_class_map[result[1]] + return (True, serializer.deserialize(result[0])) + else: + return (True, self.serializer.deserialize(result[0])) + + def cleanup(self): + if self._connection: + self._connection.close() + + def __repr__(self): + return "SQLStore(%s)" % self.name + + +class MongoStore(object): + """Store state using mongoDB.""" + + JOB_COLLECTION = 'job_state_collection' + JOB_RUN_COLLECTION = 'job_run_state_collection' + SERVICE_COLLECTION = 'service_state_collection' + METADATA_COLLECTION = 'metadata_collection' + + TYPE_TO_COLLECTION_MAP = { + runstate.JOB_STATE: JOB_COLLECTION, + runstate.JOB_RUN_STATE: JOB_RUN_COLLECTION, + runstate.SERVICE_STATE: SERVICE_COLLECTION, + runstate.MCP_STATE: METADATA_COLLECTION + } + + def __init__(self, name, connection_details, serializer=None): + import pymongo + global pymongo + assert pymongo + + self.db_name = name + connection_params = self._parse_connection_details(connection_details) + self._connect(connection_params) + + def _connect(self, params): + hostname = params.get('hostname') + port = int(params.get('port')) + username = params.get('username') + password = params.get('password') + self.connection = pymongo.Connection(hostname, port) + self.db = self.connection[self.db_name] + if username and password: + self.db.authenticate(username, password) + + def _parse_connection_details(self, connection_details): + return dict(urlparse.parse_qsl(connection_details)) if connection_details else {} + + def save(self, key, state_data, data_type): + collection = self.db[self.TYPE_TO_COLLECTION_MAP[data_type]] + state_data['_id'] = key + collection.save(state_data) + return True + + def restore(self, key, data_type): + value = self.db[self.TYPE_TO_COLLECTION_MAP[data_type]].find_one(key) + return (True, value) if value else (False, None) + + def cleanup(self): + self.connection.disconnect() + + def __repr__(self): + return "MongoStore(%s)" % self.db_name + + +class YamlStore(object): + # TODO: Deprecate this, it's bad + """Store state in a local YAML file. + + WARNING: Using this is NOT recommended, even moreso than the previous + version of this (yamlstore.py), since key/value pairs are now saved + INDIVIDUALLY rather than in batches, meaning saves are SLOOOOOOOW. + + How slow, you ask? Converting a standard Shelve store from 0.6.1 into + this object with test_config.yaml (and service_0 enabled) took about 4 + minutes. Going to a Shelve object instead took less than 5 seconds. + + Seriously, you shouldn't use this unless you're doing something + really trivial and/or want a readable Yaml file. + """ + + TYPE_MAPPING = { + runstate.JOB_STATE: 'jobs', + runstate.JOB_RUN_STATE: 'job_runs', + runstate.SERVICE_STATE: 'services', + runstate.MCP_STATE: runstate.MCP_STATE + } + + def __init__(self, filename, connection_details=None, serializer=None): + import yaml + global yaml + assert yaml + + self.filename = filename + if not os.path.exists(self.filename): + self.buffer = {} + else: + with open(self.filename, 'r') as fh: + self.buffer = yaml.load(fh) + + def save(self, key, state_data, data_type): + self.buffer.setdefault(self.TYPE_MAPPING[data_type], {})[key] = state_data + self._write_buffer() + return True + + def _write_buffer(self): + with open(self.filename, 'w') as fh: + yaml.dump(self.buffer, fh) + + def restore(self, key, data_type): + value = self.buffer.get(self.TYPE_MAPPING[data_type], {}).get(key) + return (True, value) if value else (False, None) + + def cleanup(self): + pass + + def __repr__(self): + return "YamlStore('%s')" % self.filename + + +store_class_map = { + "sql": SQLStore, + "shelve": ShelveStore, + "mongo": MongoStore, + "yaml": YamlStore +} + + +def build_store(name, store_type, connection_details, db_store_method): + serial_class = serialize_class_map[db_store_method] if db_store_method != "None" else None + return store_class_map[store_type](name, connection_details, serial_class) + + +class SyncStore(object): + """A store object that synchronizes all save/restore operations on the + store implementation, as we have no idea what could happen due to its + modular nature. + """ + + def __init__(self, config): + """Parse the configuration file and set up the store class.""" + self.lock = Lock() + if not config: + self.store = NullStore() + + else: + name = config.name + store_type = config.store_type + connection_details = config.connection_details + db_store_method = config.db_store_method + + self.store = build_store(name, store_type, connection_details, + db_store_method) + + def save(self, *args, **kwargs): + with self.lock: + return self.store.save(*args, **kwargs) + + def restore(self, *args, **kwargs): + with self.lock: + return self.store.restore(*args, **kwargs) + + def cleanup(self): + with self.lock: + self.store.cleanup() + + def __repr__(self): + return "SyncStore('%s')" % self.store.__repr__() diff --git a/tron/serialize/runstate/tronstore/tronstore.py b/tron/serialize/runstate/tronstore/tronstore.py new file mode 100644 index 000000000..ea0639eb6 --- /dev/null +++ b/tron/serialize/runstate/tronstore/tronstore.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python +import time +import signal +import os +from threading import Thread, Lock +from Queue import Queue, Empty + +from tron.serialize.runstate.tronstore.messages import StoreRequestFactory, StoreResponseFactory +from tron.serialize.runstate.tronstore import store +from tron.serialize.runstate.tronstore import msg_enums + + +def _discard_signal(signum, frame): + pass + + +def _register_null_handlers(): + signal.signal(signal.SIGINT, _discard_signal) + signal.signal(signal.SIGHUP, _discard_signal) + signal.signal(signal.SIGTERM, _discard_signal) + + +def handle_requests(request_queue, resp_factory, pipe, store_class, do_work): + """Handle requests by acting on store_class with the appropriate action. + Requests are taken from request_queue until do_work.val (it should be a + PoolBool) is False. + This is run in a separate thread. + """ + + # This should probably be lower rather than higher + WORK_TIMEOUT = 1.0 + + while do_work.val or not request_queue.empty(): + try: + request = request_queue.get(block=True, timeout=WORK_TIMEOUT) + + if request.req_type == msg_enums.REQUEST_SAVE: + store_class.save(request.data[0], request.data[1], request.data_type) + + elif request.req_type == msg_enums.REQUEST_RESTORE: + success, data = store_class.restore(request.data, request.data_type) + pipe.send_bytes(resp_factory.build(success, request.id, data).serialized) + + else: + pipe.send_bytes(resp_factory.build(False, request.id, '').serialized) + + except Empty: + continue + + +class SyncPipe(object): + """An object to handle synchronization over pipe operations. In particular, + the send and recv functions have a mutex as they are subject to + race conditions. + """ + + def __init__(self, pipe): + self.lock = Lock() + self.pipe = pipe + + # None is actually a valid timeout (blocks forever), so we need to use + # something different for checking for a non-supplied kwarg + def poll(self, *args, **kwargs): + return self.pipe.poll(*args, **kwargs) + + def send_bytes(self, *args, **kwargs): + with self.lock: + return self.pipe.send_bytes(*args, **kwargs) + + def recv_bytes(self, *args, **kwargs): + with self.lock: + return self.pipe.recv_bytes(*args, **kwargs) + + +class PoolBool(object): + """The PoolBool(TM) is a mutable boolean wrapper used for signaling.""" + + def __init__(self, value=True): + if not value in (True, False): + raise TypeError('expected boolean, got %r' % value) + self._val = value + + @property + def value(self): + return self._val + val = value + + def set(self, value): + if not value in (True, False): + raise TypeError('expected boolean, got %r' % value) + self._val = value + + +class TronstorePool(object): + """A thread pool with POOL_SIZE workers for handling requests. Enqueues + save and restore requests into a queue that is then consumed by the + workers, which send an appropriate response. + """ + + POOL_SIZE = 16 + + def __init__(self, resp_fact, pipe, store): + """Initialize the thread pool. Please make a new pool if any of the + objects passed to __init__ change. + """ + self.request_queue = Queue() + self.response_factory = resp_fact + self.pipe = pipe + self.store_class = store + self.keep_working = PoolBool(True) + self.thread_pool = [Thread(target=handle_requests, + args=( + self.request_queue, + self.response_factory, + self.pipe, + self.store_class, + self.keep_working + )) for i in range(self.POOL_SIZE)] + + def start(self): + """Start the thread pool.""" + self.keep_working.set(True) + for thread in self.thread_pool: + thread.daemon = False + thread.start() + + def stop(self): + """Stop the thread pool.""" + self.keep_working.set(False) + while self.has_work() \ + or any([thread.is_alive() for thread in self.thread_pool]): + time.sleep(0.5) + + def enqueue_work(self, work): + """Enqueue a request for the workers to consume and process.""" + self.request_queue.put(work) + + def has_work(self): + """Returns whether there is still work to be consumed by workers.""" + return not self.request_queue.empty() + + +class TronstoreMain(object): + """The main Tronstore class. Initializes a bunch of stuff and then has a + main_loop function that loops and handles requests from trond. + """ + + # this can be rather long- it's only real use it to clean up tronstore + # in case it's orphaned... however, it should be SHORTER than + # SHUTDOWN_TIMEOUT in process.py. in addition, making this longer + # can cause trond to take longer to fully shutdown. + POLL_TIMEOUT = 2.0 + + def __init__(self, config, pipe): + """Sets up the needed objects for Tronstore, including message + factories, a synchronized pipe and store object, a thread pool for + handling requests, and some internal invariants. + """ + self.pipe = SyncPipe(pipe) + self.request_factory = StoreRequestFactory() + self.response_factory = StoreResponseFactory() + self.store_class = store.SyncStore(config) + self.thread_pool = TronstorePool(self.response_factory, self.pipe, + self.store_class) + self.is_shutdown = False + self.shutdown_req_id = None + self.config = config + + def _get_all_from_pipe(self): + """Gets all of the requests from the pipe, returning an array of serialized + requests (they still need to be decoded). + """ + requests = [] + while self.pipe.poll(): + requests.append(self.pipe.recv_bytes()) + return requests + + def _reconfigure(self, request): + """Reconfigures Tronstore by attempting to make a new store object + from the recieved configuration. If anything goes wrong, we revert + back to the old configuration. + """ + self.thread_pool.stop() + self.store_class.cleanup() + try: + self.store_class = store.SyncStore(request.data) + self.thread_pool = TronstorePool(self.response_factory, self.pipe, + self.store_class) + self.thread_pool.start() + self.config = request.data + self.pipe.send_bytes(self.response_factory.build(True, request.id, '').serialized) + except: + self.store_class = store.SyncStore(self.config) + self.thread_pool = TronstorePool(self.response_factory, self.pipe, + self.store_class) + self.thread_pool.start() + self.pipe.send_bytes(self.response_factory.build(False, request.id, '').serialized) + + def _handle_request(self, request): + """Handle a request by either doing something with it ourselves + (in the case of shutdown/config), or passing it to a worker in the + thread pool (for save/restore). + """ + if request.req_type == msg_enums.REQUEST_SHUTDOWN: + self.is_shutdown = True + self.shutdown_req_id = request.id + + elif request.req_type == msg_enums.REQUEST_CONFIG: + self._reconfigure(request) + + else: + self.thread_pool.enqueue_work(request) + + def _shutdown(self): + """Shutdown Tronstore. Calls os._exit, and should only be called + once all work has been completed. + """ + self.thread_pool.stop() + self.store_class.cleanup() + if self.shutdown_req_id: + shutdown_resp = self.response_factory.build(True, self.shutdown_req_id, '') + self.pipe.send_bytes(shutdown_resp.serialized) + os._exit(0) # Hard exit- should kill everything. + + def main_loop(self): + """The main Tronstore event loop. Starts the thread pool and then + simply polls for requests until a shutdown request is recieved, after + which it cleans up and exits. + """ + self.thread_pool.start() + + while True: + try: + if self.pipe.poll(self.POLL_TIMEOUT): + requests = self._get_all_from_pipe() + requests = map(self.request_factory.from_msg, requests) + for request in requests: + self._handle_request(request) + + elif self.is_shutdown: + self._shutdown() + + else: + # Did tron die? + try: + os.kill(os.getppid(), 0) + except: + self.is_shutdown = True + except IOError, e: + # Error #4 is a system interrupt, caused by ^C + if e.errno != 4: + raise + + +def main(config, pipe): + """The main method to start Tronstore with. Simply takes the configuration + and pipe objects, and then registers some null signal handlers before + passing everything off to TronstoreMain. + + This process is spawned by trond in order to offload state save/load + operations such that trond can focus on the more important things without + blocking for chunks of time. + + Messages are sent via Pipes (also part of python's multiprocessing module). + This allows for easy polling and no need to handle chunking of messages. + + The process intercepts the two shutdown signals (SIGINT and SIGTERM) in order + to prevent the process from exiting early when trond wants to do some final + shutdown things (realistically, trond should be handling all shutdown + operations, as this is a child process.) + """ + + _register_null_handlers() + tronstore = TronstoreMain(config, pipe) + tronstore.main_loop() diff --git a/tron/trondaemon.py b/tron/trondaemon.py index bdbb508f4..077637714 100644 --- a/tron/trondaemon.py +++ b/tron/trondaemon.py @@ -154,6 +154,7 @@ def _build_context(self, options, context_class): signal.SIGINT: self._handle_graceful_shutdown, signal.SIGTERM: self._handle_shutdown, } + pidfile = PIDFile(options.pid_file) return context_class( working_directory=options.working_dir, @@ -199,7 +200,8 @@ def _run_mcp(self): def _run_reactor(self): """Run the twisted reactor.""" - self.reactor.run() + # Not setting this flag caused me 9 painful hours of debugging =( + self.reactor.run(installSignalHandlers=0) def _handle_shutdown(self, sig_num, stack_frame): log.info("Shutdown requested: sig %s" % sig_num)