Skip to content

Commit

Permalink
fix typo and test error
Browse files Browse the repository at this point in the history
  • Loading branch information
kannkyo committed Dec 8, 2024
1 parent a684fd9 commit 1fbb0e7
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 70 deletions.
87 changes: 51 additions & 36 deletions src/epss_api/epss.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from __future__ import annotations

import csv
from functools import cached_property
from gzip import GzipFile
from urllib.request import urlopen
from bisect import bisect_left
from bisect import bisect_right, bisect_left


class Score(object):
""" EPSS Score Object"""

def __init__(self, cve: str, epss: str, percentile: str):
"""Initialize EPSS Score Object
Args:
cve (str): CVE-yyyy-n
epss (str): [0-1]
percentile (str): [0-1]
"""
self.cve = cve
self.epss = float(epss)
self.percentile = float(percentile)
Expand All @@ -30,9 +36,15 @@ def __init__(self) -> None:

scores = [row for row in csv.DictReader(self._download[1:])]

self._byCVE = {row['cve'] : Score(row['cve'], row['epss'], row['percentile']) for row in scores}
self._byCVE = {row['cve']:
Score(row['cve'], row['epss'], row['percentile'])
for row in scores}

self._sortedScores = sorted(self._byCVE.values(),key=lambda x:x.percentile)
self._sortedScoresByPercentile = sorted(self._byCVE.values(),
key=lambda x: x.percentile)

self._sortedScoresByEpss = sorted(self._byCVE.values(),
key=lambda x: x.epss)

def scores(self) -> list[Score]:
"""Get all CVE's EPSS scores (downloaded data is cached in memory)
Expand All @@ -48,7 +60,7 @@ def scores(self) -> list[Score]:
Returns:
list[Score]: EPSS score's csv list
"""
return list(self._sortedScores)
return list(self._sortedScoresByEpss)

def score(self, cve_id: str) -> Score:
"""Get EPSS score and percentile
Expand All @@ -63,7 +75,7 @@ def score(self, cve_id: str) -> Score:
Score | None: EPSS score percentile
"""

return self._byCVE.get(cve_id,None)
return self._byCVE.get(cve_id, None)

def epss(self, cve_id: str) -> float:
"""Get EPSS score
Expand All @@ -74,8 +86,8 @@ def epss(self, cve_id: str) -> float:
Returns:
float | None: EPSS score (0.0-1.0)
"""
score = self._byCVE.get(cve_id,None)

score = self._byCVE.get(cve_id, None)
if score is None:
return None
else:
Expand All @@ -90,63 +102,67 @@ def percentile(self, cve_id: str) -> float:
Returns:
float | None: EPSS percentile (0.0-1.0)
"""
score = self._byCVE.get(cve_id,None)
score = self._byCVE.get(cve_id, None)
if score is None:
return None
else:
return score.percentile

def epss_gt(self, max: float) -> list[Score]:
"""Get CVEs with EPSS score greater or equal than the parameter
def epss_gt(self, min: float) -> list[Score]:
"""Get CVEs with EPSS score greater than the parameter
Args:
max (float): limit of EPSS score
min (float): limit of EPSS score
Returns:
list[Score] | None: EPSS score object list
"""
i = bisect_left(self._sortedScores,min,key=lambda x:x.epss)

return list(self._sortedScores[i:])
i = bisect_right(self._sortedScoresByEpss, min,
key=lambda x: x.epss)

def percentile_gt(self, max: float) -> list[Score]:
"""Get CVEs with percentile greater or equal than the parameter
return list(self._sortedScoresByEpss[i:])

def percentile_gt(self, min: float) -> list[Score]:
"""Get CVEs with percentile greater than the parameter
Args:
max (float): limit of EPSS score
min (float): limit of percentile
Returns:
list[Score] | None: EPSS score object list
"""
i = bisect_left(self._sortedScores,min,key=lambda x:x.percentile)

return list(self._sortedScores[i:])
i = bisect_right(self._sortedScoresByPercentile, min,
key=lambda x: x.percentile)

def epss_lt(self, min: float) -> list[Score]:
"""Get CVEs with EPSS score lower or equal than the parameter
return list(self._sortedScoresByPercentile[i:])

def epss_lt(self, max: float) -> list[Score]:
"""Get CVEs with EPSS score lower than the parameter
Args:
min (float): limit of EPSS score
max (float): limit of EPSS score
Returns:
list[Score] | None: EPSS score object list
list[Score] | []: EPSS score object list
"""
i = bisect_left(self._sortedScores[::-1],min,key=lambda x:1-x.epss)

return list(self._sortedScores[:len(self.sortedScores)-i])
i = bisect_left(self._sortedScoresByEpss, max,
key=lambda x: x.epss)

return list(self._sortedScoresByEpss[:i])

def percentile_lt(self, min: float) -> list[Score]:
"""Get CVEs with percentile lower or equal than the parameter
def percentile_lt(self, max: float) -> list[Score]:
"""Get CVEs with percentile lower than the parameter
Args:
min (float): limit of EPSS score
max (float): limit of percentile
Returns:
list[Score] | None: EPSS score object list
list[Score] | []: EPSS score object list
"""
i = bisect_left(self._sortedScores[::-1],min,key=lambda x:1-x.percentile)

return list(self._sortedScores[:len(self.sortedScores)-i])
i = bisect_left(self._sortedScoresByPercentile, max,
key=lambda x: x.percentile)

return list(self._sortedScoresByPercentile[:i])

def csv(self) -> list[str]:
"""Get csv data containing all epss scores.
Expand All @@ -162,4 +178,3 @@ def csv(self) -> list[str]:
list[str]: csv data
"""
return self._download

76 changes: 42 additions & 34 deletions tests/epss_api/test_epss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,55 +24,63 @@ def test_csv():
assert len(rows[2:]) >= 1000


@pytest.mark.parametrize("max", [-1, 0, 0.5, 1, 2])
def test_epss_gt(max):
value = epss.epss_gt(max)
for s in value:
assert s.epss >= max


@pytest.mark.parametrize("max", [-1, 0, 0.5, 1, 2])
def test_percentile_gt(max):
value = epss.percentile_gt(max)
for s in value:
assert s.percentile >= max


@pytest.mark.parametrize("min", [-1, 0, 0.5, 1, 2])
def test_epss_lt(min):
value = epss.epss_lt(min)
for s in value:
def test_epss_gt(min):
scores = epss.epss_gt(min)
for s in scores:
assert s.epss > min
for s in list(set(scores) - set(epss.scores())):
assert s.epss <= min


@pytest.mark.parametrize("min", [-1, 0, 0.5, 1, 2])
def test_percentile_lt(min):
value = epss.percentile_lt(min)
for s in value:
def test_percentile_gt(min):
scores = epss.percentile_gt(min)
for s in scores:
assert s.percentile > min
for s in list(set(scores) - set(epss.scores())):
assert s.percentile <= min


@pytest.mark.parametrize("max", [-1, 0, 0.5, 1, 2])
def test_epss_lt(max):
scores = epss.epss_lt(max)
for s in scores:
assert s.epss < max
for s in list(set(scores) - set(epss.scores())):
assert s.epss >= max


@pytest.mark.parametrize("max", [-1, 0, 0.5, 1, 2])
def test_percentile_lt(max):
scores = epss.percentile_lt(max)
for s in scores:
assert s.percentile < max
for s in list(set(scores) - set(epss.scores())):
assert s.percentile >= max


def test_score():
value = epss.score(cve_id='CVE-2022-0669')
assert value.cve.startswith('CVE-')
assert 0 <= value.epss <= 1
assert 0 <= value.percentile <= 1
value = epss.score(cve_id='CVE-1000-123')
assert value is None
score = epss.score(cve_id='CVE-2022-0669')
assert score.cve.startswith('CVE-')
assert 0 <= score.epss <= 1
assert 0 <= score.percentile <= 1
score = epss.score(cve_id='CVE-1000-123')
assert score is None


def test_epss():
value = epss.epss(cve_id='CVE-2022-0669')
assert 0 <= value <= 1
value = epss.epss(cve_id='CVE-1000-123')
assert value is None
score = epss.epss(cve_id='CVE-2022-0669')
assert 0 <= score <= 1
score = epss.epss(cve_id='CVE-1000-123')
assert score is None


def test_percentile():
value = epss.percentile(cve_id='CVE-2022-0669')
assert 0 <= value <= 1
value = epss.percentile(cve_id='CVE-1000-123')
assert value is None
score = epss.percentile(cve_id='CVE-2022-0669')
assert 0 <= score <= 1
score = epss.percentile(cve_id='CVE-1000-123')
assert score is None


def test_score_dict():
Expand Down

0 comments on commit 1fbb0e7

Please sign in to comment.