Skip to content

Commit

Permalink
mlkem768: add exhaustive tests for compress and decompress (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 authored Jan 17, 2024
1 parent 1507764 commit 344d5ee
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions mlkem768_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"flag"
"math/big"
"os"
"strconv"
"strings"
"testing"

Expand Down Expand Up @@ -79,6 +80,72 @@ func TestDecompressCompress(t *testing.T) {
}
}

func CompressRat(x fieldElement, d uint8) uint16 {
if x < 0 || x >= q {
panic("x out of range")
}
if d <= 0 || d >= 12 {
panic("d out of range")
}

precise := big.NewRat((1<<d)*int64(x), q) // (2ᵈ / q) * x == (2ᵈ * x) / q

// FloatString rounds halves away from 0, and our result should always be positive,
// so it should work as we expect. (There's no direct way to round a Rat.)
rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
if err != nil {
panic(err)
}

// If we rounded up, `rounded` may be equal to 2ᵈ, so we perform a final reduction.
return uint16(rounded % (1 << d))
}

func TestCompress(t *testing.T) {
for d := 1; d < 12; d++ {
for n := 0; n < q; n++ {
expected := CompressRat(fieldElement(n), uint8(d))
result := compress(fieldElement(n), uint8(d))
if result != expected {
t.Errorf("compress(%d, %d): got %d, expected %d", n, d, result, expected)
}
}
}
}

func DecompressRat(y uint16, d uint8) fieldElement {
if y < 0 || y >= 1<<d {
panic("y out of range")
}
if d <= 0 || d >= 12 {
panic("d out of range")
}

precise := big.NewRat(q*int64(y), 1<<d) // (q / 2ᵈ) * y == (q * y) / 2ᵈ

// FloatString rounds halves away from 0, and our result should always be positive,
// so it should work as we expect. (There's no direct way to round a Rat.)
rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
if err != nil {
panic(err)
}

// If we rounded up, `rounded` may be equal to q, so we perform a final reduction.
return fieldElement(rounded % q)
}

func TestDecompress(t *testing.T) {
for d := 1; d < 12; d++ {
for n := 0; n < (1 << d); n++ {
expected := DecompressRat(uint16(n), uint8(d))
result := decompress(uint16(n), uint8(d))
if result != expected {
t.Errorf("decompress(%d, %d): got %d, expected %d", n, d, result, expected)
}
}
}
}

func BitRev7(n uint8) uint8 {
if n>>7 != 0 {
panic("not 7 bits")
Expand Down

0 comments on commit 344d5ee

Please sign in to comment.