diff --git a/src/epss_api/epss.py b/src/epss_api/epss.py index 432c12f..dd9d6a8 100644 --- a/src/epss_api/epss.py +++ b/src/epss_api/epss.py @@ -4,7 +4,6 @@ from functools import cached_property from gzip import GzipFile from urllib.request import urlopen -from bisect import bisect_left class Score(object): @@ -18,21 +17,7 @@ def __init__(self, cve: str, epss: str, percentile: str): class EPSS(object): def __init__(self) -> None: - - url = 'https://epss.cyentia.com/epss_scores-current.csv.gz' - - with urlopen(url) as res: - dec = GzipFile(fileobj=res) - epss_scores_str: str = dec.read().decode("utf-8") - epss_scores_list = epss_scores_str.split('\n') - - self._download = epss_scores_list - - scores = [row for row in csv.DictReader(self._download[1:])] - - self._byCVE = {row['cve'] : Score(row['cve'], row['epss'], row['percentile']) for row in scores} - - self._sortedScores = sorted(self._byCVE.values(),key=lambda x:x.percentile) + pass def scores(self) -> list[Score]: """Get all CVE's EPSS scores (downloaded data is cached in memory) @@ -48,7 +33,9 @@ def scores(self) -> list[Score]: Returns: list[Score]: EPSS score's csv list """ - return list(self._sortedScores) + scores = [row for row in csv.DictReader(self._download[1:])] + return [Score(row['cve'], row['epss'], row['percentile']) + for row in scores] def score(self, cve_id: str) -> Score: """Get EPSS score and percentile @@ -62,8 +49,11 @@ def score(self, cve_id: str) -> Score: Returns: Score | None: EPSS score percentile """ - - return self._byCVE.get(cve_id,None) + rows = self._filter_by_cve_id(cve_id) + if len(rows) == 1: + return rows[0] + else: + return None def epss(self, cve_id: str) -> float: """Get EPSS score @@ -74,12 +64,11 @@ def epss(self, cve_id: str) -> float: Returns: float | None: EPSS score (0.0-1.0) """ - - score = self._byCVE.get(cve_id,None) - if score is None: - return None + rows = self._filter_by_cve_id(cve_id) + if len(rows) == 1: + return rows[0].epss else: - return score.epss + return None def percentile(self, cve_id: str) -> float: """Get EPSS percentile @@ -90,11 +79,11 @@ def percentile(self, cve_id: str) -> float: Returns: float | None: EPSS percentile (0.0-1.0) """ - score = self._byCVE.get(cve_id,None) - if score is None: - return None + rows = self._filter_by_cve_id(cve_id) + if len(rows) == 1: + return rows[0].percentile else: - return score.percentile + return None def epss_gt(self, max: float) -> list[Score]: """Get CVEs with EPSS score greater or equal than the parameter @@ -105,9 +94,8 @@ def epss_gt(self, max: float) -> list[Score]: Returns: list[Score] | None: EPSS score object list """ - i = bisect_left(self._sortedScores,min,key=lambda x:x.epss) - - return list(self._sortedScores[i:]) + rows = [r for r in filter(lambda x: x.epss >= max, self.scores())] + return rows def percentile_gt(self, max: float) -> list[Score]: """Get CVEs with percentile greater or equal than the parameter @@ -118,9 +106,9 @@ def percentile_gt(self, max: float) -> list[Score]: Returns: list[Score] | None: EPSS score object list """ - i = bisect_left(self._sortedScores,min,key=lambda x:x.percentile) - - return list(self._sortedScores[i:]) + rows = [r for r in + filter(lambda x: x.percentile >= max, self.scores())] + return rows def epss_lt(self, min: float) -> list[Score]: """Get CVEs with EPSS score lower or equal than the parameter @@ -131,9 +119,8 @@ def epss_lt(self, min: float) -> list[Score]: Returns: list[Score] | None: EPSS score object list """ - i = bisect_left(self._sortedScores[::-1],min,key=lambda x:1-x.epss) - - return list(self._sortedScores[:len(self.sortedScores)-i]) + rows = [r for r in filter(lambda x: x.epss <= min, self.scores())] + return rows def percentile_lt(self, min: float) -> list[Score]: """Get CVEs with percentile lower or equal than the parameter @@ -144,9 +131,9 @@ def percentile_lt(self, min: float) -> list[Score]: Returns: list[Score] | None: EPSS score object list """ - i = bisect_left(self._sortedScores[::-1],min,key=lambda x:1-x.percentile) - - return list(self._sortedScores[:len(self.sortedScores)-i]) + rows = [r for r in + filter(lambda x: x.percentile <= min, self.scores())] + return rows def csv(self) -> list[str]: """Get csv data containing all epss scores. @@ -163,3 +150,18 @@ def csv(self) -> list[str]: """ return self._download + def _filter_by_cve_id(self, cve_id: str): + cve_filter = filter(lambda x: x.cve == cve_id, self.scores()) + rows = [row for row in cve_filter] + return rows + + @cached_property + def _download(self): + url = 'https://epss.cyentia.com/epss_scores-current.csv.gz' + + with urlopen(url) as res: + dec = GzipFile(fileobj=res) + epss_scores_str: str = dec.read().decode("utf-8") + epss_scores_list = epss_scores_str.split('\n') + + return epss_scores_list