Skip to content

Commit

Permalink
fix trace
Browse files Browse the repository at this point in the history
  • Loading branch information
joshklop committed Nov 23, 2023
1 parent dbf939f commit 0441098
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 34 deletions.
8 changes: 4 additions & 4 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions node/throttled_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ func (tvm *ThrottledVM) Call(contractAddr, classHash, selector *felt.Felt, calld

func (tvm *ThrottledVM) Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee, skipValidate bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt,
skipChargeFee, skipValidate bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt, legacyTraceJSON bool,
) ([]*felt.Felt, []json.RawMessage, error) {
var ret []*felt.Felt
var traces []json.RawMessage
throttler := (*utils.Throttler[vm.VM])(tvm)
return ret, traces, throttler.Do(func(vm *vm.VM) error {
var err error
ret, traces, err = (*vm).Execute(txns, declaredClasses, blockNumber, blockTimestamp, sequencerAddress,
state, network, paidFeesOnL1, skipChargeFee, skipValidate, gasPriceWEI, gasPriceSTRK)
state, network, paidFeesOnL1, skipChargeFee, skipValidate, gasPriceWEI, gasPriceSTRK, legacyTraceJSON)
return err
})
}
47 changes: 35 additions & 12 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1316,10 +1316,18 @@ func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *
// It follows the specification defined here:
// https://github.com/starkware-libs/starknet-specs/blob/1ae810e0137cc5d175ace4554892a4f43052be56/api/starknet_trace_api_openrpc.json#L11
func (h *Handler) TraceTransaction(ctx context.Context, hash felt.Felt) (json.RawMessage, *jsonrpc.Error) {
return h.traceTransaction(ctx, &hash)
return h.traceTransaction(ctx, &hash, false)
}

func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt) (json.RawMessage, *jsonrpc.Error) {
// TraceTransaction returns the trace for a given executed transaction, including internal calls
//
// It follows the specification defined here:
// https://github.com/starkware-libs/starknet-specs/blob/1ae810e0137cc5d175ace4554892a4f43052be56/api/starknet_trace_api_openrpc.json#L11
func (h *Handler) LegacyTraceTransaction(ctx context.Context, hash felt.Felt) (json.RawMessage, *jsonrpc.Error) {
return h.traceTransaction(ctx, &hash, true)
}

func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, legacyTraceJSON bool) (json.RawMessage, *jsonrpc.Error) {
_, _, blockNumber, err := h.bcReader.Receipt(hash)
if err != nil {
return nil, ErrInvalidTxHash
Expand All @@ -1337,7 +1345,7 @@ func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt) (json.R
return nil, ErrTxnHashNotFound
}

traceResults, traceBlockErr := h.traceBlockTransactions(ctx, block)
traceResults, traceBlockErr := h.traceBlockTransactions(ctx, block, legacyTraceJSON)
if traceBlockErr != nil {
return nil, traceBlockErr
}
Expand All @@ -1348,11 +1356,17 @@ func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt) (json.R
func (h *Handler) SimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
return h.simulateTransactions(id, transactions, simulationFlags)
return h.simulateTransactions(id, transactions, simulationFlags, false)
}

func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTransaction,
func (h *Handler) LegacySimulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag,
) ([]SimulatedTransaction, *jsonrpc.Error) {
return h.simulateTransactions(id, transactions, simulationFlags, true)
}

func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTransaction,
simulationFlags []SimulationFlag, legacyTraceJSON bool,
) ([]SimulatedTransaction, *jsonrpc.Error) {
skipFeeCharge := slices.Contains(simulationFlags, SkipFeeChargeFlag)
skipValidate := slices.Contains(simulationFlags, SkipValidateFlag)
Expand Down Expand Up @@ -1402,7 +1416,7 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra
sequencerAddress = core.NetworkBlockHashMetaInfo(h.network).FallBackSequencerAddress
}
overallFees, traces, err := h.vm.Execute(txns, classes, blockNumber, header.Timestamp, sequencerAddress,
state, h.network, paidFeesOnL1, skipFeeCharge, skipValidate, header.GasPrice, header.GasPriceSTRK)
state, h.network, paidFeesOnL1, skipFeeCharge, skipValidate, header.GasPrice, header.GasPriceSTRK, legacyTraceJSON)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrUnexpectedError.CloneWithData(err.Error())
Expand Down Expand Up @@ -1432,7 +1446,16 @@ func (h *Handler) TraceBlockTransactions(ctx context.Context, id BlockID) ([]Tra
return nil, ErrBlockNotFound
}

return h.traceBlockTransactions(ctx, block)
return h.traceBlockTransactions(ctx, block, false)
}

func (h *Handler) LegacyTraceBlockTransactions(ctx context.Context, id BlockID) ([]TracedBlockTransaction, *jsonrpc.Error) {
block, err := h.blockByID(&id)
if block == nil || err != nil {
return nil, ErrBlockNotFound
}

return h.traceBlockTransactions(ctx, block, true)
}

var traceFallbackVersion = semver.MustParse("0.12.2")
Expand All @@ -1450,7 +1473,7 @@ func prependBlockHashToState(bc blockchain.Reader, blockNumber uint64, state cor
), nil
}

func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, //nolint: gocyclo
func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, legacyTraceJSON bool, //nolint: gocyclo
) ([]TracedBlockTransaction, *jsonrpc.Error) {
isPending := block.Hash == nil
if !isPending {
Expand Down Expand Up @@ -1522,7 +1545,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block,
}

_, traces, err := h.vm.Execute(block.Transactions, classes, blockNumber, block.Header.Timestamp,
sequencerAddress, state, h.network, paidFeesOnL1, false, false, block.Header.GasPrice, block.Header.GasPriceSTRK)
sequencerAddress, state, h.network, paidFeesOnL1, false, false, block.Header.GasPrice, block.Header.GasPriceSTRK, legacyTraceJSON)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrUnexpectedError.CloneWithData(err.Error())
Expand Down Expand Up @@ -1932,17 +1955,17 @@ func (h *Handler) LegacyMethods() ([]jsonrpc.Method, string) { //nolint: funlen
{
Name: "starknet_traceTransaction",
Params: []jsonrpc.Parameter{{Name: "transaction_hash"}},
Handler: h.TraceTransaction,
Handler: h.LegacyTraceTransaction,
},
{
Name: "starknet_simulateTransactions",
Params: []jsonrpc.Parameter{{Name: "block_id"}, {Name: "transactions"}, {Name: "simulation_flags"}},
Handler: h.SimulateTransactions,
Handler: h.LegacySimulateTransactions,
},
{
Name: "starknet_traceBlockTransactions",
Params: []jsonrpc.Parameter{{Name: "block_id"}},
Handler: h.TraceBlockTransactions,
Handler: h.LegacyTraceBlockTransactions,
},
{
Name: "starknet_specVersion",
Expand Down
14 changes: 7 additions & 7 deletions rpc/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2810,10 +2810,10 @@ func TestEstimateMessageFee(t *testing.T) {

expectedGasConsumed := new(felt.Felt).SetUint64(37)
mockVM.EXPECT().Execute(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any(), utils.Mainnet, gomock.Any(), gomock.Any(), gomock.Any(), latestHeader.GasPrice, latestHeader.GasPriceSTRK).DoAndReturn(
gomock.Any(), utils.Mainnet, gomock.Any(), gomock.Any(), gomock.Any(), latestHeader.GasPrice, latestHeader.GasPriceSTRK, false).DoAndReturn(
func(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee, skipValidate bool, gasPriceWEI *felt.Felt, _ *felt.Felt,
skipChargeFee, skipValidate bool, gasPriceWEI *felt.Felt, _ *felt.Felt, _ bool,
) ([]*felt.Felt, []json.RawMessage, error) {
require.Len(t, txns, 1)
assert.NotNil(t, txns[0].(*core.L1HandlerTransaction))
Expand Down Expand Up @@ -2893,7 +2893,7 @@ func TestTraceTransaction(t *testing.T) {
"fee_transfer_invocation": {"contract_address": "0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", "entry_point_selector": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "calldata": ["0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"], "caller_address": "0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "class_hash": "0xd0e183745e9dae3e4e78a8ffedcce0903fc4900beace4e0abf192d4c202da3", "entry_point_type": "EXTERNAL", "call_type": "CALL", "result": ["0x1"], "calls": [{"contract_address": "0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", "entry_point_selector": "0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e", "calldata": ["0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"], "caller_address": "0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "class_hash": "0x2760f25d5a4fb2bdde5f561fd0b44a3dee78c28903577d37d669939d97036a0", "entry_point_type": "EXTERNAL", "call_type": "DELEGATE", "result": ["0x1"], "calls": [], "events": [{"keys": ["0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"], "data": ["0xd747220b2744d8d8d48c8a52bd3869fb98aea915665ab2485d5eadb49def6a", "0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8", "0x2cb6", "0x0"]}], "messages": []}], "events": [], "messages": []}
}`)
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, header.Number, header.Timestamp, header.SequencerAddress,
gomock.Any(), utils.Mainnet, []*felt.Felt{}, false, false, gomock.Any(), gomock.Any()).Return(nil, []json.RawMessage{vmTrace}, nil)
gomock.Any(), utils.Mainnet, []*felt.Felt{}, false, false, gomock.Any(), gomock.Any(), false).Return(nil, []json.RawMessage{vmTrace}, nil)

trace, err := handler.TraceTransaction(context.Background(), *hash)
require.Nil(t, err)
Expand All @@ -2920,7 +2920,7 @@ func TestSimulateTransactions(t *testing.T) {
mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil)

sequencerAddress := core.NetworkBlockHashMetaInfo(network).FallBackSequencerAddress
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, true, false, nil, nil).
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, true, false, nil, nil, false).
Return([]*felt.Felt{}, []json.RawMessage{}, nil)

_, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
Expand All @@ -2935,7 +2935,7 @@ func TestSimulateTransactions(t *testing.T) {
mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil)

sequencerAddress := core.NetworkBlockHashMetaInfo(network).FallBackSequencerAddress
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, false, true, nil, nil).
mockVM.EXPECT().Execute(nil, nil, uint64(0), uint64(0), sequencerAddress, mockState, network, []*felt.Felt{}, false, true, nil, nil, false).
Return([]*felt.Felt{}, []json.RawMessage{}, nil)

_, err := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipValidateFlag})
Expand Down Expand Up @@ -3004,7 +3004,7 @@ func TestTraceBlockTransactions(t *testing.T) {
"fee_transfer_invocation": {}
}`)
mockVM.EXPECT().Execute(block.Transactions, []core.Class{declaredClass.Class}, height+1, header.Timestamp, sequencerAddress,
gomock.Any(), network, paidL1Fees, false, false, header.GasPrice, header.GasPriceSTRK).Return(nil, []json.RawMessage{vmTrace, vmTrace}, nil)
gomock.Any(), network, paidL1Fees, false, false, header.GasPrice, header.GasPriceSTRK, false).Return(nil, []json.RawMessage{vmTrace, vmTrace}, nil)

result, err := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
require.Nil(t, err)
Expand Down Expand Up @@ -3048,7 +3048,7 @@ func TestTraceBlockTransactions(t *testing.T) {
"fee_transfer_invocation":{"entry_point_selector":"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e","calldata":["0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"],"caller_address":"0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","class_hash":"0xd0e183745e9dae3e4e78a8ffedcce0903fc4900beace4e0abf192d4c202da3","entry_point_type":"EXTERNAL","call_type":"CALL","result":["0x1"],"calls":[{"entry_point_selector":"0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e","calldata":["0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"],"caller_address":"0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","class_hash":"0x2760f25d5a4fb2bdde5f561fd0b44a3dee78c28903577d37d669939d97036a0","entry_point_type":"EXTERNAL","call_type":"DELEGATE","result":["0x1"],"calls":[],"events":[{"keys":["0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9"],"data":["0xdac9bcffb3d967f19a7fe21002c98c984d5a9458a88e6fc5d1c478a97ed412","0x5dcd266a80b8a5f29f04d779c6b166b80150c24f2180a75e82427242dab20a9","0x15be","0x0"]}],"messages":[]}],"events":[],"messages":[]}}
}`)
mockVM.EXPECT().Execute([]core.Transaction{tx}, []core.Class{declaredClass.Class}, header.Number, header.Timestamp, header.SequencerAddress,
gomock.Any(), network, []*felt.Felt{}, false, false, header.GasPrice, header.GasPriceSTRK).Return(nil, []json.RawMessage{vmTrace}, nil)
gomock.Any(), network, []*felt.Felt{}, false, false, header.GasPrice, header.GasPriceSTRK, false).Return(nil, []json.RawMessage{vmTrace}, nil)

expectedResult := []rpc.TracedBlockTransaction{
{
Expand Down
37 changes: 35 additions & 2 deletions vm/rust/src/jsonrpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ pub struct TransactionTrace {
state_diff: Option<StateDiff>,
}


impl TransactionTrace {
pub fn make_legacy(&mut self) {
self.state_diff = None;
self.r#type = None;
if let Some(invocation) = &mut self.validate_invocation {
invocation.make_legacy()
}
if let Some(ExecuteInvocation::Ok(fn_invocation)) = &mut self.execute_invocation {
fn_invocation.make_legacy()
}
if let Some(invocation) = &mut self.fee_transfer_invocation {
invocation.make_legacy()
}
if let Some(invocation) = &mut self.constructor_invocation {
invocation.make_legacy()
}
if let Some(invocation) = &mut self.function_invocation {
invocation.make_legacy()
}
}
}

#[derive(Serialize)]
struct StateDiff {
storage_diffs: Vec<StorageDiff>,
Expand Down Expand Up @@ -227,7 +250,17 @@ pub struct FunctionInvocation {
pub calls: Vec<FunctionInvocation>,
pub events: Vec<OrderedEvent>,
pub messages: Vec<OrderedMessage>,
pub execution_resources: ExecutionResources,
#[serde(skip_serializing_if = "Option::is_none")]
pub execution_resources: Option<ExecutionResources>,
}

impl FunctionInvocation {
fn make_legacy(&mut self) {
self.execution_resources = None;
for call in self.calls.iter_mut() {
call.make_legacy();
}
}
}

use blockifier::execution::call_info::CallInfo as BlockifierCallInfo;
Expand Down Expand Up @@ -260,7 +293,7 @@ impl From<BlockifierCallInfo> for FunctionInvocation {
ordered_message
})
.collect(),
execution_resources: val.vm_resources.into(),
execution_resources: Some(val.vm_resources.into()),
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ pub extern "C" fn cairoVMExecute(
skip_validate: c_uchar,
gas_price_wei: *const c_uchar,
gas_price_strk: *const c_uchar,
legacy_json: c_uchar,
) {
let reader = JunoStateReader::new(reader_handle, block_number);
let chain_id_str = unsafe { CStr::from_ptr(chain_id) }.to_str().unwrap();
Expand Down Expand Up @@ -300,7 +301,7 @@ pub extern "C" fn cairoVMExecute(
}

let actual_fee = t.actual_fee.0.into();
let trace =
let mut trace =
jsonrpc::new_transaction_trace(txn_and_query_bit.txn, t, &mut txn_state);
if trace.is_err() {
report_error(
Expand All @@ -317,6 +318,9 @@ pub extern "C" fn cairoVMExecute(
unsafe {
JunoAppendActualFee(reader_handle, felt_to_byte_array(&actual_fee).as_ptr());
}
if legacy_json == 1 {
trace.as_mut().unwrap().make_legacy()
}
append_trace(reader_handle, trace.as_ref().unwrap(), &mut trace_buffer);
}
}
Expand Down
Loading

0 comments on commit 0441098

Please sign in to comment.