Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Jan 8, 2025
1 parent 6b09b11 commit dc77acd
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 75 deletions.
89 changes: 30 additions & 59 deletions cshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,48 +35,22 @@ func SumSHAKE256(data []byte, length int) []byte {
return out
}

// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is
// supported.
func SupportsSHAKE128() bool {
return supportsSHAKE(128)
}

// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is
// supported.
func SupportsSHAKE256() bool {
return supportsSHAKE(256)
}

// SupportsCSHAKE128 returns true if the CSHAKE128 extendable output function is
// supported.
func SupportsCSHAKE128() bool {
return false
}

// SupportsCSHAKE256 returns true if the CSHAKE256 extendable output function is
// supported.
func SupportsCSHAKE256() bool {
return false
}

// cacheSHAKESupported is a cache of SHAKE size support.
var cacheSHAKESupported sync.Map

// SupportsSHAKE returns true if the SHAKE extendable output function is
// supported.
func supportsSHAKE(size int) bool {
// SupportsSHAKE returns true if the SHAKE extendable output functions
// with the given securityBits are supported.
func SupportsSHAKE(securityBits int) bool {
if vMajor == 1 || (vMajor == 3 && vMinor < 3) {
// SHAKE MD's are supported since OpenSSL 1.1.1,
// but EVP_DigestSqueeze is only supported since 3.3,
// and we need it to implement [sha3.SHAKE].
return false
}
if v, ok := cacheSHAKESupported.Load(size); ok {
return v.(bool)
}
supported := loadShake(size) != nil
cacheSHAKESupported.Store(size, supported)
return supported
return loadShake(securityBits) != nil
}

// SupportsCSHAKE returns true if the CSHAKE extendable output functions
// with the given securityBits are supported.
func SupportsCSHAKE(securityBits int) bool {
return false
}

// SHAKE is an instance of a SHAKE extendable output function.
Expand Down Expand Up @@ -203,35 +177,32 @@ type shakeAlgorithm struct {
}

// loadShake converts a crypto.Hash to a EVP_MD.
func loadShake(xofLength int) *shakeAlgorithm {
if v, ok := cacheMD.Load(xofLength); ok {
func loadShake(securityBits int) (alg *shakeAlgorithm) {
if v, ok := cacheMD.Load(securityBits); ok {
return v.(*shakeAlgorithm)
}
defer func() {
cacheMD.Store(securityBits, alg)
}()

var shake shakeAlgorithm
switch xofLength {
var name *C.char
switch securityBits {
case 128:
if versionAtOrAbove(1, 1, 0) {
shake.md = C.go_openssl_EVP_shake128()
}
name = C.CString("SHAKE-128")
case 256:
if versionAtOrAbove(1, 1, 0) {
shake.md = C.go_openssl_EVP_shake256()
}
}
if shake.md == nil {
cacheMD.Store(xofLength, (*hashAlgorithm)(nil))
name = C.CString("SHAKE-256")
default:
return nil
}
shake.blockSize = int(C.go_openssl_EVP_MD_get_block_size(shake.md))
if vMajor == 3 {
md := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(shake.md), nil)
// Don't overwrite md in case it can't be fetched, as the md may still be used
// outside of EVP_MD_CTX.
if md != nil {
shake.md = md
}
defer C.free(unsafe.Pointer(name))

md := C.go_openssl_EVP_MD_fetch(nil, name, nil)
if md == nil {
return nil
}
cacheMD.Store(xofLength, &shake)
return &shake

alg = new(shakeAlgorithm)
alg.md = md
alg.blockSize = int(C.go_openssl_EVP_MD_get_block_size(md))
return alg
}
28 changes: 14 additions & 14 deletions cshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ func skipCSHAKEIfNotSupported(t *testing.T, algo string) {
var supported bool
switch algo {
case "SHAKE128":
supported = openssl.SupportsSHAKE128()
supported = openssl.SupportsSHAKE(128)
case "SHAKE256":
supported = openssl.SupportsSHAKE256()
supported = openssl.SupportsSHAKE(256)
case "CSHAKE128":
supported = openssl.SupportsCSHAKE128()
supported = openssl.SupportsCSHAKE(128)
case "CSHAKE256":
supported = openssl.SupportsCSHAKE256()
supported = openssl.SupportsCSHAKE(256)
}
if !supported {
t.Skip("skipping: not supported")
Expand Down Expand Up @@ -94,7 +94,7 @@ func TestCSHAKEReset(t *testing.T) {
skipCSHAKEIfNotSupported(t, algo)

// Calculate hash for the first time
c := v.constructor(nil, []byte{0x99, 0x98})
c := v.constructor(nil, []byte(v.defCustomStr))
c.Write(sequentialBytes(0x100))
c.Read(out1)

Expand All @@ -112,14 +112,14 @@ func TestCSHAKEReset(t *testing.T) {

func TestCSHAKEAccumulated(t *testing.T) {
t.Run("CSHAKE128", func(t *testing.T) {
if !openssl.SupportsSHAKE128() {
if !openssl.SupportsCSHAKE(128) {
t.Skip("skipping: not supported")
}
testCSHAKEAccumulated(t, openssl.NewCSHAKE128, (1600-256)/8,
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
})
t.Run("CSHAKE256", func(t *testing.T) {
if !openssl.SupportsSHAKE256() {
if !openssl.SupportsCSHAKE(256) {
t.Skip("skipping: not supported")
}
testCSHAKEAccumulated(t, openssl.NewCSHAKE256, (1600-512)/8,
Expand Down Expand Up @@ -158,7 +158,7 @@ func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *openssl.SH
}

func TestCSHAKELargeS(t *testing.T) {
if !openssl.SupportsSHAKE128() {
if !openssl.SupportsCSHAKE(128) {
t.Skip("skipping: not supported")
}
const s = (1<<32)/8 + 1000 // s * 8 > 2^32
Expand All @@ -178,11 +178,11 @@ func TestCSHAKELargeS(t *testing.T) {

func TestCSHAKESum(t *testing.T) {
const testString = "hello world"
t.Run("CSHAKE128", func(t *testing.T) {
if !openssl.SupportsSHAKE128() {
t.Run("SHAKE128", func(t *testing.T) {
if !openssl.SupportsSHAKE(128) {
t.Skip("skipping: not supported")
}
h := openssl.NewCSHAKE128(nil, nil)
h := openssl.NewSHAKE128()
h.Write([]byte(testString[:5]))
h.Write([]byte(testString[5:]))
want := make([]byte, 32)
Expand All @@ -192,11 +192,11 @@ func TestCSHAKESum(t *testing.T) {
t.Errorf("got:%x want:%x", got, want)
}
})
t.Run("CSHAKE256", func(t *testing.T) {
if !openssl.SupportsSHAKE256() {
t.Run("SHAKE256", func(t *testing.T) {
if !openssl.SupportsSHAKE(256) {
t.Skip("skipping: not supported")
}
h := openssl.NewCSHAKE256(nil, nil)
h := openssl.NewSHAKE256()
h.Write([]byte(testString[:5]))
h.Write([]byte(testString[5:]))
want := make([]byte, 32)
Expand Down
2 changes: 0 additions & 2 deletions shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,6 @@ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_224, (void), ()) \
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_256, (void), ()) \
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_384, (void), ()) \
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_512, (void), ()) \
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_shake128, (void), ()) \
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_shake256, (void), ()) \
DEFINEFUNC_LEGACY_1_0(void, HMAC_CTX_init, (GO_HMAC_CTX_PTR arg0), (arg0)) \
DEFINEFUNC_LEGACY_1_0(void, HMAC_CTX_cleanup, (GO_HMAC_CTX_PTR arg0), (arg0)) \
DEFINEFUNC_LEGACY_1(int, HMAC_Init_ex, (GO_HMAC_CTX_PTR arg0, const void *arg1, int arg2, const GO_EVP_MD_PTR arg3, GO_ENGINE_PTR arg4), (arg0, arg1, arg2, arg3, arg4)) \
Expand Down

0 comments on commit dc77acd

Please sign in to comment.