Skip to content

Commit

Permalink
speed up NewShaX (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal authored Jan 3, 2025
1 parent d9e21e3 commit 313c54f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 36 deletions.
100 changes: 64 additions & 36 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,27 +251,42 @@ func newEvpHash(ch crypto.Hash) *evpHash {
if alg == nil {
panic("openssl: unsupported hash function: " + strconv.Itoa(int(ch)))
}
ctx := C.go_openssl_EVP_MD_CTX_new()
if C.go_openssl_EVP_DigestInit_ex(ctx, alg.md, nil) != 1 {
C.go_openssl_EVP_MD_CTX_free(ctx)
panic(newOpenSSLError("EVP_DigestInit_ex"))
}
ctx2 := C.go_openssl_EVP_MD_CTX_new()
h := &evpHash{
alg: alg,
ctx: ctx,
ctx2: ctx2,
}
runtime.SetFinalizer(h, (*evpHash).finalize)
h := &evpHash{alg: alg}
// Don't call init() yet, it would be wasteful
// if the caller only wants to know the hash type. This
// is a common pattern in this package, as some functions
// accept a `func() hash.Hash` parameter and call it just
// to know the hash type.
return h
}

func (h *evpHash) finalize() {
C.go_openssl_EVP_MD_CTX_free(h.ctx)
C.go_openssl_EVP_MD_CTX_free(h.ctx2)
if h.ctx != nil {
C.go_openssl_EVP_MD_CTX_free(h.ctx)
}
if h.ctx2 != nil {
C.go_openssl_EVP_MD_CTX_free(h.ctx2)
}
}

func (h *evpHash) init() {
if h.ctx != nil {
return
}
h.ctx = C.go_openssl_EVP_MD_CTX_new()
if C.go_openssl_EVP_DigestInit_ex(h.ctx, h.alg.md, nil) != 1 {
C.go_openssl_EVP_MD_CTX_free(h.ctx)
panic(newOpenSSLError("EVP_DigestInit_ex"))
}
h.ctx2 = C.go_openssl_EVP_MD_CTX_new()
runtime.SetFinalizer(h, (*evpHash).finalize)
}

func (h *evpHash) Reset() {
if h.ctx == nil {
// The hash is not initialized yet, no need to reset.
return
}
// There is no need to reset h.ctx2 because it is always reset after
// use in evpHash.sum.
if C.go_openssl_EVP_DigestInit_ex(h.ctx, nil, nil) != 1 {
Expand All @@ -281,22 +296,31 @@ func (h *evpHash) Reset() {
}

func (h *evpHash) Write(p []byte) (int, error) {
if len(p) > 0 && C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&*addr(p)), C.size_t(len(p))) != 1 {
if len(p) == 0 {
return 0, nil
}
h.init()
if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&*addr(p)), C.size_t(len(p))) != 1 {
panic(newOpenSSLError("EVP_DigestUpdate"))
}
runtime.KeepAlive(h)
return len(p), nil
}

func (h *evpHash) WriteString(s string) (int, error) {
if len(s) > 0 && C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(unsafe.StringData(s)), C.size_t(len(s))) == 0 {
if len(s) == 0 {
return 0, nil
}
h.init()
if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(unsafe.StringData(s)), C.size_t(len(s))) == 0 {
panic("openssl: EVP_DigestUpdate failed")
}
runtime.KeepAlive(h)
return len(s), nil
}

func (h *evpHash) WriteByte(c byte) error {
h.init()
if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&c), 1) == 0 {
panic("openssl: EVP_DigestUpdate failed")
}
Expand All @@ -313,38 +337,38 @@ func (h *evpHash) BlockSize() int {
}

func (h *evpHash) Sum(in []byte) []byte {
defer runtime.KeepAlive(h)
h.init()
out := make([]byte, h.Size(), maxHashSize) // explicit cap to allow stack allocation
if C.go_hash_sum(h.ctx, h.ctx2, base(out)) != 1 {
panic(newOpenSSLError("go_hash_sum"))
}
runtime.KeepAlive(h)
return append(in, out...)
}

// Clone returns a new evpHash object that is a deep clone of itself.
// The duplicate object contains all state and data contained in the
// original object at the point of duplication.
func (h *evpHash) Clone() (hash.Hash, error) {
ctx := C.go_openssl_EVP_MD_CTX_new()
if ctx == nil {
return nil, newOpenSSLError("EVP_MD_CTX_new")
}
if C.go_openssl_EVP_MD_CTX_copy_ex(ctx, h.ctx) != 1 {
C.go_openssl_EVP_MD_CTX_free(ctx)
return nil, newOpenSSLError("EVP_MD_CTX_copy")
}
ctx2 := C.go_openssl_EVP_MD_CTX_new()
if ctx2 == nil {
C.go_openssl_EVP_MD_CTX_free(ctx)
return nil, newOpenSSLError("EVP_MD_CTX_new")
}
cloned := &evpHash{
alg: h.alg,
ctx: ctx,
ctx2: ctx2,
h2 := &evpHash{alg: h.alg}
if h.ctx != nil {
h2.ctx = C.go_openssl_EVP_MD_CTX_new()
if h2.ctx == nil {
return nil, newOpenSSLError("EVP_MD_CTX_new")
}
if C.go_openssl_EVP_MD_CTX_copy_ex(h2.ctx, h.ctx) != 1 {
C.go_openssl_EVP_MD_CTX_free(h2.ctx)
return nil, newOpenSSLError("EVP_MD_CTX_copy")
}
h2.ctx2 = C.go_openssl_EVP_MD_CTX_new()
if h2.ctx2 == nil {
C.go_openssl_EVP_MD_CTX_free(h2.ctx)
return nil, newOpenSSLError("EVP_MD_CTX_new")
}
runtime.SetFinalizer(h2, (*evpHash).finalize)
}
runtime.SetFinalizer(cloned, (*evpHash).finalize)
return cloned, nil
runtime.KeepAlive(h)
return h2, nil
}

// hashState returns a pointer to the internal hash structure.
Expand Down Expand Up @@ -384,6 +408,8 @@ func (d *evpHash) MarshalBinary() ([]byte, error) {
}

func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) {
defer runtime.KeepAlive(d)
d.init()
if !d.alg.marshallable {
return nil, errors.New("openssl: hash state is not marshallable")
}
Expand Down Expand Up @@ -419,6 +445,8 @@ func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) {
}

func (d *evpHash) UnmarshalBinary(b []byte) error {
defer runtime.KeepAlive(d)
d.init()
if !d.alg.marshallable {
return errors.New("openssl: hash state is not marshallable")
}
Expand Down
7 changes: 7 additions & 0 deletions hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ func BenchmarkSHA256(b *testing.B) {
}
}

func BenchmarkNewSHA256(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
openssl.NewSHA256()
}
}

// stubHash is a hash.Hash implementation that does nothing.
type stubHash struct{}

Expand Down

0 comments on commit 313c54f

Please sign in to comment.