From 832cac902a3d4853d3b77357dd2cef8e7aba4c03 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Tue, 26 Nov 2024 09:48:54 +0100 Subject: [PATCH] fix memory leak --- ec.go | 39 +++++++++++++++++++++++++++++++-------- ecdh.go | 27 +++++---------------------- ecdsa.go | 5 +---- evp.go | 5 +---- 4 files changed, 38 insertions(+), 38 deletions(-) diff --git a/ec.go b/ec.go index ccaf5a4c..efc004e0 100644 --- a/ec.go +++ b/ec.go @@ -4,26 +4,26 @@ package openssl // #include "goopenssl.h" import "C" +import "errors" -func curveNID(curve string) (C.int, error) { +func curveNID(curve string) C.int { switch curve { case "P-224": - return C.GO_NID_secp224r1, nil + return C.GO_NID_secp224r1 case "P-256": - return C.GO_NID_X9_62_prime256v1, nil + return C.GO_NID_X9_62_prime256v1 case "P-384": - return C.GO_NID_secp384r1, nil + return C.GO_NID_secp384r1 case "P-521": - return C.GO_NID_secp521r1, nil + return C.GO_NID_secp521r1 + default: + panic("openssl: unknown curve " + curve) } - return 0, errUnknownCurve } // curveSize returns the size of the curve in bytes. func curveSize(curve string) int { switch curve { - default: - panic("openssl: unknown curve " + curve) case "P-224": return 224 / 8 case "P-256": @@ -32,6 +32,8 @@ func curveSize(curve string) int { return 384 / 8 case "P-521": return (521 + 7) / 8 + default: + panic("openssl: unknown curve " + curve) } } @@ -65,3 +67,24 @@ func generateAndEncodeEcPublicKey(nid C.int, newPubKeyPointFn func(group C.GO_EC defer C.go_openssl_EC_POINT_free(pt) return encodeEcPoint(group, pt) } + +func checkPkey(pkey C.GO_EVP_PKEY_PTR, isPrivate bool) error { + ctx := C.go_openssl_EVP_PKEY_CTX_new(pkey, nil) + if ctx == nil { + return newOpenSSLError("EVP_PKEY_CTX_new") + } + defer C.go_openssl_EVP_PKEY_CTX_free(ctx) + if isPrivate { + if C.go_openssl_EVP_PKEY_private_check(ctx) != 1 { + // Match upstream error message. + return errors.New("invalid private key") + } + } else { + // Upstream Go does a partial check here, so do we. + if C.go_openssl_EVP_PKEY_public_check_quick(ctx) != 1 { + // Match upstream error message. + return errors.New("invalid public key") + } + } + return nil +} diff --git a/ecdh.go b/ecdh.go index 0ea8fb59..d8c72f62 100644 --- a/ecdh.go +++ b/ecdh.go @@ -107,10 +107,7 @@ func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { } func newECDHPkey(curve string, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, error) { - nid, err := curveNID(curve) - if err != nil { - return nil, err - } + nid := curveNID(curve) switch vMajor { case 1: return newECDHPkey1(nid, bytes, isPrivate) @@ -214,24 +211,10 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e if err != nil { return nil, err } - ctx := C.go_openssl_EVP_PKEY_CTX_new(pkey, nil) - if ctx == nil { - return nil, newOpenSSLError("EVP_PKEY_CTX_new") - } - defer C.go_openssl_EVP_PKEY_CTX_free(ctx) - if isPrivate { - if C.go_openssl_EVP_PKEY_private_check(ctx) != 1 { - C.go_openssl_EVP_PKEY_free(pkey) - // Match upstream error message. - return nil, errors.New("crypto/ecdh: invalid private key") - } - } else { - // Upstream Go does a partial check here, so do we. - if C.go_openssl_EVP_PKEY_public_check_quick(ctx) != 1 { - C.go_openssl_EVP_PKEY_free(pkey) - // Match upstream error message. - return nil, errors.New("crypto/ecdh: invalid public key") - } + + if err := checkPkey(pkey, isPrivate); err != nil { + C.go_openssl_EVP_PKEY_free(pkey) + return nil, errors.New("crypto/ecdh: " + err.Error()) } return pkey, nil } diff --git a/ecdsa.go b/ecdsa.go index f85782a6..bc5f1117 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -122,10 +122,7 @@ func HashVerifyECDSA(pub *PublicKeyECDSA, h crypto.Hash, msg, sig []byte) bool { } func newECDSAKey(curve string, x, y, d BigInt) (C.GO_EVP_PKEY_PTR, error) { - nid, err := curveNID(curve) - if err != nil { - return nil, err - } + nid := curveNID(curve) var bx, by, bd C.GO_BIGNUM_PTR defer func() { C.go_openssl_BN_free(bx) diff --git a/evp.go b/evp.go index 91296a93..17d040a4 100644 --- a/evp.go +++ b/evp.go @@ -175,10 +175,7 @@ func generateEVPPKey(id C.int, bits int, curve string) (C.GO_EVP_PKEY_PTR, error } } if curve != "" { - nid, err := curveNID(curve) - if err != nil { - return nil, err - } + nid := curveNID(curve) if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, id, -1, C.GO_EVP_PKEY_CTRL_EC_PARAMGEN_CURVE_NID, nid, nil) != 1 { return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed") }