diff --git a/tests/core/service_test.py b/tests/core/service_test.py index 847ff7113..01f12371e 100644 --- a/tests/core/service_test.py +++ b/tests/core/service_test.py @@ -149,6 +149,28 @@ def test_restore_state(self): self.instances.restore_state.return_value) self.service.enable.assert_called_with() + def test_update_node_pool_enabled(self): + autospec_method(self.service.repair) + self.service.enabled = True + node_pool = mock.Mock() + + self.service.update_node_pool(node_pool) + + self.service.instances.update_node_pool.assert_called_once_with(node_pool) + self.service.instances.clear_extra.assert_called_once_with() + self.service.repair.assert_called_once_with() + + def test_update_node_pool_disabled(self): + autospec_method(self.service.repair) + self.service.enabled = False + node_pool = mock.Mock() + + self.service.update_node_pool(node_pool) + + self.service.instances.update_node_pool.assert_called_once_with(node_pool) + self.service.instances.clear_extra.assert_called_once_with() + assert not self.service.repair.called + class ServiceCollectionTestCase(TestCase): @@ -162,25 +184,64 @@ def _add_service(self): self.collection.services.update( (serv.name, serv) for serv in self.service_list) + def test_build_with_new_config(self): + new_config = mock.Mock( + name='i_come_from_the_land_of_ice_and_snow', + count=42) + new_service = mock.Mock(config=new_config) + old_service = mock.Mock() + with mock.patch.object(self.collection, 'get_by_name', return_value=old_service) \ + as get_patch: + assert not self.collection._build(new_service) + assert not old_service.update_node_pool.called + get_patch.assert_called_once_with(new_config.name) + + def test_build_with_same_config(self): + config = mock.Mock( + name='hamsteak', + count=413) + old_service = mock.Mock(config=config) + new_service = mock.Mock(config=config) + with mock.patch.object(self.collection, 'get_by_name', return_value=old_service) \ + as get_patch: + assert self.collection._build(new_service) + get_patch.assert_called_once_with(config.name) + old_service.update_node_pool.assert_called_once_with(new_service.instances.node_pool) + assert_equal(old_service.instances.context, new_service.instances.context) + + def test_build_with_diff_count(self): + name = 'ni' + old_config = mock.Mock( + count=77) + new_config = mock.Mock( + count=1111111111111) + new_eq = lambda s, o: (s.name == o.name and s.count == o.count) + old_config.__eq__ = new_eq + new_config.__eq__ = new_eq + # We have to do this, since name is an actual kwarg for mock.Mock(). + old_config.name = name + new_config.name = name + old_service = mock.Mock(config=old_config) + new_service = mock.Mock(config=new_config) + with mock.patch.object(self.collection, 'get_by_name', return_value=old_service) \ + as get_patch: + assert self.collection._build(new_service) + get_patch.assert_called_once_with(new_service.config.name) + old_service.update_node_pool.assert_called_once_with(new_service.instances.node_pool) + assert_equal(old_service.instances.context, new_service.instances.context) + @mock.patch('tron.core.service.Service', autospec=True) def test_load_from_config(self, mock_service): autospec_method(self.collection.get_names) - autospec_method(self.collection.add) + autospec_method(self.collection.services.add) service_configs = {'a': mock.Mock(), 'b': mock.Mock()} context = mock.create_autospec(command_context.CommandContext) result = list(self.collection.load_from_config(service_configs, context)) expected = [mock.call(config, context) for config in service_configs.itervalues()] assert_mock_calls(expected, mock_service.from_config.mock_calls) - expected = [mock.call(s) for s in result] - assert_mock_calls(expected, self.collection.add.mock_calls) - - def test_add(self): - self.collection.services = mock.MagicMock() - service = mock.Mock() - result = self.collection.add(service) - self.collection.services.replace.assert_called_with(service) - assert_equal(result, self.collection.services.replace.return_value) + expected = [mock.call(s, self.collection._build) for s in result] + assert_mock_calls(expected, self.collection.services.add.mock_calls) def test_restore_state(self): state_count = 2 diff --git a/tests/core/serviceinstance_test.py b/tests/core/serviceinstance_test.py index ff75002e3..e4846a557 100644 --- a/tests/core/serviceinstance_test.py +++ b/tests/core/serviceinstance_test.py @@ -502,6 +502,48 @@ def test_get_by_number(self): instance = self.collection.get_by_number(3) assert_equal(instance, instances[3]) + def test_update_node_pool_same_pool(self): + self.collection.update_node_pool(self.collection.node_pool) + assert not self.collection.node_pool.get_by_name.called + + def test_update_node_pool_diff_pool_same_nodes(self): + new_instances = [mock.Mock(), mock.Mock()] + self.collection.instances = new_instances + nodes = [instance.node for instance in new_instances] + node_pool = mock.Mock(get_by_name=mock.Mock(side_effect=iter(nodes))) + + self.collection.update_node_pool(node_pool) + + assert_equal(self.collection.node_pool, node_pool) + calls = [mock.call(instance.node.name) for instance in new_instances] + node_pool.get_by_name.assert_calls(calls) + assert not any([instance.stop.called for instance in new_instances]) + assert_equal(self.collection.instances, new_instances) + + def test_update_node_pool_diff_everything(self): + new_instances = [mock.Mock(), mock.Mock()] + self.collection.instances = [mock.Mock(), mock.Mock()] + nodes = [instance.node for instance in new_instances] + node_pool = mock.Mock(get_by_name=mock.Mock(side_effect=iter(nodes))) + + self.collection.update_node_pool(node_pool) + + assert_equal(self.collection.node_pool, node_pool) + calls = [mock.call(instance.node.name) for instance in self.collection.instances] + node_pool.get_by_name.assert_calls(calls) + assert all([instance.stop.called for instance in self.collection.instances]) + assert_equal(self.collection.instances, []) + + def test_clear_extra(self): + instance_a = mock.Mock() + instance_b = mock.Mock() + instance_c = mock.Mock() + self.collection.instances = [instance_a, instance_b, instance_c] + self.collection.config.count = 2 + self.collection.clear_extra() + assert_equal(self.collection.instances, [instance_a, instance_b]) + instance_c.stop.assert_called_once_with() + if __name__ == "__main__": run() diff --git a/tests/trond_test.py b/tests/trond_test.py index 57cdff306..3aae26b17 100644 --- a/tests/trond_test.py +++ b/tests/trond_test.py @@ -153,7 +153,7 @@ def test_node_reconfig(self): self.sandbox.tronfig(second_config) sandbox.wait_on_state(self.client.service, service_url, - service.ServiceState.DISABLED) + service.ServiceState.FAILED) job_url = self.client.get_url('MASTER.a_job') def wait_on_next_run(): diff --git a/tron/core/service.py b/tron/core/service.py index c9493f3b7..1034a4c2f 100644 --- a/tron/core/service.py +++ b/tron/core/service.py @@ -156,6 +156,12 @@ def restore_state(self, state_data): (self.enable if state_data.get('enabled') else self.disable)() self.event_recorder.info("restored") + def update_node_pool(self, node_pool): + self.instances.update_node_pool(node_pool) + self.instances.clear_extra() + if self.enabled: + self.repair() + class ServiceCollection(object): """A collection of services.""" @@ -163,21 +169,46 @@ class ServiceCollection(object): def __init__(self): self.services = collections.MappingCollection('services') + def _build(self, new_service): + """A method to be used as an update function for MappingCollection.add. + This function attempts to load an old Service object, and if one + exists, see if we don't actually have to use an entirely new + Service object on reconfiguration. + + To do this, we first check if the number of instances (config.count) is + different, as we have a method to fix this when updating the service's + node pool. Then, if the configs are now equal, we can simply update + the node pool of the old Service object and be done- no need for the + new Service object. Otherwise, we use the new object as normal. + """ + old_service = self.get_by_name(new_service.config.name) + + if not old_service: + log.debug("Building new service %s", new_service.config.name) + return False + + if old_service.config.count != new_service.config.count: + old_service.config.count = new_service.config.count + + if old_service.config == new_service.config: + log.debug("Updating service %s\'s node pool" % new_service.config.name) + old_service.instances.context = new_service.instances.context + old_service.update_node_pool(new_service.instances.node_pool) + return True + else: + log.debug("Building new service %s", new_service.config.name) + old_service.disable() + return False + def load_from_config(self, service_configs, context): """Apply a configuration to this collection and return a generator of services which were added. """ self.services.filter_by_name(service_configs.keys()) - def build(config): - log.debug("Building new service %s", config.name) - return Service.from_config(config, context) - - seq = (build(config) for config in service_configs.itervalues()) - return itertools.ifilter(self.add, seq) - - def add(self, service): - return self.services.replace(service) + seq = (Service.from_config(config, context) + for config in service_configs.itervalues()) + return itertools.ifilter(lambda e: self.services.add(e, self._build), seq) def restore_state(self, service_state_data): self.services.restore_state(service_state_data) diff --git a/tron/core/serviceinstance.py b/tron/core/serviceinstance.py index d56e0218f..abc34b8e6 100644 --- a/tron/core/serviceinstance.py +++ b/tron/core/serviceinstance.py @@ -461,6 +461,30 @@ def get_by_number(self, instance_number): if instance.instance_number == instance_number: return instance + def update_node_pool(self, node_pool): + """Attempt to load a new node pool from the NodePoolRepository, and + remove instances that no longer have their node in the NodePool.""" + if node_pool == self.node_pool: + return + + self.node_pool = node_pool + + def _trim_old_nodes(): + for instance in self.instances: + new_node = self.node_pool.get_by_name(instance.node.name) + if new_node != instance.node: + instance.stop() + else: + yield instance + + self.instances = list(_trim_old_nodes()) + + def clear_extra(self): + """Clear out instances if too many exist.""" + for i in range(0, self.missing, -1): + instance = self.instances.pop() + instance.stop() + @property def missing(self): return self.config.count - len(self.instances)