diff --git a/tests/core/service_test.py b/tests/core/service_test.py index 24d2a4afa..5868e190f 100644 --- a/tests/core/service_test.py +++ b/tests/core/service_test.py @@ -152,18 +152,36 @@ def test_restore_state(self): def test_update_node_pool_enabled(self): autospec_method(self.service.repair) self.service.enabled = True - self.service.update_node_pool() - self.service.instances.update_node_pool.assert_called_once_with() - self.service.instances.clear_extra.assert_called_once_with() - self.service.repair.assert_called_once_with() + node_pool = mock.Mock() + + with mock.patch('tron.node.NodePoolRepository', autospec=True) as pool_patch: + get_mock = pool_patch.get_instance().get_by_name + get_mock.configure_mock(return_value=node_pool) + + self.service.update_node_pool() + + assert_equal(pool_patch.get_instance.call_count, 2) + get_mock.assert_called_once_with(self.service.config.node) + self.service.instances.update_node_pool.assert_called_once_with(node_pool) + self.service.instances.clear_extra.assert_called_once_with() + self.service.repair.assert_called_once_with() def test_update_node_pool_disabled(self): autospec_method(self.service.repair) self.service.enabled = False - self.service.update_node_pool() - self.service.instances.update_node_pool.assert_called_once_with() - self.service.instances.clear_extra.assert_called_once_with() - assert not self.service.repair.called + node_pool = mock.Mock() + + with mock.patch('tron.node.NodePoolRepository', autospec=True) as pool_patch: + get_mock = pool_patch.get_instance().get_by_name + get_mock.configure_mock(return_value=node_pool) + + self.service.update_node_pool() + + assert_equal(pool_patch.get_instance.call_count, 2) + get_mock.assert_called_once_with(self.service.config.node) + self.service.instances.update_node_pool.assert_called_once_with(node_pool) + self.service.instances.clear_extra.assert_called_once_with() + assert not self.service.repair.called class ServiceCollectionTestCase(TestCase): @@ -227,7 +245,7 @@ def test_build_with_diff_count(self): def test_load_from_config(self): autospec_method(self.collection.get_names) - autospec_method(self.collection.add) + autospec_method(self.collection.services.replace) autospec_method(self.collection._build) service_configs = {'a': mock.Mock(), 'b': mock.Mock()} context = mock.create_autospec(command_context.CommandContext) @@ -237,14 +255,7 @@ def test_load_from_config(self): for config in service_configs.itervalues()] build_patch.assert_calls(expected) expected = [mock.call(s) for s in result] - assert_mock_calls(expected, self.collection.add.mock_calls) - - def test_add(self): - self.collection.services = mock.MagicMock() - service = mock.Mock() - result = self.collection.add(service) - self.collection.services.replace.assert_called_with(service) - assert_equal(result, self.collection.services.replace.return_value) + assert_mock_calls(expected, self.collection.services.replace.mock_calls) def test_restore_state(self): state_count = 2 diff --git a/tests/core/serviceinstance_test.py b/tests/core/serviceinstance_test.py index 72ed3db1e..e4846a557 100644 --- a/tests/core/serviceinstance_test.py +++ b/tests/core/serviceinstance_test.py @@ -503,12 +503,8 @@ def test_get_by_number(self): assert_equal(instance, instances[3]) def test_update_node_pool_same_pool(self): - with mock.patch('tron.node.NodePoolRepository', autospec=True) as pool_patch: - get_mock = pool_patch.get_instance().get_by_name - get_mock.configure_mock(return_value=self.collection.node_pool) - self.collection.update_node_pool() - assert_equal(pool_patch.get_instance.call_count, 2) - get_mock.assert_called_once_with(self.collection.config.node) + self.collection.update_node_pool(self.collection.node_pool) + assert not self.collection.node_pool.get_by_name.called def test_update_node_pool_diff_pool_same_nodes(self): new_instances = [mock.Mock(), mock.Mock()] @@ -516,19 +512,13 @@ def test_update_node_pool_diff_pool_same_nodes(self): nodes = [instance.node for instance in new_instances] node_pool = mock.Mock(get_by_name=mock.Mock(side_effect=iter(nodes))) - with mock.patch('tron.node.NodePoolRepository', autospec=True) as pool_patch: - get_mock = pool_patch.get_instance().get_by_name - get_mock.configure_mock(return_value=node_pool) + self.collection.update_node_pool(node_pool) - self.collection.update_node_pool() - - assert_equal(pool_patch.get_instance.call_count, 2) - get_mock.assert_called_once_with(self.collection.config.node) - assert_equal(self.collection.node_pool, node_pool) - calls = [mock.call(instance.node.name) for instance in new_instances] - node_pool.get_by_name.assert_calls(calls) - assert not any([instance.stop.called for instance in new_instances]) - assert_equal(self.collection.instances, new_instances) + assert_equal(self.collection.node_pool, node_pool) + calls = [mock.call(instance.node.name) for instance in new_instances] + node_pool.get_by_name.assert_calls(calls) + assert not any([instance.stop.called for instance in new_instances]) + assert_equal(self.collection.instances, new_instances) def test_update_node_pool_diff_everything(self): new_instances = [mock.Mock(), mock.Mock()] @@ -536,19 +526,13 @@ def test_update_node_pool_diff_everything(self): nodes = [instance.node for instance in new_instances] node_pool = mock.Mock(get_by_name=mock.Mock(side_effect=iter(nodes))) - with mock.patch('tron.node.NodePoolRepository', autospec=True) as pool_patch: - get_mock = pool_patch.get_instance().get_by_name - get_mock.configure_mock(return_value=node_pool) - - self.collection.update_node_pool() + self.collection.update_node_pool(node_pool) - assert_equal(pool_patch.get_instance.call_count, 2) - get_mock.assert_called_once_with(self.collection.config.node) - assert_equal(self.collection.node_pool, node_pool) - calls = [mock.call(instance.node.name) for instance in self.collection.instances] - node_pool.get_by_name.assert_calls(calls) - assert all([instance.stop.called for instance in self.collection.instances]) - assert_equal(self.collection.instances, []) + assert_equal(self.collection.node_pool, node_pool) + calls = [mock.call(instance.node.name) for instance in self.collection.instances] + node_pool.get_by_name.assert_calls(calls) + assert all([instance.stop.called for instance in self.collection.instances]) + assert_equal(self.collection.instances, []) def test_clear_extra(self): instance_a = mock.Mock() diff --git a/tron/core/service.py b/tron/core/service.py index 58c010266..94e37332b 100644 --- a/tron/core/service.py +++ b/tron/core/service.py @@ -157,7 +157,10 @@ def restore_state(self, state_data): self.event_recorder.info("restored") def update_node_pool(self): - self.instances.update_node_pool() + node_repo = node.NodePoolRepository.get_instance() + node_pool = node_repo.get_by_name(self.config.node) + + self.instances.update_node_pool(node_pool) self.instances.clear_extra() if self.enabled: self.repair() @@ -190,10 +193,7 @@ def load_from_config(self, service_configs, context): seq = (self._build(config, context) for config in service_configs.itervalues()) - return itertools.ifilter(self.add, seq) - - def add(self, service): - return self.services.replace(service) + return itertools.ifilter(self.services.replace, seq) def restore_state(self, service_state_data): self.services.restore_state(service_state_data) diff --git a/tron/core/serviceinstance.py b/tron/core/serviceinstance.py index e260973c5..5e063a397 100644 --- a/tron/core/serviceinstance.py +++ b/tron/core/serviceinstance.py @@ -461,12 +461,9 @@ def get_by_number(self, instance_number): if instance.instance_number == instance_number: return instance - def update_node_pool(self): + def update_node_pool(self, node_pool): """Attempt to load a new node pool from the NodePoolRepository, and remove instances that no longer have their node in the NodePool.""" - node_repo = node.NodePoolRepository.get_instance() - node_pool = node_repo.get_by_name(self.config.node) - if node_pool != self.node_pool: self.node_pool = node_pool needs_new_node = []