From 1fbb0e794fd1bddea72cb3966c072b5dc6112470 Mon Sep 17 00:00:00 2001 From: kannkyo <15080890+kannkyo@users.noreply.github.com> Date: Mon, 9 Dec 2024 01:43:53 +0900 Subject: [PATCH] fix typo and test error --- src/epss_api/epss.py | 87 ++++++++++++++++++++++--------------- tests/epss_api/test_epss.py | 76 +++++++++++++++++--------------- 2 files changed, 93 insertions(+), 70 deletions(-) diff --git a/src/epss_api/epss.py b/src/epss_api/epss.py index 432c12f..84346f1 100644 --- a/src/epss_api/epss.py +++ b/src/epss_api/epss.py @@ -1,16 +1,22 @@ from __future__ import annotations import csv -from functools import cached_property from gzip import GzipFile from urllib.request import urlopen -from bisect import bisect_left +from bisect import bisect_right, bisect_left class Score(object): """ EPSS Score Object""" def __init__(self, cve: str, epss: str, percentile: str): + """Initialize EPSS Score Object + + Args: + cve (str): CVE-yyyy-n + epss (str): [0-1] + percentile (str): [0-1] + """ self.cve = cve self.epss = float(epss) self.percentile = float(percentile) @@ -30,9 +36,15 @@ def __init__(self) -> None: scores = [row for row in csv.DictReader(self._download[1:])] - self._byCVE = {row['cve'] : Score(row['cve'], row['epss'], row['percentile']) for row in scores} + self._byCVE = {row['cve']: + Score(row['cve'], row['epss'], row['percentile']) + for row in scores} - self._sortedScores = sorted(self._byCVE.values(),key=lambda x:x.percentile) + self._sortedScoresByPercentile = sorted(self._byCVE.values(), + key=lambda x: x.percentile) + + self._sortedScoresByEpss = sorted(self._byCVE.values(), + key=lambda x: x.epss) def scores(self) -> list[Score]: """Get all CVE's EPSS scores (downloaded data is cached in memory) @@ -48,7 +60,7 @@ def scores(self) -> list[Score]: Returns: list[Score]: EPSS score's csv list """ - return list(self._sortedScores) + return list(self._sortedScoresByEpss) def score(self, cve_id: str) -> Score: """Get EPSS score and percentile @@ -63,7 +75,7 @@ def score(self, cve_id: str) -> Score: Score | None: EPSS score percentile """ - return self._byCVE.get(cve_id,None) + return self._byCVE.get(cve_id, None) def epss(self, cve_id: str) -> float: """Get EPSS score @@ -74,8 +86,8 @@ def epss(self, cve_id: str) -> float: Returns: float | None: EPSS score (0.0-1.0) """ - - score = self._byCVE.get(cve_id,None) + + score = self._byCVE.get(cve_id, None) if score is None: return None else: @@ -90,63 +102,67 @@ def percentile(self, cve_id: str) -> float: Returns: float | None: EPSS percentile (0.0-1.0) """ - score = self._byCVE.get(cve_id,None) + score = self._byCVE.get(cve_id, None) if score is None: return None else: return score.percentile - def epss_gt(self, max: float) -> list[Score]: - """Get CVEs with EPSS score greater or equal than the parameter + def epss_gt(self, min: float) -> list[Score]: + """Get CVEs with EPSS score greater than the parameter Args: - max (float): limit of EPSS score + min (float): limit of EPSS score Returns: list[Score] | None: EPSS score object list """ - i = bisect_left(self._sortedScores,min,key=lambda x:x.epss) - - return list(self._sortedScores[i:]) + i = bisect_right(self._sortedScoresByEpss, min, + key=lambda x: x.epss) - def percentile_gt(self, max: float) -> list[Score]: - """Get CVEs with percentile greater or equal than the parameter + return list(self._sortedScoresByEpss[i:]) + + def percentile_gt(self, min: float) -> list[Score]: + """Get CVEs with percentile greater than the parameter Args: - max (float): limit of EPSS score + min (float): limit of percentile Returns: list[Score] | None: EPSS score object list """ - i = bisect_left(self._sortedScores,min,key=lambda x:x.percentile) - - return list(self._sortedScores[i:]) + i = bisect_right(self._sortedScoresByPercentile, min, + key=lambda x: x.percentile) - def epss_lt(self, min: float) -> list[Score]: - """Get CVEs with EPSS score lower or equal than the parameter + return list(self._sortedScoresByPercentile[i:]) + + def epss_lt(self, max: float) -> list[Score]: + """Get CVEs with EPSS score lower than the parameter Args: - min (float): limit of EPSS score + max (float): limit of EPSS score Returns: - list[Score] | None: EPSS score object list + list[Score] | []: EPSS score object list """ - i = bisect_left(self._sortedScores[::-1],min,key=lambda x:1-x.epss) - - return list(self._sortedScores[:len(self.sortedScores)-i]) + i = bisect_left(self._sortedScoresByEpss, max, + key=lambda x: x.epss) + + return list(self._sortedScoresByEpss[:i]) - def percentile_lt(self, min: float) -> list[Score]: - """Get CVEs with percentile lower or equal than the parameter + def percentile_lt(self, max: float) -> list[Score]: + """Get CVEs with percentile lower than the parameter Args: - min (float): limit of EPSS score + max (float): limit of percentile Returns: - list[Score] | None: EPSS score object list + list[Score] | []: EPSS score object list """ - i = bisect_left(self._sortedScores[::-1],min,key=lambda x:1-x.percentile) - - return list(self._sortedScores[:len(self.sortedScores)-i]) + i = bisect_left(self._sortedScoresByPercentile, max, + key=lambda x: x.percentile) + + return list(self._sortedScoresByPercentile[:i]) def csv(self) -> list[str]: """Get csv data containing all epss scores. @@ -162,4 +178,3 @@ def csv(self) -> list[str]: list[str]: csv data """ return self._download - diff --git a/tests/epss_api/test_epss.py b/tests/epss_api/test_epss.py index 9e20464..f15c29d 100644 --- a/tests/epss_api/test_epss.py +++ b/tests/epss_api/test_epss.py @@ -24,55 +24,63 @@ def test_csv(): assert len(rows[2:]) >= 1000 -@pytest.mark.parametrize("max", [-1, 0, 0.5, 1, 2]) -def test_epss_gt(max): - value = epss.epss_gt(max) - for s in value: - assert s.epss >= max - - -@pytest.mark.parametrize("max", [-1, 0, 0.5, 1, 2]) -def test_percentile_gt(max): - value = epss.percentile_gt(max) - for s in value: - assert s.percentile >= max - - @pytest.mark.parametrize("min", [-1, 0, 0.5, 1, 2]) -def test_epss_lt(min): - value = epss.epss_lt(min) - for s in value: +def test_epss_gt(min): + scores = epss.epss_gt(min) + for s in scores: + assert s.epss > min + for s in list(set(scores) - set(epss.scores())): assert s.epss <= min @pytest.mark.parametrize("min", [-1, 0, 0.5, 1, 2]) -def test_percentile_lt(min): - value = epss.percentile_lt(min) - for s in value: +def test_percentile_gt(min): + scores = epss.percentile_gt(min) + for s in scores: + assert s.percentile > min + for s in list(set(scores) - set(epss.scores())): assert s.percentile <= min +@pytest.mark.parametrize("max", [-1, 0, 0.5, 1, 2]) +def test_epss_lt(max): + scores = epss.epss_lt(max) + for s in scores: + assert s.epss < max + for s in list(set(scores) - set(epss.scores())): + assert s.epss >= max + + +@pytest.mark.parametrize("max", [-1, 0, 0.5, 1, 2]) +def test_percentile_lt(max): + scores = epss.percentile_lt(max) + for s in scores: + assert s.percentile < max + for s in list(set(scores) - set(epss.scores())): + assert s.percentile >= max + + def test_score(): - value = epss.score(cve_id='CVE-2022-0669') - assert value.cve.startswith('CVE-') - assert 0 <= value.epss <= 1 - assert 0 <= value.percentile <= 1 - value = epss.score(cve_id='CVE-1000-123') - assert value is None + score = epss.score(cve_id='CVE-2022-0669') + assert score.cve.startswith('CVE-') + assert 0 <= score.epss <= 1 + assert 0 <= score.percentile <= 1 + score = epss.score(cve_id='CVE-1000-123') + assert score is None def test_epss(): - value = epss.epss(cve_id='CVE-2022-0669') - assert 0 <= value <= 1 - value = epss.epss(cve_id='CVE-1000-123') - assert value is None + score = epss.epss(cve_id='CVE-2022-0669') + assert 0 <= score <= 1 + score = epss.epss(cve_id='CVE-1000-123') + assert score is None def test_percentile(): - value = epss.percentile(cve_id='CVE-2022-0669') - assert 0 <= value <= 1 - value = epss.percentile(cve_id='CVE-1000-123') - assert value is None + score = epss.percentile(cve_id='CVE-2022-0669') + assert 0 <= score <= 1 + score = epss.percentile(cve_id='CVE-1000-123') + assert score is None def test_score_dict():