Skip to content

Commit

Permalink
mlkem768: move parsing and errors to high-level functions
Browse files Browse the repository at this point in the history
This is preparatory work to enable a more efficient high-level API.
  • Loading branch information
FiloSottile committed Apr 10, 2024
1 parent 23bb5fd commit a99ada4
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 109 deletions.
220 changes: 123 additions & 97 deletions mlkem768.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,18 @@ func pkeKeyGen(d []byte) (ek, dk []byte) {
return ek, dk
}

type EncapsulationKey struct {
ek [EncapsulationKeySize]byte
Hek [32]byte // H(ek)
encryptionKey
}

// encryptionKey is the parsed and expanded form of a PKE encryption key.
type encryptionKey struct {
t [k]nttElement // ByteDecode₁₂(ek[:384k])
A [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
}

// Encapsulate generates a shared key and an associated ciphertext from an
// encapsulation key, drawing random bytes from crypto/rand.
// If the encapsulation key is not valid, Encapsulate returns an error.
Expand All @@ -170,11 +182,19 @@ func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err err
if len(encapsulationKey) != EncapsulationKeySize {
return nil, nil, errors.New("mlkem768: invalid encapsulation key length")
}
m := make([]byte, messageSize)
if _, err := rand.Read(m); err != nil {
var m [messageSize]byte
if _, err := rand.Read(m[:]); err != nil {
return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
}
ciphertext, sharedKey, err = kemEncaps(encapsulationKey, m)
ek, err := parseEK(encapsulationKey)
if err != nil {
return nil, nil, err
}
ciphertext, sharedKey, err = kemEncaps(&EncapsulationKey{
ek: [EncapsulationKeySize]byte(encapsulationKey),
Hek: sha3.Sum256(encapsulationKey),
encryptionKey: *ek,
}, &m)
if err != nil {
return nil, nil, err
}
Expand All @@ -184,48 +204,52 @@ func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err err
// kemEncaps generates a shared key and an associated ciphertext.
//
// It implements ML-KEM.Encaps according to FIPS 203 (DRAFT), Algorithm 16.
func kemEncaps(ek, m []byte) (c, K []byte, err error) {
H := sha3.Sum256(ek)
func kemEncaps(ek *EncapsulationKey, m *[messageSize]byte) (c, K []byte, err error) {
g := sha3.New512()
g.Write(m)
g.Write(H[:])
g.Write(m[:])
g.Write(ek.Hek[:])
G := g.Sum(nil)
K, r := G[:SharedKeySize], G[SharedKeySize:]
c, err = pkeEncrypt(ek, m, r)
c = pkeEncrypt(&ek.encryptionKey, m, r)
return c, K, err
}

// pkeEncrypt encrypt a plaintext message. It expects ek (the encryption key) to
// be 1184 bytes, and m (the message) and rnd (the randomness) to be 32 bytes.
// parseEK parses an encryption key from its encoded form.
//
// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13.
func pkeEncrypt(ek, m, rnd []byte) ([]byte, error) {
if len(ek) != encryptionKeySize {
// It implements the initial stages of K-PKE.Encrypt according to FIPS 203
// (DRAFT), Algorithm 13.
func parseEK(ekPKE []byte) (*encryptionKey, error) {
if len(ekPKE) != encryptionKeySize {
return nil, errors.New("mlkem768: invalid encryption key length")
}
if len(m) != messageSize {
return nil, errors.New("mlkem768: invalid messages length")
}
ek := &encryptionKey{}

t := make([]nttElement, k)
for i := range t {
for i := range ek.t {
var err error
t[i], err = polyByteDecode[nttElement](ek[:encodingSize12])
ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
if err != nil {
return nil, err
}
ek = ek[encodingSize12:]
ekPKE = ekPKE[encodingSize12:]
}
ρ := ek
ρ := ekPKE

AT := make([]nttElement, k*k)
for i := byte(0); i < k; i++ {
for j := byte(0); j < k; j++ {
// Note that i and j are inverted, as we need the transposed of A.
AT[i*k+j] = sampleNTT(ρ, i, j)
// See the note in pkeKeyGen about the order of the indices being
// consistent with Kyber round 3.
ek.A[i*k+j] = sampleNTT(ρ, j, i)
}
}

return ek, nil
}

// pkeEncrypt encrypt a plaintext message.
//
// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13,
// although the computation of t and AT is done in parseEK.
func pkeEncrypt(ek *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
var N byte
r, e1 := make([]nttElement, k), make([]ringElement, k)
for i := range r {
Expand All @@ -242,18 +266,16 @@ func pkeEncrypt(ek, m, rnd []byte) ([]byte, error) {
for i := range u {
u[i] = e1[i]
for j := range r {
u[i] = polyAdd(u[i], inverseNTT(nttMul(AT[i*k+j], r[j])))
// Note that i and j are inverted, as we need the transposed of A.
u[i] = polyAdd(u[i], inverseNTT(nttMul(ek.A[j*k+i], r[j])))
}
}

μ, err := ringDecodeAndDecompress1(m)
if err != nil {
return nil, err
}
μ := ringDecodeAndDecompress1(m)

var vNTT nttElement // t⊺ ◦ r
for i := range t {
vNTT = polyAdd(vNTT, nttMul(t[i], r[i]))
for i := range ek.t {
vNTT = polyAdd(vNTT, nttMul(ek.t[i], r[i]))
}
v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)

Expand All @@ -263,7 +285,18 @@ func pkeEncrypt(ek, m, rnd []byte) ([]byte, error) {
}
c = ringCompressAndEncode4(c, v)

return c, nil
return c
}

type DecapsulationKey struct {
dk [DecapsulationKeySize]byte
encryptionKey
decryptionKey
}

// decryptionKey is the parsed and expanded form of a PKE decryption key.
type decryptionKey struct {
s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize])
}

// Decapsulate generates a shared key from a ciphertext and a decapsulation key.
Expand All @@ -278,89 +311,90 @@ func Decapsulate(decapsulationKey, ciphertext []byte) (sharedKey []byte, err err
if len(ciphertext) != CiphertextSize {
return nil, errors.New("mlkem768: invalid ciphertext length")
}
return kemDecaps(decapsulationKey, ciphertext)
dkPKE := decapsulationKey[:decryptionKeySize]
dk, err := parseDK(dkPKE)
if err != nil {
return nil, err
}
ekPKE := decapsulationKey[decryptionKeySize : decryptionKeySize+encryptionKeySize]
ek, err := parseEK(ekPKE)
if err != nil {
return nil, err
}
return kemDecaps(&DecapsulationKey{
dk: [DecapsulationKeySize]byte(decapsulationKey),
encryptionKey: *ek,
decryptionKey: *dk,
}, (*[CiphertextSize]byte)(ciphertext)), nil
}

// kemDecaps produces a shared key from a ciphertext.
//
// It implements ML-KEM.Decaps according to FIPS 203 (DRAFT), Algorithm 17.
func kemDecaps(dk, c []byte) (K []byte, err error) {
dkPKE := dk[:decryptionKeySize]
ekPKE := dk[decryptionKeySize : decryptionKeySize+encryptionKeySize]
h := dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32]
z := dk[decryptionKeySize+encryptionKeySize+32:]
func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) {
h := dk.dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32]
z := dk.dk[decryptionKeySize+encryptionKeySize+32:]

m, err := pkeDecrypt(dkPKE, c)
if err != nil {
// This is only reachable if the ciphertext or the decryption key are
// encoded incorrectly, so it leaks no information about the message.
return nil, err
}
m := pkeDecrypt(&dk.decryptionKey, c)
g := sha3.New512()
g.Write(m)
g.Write(m[:])
g.Write(h)
G := g.Sum(nil)
Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
J := sha3.NewShake256()
J.Write(z)
J.Write(c)
J.Write(c[:])
Kout := make([]byte, SharedKeySize)
J.Read(Kout)
c1, err := pkeEncrypt(ekPKE, m, r)
if err != nil {
// Likewise, this is only reachable if the encryption key is encoded
// incorrectly, so it leaks no secret information through timing.
return nil, err
}
c1 := pkeEncrypt(&dk.encryptionKey, (*[32]byte)(m), r)

subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c, c1), Kout, Kprime)
return Kout, nil
subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
return Kout
}

// pkeDecrypt decrypts a ciphertext. It expects dk (the decryption key) to
// be 1152 bytes, and c (the ciphertext) to be 1088 bytes.
// parseDK parses a decryption key from its encoded form.
//
// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14.
func pkeDecrypt(dk, c []byte) ([]byte, error) {
if len(dk) != decryptionKeySize {
// It implements the computation of s from K-PKE.Decrypt according to FIPS 203
// (DRAFT), Algorithm 14.
func parseDK(dkPKE []byte) (*decryptionKey, error) {
if len(dkPKE) != decryptionKeySize {
return nil, errors.New("mlkem768: invalid decryption key length")
}
if len(c) != CiphertextSize {
return nil, errors.New("mlkem768: invalid ciphertext length")
}
dk := &decryptionKey{}

u := make([]ringElement, k)
for i := range u {
f, err := ringDecodeAndDecompress10(c[:encodingSize10])
for i := range dk.s {
f, err := polyByteDecode[nttElement](dkPKE[:encodingSize12])
if err != nil {
return nil, err
}
u[i] = f
c = c[encodingSize10:]
dk.s[i] = f
dkPKE = dkPKE[encodingSize12:]
}

v, err := ringDecodeAndDecompress4(c)
if err != nil {
return nil, err
}
return dk, nil
}

s := make([]nttElement, k)
for i := range s {
f, err := polyByteDecode[nttElement](dk[:encodingSize12])
if err != nil {
return nil, err
}
s[i] = f
dk = dk[encodingSize12:]
// pkeDecrypt decrypts a ciphertext.
//
// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14,
// although the computation of s is done in parseDK.
func pkeDecrypt(dk *decryptionKey, c *[CiphertextSize]byte) []byte {
u := make([]ringElement, k)
for i := range u {
b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
u[i] = ringDecodeAndDecompress10(b)
}

v := ringDecodeAndDecompress4(
(*[encodingSize4]byte)(c[encodingSize10*k:]))

var mask nttElement // s⊺ ◦ NTT(u)
for i := range s {
mask = polyAdd(mask, nttMul(s[i], ntt(u[i])))
for i := range dk.s {
mask = polyAdd(mask, nttMul(dk.s[i], ntt(u[i])))
}
w := polySub(v, inverseNTT(mask))

return ringCompressAndEncode1(nil, w), nil
return ringCompressAndEncode1(nil, w)
}

// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced.
Expand Down Expand Up @@ -558,17 +592,14 @@ func ringCompressAndEncode1(s []byte, f ringElement) []byte {
//
// It implements ByteDecode₁, according to FIPS 203 (DRAFT), Algorithm 5,
// followed by Decompress₁, according to FIPS 203 (DRAFT), Definition 4.6.
func ringDecodeAndDecompress1(b []byte) (ringElement, error) {
if len(b) != encodingSize1 {
return ringElement{}, errors.New("mlkem768: invalid message length")
}
func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement {
var f ringElement
for i := range f {
b_i := b[i/8] >> (i % 8) & 1
const halfQ = (q + 1) / 2 // ⌈q/2⌋, rounded up per FIPS 203 (DRAFT), Section 2.3
f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋
}
return f, nil
return f
}

// ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s,
Expand All @@ -589,16 +620,13 @@ func ringCompressAndEncode4(s []byte, f ringElement) []byte {
//
// It implements ByteDecode₄, according to FIPS 203 (DRAFT), Algorithm 5,
// followed by Decompress₄, according to FIPS 203 (DRAFT), Definition 4.6.
func ringDecodeAndDecompress4(b []byte) (ringElement, error) {
if len(b) != encodingSize4 {
return ringElement{}, errors.New("mlkem768: invalid encoding length")
}
func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement {
var f ringElement
for i := 0; i < n; i += 2 {
f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4))
f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4))
}
return f, nil
return f
}

// ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s,
Expand Down Expand Up @@ -629,10 +657,8 @@ func ringCompressAndEncode10(s []byte, f ringElement) []byte {
//
// It implements ByteDecode₁₀, according to FIPS 203 (DRAFT), Algorithm 5,
// followed by Decompress₁₀, according to FIPS 203 (DRAFT), Definition 4.6.
func ringDecodeAndDecompress10(b []byte) (ringElement, error) {
if len(b) != encodingSize10 {
return ringElement{}, errors.New("mlkem768: invalid encoding length")
}
func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement {
b := bb[:]
var f ringElement
for i := 0; i < n; i += 4 {
x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32
Expand All @@ -642,7 +668,7 @@ func ringDecodeAndDecompress10(b []byte) (ringElement, error) {
f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10))
f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10))
}
return f, nil
return f
}

// samplePolyCBD draws a ringElement from the special Dη distribution given a
Expand Down
Loading

0 comments on commit a99ada4

Please sign in to comment.