From 324719855f39dae3a8846790ec33f942d4a929a3 Mon Sep 17 00:00:00 2001 From: Giacomo Pope Date: Tue, 23 Jul 2024 13:24:21 +0100 Subject: [PATCH] test generic module things --- src/kyber_py/modules/modules.py | 2 +- src/kyber_py/modules/modules_generic.py | 23 +++++- tests/test_module_generic.py | 102 ++++++++++++++++++++++++ tests/test_polynomial_generic.py | 2 +- 4 files changed, 124 insertions(+), 5 deletions(-) create mode 100644 tests/test_module_generic.py diff --git a/src/kyber_py/modules/modules.py b/src/kyber_py/modules/modules.py index 62c3b68..67b2e51 100644 --- a/src/kyber_py/modules/modules.py +++ b/src/kyber_py/modules/modules.py @@ -35,7 +35,7 @@ def __init__(self, parent, matrix_data, transpose=False): self.parent = parent self._data = matrix_data self._transpose = transpose - if not self.check_dimensions(): + if not self._check_dimensions(): raise ValueError("Inconsistent row lengths in matrix") def encode(self, d): diff --git a/src/kyber_py/modules/modules_generic.py b/src/kyber_py/modules/modules_generic.py index 28957c9..2710865 100644 --- a/src/kyber_py/modules/modules_generic.py +++ b/src/kyber_py/modules/modules_generic.py @@ -3,6 +3,12 @@ def __init__(self, ring): self.ring = ring self.matrix = Matrix + def random_element(self, m, n): + elements = [ + [self.ring.random_element() for _ in range(n)] for _ in range(m) + ] + return self(elements) + def __repr__(self): return f"Module over the commutative ring: {self.ring}" @@ -51,19 +57,20 @@ def __init__(self, parent, matrix_data, transpose=False): self.parent = parent self._data = matrix_data self._transpose = transpose - if not self.check_dimensions(): + if not self._check_dimensions(): raise ValueError("Inconsistent row lengths in matrix") def dim(self): """ Return the dimensions of the matrix with m rows - and n columns""" + and n columns + """ if not self._transpose: return len(self._data), len(self._data[0]) else: return len(self._data[0]), len(self._data) - def check_dimensions(self): + def _check_dimensions(self): """ Ensure that the matrix is rectangular """ @@ -114,6 +121,16 @@ def __eq__(self, other): [self[i, j] == other[i, j] for i in range(m) for j in range(n)] ) + def __neg__(self): + """ + Returns -self, by negating all elements + """ + m, n = self.dim() + return self.parent( + [[-self[i, j] for j in range(n)] for i in range(m)], + self._transpose, + ) + def __add__(self, other): if not isinstance(other, type(self)): raise TypeError("Can only add matrices to other matrices") diff --git a/tests/test_module_generic.py b/tests/test_module_generic.py new file mode 100644 index 0000000..350cc52 --- /dev/null +++ b/tests/test_module_generic.py @@ -0,0 +1,102 @@ +import unittest +from random import randint +from kyber_py.polynomials.polynomials_generic import PolynomialRing +from kyber_py.modules.modules_generic import Module + + +class TestModule(unittest.TestCase): + R = PolynomialRing(11, 5) + M = Module(R) + + def test_random_element(self): + for _ in range(100): + m = randint(1, 5) + n = randint(1, 5) + A = self.M.random_element(m, n) + self.assertEqual(type(A), self.M.matrix) + self.assertEqual(type(A[0, 0]), self.R.element) + self.assertEqual(A.dim(), (m, n)) + + +class TestMatrix(unittest.TestCase): + R = PolynomialRing(11, 5) + M = Module(R) + + def test_matrix_add(self): + zero = self.R(0) + Z = self.M([[zero, zero], [zero, zero]]) + for _ in range(100): + A = self.M.random_element(2, 2) + B = self.M.random_element(2, 2) + C = self.M.random_element(2, 2) + + self.assertEqual(A + Z, A) + self.assertEqual(A + B, B + A) + self.assertEqual(A + (B + C), (A + B) + C) + + def test_matrix_sub(self): + zero = self.R(0) + Z = self.M([[zero, zero], [zero, zero]]) + for _ in range(100): + A = self.M.random_element(2, 2) + B = self.M.random_element(2, 2) + C = self.M.random_element(2, 2) + + self.assertEqual(A - Z, A) + self.assertEqual(A - B, -(B - A)) + self.assertEqual(A - (B - C), (A - B) + C) + + def test_matrix_mul_square(self): + zero = self.R(0) + one = self.R(1) + Z = self.M([[zero, zero], [zero, zero]]) + I = self.M([[one, zero], [zero, one]]) + for _ in range(100): + A = self.M.random_element(2, 2) + B = self.M.random_element(2, 2) + C = self.M.random_element(2, 2) + d = self.R.random_element() + D = self.M([[d, zero], [zero, d]]) + + self.assertEqual(A @ Z, Z) + self.assertEqual(A @ I, A) + self.assertEqual(A @ D, D @ A) # Diagonal matrices commute + self.assertEqual(A @ (B + C), A @ B + A @ C) + + def test_matrix_mul_rectangle(self): + for _ in range(100): + A = self.M.random_element(7, 3) + B = self.M.random_element(3, 2) + C = self.M.random_element(3, 2) + + self.assertEqual(A @ (B + C), A @ B + A @ C) + + def test_matrix_transpose_id(self): + zero = self.R(0) + one = self.R(1) + I = self.M([[one, zero], [zero, one]]) + + self.assertEqual(I, I.transpose()) + + def test_matrix_transpose(self): + for _ in range(100): + A = self.M.random_element(7, 3) + At = A.transpose() + AAt = A @ At + + self.assertEqual(AAt, AAt.transpose()) + + def test_matrix_dot(self): + for _ in range(100): + u = [self.R.random_element() for _ in range(5)] + v = [self.R.random_element() for _ in range(5)] + dot = sum([ui * vi for ui, vi in zip(u, v)]) + + U = self.M.vector(u) + V = self.M.vector(v) + + self.assertEqual(dot, U.dot(V)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_polynomial_generic.py b/tests/test_polynomial_generic.py index 488bedf..83631c9 100644 --- a/tests/test_polynomial_generic.py +++ b/tests/test_polynomial_generic.py @@ -29,7 +29,7 @@ def test_is_constant(self): self.assertTrue(self.R(1).is_constant()) self.assertFalse(self.R.gen().is_constant()) - def test_reduce_coefficents(self): + def test_reduce_coefficients(self): for _ in range(100): # Create non-canonical coefficients coeffs = [