diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 11a55edae5..a2ee47b05d 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -16,12 +16,9 @@ const ( bits8 = 8 ) -var ( - maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} - emptyBitArray = new(BitArray) -) +var emptyBitArray = new(BitArray) -// BitArray is a structure that represents a bit array with length representing the number of used bits. +// Represents a bit array with length representing the number of used bits. // It uses a little endian representation to do bitwise operations of the words efficiently. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. // The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. @@ -35,6 +32,7 @@ func NewBitArray(length uint8, val uint64) *BitArray { return new(BitArray).SetUint64(length, val) } +// Returns the felt representation of the bit array. func (b *BitArray) Felt() felt.Felt { var f felt.Felt f.SetBytes(b.Bytes()) @@ -45,7 +43,7 @@ func (b *BitArray) Len() uint8 { return b.len } -// Bytes returns the bytes representation of the bit array in big endian format +// Returns the bytes representation of the bit array in big endian format // //nolint:mnd func (b *BitArray) Bytes() []byte { @@ -83,42 +81,7 @@ func (b *BitArray) Bytes() []byte { return res[:] } -// EqualMSBs checks if two bit arrays share the same most significant bits, where the length of -// the check is determined by the shorter array. Returns true if either array has -// length 0, or if the first min(b.len, x.len) MSBs are identical. -// -// For example: -// -// a = 1101 (len=4) -// b = 11010111 (len=8) -// a.EqualMSBs(b) = true // First 4 MSBs match -// -// a = 1100 (len=4) -// b = 1101 (len=4) -// a.EqualMSBs(b) = false // All bits compared, not equal -// -// a = 1100 (len=4) -// b = [] (len=0) -// a.EqualMSBs(b) = true // Zero length is always a prefix match -func (b *BitArray) EqualMSBs(x *BitArray) bool { - if b.len == x.len { - return b.Equal(x) - } - - if b.len == 0 || x.len == 0 { - return true - } - - // Compare only the first min(b.len, x.len) bits - minLen := b.len - if x.len < minLen { - minLen = x.len - } - - return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) -} - -// LSBs sets b to the least significant 'n' bits of x. +// Sets b to the least significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: @@ -164,7 +127,42 @@ func (b *BitArray) LSBs(x *BitArray, length uint8) *BitArray { return b } -// MSBs sets b to the most significant 'n' bits of x. +// Checks if the current bit array share the same most significant bits with another, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *BitArray) EqualMSBs(x *BitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + if b.len == 0 || x.len == 0 { + return true + } + + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len + } + + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) +} + +// Sets b to the most significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: @@ -181,7 +179,7 @@ func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { return b.Rsh(x, x.len-n) } -// CommonMSBs sets b to the longest sequence of matching most significant bits between two bit arrays. +// Sets b to the longest sequence of matching most significant bits between two bit arrays. // For example: // // x = 1101 0111 (len=8) @@ -219,7 +217,7 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { return b.Rsh(short, divergentBit) } -// Rsh sets b = x >> n and returns b. +// Sets b = x >> n and returns b. // //nolint:mnd func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { @@ -264,7 +262,7 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { return b } -// Xor sets b = x ^ y and returns b. +// Sets b = x ^ y and returns b. func (b *BitArray) Xor(x, y *BitArray) *BitArray { b.words[0] = x.words[0] ^ y.words[0] b.words[1] = x.words[1] ^ y.words[1] @@ -273,7 +271,7 @@ func (b *BitArray) Xor(x, y *BitArray) *BitArray { return b } -// Eq checks if two bit arrays are equal +// Checks if two bit arrays are equal func (b *BitArray) Equal(x *BitArray) bool { // TODO(weiihann): this is really not a good thing to do... if b == nil && x == nil { @@ -289,7 +287,7 @@ func (b *BitArray) Equal(x *BitArray) bool { b.words[3] == x.words[3] } -// IsBitSit returns true if bit n-th is set, where n = 0 is LSB. +// Returns true if bit n-th is set, where n = 0 is LSB. // The n must be <= 255. func (b *BitArray) IsBitSet(n uint8) bool { if n >= b.len { @@ -299,7 +297,7 @@ func (b *BitArray) IsBitSet(n uint8) bool { return (b.words[n/64] & (1 << (n % 64))) != 0 } -// Write serialises the BitArray into a bytes buffer in the following format: +// Serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: @@ -314,7 +312,7 @@ func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { return n + 1, err } -// UnmarshalBinary deserialises the BitArray from a bytes buffer in the following format: +// Deserialises the BitArray from a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: @@ -328,6 +326,7 @@ func (b *BitArray) UnmarshalBinary(data []byte) { b.setBytes32(bs[:]) } +// Sets b to the same value as x. func (b *BitArray) Set(x *BitArray) *BitArray { b.len = x.len b.words[0] = x.words[0] @@ -337,40 +336,48 @@ func (b *BitArray) Set(x *BitArray) *BitArray { return b } +// Sets b to the bytes representation of a felt. func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { b.setFelt(f) b.len = length return b } +// Sets b to the bytes representation of a felt with length 251. func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { b.setFelt(f) b.len = 251 return b } +// Interprets the data as the big-endian bytes, sets b to that value and returns b. +// If the data is larger than 32 bytes, only the first 32 bytes are used. func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { b.setBytes32(data) b.len = length return b } +// Sets b to the uint64 representation of a bit array. func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data b.len = length return b } +// Returns the length of the encoded bit array in bytes. func (b *BitArray) EncodedLen() uint { return b.byteCount() + 1 } +// Returns a deep copy of b. func (b *BitArray) Copy() BitArray { var res BitArray res.Set(b) return res } +// Returns a string representation of the bit array. func (b *BitArray) String() string { return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index c90223ab6a..479df49fd1 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "encoding/binary" + "math" "math/bits" "testing" @@ -11,6 +12,8 @@ import ( "github.com/stretchr/testify/require" ) +var maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + const ( ones63 = 0x7FFFFFFFFFFFFFFF )