Skip to content

Commit

Permalink
update trace api
Browse files Browse the repository at this point in the history
  • Loading branch information
joshklop committed Nov 14, 2023
1 parent 623a24c commit dfee656
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 119 deletions.
8 changes: 4 additions & 4 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions node/throttled_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ func (tvm *ThrottledVM) Call(contractAddr, selector *felt.Felt, calldata []felt.

func (tvm *ThrottledVM) Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee bool, gasPrice *felt.Felt, legacyTraceJSON bool,
skipChargeFee bool, gasPrice *felt.Felt,
) ([]*felt.Felt, []json.RawMessage, error) {
var ret []*felt.Felt
var traces []json.RawMessage
throttler := (*utils.Throttler[vm.VM])(tvm)
return ret, traces, throttler.Do(func(vm *vm.VM) error {
var err error
ret, traces, err = (*vm).Execute(txns, declaredClasses, blockNumber, blockTimestamp, sequencerAddress,
state, network, paidFeesOnL1, skipChargeFee, gasPrice, legacyTraceJSON)
state, network, paidFeesOnL1, skipChargeFee, gasPrice)
return err
})
}
51 changes: 14 additions & 37 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ const (

type traceCacheKey struct {
blockHash felt.Felt
legacy bool
}

type Handler struct {
Expand Down Expand Up @@ -1236,14 +1235,10 @@ func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *
// It follows the specification defined here:
// https://github.com/starkware-libs/starknet-specs/blob/1ae810e0137cc5d175ace4554892a4f43052be56/api/starknet_trace_api_openrpc.json#L11
func (h *Handler) TraceTransaction(ctx context.Context, hash felt.Felt) (json.RawMessage, *jsonrpc.Error) {
return h.traceTransaction(ctx, &hash, false)
return h.traceTransaction(ctx, &hash)
}

func (h *Handler) LegacyTraceTransaction(ctx context.Context, hash felt.Felt) (json.RawMessage, *jsonrpc.Error) {
return h.traceTransaction(ctx, &hash, true)
}

func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, legacyTraceJSON bool) (json.RawMessage, *jsonrpc.Error) {
func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt) (json.RawMessage, *jsonrpc.Error) {
_, _, blockNumber, err := h.bcReader.Receipt(hash)
if err != nil {
return nil, ErrInvalidTxHash
Expand All @@ -1261,7 +1256,7 @@ func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, legacyT
return nil, ErrTxnHashNotFound
}

traceResults, traceBlockErr := h.traceBlockTransactions(ctx, block, legacyTraceJSON)
traceResults, traceBlockErr := h.traceBlockTransactions(ctx, block)
if traceBlockErr != nil {
return nil, traceBlockErr
}
Expand All @@ -1272,17 +1267,11 @@ func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, legacyT
func (h *Handler) SimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
return h.simulateTransactions(id, transactions, simulationFlags, false)
}

func (h *Handler) LegacySimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
return h.simulateTransactions(id, transactions, simulationFlags, true)
return h.simulateTransactions(id, transactions, simulationFlags)
}

func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag, legacyTraceJSON bool,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
if slices.Contains(simulationFlags, SkipValidateFlag) {
return nil, jsonrpc.Err(jsonrpc.InvalidParams, "Skip validate is not supported")
Expand Down Expand Up @@ -1334,7 +1323,7 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra
sequencerAddress = core.NetworkBlockHashMetaInfo(h.network).FallBackSequencerAddress
}
overallFees, traces, err := h.vm.Execute(txns, classes, blockNumber, header.Timestamp, sequencerAddress,
state, h.network, paidFeesOnL1, skipFeeCharge, header.GasPrice, legacyTraceJSON)
state, h.network, paidFeesOnL1, skipFeeCharge, header.GasPrice)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrUnexpectedError.CloneWithData(err.Error())
Expand Down Expand Up @@ -1364,35 +1353,24 @@ func (h *Handler) TraceBlockTransactions(ctx context.Context, id BlockID) ([]Tra
return nil, ErrBlockNotFound
}

return h.traceBlockTransactions(ctx, block, false)
}

func (h *Handler) LegacyTraceBlockTransactions(ctx context.Context, hash felt.Felt) ([]TracedBlockTransaction, *jsonrpc.Error) {
block, err := h.bcReader.BlockByHash(&hash)
if err != nil {
return nil, ErrInvalidBlockHash
}

return h.traceBlockTransactions(ctx, block, true)
return h.traceBlockTransactions(ctx, block)
}

var traceFallbackVersion = semver.MustParse("0.12.2")

func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, //nolint: gocyclo
legacyJSON bool,
) ([]TracedBlockTransaction, *jsonrpc.Error) {
isPending := block.Hash == nil
if !isPending {
if blockVer, err := core.ParseBlockVersion(block.ProtocolVersion); err != nil {
return nil, ErrUnexpectedError.CloneWithData(err.Error())
} else if blockVer.Compare(traceFallbackVersion) != 1 {
// version <= 0.12.2
return h.fetchTraces(ctx, block.Hash, legacyJSON)
return h.fetchTraces(ctx, block.Hash)
}

if trace, hit := h.blockTraceCache.Get(traceCacheKey{
blockHash: *block.Hash,
legacy: legacyJSON,
}); hit {
return trace, nil
}
Expand Down Expand Up @@ -1452,7 +1430,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block,
}

_, traces, err := h.vm.Execute(block.Transactions, classes, blockNumber, header.Timestamp,
sequencerAddress, state, h.network, paidFeesOnL1, false, header.GasPrice, legacyJSON)
sequencerAddress, state, h.network, paidFeesOnL1, false, header.GasPrice)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrUnexpectedError.CloneWithData(err.Error())
Expand All @@ -1471,14 +1449,13 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block,
if !isPending {
h.blockTraceCache.Add(traceCacheKey{
blockHash: *block.Hash,
legacy: legacyJSON,
}, result)
}

return result, nil
}

func (h *Handler) fetchTraces(ctx context.Context, blockHash *felt.Felt, legacyTrace bool) ([]TracedBlockTransaction, *jsonrpc.Error) {
func (h *Handler) fetchTraces(ctx context.Context, blockHash *felt.Felt) ([]TracedBlockTransaction, *jsonrpc.Error) {
rpcBlock, err := h.BlockWithTxs(BlockID{
Hash: blockHash, // known non-nil
})
Expand All @@ -1491,7 +1468,7 @@ func (h *Handler) fetchTraces(ctx context.Context, blockHash *felt.Felt, legacyT
return nil, ErrUnexpectedError.CloneWithData(fErr.Error())
}

traces, aErr := adaptBlockTrace(rpcBlock, blockTrace, legacyTrace)
traces, aErr := adaptBlockTrace(rpcBlock, blockTrace)
if aErr != nil {
return nil, ErrUnexpectedError.CloneWithData(aErr.Error())
}
Expand Down Expand Up @@ -1861,17 +1838,17 @@ func (h *Handler) LegacyMethods() ([]jsonrpc.Method, string) { //nolint: funlen
{
Name: "starknet_traceTransaction",
Params: []jsonrpc.Parameter{{Name: "transaction_hash"}},
Handler: h.TraceTransaction, // TODO legacy meaning needs to change
Handler: h.TraceTransaction,
},
{
Name: "starknet_simulateTransactions",
Params: []jsonrpc.Parameter{{Name: "block_id"}, {Name: "transactions"}, {Name: "simulation_flags"}},
Handler: h.SimulateTransactions, // TODO legacy meaning needs to change
Handler: h.SimulateTransactions,
},
{
Name: "starknet_traceBlockTransactions",
Params: []jsonrpc.Parameter{{Name: "block_id"}},
Handler: h.TraceBlockTransactions, // TODO legacy meaning needs to change
Handler: h.TraceBlockTransactions,
},
{
Name: "starknet_specVersion",
Expand Down
12 changes: 6 additions & 6 deletions rpc/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2038,10 +2038,10 @@ func TestEstimateMessageFee(t *testing.T) {

expectedGasConsumed := new(felt.Felt).SetUint64(37)
mockVM.EXPECT().Execute(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any(), utils.MAINNET, gomock.Any(), gomock.Any(), latestHeader.GasPrice, gomock.Any()).DoAndReturn(
gomock.Any(), utils.MAINNET, gomock.Any(), gomock.Any(), latestHeader.GasPrice).DoAndReturn(
func(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee bool, gasPrice *felt.Felt, legacyTraceJson bool,
skipChargeFee bool, gasPrice *felt.Felt,
) ([]*felt.Felt, []json.RawMessage, error) {
require.Len(t, txns, 1)
assert.NotNil(t, txns[0].(*core.L1HandlerTransaction))
Expand Down Expand Up @@ -2121,7 +2121,7 @@ func TestTraceTransaction(t *testing.T) {
"fee_transfer_invocation": {"contract_address": "0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", "entry_point_selector": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "calldata": ["0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"], "caller_address": "0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "class_hash": "0xd0e183745e9dae3e4e78a8ffedcce0903fc4900beace4e0abf192d4c202da3", "entry_point_type": "EXTERNAL", "call_type": "CALL", "result": ["0x1"], "calls": [{"contract_address": "0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", "entry_point_selector": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "calldata": ["0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"], "caller_address": "0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "class_hash": "0x2760f25d5a4fb2bdde5f561fd0b44a3dee78c28903577d37d669939d97036a0", "entry_point_type": "EXTERNAL", "call_type": "DELEGATE", "result": ["0x1"], "calls": [], "events": [{"keys": ["0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"], "data": ["0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"]}], "messages": []}], "events": [], "messages": []}
}`)
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, header.Number, header.Timestamp, header.SequencerAddress,
nil, utils.MAINNET, []*felt.Felt{}, false, gomock.Any(), false).Return(nil, []json.RawMessage{vmTrace}, nil)
nil, utils.MAINNET, []*felt.Felt{}, false, gomock.Any()).Return(nil, []json.RawMessage{vmTrace}, nil)

trace, err := handler.TraceTransaction(context.Background(), *hash)
require.Nil(t, err)
Expand Down Expand Up @@ -2151,7 +2151,7 @@ func TestSimulateTransactions(t *testing.T) {
mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil)

sequencerAddress := core.NetworkBlockHashMetaInfo(network).FallBackSequencerAddress
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, true, nil, false).
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, true, nil).
Return([]*felt.Felt{}, []json.RawMessage{}, nil)

_, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
Expand Down Expand Up @@ -2220,7 +2220,7 @@ func TestTraceBlockTransactions(t *testing.T) {
"fee_transfer_invocation": {}
}`)
mockVM.EXPECT().Execute(block.Transactions, []core.Class{declaredClass.Class}, height+1, header.Timestamp, sequencerAddress,
state, network, paidL1Fees, false, header.GasPrice, false).Return(nil, []json.RawMessage{vmTrace, vmTrace}, nil)
state, network, paidL1Fees, false, header.GasPrice).Return(nil, []json.RawMessage{vmTrace, vmTrace}, nil)

result, err := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
require.Nil(t, err)
Expand Down Expand Up @@ -2264,7 +2264,7 @@ func TestTraceBlockTransactions(t *testing.T) {
"fee_transfer_invocation":{"entry_point_selector":"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e","calldata":["0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"],"caller_address":"0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","class_hash":"0xd0e183745e9dae3e4e78a8ffedcce0903fc4900beace4e0abf192d4c202da3","entry_point_type":"EXTERNAL","call_type":"CALL","result":["0x1"],"calls":[{"entry_point_selector":"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e","calldata":["0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"],"caller_address":"0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","class_hash":"0x2760f25d5a4fb2bdde5f561fd0b44a3dee78c28903577d37d669939d97036a0","entry_point_type":"EXTERNAL","call_type":"DELEGATE","result":["0x1"],"calls":[],"events":[{"keys":["0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"],"data":["0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"]}],"messages":[]}],"events":[],"messages":[]}}
}`)
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, header.Number, header.Timestamp, header.SequencerAddress,
nil, network, []*felt.Felt{}, false, header.GasPrice, false).Return(nil, []json.RawMessage{vmTrace}, nil)
nil, network, []*felt.Felt{}, false, header.GasPrice).Return(nil, []json.RawMessage{vmTrace}, nil)

expectedResult := []rpc.TracedBlockTransaction{
{
Expand Down
27 changes: 9 additions & 18 deletions rpc/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type OrderedL2toL1Message struct {
MsgToL1
}

func adaptBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace, legacyTrace bool) ([]TracedBlockTransaction, error) {
func adaptBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace) ([]TracedBlockTransaction, error) {
if blockTrace == nil {
return nil, nil
}
Expand All @@ -60,14 +60,12 @@ func adaptBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace, legac
for index := range blockTrace.Traces {
feederTrace := &blockTrace.Traces[index]
trace := TransactionTrace{}
if !legacyTrace {
trace.Type = block.Transactions[index].Type
}
trace.Type = block.Transactions[index].Type

trace.FeeTransferInvocation = adaptFunctionInvocation(feederTrace.FeeTransferInvocation, legacyTrace)
trace.ValidateInvocation = adaptFunctionInvocation(feederTrace.ValidateInvocation, legacyTrace)
trace.FeeTransferInvocation = adaptFunctionInvocation(feederTrace.FeeTransferInvocation)
trace.ValidateInvocation = adaptFunctionInvocation(feederTrace.ValidateInvocation)

fnInvocation := adaptFunctionInvocation(feederTrace.FunctionInvocation, legacyTrace)
fnInvocation := adaptFunctionInvocation(feederTrace.FunctionInvocation)
switch block.Transactions[index].Type {
case TxnDeploy:
trace.ConstructorInvocation = fnInvocation
Expand Down Expand Up @@ -97,18 +95,11 @@ func adaptBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace, legac
return traces, nil
}

func adaptFunctionInvocation(snFnInvocation *starknet.FunctionInvocation, legacyTrace bool) *FunctionInvocation {
func adaptFunctionInvocation(snFnInvocation *starknet.FunctionInvocation) *FunctionInvocation {
if snFnInvocation == nil {
return nil
}

orderPtr := func(o uint64) *uint64 {
if legacyTrace {
return nil
}
return &o
}

fnInvocation := FunctionInvocation{
ContractAddress: snFnInvocation.ContractAddress,
EntryPointSelector: snFnInvocation.Selector,
Expand All @@ -123,12 +114,12 @@ func adaptFunctionInvocation(snFnInvocation *starknet.FunctionInvocation, legacy
Messages: make([]OrderedL2toL1Message, 0, len(snFnInvocation.Messages)),
}
for index := range snFnInvocation.InternalCalls {
fnInvocation.Calls = append(fnInvocation.Calls, *adaptFunctionInvocation(&snFnInvocation.InternalCalls[index], legacyTrace))
fnInvocation.Calls = append(fnInvocation.Calls, *adaptFunctionInvocation(&snFnInvocation.InternalCalls[index]))
}
for index := range snFnInvocation.Events {
snEvent := &snFnInvocation.Events[index]
fnInvocation.Events = append(fnInvocation.Events, OrderedEvent{
Order: orderPtr(snEvent.Order),
Order: &snEvent.Order,
Event: Event{
Keys: utils.Map(snEvent.Keys, utils.Ptr[felt.Felt]),
Data: utils.Map(snEvent.Data, utils.Ptr[felt.Felt]),
Expand All @@ -138,7 +129,7 @@ func adaptFunctionInvocation(snFnInvocation *starknet.FunctionInvocation, legacy
for index := range snFnInvocation.Messages {
snMessage := &snFnInvocation.Messages[index]
fnInvocation.Messages = append(fnInvocation.Messages, OrderedL2toL1Message{
Order: orderPtr(snMessage.Order),
Order: &snMessage.Order,
MsgToL1: MsgToL1{
Payload: utils.Map(snMessage.Payload, utils.Ptr[felt.Felt]),
To: common.HexToAddress(snMessage.ToAddr),
Expand Down
36 changes: 0 additions & 36 deletions vm/rust/src/jsonrpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,28 +94,6 @@ struct DeclaredClass {
compiled_class_hash: StarkFelt,
}

impl TransactionTrace {
pub fn make_legacy(&mut self) {
self.state_diff = None;
self.r#type = None;
if let Some(invocation) = &mut self.validate_invocation {
invocation.make_legacy()
}
if let Some(ExecuteInvocation::Ok(fn_invocation)) = &mut self.execute_invocation {
fn_invocation.make_legacy()
}
if let Some(invocation) = &mut self.fee_transfer_invocation {
invocation.make_legacy()
}
if let Some(invocation) = &mut self.constructor_invocation {
invocation.make_legacy()
}
if let Some(invocation) = &mut self.function_invocation {
invocation.make_legacy()
}
}
}

#[derive(Serialize)]
#[serde(untagged)]
pub enum ExecuteInvocation {
Expand Down Expand Up @@ -257,20 +235,6 @@ pub struct FunctionInvocation {
pub execution_resources: ExecutionResources,
}

impl FunctionInvocation {
fn make_legacy(&mut self) {
for indx in 0..self.events.len() {
self.events[indx].order = None;
}
for indx in 0..self.messages.len() {
self.messages[indx].order = None;
}
for indx in 0..self.calls.len() {
self.calls[indx].make_legacy();
}
}
}

type BlockifierCallInfo = blockifier::execution::entry_point::CallInfo;
impl From<BlockifierCallInfo> for FunctionInvocation {
fn from(val: BlockifierCallInfo) -> Self {
Expand Down
Loading

0 comments on commit dfee656

Please sign in to comment.