Skip to content

Commit

Permalink
add MSBs() and rename Truncate to LSBs
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Dec 18, 2024
1 parent 95d759f commit 3e4d4c8
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 34 deletions.
39 changes: 24 additions & 15 deletions core/trie/bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const (
)

var (
maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64}
maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64}
emptyBitArray = new(BitArray)
)

Expand Down Expand Up @@ -119,18 +119,18 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool {
return long.Rsh(long, long.len-short.len).Equal(short)
}

// Truncate sets b to the first 'length' bits of x (starting from the least significant bit).
// If length >= x.len, b is an exact copy of x.
// LSBs sets b to the least significant 'n' bits of x.
// If n >= x.len, b is an exact copy of x.
// Any bits beyond the specified length are cleared to zero.
// For example:
//
// x = 11001011 (len=8)
// Truncate(x, 4) = 1011 (len=4)
// Truncate(x, 10) = 11001011 (len=8, original x)
// Truncate(x, 0) = 0 (len=0)
// LSBs(x, 4) = 1011 (len=4)
// LSBs(x, 10) = 11001011 (len=8, original x)
// LSBs(x, 0) = 0 (len=0)
//
//nolint:mnd
func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray {
func (b *BitArray) LSBs(x *BitArray, length uint8) *BitArray {
if length >= x.len {
return b.Set(x)
}
Expand Down Expand Up @@ -165,6 +165,23 @@ func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray {
return b
}

// MSBs sets b to the most significant 'n' bits of x.
// If n >= x.len, b is an exact copy of x.
// Any bits beyond the specified length are cleared to zero.
// For example:
//
// x = 11001011 (len=8)
// MSBs(x, 4) = 1100 (len=4)
// MSBs(x, 10) = 11001011 (len=8, original x)
// MSBs(x, 0) = 0 (len=0)
func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray {
if n >= x.len {
return b.Set(x)
}

return b.Rsh(x, x.len-n)
}

// CommonMSBs sets b to the longest sequence of matching most significant bits between two bit arrays.
// For example:
//
Expand Down Expand Up @@ -277,14 +294,6 @@ func (b *BitArray) IsBitSet(n uint8) bool {
return (b.words[n/64] & (1 << (n % 64))) != 0
}

func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray {
if n >= x.len {
return b.Set(x)
}

return b.Rsh(x, x.len-n)
}

// Write serialises the BitArray into a bytes buffer in the following format:
// - First byte: length of the bit array (0-255)
// - Remaining bytes: the necessary bytes included in big endian order
Expand Down
155 changes: 136 additions & 19 deletions core/trie/bitarray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
"github.com/stretchr/testify/require"
)

const (
ones63 = 0x7FFFFFFFFFFFFFFF
)

func TestBytes(t *testing.T) {
tests := []struct {
name string
Expand All @@ -19,12 +23,12 @@ func TestBytes(t *testing.T) {
}{
{
name: "length == 0",
ba: BitArray{len: 0, words: maxBitArray},
ba: BitArray{len: 0, words: maxBits},
want: [32]byte{},
},
{
name: "length < 64",
ba: BitArray{len: 38, words: maxBitArray},
ba: BitArray{len: 38, words: maxBits},
want: func() [32]byte {
var b [32]byte
binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF)
Expand All @@ -33,7 +37,7 @@ func TestBytes(t *testing.T) {
},
{
name: "64 <= length < 128",
ba: BitArray{len: 100, words: maxBitArray},
ba: BitArray{len: 100, words: maxBits},
want: func() [32]byte {
var b [32]byte
binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF)
Expand All @@ -43,7 +47,7 @@ func TestBytes(t *testing.T) {
},
{
name: "128 <= length < 192",
ba: BitArray{len: 130, words: maxBitArray},
ba: BitArray{len: 130, words: maxBits},
want: func() [32]byte {
var b [32]byte
binary.BigEndian.PutUint64(b[8:16], 0x3)
Expand All @@ -54,7 +58,7 @@ func TestBytes(t *testing.T) {
},
{
name: "192 <= length < 255",
ba: BitArray{len: 201, words: maxBitArray},
ba: BitArray{len: 201, words: maxBits},
want: func() [32]byte {
var b [32]byte
binary.BigEndian.PutUint64(b[0:8], 0x1FF)
Expand All @@ -66,7 +70,7 @@ func TestBytes(t *testing.T) {
},
{
name: "length == 254",
ba: BitArray{len: 254, words: maxBitArray},
ba: BitArray{len: 254, words: maxBits},
want: func() [32]byte {
var b [32]byte
binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF)
Expand All @@ -78,10 +82,10 @@ func TestBytes(t *testing.T) {
},
{
name: "length == 255",
ba: BitArray{len: 255, words: maxBitArray},
ba: BitArray{len: 255, words: maxBits},
want: func() [32]byte {
var b [32]byte
binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF)
binary.BigEndian.PutUint64(b[0:8], ones63)
binary.BigEndian.PutUint64(b[8:16], maxUint64)
binary.BigEndian.PutUint64(b[16:24], maxUint64)
binary.BigEndian.PutUint64(b[24:32], maxUint64)
Expand Down Expand Up @@ -180,7 +184,7 @@ func TestRsh(t *testing.T) {
name: "shift by 127",
initial: &BitArray{
len: 255,
words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF},
words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63},
},
shiftBy: 127,
expected: &BitArray{
Expand Down Expand Up @@ -342,7 +346,7 @@ func TestPrefixEqual(t *testing.T) {
}
}

func TestTruncate(t *testing.T) {
func TestLSBs(t *testing.T) {
tests := []struct {
name string
initial BitArray
Expand Down Expand Up @@ -497,14 +501,127 @@ func TestTruncate(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := new(BitArray).Truncate(&tt.initial, tt.length)
result := new(BitArray).LSBs(&tt.initial, tt.length)
if !result.Equal(&tt.expected) {
t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected)
}
})
}
}

func TestMSBs(t *testing.T) {
tests := []struct {
name string
x *BitArray
n uint8
want *BitArray
}{
{
name: "empty array",
x: emptyBitArray,
n: 0,
want: emptyBitArray,
},
{
name: "get all bits",
x: &BitArray{
len: 64,
words: [4]uint64{maxUint64, 0, 0, 0},
},
n: 64,
want: &BitArray{
len: 64,
words: [4]uint64{maxUint64, 0, 0, 0},
},
},
{
name: "get more bits than available",
x: &BitArray{
len: 32,
words: [4]uint64{0xFFFFFFFF, 0, 0, 0},
},
n: 64,
want: &BitArray{
len: 32,
words: [4]uint64{0xFFFFFFFF, 0, 0, 0},
},
},
{
name: "get half of available bits",
x: &BitArray{
len: 64,
words: [4]uint64{maxUint64, 0, 0, 0},
},
n: 32,
want: &BitArray{
len: 32,
words: [4]uint64{0xFFFFFFFF00000000 >> 32, 0, 0, 0},
},
},
{
name: "get MSBs across word boundary",
x: &BitArray{
len: 128,
words: [4]uint64{maxUint64, maxUint64, 0, 0},
},
n: 100,
want: &BitArray{
len: 100,
words: [4]uint64{maxUint64, maxUint64 >> 28, 0, 0},
},
},
{
name: "get MSBs from max length array",
x: &BitArray{
len: 255,
words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63},
},
n: 64,
want: &BitArray{
len: 64,
words: [4]uint64{maxUint64, 0, 0, 0},
},
},
{
name: "get zero bits",
x: &BitArray{
len: 64,
words: [4]uint64{maxUint64, 0, 0, 0},
},
n: 0,
want: &BitArray{
len: 0,
words: [4]uint64{0, 0, 0, 0},
},
},
{
name: "sparse bits",
x: &BitArray{
len: 128,
words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0},
},
n: 64,
want: &BitArray{
len: 64,
words: [4]uint64{0x5555555555555555, 0, 0, 0},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := new(BitArray).MSBs(tt.x, tt.n)
if !got.Equal(tt.want) {
t.Errorf("MSBs() = %v, want %v", got, tt.want)
}

if got.len != tt.want.len {
t.Errorf("MSBs() = %v, want %v", got, tt.want)
}
})
}
}

func TestWriteAndUnmarshalBinary(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -671,22 +788,22 @@ func TestCommonPrefix(t *testing.T) {
name: "different lengths with common prefix - multiple words",
x: &BitArray{
len: 255,
words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF},
words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63},
},
y: &BitArray{
len: 127,
words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0},
words: [4]uint64{maxUint64, ones63, 0, 0},
},
want: &BitArray{
len: 127,
words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0},
words: [4]uint64{maxUint64, ones63, 0, 0},
},
},
{
name: "different at first bit",
x: &BitArray{
len: 64,
words: [4]uint64{0x7FFFFFFFFFFFFFFF, 0, 0, 0},
words: [4]uint64{ones63, 0, 0, 0},
},
y: &BitArray{
len: 64,
Expand Down Expand Up @@ -776,7 +893,7 @@ func TestCommonPrefix(t *testing.T) {
name: "max length difference",
x: &BitArray{
len: 255,
words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF},
words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63},
},
y: &BitArray{
len: 1,
Expand Down Expand Up @@ -961,12 +1078,12 @@ func TestFeltConversion(t *testing.T) {
want: "0xffffffffffffffffffffffffffffffffffffffffffffffff",
},
{
name: "max length (251 bits)",
name: "251 bits",
ba: BitArray{
len: 255,
len: 251,
words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF},
},
length: 255,
length: 251,
want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
},
{
Expand Down

0 comments on commit 3e4d4c8

Please sign in to comment.