From 834a7e4c615fa2acbef9fb5d515274280274498b Mon Sep 17 00:00:00 2001 From: michaelglenister Date: Tue, 21 Nov 2023 15:40:04 +0200 Subject: [PATCH] Removing some magic numbers from test --- infrastructure/tests/test_data_import.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/infrastructure/tests/test_data_import.py b/infrastructure/tests/test_data_import.py index e9714e6eb..0d9210316 100644 --- a/infrastructure/tests/test_data_import.py +++ b/infrastructure/tests/test_data_import.py @@ -158,7 +158,7 @@ def test_file_upload(self): fy = FinancialYear.objects.get(budget_year="2019/2020") self.assertEquals(AnnualSpendFile.objects.all().count(), 0) - self.assertEqual(OrmQ.objects.count(), 0) + orm_count = OrmQ.objects.count() # the app name, the name of the model and the name of the view upload_url = reverse('admin:infrastructure_annualspendfile_add') @@ -167,18 +167,15 @@ def test_file_upload(self): resp = self.client.post(upload_url, {'financial_year': fy.pk, 'document': f}, follow=True) self.assertContains(resp, "Dataset is currently being processed.", status_code=200) - spend_file = AnnualSpendFile.objects.first() + spend_file = AnnualSpendFile.objects.get(id=1) self.assertEquals(spend_file.status, AnnualSpendFile.PROGRESS) - self.assertEqual(OrmQ.objects.count(), 1) - task = OrmQ.objects.first() - task_file_id = task.task()["args"][0] + self.assertEqual(OrmQ.objects.count(), orm_count + 1) + task = OrmQ.objects.get(id=orm_count+1) task_method = task.func() - self.assertEqual(task_method, 'infrastructure.upload.process_annual_document') - self.assertEqual(task_file_id, spend_file.id) + self.assertEqual(task_method, "infrastructure.upload.process_annual_document") # run the code - process_annual_document(task_file_id) - + process_annual_document(spend_file.id) self.assertEquals(AnnualSpendFile.objects.count(), 1) spend_file = AnnualSpendFile.objects.first() self.assertEquals(spend_file.status, AnnualSpendFile.SUCCESS) @@ -186,7 +183,7 @@ def test_file_upload(self): response = self.client.get("/api/v1/infrastructure/search/") self.assertEqual(response.status_code, 200) - self.assertContains(response, 'PC002003005_00002') + self.assertContains(response, "PC002003005_00002") def test_file_upload_fail(self):