From 3a5ce53d8644ce34b28383e6860de5485cf9fe19 Mon Sep 17 00:00:00 2001 From: rian Date: Sat, 3 Aug 2024 16:44:22 +0300 Subject: [PATCH] allow vm.Call to modify state if ReaderWriter interface is satisfied --- vm/rust/src/lib.rs | 20 ++++++++++++++------ vm/vm.go | 6 +++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/vm/rust/src/lib.rs b/vm/rust/src/lib.rs index fc4ea0c22d..ac3f41b98b 100644 --- a/vm/rust/src/lib.rs +++ b/vm/rust/src/lib.rs @@ -14,8 +14,8 @@ use blockifier::blockifier::block::{ }; use blockifier::bouncer::BouncerConfig; use blockifier::fee::{fee_utils, gas_usage}; -use blockifier::state::cached_state::{CachedState, MutRefState}; -use blockifier::state::state_api::{State, UpdatableState}; +use blockifier::state::cached_state::CachedState; +use blockifier::state::state_api::State; use blockifier::transaction::objects::GasVector; use blockifier::{ context::{BlockContext, ChainInfo, FeeTokenAddresses, TransactionContext}, @@ -33,7 +33,6 @@ use blockifier::{ }, versioned_constants::VersionedConstants, }; -use std::borrow::Borrow; use std::{ collections::HashMap, ffi::{c_char, c_longlong, c_uchar, c_ulonglong, c_void, CStr, CString}, @@ -101,6 +100,7 @@ pub extern "C" fn cairoVMCall( chain_id: *const c_char, max_steps: c_ulonglong, concurrency_mode: c_uchar, + is_mutable: c_uchar, ) { let block_info = unsafe { *block_info_ptr }; let call_info = unsafe { *call_info_ptr }; @@ -143,13 +143,21 @@ pub extern "C" fn cairoVMCall( initial_gas: get_versioned_constants(block_info.version).tx_initial_gas(), }; - let mut juno_state = JunoStateReader::new(reader_handle, block_info.block_number); + let mut state: Box; + + let juno_reader = JunoStateReader::new(reader_handle, block_info.block_number); + if is_mutable == 1 { + state = Box::new(JunoStateReader::new(reader_handle, block_info.block_number)); + } else { + state = Box::new(CachedState::new(juno_reader)); + } + let concurrency_mode = concurrency_mode == 1; let mut resources = ExecutionResources::default(); let context = EntryPointExecutionContext::new_invoke( Arc::new(TransactionContext { block_context: build_block_context( - &mut juno_state, + &mut *state, &block_info, chain_id_str, Some(max_steps), @@ -164,7 +172,7 @@ pub extern "C" fn cairoVMCall( return; } - match entry_point.execute(&mut juno_state, &mut resources, &mut context.unwrap()) { + match entry_point.execute(&mut *state, &mut resources, &mut context.unwrap()) { Err(e) => report_error(reader_handle, e.to_string().as_str(), -1), Ok(t) => { for data in t.execution.retdata.0 { diff --git a/vm/vm.go b/vm/vm.go index 0ff0604276..3bd634a652 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -29,7 +29,7 @@ typedef struct BlockInfo { } BlockInfo; extern void cairoVMCall(CallInfo* call_info_ptr, BlockInfo* block_info_ptr, uintptr_t readerHandle, char* chain_id, - unsigned long long max_steps, unsigned char concurrency_mode); + unsigned long long max_steps, unsigned char concurrency_mode, unsigned char is_mutable); extern void cairoVMExecute(char* txns_json, char* classes_json, char* paid_fees_on_l1_json, BlockInfo* block_info_ptr, uintptr_t readerHandle, char* chain_id, @@ -270,6 +270,8 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead if v.concurrencyMode { concurrencyModeByte = 1 } + _, isMutableState := context.state.(StateReadWriter) + mutableStateByte := makeByteFromBool(isMutableState) C.setVersionedConstants(C.CString("my_json")) cCallInfo, callInfoPinner := makeCCallInfo(callInfo) @@ -282,6 +284,8 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead chainID, C.ulonglong(maxSteps), //nolint:gocritic C.uchar(concurrencyModeByte), //nolint:gocritic + C.uchar(mutableStateByte), //nolint:gocritic + ) callInfoPinner.Unpin() C.free(unsafe.Pointer(chainID))