diff --git a/fiasco/conftest.py b/fiasco/conftest.py index 5843f5ab..a5e2e6e2 100644 --- a/fiasco/conftest.py +++ b/fiasco/conftest.py @@ -1,6 +1,9 @@ +import numpy as np import pathlib import pytest +from packaging.version import Version + from fiasco.util import check_database, read_chianti_version # Force MPL to use non-gui backends for testing. @@ -160,14 +163,23 @@ def requires_dbase_version(request, dbase_version): # NOTE: Fixtures that depend on other fixtures are awkward to implement. # See this SO answer: https://stackoverflow.com/a/28198398 if marker := request.node.get_closest_marker('requires_dbase_version'): - version_condition = marker.args[0] - # NOTE: This currently only works for major versions. If we - # need tests that discriminate between minor versions or patches, - # this logic will need to be added. - conditional = f'{dbase_version["major"]}{version_condition}' - allowed_dbase_version = eval(conditional) + # NOTE: This has to have a space between the operator and the target + if len(marker.args) != 2: + raise ValueError("Arguments must contain a condition and a version number, e.g. '<', '8.0.7'") + operator, target_version = marker.args + op_dict = {'<': np.less, + '<=': np.less_equal, + '>': np.greater, + '>=': np.greater_equal, + '=': np.equal, + '!=': np.not_equal} + if operator not in op_dict: + raise ValueError(f'''{operator} is not a supported comparison operation. + Must be one of {list(op_dict.keys())}.''') + target_version = Version(target_version) + allowed_dbase_version = op_dict[operator](dbase_version, target_version) if not allowed_dbase_version: - pytest.skip(f'Test skipped because current dbase version {conditional}.') + pytest.skip(f'Skip because database version {dbase_version} is not {operator} {target_version}.') def pytest_configure(config): diff --git a/fiasco/io/sources/tests/test_sources.py b/fiasco/io/sources/tests/test_sources.py index 02d039b2..bd1d291d 100644 --- a/fiasco/io/sources/tests/test_sources.py +++ b/fiasco/io/sources/tests/test_sources.py @@ -22,8 +22,8 @@ 'fe_2.trparams', 'fe_12.drparams', 'al_3.diparams', - pytest.param('fe_23.auto', marks=pytest.mark.requires_dbase_version('>=9')), - pytest.param('fe_23.rrlvl', marks=pytest.mark.requires_dbase_version('>=9')), + pytest.param('fe_23.auto', marks=pytest.mark.requires_dbase_version('>=', '9')), + pytest.param('fe_23.rrlvl', marks=pytest.mark.requires_dbase_version('>=', '9')), ]) def test_ion_sources(ascii_dbase_root, filename,): parser = fiasco.io.Parser(filename, ascii_dbase_root=ascii_dbase_root) diff --git a/fiasco/tests/test_ion.py b/fiasco/tests/test_ion.py index 13c5c8ae..2e42f227 100644 --- a/fiasco/tests/test_ion.py +++ b/fiasco/tests/test_ion.py @@ -244,7 +244,7 @@ def test_level_populations_normalized(pops_no_correction, pops_with_correction): assert u.allclose(pops_no_correction.sum(axis=1), 1, atol=None, rtol=1e-15) -@pytest.mark.requires_dbase_version('<=8') +@pytest.mark.requires_dbase_version('<=','8') def test_level_populations_correction(fe20, pops_no_correction, pops_with_correction): # Test level-resolved correction applied to correct levels i_corrected = np.unique(np.concatenate([fe20._cilvl['upper_level'], fe20._reclvl['upper_level']]))