From 8425422af7c7f942ff6b6a54691d3c84c1cb16e0 Mon Sep 17 00:00:00 2001 From: Giacomo Pope Date: Wed, 24 Jul 2024 12:35:10 +0100 Subject: [PATCH] Add tests to check for failures in MLKEM/Kyber --- src/kyber_py/drbg/aes256_ctr_drbg.py | 9 ------ src/kyber_py/kyber/kyber.py | 18 +---------- src/kyber_py/ml_kem/ml_kem.py | 45 +++++++++++++--------------- tests/test_kyber.py | 13 ++++++++ tests/test_ml_kem.py | 31 +++++++++++++++++++ 5 files changed, 65 insertions(+), 51 deletions(-) diff --git a/src/kyber_py/drbg/aes256_ctr_drbg.py b/src/kyber_py/drbg/aes256_ctr_drbg.py index 8f2422a..db28e18 100644 --- a/src/kyber_py/drbg/aes256_ctr_drbg.py +++ b/src/kyber_py/drbg/aes256_ctr_drbg.py @@ -68,15 +68,6 @@ def ctr_drbg_update(self, provided_data): self.key = tmp[:32] self.V = tmp[32:] - def reseed(self, additional_information=b""): - """ - Reseed the DRBG for when reseed_ctr hits the - limit. - """ - seed_material = self.__instantiate(additional_information) - self.ctr_drbg_update(seed_material) - self.reseed_ctr = 1 - def random_bytes(self, num_bytes, additional=None): if self.reseed_ctr >= self.reseed_interval: raise Warning("The DRBG has been exhausted! Reseed!") diff --git a/src/kyber_py/kyber/kyber.py b/src/kyber_py/kyber/kyber.py index 1762945..f6ed40b 100644 --- a/src/kyber_py/kyber/kyber.py +++ b/src/kyber_py/kyber/kyber.py @@ -46,25 +46,9 @@ def set_drbg_seed(self, seed): self.random_bytes = self._drbg.random_bytes except ImportError as e: print(f"Error importing AES from pycryptodome: {e = }") - print( - "Have you tried installing requirements: pip -r install requirements" - ) - - def reseed_drbg(self, seed): - """ - Reseeds the DRBG, errors if a DRBG is not set. - - Note: - currently requires pycryptodome for AES impl. - - :param bytes seed: random bytes to use as a new seed of the DRBG - """ - if self._drbg is None: raise Warning( - "Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`" + "Cannot set DRBG seed due to missing dependencies, try installing requirements: pip -r install requirements" ) - else: - self._drbg.reseed(seed) @staticmethod def _xof(bytes32, i, j): diff --git a/src/kyber_py/ml_kem/ml_kem.py b/src/kyber_py/ml_kem/ml_kem.py index f56f277..3708d6a 100644 --- a/src/kyber_py/ml_kem/ml_kem.py +++ b/src/kyber_py/ml_kem/ml_kem.py @@ -51,25 +51,9 @@ def set_drbg_seed(self, seed): self.random_bytes = self._drbg.random_bytes except ImportError as e: print(f"Error importing AES from pycryptodome: {e = }") - print( - "Have you tried installing requirements: pip -r install requirements" - ) - - def reseed_drbg(self, seed): - """ - Reseeds the DRBG, errors if a DRBG is not set. - - Note: - currently requires pycryptodome for AES impl. - - :param bytes seed: random bytes to use as a new seed of the DRBG - """ - if self._drbg is None: raise Warning( - "Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`" + "Cannot set DRBG seed due to missing dependencies, try installing requirements: pip -r install requirements" ) - else: - self._drbg.reseed(seed) @staticmethod def _xof(bytes32, i, j): @@ -201,12 +185,13 @@ def _pke_encrypt(self, ek_pke, m, r): # NOTE: # Perform the input validation checks for ML-KEM - assert ( - len(ek_pke) == 384 * self.k + 32 - ), "Type check failed, ek_pke has the wrong length" - assert ( - t_hat.encode(12) == t_hat_bytes - ), "Modulus check failed, t_hat does not encode correctly" + if len(ek_pke) != 384 * self.k + 32: + raise ValueError("Type check failed, ek_pke has the wrong length") + + if t_hat.encode(12) != t_hat_bytes: + raise ValueError( + "Modulus check failed, t_hat does not encode correctly" + ) # Generate A_hat^T from seed rho A_hat_T = self._generate_matrix_from_seed(rho, transpose=True) @@ -286,8 +271,14 @@ def encaps(self, ek): m = self.random_bytes(32) K, r = self._G(m + self._H(ek)) - # Perform the underlying pke encryption - c = self._pke_encrypt(ek, m, r) + # Perform the underlying pke encryption, raises a ValueError if + # ek fails either the TypeCheck or ModulusCheck + try: + c = self._pke_encrypt(ek, m, r) + except ValueError as e: + raise ValueError( + f"Valildation of encapsulation key failed: {e = }" + ) return (K, c) @@ -325,6 +316,10 @@ def decaps(self, c, dk): # Re-encrypt the recovered message K_prime, r_prime = self._G(m_prime + h) K_bar = self._J(z + c) + + # Here the public encapsulation key is read from the private + # key and so we never expect this to fail the TypeCheck or + # ModulusCheck c_prime = self._pke_encrypt(ek_pke, m_prime, r_prime) # If c != c_prime, return K_bar as garbage diff --git a/tests/test_kyber.py b/tests/test_kyber.py index eefad76..586ec6c 100644 --- a/tests/test_kyber.py +++ b/tests/test_kyber.py @@ -34,9 +34,16 @@ def generic_test_kyber(self, Kyber, count): pk, sk = Kyber.keygen() for _ in range(count): key, c = Kyber.encaps(pk) + + # Correct decaps works _key = Kyber.decaps(c, sk) self.assertEqual(key, _key) + # Incorrect ct does not work + _bad_ct = bytes([0] * len(c)) + _bad = Kyber.decaps(_bad_ct, sk) + self.assertNotEqual(key, _bad) + def test_kyber512(self): self.generic_test_kyber(Kyber512, 5) @@ -46,6 +53,12 @@ def test_kyber768(self): def test_kyber1024(self): self.generic_test_kyber(Kyber1024, 5) + def test_xof_failure(self): + self.assertRaises(ValueError, lambda: Kyber512._xof(b"1", b"2", b"3")) + + def test_prf_failure(self): + self.assertRaises(ValueError, lambda: Kyber512._prf(b"1", b"2", 32)) + class TestKyberDeterministic(unittest.TestCase): """ diff --git a/tests/test_ml_kem.py b/tests/test_ml_kem.py index f061807..2a1b31d 100644 --- a/tests/test_ml_kem.py +++ b/tests/test_ml_kem.py @@ -59,6 +59,37 @@ def test_ML_KEM_768(self): def test_ML_KEM_1024(self): self.generic_test_ML_KEM(ML_KEM_1024, 5) + def test_encaps_type_check_failure(self): + """ + Send an ecaps key of the wrong length + """ + self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(b"1")) + + def test_encaps_modulus_check_failure(self): + """ + We create a vector of polynomials with non-canonical values for + coefficents to fail the modulus check + """ + (ek, _) = ML_KEM_512.keygen() + rho = ek[-32:] + + bad_f_hat = ML_KEM_512.R([3329] * 256) + bad_t_hat = ML_KEM_512.M.vector([bad_f_hat, bad_f_hat]) + bad_t_hat_bytes = bad_t_hat.encode(12) + + bad_ek = bad_t_hat_bytes + rho + + self.assertEqual(len(bad_ek), len(ek)) + self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(bad_ek)) + + def test_xof_failure(self): + self.assertRaises( + ValueError, lambda: ML_KEM_512._xof(b"1", b"2", b"3") + ) + + def test_prf_failure(self): + self.assertRaises(ValueError, lambda: ML_KEM_512._prf(2, b"1", b"2")) + # As there are 1000 KATs in the file, execution of all of them takes # a lot of time, run just 100