diff --git a/hash.go b/hash.go index e0ea56d8..120fc412 100644 --- a/hash.go +++ b/hash.go @@ -231,6 +231,34 @@ func (h *evpHash) sum(out []byte) { runtime.KeepAlive(h) } +// clone returns a new evpHash object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *evpHash) clone() (*evpHash, error) { + ctx := C.go_openssl_EVP_MD_CTX_new() + if ctx == nil { + return nil, newOpenSSLError("EVP_MD_CTX_new") + } + if C.go_openssl_EVP_MD_CTX_copy_ex(ctx, h.ctx) != 1 { + C.go_openssl_EVP_MD_CTX_free(ctx) + return nil, newOpenSSLError("EVP_MD_CTX_copy_ex") + } + ctx2 := C.go_openssl_EVP_MD_CTX_new() + if ctx2 == nil { + C.go_openssl_EVP_MD_CTX_free(ctx) + return nil, newOpenSSLError("EVP_MD_CTX_new") + } + cloned := &evpHash{ + ctx: ctx, + ctx2: ctx2, + size: h.size, + blockSize: h.blockSize, + marshallable: h.marshallable, + } + runtime.SetFinalizer(cloned, (*evpHash).finalize) + return cloned, nil +} + var testNotMarshalable bool // Used in tests. // hashState returns a pointer to the internal hash structure. @@ -576,6 +604,17 @@ func (h *sha256Hash) UnmarshalBinary(b []byte) error { return nil } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha256Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha256Hash{evpHash: c}, nil +} + // NewSHA384 returns a new SHA384 hash. func NewSHA384() hash.Hash { return &sha384Hash{ @@ -588,6 +627,17 @@ type sha384Hash struct { out [384 / 8]byte } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha384Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha384Hash{evpHash: c}, nil +} + func (h *sha384Hash) Sum(in []byte) []byte { h.sum(h.out[:]) return append(in, h.out[:]...) @@ -731,6 +781,17 @@ func (h *sha512Hash) UnmarshalBinary(b []byte) error { return nil } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha512Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha512Hash{evpHash: c}, nil +} + // NewSHA3_224 returns a new SHA3-224 hash. func NewSHA3_224() hash.Hash { return &sha3_224Hash{ diff --git a/hash_test.go b/hash_test.go index dcf01c01..c7103e28 100644 --- a/hash_test.go +++ b/hash_test.go @@ -62,6 +62,38 @@ func TestHashNotMarshalable(t *testing.T) { } } +func TestHash_Clone(t *testing.T) { + msg := []byte("testing") + for _, ch := range []crypto.Hash{crypto.SHA256, crypto.SHA384, crypto.SHA512} { + ch := ch + t.Run(ch.String(), func(t *testing.T) { + t.Parallel() + if !openssl.SupportsHash(ch) { + t.Skip("skipping: not supported") + } + h := cryptoToHash(ch)() + if _, ok := h.(encoding.BinaryMarshaler); !ok { + t.Skip("skipping: not supported") + } + _, err := h.Write(msg) + if err != nil { + t.Fatal(err) + } + // We don't define an interface for the Clone method to avoid other + // packages from depending on it. Use type assertion to call it. + h2, err := h.(interface{ Clone() (hash.Hash, error) }).Clone() + if err != nil { + t.Fatal(err) + } + h.Write(msg) + h2.Write(msg) + if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { + t.Errorf("%s(%q) = 0x%x != cloned 0x%x", ch.String(), msg, actual, actual2) + } + }) + } +} + func TestHash(t *testing.T) { msg := []byte("testing") var tests = []struct { diff --git a/shims.h b/shims.h index 6ea5bc6d..57c2b36c 100644 --- a/shims.h +++ b/shims.h @@ -199,6 +199,7 @@ DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \ DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \ DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \ DEFINEFUNC(int, EVP_MD_CTX_copy, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \ +DEFINEFUNC(int, EVP_MD_CTX_copy_ex, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \ DEFINEFUNC(int, EVP_Digest, (const void *data, size_t count, unsigned char *md, unsigned int *size, const GO_EVP_MD_PTR type, GO_ENGINE_PTR impl), (data, count, md, size, type, impl)) \ DEFINEFUNC(int, EVP_DigestInit_ex, (GO_EVP_MD_CTX_PTR ctx, const GO_EVP_MD_PTR type, GO_ENGINE_PTR impl), (ctx, type, impl)) \ DEFINEFUNC(int, EVP_DigestInit, (GO_EVP_MD_CTX_PTR ctx, const GO_EVP_MD_PTR type), (ctx, type)) \