Skip to content

Commit

Permalink
Merge pull request apel#333 from tofu-rocketry/fix-mysql-tests
Browse files Browse the repository at this point in the history
Fix mysql tests
  • Loading branch information
tofu-rocketry authored Oct 26, 2023
2 parents 4351d5a + d8a4cdd commit c3261bf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
42 changes: 21 additions & 21 deletions test/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ def setUp(self):
subprocess.call(['mysql', '-u', 'root', 'apel_unittest'], stdin=schema_handle)
schema_handle.close()

self.db = apel.db.apeldb.ApelDb('mysql', 'localhost', 3306, 'root', '',
'apel_unittest')
self.apel_db = apel.db.apeldb.ApelDb('mysql', 'localhost', 3306, 'root', '',
'apel_unittest')

# This method seems to run really slowly on Travis CI
#def tearDown(self):
# query = "DROP DATABASE apel_unittest;"
# subprocess.call(['mysql', '-u', 'root', '-e', query])
def tearDown(self):
self.apel_db.db.close()
query = "DROP DATABASE apel_unittest;"
subprocess.call(['mysql', '-u', 'root', '-e', query])

def test_test_connection(self):
"""Basic check that test_connection works without error."""
self.db.test_connection()
self.apel_db.test_connection()

def test_bad_connection(self):
"""Check that initialising ApelDb fails if a bad password is used."""
Expand All @@ -53,15 +53,15 @@ def test_lost_connection(self):
Simulate the lost connection by changing the host.
"""
self.db._db_host = 'badhost'
self.apel_db._db_host = 'badhost'
self.assertRaises(apel.db.apeldb.ApelDbException,
self.db.test_connection)
self.apel_db.test_connection)

def test_bad_loads(self):
"""Check that empty loads return None and bad types raise exception."""
self.assertTrue(self.db.load_records([], source='testDN') is None)
self.assertIsNone(self.apel_db.load_records([], source='testDN'))
self.assertRaises(apel.db.apeldb.ApelDbException,
self.db.load_records, [1234], source='testDN')
self.apel_db.load_records, [1234], source='testDN')

def test_load_and_get_job(self):
job = apel.db.records.job.JobRecord()
Expand All @@ -76,9 +76,9 @@ def test_load_and_get_job(self):
record_list = [job]
# load_records changes the 'job' job record as it calls _check_fields
# which adds placeholders to empty fields
self.db.load_records(record_list, source='testDN')
self.apel_db.load_records(record_list, source='testDN')

records_out = self.db.get_records(apel.db.records.job.JobRecord)
records_out = self.apel_db.get_records(apel.db.records.job.JobRecord)
items_out = list(records_out)[0][0]._record_content.items()
# Check that items_in is a subset of items_out
# Can't use 'all()' rather than comparing the length as Python 2.4
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_load_and_get_cloud(self):
# load_records changes the 'cloud' cloud record as it calls _check_fields
# which adds placeholders to empty fields
try:
self.db.load_records(record_list, source='testDN')
self.apel_db.load_records(record_list, source='testDN')
except apel.db.apeldb.ApelDbException as err:
self.fail(err.message)

Expand All @@ -138,7 +138,7 @@ def test_load_and_get_cloud(self):
# a cloud 0.2 message has Benchmark 'None', which gets saved to 0.0
# in the database

# records_out = self.db.get_records(apel.db.records.cloud.CloudRecord)
# records_out = self.apel_db.get_records(apel.db.records.cloud.CloudRecord)
# # record_out_list is a list of lists, i.e. [[record0.2],[[record0.4]]
# record_out_list = list(records_out)
# items_out = []
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_mixed_load(self):
record_list = [job, summary]

self.assertRaises(apel.db.apeldb.ApelDbException,
self.db.load_records, record_list, source='testDN')
self.apel_db.load_records, record_list, source='testDN')

def test_mixed_storage_records(self):
"""
Expand Down Expand Up @@ -205,8 +205,8 @@ def test_mixed_storage_records(self):

# Try loading both with and without a source set. Both record types
# should ignore that field.
self.db.load_records(record_list, source='testDN')
self.db.load_records(record_list)
self.apel_db.load_records(record_list, source='testDN')
self.apel_db.load_records(record_list)

def test_last_update(self):
"""
Expand All @@ -215,9 +215,9 @@ def test_last_update(self):
It should not be set initially, so should return None, then should
return a time after being set.
"""
self.assertTrue(self.db.get_last_updated() is None)
self.assertTrue(self.db.set_updated())
self.assertTrue(type(self.db.get_last_updated()) is datetime.datetime)
self.assertIsNone(self.apel_db.get_last_updated())
self.assertTrue(self.apel_db.set_updated())
self.assertIs(type(self.apel_db.get_last_updated()), datetime.datetime)

CLOUD2 = '''VMUUID: 12345 Site1 vm-1
SiteName: Site1
Expand Down
4 changes: 4 additions & 0 deletions test/test_republish.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def test_cloud_republish(self):
# Now check the database to see which record has been saved.
self._check_measurement_time_equals(expected_measurement_time)

# Clean up DB connection and schema.
database.db.close()
call(['mysql', '-u', 'root', '-e', "DROP DATABASE apel_unittest;"])

def _check_measurement_time_equals(self, expected_measurement_time):
"""
Check MeasurementTime in database is what we would expect.
Expand Down

0 comments on commit c3261bf

Please sign in to comment.