From dc77acd7673dba235cdbe10f3c667c9390a98af0 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 8 Jan 2025 09:42:05 +0100 Subject: [PATCH] fix tests --- cshake.go | 89 +++++++++++++++++--------------------------------- cshake_test.go | 28 ++++++++-------- shims.h | 2 -- 3 files changed, 44 insertions(+), 75 deletions(-) diff --git a/cshake.go b/cshake.go index dea2f1e..25205cb 100644 --- a/cshake.go +++ b/cshake.go @@ -35,48 +35,22 @@ func SumSHAKE256(data []byte, length int) []byte { return out } -// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is -// supported. -func SupportsSHAKE128() bool { - return supportsSHAKE(128) -} - -// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is -// supported. -func SupportsSHAKE256() bool { - return supportsSHAKE(256) -} - -// SupportsCSHAKE128 returns true if the CSHAKE128 extendable output function is -// supported. -func SupportsCSHAKE128() bool { - return false -} - -// SupportsCSHAKE256 returns true if the CSHAKE256 extendable output function is -// supported. -func SupportsCSHAKE256() bool { - return false -} - -// cacheSHAKESupported is a cache of SHAKE size support. -var cacheSHAKESupported sync.Map - -// SupportsSHAKE returns true if the SHAKE extendable output function is -// supported. -func supportsSHAKE(size int) bool { +// SupportsSHAKE returns true if the SHAKE extendable output functions +// with the given securityBits are supported. +func SupportsSHAKE(securityBits int) bool { if vMajor == 1 || (vMajor == 3 && vMinor < 3) { // SHAKE MD's are supported since OpenSSL 1.1.1, // but EVP_DigestSqueeze is only supported since 3.3, // and we need it to implement [sha3.SHAKE]. return false } - if v, ok := cacheSHAKESupported.Load(size); ok { - return v.(bool) - } - supported := loadShake(size) != nil - cacheSHAKESupported.Store(size, supported) - return supported + return loadShake(securityBits) != nil +} + +// SupportsCSHAKE returns true if the CSHAKE extendable output functions +// with the given securityBits are supported. +func SupportsCSHAKE(securityBits int) bool { + return false } // SHAKE is an instance of a SHAKE extendable output function. @@ -203,35 +177,32 @@ type shakeAlgorithm struct { } // loadShake converts a crypto.Hash to a EVP_MD. -func loadShake(xofLength int) *shakeAlgorithm { - if v, ok := cacheMD.Load(xofLength); ok { +func loadShake(securityBits int) (alg *shakeAlgorithm) { + if v, ok := cacheMD.Load(securityBits); ok { return v.(*shakeAlgorithm) } + defer func() { + cacheMD.Store(securityBits, alg) + }() - var shake shakeAlgorithm - switch xofLength { + var name *C.char + switch securityBits { case 128: - if versionAtOrAbove(1, 1, 0) { - shake.md = C.go_openssl_EVP_shake128() - } + name = C.CString("SHAKE-128") case 256: - if versionAtOrAbove(1, 1, 0) { - shake.md = C.go_openssl_EVP_shake256() - } - } - if shake.md == nil { - cacheMD.Store(xofLength, (*hashAlgorithm)(nil)) + name = C.CString("SHAKE-256") + default: return nil } - shake.blockSize = int(C.go_openssl_EVP_MD_get_block_size(shake.md)) - if vMajor == 3 { - md := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(shake.md), nil) - // Don't overwrite md in case it can't be fetched, as the md may still be used - // outside of EVP_MD_CTX. - if md != nil { - shake.md = md - } + defer C.free(unsafe.Pointer(name)) + + md := C.go_openssl_EVP_MD_fetch(nil, name, nil) + if md == nil { + return nil } - cacheMD.Store(xofLength, &shake) - return &shake + + alg = new(shakeAlgorithm) + alg.md = md + alg.blockSize = int(C.go_openssl_EVP_MD_get_block_size(md)) + return alg } diff --git a/cshake_test.go b/cshake_test.go index 89c762d..59c4d71 100644 --- a/cshake_test.go +++ b/cshake_test.go @@ -30,13 +30,13 @@ func skipCSHAKEIfNotSupported(t *testing.T, algo string) { var supported bool switch algo { case "SHAKE128": - supported = openssl.SupportsSHAKE128() + supported = openssl.SupportsSHAKE(128) case "SHAKE256": - supported = openssl.SupportsSHAKE256() + supported = openssl.SupportsSHAKE(256) case "CSHAKE128": - supported = openssl.SupportsCSHAKE128() + supported = openssl.SupportsCSHAKE(128) case "CSHAKE256": - supported = openssl.SupportsCSHAKE256() + supported = openssl.SupportsCSHAKE(256) } if !supported { t.Skip("skipping: not supported") @@ -94,7 +94,7 @@ func TestCSHAKEReset(t *testing.T) { skipCSHAKEIfNotSupported(t, algo) // Calculate hash for the first time - c := v.constructor(nil, []byte{0x99, 0x98}) + c := v.constructor(nil, []byte(v.defCustomStr)) c.Write(sequentialBytes(0x100)) c.Read(out1) @@ -112,14 +112,14 @@ func TestCSHAKEReset(t *testing.T) { func TestCSHAKEAccumulated(t *testing.T) { t.Run("CSHAKE128", func(t *testing.T) { - if !openssl.SupportsSHAKE128() { + if !openssl.SupportsCSHAKE(128) { t.Skip("skipping: not supported") } testCSHAKEAccumulated(t, openssl.NewCSHAKE128, (1600-256)/8, "bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252") }) t.Run("CSHAKE256", func(t *testing.T) { - if !openssl.SupportsSHAKE256() { + if !openssl.SupportsCSHAKE(256) { t.Skip("skipping: not supported") } testCSHAKEAccumulated(t, openssl.NewCSHAKE256, (1600-512)/8, @@ -158,7 +158,7 @@ func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *openssl.SH } func TestCSHAKELargeS(t *testing.T) { - if !openssl.SupportsSHAKE128() { + if !openssl.SupportsCSHAKE(128) { t.Skip("skipping: not supported") } const s = (1<<32)/8 + 1000 // s * 8 > 2^32 @@ -178,11 +178,11 @@ func TestCSHAKELargeS(t *testing.T) { func TestCSHAKESum(t *testing.T) { const testString = "hello world" - t.Run("CSHAKE128", func(t *testing.T) { - if !openssl.SupportsSHAKE128() { + t.Run("SHAKE128", func(t *testing.T) { + if !openssl.SupportsSHAKE(128) { t.Skip("skipping: not supported") } - h := openssl.NewCSHAKE128(nil, nil) + h := openssl.NewSHAKE128() h.Write([]byte(testString[:5])) h.Write([]byte(testString[5:])) want := make([]byte, 32) @@ -192,11 +192,11 @@ func TestCSHAKESum(t *testing.T) { t.Errorf("got:%x want:%x", got, want) } }) - t.Run("CSHAKE256", func(t *testing.T) { - if !openssl.SupportsSHAKE256() { + t.Run("SHAKE256", func(t *testing.T) { + if !openssl.SupportsSHAKE(256) { t.Skip("skipping: not supported") } - h := openssl.NewCSHAKE256(nil, nil) + h := openssl.NewSHAKE256() h.Write([]byte(testString[:5])) h.Write([]byte(testString[5:])) want := make([]byte, 32) diff --git a/shims.h b/shims.h index c95c4f9..df51d37 100644 --- a/shims.h +++ b/shims.h @@ -252,8 +252,6 @@ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_224, (void), ()) \ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_256, (void), ()) \ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_384, (void), ()) \ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_512, (void), ()) \ -DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_shake128, (void), ()) \ -DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_shake256, (void), ()) \ DEFINEFUNC_LEGACY_1_0(void, HMAC_CTX_init, (GO_HMAC_CTX_PTR arg0), (arg0)) \ DEFINEFUNC_LEGACY_1_0(void, HMAC_CTX_cleanup, (GO_HMAC_CTX_PTR arg0), (arg0)) \ DEFINEFUNC_LEGACY_1(int, HMAC_Init_ex, (GO_HMAC_CTX_PTR arg0, const void *arg1, int arg2, const GO_EVP_MD_PTR arg3, GO_ENGINE_PTR arg4), (arg0, arg1, arg2, arg3, arg4)) \