Skip to content

Commit

Permalink
test generic module things
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Jul 23, 2024
1 parent 7a66d63 commit 3247198
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/kyber_py/modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, parent, matrix_data, transpose=False):
self.parent = parent
self._data = matrix_data
self._transpose = transpose
if not self.check_dimensions():
if not self._check_dimensions():
raise ValueError("Inconsistent row lengths in matrix")

def encode(self, d):
Expand Down
23 changes: 20 additions & 3 deletions src/kyber_py/modules/modules_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ def __init__(self, ring):
self.ring = ring
self.matrix = Matrix

def random_element(self, m, n):
elements = [
[self.ring.random_element() for _ in range(n)] for _ in range(m)
]
return self(elements)

def __repr__(self):
return f"Module over the commutative ring: {self.ring}"

Expand Down Expand Up @@ -51,19 +57,20 @@ def __init__(self, parent, matrix_data, transpose=False):
self.parent = parent
self._data = matrix_data
self._transpose = transpose
if not self.check_dimensions():
if not self._check_dimensions():
raise ValueError("Inconsistent row lengths in matrix")

def dim(self):
"""
Return the dimensions of the matrix with m rows
and n columns"""
and n columns
"""
if not self._transpose:
return len(self._data), len(self._data[0])
else:
return len(self._data[0]), len(self._data)

def check_dimensions(self):
def _check_dimensions(self):
"""
Ensure that the matrix is rectangular
"""
Expand Down Expand Up @@ -114,6 +121,16 @@ def __eq__(self, other):
[self[i, j] == other[i, j] for i in range(m) for j in range(n)]
)

def __neg__(self):
"""
Returns -self, by negating all elements
"""
m, n = self.dim()
return self.parent(
[[-self[i, j] for j in range(n)] for i in range(m)],
self._transpose,
)

def __add__(self, other):
if not isinstance(other, type(self)):
raise TypeError("Can only add matrices to other matrices")
Expand Down
102 changes: 102 additions & 0 deletions tests/test_module_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest
from random import randint
from kyber_py.polynomials.polynomials_generic import PolynomialRing
from kyber_py.modules.modules_generic import Module


class TestModule(unittest.TestCase):
R = PolynomialRing(11, 5)
M = Module(R)

def test_random_element(self):
for _ in range(100):
m = randint(1, 5)
n = randint(1, 5)
A = self.M.random_element(m, n)
self.assertEqual(type(A), self.M.matrix)
self.assertEqual(type(A[0, 0]), self.R.element)
self.assertEqual(A.dim(), (m, n))


class TestMatrix(unittest.TestCase):
R = PolynomialRing(11, 5)
M = Module(R)

def test_matrix_add(self):
zero = self.R(0)
Z = self.M([[zero, zero], [zero, zero]])
for _ in range(100):
A = self.M.random_element(2, 2)
B = self.M.random_element(2, 2)
C = self.M.random_element(2, 2)

self.assertEqual(A + Z, A)
self.assertEqual(A + B, B + A)
self.assertEqual(A + (B + C), (A + B) + C)

def test_matrix_sub(self):
zero = self.R(0)
Z = self.M([[zero, zero], [zero, zero]])
for _ in range(100):
A = self.M.random_element(2, 2)
B = self.M.random_element(2, 2)
C = self.M.random_element(2, 2)

self.assertEqual(A - Z, A)
self.assertEqual(A - B, -(B - A))
self.assertEqual(A - (B - C), (A - B) + C)

def test_matrix_mul_square(self):
zero = self.R(0)
one = self.R(1)
Z = self.M([[zero, zero], [zero, zero]])
I = self.M([[one, zero], [zero, one]])
for _ in range(100):
A = self.M.random_element(2, 2)
B = self.M.random_element(2, 2)
C = self.M.random_element(2, 2)
d = self.R.random_element()
D = self.M([[d, zero], [zero, d]])

self.assertEqual(A @ Z, Z)
self.assertEqual(A @ I, A)
self.assertEqual(A @ D, D @ A) # Diagonal matrices commute
self.assertEqual(A @ (B + C), A @ B + A @ C)

def test_matrix_mul_rectangle(self):
for _ in range(100):
A = self.M.random_element(7, 3)
B = self.M.random_element(3, 2)
C = self.M.random_element(3, 2)

self.assertEqual(A @ (B + C), A @ B + A @ C)

def test_matrix_transpose_id(self):
zero = self.R(0)
one = self.R(1)
I = self.M([[one, zero], [zero, one]])

self.assertEqual(I, I.transpose())

def test_matrix_transpose(self):
for _ in range(100):
A = self.M.random_element(7, 3)
At = A.transpose()
AAt = A @ At

self.assertEqual(AAt, AAt.transpose())

def test_matrix_dot(self):
for _ in range(100):
u = [self.R.random_element() for _ in range(5)]
v = [self.R.random_element() for _ in range(5)]
dot = sum([ui * vi for ui, vi in zip(u, v)])

U = self.M.vector(u)
V = self.M.vector(v)

self.assertEqual(dot, U.dot(V))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_polynomial_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_is_constant(self):
self.assertTrue(self.R(1).is_constant())
self.assertFalse(self.R.gen().is_constant())

def test_reduce_coefficents(self):
def test_reduce_coefficients(self):
for _ in range(100):
# Create non-canonical coefficients
coeffs = [
Expand Down

0 comments on commit 3247198

Please sign in to comment.