Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement batch hash utils #384

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion packages/persistent-merkle-tree/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
"clean": "rm -rf lib",
"build": "tsc",
"lint": "eslint --color --ext .ts src/",
"benchmark": "node --max-old-space-size=4096 --expose-gc -r ts-node/register ./node_modules/.bin/benchmark 'test/perf/*.perf.ts'",
"lint:fix": "yarn run lint --fix",
"benchmark:files": "node --max-old-space-size=4096 --expose-gc -r ts-node/register ../../node_modules/.bin/benchmark",
"benchmark": "yarn benchmark:files 'test/perf/*.test.ts'",
"benchmark:local": "yarn benchmark --local",
"test": "mocha -r ts-node/register 'test/unit/**/*.test.ts'"
},
Expand All @@ -45,6 +47,7 @@
"homepage": "https://github.com/ChainSafe/persistent-merkle-tree#readme",
"dependencies": {
"@chainsafe/as-sha256": "0.4.2",
"@chainsafe/hashtree": "1.0.1",
"@noble/hashes": "^1.3.0"
}
}
138 changes: 136 additions & 2 deletions packages/persistent-merkle-tree/src/hasher/as-sha256.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,141 @@
import {digest2Bytes32, digest64HashObjects} from "@chainsafe/as-sha256";
import {
digest2Bytes32,
digest64HashObjectsInto,
digest64HashObjects,
batchHash4HashObjectInputs,
hashInto,
} from "@chainsafe/as-sha256";
import type {Hasher} from "./types";
import {HashComputation, Node} from "../node";
import {doDigestNLevel, doMerkleizeInto} from "./util";

export const hasher: Hasher = {
name: "as-sha256",
digest64: digest2Bytes32,
digest64HashObjects,
digest64HashObjects: digest64HashObjectsInto,
merkleizeInto(data: Uint8Array, padFor: number, output: Uint8Array, offset: number): void {
return doMerkleizeInto(data, padFor, output, offset, hashInto);
},
digestNLevel(data: Uint8Array, nLevel: number): Uint8Array {
return doDigestNLevel(data, nLevel, hashInto);
},
executeHashComputations: (hashComputations: HashComputation[][]) => {
for (let level = hashComputations.length - 1; level >= 0; level--) {
const hcArr = hashComputations[level];
if (!hcArr) {
// should not happen
throw Error(`no hash computations for level ${level}`);
}

if (hcArr.length === 0) {
// nothing to hash
continue;
}

// HashComputations of the same level are safe to batch
let src0_0: Node | null = null;
let src1_0: Node | null = null;
let dest0: Node | null = null;
let src0_1: Node | null = null;
let src1_1: Node | null = null;
let dest1: Node | null = null;
let src0_2: Node | null = null;
let src1_2: Node | null = null;
let dest2: Node | null = null;
let src0_3: Node | null = null;
let src1_3: Node | null = null;
let dest3: Node | null = null;

for (const [i, hc] of hcArr.entries()) {
const indexInBatch = i % 4;

switch (indexInBatch) {
case 0:
src0_0 = hc.src0;
src1_0 = hc.src1;
dest0 = hc.dest;
break;
case 1:
src0_1 = hc.src0;
src1_1 = hc.src1;
dest1 = hc.dest;
break;
case 2:
src0_2 = hc.src0;
src1_2 = hc.src1;
dest2 = hc.dest;
break;
case 3:
src0_3 = hc.src0;
src1_3 = hc.src1;
dest3 = hc.dest;

if (
src0_0 !== null &&
src1_0 !== null &&
dest0 !== null &&
src0_1 !== null &&
src1_1 !== null &&
dest1 !== null &&
src0_2 !== null &&
src1_2 !== null &&
dest2 !== null &&
src0_3 !== null &&
src1_3 !== null &&
dest3 !== null
) {
// TODO - batch: find a way not allocate here
const [o0, o1, o2, o3] = batchHash4HashObjectInputs([
src0_0,
src1_0,
src0_1,
src1_1,
src0_2,
src1_2,
src0_3,
src1_3,
]);
if (o0 == null || o1 == null || o2 == null || o3 == null) {
throw Error(`batchHash4HashObjectInputs return null or undefined at batch ${i} level ${level}`);
}
dest0.applyHash(o0);
dest1.applyHash(o1);
dest2.applyHash(o2);
dest3.applyHash(o3);

// reset for next batch
src0_0 = null;
src1_0 = null;
dest0 = null;
src0_1 = null;
src1_1 = null;
dest1 = null;
src0_2 = null;
src1_2 = null;
dest2 = null;
src0_3 = null;
src1_3 = null;
dest3 = null;
}
break;
default:
throw Error(`Unexpected indexInBatch ${indexInBatch}`);
}
}

// remaining
if (src0_0 !== null && src1_0 !== null && dest0 !== null) {
dest0.applyHash(digest64HashObjects(src0_0, src1_0));
}
if (src0_1 !== null && src1_1 !== null && dest1 !== null) {
dest1.applyHash(digest64HashObjects(src0_1, src1_1));
}
if (src0_2 !== null && src1_2 !== null && dest2 !== null) {
dest2.applyHash(digest64HashObjects(src0_2, src1_2));
}
if (src0_3 !== null && src1_3 !== null && dest3 !== null) {
dest3.applyHash(digest64HashObjects(src0_3, src1_3));
}
}
},
};
124 changes: 124 additions & 0 deletions packages/persistent-merkle-tree/src/hasher/hashtree.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import {hashInto} from "@chainsafe/hashtree";
import {Hasher, HashObject} from "./types";
import {HashComputation, Node} from "../node";
import {byteArrayIntoHashObject} from "@chainsafe/as-sha256/lib/hashObject";
import {doDigestNLevel, doMerkleizeInto} from "./util";

/**
* Best SIMD implementation is in 512 bits = 64 bytes
* If not, hashtree will make a loop inside
* Given sha256 operates on a block of 4 bytes, we can hash 16 inputs at once
* Each input is 64 bytes
*/
const PARALLEL_FACTOR = 16;
const MAX_INPUT_SIZE = PARALLEL_FACTOR * 64;
const uint8Input = new Uint8Array(MAX_INPUT_SIZE);
const uint32Input = new Uint32Array(uint8Input.buffer);
const uint8Output = new Uint8Array(PARALLEL_FACTOR * 32);
// having this will cause more memory to extract uint32
// const uint32Output = new Uint32Array(uint8Output.buffer);
// convenient reusable Uint8Array for hash64
const hash64Input = uint8Input.subarray(0, 64);
const hash64Output = uint8Output.subarray(0, 32);

export const hasher: Hasher = {
name: "hashtree",
digest64(obj1: Uint8Array, obj2: Uint8Array): Uint8Array {
if (obj1.length !== 32 || obj2.length !== 32) {
throw new Error("Invalid input length");
}
hash64Input.set(obj1, 0);
hash64Input.set(obj2, 32);
hashInto(hash64Input, hash64Output);
return hash64Output.slice();
},
digest64HashObjects(left: HashObject, right: HashObject, parent: HashObject): void {
hashObjectsToUint32Array(left, right, uint32Input);
hashInto(hash64Input, hash64Output);
byteArrayIntoHashObject(hash64Output, 0, parent);
},
merkleizeInto(data: Uint8Array, padFor: number, output: Uint8Array, offset: number): void {
return doMerkleizeInto(data, padFor, output, offset, hashInto);
},
digestNLevel(data: Uint8Array, nLevel: number): Uint8Array {
return doDigestNLevel(data, nLevel, hashInto);
},
executeHashComputations(hashComputations: HashComputation[][]): void {
for (let level = hashComputations.length - 1; level >= 0; level--) {
const hcArr = hashComputations[level];
if (!hcArr) {
// should not happen
throw Error(`no hash computations for level ${level}`);
}

if (hcArr.length === 0) {
// nothing to hash
continue;
}

// size input array to 2 HashObject per computation * 32 bytes per object
// const input: Uint8Array = Uint8Array.from(new Array(hcArr.length * 2 * 32));
let destNodes: Node[] = [];

// hash every 16 inputs at once to avoid memory allocation
for (const [i, {src0, src1, dest}] of hcArr.entries()) {
const indexInBatch = i % PARALLEL_FACTOR;
const offset = indexInBatch * 16;

hashObjectToUint32Array(src0, uint32Input, offset);
hashObjectToUint32Array(src1, uint32Input, offset + 8);
destNodes.push(dest);
if (indexInBatch === PARALLEL_FACTOR - 1) {
hashInto(uint8Input, uint8Output);
for (const [j, destNode] of destNodes.entries()) {
byteArrayIntoHashObject(uint8Output, j * 32, destNode);
}
destNodes = [];
}
}

const remaining = hcArr.length % PARALLEL_FACTOR;
// we prepared data in input, now hash the remaining
if (remaining > 0) {
const remainingInput = uint8Input.subarray(0, remaining * 64);
const remainingOutput = uint8Output.subarray(0, remaining * 32);
hashInto(remainingInput, remainingOutput);
// destNodes was prepared above
for (const [i, destNode] of destNodes.entries()) {
byteArrayIntoHashObject(remainingOutput, i * 32, destNode);
}
}
}
},
};

function hashObjectToUint32Array(obj: HashObject, arr: Uint32Array, offset: number): void {
arr[offset] = obj.h0;
arr[offset + 1] = obj.h1;
arr[offset + 2] = obj.h2;
arr[offset + 3] = obj.h3;
arr[offset + 4] = obj.h4;
arr[offset + 5] = obj.h5;
arr[offset + 6] = obj.h6;
arr[offset + 7] = obj.h7;
}

// note that uint32ArrayToHashObject will cause more memory
function hashObjectsToUint32Array(obj1: HashObject, obj2: HashObject, arr: Uint32Array): void {
arr[0] = obj1.h0;
arr[1] = obj1.h1;
arr[2] = obj1.h2;
arr[3] = obj1.h3;
arr[4] = obj1.h4;
arr[5] = obj1.h5;
arr[6] = obj1.h6;
arr[7] = obj1.h7;
arr[8] = obj2.h0;
arr[9] = obj2.h1;
arr[10] = obj2.h2;
arr[11] = obj2.h3;
arr[12] = obj2.h4;
arr[13] = obj2.h5;
arr[14] = obj2.h6;
arr[15] = obj2.h7;
}
20 changes: 18 additions & 2 deletions packages/persistent-merkle-tree/src/hasher/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import {Hasher} from "./types";
import {hasher as nobleHasher} from "./noble";
import type {HashComputation} from "../node";

export {HashObject} from "@chainsafe/as-sha256/lib/hashObject";
export * from "./types";
export * from "./util";

/**
* Hasher used across the SSZ codebase
* Hasher used across the SSZ codebase, by default, this does not support batch hash.
*/
export let hasher: Hasher = nobleHasher;

Expand All @@ -18,3 +18,19 @@ export let hasher: Hasher = nobleHasher;
export function setHasher(newHasher: Hasher): void {
hasher = newHasher;
}

export function digest64(a: Uint8Array, b: Uint8Array): Uint8Array {
return hasher.digest64(a, b);
}

export function digestNLevel(data: Uint8Array, nLevel: number): Uint8Array {
return hasher.digestNLevel(data, nLevel);
}

export function merkleizeInto(data: Uint8Array, padFor: number, output: Uint8Array, offset: number): void {
hasher.merkleizeInto(data, padFor, output, offset);
}

export function executeHashComputations(hashComputations: HashComputation[][]): void {
hasher.executeHashComputations(hashComputations);
}
44 changes: 42 additions & 2 deletions packages/persistent-merkle-tree/src/hasher/noble.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
import {sha256} from "@noble/hashes/sha256";
import {digest64HashObjects, byteArrayIntoHashObject} from "@chainsafe/as-sha256";
import type {Hasher} from "./types";
import {hashObjectToUint8Array, uint8ArrayToHashObject} from "./util";
import {doDigestNLevel, doMerkleizeInto, hashObjectToUint8Array} from "./util";

const digest64 = (a: Uint8Array, b: Uint8Array): Uint8Array => sha256.create().update(a).update(b).digest();
const hashInto = (input: Uint8Array, output: Uint8Array): void => {
if (input.length % 64 !== 0) {
throw new Error(`Invalid input length ${input.length}`);
}
if (input.length !== output.length * 2) {
throw new Error(`Invalid output length ${output.length}`);
}

const count = Math.floor(input.length / 64);
for (let i = 0; i < count; i++) {
const offset = i * 64;
const in1 = input.subarray(offset, offset + 32);
const in2 = input.subarray(offset + 32, offset + 64);
const out = digest64(in1, in2);
output.set(out, i * 32);
}
};

export const hasher: Hasher = {
name: "noble",
digest64,
digest64HashObjects: (a, b) => uint8ArrayToHashObject(digest64(hashObjectToUint8Array(a), hashObjectToUint8Array(b))),
digest64HashObjects: (left, right, parent) => {
byteArrayIntoHashObject(digest64(hashObjectToUint8Array(left), hashObjectToUint8Array(right)), 0, parent);
},
merkleizeInto(data: Uint8Array, padFor: number, output: Uint8Array, offset: number): void {
return doMerkleizeInto(data, padFor, output, offset, hashInto);
},
digestNLevel(data: Uint8Array, nLevel: number): Uint8Array {
return doDigestNLevel(data, nLevel, hashInto);
},
executeHashComputations: (hashComputations) => {
for (let level = hashComputations.length - 1; level >= 0; level--) {
const hcArr = hashComputations[level];
if (!hcArr) {
// should not happen
throw Error(`no hash computations for level ${level}`);
}

for (const hc of hcArr) {
hc.dest.applyHash(digest64HashObjects(hc.src0, hc.src1));
}
}
},
};
Loading
Loading