Skip to content

Commit

Permalink
Add tests dealing with result components
Browse files Browse the repository at this point in the history
  • Loading branch information
nnansters committed Nov 16, 2023
1 parent 30c117e commit 15876c6
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions tests/io/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def realized_performance_for_binary_classification_result():
y_true='work_home_actual',
problem_type='classification_binary',
timestamp_column_name='timestamp',
metrics=['roc_auc', 'f1'],
metrics=['roc_auc', 'f1', 'confusion_matrix'],
).fit(reference_df)
result = calc.calculate(analysis_df.merge(analysis_targets_df, on='identifier'))
return result
Expand Down Expand Up @@ -160,7 +160,7 @@ def cbpe_estimated_performance_for_binary_classification_result():
y_true='work_home_actual',
problem_type='classification_binary',
timestamp_column_name='timestamp',
metrics=['roc_auc', 'f1'],
metrics=['roc_auc', 'f1', 'confusion_matrix'],
).fit(reference_df)
result = calc.estimate(analysis_df.merge(analysis_targets_df, on='identifier'))
return result
Expand Down Expand Up @@ -355,14 +355,14 @@ def test_pickle_file_writer_raises_no_exceptions_when_writing(result):
'data_reconstruction_feature_drift_metrics',
10,
),
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics', 40),
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics', 120),
(
lazy_fixture('realized_performance_for_multiclass_classification_result'),
'realized_performance_metrics',
40,
),
(lazy_fixture('realized_performance_for_regression_result'), 'realized_performance_metrics', 40),
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics', 20),
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics', 60),
(
lazy_fixture('cbpe_estimated_performance_for_multiclass_classification_result'),
'cbpe_performance_metrics',
Expand All @@ -389,3 +389,33 @@ def test_database_writer_exports_correctly(result, table_name, expected_row_coun

finally:
os.remove('test.db')


@pytest.mark.parametrize(
'result, table_name',
[
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics'),
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics'),
],
)
def test_database_writer_deals_with_metric_components(result, table_name):
try:
writer = DatabaseWriter(connection_string='sqlite:///test.db', model_name='test')
writer.write(result.filter(metrics=['confusion_matrix']))

import sqlite3

with sqlite3.connect("test.db", uri=True) as db:
res = db.cursor().execute(f"SELECT DISTINCT metric_name FROM {table_name}").fetchall()
sut = [row[0] for row in res]

assert 'true_positive' in sut
assert 'false_positive' in sut
assert 'true_negative' in sut
assert 'false_negative' in sut

except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")

finally:
os.remove('test.db')

0 comments on commit 15876c6

Please sign in to comment.