Skip to content

Commit

Permalink
Fix and refactor trie proof logics (#2252)
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann authored Dec 12, 2024
1 parent 2b1b219 commit 65b7507
Show file tree
Hide file tree
Showing 9 changed files with 1,536 additions and 1,591 deletions.
115 changes: 73 additions & 42 deletions core/trie/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package trie
import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"math/big"

"github.com/NethermindEth/juno/core/felt"
)

var NilKey = &Key{len: 0, bitset: [32]byte{}}

type Key struct {
len uint8
bitset [32]byte
Expand All @@ -24,26 +25,6 @@ func NewKey(length uint8, keyBytes []byte) Key {
return k
}

func (k *Key) SubKey(n uint8) (*Key, error) {
if n > k.len {
return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len))
}

newKey := &Key{len: n}
copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:mnd

// Shift right by the number of bits that are not needed
shift := k.len - n
for i := len(newKey.bitset) - 1; i >= 0; i-- {
newKey.bitset[i] >>= shift
if i > 0 {
newKey.bitset[i] |= newKey.bitset[i-1] << (8 - shift)
}
}

return newKey, nil
}

func (k *Key) bytesNeeded() uint {
const byteBits = 8
return (uint(k.len) + (byteBits - 1)) / byteBits
Expand Down Expand Up @@ -96,31 +77,48 @@ func (k *Key) Equal(other *Key) bool {
return k.len == other.len && k.bitset == other.bitset
}

func (k *Key) Test(bit uint8) bool {
// IsBitSet returns whether the bit at the given position is 1.
// Position 0 represents the least significant (rightmost) bit.
func (k *Key) IsBitSet(position uint8) bool {
const LSB = uint8(0x1)
byteIdx := bit / 8
byteIdx := position / 8
byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1]
bitIdx := bit % 8
bitIdx := position % 8
return ((byteAtIdx >> bitIdx) & LSB) != 0
}

func (k *Key) String() string {
return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:]))
}

// DeleteLSB right shifts and shortens the key
func (k *Key) DeleteLSB(n uint8) {
// shiftRight removes n least significant bits from the key by performing a right shift
// operation and reducing the key length. For example, if the key contains bits
// "1111 0000" (length=8) and n=4, the result will be "1111" (length=4).
//
// The operation is destructive - it modifies the key in place.
func (k *Key) shiftRight(n uint8) {
if k.len < n {
panic("deleting more bits than there are")
}

if n == 0 {
return
}

var bigInt big.Int
bigInt.SetBytes(k.bitset[:])
bigInt.Rsh(&bigInt, uint(n))
bigInt.FillBytes(k.bitset[:])
k.len -= n
}

// MostSignificantBits returns a new key with the most significant n bits of the current key.
func (k *Key) MostSignificantBits(n uint8) (*Key, error) {
if n > k.len {
return nil, fmt.Errorf("cannot get more bits than the key length")
}

keyCopy := k.Copy()
keyCopy.shiftRight(k.len - n)
return &keyCopy, nil
}

// Truncate truncates key to `length` bits by clearing the remaining upper bits
func (k *Key) Truncate(length uint8) {
k.len = length
Expand All @@ -136,20 +134,53 @@ func (k *Key) Truncate(length uint8) {
}
}

func (k *Key) RemoveLastBit() {
if k.len == 0 {
return
}
func (k *Key) String() string {
return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:]))
}

k.len--
// Copy returns a deep copy of the key
func (k *Key) Copy() Key {
newKey := Key{len: k.len}
copy(newKey.bitset[:], k.bitset[:])
return newKey
}

unusedBytes := k.unusedBytes()
clear(unusedBytes)
func (k *Key) Bytes() [32]byte {
var result [32]byte
copy(result[:], k.bitset[:])
return result
}

// clear upper bits on the last used byte
inUseBytes := k.inUseBytes()
unusedBitsCount := 8 - (k.len % 8)
if unusedBitsCount != 8 && len(inUseBytes) > 0 {
inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount
// findCommonKey finds the set of common MSB bits in two key bitsets.
func findCommonKey(longerKey, shorterKey *Key) (Key, bool) {
divergentBit := findDivergentBit(longerKey, shorterKey)

if divergentBit == 0 {
return *NilKey, false
}

commonKey := *shorterKey
commonKey.shiftRight(shorterKey.Len() - divergentBit + 1)
return commonKey, divergentBit == shorterKey.Len()+1
}

// findDivergentBit finds the first bit that is different between two keys,
// starting from the most significant bit of both keys.
func findDivergentBit(longerKey, shorterKey *Key) uint8 {
divergentBit := uint8(0)
for divergentBit <= shorterKey.Len() &&
longerKey.IsBitSet(longerKey.Len()-divergentBit) == shorterKey.IsBitSet(shorterKey.Len()-divergentBit) {
divergentBit++
}
return divergentBit
}

func isSubset(longerKey, shorterKey *Key) bool {
divergentBit := findDivergentBit(longerKey, shorterKey)
return divergentBit == shorterKey.Len()+1
}

func FeltToKey(length uint8, key *felt.Felt) Key {
keyBytes := key.Bytes()
return NewKey(length, keyBytes[:])
}
156 changes: 115 additions & 41 deletions core/trie/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,47 +68,6 @@ func BenchmarkKeyEncoding(b *testing.B) {
}
}

func TestKeyTest(t *testing.T) {
key := trie.NewKey(44, []byte{0x10, 0x02})
for i := 0; i < int(key.Len()); i++ {
assert.Equal(t, i == 1 || i == 12, key.Test(uint8(i)), i)
}
}

func TestDeleteLSB(t *testing.T) {
key := trie.NewKey(16, []byte{0xF3, 0x04})

tests := map[string]struct {
shiftAmount uint8
expectedKey trie.Key
}{
"delete 0 bits": {
shiftAmount: 0,
expectedKey: key,
},
"delete 4 bits": {
shiftAmount: 4,
expectedKey: trie.NewKey(12, []byte{0x0F, 0x30}),
},
"delete 8 bits": {
shiftAmount: 8,
expectedKey: trie.NewKey(8, []byte{0xF3}),
},
"delete 9 bits": {
shiftAmount: 9,
expectedKey: trie.NewKey(7, []byte{0x79}),
},
}

for desc, test := range tests {
t.Run(desc, func(t *testing.T) {
copyKey := key
copyKey.DeleteLSB(test.shiftAmount)
assert.Equal(t, test.expectedKey, copyKey)
})
}
}

func TestTruncate(t *testing.T) {
tests := map[string]struct {
key trie.Key
Expand Down Expand Up @@ -153,3 +112,118 @@ func TestTruncate(t *testing.T) {
})
}
}

func TestKeyTest(t *testing.T) {
key := trie.NewKey(44, []byte{0x10, 0x02})
for i := 0; i < int(key.Len()); i++ {
assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i)
}
}

func TestIsBitSet(t *testing.T) {
tests := map[string]struct {
key trie.Key
position uint8
expected bool
}{
"single byte, LSB set": {
key: trie.NewKey(8, []byte{0x01}),
position: 0,
expected: true,
},
"single byte, MSB set": {
key: trie.NewKey(8, []byte{0x80}),
position: 7,
expected: true,
},
"single byte, middle bit set": {
key: trie.NewKey(8, []byte{0x10}),
position: 4,
expected: true,
},
"single byte, bit not set": {
key: trie.NewKey(8, []byte{0xFE}),
position: 0,
expected: false,
},
"multiple bytes, LSB set": {
key: trie.NewKey(16, []byte{0x00, 0x02}),
position: 1,
expected: true,
},
"multiple bytes, MSB set": {
key: trie.NewKey(16, []byte{0x01, 0x00}),
position: 8,
expected: true,
},
"multiple bytes, no bits set": {
key: trie.NewKey(16, []byte{0x00, 0x00}),
position: 7,
expected: false,
},
"check all bits in pattern": {
key: trie.NewKey(8, []byte{0xA5}), // 10100101
position: 0,
expected: true,
},
}

// Additional test for 0xA5 pattern
key := trie.NewKey(8, []byte{0xA5}) // 10100101
expectedBits := []bool{true, false, true, false, false, true, false, true}
for i, expected := range expectedBits {
assert.Equal(t, expected, key.IsBitSet(uint8(i)), "bit %d in 0xA5", i)
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result := tc.key.IsBitSet(tc.position)
assert.Equal(t, tc.expected, result)
})
}
}

func TestMostSignificantBits(t *testing.T) {
tests := []struct {
name string
key trie.Key
n uint8
want trie.Key
expectErr bool
}{
{
name: "Valid case",
key: trie.NewKey(8, []byte{0b11110000}),
n: 4,
want: trie.NewKey(4, []byte{0b00001111}),
expectErr: false,
},
{
name: "Request more bits than available",
key: trie.NewKey(8, []byte{0b11110000}),
n: 10,
want: trie.Key{},
expectErr: true,
},
{
name: "Zero bits requested",
key: trie.NewKey(8, []byte{0b11110000}),
n: 0,
want: trie.NewKey(0, []byte{}),
expectErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.key.MostSignificantBits(tt.n)
if (err != nil) != tt.expectErr {
t.Errorf("MostSignificantBits() error = %v, expectErr %v", err, tt.expectErr)
return
}
if !tt.expectErr && !got.Equal(&tt.want) {
t.Errorf("MostSignificantBits() = %v, want %v", got, tt.want)
}
})
}
}
53 changes: 53 additions & 0 deletions core/trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package trie
import (
"bytes"
"errors"
"fmt"

"github.com/NethermindEth/juno/core/felt"
)
Expand Down Expand Up @@ -138,3 +139,55 @@ func (n *Node) UnmarshalBinary(data []byte) error {
n.RightHash.SetBytes(data[:felt.Bytes])
return nil
}

func (n *Node) String() string {
return fmt.Sprintf("Node{Value: %s, Left: %s, Right: %s, LeftHash: %s, RightHash: %s}", n.Value, n.Left, n.Right, n.LeftHash, n.RightHash)
}

// Update the receiver with non-nil fields from the `other` Node.
// If a field is non-nil in both Nodes, they must be equal, or an error is returned.
//
// This method modifies the receiver in-place and returns an error if any field conflicts are detected.
//
//nolint:gocyclo
func (n *Node) Update(other *Node) error {
// First validate all fields for conflicts
if n.Value != nil && other.Value != nil && !n.Value.Equal(other.Value) {
return fmt.Errorf("conflicting Values: %v != %v", n.Value, other.Value)
}

if n.Left != nil && other.Left != nil && !n.Left.Equal(NilKey) && !other.Left.Equal(NilKey) && !n.Left.Equal(other.Left) {
return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left)
}

if n.Right != nil && other.Right != nil && !n.Right.Equal(NilKey) && !other.Right.Equal(NilKey) && !n.Right.Equal(other.Right) {
return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right)
}

if n.LeftHash != nil && other.LeftHash != nil && !n.LeftHash.Equal(other.LeftHash) {
return fmt.Errorf("conflicting LeftHash: %v != %v", n.LeftHash, other.LeftHash)
}

if n.RightHash != nil && other.RightHash != nil && !n.RightHash.Equal(other.RightHash) {
return fmt.Errorf("conflicting RightHash: %v != %v", n.RightHash, other.RightHash)
}

// After validation, perform all updates
if other.Value != nil {
n.Value = other.Value
}
if other.Left != nil && !other.Left.Equal(NilKey) {
n.Left = other.Left
}
if other.Right != nil && !other.Right.Equal(NilKey) {
n.Right = other.Right
}
if other.LeftHash != nil {
n.LeftHash = other.LeftHash
}
if other.RightHash != nil {
n.RightHash = other.RightHash
}

return nil
}
Loading

0 comments on commit 65b7507

Please sign in to comment.