diff --git a/core/trie/key.go b/core/trie/key.go index 7f0e6af609..2d94c4ad73 100644 --- a/core/trie/key.go +++ b/core/trie/key.go @@ -3,13 +3,14 @@ package trie import ( "bytes" "encoding/hex" - "errors" "fmt" "math/big" "github.com/NethermindEth/juno/core/felt" ) +var NilKey = &Key{len: 0, bitset: [32]byte{}} + type Key struct { len uint8 bitset [32]byte @@ -24,26 +25,6 @@ func NewKey(length uint8, keyBytes []byte) Key { return k } -func (k *Key) SubKey(n uint8) (*Key, error) { - if n > k.len { - return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len)) - } - - newKey := &Key{len: n} - copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:mnd - - // Shift right by the number of bits that are not needed - shift := k.len - n - for i := len(newKey.bitset) - 1; i >= 0; i-- { - newKey.bitset[i] >>= shift - if i > 0 { - newKey.bitset[i] |= newKey.bitset[i-1] << (8 - shift) - } - } - - return newKey, nil -} - func (k *Key) bytesNeeded() uint { const byteBits = 8 return (uint(k.len) + (byteBits - 1)) / byteBits @@ -96,24 +77,30 @@ func (k *Key) Equal(other *Key) bool { return k.len == other.len && k.bitset == other.bitset } -func (k *Key) Test(bit uint8) bool { +// IsBitSet returns whether the bit at the given position is 1. +// Position 0 represents the least significant (rightmost) bit. +func (k *Key) IsBitSet(position uint8) bool { const LSB = uint8(0x1) - byteIdx := bit / 8 + byteIdx := position / 8 byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1] - bitIdx := bit % 8 + bitIdx := position % 8 return ((byteAtIdx >> bitIdx) & LSB) != 0 } -func (k *Key) String() string { - return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:])) -} - -// DeleteLSB right shifts and shortens the key -func (k *Key) DeleteLSB(n uint8) { +// shiftRight removes n least significant bits from the key by performing a right shift +// operation and reducing the key length. For example, if the key contains bits +// "1111 0000" (length=8) and n=4, the result will be "1111" (length=4). +// +// The operation is destructive - it modifies the key in place. +func (k *Key) shiftRight(n uint8) { if k.len < n { panic("deleting more bits than there are") } + if n == 0 { + return + } + var bigInt big.Int bigInt.SetBytes(k.bitset[:]) bigInt.Rsh(&bigInt, uint(n)) @@ -121,6 +108,17 @@ func (k *Key) DeleteLSB(n uint8) { k.len -= n } +// MostSignificantBits returns a new key with the most significant n bits of the current key. +func (k *Key) MostSignificantBits(n uint8) (*Key, error) { + if n > k.len { + return nil, fmt.Errorf("cannot get more bits than the key length") + } + + keyCopy := k.Copy() + keyCopy.shiftRight(k.len - n) + return &keyCopy, nil +} + // Truncate truncates key to `length` bits by clearing the remaining upper bits func (k *Key) Truncate(length uint8) { k.len = length @@ -136,20 +134,53 @@ func (k *Key) Truncate(length uint8) { } } -func (k *Key) RemoveLastBit() { - if k.len == 0 { - return - } +func (k *Key) String() string { + return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:])) +} - k.len-- +// Copy returns a deep copy of the key +func (k *Key) Copy() Key { + newKey := Key{len: k.len} + copy(newKey.bitset[:], k.bitset[:]) + return newKey +} - unusedBytes := k.unusedBytes() - clear(unusedBytes) +func (k *Key) Bytes() [32]byte { + var result [32]byte + copy(result[:], k.bitset[:]) + return result +} - // clear upper bits on the last used byte - inUseBytes := k.inUseBytes() - unusedBitsCount := 8 - (k.len % 8) - if unusedBitsCount != 8 && len(inUseBytes) > 0 { - inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount +// findCommonKey finds the set of common MSB bits in two key bitsets. +func findCommonKey(longerKey, shorterKey *Key) (Key, bool) { + divergentBit := findDivergentBit(longerKey, shorterKey) + + if divergentBit == 0 { + return *NilKey, false } + + commonKey := *shorterKey + commonKey.shiftRight(shorterKey.Len() - divergentBit + 1) + return commonKey, divergentBit == shorterKey.Len()+1 +} + +// findDivergentBit finds the first bit that is different between two keys, +// starting from the most significant bit of both keys. +func findDivergentBit(longerKey, shorterKey *Key) uint8 { + divergentBit := uint8(0) + for divergentBit <= shorterKey.Len() && + longerKey.IsBitSet(longerKey.Len()-divergentBit) == shorterKey.IsBitSet(shorterKey.Len()-divergentBit) { + divergentBit++ + } + return divergentBit +} + +func isSubset(longerKey, shorterKey *Key) bool { + divergentBit := findDivergentBit(longerKey, shorterKey) + return divergentBit == shorterKey.Len()+1 +} + +func FeltToKey(length uint8, key *felt.Felt) Key { + keyBytes := key.Bytes() + return NewKey(length, keyBytes[:]) } diff --git a/core/trie/key_test.go b/core/trie/key_test.go index 8d56a31e0c..3867678e6e 100644 --- a/core/trie/key_test.go +++ b/core/trie/key_test.go @@ -68,47 +68,6 @@ func BenchmarkKeyEncoding(b *testing.B) { } } -func TestKeyTest(t *testing.T) { - key := trie.NewKey(44, []byte{0x10, 0x02}) - for i := 0; i < int(key.Len()); i++ { - assert.Equal(t, i == 1 || i == 12, key.Test(uint8(i)), i) - } -} - -func TestDeleteLSB(t *testing.T) { - key := trie.NewKey(16, []byte{0xF3, 0x04}) - - tests := map[string]struct { - shiftAmount uint8 - expectedKey trie.Key - }{ - "delete 0 bits": { - shiftAmount: 0, - expectedKey: key, - }, - "delete 4 bits": { - shiftAmount: 4, - expectedKey: trie.NewKey(12, []byte{0x0F, 0x30}), - }, - "delete 8 bits": { - shiftAmount: 8, - expectedKey: trie.NewKey(8, []byte{0xF3}), - }, - "delete 9 bits": { - shiftAmount: 9, - expectedKey: trie.NewKey(7, []byte{0x79}), - }, - } - - for desc, test := range tests { - t.Run(desc, func(t *testing.T) { - copyKey := key - copyKey.DeleteLSB(test.shiftAmount) - assert.Equal(t, test.expectedKey, copyKey) - }) - } -} - func TestTruncate(t *testing.T) { tests := map[string]struct { key trie.Key @@ -153,3 +112,118 @@ func TestTruncate(t *testing.T) { }) } } + +func TestKeyTest(t *testing.T) { + key := trie.NewKey(44, []byte{0x10, 0x02}) + for i := 0; i < int(key.Len()); i++ { + assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i) + } +} + +func TestIsBitSet(t *testing.T) { + tests := map[string]struct { + key trie.Key + position uint8 + expected bool + }{ + "single byte, LSB set": { + key: trie.NewKey(8, []byte{0x01}), + position: 0, + expected: true, + }, + "single byte, MSB set": { + key: trie.NewKey(8, []byte{0x80}), + position: 7, + expected: true, + }, + "single byte, middle bit set": { + key: trie.NewKey(8, []byte{0x10}), + position: 4, + expected: true, + }, + "single byte, bit not set": { + key: trie.NewKey(8, []byte{0xFE}), + position: 0, + expected: false, + }, + "multiple bytes, LSB set": { + key: trie.NewKey(16, []byte{0x00, 0x02}), + position: 1, + expected: true, + }, + "multiple bytes, MSB set": { + key: trie.NewKey(16, []byte{0x01, 0x00}), + position: 8, + expected: true, + }, + "multiple bytes, no bits set": { + key: trie.NewKey(16, []byte{0x00, 0x00}), + position: 7, + expected: false, + }, + "check all bits in pattern": { + key: trie.NewKey(8, []byte{0xA5}), // 10100101 + position: 0, + expected: true, + }, + } + + // Additional test for 0xA5 pattern + key := trie.NewKey(8, []byte{0xA5}) // 10100101 + expectedBits := []bool{true, false, true, false, false, true, false, true} + for i, expected := range expectedBits { + assert.Equal(t, expected, key.IsBitSet(uint8(i)), "bit %d in 0xA5", i) + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + result := tc.key.IsBitSet(tc.position) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestMostSignificantBits(t *testing.T) { + tests := []struct { + name string + key trie.Key + n uint8 + want trie.Key + expectErr bool + }{ + { + name: "Valid case", + key: trie.NewKey(8, []byte{0b11110000}), + n: 4, + want: trie.NewKey(4, []byte{0b00001111}), + expectErr: false, + }, + { + name: "Request more bits than available", + key: trie.NewKey(8, []byte{0b11110000}), + n: 10, + want: trie.Key{}, + expectErr: true, + }, + { + name: "Zero bits requested", + key: trie.NewKey(8, []byte{0b11110000}), + n: 0, + want: trie.NewKey(0, []byte{}), + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.key.MostSignificantBits(tt.n) + if (err != nil) != tt.expectErr { + t.Errorf("MostSignificantBits() error = %v, expectErr %v", err, tt.expectErr) + return + } + if !tt.expectErr && !got.Equal(&tt.want) { + t.Errorf("MostSignificantBits() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/core/trie/node.go b/core/trie/node.go index db9cb85206..2ef176f92a 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "errors" + "fmt" "github.com/NethermindEth/juno/core/felt" ) @@ -138,3 +139,55 @@ func (n *Node) UnmarshalBinary(data []byte) error { n.RightHash.SetBytes(data[:felt.Bytes]) return nil } + +func (n *Node) String() string { + return fmt.Sprintf("Node{Value: %s, Left: %s, Right: %s, LeftHash: %s, RightHash: %s}", n.Value, n.Left, n.Right, n.LeftHash, n.RightHash) +} + +// Update the receiver with non-nil fields from the `other` Node. +// If a field is non-nil in both Nodes, they must be equal, or an error is returned. +// +// This method modifies the receiver in-place and returns an error if any field conflicts are detected. +// +//nolint:gocyclo +func (n *Node) Update(other *Node) error { + // First validate all fields for conflicts + if n.Value != nil && other.Value != nil && !n.Value.Equal(other.Value) { + return fmt.Errorf("conflicting Values: %v != %v", n.Value, other.Value) + } + + if n.Left != nil && other.Left != nil && !n.Left.Equal(NilKey) && !other.Left.Equal(NilKey) && !n.Left.Equal(other.Left) { + return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left) + } + + if n.Right != nil && other.Right != nil && !n.Right.Equal(NilKey) && !other.Right.Equal(NilKey) && !n.Right.Equal(other.Right) { + return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right) + } + + if n.LeftHash != nil && other.LeftHash != nil && !n.LeftHash.Equal(other.LeftHash) { + return fmt.Errorf("conflicting LeftHash: %v != %v", n.LeftHash, other.LeftHash) + } + + if n.RightHash != nil && other.RightHash != nil && !n.RightHash.Equal(other.RightHash) { + return fmt.Errorf("conflicting RightHash: %v != %v", n.RightHash, other.RightHash) + } + + // After validation, perform all updates + if other.Value != nil { + n.Value = other.Value + } + if other.Left != nil && !other.Left.Equal(NilKey) { + n.Left = other.Left + } + if other.Right != nil && !other.Right.Equal(NilKey) { + n.Right = other.Right + } + if other.LeftHash != nil { + n.LeftHash = other.LeftHash + } + if other.RightHash != nil { + n.RightHash = other.RightHash + } + + return nil +} diff --git a/core/trie/proof.go b/core/trie/proof.go index 6dcbe3960c..bc4b66d0d9 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -4,18 +4,21 @@ import ( "errors" "fmt" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" ) -var ( - ErrUnknownProofNode = errors.New("unknown proof node") - ErrChildHashNotFound = errors.New("can't determine the child hash from the parent and child") -) +type ProofNodeSet = utils.OrderedSet[felt.Felt, ProofNode] + +func NewProofNodeSet() *ProofNodeSet { + return utils.NewOrderedSet[felt.Felt, ProofNode]() +} type ProofNode interface { Hash(hash hashFunc) *felt.Felt Len() uint8 - PrettyPrint() + String() string } type Binary struct { @@ -31,10 +34,8 @@ func (b *Binary) Len() uint8 { return 1 } -func (b *Binary) PrettyPrint() { - fmt.Printf(" Binary:\n") - fmt.Printf(" LeftHash: %v\n", b.LeftHash) - fmt.Printf(" RightHash: %v\n", b.RightHash) +func (b *Binary) String() string { + return fmt.Sprintf("Binary: %v:\n\tLeftHash: %v\n\tRightHash: %v\n", b.Hash(crypto.Pedersen), b.LeftHash, b.RightHash) } type Edge struct { @@ -54,623 +55,585 @@ func (e *Edge) Len() uint8 { return e.Path.Len() } -func (e *Edge) PrettyPrint() { - fmt.Printf(" Edge:\n") - fmt.Printf(" Child: %v\n", e.Child) - fmt.Printf(" Path: %v\n", e.Path) +func (e *Edge) String() string { + return fmt.Sprintf("Edge: %v:\n\tChild: %v\n\tPath: %v\n", e.Hash(crypto.Pedersen), e.Child, e.Path) } -func GetBoundaryProofs(leftBoundary, rightBoundary *Key, tri *Trie) ([2][]ProofNode, error) { - proofs := [2][]ProofNode{} - leftProof, err := GetProof(leftBoundary, tri) - if err != nil { - return proofs, err - } - rightProof, err := GetProof(rightBoundary, tri) +// Prove generates a Merkle proof for a given key in the trie. +// The result contains the proof nodes on the path from the root to the leaf. +// The value is included in the proof if the key is present in the trie. +// If the key is not present, the proof will contain the nodes on the path to the closest ancestor. +func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { + k := t.FeltToKey(key) + + nodesFromRoot, err := t.nodesFromRoot(&k) if err != nil { - return proofs, err + return err } - proofs[0] = leftProof - proofs[1] = rightProof - return proofs, nil -} -func isEdge(parentKey *Key, sNode StorageNode) bool { - sNodeLen := sNode.key.len - if parentKey == nil { // Root - return sNodeLen != 0 - } - return sNodeLen-parentKey.len > 1 -} + var parentKey *Key -// Note: we need to account for the fact that Junos Trie has nodes that are Binary AND Edge, -// whereas the protocol requires nodes that are Binary XOR Edge -func transformNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { - isEdgeBool := isEdge(parentKey, sNode) + for i, sNode := range nodesFromRoot { + sNodeEdge, sNodeBinary, err := storageNodeToProofNode(t, parentKey, sNode) + if err != nil { + return err + } + isLeaf := sNode.key.len == t.height - var edge *Edge - if isEdgeBool { - edgePath := path(sNode.key, parentKey) - edge = &Edge{ - Path: &edgePath, - Child: sNode.node.Value, + if sNodeEdge != nil && !isLeaf { // Internal Edge + proof.Put(*sNodeEdge.Hash(t.hash), sNodeEdge) + proof.Put(*sNodeBinary.Hash(t.hash), sNodeBinary) + } else if sNodeEdge == nil && !isLeaf { // Internal Binary + proof.Put(*sNodeBinary.Hash(t.hash), sNodeBinary) + } else if sNodeEdge != nil && isLeaf { // Leaf Edge + proof.Put(*sNodeEdge.Hash(t.hash), sNodeEdge) + } else if sNodeEdge == nil && sNodeBinary == nil { // sNode is a binary leaf + break } + parentKey = nodesFromRoot[i].key } - if sNode.key.len == tri.height { // Leaf - return edge, nil, nil - } - lNode, err := tri.GetNodeFromKey(sNode.node.Left) - if err != nil { - return nil, nil, err - } - rNode, err := tri.GetNodeFromKey(sNode.node.Right) + return nil +} + +// GetRangeProof generates a range proof for the given range of keys. +// The proof contains the proof nodes on the path from the root to the closest ancestor of the left and right keys. +func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSet) error { + err := t.Prove(leftKey, proofSet) if err != nil { - return nil, nil, err + return err } - rightHash := rNode.Value - if isEdge(sNode.key, StorageNode{node: rNode, key: sNode.node.Right}) { - edgePath := path(sNode.node.Right, sNode.key) - rEdge := &Edge{ - Path: &edgePath, - Child: rNode.Value, - } - rightHash = rEdge.Hash(tri.hash) - } - leftHash := lNode.Value - if isEdge(sNode.key, StorageNode{node: lNode, key: sNode.node.Left}) { - edgePath := path(sNode.node.Left, sNode.key) - lEdge := &Edge{ - Path: &edgePath, - Child: lNode.Value, - } - leftHash = lEdge.Hash(tri.hash) + // If they are the same key, don't need to generate the proof again + if leftKey.Equal(rightKey) { + return nil } - binary := &Binary{ - LeftHash: leftHash, - RightHash: rightHash, - } - - return edge, binary, nil -} -// pathSplitOccurredCheck checks if there happens at most one split in the merged path -// loops through the merged paths if left and right hashes of a node exist in the nodeHashes -// then a split happened in case of multiple splits it returns an error -func pathSplitOccurredCheck(mergedPath []ProofNode, nodeHashes map[felt.Felt]ProofNode) error { - splitHappened := false - for _, node := range mergedPath { - switch node := node.(type) { - case *Edge: - continue - case *Binary: - _, leftExists := nodeHashes[*node.LeftHash] - _, rightExists := nodeHashes[*node.RightHash] - if leftExists && rightExists { - if splitHappened { - return errors.New("split happened more than once") - } - splitHappened = true - } - default: - return fmt.Errorf("%w: %T", ErrUnknownProofNode, node) - } + err = t.Prove(rightKey, proofSet) + if err != nil { + return err } + return nil } -func rootNodeExistsCheck(rootHash *felt.Felt, nodeHashes map[felt.Felt]ProofNode) (ProofNode, error) { - currNode, rootExists := nodeHashes[*rootHash] - if !rootExists { - return currNode, errors.New("root hash not found in the merged path") - } +// VerifyProof verifies that a proof path is valid for a given key in a binary trie. +// It walks through the proof nodes, verifying each step matches the expected path to reach the key. +// +// The verification process: +// 1. Starts at the root hash and retrieves the corresponding proof node +// 2. For each proof node: +// - Verifies the node's computed hash matches the expected hash +// - For Binary nodes: +// -- Uses the next unprocessed bit in the key to choose left/right path +// -- If key bit is 0, takes left path; if 1, takes right path +// - For Edge nodes: +// -- Verifies the compressed path matches the corresponding bits in the key +// -- Moves to the child node if paths match +// +// 3. Continues until all bits in the key are processed +// +// The proof is considered invalid if: +// - Any proof node is missing from the OrderedSet +// - Any node's computed hash doesn't match its expected hash +// - The path bits don't match the key bits +// - The proof ends before processing all key bits +func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash hashFunc) (*felt.Felt, error) { + key := FeltToKey(globalTrieHeight, keyFelt) + expectedHash := root + keyLen := key.Len() - return currNode, nil -} + var curPos uint8 + for { + proofNode, ok := proof.Get(*expectedHash) + if !ok { + return nil, fmt.Errorf("proof node not found, expected hash: %s", expectedHash.String()) + } -// traverseNodes traverses the merged proof path starting at `currNode` -// and adds nodes to `path` slice. It stops when the split node is added -// or the path is exhausted, and `currNode` children are not included -// in the path (nodeHashes) -func traverseNodes(currNode ProofNode, path *[]ProofNode, nodeHashes map[felt.Felt]ProofNode) { - *path = append(*path, currNode) + // Verify the hash matches + if !proofNode.Hash(hash).Equal(expectedHash) { + return nil, fmt.Errorf("proof node hash mismatch, expected hash: %s, got hash: %s", expectedHash.String(), proofNode.Hash(hash).String()) + } - switch currNode := currNode.(type) { - case *Binary: - nodeLeft, leftExist := nodeHashes[*currNode.LeftHash] - nodeRight, rightExist := nodeHashes[*currNode.RightHash] + switch node := proofNode.(type) { + case *Binary: // Binary nodes represent left/right choices + if key.Len() <= curPos { + return nil, fmt.Errorf("key length less than current position, key length: %d, current position: %d", key.Len(), curPos) + } + // Determine the next node to traverse based on the next bit position + expectedHash = node.LeftHash + if key.IsBitSet(keyLen - curPos - 1) { + expectedHash = node.RightHash + } + curPos++ + case *Edge: // Edge nodes represent paths between binary nodes + if !verifyEdgePath(&key, node.Path, curPos) { + return &felt.Zero, nil + } - if leftExist && rightExist { - return - } else if leftExist { - traverseNodes(nodeLeft, path, nodeHashes) - } else if rightExist { - traverseNodes(nodeRight, path, nodeHashes) + // Move to the immediate child node + curPos += node.Path.Len() + expectedHash = node.Child } - case *Edge: - edgeNode, exist := nodeHashes[*currNode.Child] - if exist { - traverseNodes(edgeNode, path, nodeHashes) + + // We've consumed all bits in our path + if curPos >= keyLen { + return expectedHash, nil } } } -// MergeProofPaths removes duplicates and merges proof paths into a single path -// merges paths in the specified order [commonNodes..., leftNodes..., rightNodes...] -// ordering of the merged path is not important -// since SplitProofPath can discover the left and right paths using the merged path and the rootHash -func MergeProofPaths(leftPath, rightPath []ProofNode, hash hashFunc) ([]ProofNode, *felt.Felt, error) { - merged := []ProofNode{} - minLen := min(len(leftPath), len(rightPath)) - - if len(leftPath) == 0 || len(rightPath) == 0 { - return merged, nil, errors.New("empty proof paths") - } - - if !leftPath[0].Hash(hash).Equal(rightPath[0].Hash(hash)) { - return merged, nil, errors.New("roots of the proof paths are different") +// VerifyRangeProof checks the validity of given key-value pairs and range proof against a provided root hash. +// The key-value pairs should be consecutive (no gaps) and monotonically increasing. +// The range proof contains two edge proofs: one for the first key and another for the last key. +// Both edge proofs can be for existent or non-existent keys. +// This function handles the following special cases: +// +// - All elements proof: The proof can be nil if the range includes all leaves in the trie. +// - Single element proof: Both left and right edge proofs are identical, and the range contains only one element. +// - Zero element proof: A single edge proof suffices for verification. The proof is invalid if there are additional elements. +// +// The function returns a boolean indicating if there are more elements and an error if the range proof is invalid. +// +// TODO(weiihann): Given a binary leaf and a left-sibling first key, if the right sibling is removed, the proof would still be valid. +// Conversely, given a binary leaf and a right-sibling last key, if the left sibling is removed, the proof would still be valid. +// Range proof should not be valid for both of these cases, but currently is, which is an attack vector. +// The problem probably lies in how we do root hash calculation. +func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof *ProofNodeSet) (bool, error) { //nolint:funlen,gocyclo + // Ensure the number of keys and values are the same + if len(keys) != len(values) { + return false, fmt.Errorf("inconsistent length of proof data, keys: %d, values: %d", len(keys), len(values)) } - rootHash := leftPath[0].Hash(hash) - - // Get duplicates and insert by one - i := 0 - for i = 0; i < minLen; i++ { - leftNode := leftPath[i] - rightNode := rightPath[i] + // Ensure all keys are monotonically increasing and values contain no deletions + for i := 0; i < len(keys); i++ { + if i < len(keys)-1 && keys[i].Cmp(keys[i+1]) > 0 { + return false, errors.New("keys are not monotonic increasing") + } - if leftNode.Hash(hash).Equal(rightNode.Hash(hash)) { - merged = append(merged, leftNode) - } else { - break + if values[i] == nil || values[i].Equal(&felt.Zero) { + return false, errors.New("range contains empty leaf") } } - // Add rest of the nodes - merged = append(merged, leftPath[i:]...) - merged = append(merged, rightPath[i:]...) - - return merged, rootHash, nil -} - -// SplitProofPath splits the merged proof path into two paths (left and right), which were merged before -// it first validates that the merged path is not circular, the split happens at most once and rootHash exists -// then calls traverseNodes to split the path to left and right paths -func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash hashFunc) ([]ProofNode, []ProofNode, error) { - commonPath := []ProofNode{} - leftPath := []ProofNode{} - rightPath := []ProofNode{} - nodeHashes := make(map[felt.Felt]ProofNode) - - for _, node := range mergedPath { - nodeHash := node.Hash(hash) - _, nodeExists := nodeHashes[*nodeHash] - - if nodeExists { - return leftPath, rightPath, errors.New("duplicate node in the merged path") + // Special case: no edge proof provided; the given range contains all leaves in the trie + if proof == nil { + tr, err := buildTrie(globalTrieHeight, nil, nil, keys, values) + if err != nil { + return false, err } - nodeHashes[*nodeHash] = node - } - if len(mergedPath) == 0 { - return leftPath, rightPath, nil - } + recomputedRoot, err := tr.Root() + if err != nil { + return false, err + } - currNode, err := rootNodeExistsCheck(rootHash, nodeHashes) - if err != nil { - return leftPath, rightPath, err - } + if !recomputedRoot.Equal(root) { + return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) + } - if err := pathSplitOccurredCheck(mergedPath, nodeHashes); err != nil { - return leftPath, rightPath, err + return false, nil // no more elements available } - traverseNodes(currNode, &commonPath, nodeHashes) - - leftPath = append(leftPath, commonPath...) - rightPath = append(rightPath, commonPath...) + nodes := NewStorageNodeSet() + firstKey := FeltToKey(globalTrieHeight, first) - currNode = commonPath[len(commonPath)-1] - - leftNode := nodeHashes[*currNode.(*Binary).LeftHash] - rightNode := nodeHashes[*currNode.(*Binary).RightHash] - - traverseNodes(leftNode, &leftPath, nodeHashes) - traverseNodes(rightNode, &rightPath, nodeHashes) + // Special case: there is a provided proof but no key-value pairs, make sure regenerated trie has no more values + // Empty range proof with more elements on the right is not accepted in this function. + // This is due to snap sync specification detail, where the responder must send an existing key (if any) if the requested range is empty. + if len(keys) == 0 { + rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) + if err != nil { + return false, err + } - return leftPath, rightPath, nil -} + if val != nil || hasRightElement(rootKey, &firstKey, nodes) { + return false, errors.New("more entries available") + } -// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L514 -// GetProof generates a set of proof nodes from the root to the leaf. -// The proof never contains the leaf node if it is set, as we already know it's hash. -func GetProof(key *Key, tri *Trie) ([]ProofNode, error) { - nodesFromRoot, err := tri.nodesFromRoot(key) - if err != nil { - return nil, err + return false, nil } - proofNodes := []ProofNode{} - var parentKey *Key + last := keys[len(keys)-1] + lastKey := FeltToKey(globalTrieHeight, last) - for i, sNode := range nodesFromRoot { - sNodeEdge, sNodeBinary, err := transformNode(tri, parentKey, sNode) + // Special case: there is only one element and two edge keys are the same + if len(keys) == 1 && firstKey.Equal(&lastKey) { + rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) if err != nil { - return nil, err + return false, err } - isLeaf := sNode.key.len == tri.height - if sNodeEdge != nil && !isLeaf { // Internal Edge - proofNodes = append(proofNodes, sNodeEdge, sNodeBinary) - } else if sNodeEdge == nil && !isLeaf { // Internal Binary - proofNodes = append(proofNodes, sNodeBinary) - } else if sNodeEdge != nil && isLeaf { // Leaf Edge - proofNodes = append(proofNodes, sNodeEdge) - } else if sNodeEdge == nil && sNodeBinary == nil { // sNode is a binary leaf - break + elementKey := FeltToKey(globalTrieHeight, keys[0]) + if !firstKey.Equal(&elementKey) { + return false, errors.New("correct proof but invalid key") } - parentKey = nodesFromRoot[i].key - } - return proofNodes, nil -} -// VerifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes` -// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006 -func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool { - expectedHash := root - remainingPath := NewKey(key.len, key.bitset[:]) - for i, proofNode := range proofs { - if !proofNode.Hash(hash).Equal(expectedHash) { - return false + if val == nil || !values[0].Equal(val) { + return false, errors.New("correct proof but invalid value") } - switch proofNode := proofNode.(type) { - case *Binary: - if remainingPath.Test(remainingPath.Len() - 1) { - expectedHash = proofNode.RightHash - } else { - expectedHash = proofNode.LeftHash - } - remainingPath.RemoveLastBit() - case *Edge: - subKey, err := remainingPath.SubKey(proofNode.Path.Len()) - if err != nil { - return false - } - - // Todo: - // If we are verifying the key doesn't exist, then we should - // update subKey to point in the other direction - if value == nil && i == len(proofs)-1 { - return true - } - - if !proofNode.Path.Equal(subKey) { - return false - } - expectedHash = proofNode.Child - remainingPath.Truncate(251 - proofNode.Path.Len()) //nolint:mnd - } + return hasRightElement(rootKey, &firstKey, nodes), nil } - return expectedHash.Equal(value) -} - -// VerifyRangeProof verifies the range proof for the given range of keys. -// This is achieved by constructing a trie from the boundary proofs, and the supplied key-values. -// If the root of the reconstructed trie matches the supplied root, then the verification passes. -// If the trie is constructed incorrectly then the root will have an incorrect key(len,path), and value, -// and therefore it's hash won't match the expected root. -// ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484 -func VerifyRangeProof(root *felt.Felt, keys, values []*felt.Felt, proofKeys [2]*Key, proofValues [2]*felt.Felt, - proofs [2][]ProofNode, hash hashFunc, -) (bool, error) { - // Step 0: checks - if len(keys) != len(values) { - return false, fmt.Errorf("inconsistent proof data, number of keys: %d, number of values: %d", len(keys), len(values)) + // In all other cases, we require two edge paths available. + // First, ensure that the last key is greater than the first key + if last.Cmp(first) <= 0 { + return false, errors.New("last key is less than first key") } - // Ensure all keys are monotonic increasing - if err := ensureMonotonicIncreasing(proofKeys, keys); err != nil { + rootKey, _, err := proofToPath(root, &firstKey, proof, nodes) + if err != nil { return false, err } - // Ensure the inner values contain no deletions - for _, value := range values { - if value.Equal(&felt.Zero) { - return false, errors.New("range contains deletion") - } + lastRootKey, _, err := proofToPath(root, &lastKey, proof, nodes) + if err != nil { + return false, err } - // Step 1: Verify proofs, and get proof paths - var proofPaths [2][]StorageNode - var err error - for i := 0; i < 2; i++ { - if proofs[i] != nil { - if !VerifyProof(root, proofKeys[i], proofValues[i], proofs[i], hash) { - return false, fmt.Errorf("invalid proof for key %x", proofKeys[i].String()) - } - - proofPaths[i], err = ProofToPath(proofs[i], proofKeys[i], hash) - if err != nil { - return false, err - } - } + if !rootKey.Equal(lastRootKey) { + return false, errors.New("first and last root keys do not match") } - // Step 2: Build trie from proofPaths and keys - tmpTrie, err := BuildTrie(proofPaths[0], proofPaths[1], keys, values) + // Build the trie from the proof paths + tr, err := buildTrie(globalTrieHeight, rootKey, nodes.List(), keys, values) if err != nil { return false, err } // Verify that the recomputed root hash matches the provided root hash - recomputedRoot, err := tmpTrie.Root() + recomputedRoot, err := tr.Root() if err != nil { return false, err } + if !recomputedRoot.Equal(root) { - return false, errors.New("root hash mismatch") + return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) } - return true, nil + return hasRightElement(rootKey, &lastKey, nodes), nil } -func ensureMonotonicIncreasing(proofKeys [2]*Key, keys []*felt.Felt) error { - if proofKeys[0] != nil { - leftProofFelt := proofKeys[0].Felt() - if leftProofFelt.Cmp(keys[0]) >= 0 { - return errors.New("range is not monotonically increasing") - } - } - if proofKeys[1] != nil { - rightProofFelt := proofKeys[1].Felt() - if keys[len(keys)-1].Cmp(&rightProofFelt) >= 0 { - return errors.New("range is not monotonically increasing") - } - } - if len(keys) >= 2 { - for i := 0; i < len(keys)-1; i++ { - if keys[i].Cmp(keys[i+1]) >= 0 { - return errors.New("range is not monotonically increasing") - } - } +// isEdge checks if the storage node is an edge node. +func isEdge(parentKey *Key, sNode StorageNode) bool { + sNodeLen := sNode.key.len + if parentKey == nil { // Root + return sNodeLen != 0 } - return nil + return sNodeLen-parentKey.len > 1 } -// compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key -func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, error) { - parent := proofNodes[idx] - - if idx == len(proofNodes)-1 { - if _, ok := parent.(*Edge); ok { - return 1, parent.Len(), nil +// storageNodeToProofNode converts a StorageNode to the ProofNode(s). +// Juno's Trie has nodes that are Binary AND Edge, whereas the protocol requires nodes that are Binary XOR Edge. +// We need to convert the former to the latter for proof generation. +func storageNodeToProofNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { + var edge *Edge + if isEdge(parentKey, sNode) { + edgePath := path(sNode.key, parentKey) + edge = &Edge{ + Path: &edgePath, + Child: sNode.node.Value, } - return 0, parent.Len(), nil + } + if sNode.key.len == tri.height { // Leaf + return edge, nil, nil + } + lNode, err := tri.GetNodeFromKey(sNode.node.Left) + if err != nil { + return nil, nil, err + } + rNode, err := tri.GetNodeFromKey(sNode.node.Right) + if err != nil { + return nil, nil, err } - child := proofNodes[idx+1] - _, isChildBinary := child.(*Binary) - isChildEdge := !isChildBinary - switch parent := parent.(type) { - case *Edge: - if isChildEdge { - break - } - return 1, parent.Len(), nil - case *Binary: - if isChildBinary { - break + rightHash := rNode.Value + if isEdge(sNode.key, StorageNode{node: rNode, key: sNode.node.Right}) { + edgePath := path(sNode.node.Right, sNode.key) + rEdge := &Edge{ + Path: &edgePath, + Child: rNode.Value, } - childHash := child.Hash(hashF) - if parent.LeftHash.Equal(childHash) || parent.RightHash.Equal(childHash) { - return 1, child.Len(), nil + rightHash = rEdge.Hash(tri.hash) + } + leftHash := lNode.Value + if isEdge(sNode.key, StorageNode{node: lNode, key: sNode.node.Left}) { + edgePath := path(sNode.node.Left, sNode.key) + lEdge := &Edge{ + Path: &edgePath, + Child: lNode.Value, } - return 0, 0, ErrChildHashNotFound + leftHash = lEdge.Hash(tri.hash) + } + binary := &Binary{ + LeftHash: leftHash, + RightHash: rightHash, } - return 0, 1, nil + return edge, binary, nil } -func assignChild(i, compressedParent int, parentNode *Node, - nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF hashFunc, -) (*Key, error) { - childInd := i + compressedParent + 1 - childKey, err := getChildKey(childInd, parentKey, leafKey, nilKey, proofNodes, hashF) +// proofToPath converts a Merkle proof to trie node path. All necessary nodes will be resolved and leave the remaining +// as hashes. The given edge proof can be existent or non-existent. +func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageNodeSet) (*Key, *felt.Felt, error) { + rootKey, val, err := buildPath(root, key, 0, nil, proof, nodes) if err != nil { - return nil, err - } - if leafKey.Test(leafKey.len - parentKey.len - 1) { - parentNode.Right = childKey - parentNode.Left = nilKey - } else { - parentNode.Right = nilKey - parentNode.Left = childKey + return nil, nil, err } - return childKey, nil -} -// ProofToPath returns a set of storage nodes from the root to the end of the proof path. -// The storage nodes will have the hashes of the children, but only the key of the child -// along the path outlined by the proof. -func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]StorageNode, error) { - pathNodes := []StorageNode{} - - // Child keys that can't be derived are set to nilKey, so that we can store the node - zeroFeltBytes := new(felt.Felt).Bytes() - nilKey := NewKey(0, zeroFeltBytes[:]) - - for i, pNode := range proofNodes { - // Keep moving along the path (may need to skip nodes that were compressed into the last path node) - if i != 0 { - if skipNode(pNode, pathNodes, hashF) { - continue - } + // Special case: non-existent key at the root + // We must include the root node in the node set. + // We will only get the following two cases: + // 1. The root node is an edge node only where path.len == key.len (single key trie) + // 2. The root node is an edge node + binary node (double key trie) + if nodes.Size() == 0 { + proofNode, ok := proof.Get(*root) + if !ok { + return nil, nil, fmt.Errorf("root proof node not found: %s", root) } - var parentKey *Key - parentNode := Node{} - - // Set the key of the current node - compressParent, compressParentOffset, err := compressNode(i, proofNodes, hashF) - if err != nil { - return nil, err - } - parentKey, err = getParentKey(i, compressParentOffset, leafKey, pNode, pathNodes, proofNodes) - if err != nil { - return nil, err + edge, ok := proofNode.(*Edge) + if !ok { + return nil, nil, fmt.Errorf("expected edge node at root, got: %T", proofNode) } - // Don't store leafs along proof paths - if parentKey.len == 251 { //nolint:mnd - break - } + sn := NewPartialStorageNode(edge.Path, edge.Child) - // Set the value of the current node - parentNode.Value = pNode.Hash(hashF) + // Handle leaf edge case (single key trie) + if edge.Path.Len() == key.Len() { + if err := nodes.Put(*sn.key, sn); err != nil { + return nil, nil, fmt.Errorf("failed to store leaf edge: %w", err) + } + return sn.Key(), sn.Value(), nil + } - // Set the child key of the current node. - childKey, err := assignChild(i, compressParent, &parentNode, &nilKey, leafKey, parentKey, proofNodes, hashF) - if err != nil { - return nil, err + // Handle edge + binary case (double key trie) + child, ok := proof.Get(*edge.Child) + if !ok { + return nil, nil, fmt.Errorf("edge child not found: %s", edge.Child) } - // Set the LeftHash and RightHash values - parentNode.LeftHash, parentNode.RightHash, err = getLeftRightHash(i, proofNodes) - if err != nil { - return nil, err + binary, ok := child.(*Binary) + if !ok { + return nil, nil, fmt.Errorf("expected binary node as child, got: %T", child) } - pathNodes = append(pathNodes, StorageNode{key: parentKey, node: &parentNode}) + sn.node.LeftHash = binary.LeftHash + sn.node.RightHash = binary.RightHash - // break early since we don't store leafs along proof paths, or if no more nodes exist along the proof paths - if childKey.len == 0 || childKey.len == 251 { - break + if err := nodes.Put(*sn.key, sn); err != nil { + return nil, nil, fmt.Errorf("failed to store edge+binary: %w", err) } + rootKey = sn.Key() } - return pathNodes, nil + return rootKey, val, nil } -func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF hashFunc) bool { - lastNode := pathNodes[len(pathNodes)-1].node - noLeftMatch, noRightMatch := false, false - if lastNode.LeftHash != nil && !pNode.Hash(hashF).Equal(lastNode.LeftHash) { - noLeftMatch = true - } - if lastNode.RightHash != nil && !pNode.Hash(hashF).Equal(lastNode.RightHash) { - noRightMatch = true - } - if noLeftMatch && noRightMatch { - return true +// buildPath recursively builds the path for a given node hash, key, and current position. +// It returns the current node's key and any leaf value found along this path. +func buildPath( + nodeHash *felt.Felt, + key *Key, + curPos uint8, + curNode *StorageNode, + proof *ProofNodeSet, + nodes *StorageNodeSet, +) (*Key, *felt.Felt, error) { + // We reached the leaf + if curPos == key.Len() { + leafKey := key.Copy() + leafNode := NewPartialStorageNode(&leafKey, nodeHash) + if err := nodes.Put(leafKey, leafNode); err != nil { + return nil, nil, err + } + return leafNode.Key(), leafNode.Value(), nil } - return false -} -func getLeftRightHash(parentInd int, proofNodes []ProofNode) (*felt.Felt, *felt.Felt, error) { - parent := proofNodes[parentInd] + proofNode, ok := proof.Get(*nodeHash) + if !ok { // non-existent proof node + return NilKey, nil, nil + } - switch parent := parent.(type) { + switch pn := proofNode.(type) { case *Binary: - return parent.LeftHash, parent.RightHash, nil + return handleBinaryNode(pn, nodeHash, key, curPos, curNode, proof, nodes) case *Edge: - if parentInd+1 > len(proofNodes)-1 { - return nil, nil, errors.New("cant get hash of children from proof node, out of range") - } - parentBinary := proofNodes[parentInd+1].(*Binary) - return parentBinary.LeftHash, parentBinary.RightHash, nil - default: - return nil, nil, fmt.Errorf("%w: %T", ErrUnknownProofNode, parent) + return handleEdgeNode(pn, key, curPos, proof, nodes) } -} -func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key, - pNode ProofNode, pathNodes []StorageNode, proofNodes []ProofNode, -) (*Key, error) { - var crntKey *Key - var err error + return nil, nil, nil +} - var height uint8 - if len(pathNodes) > 0 { - if p, ok := proofNodes[idx].(*Edge); ok { - height = pathNodes[len(pathNodes)-1].key.len + p.Path.len - } else { - height = pathNodes[len(pathNodes)-1].key.len + 1 +// handleBinaryNode processes a binary node in the proof path by creating/updating a storage node, +// setting its left/right hashes, and recursively building the path for the appropriate child direction. +// It returns the current node's key and any leaf value found along this path. +func handleBinaryNode( + binary *Binary, + nodeHash *felt.Felt, + key *Key, + curPos uint8, + curNode *StorageNode, + proof *ProofNodeSet, + nodes *StorageNodeSet, +) (*Key, *felt.Felt, error) { + // If curNode is nil, it means that this current binary node is the root node. + // Or, it's an internal binary node and the parent is also a binary node. + // A standalone binary proof node always corresponds to a single storage node. + // If curNode is not nil, it means that the parent node is an edge node. + // In this case, the key of the storage node is based on the parent edge node. + if curNode == nil { + nodeKey, err := key.MostSignificantBits(curPos) + if err != nil { + return nil, nil, err } + curNode = NewPartialStorageNode(nodeKey, nodeHash) } + curNode.node.LeftHash = binary.LeftHash + curNode.node.RightHash = binary.RightHash - if _, ok := pNode.(*Binary); ok { - crntKey, err = leafKey.SubKey(height) + // Calculate next position and determine to take left or right path + nextPos := curPos + 1 + isRightPath := key.IsBitSet(key.Len() - nextPos) + nextHash := binary.LeftHash + if isRightPath { + nextHash = binary.RightHash + } + + childKey, val, err := buildPath(nextHash, key, nextPos, nil, proof, nodes) + if err != nil { + return nil, nil, err + } + + // Set child reference + if isRightPath { + curNode.node.Right = childKey } else { - crntKey, err = leafKey.SubKey(height + compressedParentOffset) + curNode.node.Left = childKey + } + + if err := nodes.Put(*curNode.key, curNode); err != nil { + return nil, nil, fmt.Errorf("failed to store binary node: %w", err) } - return crntKey, err + + return curNode.Key(), val, nil } -func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF hashFunc) (*Key, error) { - if childIdx > len(proofNodes)-1 { - return nilKey, nil +// handleEdgeNode processes an edge node in the proof path by verifying the edge path matches +// the key path and either creating a leaf node or continuing to traverse the trie. It returns +// the current node's key and any leaf value found along this path. +func handleEdgeNode( + edge *Edge, + key *Key, + curPos uint8, + proof *ProofNodeSet, + nodes *StorageNodeSet, +) (*Key, *felt.Felt, error) { + // Verify the edge path matches the key path + if !verifyEdgePath(key, edge.Path, curPos) { + return NilKey, nil, nil + } + + // The next node position is the end of the edge path + nextPos := curPos + edge.Path.Len() + nodeKey, err := key.MostSignificantBits(nextPos) + if err != nil { + return nil, nil, fmt.Errorf("failed to get MSB for internal edge: %w", err) } + curNode := NewPartialStorageNode(nodeKey, edge.Child) - compressChild, compressChildOffset, err := compressNode(childIdx, proofNodes, hashF) + // This is an edge leaf, stop traversing the trie + if nextPos == key.Len() { + if err := nodes.Put(*curNode.key, curNode); err != nil { + return nil, nil, fmt.Errorf("failed to store edge leaf: %w", err) + } + return curNode.Key(), curNode.Value(), nil + } + + _, val, err := buildPath(edge.Child, key, nextPos, curNode, proof, nodes) if err != nil { - return nil, err + return nil, nil, fmt.Errorf("failed to build child path: %w", err) } - if crntKey.len+uint8(compressChild)+compressChildOffset == 251 { //nolint:mnd - return nilKey, nil + if err := nodes.Put(*curNode.key, curNode); err != nil { + return nil, nil, fmt.Errorf("failed to store internal edge: %w", err) } - return leafKey.SubKey(crntKey.len + uint8(compressChild) + compressChildOffset) + return curNode.Key(), val, nil } -// BuildTrie builds a trie using the proof paths (including inner nodes), and then sets all the keys-values (leaves) -func BuildTrie(leftProofPath, rightProofPath []StorageNode, keys, values []*felt.Felt) (*Trie, error) { //nolint:gocyclo - tempTrie, err := NewTriePedersen(newMemStorage(), 251) //nolint:mnd +// verifyEdgePath checks if the edge path matches the key path at the current position. +func verifyEdgePath(key, edgePath *Key, curPos uint8) bool { + if key.Len() < curPos+edgePath.Len() { + return false + } + + // Ensure the bits between segment of the key and the node path match + start := key.Len() - curPos - edgePath.Len() + end := key.Len() - curPos + for i := start; i < end; i++ { + if key.IsBitSet(i) != edgePath.IsBitSet(i-start) { + return false // paths diverge - this proves non-membership + } + } + return true +} + +// buildTrie builds a trie from a list of storage nodes and a list of keys and values. +func buildTrie(height uint8, rootKey *Key, nodes []*StorageNode, keys, values []*felt.Felt) (*Trie, error) { + tr, err := NewTriePedersen(newMemStorage(), height) if err != nil { return nil, err } - // merge proof paths - for i := range min(len(leftProofPath), len(rightProofPath)) { - // Can't store nil keys so stop merging - if leftProofPath[i].node.Left == nil || leftProofPath[i].node.Right == nil || - rightProofPath[i].node.Left == nil || rightProofPath[i].node.Right == nil { - break - } - if leftProofPath[i].key.Equal(rightProofPath[i].key) { - leftProofPath[i].node.Right = rightProofPath[i].node.Right - rightProofPath[i].node.Left = leftProofPath[i].node.Left - } else { - break + tr.setRootKey(rootKey) + + // Nodes are inserted in reverse order because the leaf nodes are placed at the front of the list. + // We would want to insert root node first so the root key is set first. + for i := len(nodes) - 1; i >= 0; i-- { + if err := tr.PutInner(nodes[i].key, nodes[i].node); err != nil { + return nil, err } } - for _, sNode := range leftProofPath { - if sNode.node.Left == nil || sNode.node.Right == nil { - break - } - _, err := tempTrie.PutInner(sNode.key, sNode.node) + for index, key := range keys { + _, err = tr.PutWithProof(key, values[index], nodes) if err != nil { return nil, err } } - for _, sNode := range rightProofPath { - if sNode.node.Left == nil || sNode.node.Right == nil { - break + return tr, nil +} + +// hasRightElement checks if there is a right sibling for the given key in the trie. +// This function assumes that the entire path has been resolved. +func hasRightElement(rootKey, key *Key, nodes *StorageNodeSet) bool { + cur := rootKey + for cur != nil && !cur.Equal(NilKey) { + sn, ok := nodes.Get(*cur) + if !ok { + return false } - _, err := tempTrie.PutInner(sNode.key, sNode.node) - if err != nil { - return nil, err + + // We resolved the entire path, no more elements + if key.Equal(cur) { + return false } - } - for i := range len(keys) { - _, err := tempTrie.PutWithProof(keys[i], values[i], leftProofPath, rightProofPath) - if err != nil { - return nil, err + // If we're taking a left path and there's a right sibling, + // then there are elements with larger values + bitPos := key.Len() - cur.Len() - 1 + isLeft := !key.IsBitSet(bitPos) + if isLeft && sn.node.RightHash != nil { + return true + } + + // Move to next node based on the path + cur = sn.node.Right + if isLeft { + cur = sn.node.Left } } - return tempTrie, nil + + return false } diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index e6d8576d83..94eaabc549 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -1,6 +1,8 @@ package trie_test import ( + "math/rand" + "sort" "testing" "github.com/NethermindEth/juno/core/crypto" @@ -8,1132 +10,818 @@ import ( "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db/pebble" "github.com/NethermindEth/juno/utils" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func buildSimpleTrie(t *testing.T) *trie.Trie { - // (250, 0, x1) edge - // | - // (0,0,x1) binary - // / \ - // (2) (3) - // Build trie - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) +func TestProve(t *testing.T) { + t.Parallel() - // Update trie - key1 := new(felt.Felt).SetUint64(0) - key2 := new(felt.Felt).SetUint64(1) - value1 := new(felt.Felt).SetUint64(2) - value2 := new(felt.Felt).SetUint64(3) + n := 1000 + tempTrie, records := nonRandomTrie(t, n) - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) + for _, record := range records { + proofSet := trie.NewProofNodeSet() + err := tempTrie.Prove(record.key, proofSet) + require.NoError(t, err) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) - require.NoError(t, tempTrie.Commit()) - return tempTrie + val, err := trie.VerifyProof(root, record.key, proofSet, crypto.Pedersen) + if err != nil { + t.Fatalf("failed for key %s", record.key.String()) + } + require.Equal(t, record.value, val) + } } -func buildSimpleBinaryRootTrie(t *testing.T) *trie.Trie { - // PF - // (0, 0, x) - // / \ - // (250, 0, cc) (250, 11111.., dd) - // | | - // (cc) (dd) - - // JUNO - // (0, 0, x) - // / \ - // (251, 0, cc) (251, 11111.., dd) - - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) +func TestProveNonExistent(t *testing.T) { + t.Parallel() - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) - require.NoError(t, err) + n := 1000 + tempTrie, _ := nonRandomTrie(t, n) - key1 := new(felt.Felt).SetUint64(0) - key2 := utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") - value1 := utils.HexToFelt(t, "0xcc") - value2 := utils.HexToFelt(t, "0xdd") + for i := 1; i < n+1; i++ { + keyFelt := new(felt.Felt).SetUint64(uint64(i + n)) - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) + proofSet := trie.NewProofNodeSet() + err := tempTrie.Prove(keyFelt, proofSet) + require.NoError(t, err) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) - require.NoError(t, tempTrie.Commit()) - return tempTrie + val, err := trie.VerifyProof(root, keyFelt, proofSet, crypto.Pedersen) + if err != nil { + t.Fatalf("failed for key %s", keyFelt.String()) + } + require.Equal(t, &felt.Zero, val) + } } -func buildSimpleDoubleBinaryTrie(t *testing.T) (*trie.Trie, []trie.ProofNode) { - // (249,0,x3) // Edge - // | - // (0, 0, x3) // Binary - // / \ - // (0,0,x1) // B (1, 1, 5) // Edge leaf - // / \ | - // (2) (3) (5) - - // Build trie - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) - require.NoError(t, err) - - // Update trie - key1 := new(felt.Felt).SetUint64(0) - key2 := new(felt.Felt).SetUint64(1) - key3 := new(felt.Felt).SetUint64(3) - value1 := new(felt.Felt).SetUint64(2) - value2 := new(felt.Felt).SetUint64(3) - value3 := new(felt.Felt).SetUint64(5) - - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) +func TestProveRandom(t *testing.T) { + t.Parallel() + tempTrie, records := randomTrie(t, 1000) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) + for _, record := range records { + proofSet := trie.NewProofNodeSet() + err := tempTrie.Prove(record.key, proofSet) + require.NoError(t, err) - _, err = tempTrie.Put(key3, value3) - require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) - require.NoError(t, tempTrie.Commit()) + val, err := trie.VerifyProof(root, record.key, proofSet, crypto.Pedersen) + require.NoError(t, err) + require.Equal(t, record.value, val) + } +} - zero := trie.NewKey(249, []byte{0}) - key3Bytes := new(felt.Felt).SetUint64(1).Bytes() - path3 := trie.NewKey(1, key3Bytes[:]) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x055C81F6A791FD06FC2E2CCAD922397EC76C3E35F2E06C0C0D43D551005A8DEA"), +func TestProveCustom(t *testing.T) { + t.Parallel() + + tests := []testTrie{ + { + name: "simple binary", + buildFn: buildSimpleTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(1), + expected: new(felt.Felt).SetUint64(3), + }, + }, }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - RightHash: utils.HexToFelt(t, "0x07C5BC1CC68B7BC8CA2F632DE98297E6DA9594FA23EDE872DD2ABEAFDE353B43"), + { + name: "simple double binary", + buildFn: buildSimpleDoubleBinaryTrie, + testKeys: []testKey{ + { + name: "prove existing key 0", + key: new(felt.Felt).SetUint64(0), + expected: new(felt.Felt).SetUint64(2), + }, + { + name: "prove existing key 3", + key: new(felt.Felt).SetUint64(3), + expected: new(felt.Felt).SetUint64(5), + }, + { + name: "prove non-existent key 2", + key: new(felt.Felt).SetUint64(2), + expected: new(felt.Felt).SetUint64(0), + }, + { + name: "prove non-existent key 123", + key: new(felt.Felt).SetUint64(123), + expected: new(felt.Felt).SetUint64(0), + }, + }, }, - &trie.Edge{ - Path: &path3, - Child: value3, + { + name: "simple binary root", + buildFn: buildSimpleBinaryRootTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(0), + expected: utils.HexToFelt(t, "0xcc"), + }, + }, + }, + { + name: "left-right edge", + buildFn: func(t *testing.T) (*trie.Trie, []*keyValue) { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tr, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) + require.NoError(t, err) + + records := []*keyValue{ + {key: utils.HexToFelt(t, "0xff"), value: utils.HexToFelt(t, "0xaa")}, + } + + for _, record := range records { + _, err = tr.Put(record.key, record.value) + require.NoError(t, err) + } + require.NoError(t, tr.Commit()) + return tr, records + }, + testKeys: []testKey{ + { + name: "prove existing key", + key: utils.HexToFelt(t, "0xff"), + expected: utils.HexToFelt(t, "0xaa"), + }, + }, + }, + { + name: "three key trie", + buildFn: build3KeyTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(2), + expected: new(felt.Felt).SetUint64(6), + }, + }, }, } - return tempTrie, expectedProofNodes -} - -func build3KeyTrie(t *testing.T) *trie.Trie { - // Starknet - // -------- - // - // Edge - // | - // Binary with len 249 parent - // / \ - // Binary (250) Edge with len 250 - // / \ / - // 0x4 0x5 0x6 child - - // Juno - // ---- - // - // Node (path 249) - // / \ - // Node (binary) \ - // / \ / - // 0x4 0x5 0x6 - - // Build trie - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) - - // Update trie - key1 := new(felt.Felt).SetUint64(0) - key2 := new(felt.Felt).SetUint64(1) - key3 := new(felt.Felt).SetUint64(2) - value1 := new(felt.Felt).SetUint64(4) - value2 := new(felt.Felt).SetUint64(5) - value3 := new(felt.Felt).SetUint64(6) - - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) - - _, err = tempTrie.Put(key3, value3) - require.NoError(t, err) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) - - require.NoError(t, tempTrie.Commit()) - return tempTrie -} + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() -func build4KeyTrie(t *testing.T) *trie.Trie { - // Juno - // 248 - // / \ - // 249 \ - // / \ \ - // 250 \ \ - // / \ /\ /\ - // 0 1 2 4 - - // Juno - should be able to reconstruct this from proofs - // 248 - // / \ - // 249 // Note we cant derive the right key, but need to store it's hash - // / \ - // 250 \ - // / \ / (Left hash set, no key) - // 0 - - // Pathfinder (???) - // 0 Edge - // | - // 248 Binary - // / \ - // 249 \ Binary Edge ?? - // / \ \ - // 250 250 250 Binary Edge ?? - // / \ / / - // 0 1 2 4 - - // Build trie - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) + tr, _ := test.buildFn(t) - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) + for _, tc := range test.testKeys { + t.Run(tc.name, func(t *testing.T) { + proofSet := trie.NewProofNodeSet() + err := tr.Prove(tc.key, proofSet) + require.NoError(t, err) - // Update trie - key1 := new(felt.Felt).SetUint64(0) - key2 := new(felt.Felt).SetUint64(1) - key3 := new(felt.Felt).SetUint64(2) - key5 := new(felt.Felt).SetUint64(4) - value1 := new(felt.Felt).SetUint64(4) - value2 := new(felt.Felt).SetUint64(5) - value3 := new(felt.Felt).SetUint64(6) - value5 := new(felt.Felt).SetUint64(7) - - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) - - _, err = tempTrie.Put(key3, value3) - require.NoError(t, err) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) - _, err = tempTrie.Put(key5, value5) - require.NoError(t, err) - - require.NoError(t, tempTrie.Commit()) - - return tempTrie -} - -func noDuplicates(proofNodes []trie.ProofNode) bool { - seen := make(map[felt.Felt]bool) - for _, pNode := range proofNodes { - if _, ok := seen[*pNode.Hash(crypto.Pedersen)]; ok { - return false - } - seen[*pNode.Hash(crypto.Pedersen)] = true - } - return true -} + root, err := tr.Root() + require.NoError(t, err) -// containsAll checks that subsetProofNodes is a subset of proofNodes -func containsAll(proofNodes, subsetProofNodes []trie.ProofNode) bool { - for _, pNode := range subsetProofNodes { - found := false - for _, p := range proofNodes { - if p.Hash(crypto.Pedersen).Equal(pNode.Hash(crypto.Pedersen)) { - found = true - break + val, err := trie.VerifyProof(root, tc.key, proofSet, crypto.Pedersen) + require.NoError(t, err) + require.Equal(t, tc.expected, val) + }) } - } - if !found { - return false - } - } - return true -} - -func isSameProofPath(proofNodes, expectedProofNodes []trie.ProofNode) bool { - if len(proofNodes) != len(expectedProofNodes) { - return false - } - for i := range proofNodes { - if !proofNodes[i].Hash(crypto.Pedersen).Equal(expectedProofNodes[i].Hash(crypto.Pedersen)) { - return false - } + }) } - return true } -func newBinaryProofNode() *trie.Binary { - return &trie.Binary{ - LeftHash: new(felt.Felt).SetUint64(1), - RightHash: new(felt.Felt).SetUint64(2), - } -} +// TestRangeProof tests normal range proof with both edge proofs +func TestRangeProof(t *testing.T) { + t.Parallel() -func TestGetProof(t *testing.T) { - t.Run("GP Simple Trie - simple binary", func(t *testing.T) { - tempTrie := buildSimpleTrie(t) + n := 500 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - zero := trie.NewKey(250, []byte{0}) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - }, + for i := 0; i < 100; i++ { + start := rand.Intn(n) + end := rand.Intn(n-start) + start + 1 - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), - RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), - }, - } - leafFelt := new(felt.Felt).SetUint64(0).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) + proof := trie.NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) require.NoError(t, err) - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) - - t.Run("GP Simple Trie - simple double binary", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) - - expectedProofNodes[2] = &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), - RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := start; i < end; i++ { + keys = append(keys, records[i].key) + values = append(values, records[i].value) } - leafFelt := new(felt.Felt).SetUint64(0).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) + _, err = trie.VerifyRangeProof(root, records[start].key, keys, values, proof) require.NoError(t, err) + } +} - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) - - t.Run("GP Simple Trie - simple double binary edge", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) - leafFelt := new(felt.Felt).SetUint64(3).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) - require.NoError(t, err) +// TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs +func TestRangeProofWithNonExistentProof(t *testing.T) { + t.Parallel() - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) + n := 500 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - t.Run("GP Simple Trie - simple binary root", func(t *testing.T) { - tempTrie := buildSimpleBinaryRootTrie(t) + for i := 0; i < 100; i++ { + start := rand.Intn(n) + end := rand.Intn(n-start) + start + 1 - key1Bytes := new(felt.Felt).SetUint64(0).Bytes() - path1 := trie.NewKey(250, key1Bytes[:]) - expectedProofNodes := []trie.ProofNode{ - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x06E08BF82793229338CE60B65D1845F836C8E2FBFE2BC59FF24AEDBD8BA219C4"), - RightHash: utils.HexToFelt(t, "0x04F9B8E66212FB528C0C1BD02F43309C53B895AA7D9DC91180001BDD28A588FA"), - }, - &trie.Edge{ - Path: &path1, - Child: utils.HexToFelt(t, "0xcc"), - }, + first := decrementFelt(records[start].key) + if start != 0 && first.Equal(records[start-1].key) { + continue } - leafFelt := new(felt.Felt).SetUint64(0).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - - proofNodes, err := trie.GetProof(&leafKey, tempTrie) - require.NoError(t, err) - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) - - t.Run("GP Simple Trie - left-right edge", func(t *testing.T) { - // (251,0xff,0xaa) - // / - // \ - // (0xaa) - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) - require.NoError(t, err) - - key1 := utils.HexToFelt(t, "0xff") - value1 := utils.HexToFelt(t, "0xaa") - - _, err = tempTrie.Put(key1, value1) + proof := trie.NewProofNodeSet() + err := tr.GetRangeProof(first, records[end-1].key, proof) require.NoError(t, err) - require.NoError(t, tempTrie.Commit()) - - key1Bytes := key1.Bytes() - path1 := trie.NewKey(251, key1Bytes[:]) - - child := utils.HexToFelt(t, "0x00000000000000000000000000000000000000000000000000000000000000AA") - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &path1, - Child: child, - }, + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value } - leafFelt := new(felt.Felt).SetUint64(0).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) - require.NoError(t, err) - - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) - - t.Run("GP Simple Trie - proof for non-set key", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) - leafFelt := new(felt.Felt).SetUint64(123).Bytes() // The (root) edge node would have a shorter len if this key was set - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) + _, err = trie.VerifyRangeProof(root, first, keys, values, proof) require.NoError(t, err) + } +} - // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } - require.Equal(t, expectedProofNodes[0:2], proofNodes) - }) +// TestRangeProofWithInvalidNonExistentProof tests range proof with invalid non-existent proofs. +// One scenario is when there is a gap between the first element and the left edge proof. +func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { + t.Parallel() - t.Run("GP Simple Trie - proof for inner key", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) + n := 500 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - innerFelt := new(felt.Felt).SetUint64(2).Bytes() - innerKey := trie.NewKey(123, innerFelt[:]) // The (root) edge node has len 249 which shows this doesn't exist - proofNodes, err := trie.GetProof(&innerKey, tempTrie) - require.NoError(t, err) + start, end := 100, 200 + first := decrementFelt(records[start].key) - // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } - require.Equal(t, expectedProofNodes[0:2], proofNodes) - }) - - t.Run("GP Simple Trie - proof for non-set key, with leafs set to right and left", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(first, records[end-1].key, proof) + require.NoError(t, err) - leafFelt := new(felt.Felt).SetUint64(2).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) - require.NoError(t, err) + start = 105 // Gap created + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } - // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } - require.Equal(t, expectedProofNodes, proofNodes) - }) + _, err = trie.VerifyRangeProof(root, first, keys, values, proof) + require.Error(t, err) } -func TestVerifyProof(t *testing.T) { - // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2137 - t.Run("VP Simple binary trie", func(t *testing.T) { - tempTrie := buildSimpleTrie(t) - zero := trie.NewKey(250, []byte{0}) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), - RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), - }, - } - - root, err := tempTrie.Root() - require.NoError(t, err) - val1 := new(felt.Felt).SetUint64(2) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - leafkey := trie.NewKey(251, zeroFeltBytes[:]) - assert.True(t, trie.VerifyProof(root, &leafkey, val1, expectedProofNodes, crypto.Pedersen)) - }) - - // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2167 - t.Run("VP Simple double binary trie", func(t *testing.T) { - tempTrie, _ := buildSimpleDoubleBinaryTrie(t) - zero := trie.NewKey(249, []byte{0}) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x055C81F6A791FD06FC2E2CCAD922397EC76C3E35F2E06C0C0D43D551005A8DEA"), - }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - RightHash: utils.HexToFelt(t, "0x07C5BC1CC68B7BC8CA2F632DE98297E6DA9594FA23EDE872DD2ABEAFDE353B43"), - }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), - RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), - }, - } +func TestOneElementRangeProof(t *testing.T) { + t.Parallel() - root, err := tempTrie.Root() - require.NoError(t, err) - val1 := new(felt.Felt).SetUint64(2) - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - leafkey := trie.NewKey(251, zeroFeltBytes[:]) - assert.True(t, trie.VerifyProof(root, &leafkey, val1, expectedProofNodes, crypto.Pedersen)) - }) + n := 1000 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - t.Run("VP three key trie", func(t *testing.T) { - tempTrie := build3KeyTrie(t) - zero := trie.NewKey(249, []byte{0}) - felt2 := new(felt.Felt).SetUint64(0).Bytes() - lastPath := trie.NewKey(1, felt2[:]) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x0768DEB8D0795D80AAAC2E5E326141F33044759F97A1BF092D8EB9C4E4BE9234"), - }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x057166F9476D0A2D6875124251841EB85A9AE37462FAE3CBF7304BCD593938E7"), - RightHash: utils.HexToFelt(t, "0x060FBDE29F96F706498EFD132DC7F312A4C99A9AE051BF152C2AF2B3CAF31E5B"), - }, - &trie.Edge{ - Path: &lastPath, - Child: utils.HexToFelt(t, "0x6"), - }, - } + t.Run("both edge proofs with the same key", func(t *testing.T) { + t.Parallel() - root, err := tempTrie.Root() + start := 100 + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[start].key, records[start].key, proof) require.NoError(t, err) - val6 := new(felt.Felt).SetUint64(6) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - leafkey := trie.NewKey(251, twoFeltBytes[:]) - gotProof, err := trie.GetProof(&leafkey, tempTrie) + _, err = trie.VerifyRangeProof(root, records[start].key, []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) require.NoError(t, err) - require.Equal(t, expectedProofNodes, gotProof) - - assert.True(t, trie.VerifyProof(root, &leafkey, val6, expectedProofNodes, crypto.Pedersen)) }) - t.Run("VP non existent key - less than root edge", func(t *testing.T) { - tempTrie, _ := buildSimpleDoubleBinaryTrie(t) + t.Run("left non-existent edge proof", func(t *testing.T) { + t.Parallel() - nonExistentKey := trie.NewKey(123, []byte{0}) // Diverges before the root node (len root node = 249) - nonExistentKeyValue := new(felt.Felt).SetUint64(2) - proofNodes, err := trie.GetProof(&nonExistentKey, tempTrie) + start := 100 + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(decrementFelt(records[start].key), records[start].key, proof) require.NoError(t, err) - root, err := tempTrie.Root() + _, err = trie.VerifyRangeProof(root, decrementFelt(records[start].key), []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) require.NoError(t, err) - - require.False(t, trie.VerifyProof(root, &nonExistentKey, nonExistentKeyValue, proofNodes, crypto.Pedersen)) }) - t.Run("VP non existent leaf key", func(t *testing.T) { - tempTrie, _ := buildSimpleDoubleBinaryTrie(t) + t.Run("right non-existent edge proof", func(t *testing.T) { + t.Parallel() - nonExistentKeyByte := new(felt.Felt).SetUint64(2).Bytes() // Key not set - nonExistentKey := trie.NewKey(251, nonExistentKeyByte[:]) - nonExistentKeyValue := new(felt.Felt).SetUint64(2) - proofNodes, err := trie.GetProof(&nonExistentKey, tempTrie) + end := 100 + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[end].key, incrementFelt(records[end].key), proof) require.NoError(t, err) - root, err := tempTrie.Root() + _, err = trie.VerifyRangeProof(root, records[end].key, []*felt.Felt{records[end].key}, []*felt.Felt{records[end].value}, proof) require.NoError(t, err) - - require.False(t, trie.VerifyProof(root, &nonExistentKey, nonExistentKeyValue, proofNodes, crypto.Pedersen)) }) -} -func TestProofToPath(t *testing.T) { - t.Run("PTP Proof To Path Simple binary trie proof to path", func(t *testing.T) { - tempTrie := buildSimpleTrie(t) - zeroFeltByte := new(felt.Felt).Bytes() - zero := trie.NewKey(250, zeroFeltByte[:]) - leafValue := utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002") - siblingValue := utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003") - proofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - }, - &trie.Binary{ - LeftHash: leafValue, - RightHash: siblingValue, - }, - } + t.Run("both non-existent edge proofs", func(t *testing.T) { + t.Parallel() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - leafkey := trie.NewKey(251, zeroFeltBytes[:]) - sns, err := trie.ProofToPath(proofNodes, &leafkey, crypto.Pedersen) + start := 100 + first, last := decrementFelt(records[start].key), incrementFelt(records[start].key) + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(first, last, proof) require.NoError(t, err) - rootKey := tempTrie.RootKey() - - require.Equal(t, 1, len(sns)) - require.Equal(t, rootKey.Len(), sns[0].Key().Len()) - require.Equal(t, leafValue.String(), sns[0].Node().LeftHash.String()) - require.Equal(t, siblingValue.String(), sns[0].Node().RightHash.String()) - }) - - t.Run("PTP Simple double binary trie proof to path", func(t *testing.T) { - tempTrie := buildSimpleBinaryRootTrie(t) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - leafkey := trie.NewKey(251, zeroFeltBytes[:]) - path1 := trie.NewKey(250, zeroFeltBytes[:]) - proofNodes := []trie.ProofNode{ - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x06E08BF82793229338CE60B65D1845F836C8E2FBFE2BC59FF24AEDBD8BA219C4"), - RightHash: utils.HexToFelt(t, "0x04F9B8E66212FB528C0C1BD02F43309C53B895AA7D9DC91180001BDD28A588FA"), - }, - &trie.Edge{ - Path: &path1, - Child: utils.HexToFelt(t, "0xcc"), - }, - } - - siblingValue := utils.HexToFelt(t, "0xdd") - sns, err := trie.ProofToPath(proofNodes, &leafkey, crypto.Pedersen) - require.NoError(t, err) - rootKey := tempTrie.RootKey() - rootNode, err := tempTrie.GetNodeFromKey(rootKey) + _, err = trie.VerifyRangeProof(root, first, []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) require.NoError(t, err) - leftNode, err := tempTrie.GetNodeFromKey(rootNode.Left) - require.NoError(t, err) - require.Equal(t, 1, len(sns)) - require.Equal(t, rootKey.Len(), sns[0].Key().Len()) - require.Equal(t, leftNode.HashFromParent(rootKey, rootNode.Left, crypto.Pedersen).String(), sns[0].Node().LeftHash.String()) - require.NotEqual(t, siblingValue.String(), sns[0].Node().RightHash.String()) }) - t.Run("PTP boundary proofs with three key trie", func(t *testing.T) { - tri := build3KeyTrie(t) - rootKey := tri.RootKey() - rootNode, err := tri.GetNodeFromKey(rootKey) - require.NoError(t, err) + t.Run("1 key trie", func(t *testing.T) { + t.Parallel() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - zeroLeafValue := new(felt.Felt).SetUint64(4) - oneLeafValue := new(felt.Felt).SetUint64(5) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - bProofs, err := trie.GetBoundaryProofs(&zeroLeafkey, &twoLeafkey, tri) + tr, records := build1KeyTrie(t) + root, err := tr.Root() require.NoError(t, err) - // Test 1 - leftProofPath, err := trie.ProofToPath(bProofs[0], &zeroLeafkey, crypto.Pedersen) - require.Equal(t, 2, len(leftProofPath)) - require.NoError(t, err) - left, err := tri.GetNodeFromKey(rootNode.Left) - require.NoError(t, err) - right, err := tri.GetNodeFromKey(rootNode.Right) + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(&felt.Zero, records[0].key, proof) require.NoError(t, err) - require.Equal(t, rootKey, leftProofPath[0].Key()) - require.Equal(t, left.HashFromParent(rootKey, rootNode.Left, crypto.Pedersen).String(), leftProofPath[0].Node().LeftHash.String()) - require.Equal(t, right.HashFromParent(rootKey, rootNode.Right, crypto.Pedersen).String(), leftProofPath[0].Node().RightHash.String()) - require.Equal(t, rootNode.Left, leftProofPath[1].Key()) - require.Equal(t, zeroLeafValue.String(), leftProofPath[1].Node().LeftHash.String()) - require.Equal(t, oneLeafValue.String(), leftProofPath[1].Node().RightHash.String()) - - // Test 2 - rightProofPath, err := trie.ProofToPath(bProofs[1], &twoLeafkey, crypto.Pedersen) - require.Equal(t, 1, len(rightProofPath)) + + _, err = trie.VerifyRangeProof(root, records[0].key, []*felt.Felt{records[0].key}, []*felt.Felt{records[0].value}, proof) require.NoError(t, err) - require.Equal(t, rootKey, rightProofPath[0].Key()) - require.NotEqual(t, rootNode.Right, rightProofPath[0].Node().Right) - require.NotEqual(t, uint8(0), rightProofPath[0].Node().Right) - require.Equal(t, right.HashFromParent(rootKey, rootNode.Right, crypto.Pedersen).String(), rightProofPath[0].Node().RightHash.String()) }) } -func TestBuildTrie(t *testing.T) { - t.Run("Simple binary trie proof to path", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - - tri := build3KeyTrie(t) - rootKey := tri.RootKey() - rootCommitment, err := tri.Root() - require.NoError(t, err) - rootNode, err := tri.GetNodeFromKey(rootKey) - require.NoError(t, err) - leftNode, err := tri.GetNodeFromKey(rootNode.Left) - require.NoError(t, err) - leftleftNode, err := tri.GetNodeFromKey(leftNode.Left) - require.NoError(t, err) - leftrightNode, err := tri.GetNodeFromKey(leftNode.Right) - require.NoError(t, err) +// TestAllElementsProof tests the range proof with all elements and nil proof. +func TestAllElementsRangeProof(t *testing.T) { + t.Parallel() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - bProofs, err := trie.GetBoundaryProofs(&zeroLeafkey, &twoLeafkey, tri) - require.NoError(t, err) - - leftProof, err := trie.ProofToPath(bProofs[0], &zeroLeafkey, crypto.Pedersen) - require.NoError(t, err) + n := 1000 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - rightProof, err := trie.ProofToPath(bProofs[1], &twoLeafkey, crypto.Pedersen) - require.NoError(t, err) + keys := make([]*felt.Felt, n) + values := make([]*felt.Felt, n) + for i, record := range records { + keys[i] = record.key + values[i] = record.value + } - keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} - values := []*felt.Felt{new(felt.Felt).SetUint64(5)} - builtTrie, err := trie.BuildTrie(leftProof, rightProof, keys, values) - require.NoError(t, err) + _, err = trie.VerifyRangeProof(root, nil, keys, values, nil) + require.NoError(t, err) - builtRootKey := builtTrie.RootKey() - builtRootNode, err := builtTrie.GetNodeFromKey(builtRootKey) - require.NoError(t, err) - builtLeftNode, err := builtTrie.GetNodeFromKey(builtRootNode.Left) - require.NoError(t, err) - builtLeftRightNode, err := builtTrie.GetNodeFromKey(builtLeftNode.Right) - require.NoError(t, err) + // Should also work with proof + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[0].key, records[n-1].key, proof) + require.NoError(t, err) - // Assert the structure / keys correct - require.Equal(t, rootKey, builtRootKey) - require.Equal(t, rootNode.Left, builtRootNode.Left, "left fail") - require.Equal(t, leftrightNode.Right, builtLeftRightNode.Right, "right fail") - require.Equal(t, uint8(0), builtRootNode.Right.Len(), "right fail") - require.Equal(t, uint8(0), builtLeftNode.Left.Len(), "left left fail") + _, err = trie.VerifyRangeProof(root, keys[0], keys, values, proof) + require.NoError(t, err) +} - // Assert the leaf nodes have the correct values - require.Equal(t, leftleftNode.Value.String(), builtLeftNode.LeftHash.String(), "should be 0x4") - require.Equal(t, leftrightNode.Value.String(), builtLeftRightNode.Value.String(), "should be 0x5") +// TestSingleSideRangeProof tests the range proof starting with zero. +func TestSingleSideRangeProof(t *testing.T) { + t.Parallel() - // Given the above two asserts pass, we should be able to reconstruct the correct commitment - reconstructedRootCommitment, err := builtTrie.Root() - require.NoError(t, err) - require.Equal(t, rootCommitment.String(), reconstructedRootCommitment.String(), "root commitment not equal") - }) -} + tr, records := randomTrie(t, 1000) + root, err := tr.Root() + require.NoError(t, err) -func TestVerifyRangeProof(t *testing.T) { - t.Run("VPR two proofs, single key trie", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - - tri := build3KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} - values := []*felt.Felt{new(felt.Felt).SetUint64(5)} - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} - proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(6)} - rootCommitment, err := tri.Root() - require.NoError(t, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) + for i := 0; i < len(records); i += 100 { + proof := trie.NewProofNodeSet() + err := tr.GetRangeProof(&felt.Zero, records[i].key, proof) require.NoError(t, err) - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) - require.NoError(t, err) - require.True(t, verif) - }) - t.Run("VPR all keys provided, no proofs needed", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - tri := build3KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} - values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} - proofKeys := [2]*trie.Key{} - proofValues := [2]*felt.Felt{} - proofs := [2][]trie.ProofNode{} - rootCommitment, err := tri.Root() - require.NoError(t, err) - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) - require.NoError(t, err) - require.True(t, verif) - }) + keys := make([]*felt.Felt, i+1) + values := make([]*felt.Felt, i+1) + for j := 0; j < i+1; j++ { + keys[j] = records[j].key + values[j] = records[j].value + } - t.Run("VPR left proof, all right keys", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - - tri := build3KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} - values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} - proofKeys := [2]*trie.Key{&zeroLeafkey} - proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4)} - leftProof, err := trie.GetProof(proofKeys[0], tri) - require.NoError(t, err) - proofs := [2][]trie.ProofNode{leftProof} - rootCommitment, err := tri.Root() - require.NoError(t, err) - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + _, err = trie.VerifyRangeProof(root, &felt.Zero, keys, values, proof) require.NoError(t, err) - require.True(t, verif) - }) + } +} - t.Run("VPR right proof, all left keys", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - - tri := build3KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1)} - values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5)} - proofKeys := [2]*trie.Key{nil, &twoLeafkey} - proofValues := [2]*felt.Felt{nil, new(felt.Felt).SetUint64(6)} - rightProof, err := trie.GetProof(proofKeys[1], tri) - require.NoError(t, err) - proofs := [2][]trie.ProofNode{nil, rightProof} - rootCommitment, err := tri.Root() - require.NoError(t, err) - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) - require.NoError(t, err) - require.True(t, verif) - }) +func TestGappedRangeProof(t *testing.T) { + t.Parallel() + t.Skip("gapped keys will sometimes succeed, the current proof format is not able to handle this") - t.Run("VPR left proof, all inner keys, right proof with non-set key", func(t *testing.T) { - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + tr, records := nonRandomTrie(t, 5) + root, err := tr.Root() + require.NoError(t, err) - threeFeltBytes := new(felt.Felt).SetUint64(3).Bytes() - threeLeafkey := trie.NewKey(251, threeFeltBytes[:]) + first, last := 1, 4 + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[first].key, records[last].key, proof) + require.NoError(t, err) - tri := build4KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} - values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} - proofKeys := [2]*trie.Key{&zeroLeafkey, &threeLeafkey} - proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), nil} - leftProof, err := trie.GetProof(proofKeys[0], tri) - require.NoError(t, err) - rightProof, err := trie.GetProof(proofKeys[1], tri) - require.NoError(t, err) + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := first; i <= last; i++ { + if i == (first+last)/2 { + continue + } - proofs := [2][]trie.ProofNode{leftProof, rightProof} - rootCommitment, err := tri.Root() - require.NoError(t, err) + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) - require.NoError(t, err) - require.True(t, verif) - }) + _, err = trie.VerifyRangeProof(root, records[first].key, keys, values, proof) + require.Error(t, err) } -func TestMergeProofPaths(t *testing.T) { - t.Run("3Key Trie no duplicates and all values exist in merged path", func(t *testing.T) { - tri := build3KeyTrie(t) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) +func TestEmptyRangeProof(t *testing.T) { + t.Parallel() - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} + tr, records := randomTrie(t, 1000) + root, err := tr.Root() + require.NoError(t, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + cases := []struct { + pos int + err bool + }{ + {len(records) - 1, false}, + {500, true}, + } - mergedProofs, _, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) + for _, c := range cases { + proof := trie.NewProofNodeSet() + first := incrementFelt(records[c.pos].key) + err = tr.GetRangeProof(first, first, proof) require.NoError(t, err) - require.True(t, containsAll(mergedProofs, proofs[0])) - require.True(t, containsAll(mergedProofs, proofs[1])) - require.True(t, noDuplicates(mergedProofs)) - }) - - t.Run("4Key Trie two common ancestors", func(t *testing.T) { - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + _, err := trie.VerifyRangeProof(root, first, nil, nil, proof) + if c.err { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } +} - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) +func TestHasRightElement(t *testing.T) { + t.Parallel() - tri := build4KeyTrie(t) + tr, records := randomTrie(t, 500) + root, err := tr.Root() + require.NoError(t, err) - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} + cases := []struct { + start int + end int + hasMore bool + }{ + {-1, 1, true}, // single element with non-existent left proof + {0, 1, true}, // single element with existent left proof + {0, 100, true}, // start to middle + {50, 100, true}, // middle only + {50, len(records), false}, // middle to end + {len(records) - 1, len(records), false}, // Single last element with two existent proofs(point to same key) + {0, len(records), false}, // The whole set with existent left proof + {-1, len(records), false}, // The whole set with non-existent left proof + } - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + for _, c := range cases { + var ( + first *felt.Felt + start = c.start + end = c.end + proof = trie.NewProofNodeSet() + ) + if start == -1 { + first = &felt.Zero + start = 0 + } else { + first = records[start].key + } - mergedProofs, _, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) + err := tr.GetRangeProof(first, records[end-1].key, proof) require.NoError(t, err) - require.True(t, containsAll(mergedProofs, proofs[0])) - require.True(t, containsAll(mergedProofs, proofs[1])) - require.True(t, noDuplicates(mergedProofs)) - }) + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := start; i < end; i++ { + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } - t.Run("Trie 4Key one ancestor", func(t *testing.T) { - tri := build4KeyTrie(t) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + hasMore, err := trie.VerifyRangeProof(root, first, keys, values, proof) + require.NoError(t, err) + require.Equal(t, c.hasMore, hasMore) + } +} - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) +// TestBadRangeProof generates random bad proof scenarios and verifies that the proof is invalid. +func TestBadRangeProof(t *testing.T) { + t.Parallel() - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} + tr, records := randomTrie(t, 1000) + root, err := tr.Root() + require.NoError(t, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + for i := 0; i < 100; i++ { + start := rand.Intn(len(records)) + end := rand.Intn(len(records)-start) + start + 1 - mergedProofs, _, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) + proof := trie.NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) require.NoError(t, err) - require.True(t, containsAll(mergedProofs, proofs[0])) - require.True(t, containsAll(mergedProofs, proofs[1])) - require.True(t, noDuplicates(mergedProofs)) - }) - - t.Run("Empty proof path", func(t *testing.T) { - tri := build4KeyTrie(t) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + keys := []*felt.Felt{} + values := []*felt.Felt{} + for j := start; j < end; j++ { + keys = append(keys, records[j].key) + values = append(values, records[j].value) + } - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + first := keys[0] + testCase := rand.Intn(5) + + index := rand.Intn(end - start) + switch testCase { + case 0: // modified key + keys[index] = new(felt.Felt).SetUint64(rand.Uint64()) + case 1: // modified value + values[index] = new(felt.Felt).SetUint64(rand.Uint64()) + case 2: // out of order + index2 := rand.Intn(end - start) + if index2 == index { + continue + } + keys[index], keys[index2] = keys[index2], keys[index] + values[index], values[index2] = values[index2], values[index] + case 3: // set random key to empty + keys[index] = &felt.Zero + case 4: // set random value to empty + values[index] = &felt.Zero + // TODO(weiihann): gapped proof will fail sometimes + // case 5: // gapped + // if end-start < 100 || index == 0 || index == end-start-1 { + // continue + // } + // keys = append(keys[:index], keys[index+1:]...) + // values = append(values[:index], values[index+1:]...) + } + _, err = trie.VerifyRangeProof(root, first, keys, values, proof) + if err == nil { + t.Fatalf("expected error for test case %d, index %d, start %d, end %d", testCase, index, start, end) + } + } +} - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} +func BenchmarkProve(b *testing.B) { + tr, records := randomTrie(b, 1000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + proof := trie.NewProofNodeSet() + key := records[i%len(records)].key + if err := tr.Prove(key, proof); err != nil { + b.Fatal(err) + } + } +} - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) +func BenchmarkVerifyProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root, err := tr.Root() + require.NoError(b, err) - emptyPath := []trie.ProofNode{} + var proofs []*trie.ProofNodeSet + for _, record := range records { + proof := trie.NewProofNodeSet() + if err := tr.Prove(record.key, proof); err != nil { + b.Fatal(err) + } + proofs = append(proofs, proof) + } - _, _, err = trie.MergeProofPaths(proofs[0], emptyPath, crypto.Pedersen) - require.Error(t, err) - }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + index := i % len(records) + if _, err := trie.VerifyProof(root, records[index].key, proofs[index], crypto.Pedersen); err != nil { + b.Fatal(err) + } + } +} - t.Run("Root of the proof paths are different", func(t *testing.T) { - tri := build4KeyTrie(t) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() +func BenchmarkVerifyRangeProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root, err := tr.Root() + require.NoError(b, err) - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + start := 2 + end := start + 500 - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(b, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } - _, _, err = trie.MergeProofPaths(proofs[0], proofs[1][1:], crypto.Pedersen) - require.Error(t, err) - }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := trie.VerifyRangeProof(root, keys[0], keys, values, proof) + require.NoError(b, err) + } } -func TestSplitProofPaths(t *testing.T) { - t.Run("3Key Trie retrieved right and left proofs are same with the merged ones", func(t *testing.T) { - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - - tri := build3KeyTrie(t) - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} +func buildTrie(t *testing.T, records []*keyValue) *trie.Trie { + if len(records) == 0 { + t.Fatal("records must have at least one element") + } - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) - mergedProofs, rootHash, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) - require.NoError(t, err) + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) - leftSplit, rightSplit, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) + for _, record := range records { + _, err = tempTrie.Put(record.key, record.value) require.NoError(t, err) + } - require.True(t, isSameProofPath(leftSplit, proofs[0])) - require.True(t, isSameProofPath(rightSplit, proofs[1])) - }) + require.NoError(t, tempTrie.Commit()) - t.Run("4Key Trie two common ancestors retrieved right and left proofs are same with the merged ones", func(t *testing.T) { - tri := build4KeyTrie(t) + return tempTrie +} - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() +func build1KeyTrie(t *testing.T) (*trie.Trie, []*keyValue) { + return nonRandomTrie(t, 1) +} - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) +func buildSimpleTrie(t *testing.T) (*trie.Trie, []*keyValue) { + // (250, 0, x1) edge + // | + // (0,0,x1) binary + // / \ + // (2) (3) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(2)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(3)}, + } - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} + return buildTrie(t, records), records +} - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) +func buildSimpleBinaryRootTrie(t *testing.T) (*trie.Trie, []*keyValue) { + // PF + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) - mergedProofs, rootHash, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) - require.NoError(t, err) + // JUNO + // (0, 0, x) + // / \ + // (251, 0, cc) (251, 11111.., dd) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: utils.HexToFelt(t, "0xcc")}, + {key: utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), value: utils.HexToFelt(t, "0xdd")}, + } + return buildTrie(t, records), records +} - leftSplit, rightSplit, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) - require.NoError(t, err) +//nolint:dupl +func buildSimpleDoubleBinaryTrie(t *testing.T) (*trie.Trie, []*keyValue) { + // (249,0,x3) // Edge + // | + // (0, 0, x3) // Binary + // / \ + // (0,0,x1) // B (1, 1, 5) // Edge leaf + // / \ | + // (2) (3) (5) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(2)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(3)}, + {key: new(felt.Felt).SetUint64(3), value: new(felt.Felt).SetUint64(5)}, + } + return buildTrie(t, records), records +} - require.True(t, isSameProofPath(leftSplit, proofs[0])) - require.True(t, isSameProofPath(rightSplit, proofs[1])) - }) +//nolint:dupl +func build3KeyTrie(t *testing.T) (*trie.Trie, []*keyValue) { + // Starknet + // -------- + // + // Edge + // | + // Binary with len 249 parent + // / \ + // Binary (250) Edge with len 250 + // / \ / + // 0x4 0x5 0x6 child - t.Run("4Key Trie one common ancestor retrieved right and left proofs are same with the merged ones", func(t *testing.T) { - tri := build4KeyTrie(t) - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) + // Juno + // ---- + // + // Node (path 249) + // / \ + // Node (binary) \ + // / \ / + // 0x4 0x5 0x6 + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(4)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(5)}, + {key: new(felt.Felt).SetUint64(2), value: new(felt.Felt).SetUint64(6)}, + } - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} + return buildTrie(t, records), records +} - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) +func nonRandomTrie(t *testing.T, numKeys int) (*trie.Trie, []*keyValue) { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) - mergedProofs, rootHash, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) - require.NoError(t, err) + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) - leftSplit, rightSplit, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) + records := make([]*keyValue, numKeys) + for i := 1; i < numKeys+1; i++ { + key := new(felt.Felt).SetUint64(uint64(i)) + records[i-1] = &keyValue{key: key, value: key} + _, err := tempTrie.Put(key, key) require.NoError(t, err) + } - require.True(t, isSameProofPath(leftSplit, proofs[0])) - require.True(t, isSameProofPath(rightSplit, proofs[1])) + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 }) - t.Run("4Key Trie reversed merge path", func(t *testing.T) { - tri := build4KeyTrie(t) - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) - - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} - - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) - - mergedProofs, rootHash, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) - require.NoError(t, err) - - for i := 0; i < len(mergedProofs)/2; i++ { - j := len(mergedProofs) - 1 - i - mergedProofs[i], mergedProofs[j] = mergedProofs[j], mergedProofs[i] - } - - leftSplit, rightSplit, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) - require.NoError(t, err) + require.NoError(t, tempTrie.Commit()) - require.True(t, isSameProofPath(leftSplit, proofs[0])) - require.True(t, isSameProofPath(rightSplit, proofs[1])) - }) + return tempTrie, records +} - t.Run("Roothash does not exist", func(t *testing.T) { - tri := build4KeyTrie(t) - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) +func randomTrie(t testing.TB, n int) (*trie.Trie, []*keyValue) { + rrand := rand.New(rand.NewSource(3)) - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) - mergedProofs, _, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) + records := make([]*keyValue, n) + for i := 0; i < n; i++ { + key := new(felt.Felt).SetUint64(uint64(rrand.Uint32() + 1)) + records[i] = &keyValue{key: key, value: key} + _, err := tempTrie.Put(key, key) require.NoError(t, err) + } - rootHashFalse := new(felt.Felt).SetUint64(0) + require.NoError(t, tempTrie.Commit()) - _, _, err = trie.SplitProofPath(mergedProofs, rootHashFalse, crypto.Pedersen) - require.Error(t, err) + // Sort records by key + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 }) - t.Run("Two splits in the merged path", func(t *testing.T) { - p1 := newBinaryProofNode() - p2 := newBinaryProofNode() - p3 := newBinaryProofNode() - p4 := newBinaryProofNode() - p5 := newBinaryProofNode() - - p4.LeftHash = new(felt.Felt).SetUint64(3) - p2.RightHash = new(felt.Felt).SetUint64(4) - - p3.RightHash = p5.Hash(crypto.Pedersen) - p3.LeftHash = p4.Hash(crypto.Pedersen) - p1.RightHash = p3.Hash(crypto.Pedersen) - p1.LeftHash = p2.Hash(crypto.Pedersen) - - mergedProofs := []trie.ProofNode{p1, p2, p3, p4, p5} - rootHash := p1.Hash(crypto.Pedersen) + return tempTrie, records +} - _, _, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) - require.Error(t, err) - }) +func decrementFelt(f *felt.Felt) *felt.Felt { + return new(felt.Felt).Sub(f, new(felt.Felt).SetUint64(1)) +} - t.Run("Duplicate nodes in the merged path", func(t *testing.T) { - p1 := newBinaryProofNode() - p2 := newBinaryProofNode() - p3 := newBinaryProofNode() - p4 := newBinaryProofNode() - p5 := newBinaryProofNode() +func incrementFelt(f *felt.Felt) *felt.Felt { + return new(felt.Felt).Add(f, new(felt.Felt).SetUint64(1)) +} - p3.RightHash = p5.Hash(crypto.Pedersen) - p3.LeftHash = p4.Hash(crypto.Pedersen) - p1.RightHash = p3.Hash(crypto.Pedersen) - p1.LeftHash = p2.Hash(crypto.Pedersen) +type testKey struct { + name string + key *felt.Felt + expected *felt.Felt +} - mergedProofs := []trie.ProofNode{p1, p2, p3, p4, p5} - rootHash := p1.Hash(crypto.Pedersen) +type testTrie struct { + name string + buildFn func(*testing.T) (*trie.Trie, []*keyValue) + testKeys []testKey +} - _, _, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) - require.Error(t, err) - }) +type keyValue struct { + key *felt.Felt + value *felt.Felt } diff --git a/core/trie/trie.go b/core/trie/trie.go index ff978c8709..0dd6f4c77c 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -11,8 +11,11 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils" ) +const globalTrieHeight = 251 // TODO(weiihann): this is declared in core also, should be moved to a common place + type hashFunc func(*felt.Felt, *felt.Felt) *felt.Felt // Trie is a dense Merkle Patricia Trie (i.e., all internal nodes have two children). @@ -95,31 +98,8 @@ func RunOnTempTriePoseidon(height uint8, do func(*Trie) error) error { // feltToKey Converts a key, given in felt, to a trie.Key which when followed on a [Trie], // leads to the corresponding [Node] -func (t *Trie) feltToKey(k *felt.Felt) Key { - kBytes := k.Bytes() - return NewKey(t.height, kBytes[:]) -} - -// findCommonKey finds the set of common MSB bits in two key bitsets. -func findCommonKey(longerKey, shorterKey *Key) (Key, bool) { - divergentBit := findDivergentBit(longerKey, shorterKey) - commonKey := *shorterKey - commonKey.DeleteLSB(shorterKey.Len() - divergentBit + 1) - return commonKey, divergentBit == shorterKey.Len()+1 -} - -func findDivergentBit(longerKey, shorterKey *Key) uint8 { - divergentBit := uint8(0) - for divergentBit <= shorterKey.Len() && - longerKey.Test(longerKey.Len()-divergentBit) == shorterKey.Test(shorterKey.Len()-divergentBit) { - divergentBit++ - } - return divergentBit -} - -func isSubset(longerKey, shorterKey *Key) bool { - divergentBit := findDivergentBit(longerKey, shorterKey) - return divergentBit == shorterKey.Len()+1 +func (t *Trie) FeltToKey(k *felt.Felt) Key { + return FeltToKey(t.height, k) } // path returns the path as mentioned in the [specification] for commitment calculations. @@ -145,14 +125,97 @@ func (sn *StorageNode) Key() *Key { return sn.key } -func (sn *StorageNode) Node() *Node { - return sn.node +func (sn *StorageNode) Value() *felt.Felt { + return sn.node.Value +} + +func (sn *StorageNode) String() string { + return fmt.Sprintf("StorageNode{key: %s, node: %s}", sn.key, sn.node) +} + +func (sn *StorageNode) Update(other *StorageNode) error { + // First validate all fields for conflicts + if sn.key != nil && other.key != nil && !sn.key.Equal(NilKey) && !other.key.Equal(NilKey) { + if !sn.key.Equal(other.key) { + return fmt.Errorf("keys do not match: %s != %s", sn.key, other.key) + } + } + + // Validate node updates + if sn.node != nil && other.node != nil { + if err := sn.node.Update(other.node); err != nil { + return err + } + } + + // After validation, perform update + if other.key != nil && !other.key.Equal(NilKey) { + sn.key = other.key + } + + return nil } func NewStorageNode(key *Key, node *Node) *StorageNode { return &StorageNode{key: key, node: node} } +// NewPartialStorageNode creates a new StorageNode with a given key and value, +// where the right and left children are nil. +func NewPartialStorageNode(key *Key, value *felt.Felt) *StorageNode { + return &StorageNode{ + key: key, + node: &Node{ + Value: value, + Left: NilKey, + Right: NilKey, + }, + } +} + +// StorageNodeSet wraps OrderedSet to provide specific functionality for StorageNodes +type StorageNodeSet struct { + set *utils.OrderedSet[Key, *StorageNode] +} + +func NewStorageNodeSet() *StorageNodeSet { + return &StorageNodeSet{ + set: utils.NewOrderedSet[Key, *StorageNode](), + } +} + +func (s *StorageNodeSet) Get(key Key) (*StorageNode, bool) { + return s.set.Get(key) +} + +// Put adds a new StorageNode or updates an existing one. +func (s *StorageNodeSet) Put(key Key, node *StorageNode) error { + if node == nil { + return fmt.Errorf("cannot put nil node") + } + + // If key exists, update the node + if existingNode, exists := s.set.Get(key); exists { + if err := existingNode.Update(node); err != nil { + return fmt.Errorf("failed to update node for key %v: %w", key, err) + } + return nil + } + + // Add new node if key doesn't exist + s.set.Put(key, node) + return nil +} + +// List returns the list of StorageNodes in the set. +func (s *StorageNodeSet) List() []*StorageNode { + return s.set.List() +} + +func (s *StorageNodeSet) Size() int { + return s.set.Size() +} + // nodesFromRoot enumerates the set of [Node] objects that need to be traversed from the root // of the Trie to the node which is given by the key. // The [storageNode]s are returned in descending order beginning with the root. @@ -180,7 +243,7 @@ func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { return nodes, nil } - if key.Test(key.Len() - cur.Len() - 1) { + if key.IsBitSet(key.Len() - cur.Len() - 1) { cur = node.Right } else { cur = node.Left @@ -192,7 +255,7 @@ func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { // Get the corresponding `value` for a `key` func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { - storageKey := t.feltToKey(key) + storageKey := t.FeltToKey(key) value, err := t.storage.Get(&storageKey) if err != nil { if errors.Is(err, db.ErrKeyNotFound) { @@ -261,6 +324,7 @@ func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent S } } +// TODO(weiihann): not a good idea to couple proof verification logic with trie logic func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { commonKey, _ := findCommonKey(nodeKey, sibling.key) @@ -274,7 +338,7 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode if err != nil { return err } - if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSet(nodeKey.Len() - commonKey.Len() - 1) { newParent.Right = nodeKey newParent.RightHash = node.Hash(nodeKey, t.hash) } else { @@ -286,7 +350,7 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode } t.dirtyNodes = append(t.dirtyNodes, &commonKey) } else { - if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSet(nodeKey.Len() - commonKey.Len() - 1) { newParent.Left, newParent.Right = sibling.key, nodeKey leftChild, rightChild = sibling.node, node } else { @@ -328,7 +392,7 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { } old := felt.Zero - nodeKey := t.feltToKey(key) + nodeKey := t.FeltToKey(key) node := &Node{ Value: value, } @@ -373,13 +437,13 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { } // Put updates the corresponding `value` for a `key` -func (t *Trie) PutWithProof(key, value *felt.Felt, lProofPath, rProofPath []StorageNode) (*felt.Felt, error) { +func (t *Trie) PutWithProof(key, value *felt.Felt, proof []*StorageNode) (*felt.Felt, error) { if key.Cmp(t.maxKey) > 0 { return nil, fmt.Errorf("key %s exceeds trie height %d", key, t.height) } old := felt.Zero - nodeKey := t.feltToKey(key) + nodeKey := t.FeltToKey(key) node := &Node{ Value: value, } @@ -417,24 +481,14 @@ func (t *Trie) PutWithProof(key, value *felt.Felt, lProofPath, rProofPath []Stor } // override the sibling to be the parent if it's a proof - parentIsProof, found := false, false - for _, proof := range lProofPath { - if proof.key.Equal(sibling.key) { - sibling = proof + parentIsProof := false + for _, proofNode := range proof { + if proofNode.key.Equal(sibling.key) { + sibling = *proofNode parentIsProof = true - found = true break } } - if !found { - for _, proof := range rProofPath { - if proof.key.Equal(sibling.key) { - sibling = proof - parentIsProof = true - break - } - } - } err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling, parentIsProof) if err != nil { @@ -445,14 +499,11 @@ func (t *Trie) PutWithProof(key, value *felt.Felt, lProofPath, rProofPath []Stor } // Put updates the corresponding `value` for a `key` -func (t *Trie) PutInner(key *Key, node *Node) (*felt.Felt, error) { +func (t *Trie) PutInner(key *Key, node *Node) error { if err := t.storage.Put(key, node); err != nil { - return nil, err - } - if t.rootKey == nil { - t.setRootKey(key) + return err } - return &felt.Zero, nil + return nil } func (t *Trie) setRootKey(newRootKey *Key) { @@ -461,9 +512,6 @@ func (t *Trie) setRootKey(newRootKey *Key) { } func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo - zeroFeltBytes := new(felt.Felt).Bytes() - nilKey := NewKey(0, zeroFeltBytes[:]) - node, err := t.storage.Get(key) if err != nil { return nil, err @@ -485,9 +533,9 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo } // Update inner proof nodes - if node.Left.Equal(&nilKey) && node.Right.Equal(&nilKey) { // leaf + if node.Left.Equal(NilKey) && node.Right.Equal(NilKey) { // leaf shouldUpdate = false - } else if node.Left.Equal(&nilKey) || node.Right.Equal(&nilKey) { // inner + } else if node.Left.Equal(NilKey) || node.Right.Equal(NilKey) { // inner shouldUpdate = true } if !shouldUpdate { @@ -496,11 +544,11 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo var leftIsProof, rightIsProof bool var leftHash, rightHash *felt.Felt - if node.Left.Equal(&nilKey) { + if node.Left.Equal(NilKey) { // key could be nil but hash cannot be leftIsProof = true leftHash = node.LeftHash } - if node.Right.Equal(&nilKey) { + if node.Right.Equal(NilKey) { rightIsProof = true rightHash = node.RightHash } @@ -698,12 +746,33 @@ func (t *Trie) dump(level int, parentP *Key) { } defer nodePool.Put(root) path := path(t.rootKey, parentP) - fmt.Printf("%sstorage : \"%s\" %d spec: \"%s\" %d bottom: \"%s\" \n", + + left := "" + right := "" + leftHash := "" + rightHash := "" + + if root.Left != nil { + left = root.Left.String() + } + if root.Right != nil { + right = root.Right.String() + } + if root.LeftHash != nil { + leftHash = root.LeftHash.String() + } + if root.RightHash != nil { + rightHash = root.RightHash.String() + } + + fmt.Printf("%skey : \"%s\" path: \"%s\" left: \"%s\" right: \"%s\" LH: \"%s\" RH: \"%s\" value: \"%s\" \n", strings.Repeat("\t", level), t.rootKey.String(), - t.rootKey.Len(), path.String(), - path.Len(), + left, + right, + leftHash, + rightHash, root.Value.String(), ) (&Trie{ diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 87f1801f78..5426cbcafa 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -26,7 +26,7 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) assert.Equal(t, val, value, "key-val not match") - assert.Equal(t, tempTrie.feltToKey(key), *tempTrie.rootKey, "root key not match single node's key") + assert.Equal(t, tempTrie.FeltToKey(key), *tempTrie.rootKey, "root key not match single node's key") }) t.Run("put a left then a right node", func(t *testing.T) { @@ -53,8 +53,8 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) // Check parent and its left right children - l := tempTrie.feltToKey(leftKey) - r := tempTrie.feltToKey(rightKey) + l := tempTrie.FeltToKey(leftKey) + r := tempTrie.FeltToKey(rightKey) commonKey, isSame := findCommonKey(&l, &r) require.False(t, isSame) @@ -69,8 +69,8 @@ func TestTrieKeys(t *testing.T) { parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) - assert.Equal(t, tempTrie.feltToKey(leftKey), *parentNode.Left) - assert.Equal(t, tempTrie.feltToKey(rightKey), *parentNode.Right) + assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) }) t.Run("put a right node then a left node", func(t *testing.T) { @@ -96,8 +96,8 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) // Check parent and its left right children - l := tempTrie.feltToKey(leftKey) - r := tempTrie.feltToKey(rightKey) + l := tempTrie.FeltToKey(leftKey) + r := tempTrie.FeltToKey(rightKey) commonKey, isSame := findCommonKey(&l, &r) require.False(t, isSame) @@ -108,8 +108,8 @@ func TestTrieKeys(t *testing.T) { parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) - assert.Equal(t, tempTrie.feltToKey(leftKey), *parentNode.Left) - assert.Equal(t, tempTrie.feltToKey(rightKey), *parentNode.Right) + assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) }) t.Run("Add new key to different branches", func(t *testing.T) { @@ -142,8 +142,8 @@ func TestTrieKeys(t *testing.T) { commonKey := NewKey(250, []byte{0x2}) parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) - assert.Equal(t, tempTrie.feltToKey(leftKey), *parentNode.Left) - assert.Equal(t, tempTrie.feltToKey(newKey), *parentNode.Right) + assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) }) //nolint: dupl t.Run("Add to right branch", func(t *testing.T) { @@ -153,8 +153,8 @@ func TestTrieKeys(t *testing.T) { commonKey := NewKey(250, []byte{0x3}) parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) - assert.Equal(t, tempTrie.feltToKey(newKey), *parentNode.Left) - assert.Equal(t, tempTrie.feltToKey(rightKey), *parentNode.Right) + assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) }) t.Run("Add new node as parent sibling", func(t *testing.T) { newKeyNum, err := strconv.ParseUint("000", 2, 64) @@ -170,7 +170,7 @@ func TestTrieKeys(t *testing.T) { parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) - assert.Equal(t, tempTrie.feltToKey(newKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) expectRightKey := NewKey(249, []byte{0x1}) @@ -246,8 +246,8 @@ func TestTrieKeysAfterDeleteSubtree(t *testing.T) { rootNode, err := tempTrie.storage.Get(&newRootKey) require.NoError(t, err) - assert.Equal(t, tempTrie.feltToKey(rightKey), *rootNode.Right) - assert.Equal(t, tempTrie.feltToKey(test.expectLeft), *rootNode.Left) + assert.Equal(t, tempTrie.FeltToKey(rightKey), *rootNode.Right) + assert.Equal(t, tempTrie.FeltToKey(test.expectLeft), *rootNode.Left) }) } } diff --git a/db/pebble/db.go b/db/pebble/db.go index 5974edf720..77aed603d7 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -60,7 +60,7 @@ func NewMem() (db.DB, error) { } // NewMemTest opens a new in-memory database, panics on error -func NewMemTest(t *testing.T) db.DB { +func NewMemTest(t testing.TB) db.DB { memDB, err := NewMem() if err != nil { t.Fatalf("create in-memory db: %v", err) diff --git a/utils/orderedset.go b/utils/orderedset.go new file mode 100644 index 0000000000..e6e7e2d948 --- /dev/null +++ b/utils/orderedset.go @@ -0,0 +1,67 @@ +package utils + +import ( + "sync" +) + +// OrderedSet is a thread-safe data structure that maintains both uniqueness and insertion order of elements. +// It combines the benefits of both maps and slices: +// - Uses a map for O(1) lookups and to ensure element uniqueness +// - Uses a slice to maintain insertion order and enable ordered iteration +// The data structure is safe for concurrent access through the use of a read-write mutex. +type OrderedSet[K comparable, V any] struct { + itemPos map[K]int // position of the node in the list + items []V + size int + lock sync.RWMutex +} + +func NewOrderedSet[K comparable, V any]() *OrderedSet[K, V] { + return &OrderedSet[K, V]{ + itemPos: make(map[K]int), + } +} + +func (ps *OrderedSet[K, V]) Put(key K, value V) { + ps.lock.Lock() + defer ps.lock.Unlock() + + // Update existing entry + if pos, exists := ps.itemPos[key]; exists { + ps.items[pos] = value + return + } + + // Insert new entry + ps.itemPos[key] = len(ps.items) + ps.items = append(ps.items, value) + ps.size++ +} + +func (ps *OrderedSet[K, V]) Get(key K) (V, bool) { + ps.lock.RLock() + defer ps.lock.RUnlock() + + if pos, ok := ps.itemPos[key]; ok { + return ps.items[pos], true + } + var zero V + return zero, false +} + +func (ps *OrderedSet[K, V]) Size() int { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return ps.size +} + +// List returns a shallow copy of the proof set's value list. +func (ps *OrderedSet[K, V]) List() []V { + ps.lock.RLock() + defer ps.lock.RUnlock() + + values := make([]V, len(ps.items)) + copy(values, ps.items) + return values +}