Skip to content

Commit

Permalink
allow vm.Call to modify state if ReaderWriter interface is satisfied
Browse files Browse the repository at this point in the history
  • Loading branch information
rian committed Aug 3, 2024
1 parent 9f14fff commit 3a5ce53
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
20 changes: 14 additions & 6 deletions vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use blockifier::blockifier::block::{
};
use blockifier::bouncer::BouncerConfig;
use blockifier::fee::{fee_utils, gas_usage};
use blockifier::state::cached_state::{CachedState, MutRefState};
use blockifier::state::state_api::{State, UpdatableState};
use blockifier::state::cached_state::CachedState;
use blockifier::state::state_api::State;
use blockifier::transaction::objects::GasVector;
use blockifier::{
context::{BlockContext, ChainInfo, FeeTokenAddresses, TransactionContext},
Expand All @@ -33,7 +33,6 @@ use blockifier::{
},
versioned_constants::VersionedConstants,
};
use std::borrow::Borrow;
use std::{
collections::HashMap,
ffi::{c_char, c_longlong, c_uchar, c_ulonglong, c_void, CStr, CString},
Expand Down Expand Up @@ -101,6 +100,7 @@ pub extern "C" fn cairoVMCall(
chain_id: *const c_char,
max_steps: c_ulonglong,
concurrency_mode: c_uchar,
is_mutable: c_uchar,
) {
let block_info = unsafe { *block_info_ptr };
let call_info = unsafe { *call_info_ptr };
Expand Down Expand Up @@ -143,13 +143,21 @@ pub extern "C" fn cairoVMCall(
initial_gas: get_versioned_constants(block_info.version).tx_initial_gas(),
};

let mut juno_state = JunoStateReader::new(reader_handle, block_info.block_number);
let mut state: Box<dyn State>;

let juno_reader = JunoStateReader::new(reader_handle, block_info.block_number);
if is_mutable == 1 {
state = Box::new(JunoStateReader::new(reader_handle, block_info.block_number));
} else {
state = Box::new(CachedState::new(juno_reader));
}

let concurrency_mode = concurrency_mode == 1;
let mut resources = ExecutionResources::default();
let context = EntryPointExecutionContext::new_invoke(
Arc::new(TransactionContext {
block_context: build_block_context(
&mut juno_state,
&mut *state,
&block_info,
chain_id_str,
Some(max_steps),
Expand All @@ -164,7 +172,7 @@ pub extern "C" fn cairoVMCall(
return;
}

match entry_point.execute(&mut juno_state, &mut resources, &mut context.unwrap()) {
match entry_point.execute(&mut *state, &mut resources, &mut context.unwrap()) {
Err(e) => report_error(reader_handle, e.to_string().as_str(), -1),
Ok(t) => {
for data in t.execution.retdata.0 {
Expand Down
6 changes: 5 additions & 1 deletion vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ typedef struct BlockInfo {
} BlockInfo;
extern void cairoVMCall(CallInfo* call_info_ptr, BlockInfo* block_info_ptr, uintptr_t readerHandle, char* chain_id,
unsigned long long max_steps, unsigned char concurrency_mode);
unsigned long long max_steps, unsigned char concurrency_mode, unsigned char is_mutable);
extern void cairoVMExecute(char* txns_json, char* classes_json, char* paid_fees_on_l1_json,
BlockInfo* block_info_ptr, uintptr_t readerHandle, char* chain_id,
Expand Down Expand Up @@ -270,6 +270,8 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead
if v.concurrencyMode {
concurrencyModeByte = 1
}
_, isMutableState := context.state.(StateReadWriter)
mutableStateByte := makeByteFromBool(isMutableState)
C.setVersionedConstants(C.CString("my_json"))

cCallInfo, callInfoPinner := makeCCallInfo(callInfo)
Expand All @@ -282,6 +284,8 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead
chainID,
C.ulonglong(maxSteps), //nolint:gocritic
C.uchar(concurrencyModeByte), //nolint:gocritic
C.uchar(mutableStateByte), //nolint:gocritic

)
callInfoPinner.Unpin()
C.free(unsafe.Pointer(chainID))
Expand Down

0 comments on commit 3a5ce53

Please sign in to comment.