Skip to content

Commit

Permalink
Add tests to check for failures in MLKEM/Kyber
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Jul 24, 2024
1 parent 2f7af63 commit 8425422
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 51 deletions.
9 changes: 0 additions & 9 deletions src/kyber_py/drbg/aes256_ctr_drbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ def ctr_drbg_update(self, provided_data):
self.key = tmp[:32]
self.V = tmp[32:]

def reseed(self, additional_information=b""):
"""
Reseed the DRBG for when reseed_ctr hits the
limit.
"""
seed_material = self.__instantiate(additional_information)
self.ctr_drbg_update(seed_material)
self.reseed_ctr = 1

def random_bytes(self, num_bytes, additional=None):
if self.reseed_ctr >= self.reseed_interval:
raise Warning("The DRBG has been exhausted! Reseed!")
Expand Down
18 changes: 1 addition & 17 deletions src/kyber_py/kyber/kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,9 @@ def set_drbg_seed(self, seed):
self.random_bytes = self._drbg.random_bytes
except ImportError as e:
print(f"Error importing AES from pycryptodome: {e = }")
print(
"Have you tried installing requirements: pip -r install requirements"
)

def reseed_drbg(self, seed):
"""
Reseeds the DRBG, errors if a DRBG is not set.
Note:
currently requires pycryptodome for AES impl.
:param bytes seed: random bytes to use as a new seed of the DRBG
"""
if self._drbg is None:
raise Warning(
"Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`"
"Cannot set DRBG seed due to missing dependencies, try installing requirements: pip -r install requirements"
)
else:
self._drbg.reseed(seed)

@staticmethod
def _xof(bytes32, i, j):
Expand Down
45 changes: 20 additions & 25 deletions src/kyber_py/ml_kem/ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,9 @@ def set_drbg_seed(self, seed):
self.random_bytes = self._drbg.random_bytes
except ImportError as e:
print(f"Error importing AES from pycryptodome: {e = }")
print(
"Have you tried installing requirements: pip -r install requirements"
)

def reseed_drbg(self, seed):
"""
Reseeds the DRBG, errors if a DRBG is not set.
Note:
currently requires pycryptodome for AES impl.
:param bytes seed: random bytes to use as a new seed of the DRBG
"""
if self._drbg is None:
raise Warning(
"Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`"
"Cannot set DRBG seed due to missing dependencies, try installing requirements: pip -r install requirements"
)
else:
self._drbg.reseed(seed)

@staticmethod
def _xof(bytes32, i, j):
Expand Down Expand Up @@ -201,12 +185,13 @@ def _pke_encrypt(self, ek_pke, m, r):

# NOTE:
# Perform the input validation checks for ML-KEM
assert (
len(ek_pke) == 384 * self.k + 32
), "Type check failed, ek_pke has the wrong length"
assert (
t_hat.encode(12) == t_hat_bytes
), "Modulus check failed, t_hat does not encode correctly"
if len(ek_pke) != 384 * self.k + 32:
raise ValueError("Type check failed, ek_pke has the wrong length")

if t_hat.encode(12) != t_hat_bytes:
raise ValueError(
"Modulus check failed, t_hat does not encode correctly"
)

# Generate A_hat^T from seed rho
A_hat_T = self._generate_matrix_from_seed(rho, transpose=True)
Expand Down Expand Up @@ -286,8 +271,14 @@ def encaps(self, ek):
m = self.random_bytes(32)
K, r = self._G(m + self._H(ek))

# Perform the underlying pke encryption
c = self._pke_encrypt(ek, m, r)
# Perform the underlying pke encryption, raises a ValueError if
# ek fails either the TypeCheck or ModulusCheck
try:
c = self._pke_encrypt(ek, m, r)
except ValueError as e:
raise ValueError(
f"Valildation of encapsulation key failed: {e = }"
)

return (K, c)

Expand Down Expand Up @@ -325,6 +316,10 @@ def decaps(self, c, dk):
# Re-encrypt the recovered message
K_prime, r_prime = self._G(m_prime + h)
K_bar = self._J(z + c)

# Here the public encapsulation key is read from the private
# key and so we never expect this to fail the TypeCheck or
# ModulusCheck
c_prime = self._pke_encrypt(ek_pke, m_prime, r_prime)

# If c != c_prime, return K_bar as garbage
Expand Down
13 changes: 13 additions & 0 deletions tests/test_kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ def generic_test_kyber(self, Kyber, count):
pk, sk = Kyber.keygen()
for _ in range(count):
key, c = Kyber.encaps(pk)

# Correct decaps works
_key = Kyber.decaps(c, sk)
self.assertEqual(key, _key)

# Incorrect ct does not work
_bad_ct = bytes([0] * len(c))
_bad = Kyber.decaps(_bad_ct, sk)
self.assertNotEqual(key, _bad)

def test_kyber512(self):
self.generic_test_kyber(Kyber512, 5)

Expand All @@ -46,6 +53,12 @@ def test_kyber768(self):
def test_kyber1024(self):
self.generic_test_kyber(Kyber1024, 5)

def test_xof_failure(self):
self.assertRaises(ValueError, lambda: Kyber512._xof(b"1", b"2", b"3"))

def test_prf_failure(self):
self.assertRaises(ValueError, lambda: Kyber512._prf(b"1", b"2", 32))


class TestKyberDeterministic(unittest.TestCase):
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/test_ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ def test_ML_KEM_768(self):
def test_ML_KEM_1024(self):
self.generic_test_ML_KEM(ML_KEM_1024, 5)

def test_encaps_type_check_failure(self):
"""
Send an ecaps key of the wrong length
"""
self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(b"1"))

def test_encaps_modulus_check_failure(self):
"""
We create a vector of polynomials with non-canonical values for
coefficents to fail the modulus check
"""
(ek, _) = ML_KEM_512.keygen()
rho = ek[-32:]

bad_f_hat = ML_KEM_512.R([3329] * 256)
bad_t_hat = ML_KEM_512.M.vector([bad_f_hat, bad_f_hat])
bad_t_hat_bytes = bad_t_hat.encode(12)

bad_ek = bad_t_hat_bytes + rho

self.assertEqual(len(bad_ek), len(ek))
self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(bad_ek))

def test_xof_failure(self):
self.assertRaises(
ValueError, lambda: ML_KEM_512._xof(b"1", b"2", b"3")
)

def test_prf_failure(self):
self.assertRaises(ValueError, lambda: ML_KEM_512._prf(2, b"1", b"2"))


# As there are 1000 KATs in the file, execution of all of them takes
# a lot of time, run just 100
Expand Down

0 comments on commit 8425422

Please sign in to comment.