Skip to content

Commit

Permalink
fix: use column names for database mappers (#335)
Browse files Browse the repository at this point in the history
* fix: use column names for database mappers

When writing results where the metrics include 'confusion_matrix', only
the first column name is written. In the case of the confusion_matrix it
is "true_positive". The desired behaviour is to write all column values.

* Apply same practice for CBPE mapper

* Add tests dealing with result components

---------

Co-authored-by: Niels Nuyttens <niels@nannyml.com>
  • Loading branch information
shezadkhan137 and nnansters authored Nov 16, 2023
1 parent bb28916 commit a96b663
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
6 changes: 4 additions & 2 deletions nannyml/io/db/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def _parse(

res: List[DbMetric] = []

for metric in [metric.column_name for metric in result.metrics]:
column_names = [column_name for metric in result.metrics for column_name in metric.column_names]

for metric in column_names:
res += (
result.filter(partition='analysis')
.to_df()[
Expand Down Expand Up @@ -288,7 +290,7 @@ def _parse(

res: List[Metric] = []

for metric in [component[1] for metric in result.metrics for component in metric.components]:
for metric in [column_name for metric in result.metrics for column_name in metric.column_names]:
res += (
result.filter(period='analysis')
.to_df()[
Expand Down
38 changes: 34 additions & 4 deletions tests/io/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def realized_performance_for_binary_classification_result():
y_true='work_home_actual',
problem_type='classification_binary',
timestamp_column_name='timestamp',
metrics=['roc_auc', 'f1'],
metrics=['roc_auc', 'f1', 'confusion_matrix'],
).fit(reference_df)
result = calc.calculate(analysis_df.merge(analysis_targets_df, on='identifier'))
return result
Expand Down Expand Up @@ -160,7 +160,7 @@ def cbpe_estimated_performance_for_binary_classification_result():
y_true='work_home_actual',
problem_type='classification_binary',
timestamp_column_name='timestamp',
metrics=['roc_auc', 'f1'],
metrics=['roc_auc', 'f1', 'confusion_matrix'],
).fit(reference_df)
result = calc.estimate(analysis_df.merge(analysis_targets_df, on='identifier'))
return result
Expand Down Expand Up @@ -355,14 +355,14 @@ def test_pickle_file_writer_raises_no_exceptions_when_writing(result):
'data_reconstruction_feature_drift_metrics',
10,
),
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics', 40),
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics', 120),
(
lazy_fixture('realized_performance_for_multiclass_classification_result'),
'realized_performance_metrics',
40,
),
(lazy_fixture('realized_performance_for_regression_result'), 'realized_performance_metrics', 40),
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics', 20),
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics', 60),
(
lazy_fixture('cbpe_estimated_performance_for_multiclass_classification_result'),
'cbpe_performance_metrics',
Expand All @@ -389,3 +389,33 @@ def test_database_writer_exports_correctly(result, table_name, expected_row_coun

finally:
os.remove('test.db')


@pytest.mark.parametrize(
'result, table_name',
[
(lazy_fixture('realized_performance_for_binary_classification_result'), 'realized_performance_metrics'),
(lazy_fixture('cbpe_estimated_performance_for_binary_classification_result'), 'cbpe_performance_metrics'),
],
)
def test_database_writer_deals_with_metric_components(result, table_name):
try:
writer = DatabaseWriter(connection_string='sqlite:///test.db', model_name='test')
writer.write(result.filter(metrics=['confusion_matrix']))

import sqlite3

with sqlite3.connect("test.db", uri=True) as db:
res = db.cursor().execute(f"SELECT DISTINCT metric_name FROM {table_name}").fetchall()
sut = [row[0] for row in res]

assert 'true_positive' in sut
assert 'false_positive' in sut
assert 'true_negative' in sut
assert 'false_negative' in sut

except Exception as exc:
pytest.fail(f"an unexpected exception occurred: {exc}")

finally:
os.remove('test.db')

0 comments on commit a96b663

Please sign in to comment.