Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(zktrie): fix deletion proofs and collect them in commiting phase #263

Merged
merged 14 commits into from
May 8, 2023
Merged
86 changes: 86 additions & 0 deletions core/state/state_prove.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package state

import (
"errors"
"fmt"

zkt "github.com/scroll-tech/zktrie/types"

zktrie "github.com/scroll-tech/go-ethereum/trie"

"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/crypto"
"github.com/scroll-tech/go-ethereum/ethdb"
)

type TrieProve interface {
Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error
}

type ZktrieProofTracer struct {
*zktrie.ProofTracer
}

// MarkDeletion overwrite the underlayer method with secure key
func (t ZktrieProofTracer) MarkDeletion(key common.Hash) {
key_s, _ := zkt.ToSecureKeyBytes(key.Bytes())
t.ProofTracer.MarkDeletion(key_s.Bytes())
}
noel2004 marked this conversation as resolved.
Show resolved Hide resolved

// Merge overwrite underlayer method with proper argument
func (t ZktrieProofTracer) Merge(another ZktrieProofTracer) {
t.ProofTracer.Merge(another.ProofTracer)
}

func (t ZktrieProofTracer) Available() bool {
return t.ProofTracer != nil
}

// NewProofTracer is not in Db interface and used explictily for reading proof in storage trie (not updated by the dirty value)
func (s *StateDB) NewProofTracer(trieS Trie) ZktrieProofTracer {
if s.IsZktrie() {
zkTrie := trieS.(*zktrie.ZkTrie)
if zkTrie == nil {
panic("unexpected trie type for zktrie")
}
return ZktrieProofTracer{zkTrie.NewProofTracer()}
}
return ZktrieProofTracer{}
}

// GetStorageTrieForProof is not in Db interface and used explictily for reading proof in storage trie (not updated by the dirty value)
func (s *StateDB) GetStorageTrieForProof(addr common.Address) (Trie, error) {

// try the trie in stateObject first, else we would create one
stateObject := s.getStateObject(addr)
if stateObject == nil {
return nil, errors.New("storage trie for requested address does not exist")
noel2004 marked this conversation as resolved.
Show resolved Hide resolved
}

trie := stateObject.trie
var err error
if trie == nil {
// use a new, temporary trie
trie, err = s.db.OpenStorageTrie(stateObject.addrHash, stateObject.data.Root)
noel2004 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, fmt.Errorf("can't create storage trie on root %s: %v ", stateObject.data.Root, err)
}
}

return trie, nil
}

// GetSecureTrieProof handle any interface with Prove (should be a Trie in most case) and
// deliver the proof in bytes
func (s *StateDB) GetSecureTrieProof(trieProve TrieProve, key common.Hash) ([][]byte, error) {

var proof proofList
var err error
if s.IsZktrie() {
key_s, _ := zkt.ToSecureKeyBytes(key.Bytes())
err = trieProve.Prove(key_s.Bytes(), 0, &proof)
} else {
err = trieProve.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
}
return proof, err
}
47 changes: 2 additions & 45 deletions core/state/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,56 +350,13 @@ func (s *StateDB) GetRootHash() common.Hash {
return s.trie.Hash()
}

// StorageTrieProof is not in Db interface and used explictily for reading proof in storage trie (not the dirty value)
// For zktrie it also provide required data for predict the deletion, else it just fallback to GetStorageProof
func (s *StateDB) GetStorageTrieProof(a common.Address, key common.Hash) ([][]byte, []byte, error) {

// try the trie in stateObject first, else we would create one
stateObject := s.getStateObject(a)
if stateObject == nil {
return nil, nil, errors.New("storage trie for requested address does not exist")
}

trieS := stateObject.trie
var err error
if trieS == nil {
// use a new, temporary trie
trieS, err = s.db.OpenStorageTrie(stateObject.addrHash, stateObject.data.Root)
if err != nil {
return nil, nil, fmt.Errorf("can't create storage trie on root %s: %v ", stateObject.data.Root, err)
}
}

var proof proofList
var sibling []byte
if s.IsZktrie() {
zkTrie := trieS.(*trie.ZkTrie)
if zkTrie == nil {
panic("unexpected trie type for zktrie")
}
key_s, _ := zkt.ToSecureKeyBytes(key.Bytes())
sibling, err = zkTrie.ProveWithDeletion(key_s.Bytes(), 0, &proof)
} else {
err = trieS.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
}
return proof, sibling, err
}

// GetStorageProof returns the Merkle proof for given storage slot.
func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) {
var proof proofList
trie := s.StorageTrie(a)
if trie == nil {
return proof, errors.New("storage trie for requested address does not exist")
return nil, errors.New("storage trie for requested address does not exist")
}
var err error
if s.IsZktrie() {
key_s, _ := zkt.ToSecureKeyBytes(key.Bytes())
err = trie.Prove(key_s.Bytes(), 0, &proof)
} else {
err = trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
}
return proof, err
return s.GetSecureTrieProof(trie, key)
}

// GetCommittedState retrieves a value from the given account's committed storage trie.
Expand Down
49 changes: 45 additions & 4 deletions eth/tracers/api_blocktrace.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tracers

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -42,6 +43,8 @@ type traceEnv struct {
// this lock is used to protect StorageTrace's read and write mutual exclusion.
sMu sync.Mutex
*types.StorageTrace
// zktrie tracer is used for zktrie storage to build additional deletion proof
zkTrieTracer map[string]state.ZktrieProofTracer
executionResults []*types.ExecutionResult
}

Expand Down Expand Up @@ -119,6 +122,7 @@ func (api *API) createTraceEnv(ctx context.Context, config *TraceConfig, block *
Proofs: make(map[string][]hexutil.Bytes),
StorageProofs: make(map[string]map[string][]hexutil.Bytes),
},
zkTrieTracer: make(map[string]state.ZktrieProofTracer),
executionResults: make([]*types.ExecutionResult, block.Transactions().Len()),
}

Expand Down Expand Up @@ -189,6 +193,18 @@ func (api *API) getBlockTrace(block *types.Block, env *traceEnv) (*types.BlockTr
close(jobs)
pend.Wait()

// after all tx has been traced, collect "deletion proof" for zktrie
for _, tracer := range env.zkTrieTracer {
delProofs, err := tracer.GetDeletionProofs()
if err != nil {
log.Error("deletion proof failure", "error", err)
noel2004 marked this conversation as resolved.
Show resolved Hide resolved
} else {
for _, proof := range delProofs {
env.DeletionProofs = append(env.DeletionProofs, proof)
}
Thegaram marked this conversation as resolved.
Show resolved Hide resolved
}
}

// If execution failed in between, abort
select {
case err := <-errCh:
Expand Down Expand Up @@ -299,22 +315,47 @@ func (api *API) getTxResult(env *traceEnv, state *state.StateDB, index int, bloc

proofStorages := tracer.UpdatedStorages()
for addr, keys := range proofStorages {
for key := range keys {
env.sMu.Lock()
trie, err := state.GetStorageTrieForProof(addr)
if err != nil {
// but we still continue to next address
noel2004 marked this conversation as resolved.
Show resolved Hide resolved
log.Error("Storage trie not available", "error", err, "address", addr)
env.sMu.Unlock()
continue
}
zktrieTracer := state.NewProofTracer(trie)
env.sMu.Unlock()

for key, values := range keys {
addrStr := addr.String()
keyStr := key.String()
isDelete := bytes.Equal(values.Bytes(), common.Hash{}.Bytes())
noel2004 marked this conversation as resolved.
Show resolved Hide resolved

env.sMu.Lock()
m, existed := env.StorageProofs[addrStr]
if !existed {
m = make(map[string][]hexutil.Bytes)
env.StorageProofs[addrStr] = m
if zktrieTracer.Available() {
env.zkTrieTracer[addrStr] = zktrieTracer
}
} else if _, existed := m[keyStr]; existed {
// still need to touch tracer for deletion
if isDelete && zktrieTracer.Available() {
env.zkTrieTracer[addrStr].MarkDeletion(key)
}
env.sMu.Unlock()
continue
}
env.sMu.Unlock()

proof, sibling, err := state.GetStorageTrieProof(addr, key)
var proof [][]byte
var err error
if zktrieTracer.Available() {
proof, err = state.GetSecureTrieProof(zktrieTracer, key)
} else {
proof, err = state.GetSecureTrieProof(trie, key)
}
noel2004 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
log.Error("Storage proof not available", "error", err, "address", addrStr, "key", keyStr)
// but we still mark the proofs map with nil array
Expand All @@ -325,8 +366,8 @@ func (api *API) getTxResult(env *traceEnv, state *state.StateDB, index int, bloc
}
env.sMu.Lock()
m[keyStr] = wrappedProof
if sibling != nil {
env.DeletionProofs = append(env.DeletionProofs, sibling)
if zktrieTracer.Available() {
env.zkTrieTracer[addrStr].Merge(zktrieTracer)
noel2004 marked this conversation as resolved.
Show resolved Hide resolved
}
env.sMu.Unlock()
}
Expand Down
52 changes: 16 additions & 36 deletions trie/zk_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,49 +174,29 @@ func (t *ZkTrie) NodeIterator(start []byte) NodeIterator {
// nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key.
func (t *ZkTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error {
// omit sibling, which is not required for proving only
_, err := t.ProveWithDeletion(key, fromLevel, proofDb)
return err
}

// ProveWithDeletion is the implement of Prove, it also return possible sibling node
// (if there is, i.e. the node of key exist and is not the only node in trie)
// so witness generator can predict the final state root after deletion of this key
// the returned sibling node has no key along with it for witness generator must decode
// the node for its purpose
func (t *ZkTrie) ProveWithDeletion(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) (sibling []byte, err error) {
err = t.ZkTrie.ProveWithDeletion(key, fromLevel,
func(n *zktrie.Node) error {
nodeHash, err := n.NodeHash()
if err != nil {
return err
}
err := t.ZkTrie.Prove(key, fromLevel, func(n *zktrie.Node) error {
nodeHash, err := n.NodeHash()
if err != nil {
return err
}

if n.Type == zktrie.NodeTypeLeaf {
preImage := t.GetKey(n.NodeKey.Bytes())
if len(preImage) > 0 {
n.KeyPreimage = &zkt.Byte32{}
copy(n.KeyPreimage[:], preImage)
//return fmt.Errorf("key preimage not found for [%x] ref %x", n.NodeKey.Bytes(), k.Bytes())
}
}
return proofDb.Put(nodeHash[:], n.Value())
},
func(_ *zktrie.Node, n *zktrie.Node) {
// the sibling for each leaf should be unique except for EmptyNode
if n != nil && n.Type != zktrie.NodeTypeEmpty {
sibling = n.Value()
if n.Type == zktrie.NodeTypeLeaf {
preImage := t.GetKey(n.NodeKey.Bytes())
if len(preImage) > 0 {
n.KeyPreimage = &zkt.Byte32{}
copy(n.KeyPreimage[:], preImage)
//return fmt.Errorf("key preimage not found for [%x] ref %x", n.NodeKey.Bytes(), k.Bytes())
}
},
)
}
return proofDb.Put(nodeHash[:], n.Value())
})
if err != nil {
return
return err
}

// we put this special kv pair in db so we can distinguish the type and
// make suitable Proof
err = proofDb.Put(magicHash, zktrie.ProofMagicBytes())
return
return proofDb.Put(magicHash, zktrie.ProofMagicBytes())
}

// VerifyProof checks merkle proofs. The given proof must contain the value for
Expand Down
60 changes: 53 additions & 7 deletions trie/zk_trie_proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func randomZktrie(t *testing.T, n int) (*ZkTrie, map[string]*kv) {
return tr, vals
}

// Tests that new "proof with deletion" feature
// Tests that new "proof trace" feature
func TestProofWithDeletion(t *testing.T) {
tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New())))
mt := &zkTrieImplTestWrapper{tr.Tree()}
Expand All @@ -217,20 +217,66 @@ func TestProofWithDeletion(t *testing.T) {
s_key1, err := zkt.ToSecureKeyBytes(key1)
assert.NoError(t, err)

sibling1, err := tr.ProveWithDeletion(s_key1.Bytes(), 0, proof)
proofTracer := tr.NewProofTracer()

err = proofTracer.Prove(s_key1.Bytes(), 0, proof)
assert.NoError(t, err)
nd, err := tr.TryGet(key2)
assert.NoError(t, err)
l := len(sibling1)

s_key2, err := zkt.ToSecureKeyBytes(bytes.Repeat([]byte("x"), 32))
assert.NoError(t, err)

err = proofTracer.Prove(s_key2.Bytes(), 0, proof)
assert.NoError(t, err)
// assert.Equal(t, len(sibling1), len(delTracer.GetProofs()))

siblings, err := proofTracer.GetDeletionProofs()
assert.NoError(t, err)
assert.Equal(t, 0, len(siblings))

proofTracer.MarkDeletion(s_key1.Bytes())
siblings, err = proofTracer.GetDeletionProofs()
assert.NoError(t, err)
assert.Equal(t, 1, len(siblings))
l := len(siblings[0])
// a hacking to grep the value part directly from the encoded leaf node,
// notice the sibling of key `k*32`` is just the leaf of key `m*32`
assert.Equal(t, sibling1[l-33:l-1], nd)
assert.Equal(t, siblings[0][l-33:l-1], nd)

s_key2, err := zkt.ToSecureKeyBytes(bytes.Repeat([]byte("x"), 32))
// no effect
proofTracer.MarkDeletion(s_key2.Bytes())
siblings, err = proofTracer.GetDeletionProofs()
assert.NoError(t, err)
assert.Equal(t, 1, len(siblings))

key3 := bytes.Repeat([]byte("x"), 32)
err = mt.UpdateWord(
zkt.NewByte32FromBytesPaddingZero(key3),
zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("z"), 32)),
)
assert.NoError(t, err)

proofTracer = tr.NewProofTracer()
err = proofTracer.Prove(s_key1.Bytes(), 0, proof)
assert.NoError(t, err)
err = proofTracer.Prove(s_key2.Bytes(), 0, proof)
assert.NoError(t, err)

proofTracer.MarkDeletion(s_key1.Bytes())
siblings, err = proofTracer.GetDeletionProofs()
assert.NoError(t, err)
assert.Equal(t, 1, len(siblings))

sibling2, err := tr.ProveWithDeletion(s_key2.Bytes(), 0, proof)
proofTracer.MarkDeletion(s_key2.Bytes())
siblings, err = proofTracer.GetDeletionProofs()
assert.NoError(t, err)
assert.Nil(t, sibling2)
assert.Equal(t, 2, len(siblings))

// one of the siblings is just leaf for key2, while
// another one must be a middle node
match1 := bytes.Equal(siblings[0][l-33:l-1], nd)
match2 := bytes.Equal(siblings[1][l-33:l-1], nd)
assert.True(t, match1 || match2)
assert.False(t, match1 && match2)
}
Loading