Skip to content

Commit

Permalink
Merge pull request #1071 from EnMAP-Box/1070-classification-layer-acc…
Browse files Browse the repository at this point in the history
…uracy-report-keyerror-accuracy

resolved #1070
  • Loading branch information
janzandr authored Jan 20, 2025
2 parents 5c3a667 + d074094 commit abd1e02
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from enmapboxprocessing.rasterreader import RasterReader
from enmapboxprocessing.reportwriter import MultiReportWriter, HtmlReportWriter, CsvReportWriter
from enmapboxprocessing.utils import Utils
from qgis.core import QgsProcessingContext, QgsProcessingFeedback, QgsRasterLayer, QgsVectorLayer
from qgis.core import QgsProcessingContext, QgsProcessingFeedback, QgsRasterLayer, QgsVectorLayer, \
QgsProcessingException


@typechecked
Expand Down Expand Up @@ -124,6 +125,8 @@ def processAlgorithm(
feedback.pushInfo('Estimate statistics and create report')
classValues = [c.value for c in categoriesReference]
classNames = [c.name for c in categoriesReference]
if not np.isin(yMap, classValues).all():
raise QgsProcessingException('Predicted values not matching reference classes.')
stats = accuracyAssessment(yReference, yMap, classNames, classValues)

self.writeReport(filename, stats, classNamesMatching)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enmapboxprocessing.typing import Category
from enmapboxprocessing.utils import Utils
from enmapboxtestdata import landcover_map_l3
from qgis.core import QgsRasterLayer, QgsMapLayer
from qgis.core import QgsRasterLayer, QgsMapLayer, QgsProcessingException


class TestClassificationPerformanceSimpleAlgorithm(TestCase):
Expand Down Expand Up @@ -129,3 +129,33 @@ def test_twoClass(self):
self.assertEqual(77, int(stats['overallAccuracy'] * 100))
self.assertListEqual([100, 33], [int(v * 100) for v in stats['usersAccuracy']])
self.assertListEqual([75, 100], [int(v * 100) for v in stats['producersAccuracy']])

def test_issue1070(self):
# handle unclassified data
categories = [Category(1, 'c1', '#000000'), Category(2, 'c2', '#000000')]
writer1 = self.rasterFromArray([[[1, 2]]], 'observation.tif')
writer1.close()
raster1 = QgsRasterLayer(writer1.source())
renderer = Utils().palettedRasterRendererFromCategories(raster1.dataProvider(), 1, categories)
raster1.setRenderer(renderer.clone())
raster1.saveDefaultStyle(QgsMapLayer.StyleCategory.AllStyleCategories)

writer2 = self.rasterFromArray([[[0, 2]]], 'prediction.tif') # introduce unclassified data
writer2.close()
raster2 = QgsRasterLayer(writer2.source())
renderer = Utils().palettedRasterRendererFromCategories(raster2.dataProvider(), 1, categories)
raster2.setRenderer(renderer.clone())
raster2.saveDefaultStyle(QgsMapLayer.StyleCategory.AllStyleCategories)

alg = ClassificationPerformanceSimpleAlgorithm()
alg.initAlgorithm()
parameters = {
alg.P_CLASSIFICATION: writer2.source(),
alg.P_REFERENCE: writer1.source(),
alg.P_OPEN_REPORT: self.openReport,
alg.P_OUTPUT_REPORT: self.filename('report.html'),
}
try:
self.runalg(alg, parameters)
except QgsProcessingException as error:
assert str(error) == 'Predicted values not matching reference classes.'

0 comments on commit abd1e02

Please sign in to comment.