From a96b663b5f6435ceb8199da936c7109ffd0c1a75 Mon Sep 17 00:00:00 2001 From: Shezad Khan Date: Thu, 16 Nov 2023 00:33:21 +0000 Subject: [PATCH] fix: use column names for database mappers (#335) * fix: use column names for database mappers When writing results where the metrics include 'confusion_matrix', only the first column name is written. In the case of the confusion_matrix it is "true_positive". The desired behaviour is to write all column values. * Apply same practice for CBPE mapper * Add tests dealing with result components --------- Co-authored-by: Niels Nuyttens --- nannyml/io/db/mappers.py | 6 ++++-- tests/io/test_writers.py | 38 ++++++++++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/nannyml/io/db/mappers.py b/nannyml/io/db/mappers.py index 2254124de..65030ee2c 100644 --- a/nannyml/io/db/mappers.py +++ b/nannyml/io/db/mappers.py @@ -234,7 +234,9 @@ def _parse( res: List[DbMetric] = [] - for metric in [metric.column_name for metric in result.metrics]: + column_names = [column_name for metric in result.metrics for column_name in metric.column_names] + + for metric in column_names: res += ( result.filter(partition='analysis') .to_df()[ @@ -288,7 +290,7 @@ def _parse( res: List[Metric] = [] - for metric in [component[1] for metric in result.metrics for component in metric.components]: + for metric in [column_name for metric in result.metrics for column_name in metric.column_names]: res += ( result.filter(period='analysis') .to_df()[ diff --git a/tests/io/test_writers.py b/tests/io/test_writers.py index b05379ef1..932c79469 100644 --- a/tests/io/test_writers.py +++ b/tests/io/test_writers.py @@ -112,7 +112,7 @@ def realized_performance_for_binary_classification_result(): y_true='work_home_actual', problem_type='classification_binary', timestamp_column_name='timestamp', - metrics=['roc_auc', 'f1'], + metrics=['roc_auc', 'f1', 'confusion_matrix'], ).fit(reference_df) result = calc.calculate(analysis_df.merge(analysis_targets_df, on='identifier')) return result @@ -160,7 +160,7 @@ def cbpe_estimated_performance_for_binary_classification_result(): y_true='work_home_actual', problem_type='classification_binary', timestamp_column_name='timestamp', - metrics=['roc_auc', 'f1'], + metrics=['roc_auc', 'f1', 'confusion_matrix'], ).fit(reference_df) result = calc.estimate(analysis_df.merge(analysis_targets_df, on='identifier')) return result @@ -355,14 +355,14 @@ def test_pickle_file_writer_raises_no_exceptions_when_writing(result): 'data_reconstruction_feature_drift_metrics', 10, ), - (lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics', 40), + (lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics', 120), ( lazy_fixture('realized_performance_for_multiclass_classification_result'), 'realized_performance_metrics', 40, ), (lazy_fixture('realized_performance_for_regression_result'), 'realized_performance_metrics', 40), - (lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics', 20), + (lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics', 60), ( lazy_fixture('cbpe_estimated_performance_for_multiclass_classification_result'), 'cbpe_performance_metrics', @@ -389,3 +389,33 @@ def test_database_writer_exports_correctly(result, table_name, expected_row_coun finally: os.remove('test.db') + + +@pytest.mark.parametrize( + 'result, table_name', + [ + (lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics'), + (lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics'), + ], +) +def test_database_writer_deals_with_metric_components(result, table_name): + try: + writer = DatabaseWriter(connection_string='sqlite:///test.db', model_name='test') + writer.write(result.filter(metrics=['confusion_matrix'])) + + import sqlite3 + + with sqlite3.connect("test.db", uri=True) as db: + res = db.cursor().execute(f"SELECT DISTINCT metric_name FROM {table_name}").fetchall() + sut = [row[0] for row in res] + + assert 'true_positive' in sut + assert 'false_positive' in sut + assert 'true_negative' in sut + assert 'false_negative' in sut + + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + finally: + os.remove('test.db')