Skip to content

Commit

Permalink
feat: weighted metrics (#26)
Browse files Browse the repository at this point in the history
* feat(confusionmatrix): added weight metric

* feat(confusionmatrix): added more options to `getShortStats()`
  • Loading branch information
Berkmann18 authored Jul 24, 2019
1 parent 2a41f1b commit 35464db
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 20 deletions.
116 changes: 106 additions & 10 deletions src/__tests__/confusionMatrix.js
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ test('Predicted Negatives', () => {
expect(cm.getPredNegative('other')).toStrictEqual(11)
})

test('Support', () => {
const cm = new CM(CATEGORIES, M0)
expect(cm.getSupport('bug')).toStrictEqual(6)
expect(cm.getSupport('code')).toStrictEqual(3)
expect(cm.getSupport('other')).toStrictEqual(11)
})

describe('Accuracy', () => {
const cm = new CM(CATEGORIES, M0)
test('Accuracy', () => {
Expand All @@ -158,12 +165,19 @@ describe('Accuracy', () => {
expect(cm.getAccuracy('other')).toStrictEqual(0.8)
})

test('Macro accuracy', () => {
test('Macro Accuracy', () => {
expect(cm.getMacroAccuracy()).toStrictEqual(M[56])
})
test('Micro accuracy', () => {

test('Micro Accuracy', () => {
expect(cm.getMicroAccuracy()).toStrictEqual(0.75)
})

test('Weighted Accuracy', () => {
expect(
Math.round(cm.getWeightedAccuracy() * 100000) / 100000,
).toStrictEqual(0.83)
})
})

describe('Recall', () => {
Expand All @@ -174,13 +188,19 @@ describe('Recall', () => {
expect(cm.getRecall('other')).toStrictEqual(M[811]) //.727
})

test('Macro recall', () => {
test('Macro Recall', () => {
expect(cm.getMacroRecall()).toStrictEqual((3 / 2 + M[811]) / 3) //~.742
})

test('Micro recall', () => {
test('Micro Recall', () => {
expect(cm.getMicroRecall()).toStrictEqual(0.75)
})

test('Weighted Recall', () => {
expect(Math.round(cm.getWeightedRecall() * 100000) / 100000).toStrictEqual(
0.75,
)
})
})

describe('Precision', () => {
Expand All @@ -191,13 +211,19 @@ describe('Precision', () => {
expect(cm.getPrecision('other')).toStrictEqual(M[89]) //.889
})

test('Macro precision', () => {
test('Macro Precision', () => {
expect(cm.getMacroPrecision()).toStrictEqual((M[56] + 0.4 + M[89]) / 3) //~.707
})

test('Micro precision', () => {
test('Micro Precision', () => {
expect(cm.getMicroPrecision()).toStrictEqual(0.75)
})

test('Weighted Precision', () => {
expect(
Math.round(cm.getWeightedPrecision() * 100000) / 100000,
).toStrictEqual(0.79889)
})
})

describe('F1', () => {
Expand All @@ -215,6 +241,12 @@ describe('F1', () => {
test('Micro F1', () => {
expect(cm.getMicroF1()).toStrictEqual(0.75)
})

test('Weighted F1', () => {
expect(Math.round(cm.getWeightedF1() * 100000) / 100000).toStrictEqual(
0.765,
)
})
})

describe('MissRate', () => {
Expand All @@ -232,6 +264,12 @@ describe('MissRate', () => {
test('Micro MissRate', () => {
expect(cm.getMicroMissRate()).toStrictEqual(0.25)
})

test('Weighted MissRate', () => {
expect(
Math.round(cm.getWeightedMissRate() * 100000) / 100000,
).toStrictEqual(0.25)
})
})

describe('FallOut', () => {
Expand All @@ -250,6 +288,12 @@ describe('FallOut', () => {
test('Micro FallOut', () => {
expect(cm.getMicroFallOut()).toStrictEqual(0.125)
})

test('Weighted FallOut', () => {
expect(Math.round(cm.getWeightedFallOut() * 100000) / 100000).toStrictEqual(
0.10901,
)
})
})

describe('Specificity', () => {
Expand All @@ -268,6 +312,12 @@ describe('Specificity', () => {
test('Micro Specificity', () => {
expect(cm.getMicroSpecificity()).toStrictEqual(0.875)
})

test('Weighted Specificity', () => {
expect(
Math.round(cm.getWeightedSpecificity() * 100000) / 100000,
).toStrictEqual(0.89099)
})
})

describe('Prevalence', () => {
Expand All @@ -285,6 +335,12 @@ describe('Prevalence', () => {
test('Micro Prevalence', () => {
expect(cm.getMicroPrevalence()).toStrictEqual(M[13])
})

test('Weighted Prevalence', () => {
expect(
Math.round(cm.getWeightedPrevalence() * 100000) / 100000,
).toStrictEqual(0.415)
})
})

describe('fromData', () => {
Expand Down Expand Up @@ -416,16 +472,42 @@ describe('toString', () => {
})
})

test('shortStats', () => {
const cm = new CM(CATEGORIES, M0)
const ss = `Total: 20
describe('shortStats', () => {
test('default', () => {
const cm = new CM(CATEGORIES, M0)
const ss = `Total: 20
True: 15
False: 5
Accuracy: 75%
Precision: 75%
Recall: 75%
F1: 75%`
expect(cm.getShortStats()).toStrictEqual(ss)
expect(cm.getShortStats()).toStrictEqual(ss)
})

test('macro', () => {
const cm = new CM(CATEGORIES, M0)
const ss = `Total: 20
True: 15
False: 5
Accuracy: 83.33333333333334%
Precision: 70.74074074074073%
Recall: 74.24242424242425%
F1: 71.11111111111111%`
expect(cm.getShortStats('macro')).toStrictEqual(ss)
})

test('weighted', () => {
const cm = new CM(CATEGORIES, M0)
const ss = `Total: 20
True: 15
False: 5
Accuracy: 83%
Precision: 79.88888888888889%
Recall: 75%
F1: 76.49999999999999%`
expect(cm.getShortStats('weighted')).toStrictEqual(ss)
})
})

describe('Long stats', () => {
Expand All @@ -439,6 +521,7 @@ describe('Long stats', () => {
classes: CATEGORIES,
microAvg: {},
macroAvg: {},
weightedAvg: {},
results: {
bug: {},
code: {},
Expand Down Expand Up @@ -473,6 +556,19 @@ describe('Long stats', () => {
})
})

it('has weighted details', () => {
expect(stats.weightedAvg).toMatchObject({
accuracy: 0.83 + 1e-16,
f1: 0.7649999999999999,
fallOut: 0.10901027077497664,
missRate: 0.25,
precision: 0.7988888888888889,
prevalence: 0.41500000000000004,
recall: 0.75,
specificity: 0.8909897292250232,
})
})

it('has class details', () => {
const bugStats = stats.results.bug
expect(bugStats).toMatchObject({
Expand Down
Loading

0 comments on commit 35464db

Please sign in to comment.